diff --git a/execution_engine/execution_engine.py b/execution_engine/execution_engine.py index 9bef1e7f..87a9ad30 100644 --- a/execution_engine/execution_engine.py +++ b/execution_engine/execution_engine.py @@ -434,6 +434,23 @@ def execute_recommendation( assert start_datetime.tzinfo, "start_datetime must be timezone-aware" assert end_datetime.tzinfo, "end_datetime must be timezone-aware" + # We cut off microseconds to ensure second-level precision. + # In Python, timestamps are usually floored when cast to int, + # but in PostgreSQL they might be rounded. To avoid subtle bugs + # where a 1-second difference appears due to different rounding + # strategies, we explicitly strip microseconds and assert that + # the timestamp is exactly an integer. + start_datetime = start_datetime.replace(microsecond=0) + end_datetime = end_datetime.replace(microsecond=0) + + assert start_datetime.timestamp() == int( + start_datetime.timestamp() + ), f"start_datetime still contains sub-second precision: {start_datetime}" + + assert end_datetime.timestamp() == int( + end_datetime.timestamp() + ), f"end_datetime still contains sub-second precision: {end_datetime}" + date_format = "%Y-%m-%d %H:%M:%S %z" logging.info( diff --git a/execution_engine/execution_graph/graph.py b/execution_engine/execution_graph/graph.py index 792a034c..b4acdc85 100644 --- a/execution_engine/execution_graph/graph.py +++ b/execution_engine/execution_graph/graph.py @@ -104,6 +104,10 @@ def traverse( traverse(expr, category=category) + # make sure base node has still the correct category - it might have changed duing traversal if the base node + # is used in an interval_criterion of a TemporalCount operator + graph.nodes[base_node].update({"category": CohortCategory.BASE}) + if hash(expr) != expr_hash: raise ValueError("Expression has been modified during traversal") diff --git a/execution_engine/omop/criterion/abstract.py b/execution_engine/omop/criterion/abstract.py index 8ceaece8..bc7e5397 100644 --- a/execution_engine/omop/criterion/abstract.py +++ b/execution_engine/omop/criterion/abstract.py @@ -27,7 +27,8 @@ from execution_engine.util.interval import IntervalType from execution_engine.util.serializable import SerializableDataClassABC from execution_engine.util.sql import SelectInto, select_into -from execution_engine.util.types import PersonIntervals, TimeRange +from execution_engine.util.types import PersonIntervals +from execution_engine.util.types.timerange import TimeRange __all__ = [ "Criterion", diff --git a/execution_engine/omop/criterion/point_in_time.py b/execution_engine/omop/criterion/point_in_time.py index d79342a0..c493a567 100644 --- a/execution_engine/omop/criterion/point_in_time.py +++ b/execution_engine/omop/criterion/point_in_time.py @@ -10,7 +10,8 @@ from execution_engine.omop.criterion.concept import ConceptCriterion from execution_engine.task.process import get_processing_module from execution_engine.util.interval import IntervalType -from execution_engine.util.types import PersonIntervals, TimeRange, Timing +from execution_engine.util.types import PersonIntervals, Timing +from execution_engine.util.types.timerange import TimeRange from execution_engine.util.value import Value process = get_processing_module() diff --git a/execution_engine/omop/db/celida/tables.py b/execution_engine/omop/db/celida/tables.py index 0db17e26..b434d038 100644 --- a/execution_engine/omop/db/celida/tables.py +++ b/execution_engine/omop/db/celida/tables.py @@ -27,8 +27,8 @@ from execution_engine.util.interval import IntervalType # Use the "public" schema so that tables in different schemas can -# these enums easily without introducing depedencies between the -# respective schemas. Note that replicate the enum definitions in each +# use these enums easily without introducing dependencies between the +# respective schemas. Note that replicating the enum definitions in each # schema would not work when data must be exchanged between the # schemas because enum definitions in separate schemas, even if # identical in terms of enum values, are considered distinct and @@ -37,7 +37,6 @@ CohortCategoryEnum = Enum(CohortCategory, name="cohort_category", schema="public") - class Recommendation(Base): # noqa: D101 __tablename__ = "recommendation" __table_args__ = {"schema": SCHEMA_NAME} @@ -170,7 +169,9 @@ class ResultInterval(Base): # noqa: D101 interval_start: Mapped[datetime] interval_end: Mapped[datetime] interval_type = mapped_column(IntervalTypeEnum) - + interval_ratio: Mapped[float] = mapped_column( + nullable=True + ) execution_run: Mapped["ExecutionRun"] = relationship( primaryjoin="ResultInterval.run_id == ExecutionRun.run_id", ) diff --git a/execution_engine/omop/db/celida/views.py b/execution_engine/omop/db/celida/views.py index 19f547da..ffafdfa8 100644 --- a/execution_engine/omop/db/celida/views.py +++ b/execution_engine/omop/db/celida/views.py @@ -199,6 +199,7 @@ def interval_result_view() -> Select: rri.c.interval_type, rri.c.interval_start, rri.c.interval_end, + rri.c.interval_ratio, ) .select_from(rri) .outerjoin(pip, (rri.c.pi_pair_id == pip.c.pi_pair_id)) diff --git a/execution_engine/omop/sqlclient.py b/execution_engine/omop/sqlclient.py index b5488029..ee1e8e81 100644 --- a/execution_engine/omop/sqlclient.py +++ b/execution_engine/omop/sqlclient.py @@ -58,6 +58,35 @@ def _enable_database_triggers( cursor.close() +def datetime_cols_to_epoch(stmt: sqlalchemy.Select) -> sqlalchemy.Select: + """ + Given a SQLAlchemy 2.0 Select that has columns labeled 'interval_start' + or 'interval_end', replace those column expressions with + EXTRACT(EPOCH FROM )::BIGINT so they become integer timestamps. + + Returns a new Select object with the replaced columns. + """ + new_columns = [] + + for col in stmt.selected_columns: + label = getattr(col, "name") + + if label in ("interval_start", "interval_end"): + # We'll wrap col in EXTRACT(EPOCH FROM col)::BIGINT, + new_col = ( + sqlalchemy.func.extract("epoch", col) + .cast(sqlalchemy.BigInteger) + .label(label) + ) + new_columns.append(new_col) + else: + new_columns.append(col) + + new_stmt = stmt.with_only_columns(*new_columns, maintain_column_froms=True) + + return new_stmt + + class OMOPSQLClient: """A client for the OMOP SQL database. diff --git a/execution_engine/task/process/__init__.py b/execution_engine/task/process/__init__.py index b9e4444e..eed88a7a 100644 --- a/execution_engine/task/process/__init__.py +++ b/execution_engine/task/process/__init__.py @@ -3,6 +3,10 @@ import sys import types from collections import namedtuple +from typing import TypeVar + +from execution_engine.util.interval import IntervalType +from execution_engine.util.types.timerange import TimeRange def get_processing_module( @@ -39,6 +43,35 @@ def get_processing_module( Interval = namedtuple("Interval", ["lower", "upper", "type"]) IntervalWithCount = namedtuple("IntervalWithCount", ["lower", "upper", "type", "count"]) -IntervalWithTypeCounts = namedtuple( - "IntervalWithTypeCounts", ["lower", "upper", "counts"] -) + +AnyInterval = Interval | IntervalWithCount +GeneralizedInterval = None | AnyInterval + +TInterval = TypeVar("TInterval", bound=AnyInterval) + + +def interval_like(interval: TInterval, start: int, end: int) -> TInterval: + """ + Return a copy of the given interval with its lower and upper bounds replaced. + + Args: + interval (I): The interval to copy. Must be one of Interval or IntervalWithCount. + start (datetime): The new lower bound. + end (datetime): The new upper bound. + + Returns: + I: A copy of the interval with updated lower and upper bounds. + """ + + return interval._replace(lower=start, upper=end) # type: ignore[return-value] + + +def timerange_to_interval(tr: TimeRange, type_: IntervalType) -> Interval: + """ + Converts a timerange to an interval with the supplied type. + """ + return Interval( + lower=int(tr.start.timestamp()), + upper=int(tr.end.timestamp()), + type=type_, + ) diff --git a/execution_engine/task/process/rectangle.py b/execution_engine/task/process/rectangle.py index 91bef1e8..53236ff1 100644 --- a/execution_engine/task/process/rectangle.py +++ b/execution_engine/task/process/rectangle.py @@ -2,7 +2,8 @@ import importlib import logging import os -from typing import Callable, cast +from collections import defaultdict +from typing import Callable, Dict, List, Set, cast import numpy as np import pendulum @@ -10,9 +11,20 @@ from sqlalchemy import CursorResult from execution_engine.util.interval import IntervalType, interval_datetime -from execution_engine.util.types import TimeRange -from . import Interval, IntervalWithCount +from ...util.types.timerange import TimeRange +from . import ( + GeneralizedInterval, + Interval, + IntervalWithCount, + interval_like, + timerange_to_interval, +) + +IntervalConstructor = Callable[ + [int, int, List[GeneralizedInterval]], GeneralizedInterval +] +SameResult = Callable[[List[GeneralizedInterval], List[GeneralizedInterval]], bool] PROCESS_RECTANGLE_VERSION = os.getenv("PROCESS_RECTANGLE_VERSION", "auto") @@ -64,7 +76,7 @@ def result_to_intervals(result: CursorResult) -> PersonIntervals: """ Converts the result of the interval operations to a list of intervals. """ - person_interval = {} + person_interval = defaultdict(list) for row in result: if row.interval_end < row.interval_start: @@ -76,15 +88,12 @@ def result_to_intervals(result: CursorResult) -> PersonIntervals: raise ValueError("Interval end is None") interval = Interval( - row.interval_start.timestamp(), - row.interval_end.timestamp(), + row.interval_start, + row.interval_end, row.interval_type, ) - if row.person_id not in person_interval: - person_interval[row.person_id] = [interval] - else: - person_interval[row.person_id].append(interval) + person_interval[row.person_id].append(interval) for person_id in person_interval: person_interval[person_id] = _impl.union_rects(person_interval[person_id]) @@ -214,10 +223,10 @@ def forward_fill( if observation_window is not None: last_interval = result[person_id][-1] - if last_interval.upper < observation_window.end.timestamp(): + if last_interval.upper < int(observation_window.end.timestamp()): result[person_id][-1] = Interval( last_interval.lower, - observation_window.end.timestamp(), + int(observation_window.end.timestamp()), last_interval.type, ) @@ -302,15 +311,11 @@ def complementary_intervals( """ interval_type_missing_persons = interval_type - baseline_interval = Interval( - observation_window.start.timestamp(), - observation_window.end.timestamp(), - interval_type_missing_persons, + baseline_interval = timerange_to_interval( + observation_window, type_=interval_type_missing_persons ) - observation_window_mask = Interval( - observation_window.start.timestamp(), - observation_window.end.timestamp(), - IntervalType.least_intersection_priority(), + observation_window_mask = timerange_to_interval( + observation_window, type_=IntervalType.least_intersection_priority() ) result = {} @@ -423,112 +428,6 @@ def union_intervals(data: list[PersonIntervals]) -> PersonIntervals: return _process_intervals(data, _impl.union_interval_lists) -def interval_to_interval_with_count(interval: Interval) -> IntervalWithCount: - """ - Converts an Interval to an IntervalWithCount. - """ - return IntervalWithCount(interval.lower, interval.upper, interval.type, 1) - - -def intervals_to_intervals_with_count( - intervals: list[Interval], -) -> list[IntervalWithCount]: - """ - Converts a list of Intervals to a list of IntervalWithCount. - """ - return [interval_to_interval_with_count(interval) for interval in intervals] - - -def count_intervals(data: list[PersonIntervals]) -> PersonIntervalsWithCount: - """ - Counts the intervals per dict key in the list. - - :param data: A list of dict of intervals. - :return: A dict with the unioned intervals. - """ - if not len(data): - return dict() - - # assert dfs is a list of dataframes - assert isinstance(data, list) and all( - isinstance(arr, dict) for arr in data - ), "data must be a list of dicts" - - result = {} - - for arr in data: - if not len(arr): - # if the operation is union, an empty dataframe can be ignored - continue - - for group_keys, intervals in arr.items(): - intervals_with_count = intervals_to_intervals_with_count(intervals) - intervals_with_count = _impl.union_rects_with_count(intervals_with_count) - if group_keys not in result: - result[group_keys] = intervals_with_count - else: - result[group_keys] = _impl.union_rects_with_count( - result[group_keys] + intervals_with_count - ) - - return result - - -def filter_count_intervals( - data: PersonIntervalsWithCount, - min_count: int | None, - max_count: int | None, - keep_no_data: bool = True, - keep_not_applicable: bool = True, -) -> PersonIntervals: - """ - Filters the intervals per dict key in the list by count. - - :param data: A list of dict of intervals. - :param min_count: The minimum count of the intervals. - :param max_count: The maximum count of the intervals. - :param keep_no_data: Whether to keep NO_DATA intervals (irrespective of the count). - :param keep_not_applicable: Whether to keep NOT_APPLICABLE intervals (irrespective of the count). - :return: A dict with the unioned intervals. - """ - - result: PersonIntervals = {} - - interval_filter = [] - - if keep_no_data: - interval_filter.append(IntervalType.NO_DATA) - if keep_not_applicable: - interval_filter.append(IntervalType.NOT_APPLICABLE) - - if min_count is None and max_count is None: - raise ValueError("min_count and max_count cannot both be None") - elif min_count is not None and max_count is not None: - for person_id in data: - result[person_id] = [ - Interval(interval.lower, interval.upper, interval.type) - for interval in data[person_id] - if min_count <= interval.count <= max_count - or interval.type in interval_filter - ] - elif min_count is not None: - for person_id in data: - result[person_id] = [ - Interval(interval.lower, interval.upper, interval.type) - for interval in data[person_id] - if min_count <= interval.count or interval.type in interval_filter - ] - elif max_count is not None: - for person_id in data: - result[person_id] = [ - Interval(interval.lower, interval.upper, interval.type) - for interval in data[person_id] - if interval.count <= max_count or interval.type in interval_filter - ] - - return result - - def intersect_intervals(data: list[PersonIntervals]) -> PersonIntervals: """ Intersects the intervals per dict key in the list. @@ -547,7 +446,7 @@ def intersect_intervals(data: list[PersonIntervals]) -> PersonIntervals: def mask_intervals( data: PersonIntervals, mask: PersonIntervals, -) -> PersonIntervals: +) -> Dict[int, List[GeneralizedInterval]]: """ Masks the intervals in the dict per key. @@ -569,18 +468,21 @@ def mask_intervals( for interval in intervals ] for person_id, intervals in mask.items() + if person_id in data } - result = {} - for person_id in data: - # intersect every interval in data with every interval in mask - person_result = _impl.intersect_interval_lists( - data[person_id], person_mask[person_id] - ) - if not person_result: - continue + def intersection_interval( + start: int, end: int, intervals: List[GeneralizedInterval] + ) -> GeneralizedInterval: - result[person_id] = person_result + left_interval, right_interval = intervals + + if left_interval is not None and right_interval is not None: + return interval_like(right_interval, start, end) + + return None + + result = find_rectangles([person_mask, data], intersection_interval) return result @@ -671,10 +573,14 @@ def create_time_intervals( end_datetime = end_datetime.in_timezone(timezone) # Prepare to collect intervals - intervals = [] + intervals: list[Interval] = [] previous_end = None - def add_interval(interval_start, interval_end, interval_type): + def add_interval( + interval_start: datetime.datetime, + interval_end: datetime.datetime, + interval_type: IntervalType, + ) -> None: nonlocal previous_end effective_start = max(interval_start, start_datetime) effective_end = min(interval_end, end_datetime) @@ -686,11 +592,11 @@ def add_interval(interval_start, interval_end, interval_type): # method) which can result in an incorrect event order for # touching intervals. if previous_end is not None: - assert previous_end < effective_start + assert previous_end < effective_start # type: ignore[unreachable] intervals.append( Interval( - lower=effective_start.timestamp(), - upper=effective_end.timestamp(), + lower=int(effective_start.timestamp()), + upper=int(effective_end.timestamp()), type=interval_type, ) ) @@ -721,7 +627,6 @@ def add_interval(interval_start, interval_end, interval_type): # Create the interval with the specified interval_type if it # overlaps the main datetime range, otherwise fill the day # with an interval of type "not applicable". - # TODO: what about intervals "before" the main datetime range? if end_interval < start_datetime: # completely before datetime range day_start = timezone.localize( datetime.datetime.combine(current_date, datetime.time(0, 0, 0)) @@ -753,23 +658,6 @@ def add_interval(interval_start, interval_end, interval_type): return intervals -def find_overlapping_windows( - windows: list[Interval], data: PersonIntervals -) -> PersonIntervals: - """ - Returns a list of windows that overlap with any interval in the intervals list. A window is included in the - result if it overlaps in any part with any of the given intervals, not just where they intersect. The entire - window is returned, not just the overlapping segment. - - Note that a single, common list of windows is used for all persons. - - :param windows: A list of windows, where each window is defined as an interval. - :param data: The dict with intervals that are checked for overlap with the windows. - :return: A list of windows that have any overlap with the intervals. - """ - return {key: _impl.find_overlapping_windows(windows, data[key]) for key in data} - - def find_overlapping_personal_windows( windows: PersonIntervals, data: PersonIntervals ) -> PersonIntervals: @@ -805,14 +693,48 @@ def find_overlapping_personal_windows( return result -def find_rectangles_with_count(data: list[PersonIntervals]) -> PersonIntervals: +def find_rectangles( + data: list[PersonIntervals], + interval_constructor: IntervalConstructor, + is_same_result: SameResult | None = None, +) -> Dict[int, List[GeneralizedInterval]]: + """ + Constructs new intervals ("time slices") by combining multiple parallel tracks of intervals. + + This iterates over all intervals for each person across the given `data` list. Whenever an + interval starts or ends on any track, that boundary can produce a new interval. For each + interval, we invoke `interval_constructor(start, end, active_intervals)` to decide how + to label it (e.g., POSITIVE, NEGATIVE). + + If `is_same_result` is provided, it’s used to decide whether two adjacent slices have the + same "type" (so we can merge them). If not provided, a default routine is used that merges + slices if they have the same object identity and the same result. + + :param data: A list of dictionaries, each mapping a person ID to a list of intervals. + :param interval_constructor: A callable that takes the time boundaries and the corresponding intervals + from each source and returns a new interval or None. + :param is_same_result: Optional helper for merging adjacent intervals of the same type to avoid + unnecessary fragmentation. + :return: A dictionary mapping each person ID to the newly constructed intervals. + """ + # TODO(jmoringe): can this use _process_interval? if len(data) == 0: return {} - else: - keys = data[0].keys() - return { - key: _impl.find_rectangles_with_count( - [intervals[key] for intervals in data] - ) - for key in keys - } + + # Collect all person IDs across all tracks + keys: Set[int] = set() + result: Dict[int, List[GeneralizedInterval]] = dict() + + for track in data: + keys |= track.keys() + + for key in keys: + key_result = _impl.find_rectangles( + [intervals.get(key, []) for intervals in data], + interval_constructor, + is_same_result=is_same_result, + ) + if len(key_result) > 0: + result[key] = key_result + + return result diff --git a/execution_engine/task/process/rectangle_cython.pyx b/execution_engine/task/process/rectangle_cython.pyx index 688b7fd3..caad697e 100644 --- a/execution_engine/task/process/rectangle_cython.pyx +++ b/execution_engine/task/process/rectangle_cython.pyx @@ -1,25 +1,33 @@ -import copy -from functools import reduce -import datetime -from collections import namedtuple +import typing +from functools import cmp_to_key +from typing import Callable -cimport numpy as np - -import numpy as np from sortedcontainers import SortedDict -from execution_engine.task.process import Interval, IntervalWithCount, IntervalWithTypeCounts +from execution_engine.task.process import AnyInterval, Interval, IntervalWithCount +from execution_engine.task.process.rectangle import IntervalConstructor, SameResult from execution_engine.util.interval import IntervalType +from libc.limits cimport SCHAR_MIN +from libc.stdint cimport int64_t + + +# A sentinel for “negative infinity” in 64-bit integer space +cdef int64_t NEG_INF = -9223372036854775807 # -2^63 +cdef int64_t POS_INF = 9223372036854775807 # 2^63 - 1 + DEF SCHAR_MIN = -128 DEF SCHAR_MAX = 127 MODULE_IMPLEMENTATION = "cython" +IntervalEvent = typing.Tuple[int, bool, AnyInterval] +IntervalEventOnTrack = typing.Tuple[int, bool, AnyInterval, int] + def intervals_to_events( - intervals: list[Interval], + intervals: list[AnyInterval], closing_offset: int = 1, -) -> list[tuple[int, bool, IntervalType]]: +) -> list[IntervalEvent]: """ Converts the intervals to a list of events. @@ -28,78 +36,16 @@ def intervals_to_events( :param intervals: The intervals. :return: The events. """ - events = [(i.lower, True, i.type) for i in intervals] + [ - (i.upper + closing_offset, False, i.type) for i in intervals - ] - return sorted( - events, - key=lambda i: (i[0]), - ) - -def intervals_with_count_to_events( - intervals: list[IntervalWithCount], -) -> list[tuple[int, bool, IntervalType, int]]: - """ - Converts the intervals to a list of events. - - The events are a sorted list of the opening/closing points of all rectangles. - - :param intervals: The intervals. - :return: The events. - """ - events = [(i.lower, True, i.type, i.count) for i in intervals] + [ - (i.upper + 1, False, i.type, i.count) for i in intervals - ] - return sorted( - events, - key=lambda i: (i[0]), - ) - - -def intersect_rects(list[Interval] intervals) -> list[Interval]: - cdef double x_min = -np.inf - cdef signed char y_min = SCHAR_MAX - cdef double end_point = np.inf - - if not len(intervals): - return [] - - order = IntervalType.intersection_priority() - events = intervals_to_events(intervals) - - for cur_x, start_point, y_max_type in events: - - y_max = order.index(y_max_type) - - if start_point: - if end_point < cur_x: - # we already hit an endpoint and here starts a new one, so the intersection is empty - return [] - - if y_max < y_min: - y_min = y_max - - if cur_x > x_min: - # this point is further than the previously found one, so reset the intersection's start point - x_min = cur_x - - else: - # we found and endpoint - if cur_x > end_point: - # this endpoint lies behind another endpoint, we know we can stop - if x_min > end_point - 1: - return [] - return [Interval(lower=x_min, upper=end_point - 1, type=order[y_min])] - end_point = cur_x - - return [Interval(lower=x_min, upper=end_point - 1, type=order[y_min])] + events = [ (i.lower, True, i) for i in intervals ] \ + + [ (i.upper + closing_offset, False, i) for i in intervals ] + return sorted(events,key=lambda i: i[0]) def union_rects(list[Interval] intervals) -> list[Interval]: - cdef double last_x = -np.inf - cdef double last_x_closed = -np.inf - cdef double cur_x = -np.inf - cdef double first_x + cdef int64_t last_x = NEG_INF + cdef int64_t last_x_closed = NEG_INF + cdef int64_t cur_x = NEG_INF + cdef int64_t first_x cdef signed char max_open_y = SCHAR_MIN #cdef signed char open_y[len(intervals)] @@ -116,7 +62,8 @@ def union_rects(list[Interval] intervals) -> list[Interval]: union = [] - for x_min, start_point, y_max_type in events: + for x_min, start_point, interval in events: + y_max_type = interval.type y_max = order.index(y_max_type) if x_min > cur_x: @@ -182,85 +129,6 @@ def union_rects(list[Interval] intervals) -> list[Interval]: last_x = cur_x # start new output rectangle return union -def union_rects_with_count(list[IntervalWithCount] intervals) -> list[IntervalWithCount]: - cdef double last_x_start = -np.inf - cdef double last_x_end; - cdef double previous_x_visited = -np.inf - cdef double first_x - cdef signed char max_open_y = SCHAR_MIN - - if not intervals: - return [] - - order = IntervalType.union_priority()[::-1] - - events = intervals_with_count_to_events(intervals) - - union = [] - - last_x_start = -np.inf # holds the x_min of the currently open output rectangle - last_x_end = events[0][0] # x variable of the last closed interval (we start with the first x, so we - # don't close the first rectangle at the first x) - previous_x_visited = -np.inf - open_y = SortedDict() - - def get_y_max() -> IntervalType | None: - max_key = None - for key in reversed(open_y): - if open_y[key] > 0: - max_key = key - break - return max_key - - for x, start_point, y_type, count_event in events: - y = order.index(y_type) - if start_point: - y_max = get_y_max() - - if x > previous_x_visited and y_max is None: - # no currently open rectangles - last_x_start = x # start new output rectangle - elif y >= y_max: - if x == last_x_end or x == last_x_start: - # we already closed a rectangle at this x, so we don't need to start a new one - open_y[y] = open_y.get(y, 0) + count_event - continue - - union.append( - IntervalWithCount( - lower=last_x_start, upper=x - 1, type=order[y_max], count=open_y[y_max] - ) - ) - last_x_end = x - last_x_start = x - - open_y[y] = open_y.get(y, 0) + count_event - - else: - open_y[y] = max(open_y.get(y, 0) - count_event, 0) - - y_max = get_y_max() - - if (y_max is None or (open_y and y_max <= y)) and x > last_x_end: - if y_max is None or y_max < y: - # the closing rectangle has a higher y_max than the currently open ones - count = count_event - else: - # the closing rectangle has the same y_max as the currently open ones - count = open_y[y] + count_event - - union.append( - IntervalWithCount( - lower=last_x_start, upper=x - 1, type=order[y], count=count - ) - ) # close the previous rectangle at y_max - last_x_end = x - last_x_start = x # start new output rectangle - - previous_x_visited = x - - return merge_adjacent_intervals(union) - def merge_adjacent_intervals(intervals: list[IntervalWithCount]) -> list[IntervalWithCount]: """ Merges adjacent intervals in a list of IntervalWithCount namedtuples if they have the same 'type' and 'count'. @@ -319,6 +187,46 @@ def merge_adjacent_intervals(intervals: list[IntervalWithCount]) -> list[Interva return merged_intervals + +def intersect_rects(list[Interval] intervals) -> list[Interval]: + cdef int64_t x_min = NEG_INF + cdef signed char y_min = SCHAR_MAX + cdef int64_t end_point = POS_INF + + if not len(intervals): + return [] + + order = IntervalType.intersection_priority() + events = intervals_to_events(intervals) + + for cur_x, start_point, interval in events: + y_max_type = interval.type + y_max = order.index(y_max_type) + + if start_point: + if end_point < cur_x: + # we already hit an endpoint and here starts a new one, so the intersection is empty + return [] + + if y_max < y_min: + y_min = y_max + + if cur_x > x_min: + # this point is further than the previously found one, so reset the intersection's start point + x_min = cur_x + + else: + # we found and endpoint + if cur_x > end_point: + # this endpoint lies behind another endpoint, we know we can stop + if x_min > end_point - 1: + return [] + return [Interval(lower=x_min, upper=end_point - 1, type=order[y_min])] + end_point = cur_x + + return [Interval(lower=x_min, upper=end_point - 1, type=order[y_min])] + + def intersect_interval_lists( left: list[Interval], right: list[Interval] ) -> list[Interval]: @@ -347,159 +255,273 @@ def union_interval_lists( return union_rects(left + right) -def find_overlapping_windows( - windows: list[Interval], intervals: list[Interval] -) -> list[Interval]: +def default_is_same_result(interval_constructor: IntervalConstructor): """ - Returns a list of windows that overlap with any interval in the intervals list. A window is included in the - result if it overlaps in any part with any of the given intervals, not just where they intersect. The entire - window is returned, not just the overlapping segment. - - :param windows: A list of windows, where each window is defined as an interval. - :param intervals: A list of intervals that are checked for overlap with the windows. - :return: A list of windows that have any overlap with the intervals. + Creates an 'is_same_result' function that determines whether two sets of active intervals + produce the same resulting interval when passed to 'interval_constructor'. + + The returned function calls: + interval_constructor(0, 0, active_intervals1) + and + interval_constructor(0, 0, active_intervals2) + and checks if the results are equal. If they match, we say they represent the “same” result. + + :param interval_constructor: An interval constructor function. + :return: + A function 'is_same_result' that compares the results of two different sets of active + intervals by invoking 'interval_constructor' on each and checking for equality. """ - # Convert all intervals and windows into events - window_events = intervals_to_events(windows, closing_offset=0) - interval_events = intervals_to_events(intervals, closing_offset=0) - - # Here we collect interval for the intersecting windows - intersecting_windows = [] - def add_segment(start, end, interval_type): - intersecting_windows.append(Interval(start, end, interval_type)) - - # State and "event handler" functions for state transitions: - # inside/not inside window, inside/not inside at least one - # interval. - previous_event = None - window_state = False - any_satisfied_in_window = False - satisfied_interval_type = None - def window_open(event_time, interval_type): - nonlocal previous_event, window_state, any_satisfied_in_window - assert not window_state - window_state = interval_type - if satisfied_interval_type is not None: - any_satisfied_in_window = satisfied_interval_type - previous_event = event_time - def window_close(event_time): - nonlocal previous_event, window_state, any_satisfied_in_window - assert window_state - if window_state == IntervalType.NOT_APPLICABLE: - interval_type = IntervalType.NOT_APPLICABLE - add_segment(previous_event, event_time, interval_type) - elif any_satisfied_in_window == False: - pass - else: - interval_type = any_satisfied_in_window - add_segment(previous_event, event_time, interval_type) - window_state = False - any_satisfied_in_window = False - previous_event = event_time - def interval_satisfied(event_time, interval_type): - nonlocal satisfied_interval_type, any_satisfied_in_window - # If we are inside a window, remember that we saw an interval - # of type interval_type. Do not overwrite previously seen - # higher priority types with lower priority types - if not (satisfied_interval_type in [IntervalType.POSITIVE, IntervalType.NEGATIVE]): - satisfied_interval_type = interval_type - if window_state: - # Priorities: POSITIVE > NEGATIVE > NO_DATA or NOT_APPLICABLE > no value - if any_satisfied_in_window == IntervalType.POSITIVE: - pass - elif any_satisfied_in_window == IntervalType.NEGATIVE: - if interval_type == IntervalType.POSITIVE: - any_satisfied_in_window = interval_type - else: - any_satisfied_in_window = interval_type - def interval_unsatisfied(event_time): - nonlocal satisfied_interval_type - satisfied_interval_type = None - - # Use two indices to traverse the two sorted lists of events in - # parallel. Call event handler functions for state transitions. - def interleaved_events(): - w_idx, i_idx = 0, 0 - while True: - window_event = window_events[w_idx] if w_idx < len(window_events) else None - interval_event = interval_events[i_idx] if i_idx < len(interval_events) else None - # When tied in terms of event time, use the following - # priority for reporting events: - # window open > interval open > interval close > window close - if window_event and (not interval_event or window_event[0] < interval_event[0] or ( - window_event[0] == interval_event[0] and window_event[1])): - w_idx += 1 - yield window_event, None - elif interval_event: - i_idx += 1 - yield None, interval_event - else: - break - active_intervals = 0 - for window_event, interval_event in interleaved_events(): - if window_event: - time, open_, interval_type = window_event - window_open(time, interval_type) if open_ else window_close(time) - else: - time, open_, type_ = interval_event - active_intervals += (1 if open_ else -1) - if active_intervals == 0: # 1 -> 0 transition - interval_unsatisfied(time) - elif active_intervals == 1: # 0 -> 1 transition - interval_satisfied(time, type_) - - # Return the list of unique intersecting windows - return intersecting_windows - -def find_rectangles_with_count(all_intervals: list[list[Interval]]) -> list[IntervalWithTypeCounts]: + def is_same_result( + active_intervals1: List[GeneralizedInterval], + active_intervals2: List[GeneralizedInterval], + ) -> bool: + """ + Compares the resulting intervals for two sets of active intervals. + + :param active_intervals1: + A list of intervals (or None) describing the first track’s active intervals. + :param active_intervals2: + A list of intervals (or None) describing the second track’s active intervals. + :return: + True if 'interval_constructor(0, 0, ...)' produces the same interval for + both sets, otherwise False. + """ + # When we have to decide whether to extend a result interval + # or start a new one, we compare the state for the existing + # result interval with the new state. The states are derived + # from the respective lists of active intervals by calling + # interval_constructor (with fake points in time). + return (interval_constructor(0, 0, active_intervals1) + == interval_constructor(0, 0, active_intervals2)) + return is_same_result + +def find_rectangles( + all_intervals: list[list[AnyInterval]], + interval_constructor: IntervalConstructor, + is_same_result: SameResult | None = None, +) -> list[AnyInterval]: """ - For multiple parallel "tracks" of intervals, identify temporal - intervals in which no change occurs on any "track". For each such - interval, report the number of active intervals grouped by - interval type across all "tracks". When there is no interval on a - track for a given temporal interval, act as if a negative interval - was present there. + Low-level engine for interval construction. + + For multiple parallel "tracks" of intervals, identify segments of time + in which no change occurs on any "track". For each such segment, + call `interval_constructor(start, end, active_intervals)` to determine + how to represent the interval in the overall result. To this end, + interval_constructor receives a list "active" intervals the + elements of which are either None or an interval from + all_intervals and returns either None or an interval. The returned + None values and intervals are further processed into the overall + return value by merging adjacent intervals without "payload" + change. :param all_intervals: A list of intervals that are checked for overlap with the windows. - :return: A list of windows that have any overlap with the intervals. + :param interval_constructor: A callable that accepts a start time, + an end time and a list of "active" + intervals and returns None or an + interval. The list of active + intervals has the same length as + all_intervals and each element is + either None or an element from the + corresponding list in all_intervals. + :return: A list of intervals computed by interval_constructor such + that adjacent intervals (i.e. without gaps between them) + have different "payloads". """ - # Convert all intervals into a list of events sorted by - # time. Multiple events at the same point in time are not a - # problem here: since we simply count the number of "active" - # intervals the result does not depend on the order in which we - # process the events. + if is_same_result is None: + is_same_result = default_is_same_result(interval_constructor) + + # Convert all intervals into a single list of events sorted by + # time. Multiple events at the same point in time can be problem + # here: If an interval open event and an interval close event on + # the same track happen at the same time (which happens for + # adjacent intervals on that track), we must order the close event + # before the open event, otherwise our tracking of active + # intervals would get confused. track_count = len(all_intervals) - events = reduce(lambda acc, intervals: acc + intervals_to_events(intervals, closing_offset=0), - all_intervals, - []) - events.sort(key=lambda i: i[0]) - # The result will be a list of intervals + events = [ + (time, event, interval, j) + for j, intervals in enumerate(all_intervals) + for interval in intervals + for (time, event) in [(interval.lower, True), (interval.upper, False)] + ] + event_count = len(events) + + if event_count == 0: + return [] + + def compare_events( + event1: IntervalEventOnTrack, event2: IntervalEventOnTrack + ) -> int: + """ + Sorting comparator to ensure we process events in the correct order: + - earlier time first + - if same time and same track, close events before open events + (so we don't incorrectly treat a consecutive interval on the same track + as overlapping). + + Index of event1 and event2: + - [0]: time of event + - [1]: opening (True) or closing (False) + - [2]: the interval to which the event belongs + - [3]: track index + """ + if event1[0] < event2[0]: # event1 is earlier + return -1 + elif event2[0] < event1[0]: # event2 is earlier + return 1 + elif event1[3] == event2[3]: # at the same time and on same track, + if event1[2] == event2[2]: # same interval (we don't check for "is" because they might be different objects, but still represent the same interval) + return ( + -1 if (event1[1] is True) else 1 + ) # sort open events before open events + else: # different intervals + return ( + -1 if (event1[1] is False) else 1 + ) # sort close events before open events + else: # at the same time, but different tracks => any order is fine + return 1 + + # Sort events chronologically according to compare_events + events.sort(key=cmp_to_key(compare_events)) + + active_intervals: list[GeneralizedInterval] = [None] * track_count + + def finalize_interval( + interval_start_time: int, + current_time: int, + interval_start_state: List[GeneralizedInterval], + ) -> None: + """ + Appends a new time slice (interval_start_time -> current_time) to 'result_intervals', + ensuring we don't create duplicate adjacency boundaries if the previous slice ends + exactly where the new one starts. + """ + if len(result_intervals) > 0: + previous_result = result_intervals[-1] + if previous_result[1] == interval_start_time: + # Adjust the previous slice so it doesn't overlap or duplicate + result_intervals[-1] = ( + previous_result[0], + previous_result[1] - 1, + previous_result[2], + ) + + # Now finalize the current slice + result_intervals.append( + (interval_start_time, current_time, interval_start_state) + ) + + + def process_events_for_point_in_time( + index: int, point_time: int + ) -> Tuple[int, int, int] | None: + """ + Consumes events that occur at `point_time` (or effectively that boundary), + updating 'active_intervals' for whichever track is opening or closing + intervals at that time. + + Returns: (new_index, new_time, copy_of_active_intervals, high_time) + - new_index: the index of the first event not processed (because it's after point_time) + - new_time: the time of that next event + - copy_of_active_intervals: a snapshot of 'active_intervals' after processing + - high_time: the highest time covered by these events (may be the same as point_time + or point_time + 1 if we consider inclusive boundaries). + + If we run out of events entirely, returns (None, None, None, None). + """ + high_time = point_time + any_open = False + + for i in range(index, event_count): + time, open_, interval, track = events[i] + # Since points in time for intervals are quantized to whole + # seconds and intervals are closed (inclusive) for both start + # and end points, two adjacent intervals like + # [START_TIME1, 10:59:59] [11:00:00, END_TIME2] + # have no gap between them and can be considered a single + # continuous interval [START_TIME1, END_TIME2]. + + point_interval_closing = any_open and not open_ and interval.lower == interval.upper == point_time + + if ((point_time == time) and not point_interval_closing) or (open_ and (point_time == time - 1)): + if time > high_time: + high_time = time + any_open |= open_ + else: + # As soon as we find an event that’s clearly beyond the cluster at point_time, + # we break and return + return ( + i, + time, + high_time if any_open else high_time + 1, + ) + + # Opening => set this track’s active interval to the new interval + # Closing => set it to None + active_intervals[track] = interval if open_ else None + + # If we exit the loop fully, we used all events + return None + + # Step through event "clusters" with a common point in time and + # emit result intervals with unchanged interval "payload". + index: int | None = 0 + time: int | None = events[index][0] + interval_start_time: int = time + result_intervals: list[tuple[int, int, List[GeneralizedInterval]]] = [] + + if time is None: + # No events at all + return [] + + # process the event at index 0 at the first timepoint + res = process_events_for_point_in_time(index, time) + + if res is None: + return [] + + index, time, high_time = res + + interval_start_state = active_intervals.copy() + + # The main loop: step through event clusters + while True: + res = process_events_for_point_in_time(index, time) + if res is None: + # No more events => finalize the last slice and break + finalize_interval(interval_start_time, time, interval_start_state) + break + + new_index, new_time, high_time = res + + # Diagram for this program point: + # |___potential_result_interval___|| | + # index new_index + # interval_start_time time new_time + # interval_start_state maybe_end_state + # high_time + + # We have a region from [interval_start_time, time) or [interval_start_time, time] + # with 'interval_start_state' as the active intervals. + # Decide if we finalize that region or if we can merge with the next region. + if not is_same_result(interval_start_state, active_intervals): + # If the active intervals changed, finalize the old slice + finalize_interval(interval_start_time, time, interval_start_state) + + # Update interval start info. + interval_start_time = high_time + interval_start_state = active_intervals.copy() + + index, time = new_index, new_time + result = [] - def add_segment(start, end, type_counts): - # We consider the period between 23:59:59 of a day and - # 00:00:00 of the following day to be empty. - if not (start == end - 1 and datetime.datetime.fromtimestamp(end).time() == datetime.time(0, 0, 0)): - # Assume implicit negative intervals: increase the count - # for the negative type as needed so that the overall - # count is equal to track_count. - missing = track_count - sum(type_counts.values()) - if missing > 0: - type_counts[IntervalType.NEGATIVE] = type_counts.get(IntervalType.NEGATIVE, 0) + missing - result.append(IntervalWithTypeCounts(start, end, type_counts)) - - # Step through events and emit result intervals whenever the - # counts change. - counts = dict() - previous_time = events[0][0] - for (time, open_, interval_type) in events: - if previous_time is None: - previous_time = time - elif not previous_time == time: - add_segment(previous_time, time, copy.copy(counts)) - previous_time = time - - old_count = counts.get(interval_type, 0) - counts[interval_type] = old_count + (1 if open_ else -1) + + # Finally, convert the (start, end, intervals) slices into actual Interval objects + for start, end, intervals in result_intervals: + interval = interval_constructor(start, end, intervals) + + if interval is not None: + result.append(interval) return result diff --git a/execution_engine/task/process/rectangle_python.py b/execution_engine/task/process/rectangle_python.py index bc0609b9..8d92c83e 100644 --- a/execution_engine/task/process/rectangle_python.py +++ b/execution_engine/task/process/rectangle_python.py @@ -1,15 +1,40 @@ -import copy -import datetime -from functools import reduce +from functools import cmp_to_key +from typing import List, Tuple, cast import numpy as np -from sortedcontainers import SortedDict, SortedList - -from execution_engine.task.process import Interval, IntervalWithCount, IntervalWithTypeCounts +from sortedcontainers import SortedList + +from execution_engine.task.process import ( + AnyInterval, + GeneralizedInterval, + Interval, + IntervalWithCount, +) +from execution_engine.task.process.rectangle import IntervalConstructor, SameResult from execution_engine.util.interval import IntervalType MODULE_IMPLEMENTATION = "python" +IntervalEvent = Tuple[int, bool, AnyInterval] +IntervalEventOnTrack = Tuple[int, bool, AnyInterval, int] + + +def intervals_to_events( + intervals: list[Interval], closing_offset: int = 1 +) -> list[IntervalEvent]: + """ + Converts the intervals to a list of events. + + The events are a sorted list of the opening/closing points of all rectangles. + + :param intervals: The intervals. + :return: The events. + """ + events = [(i.lower, True, i) for i in intervals] + [ + (i.upper + closing_offset, False, i) for i in intervals + ] + return sorted(events, key=lambda i: i[0]) + def union_rects(intervals: list[Interval]) -> list[Interval]: """ @@ -30,7 +55,8 @@ def union_rects(intervals: list[Interval]) -> list[Interval]: cur_x = -np.inf open_y = SortedList() - for x_min, start_point, y_max in events: + for x_min, start_point, interval in events: + y_max = interval.type if x_min > cur_x: # previously unvisited x cur_x = x_min @@ -88,89 +114,6 @@ def union_rects(intervals: list[Interval]) -> list[Interval]: return union -def union_rects_with_count( - intervals: list[IntervalWithCount], -) -> list[IntervalWithCount]: - """ - Unions the intervals while keeping track of the count of overlapping intervals of the same type. - """ - - if not len(intervals): - return [] - - with IntervalType.union_order(): - events = intervals_with_count_to_events(intervals) - - union = [] - - last_x_start = -np.inf # holds the x_min of the currently open output rectangle - last_x_end = events[0][ - 0 - ] # x variable of the last closed interval (we start with the first x, so we - # don't close the first rectangle at the first x) - previous_x_visited = -np.inf - open_y = SortedDict() - - def get_y_max() -> IntervalType | None: - max_key = None - for key in reversed(open_y): - if open_y[key] > 0: - max_key = key - break - return max_key - - for x, start_point, y, count_event in events: - if start_point: - y_max = get_y_max() - - if x > previous_x_visited and y_max is None: - # no currently open rectangles - last_x_start = x # start new output rectangle - elif y >= y_max: - if x == last_x_end or x == last_x_start: - # we already closed a rectangle at this x, so we don't need to start a new one - open_y[y] = open_y.get(y, 0) + count_event - continue - - union.append( - IntervalWithCount( - lower=last_x_start, - upper=x - 1, - type=y_max, - count=open_y[y_max], - ) - ) - last_x_end = x - last_x_start = x - - open_y[y] = open_y.get(y, 0) + count_event - - else: - open_y[y] = max(open_y.get(y, 0) - count_event, 0) - - y_max = get_y_max() - - if (y_max is None or (open_y and y_max <= y)) and x > last_x_end: - if y_max is None or y_max < y: - # the closing rectangle has a higher y_max than the currently open ones - count = count_event - else: - # the closing rectangle has the same y_max as the currently open ones - count = open_y[y] + count_event - - union.append( - IntervalWithCount( - lower=last_x_start, upper=x - 1, type=y, count=count - ) - ) # close the previous rectangle at y_max - last_x_end = x - last_x_start = x # start new output rectangle - - previous_x_visited = x - - return merge_adjacent_intervals(union) - - def merge_adjacent_intervals( intervals: list[IntervalWithCount], ) -> list[IntervalWithCount]: @@ -249,7 +192,8 @@ def intersect_rects(intervals: list[Interval]) -> list[Interval]: y_min = np.inf end_point = np.inf - for cur_x, start_point, y_max in events: + for cur_x, start_point, interval in events: + y_max = interval.type if start_point: if end_point < cur_x: # we already hit an endpoint and here starts a new one, so the intersection is empty @@ -274,46 +218,6 @@ def intersect_rects(intervals: list[Interval]) -> list[Interval]: return [Interval(lower=x_min, upper=end_point - 1, type=y_min)] -def intervals_to_events( - intervals: list[Interval], closing_offset: int = 1 -) -> list[tuple[int, bool, IntervalType]]: - """ - Converts the intervals to a list of events. - - The events are a sorted list of the opening/closing points of all rectangles. - - :param intervals: The intervals. - :return: The events. - """ - events = [(i.lower, True, i.type) for i in intervals] + [ - (i.upper + closing_offset, False, i.type) for i in intervals - ] - return sorted( - events, - key=lambda i: (i[0]), - ) - - -def intervals_with_count_to_events( - intervals: list[IntervalWithCount], -) -> list[tuple[int, bool, IntervalType, int]]: - """ - Converts the intervals to a list of events. - - The events are a sorted list of the opening/closing points of all rectangles. - - :param intervals: The intervals. - :return: The events. - """ - events = [(i.lower, True, i.type, i.count) for i in intervals] + [ - (i.upper + 1, False, i.type, i.count) for i in intervals - ] - return sorted( - events, - key=lambda i: (i[0]), - ) - - def intersect_interval_lists( left: list[Interval], right: list[Interval] ) -> list[Interval]: @@ -340,159 +244,284 @@ def union_interval_lists(left: list[Interval], right: list[Interval]) -> list[In return union_rects(left + right) -def find_overlapping_windows( - windows: list[Interval], intervals: list[Interval] -) -> list[Interval]: +def default_is_same_result(interval_constructor: IntervalConstructor) -> SameResult: """ - Returns a list of windows that overlap with any interval in the intervals list. A window is included in the - result if it overlaps in any part with any of the given intervals, not just where they intersect. The entire - window is returned, not just the overlapping segment. - - :param windows: A list of windows, where each window is defined as an interval. - :param intervals: A list of intervals that are checked for overlap with the windows. - :return: A list of windows that have any overlap with the intervals. + Creates an 'is_same_result' function that determines whether two sets of active intervals + produce the same resulting interval when passed to 'interval_constructor'. + + The returned function calls: + interval_constructor(0, 0, active_intervals1) + and + interval_constructor(0, 0, active_intervals2) + and checks if the results are equal. If they match, we say they represent the “same” result. + + :param interval_constructor: An interval constructor function. + :return: + A function 'is_same_result' that compares the results of two different sets of active + intervals by invoking 'interval_constructor' on each and checking for equality. """ - # Convert all intervals and windows into events - window_events = intervals_to_events(windows, closing_offset=0) - interval_events = intervals_to_events(intervals, closing_offset=0) - - # Here we collect interval for the intersecting windows - intersecting_windows = [] - def add_segment(start, end, interval_type): - intersecting_windows.append(Interval(start, end, interval_type)) - - # State and "event handler" functions for state transitions: - # inside/not inside window, inside/not inside at least one - # interval. - previous_event = None - window_state = False - any_satisfied_in_window = False - satisfied_interval_type = None - def window_open(event_time, interval_type): - nonlocal previous_event, window_state, any_satisfied_in_window - assert not window_state - window_state = interval_type - if satisfied_interval_type is not None: - any_satisfied_in_window = satisfied_interval_type - previous_event = event_time - def window_close(event_time): - nonlocal previous_event, window_state, any_satisfied_in_window - assert window_state - if window_state == IntervalType.NOT_APPLICABLE: - interval_type = IntervalType.NOT_APPLICABLE - add_segment(previous_event, event_time, interval_type) - elif any_satisfied_in_window == False: - pass - else: - interval_type = any_satisfied_in_window - add_segment(previous_event, event_time, interval_type) - window_state = False - any_satisfied_in_window = False - previous_event = event_time - def interval_satisfied(event_time, interval_type): - nonlocal satisfied_interval_type, any_satisfied_in_window - # If we are inside a window, remember that we saw an interval - # of type interval_type. Do not overwrite previously seen - # higher priority types with lower priority types - if not (satisfied_interval_type in [IntervalType.POSITIVE, IntervalType.NEGATIVE]): - satisfied_interval_type = interval_type - if window_state: - # Priorities: POSITIVE > NEGATIVE > NO_DATA or NOT_APPLICABLE > no value - if any_satisfied_in_window == IntervalType.POSITIVE: - pass - elif any_satisfied_in_window == IntervalType.NEGATIVE: - if interval_type == IntervalType.POSITIVE: - any_satisfied_in_window = interval_type - else: - any_satisfied_in_window = interval_type - def interval_unsatisfied(event_time): - nonlocal satisfied_interval_type - satisfied_interval_type = None - - # Use two indices to traverse the two sorted lists of events in - # parallel. Call event handler functions for state transitions. - def interleaved_events(): - w_idx, i_idx = 0, 0 - while True: - window_event = window_events[w_idx] if w_idx < len(window_events) else None - interval_event = interval_events[i_idx] if i_idx < len(interval_events) else None - # When tied in terms of event time, use the following - # priority for reporting events: - # window open > interval open > interval close > window close - if window_event and (not interval_event or window_event[0] < interval_event[0] or ( - window_event[0] == interval_event[0] and window_event[1])): - w_idx += 1 - yield window_event, None - elif interval_event: - i_idx += 1 - yield None, interval_event - else: - break - active_intervals = 0 - for window_event, interval_event in interleaved_events(): - if window_event: - time, open_, interval_type = window_event - window_open(time, interval_type) if open_ else window_close(time) - else: - time, open_, type_ = interval_event - active_intervals += (1 if open_ else -1) - if active_intervals == 0: # 1 -> 0 transition - interval_unsatisfied(time) - elif active_intervals == 1: # 0 -> 1 transition - interval_satisfied(time, type_) - - # Return the list of unique intersecting windows - return intersecting_windows -def find_rectangles_with_count(all_intervals: list[list[Interval]]) -> list[IntervalWithTypeCounts]: + def is_same_result( + active_intervals1: List[GeneralizedInterval], + active_intervals2: List[GeneralizedInterval], + ) -> bool: + """ + Compares the resulting intervals for two sets of active intervals. + + :param active_intervals1: + A list of intervals (or None) describing the first track’s active intervals. + :param active_intervals2: + A list of intervals (or None) describing the second track’s active intervals. + :return: + True if 'interval_constructor(0, 0, ...)' produces the same interval for + both sets, otherwise False. + """ + # When we have to decide whether to extend a result interval + # or start a new one, we compare the state for the existing + # result interval with the new state. The states are derived + # from the respective lists of active intervals by calling + # interval_constructor (with fake points in time). + return interval_constructor(0, 0, active_intervals1) == interval_constructor( + 0, 0, active_intervals2 + ) + + return is_same_result + + +def find_rectangles( + all_intervals: list[list[AnyInterval]], + interval_constructor: IntervalConstructor, + is_same_result: SameResult | None = None, +) -> list[AnyInterval]: """ - For multiple parallel "tracks" of intervals, identify temporal - intervals in which no change occurs on any "track". For each such - interval, report the number of active intervals grouped by - interval type across all "tracks". When there is no interval on a - track for a given temporal interval, act as if a negative interval - was present there. + Low-level engine for interval construction. + + For multiple parallel "tracks" of intervals, identify segments of time + in which no change occurs on any "track". For each such segment, + call `interval_constructor(start, end, active_intervals)` to determine + how to represent the interval in the overall result. To this end, + interval_constructor receives a list "active" intervals the + elements of which are either None or an interval from + all_intervals and returns either None or an interval. The returned + None values and intervals are further processed into the overall + return value by merging adjacent intervals without "payload" + change. :param all_intervals: A list of intervals that are checked for overlap with the windows. - :return: A list of windows that have any overlap with the intervals. + :param interval_constructor: A callable that accepts a start time, + an end time and a list of "active" + intervals and returns None or an + interval. The list of active + intervals has the same length as + all_intervals and each element is + either None or an element from the + corresponding list in all_intervals. + :return: A list of intervals computed by interval_constructor such + that adjacent intervals (i.e. without gaps between them) + have different "payloads". """ - # Convert all intervals into a list of events sorted by - # time. Multiple events at the same point in time are not a - # problem here: since we simply count the number of "active" - # intervals the result does not depend on the order in which we - # process the events. + if is_same_result is None: + is_same_result = default_is_same_result(interval_constructor) + + # Convert all intervals into a single list of events sorted by + # time. Multiple events at the same point in time can be problem + # here: If an interval open event and an interval close event on + # the same track happen at the same time (which happens for + # adjacent intervals on that track), we must order the close event + # before the open event, otherwise our tracking of active + # intervals would get confused. track_count = len(all_intervals) - events = reduce(lambda acc, intervals: acc + intervals_to_events(intervals, closing_offset=0), - all_intervals, - []) - events.sort(key=lambda i: i[0]) - # The result will be a list of intervals + events: list[IntervalEventOnTrack] = [ + (time, event, interval, j) + for j, intervals in enumerate(all_intervals) + for interval in intervals + for (time, event) in [(interval.lower, True), (interval.upper, False)] + ] + event_count = len(events) + + if event_count == 0: + return [] + + def compare_events( + event1: IntervalEventOnTrack, event2: IntervalEventOnTrack + ) -> int: + """ + Sorting comparator to ensure we process events in the correct order: + - earlier time first + - if same time and same track, close events before open events + (so we don't incorrectly treat a consecutive interval on the same track + as overlapping). + + Index of event1 and event2: + - [0]: time of event + - [1]: opening (True) or closing (False) + - [2]: the interval to which the event belongs + - [3]: track index + """ + if event1[0] < event2[0]: # event1 is earlier + return -1 + elif event2[0] < event1[0]: # event2 is earlier + return 1 + elif event1[3] == event2[3]: # at the same time and on same track, + if ( + event1[2] == event2[2] + ): # same interval (we don't check for "is" because they might be different objects, but still represent the same interval) + return ( + -1 if (event1[1] is True) else 1 + ) # sort open events before open events + else: # different intervals + return ( + -1 if (event1[1] is False) else 1 + ) # sort close events before open events + else: # at the same time, but different tracks => any order is fine + return 1 + + # Sort events chronologically according to compare_events + events.sort(key=cmp_to_key(compare_events)) + + active_intervals: list[GeneralizedInterval] = [None] * track_count + + def finalize_interval( + interval_start_time: int, + current_time: int, + interval_start_state: List[GeneralizedInterval], + ) -> None: + """ + Appends a new time slice (interval_start_time -> current_time) to 'result_intervals', + ensuring we don't create duplicate adjacency boundaries if the previous slice ends + exactly where the new one starts. + """ + if len(result_intervals) > 0: + previous_result = result_intervals[-1] + if previous_result[1] == interval_start_time: + # Adjust the previous slice so it doesn't overlap or duplicate + result_intervals[-1] = ( + previous_result[0], + previous_result[1] - 1, + previous_result[2], + ) + + # Now finalize the current slice + result_intervals.append( + (interval_start_time, current_time, interval_start_state) + ) + + def process_events_for_point_in_time( + index: int, point_time: int + ) -> Tuple[int, int, int] | None: + """ + Consumes events that occur at `point_time` (or effectively that boundary), + updating 'active_intervals' for whichever track is opening or closing + intervals at that time. + + Returns: (new_index, new_time, copy_of_active_intervals, high_time) + - new_index: the index of the first event not processed (because it's after point_time) + - new_time: the time of that next event + - copy_of_active_intervals: a snapshot of 'active_intervals' after processing + - high_time: the highest time covered by these events (may be the same as point_time + or point_time + 1 if we consider inclusive boundaries). + + If we run out of events entirely, returns (None, None, None, None). + """ + high_time = point_time + any_open = False + + for i in range(index, event_count): + time, open_, interval, track = events[i] + # Since points in time for intervals are quantized to whole + # seconds and intervals are closed (inclusive) for both start + # and end points, two adjacent intervals like + # [START_TIME1, 10:59:59] [11:00:00, END_TIME2] + # have no gap between them and can be considered a single + # continuous interval [START_TIME1, END_TIME2]. + + point_interval_closing = ( + any_open + and not open_ + and interval.lower == interval.upper == point_time + ) + + if ((point_time == time) and not point_interval_closing) or ( + open_ and (point_time == time - 1) + ): + if time > high_time: + high_time = time + any_open |= open_ + else: + # As soon as we find an event that’s clearly beyond the cluster at point_time, + # we break and return + return ( + i, + time, + high_time if any_open else high_time + 1, + ) + + # Opening => set this track’s active interval to the new interval + # Closing => set it to None + active_intervals[track] = interval if open_ else None + + # If we exit the loop fully, we used all events + return None + + # Step through event "clusters" with a common point in time and + # emit result intervals with unchanged interval "payload". + index: int | None = 0 + time: int | None = events[index][0] # type: ignore[index] + interval_start_time: int = cast(int, time) + result_intervals: list[tuple[int, int, List[GeneralizedInterval]]] = [] + + if time is None: + # No events at all + return [] + + # process the event at index 0 at the first timepoint + res = process_events_for_point_in_time(cast(int, index), cast(int, time)) + + if res is None: + return [] + + index, time, high_time = res + + interval_start_state = active_intervals.copy() + + # The main loop: step through event clusters + while True: + res = process_events_for_point_in_time(index, time) + if res is None: + # No more events => finalize the last slice and break + finalize_interval(interval_start_time, time, interval_start_state) + break + + new_index, new_time, high_time = res + + # Diagram for this program point: + # |___potential_result_interval___|| | + # index new_index + # interval_start_time time new_time + # interval_start_state maybe_end_state + # high_time + + # We have a region from [interval_start_time, time) or [interval_start_time, time] + # with 'interval_start_state' as the active intervals. + # Decide if we finalize that region or if we can merge with the next region. + if not is_same_result(interval_start_state, active_intervals): + # If the active intervals changed, finalize the old slice + finalize_interval(interval_start_time, time, interval_start_state) + + # Update interval start info. + interval_start_time = high_time + interval_start_state = active_intervals.copy() + + index, time = new_index, new_time + result = [] - def add_segment(start, end, type_counts): - # We consider the period between 23:59:59 of a day and - # 00:00:00 of the following day to be empty. - if not (start == end - 1 and datetime.datetime.fromtimestamp(end).time() == datetime.time(0, 0, 0)): - # Assume implicit negative intervals: increase the count - # for the negative type as needed so that the overall - # count is equal to track_count. - missing = track_count - sum(type_counts.values()) - if missing > 0: - type_counts[IntervalType.NEGATIVE] = type_counts.get(IntervalType.NEGATIVE, 0) + missing - result.append(IntervalWithTypeCounts(start, end, type_counts)) - - # Step through events and emit result intervals whenever the - # counts change. - counts = dict() - previous_time = events[0][0] - for (time, open_, interval_type) in events: - if previous_time is None: - previous_time = time - elif not previous_time == time: - add_segment(previous_time, time, copy.copy(counts)) - previous_time = time - - old_count = counts.get(interval_type, 0) - counts[interval_type] = old_count + (1 if open_ else -1) + + # Finally, convert the (start, end, intervals) slices into actual Interval objects + for start, end, intervals in result_intervals: + interval = interval_constructor(start, end, intervals) + + if interval is not None: + result.append(interval) return result diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index daf28fa1..78d541c4 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -1,8 +1,11 @@ import base64 +import copy import datetime import json import logging +from collections import Counter from enum import Enum, auto +from typing import Callable, List, Type, cast from sqlalchemy.exc import DBAPIError, IntegrityError, ProgrammingError, SQLAlchemyError @@ -10,15 +13,74 @@ from execution_engine.constants import CohortCategory from execution_engine.omop.criterion.abstract import Criterion from execution_engine.omop.db.celida.tables import ResultInterval -from execution_engine.omop.sqlclient import OMOPSQLClient +from execution_engine.omop.sqlclient import OMOPSQLClient, datetime_cols_to_epoch from execution_engine.settings import get_config -from execution_engine.task.process import Interval, get_processing_module +from execution_engine.task.process import ( + AnyInterval, + GeneralizedInterval, + Interval, + IntervalWithCount, + get_processing_module, + interval_like, + timerange_to_interval, +) from execution_engine.util.enum import TimeIntervalType from execution_engine.util.interval import IntervalType -from execution_engine.util.types import PersonIntervals, TimeRange +from execution_engine.util.types import PersonIntervals +from execution_engine.util.types.timerange import TimeRange process = get_processing_module() +COUNT_TYPES = ( + logic.MinCount, + logic.ExactCount, + logic.CappedMinCount, +) + + +def default_interval_union_with_count( + start: int, end: int, intervals: List[GeneralizedInterval] +) -> IntervalWithCount: + """ + Default interval counting function to be used in logic.Or + """ + result_type = None + result_count = 0 + for interval in intervals: + if interval is None: + interval_type, interval_count = IntervalType.NEGATIVE, 0 + else: + interval_type, interval_count = interval.type, cast(int, interval.count) + if ( + ( + interval_type is IntervalType.POSITIVE + and result_type is not IntervalType.POSITIVE + ) + or ( + interval_type is IntervalType.NO_DATA + and result_type is not IntervalType.POSITIVE + and result_type is not IntervalType.NO_DATA + ) + or ( + interval_type is IntervalType.NEGATIVE + and (result_type is IntervalType.NOT_APPLICABLE or result_type is None) + ) + or (interval_type is IntervalType.NOT_APPLICABLE and result_type is None) + ): + result_type = interval_type + result_count = 0 + result_count += interval_count + return IntervalWithCount(start, end, result_type, result_count) + + +def default_interval_intersect_with_count( + start: int, end: int, intervals: List[GeneralizedInterval] +) -> IntervalWithCount: + """ + Default interval counting function to be used in logic.Or + """ + raise NotImplementedError() + def get_engine() -> OMOPSQLClient: """ @@ -93,19 +155,16 @@ def find_base_task(task: Task) -> Task: return find_base_task(self) - def select_predecessor_result( - self, expr: logic.BaseExpr, data: list[PersonIntervals] - ) -> PersonIntervals: + def get_predecessor_data_index(self, expr: logic.BaseExpr) -> int: """ - Select the result results of the predecessor task from the given expression. + Get the index of the predecessor data from the given expression. This is required in expressions where order is important, e.g. in BinaryNonCommutativeOperator. As the nx.DiGraph (and by inheritance, ExecutionGraph) does not store the order of the predecessors, we need to find the predecessor task by its expression and select the result from the data. :param expr: The expression of the predecessor task. - :param data: The input data. - :return: The result of the predecessor task. + :return: The index in of expr in the data of predecessor results """ if len(self.dependencies) == 0: raise ValueError("Task has no dependencies.") @@ -117,7 +176,42 @@ def select_predecessor_result( f"Task with expression '{str(expr)}' not found in dependencies." ) - return data[idx] + return idx + + def select_predecessor_result( + self, expr: logic.BaseExpr, data: list[PersonIntervals] + ) -> PersonIntervals: + """ + Select the result results of the predecessor task from the given expression. + + This is required in expressions where order is important, e.g. in BinaryNonCommutativeOperator. + As the nx.DiGraph (and by inheritance, ExecutionGraph) does not store the order of the predecessors, + we need to find the predecessor task by its expression and select the result from the data. + + :param expr: The expression of the predecessor task. + :param data: The input data. + :return: The result of the predecessor task. + """ + return data[self.get_predecessor_data_index(expr)] + + def receives_only_count_inputs(self) -> bool: + """ + Indicates whether this tasks only receives inputs from expression that perform counting and thus return + IntervalWithCount. + """ + # all arguments are count types + if all(isinstance(parent, COUNT_TYPES) for parent in self.expr.args): + return True + + # all arguments are logic.BinaryNonCommutativeOperator, and all of their "right" children are count types + if all( + isinstance(parent, logic.BinaryNonCommutativeOperator) + and isinstance(parent.right, COUNT_TYPES) + for parent in self.expr.args + ): + return True + + return False def run( self, @@ -226,8 +320,10 @@ def handle_criterion( :param observation_window: The observation window. :return: A DataFrame with the result of the query. """ + engine = get_engine() query = criterion.create_query() + query = datetime_cols_to_epoch(query) engine.log_query(query, params=bind_params) logging.debug(f"Running query - '{criterion.description()}'") @@ -267,6 +363,7 @@ def handle_unary_logical_operator( :return: A DataFrame with the inverted intervals. """ assert self.expr.is_Not, "Dependency is not a Not expression." + assert len(data) == 1, "Unary operators require only one input" result = process.invert_intervals( data[0], @@ -294,42 +391,114 @@ def handle_binary_logical_operator( return data[0] if isinstance(self.expr, (logic.And, logic.NonSimplifiableAnd)): - result = process.intersect_intervals(data) + + if self.receives_only_count_inputs() and hasattr( + self.expr, "count_intervals" + ): + # we check if there are custom data preparatory and interval counting functions and use these + prepare_func = getattr(self.expr, "prepare_data", None) + if prepare_func: + data = prepare_func(self, data) + func = getattr( + self.expr, "count_intervals", default_interval_intersect_with_count + ) + result = process.find_rectangles(data, func) + + else: + result = process.intersect_intervals(data) + elif isinstance(self.expr, (logic.Or, logic.NonSimplifiableOr)): - result = process.union_intervals(data) + if self.receives_only_count_inputs() and hasattr( + self.expr, "count_intervals" + ): + # we check if there are custom data preparatory and interval counting functions and use these + prepare_func = getattr(self.expr, "prepare_data", None) + if prepare_func: + data = prepare_func(self, data) + func = getattr( + self.expr, "count_intervals", default_interval_union_with_count + ) + result = process.find_rectangles(data, func) + else: + result = process.union_intervals(data) + elif isinstance(self.expr, logic.Count): - result = process.count_intervals(data) - result = process.filter_count_intervals( - result, - min_count=self.expr.count_min, - max_count=self.expr.count_max, - ) - elif isinstance(self.expr, logic.CappedCount): - intervals_with_count = process.find_rectangles_with_count(data) - result = dict() - for key, intervals in intervals_with_count.items(): + count_min = self.expr.count_min + count_max = self.expr.count_max + if count_min is None and count_max is None: + raise ValueError("count_min and count_max cannot both be None") + if count_min is None: + count_min = 0 + + def interval_counts( + start: int, end: int, intervals: List[GeneralizedInterval] + ) -> GeneralizedInterval: + + # Count the different interval types. None represents + # implicit negative intervals and is counted as such. + counts = Counter( + (interval.type if interval else IntervalType.NEGATIVE) + for interval in intervals + ) + + # Either the count constraints or the interval type + # with the highest "union priority" determines the + # result. + positive_count = counts[IntervalType.POSITIVE] + if positive_count > 0 or count_min == 0: + if count_min == 0: + if positive_count <= count_max: # type: ignore[operator] + return Interval(start, end, IntervalType.POSITIVE) + else: + return None # Implicit negative interval + else: + min_good = count_min <= positive_count + max_good = (count_max is None) or (positive_count <= count_max) + interval_type = ( + IntervalType.POSITIVE + if (min_good and max_good) + else IntervalType.NEGATIVE + ) + ratio = positive_count / count_min + return IntervalWithCount(start, end, interval_type, ratio) + if counts[IntervalType.NO_DATA] > 0: + return IntervalWithCount(start, end, IntervalType.NO_DATA, 0) + if counts[IntervalType.NOT_APPLICABLE] > 0: + return IntervalWithCount(start, end, IntervalType.NOT_APPLICABLE, 0) + if counts[IntervalType.NEGATIVE] > 0: + return IntervalWithCount(start, end, IntervalType.NEGATIVE, 0) - key_result = [] + raise ValueError("No intervals of any kind found") + result = process.find_rectangles(data, interval_counts) + + elif isinstance(self.expr, logic.CappedCount): + + def interval_counts( + start: int, end: int, intervals: List[AnyInterval] + ) -> GeneralizedInterval: + positive_count = 0 + not_applicable_count = 0 for interval in intervals: - counts = interval.counts - not_applicable_count = counts.get(IntervalType.NOT_APPLICABLE, 0) + if interval is None or interval.type == IntervalType.NEGATIVE: + pass + elif interval.type == IntervalType.POSITIVE: + positive_count += 1 + elif interval.type == IntervalType.NOT_APPLICABLE: + not_applicable_count += 1 + # we require at least one positive interval to be present in any case (hence the max(1, ...)) + effective_count_min = min( + self.expr.count_min, max(1, len(intervals) - not_applicable_count) # type: ignore[attr-defined] + ) + if positive_count >= effective_count_min: + effective_type = IntervalType.POSITIVE + else: + effective_type = IntervalType.NEGATIVE + ratio = positive_count / effective_count_min + return IntervalWithCount(start, end, effective_type, ratio) + + result = process.find_rectangles(data, interval_counts) - # we require at least one positive interval to be present in any case (hence the max(1, ...)) - effective_count_min = max( - 1, self.expr.count_min - not_applicable_count - ) - positive_count = counts.get(IntervalType.POSITIVE, 0) - effective_type = ( - IntervalType.POSITIVE - if positive_count >= effective_count_min - else IntervalType.NEGATIVE - ) - key_result.append( - Interval(interval.lower, interval.upper, effective_type) - ) - result[key] = key_result - return result elif isinstance(self.expr, logic.AllOrNone): raise NotImplementedError("AllOrNone is not implemented yet.") else: @@ -399,62 +568,103 @@ def handle_left_dependent_toggle( self.expr, (logic.LeftDependentToggle, logic.ConditionalFilter) ), "Dependency is not a LeftDependentToggle or ConditionalFilter expression." - # data[0] is the left dependency (i.e. P) - # data[1] is the right dependency (i.e. I) - - data_p = process.select_type(left, IntervalType.POSITIVE) - - if isinstance(self.expr, logic.LeftDependentToggle): - interval_type = IntervalType.NOT_APPLICABLE - elif isinstance(self.expr, logic.ConditionalFilter): - interval_type = IntervalType.NEGATIVE - - result_not_p = process.complementary_intervals( - data_p, - reference=base_data, - observation_window=observation_window, - interval_type=interval_type, + # window_intervals extends the result to the correct temporal + # range; Its type is not important. + # use a tuple for windows to make sure it is immutable (and can be shared by all persons) + windows = ( + timerange_to_interval(observation_window, type_=IntervalType.POSITIVE), ) + window_intervals = {key: windows for key in left.keys()} - result_p_and_i = process.intersect_intervals([data_p, right]) - - result = process.concat_intervals([result_not_p, result_p_and_i]) + if isinstance(self.expr, logic.LeftDependentToggle): + fill_type = IntervalType.NOT_APPLICABLE + else: + assert isinstance(self.expr, logic.ConditionalFilter) + fill_type = IntervalType.NEGATIVE - # fill remaining time with NEGATIVE - result_no_data = process.complementary_intervals( - result, - reference=base_data, - observation_window=observation_window, - interval_type=IntervalType.NEGATIVE, + interval_type: ( + Callable[[int, int, IntervalType], IntervalWithCount] | Type[Interval] ) - result = process.concat_intervals([result, result_no_data]) - - return result + # if all incoming data of "right" are count types, we will create a count type as well, to + # allow summing in the next layer + if isinstance(self.expr.right, logic.Expr) and all( + isinstance(parent, COUNT_TYPES) for parent in self.expr.right.args + ): + interval_type = lambda start, end, fill_type: IntervalWithCount( + start, end, fill_type, None + ) + else: + interval_type = Interval + + def new_interval( + start: int, end: int, intervals: List[GeneralizedInterval] + ) -> GeneralizedInterval: + left_interval, right_interval, observation_window_ = intervals + if (left_interval is None) or left_interval.type != IntervalType.POSITIVE: + # no left_interval or not positive -> use fill type + return interval_type(start, end, fill_type) + elif right_interval is not None: + return interval_like(right_interval, start, end) + else: # left_interval but not right_interval -> implicit negative + return None + + return process.find_rectangles([left, right, window_intervals], new_interval) def handle_temporal_operator( self, data: list[PersonIntervals], observation_window: TimeRange ) -> PersonIntervals: """ - Handles a TemporalCount operator. + Handles a TemporalCount operator, which checks whether a certain condition + (e.g., an event or observation) occurs a specific number of times in each defined test interval. + + Note: Currently, only TemporalMinCount(*, threshold=1) is supported. + + This method can do one of two main things: + 1) If `self.expr.interval_criterion` is set, it will intersect ("find overlapping") + a set of personal indicator windows with your data, returning intervals that + represent when the criterion is met within those windows. + 2) If `interval_criterion` is not set, it will create time slices (intervals) for + the specified `TimeIntervalType` (e.g., morning shift) within `observation_window` + and determine whether the data meets the condition inside each slice. + + For instance, if your expression says 'TemporalCount(ANY_TIME, count_min=1)', then + the entire observation_window is treated as one big interval, and we only need to + check if we see at least one positive data interval in that timeframe. + + :param data: A list of dictionaries mapping person ID -> list of intervals. + Typically, the intervals from your workflow or dataset. + :param observation_window: The overall time range we’re interested in. + :return: A dictionary mapping each person to a list of resulting intervals + (e.g., each labeled with POSITIVE, NEGATIVE, NOT_APPLICABLE, etc.). + """ + assert isinstance(self.expr, logic.TemporalCount) - May be used to aggregate multiple criteria in a temporal manner, e.g. to count the number of times a certain - condition is met within a certain time frame (e.g. morning shift). + if self.expr.interval_criterion is not None: + # If we have an interval_criterion, we expect exactly two data streams: + # (1) The "main" data + # (2) The "indicator" data (e.g., personal windows or ICU periods) + assert ( + len(data) == 2 + ), f"TemporalCount with indicator criterion requires exactly two input streams, got {len(data)}" - :param data: The input data. - :param observation_window: The observation window. - :return: A DataFrame with the merged intervals. - """ + indicator_personal_windows = data.pop( + self.get_predecessor_data_index(self.expr.interval_criterion) + ) + + assert ( + len(data) == 1 + ), f"TemporalCount requires exactly one input streams, got {len(data)}" - data_p = data[0] - # data_p = process.select_type(data[0], IntervalType.POSITIVE) - # data_p = {key: val for key, val in data_p.items() if val} + data_arg = data[0] def get_start_end_from_interval_type( type_: TimeIntervalType, ) -> tuple[datetime.time, datetime.time]: """ - Returns the start and end time for a given TimeIntervalType, read from the configuration. + Reads the config for the given TimeIntervalType, returning its start and end times. + + For example, 'MORNING_SHIFT' might map to 06:00 - 14:00, if configured that way. """ try: cnf = getattr(get_config().time_intervals, type_.value) @@ -464,28 +674,32 @@ def get_start_end_from_interval_type( assert isinstance(self.expr, logic.TemporalCount), "Invalid expression type" - if self.expr.interval_criterion is not None: + if self.expr.count_min != 1 or self.expr.count_max is not None: + raise NotImplementedError( + "Currently, only TemporalMinCount(*, threshold=1) is supported." + ) - # last element is the indicator windows - assert ( - len(data) >= 2 - ), "TemporalCount with indicator criterion requires at least two inputs" - data, indicator_personal_windows = data[:-1], data[-1] + if self.expr.interval_criterion is not None: + # Filter out only the POSITIVE intervals from the data + data_positive = process.select_type(data_arg, IntervalType.POSITIVE) + # Overlap the personal windows with the data intervals to see if the condition + # is met for each relevant chunk in the personal window. result = process.find_overlapping_personal_windows( - indicator_personal_windows, data_p + indicator_personal_windows, data_positive ) else: - + # If we have no `interval_criterion`, we must create the intervals ourselves + # from the observation_window (e.g. ANY_TIME vs MORNING_SHIFT). if self.expr.interval_type == TimeIntervalType.ANY_TIME: - indicator_windows = [ - Interval( - lower=observation_window.start.timestamp(), - upper=observation_window.end.timestamp(), - type=IntervalType.POSITIVE, - ) - ] + # Just one interval covering the entire observation window + indicator_windows = ( + timerange_to_interval( + observation_window, type_=IntervalType.POSITIVE + ), + ) else: + # If we do have a known interval type, or explicit start/end times, build them: if self.expr.interval_type is not None: start_time, end_time = get_start_end_from_interval_type( self.expr.interval_type @@ -497,6 +711,8 @@ def get_start_end_from_interval_type( else: raise ValueError("Invalid time interval settings") + # Create repeated intervals for each day from observation_window.start + # up to observation_window.end, using e.g. "06:00 - 14:00" if it's morning, etc. indicator_windows = process.create_time_intervals( start_datetime=observation_window.start, end_datetime=observation_window.end, @@ -506,7 +722,125 @@ def get_start_end_from_interval_type( timezone=get_config().timezone, ) - result = process.find_overlapping_windows(indicator_windows, data_p) + # We'll track each "window interval" by its object ID, storing the best + # result type found so far (NEGATIVE, POSITIVE, or NOT_APPLICABLE). + window_types: dict[int, IntervalType] = dict() + + def update_window_type( + window_interval: GeneralizedInterval, data_interval: GeneralizedInterval + ) -> IntervalType: + """ + Called whenever we cross a boundary in time. + + window_interval: The current "indicator" interval (e.g., a morning slice). + data_interval: The data interval from data_arg that overlaps with this moment, + could be POSITIVE, NEGATIVE, NO_DATA, NOT_APPLICABLE (or None, + which is interpreted as NEGATIVE). + + This function decides how to update the 'window interval type' based on new info. + """ + current_type = window_types.get( + id(window_interval), IntervalType.NOT_APPLICABLE + ) + + # If the data interval is NEGATIVE (or equivalently, None) or NO_DATA, we treat it as NEGATIVE + if ( + data_interval is None + or data_interval.type is IntervalType.NO_DATA + or data_interval.type is IntervalType.NEGATIVE + ): + # Set current_type to NEGATIVE if it is not already POSITIVE (because a POSITIVE data interval + # was passed earlier) + if current_type is IntervalType.NOT_APPLICABLE: + current_type = IntervalType.NEGATIVE + + elif data_interval.type is IntervalType.POSITIVE: + # If the data is POSITIVE, set the window to POSITIVE, overriding any negative state. + current_type = IntervalType.POSITIVE + + window_types[id(window_interval)] = current_type + + return current_type + + def is_same_interval( + left_intervals: List[GeneralizedInterval], + right_intervals: List[GeneralizedInterval], + ) -> bool: + """ + Helper used by ` process.find_rectangles()`. + + The framework calls this to decide if two adjacent intervals share + the same "payload" and can be merged into a single interval. + + 'left_intervals' and 'right_intervals' are each a pair [window_interval, data_interval], + for the previous block vs. the new block. We update the 'window_types' dict with + whatever is found in the new block, and return True if we can keep merging them. + """ + left_window_interval, left_data_interval = left_intervals + right_window_interval, right_data_interval = right_intervals + + # If the next block's window interval is None, it means we're off-limits + # or there's no overlap. We update the left block's type and return False. + if right_window_interval is None: + if left_window_interval is None: + return True + else: + update_window_type(left_window_interval, left_data_interval) + return False + else: + # We do have a right_window_interval, so update it with the right_data_interval + update_window_type(right_window_interval, right_data_interval) + + # If the left window is None, can't be the same interval + if left_window_interval is None: + return False + else: + # If left_window_interval and right_window_interval are literally + # the same object, we can treat them as the same interval + if left_window_interval is right_window_interval: + return True + else: + # They’re different intervals => finalize left and move on + update_window_type(left_window_interval, left_data_interval) + return False + + def result_interval( + start: int, end: int, intervals: List[AnyInterval] + ) -> AnyInterval: + """ + Called at the end of building each interval, to produce the final interval + object that will go into the result. + + 'intervals' is [window_interval, data_interval], where either can be None. + + We look up window_interval in window_types to see whether we've determined + it is POSITIVE, NEGATIVE, or NOT_APPLICABLE. + """ + window_interval, data_interval = intervals + if ( + window_interval is None + or window_interval.type is IntervalType.NOT_APPLICABLE + ): + return Interval(start, end, IntervalType.NOT_APPLICABLE) + else: + window_type = window_types.get(id(window_interval), None) + if window_type is None: + window_type = update_window_type(window_interval, data_interval) + return Interval(start, end, window_type) + + # Make separate copies of the intervals for each person so + # that the object identity of each interval is unique and + # can be used as a dictionary key. + person_indicator_windows = { + key: [copy.copy(window) for window in indicator_windows] + for key in data_arg.keys() + } + + result = process.find_rectangles( + [person_indicator_windows, data_arg], + result_interval, + is_same_result=is_same_interval, + ) return result @@ -527,17 +861,30 @@ def insert_negative_intervals( :param observation_window: The observation window. :return: A DataFrame with the merged intervals. """ - - data_negative = process.complementary_intervals( - data, - reference=base_data, - observation_window=observation_window, - interval_type=IntervalType.NEGATIVE, + # window_intervals extends the result to the correct temporal + # range and forces results to be computed for patients that + # are not represented in data; The interval types in + # window_intervals are not important. + # use a tuple for windows to make sure it is immutable (and can be shared by all persons) + windows = ( + timerange_to_interval(observation_window, type_=IntervalType.POSITIVE), ) + all_keys = data.keys() | base_data.keys() + window_intervals = {key: windows for key in all_keys} + + def create_interval( + start: int, end: int, intervals: List[GeneralizedInterval] + ) -> GeneralizedInterval: + interval, window_interval = intervals + if interval is not None: + return interval_like(interval, start, end) + else: + # Explicit representation of negative intervals is + # required here because the database views do not + # understand the implicit representation. + return Interval(start, end, IntervalType.NEGATIVE) - result = process.concat_intervals([data, data_negative]) - - return result + return process.find_rectangles([data, window_intervals], create_interval) def store_result_in_db( self, @@ -579,6 +926,19 @@ def store_result_in_db( cohort_category=self.category, ) + def interval_data(interval: AnyInterval) -> dict: + data = dict( + interval_start=interval.lower, + interval_end=interval.upper, + interval_type=interval.type, + ) + if isinstance(interval, Interval): + data["interval_ratio"] = None + else: + assert isinstance(interval, IntervalWithCount) + data["interval_ratio"] = interval.count + return data + try: with get_engine().begin() as conn: conn.execute( @@ -586,9 +946,7 @@ def store_result_in_db( [ { "person_id": person_id, - "interval_start": normalized_interval.lower, - "interval_end": normalized_interval.upper, - "interval_type": normalized_interval.type, + **interval_data(normalized_interval), **params, } for person_id, intervals in result.items() diff --git a/execution_engine/util/interval/typed_interval.py b/execution_engine/util/interval/typed_interval.py index aa35696f..d9d89cb0 100644 --- a/execution_engine/util/interval/typed_interval.py +++ b/execution_engine/util/interval/typed_interval.py @@ -82,9 +82,14 @@ class IntervalType(StrEnum): criterion/combination of criteria is/are not satisfied. """ - __union_priority_order: list[str] = [POSITIVE, NO_DATA, NOT_APPLICABLE, NEGATIVE] + # UNION PRIORITY ORDER + # from Mar 24, 2025: changed order of NEGATIVE and NOT_APPLICABLE, because when OR combining two population/inter- + # vention expressions, and one has NOT_APPLICABLE whereas the other has NEGATIVE, the overall result should be + # NEGATIVE, not NOT_APPLICABLE. + __union_priority_order: list[str] = [POSITIVE, NO_DATA, NEGATIVE, NOT_APPLICABLE] """Union priority order starting with the highest priority.""" + # INTERSECTION PRIORITY ORDER # until Jan 13, 2024: # POSITIVE has higher priority than NO_DATA, as in measurements we return NO_DATA intervals for all intervals # inbetween measurements (and outside), and these are &-ed with the POSITIVE intervals for e.g. conditions. diff --git a/execution_engine/util/logic.py b/execution_engine/util/logic.py index 0d11f587..84c82696 100644 --- a/execution_engine/util/logic.py +++ b/execution_engine/util/logic.py @@ -1011,6 +1011,14 @@ class LeftDependentToggle(BinaryNonCommutativeOperator): """ A LeftDependentToggle object represents a logical AND operation if the left operand is positive, otherwise it returns NOT_APPLICABLE. + + | left | right | Result | + |----------|----------|----------| + | NEGATIVE | * | NOT_APPLICABLE | + | NO_DATA | * | NOT_APPLICABLE | + | POSITIVE | POSITIVE | POSITIVE | + | POSITIVE | NEGATIVE | NEGATIVE | + | POSITIVE | NO_DATA | NO_DATA | """ @@ -1019,7 +1027,6 @@ class ConditionalFilter(BinaryNonCommutativeOperator): A ConditionalFilter object returns the right operand if the left operand is POSITIVE, and NEGATIVE otherwise - A conditional filter returns `right` iff `left` is POSITIVE, otherwise NEGATIVE. | left | right | Result | diff --git a/execution_engine/util/types.py b/execution_engine/util/types/__init__.py similarity index 68% rename from execution_engine/util/types.py rename to execution_engine/util/types/__init__.py index 7ff9982a..5fffe8b2 100644 --- a/execution_engine/util/types.py +++ b/execution_engine/util/types/__init__.py @@ -1,96 +1,14 @@ -from datetime import date, datetime, timedelta from typing import Any -import pendulum -import pytz from pydantic import BaseModel, ConfigDict, field_validator, model_validator +from execution_engine.task.process import AnyInterval from execution_engine.util import serializable from execution_engine.util.enum import TimeUnit -from execution_engine.util.interval import ( - DateTimeInterval, - IntervalType, - interval_datetime, -) from execution_engine.util.value import ValueNumber, ValueNumeric from execution_engine.util.value.time import ValueCount, ValueDuration, ValuePeriod -PersonIntervals = dict[int, Any] - - -@serializable.register_class -class TimeRange(BaseModel): - """ - A time range. - """ - - start: datetime - end: datetime - name: str | None = None - - @field_validator("start", "end") - def check_timezone(cls, v: datetime) -> datetime: - """ - Check that the start, end parameters are timezone-aware. - """ - if not v.tzinfo: - raise ValueError("Datetime object must be timezone-aware") - - # workaround to fix pd.testing.assert_frame_equal errors when tzinfo is of type - # pydantic_core._pydantic_core.TzInfo, which somehow triggers an error when comparing a dataframe with a that - # tz in a datetime column vs a pytz.UTC tz column. - if v.tzinfo == pytz.UTC: - return v.astimezone(pytz.utc) - - return v - - @classmethod - def from_tuple( - cls, dt: tuple[datetime | str, datetime | str], name: str | None = None - ) -> "TimeRange": - """ - Create a time range from a tuple of datetimes. - """ - return cls(start=dt[0], end=dt[1], name=name) - - @property - def period(self) -> pendulum.Interval: - """ - Get the period of the time range. - """ - return pendulum.interval(start=self.start.date(), end=self.end.date()) - - def date_range(self) -> set[date]: - """ - Get the date range of the time range. - """ - return set(self.period.range("days")) - - @property - def duration(self) -> timedelta: - """ - Get the duration of the time range. - """ - return self.end - self.start - - def interval(self, type_: IntervalType) -> DateTimeInterval: - """ - Get the interval of the time range. - - :param type_: The type of interval to get. - :return: The interval. - """ - return interval_datetime(self.start, self.end, type_=type_) - - def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, datetime]: - """ - Get the dictionary representation of the time range. - """ - prefix = self.name + "_" if self.name else "" - return { - prefix + "start_datetime": self.start, - prefix + "end_datetime": self.end, - } +PersonIntervals = dict[int, AnyInterval] @serializable.register_class diff --git a/execution_engine/util/types/timerange.py b/execution_engine/util/types/timerange.py new file mode 100644 index 00000000..b1ed30bd --- /dev/null +++ b/execution_engine/util/types/timerange.py @@ -0,0 +1,88 @@ +from datetime import date, datetime, timedelta +from typing import Any + +import pendulum +import pytz +from pydantic import BaseModel, field_validator + +from execution_engine.util import serializable +from execution_engine.util.interval import ( + DateTimeInterval, + IntervalType, + interval_datetime, +) + + +@serializable.register_class +class TimeRange(BaseModel): + """ + A time range. + """ + + start: datetime + end: datetime + name: str | None = None + + @field_validator("start", "end") + def check_timezone(cls, v: datetime) -> datetime: + """ + Check that the start, end parameters are timezone-aware. + """ + if not v.tzinfo: + raise ValueError("Datetime object must be timezone-aware") + + # workaround to fix pd.testing.assert_frame_equal errors when tzinfo is of type + # pydantic_core._pydantic_core.TzInfo, which somehow triggers an error when comparing a dataframe with a that + # tz in a datetime column vs a pytz.UTC tz column. + if v.tzinfo == pytz.UTC: + return v.astimezone(pytz.utc) + + return v + + @classmethod + def from_tuple( + cls, dt: tuple[datetime | str, datetime | str], name: str | None = None + ) -> "TimeRange": + """ + Create a time range from a tuple of datetimes. + """ + return cls(start=dt[0], end=dt[1], name=name) + + @property + def period(self) -> pendulum.Interval: + """ + Get the period of the time range. + """ + return pendulum.interval(start=self.start.date(), end=self.end.date()) + + def date_range(self) -> set[date]: + """ + Get the date range of the time range. + """ + return set(self.period.range("days")) + + @property + def duration(self) -> timedelta: + """ + Get the duration of the time range. + """ + return self.end - self.start + + def interval(self, type_: IntervalType) -> DateTimeInterval: + """ + Get the interval of the time range. + + :param type_: The type of interval to get. + :return: The interval. + """ + return interval_datetime(self.start, self.end, type_=type_) + + def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, datetime]: + """ + Get the dictionary representation of the time range. + """ + prefix = self.name + "_" if self.name else "" + return { + prefix + "start_datetime": self.start, + prefix + "end_datetime": self.end, + } diff --git a/execution_engine/util/value/value.py b/execution_engine/util/value/value.py index fcda3dc6..5b6c8fb6 100644 --- a/execution_engine/util/value/value.py +++ b/execution_engine/util/value/value.py @@ -362,4 +362,4 @@ def __str__(self) -> str: """ Get the string representation of the value. """ - return f"Value == {repr(self.value)}" + return f"Value == {str(self.value)}" diff --git a/setup.py b/setup.py index 67dac9a2..1d65c572 100644 --- a/setup.py +++ b/setup.py @@ -6,6 +6,6 @@ setup( name="Interval Extension", - ext_modules=cythonize(ext_modules), + ext_modules=cythonize(ext_modules, compiler_directives={"language_level": "3"}), include_dirs=[numpy.get_include()], ) diff --git a/tests/_fixtures/concept.py b/tests/_fixtures/concept.py index fd14e338..b5ef855d 100644 --- a/tests/_fixtures/concept.py +++ b/tests/_fixtures/concept.py @@ -51,14 +51,14 @@ def unit_concept(): invalid_reason=None, ) -concept_delir_screening = Concept( # TODO(jmoringe): copied from above (concept_artificial_respiration) - concept_id=4230167, +concept_delir_screening = Concept( + concept_id=4196006, # TODO(jmoringe): made-up id and code concept_name="Delir Screening", - domain_id="Procedure", + domain_id="Measurement", vocabulary_id="SNOMED", concept_class_id="Procedure", standard_concept="S", - concept_code="40617009", + concept_code="431182000", invalid_reason=None, ) @@ -283,6 +283,18 @@ def unit_concept(): invalid_reason=None, ) +concept_body_height: Concept = Concept( + concept_id=3036277, + concept_name="Body height", + domain_id="Measurement", + vocabulary_id="LOINC", + concept_class_id="Clinical Observation", + standard_concept="S", + concept_code="8302-2", + invalid_reason=None, +) + + """ The following list of concepts are heparin drugs and all of them directly map to heparin as ingredient (via ancestor, not relationship !) diff --git a/tests/_fixtures/omop_fixture.py b/tests/_fixtures/omop_fixture.py index c682bc55..49707d8d 100644 --- a/tests/_fixtures/omop_fixture.py +++ b/tests/_fixtures/omop_fixture.py @@ -10,7 +10,7 @@ from sqlalchemy import create_engine, event, text from sqlalchemy.orm.session import sessionmaker -from execution_engine.util.types import TimeRange +from execution_engine.util.types.timerange import TimeRange logging.basicConfig() logger = logging.getLogger() diff --git a/tests/execution_engine/omop/criterion/combination/test_logical_combination.py b/tests/execution_engine/omop/criterion/combination/test_logical_combination.py index 6c83d4b2..0f5d2ec2 100644 --- a/tests/execution_engine/omop/criterion/combination/test_logical_combination.py +++ b/tests/execution_engine/omop/criterion/combination/test_logical_combination.py @@ -11,9 +11,11 @@ from execution_engine.omop.criterion.measurement import Measurement from execution_engine.omop.criterion.noop import NoopCriterion from execution_engine.omop.criterion.procedure_occurrence import ProcedureOccurrence -from execution_engine.task.process import get_processing_module +from execution_engine.task.process import Interval, get_processing_module from execution_engine.util import logic -from execution_engine.util.types import Dosage, TimeRange +from execution_engine.util.interval import IntervalType +from execution_engine.util.types import Dosage +from execution_engine.util.types.timerange import TimeRange from execution_engine.util.value import ValueNumber from tests._fixtures.concept import ( concept_artificial_respiration, @@ -48,6 +50,19 @@ def intervals_to_df(result, by=None): return df +@pytest.fixture(params=["cython", "python"], scope="session") +def process_module(request): + module = get_processing_module("rectangle", version=request.param) + assert module._impl.MODULE_IMPLEMENTATION == request.param + return module + + +class ProcessTest: + @pytest.fixture(autouse=True) + def setup_method(self, process_module): + self.process = process_module + + class TestExpr: """ Test class for testing Expr @@ -127,7 +142,7 @@ def test_expr_contains_criteria(self, mock_criteria): assert expr.args[i] == mock_criteria[i] -class TestCriterionCombinationDatabase(TestCriterion): +class TestCriterionCombinationDatabase(TestCriterion, ProcessTest): """ Test class for testing criterion combinations on the database. """ @@ -177,6 +192,7 @@ def run_criteria_test( base_criterion, observation_window, persons, + result_mode: str = "full_day", ): c = sympy.parse_expr(combination) @@ -208,6 +224,8 @@ def run_criteria_test( cls = lambda *args: logic.AllOrNone(*args) elif c.func.name == "ConditionalFilter": cls = lambda *args: logic.ConditionalFilter(*args) + elif c.func.name == "LeftDependentToggle": + cls = lambda *args: logic.LeftDependentToggle(*args) else: raise ValueError(f"Unknown operator {c.func}") else: @@ -240,6 +258,11 @@ def run_criteria_test( right=symbols[str(c.args[1])], ) + elif hasattr(c.func, "name") and c.func.name == "LeftDependentToggle": + comb = logic.LeftDependentToggle( + left=symbols[str(c.args[0])], + right=symbols[str(c.args[1])], + ) else: comb = cls( *[symbols[str(symbol)] for symbol in c.atoms() if not symbol.is_number] @@ -261,18 +284,47 @@ def run_criteria_test( observation_window=observation_window, ) - df = self.fetch_full_day_result( - db_session, - pi_pair_id=self.pi_pair_id, - criterion_id=None, - category=CohortCategory.POPULATION, - ) - - for person in persons: - result = df.query(f"person_id=={person.person_id}") - expected_result = date_set(expected[person.person_id]) + match result_mode: + case "full_day": + df = self.fetch_full_day_result( + db_session, + pi_pair_id=self.pi_pair_id, + criterion_id=None, + category=CohortCategory.POPULATION, + ) + case "partial_day": + df = self.fetch_partial_day_result( + db_session, + pi_pair_id=self.pi_pair_id, + criterion_id=None, + category=CohortCategory.POPULATION, + ) + case "exact": + df = self.fetch_interval_result( + db_session, + pi_pair_id=self.pi_pair_id, + criterion_id=None, + category=CohortCategory.POPULATION, + ) + + if result_mode in ["full_day", "partial_day"]: + for person in persons: + result = df.query(f"person_id=={person.person_id}") + expected_result = date_set(expected[person.person_id]) + + assert ( + set(pd.to_datetime(result["valid_date"]).dt.date) == expected_result + ) + else: + for person in persons: + result = df.query(f"person_id=={person.person_id}") + result_intervals = [ + Interval(row.interval_start, row.interval_end, row.interval_type) + for _, row in result.iterrows() + ] + expected_result = expected[person.person_id] - assert set(pd.to_datetime(result["valid_date"]).dt.date) == expected_result + assert result_intervals == expected_result class TestCriterionCombinationResultShortObservationWindow( @@ -922,3 +974,428 @@ def test_combination_on_database( observation_window, persons, ) + + +class TestCriterionCombinationLeftDependentToggle(TestCriterionCombinationDatabase): + """ + Test class for testing criterion combinations on the database. + """ + + @pytest.fixture + def observation_window(self) -> TimeRange: + return TimeRange( + start="2023-03-01 04:00:00Z", end="2023-03-04 18:00:00Z", name="observation" + ) + + @pytest.fixture + def criteria(self, db_session): + c1 = ConditionOccurrence( + concept=concept_covid19, + ) + + c2 = ProcedureOccurrence(concept=concept_artificial_respiration) + + c3 = Measurement( + concept=concept_body_weight, + value=ValueNumber(value_min=70, unit=concept_unit_kg), + ) + + c1.set_id(1) + c2.set_id(2) + c3.set_id(3) + + self.register_criterion(c1, db_session) + self.register_criterion(c2, db_session) + self.register_criterion(c3, db_session) + + return [c1, c2, c3] + + @pytest.fixture + def patient_events(self, db_session, person_visit): + _, visit_occurrence = person_visit[0] + + e1 = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_artificial_respiration.concept_id, + start_datetime=pendulum.parse("2023-03-02 12:00:00+01:00"), + end_datetime=pendulum.parse("2023-03-02 12:00:00+01:00"), + ) + + e2 = create_condition( + vo=visit_occurrence, + condition_concept_id=concept_covid19.concept_id, + condition_start_datetime=pendulum.parse("2023-03-02 06:00:00+01:00"), + condition_end_datetime=pendulum.parse("2023-03-03 18:00:00+01:00"), + ) + + db_session.add_all([e1, e2]) + + db_session.commit() + + @pytest.mark.parametrize( + "combination,expected", + [ + ( + "LeftDependentToggle(c1, c2)", + { + 1: [ + Interval( + pendulum.parse("2023-03-01 10:36:24+0100", tz="CET"), + pendulum.parse("2023-03-02 05:59:59+0100", tz="CET"), + type=IntervalType.NOT_APPLICABLE, + ), + Interval( + pendulum.parse("2023-03-02 06:00:00+0100", tz="CET"), + pendulum.parse("2023-03-02 11:59:59+0100", tz="CET"), + type=IntervalType.NEGATIVE, + ), + Interval( + pendulum.parse("2023-03-02 12:00:00+0100", tz="CET"), + pendulum.parse("2023-03-02 12:00:00+0100", tz="CET"), + type=IntervalType.POSITIVE, + ), + Interval( + pendulum.parse("2023-03-02 12:00:01+0100", tz="CET"), + pendulum.parse("2023-03-03 18:00:00+0100", tz="CET"), + type=IntervalType.NEGATIVE, + ), + Interval( + pendulum.parse("2023-03-03 18:00:01+0100", tz="CET"), + pendulum.parse("2023-03-04 19:00:00+0100", tz="CET"), + type=IntervalType.NOT_APPLICABLE, + ), + ], + 2: [ + Interval( + lower=pendulum.parse("2023-03-02 09:36:24+0000", tz="UTC"), + upper=pendulum.parse("2023-03-04 18:00:00+0000", tz="UTC"), + type=IntervalType.NEGATIVE, + ) + ], + 3: [ + Interval( + lower=pendulum.parse("2023-03-03 09:36:24+0000", tz="UTC"), + upper=pendulum.parse("2023-03-04 18:00:00+0000", tz="UTC"), + type=IntervalType.NEGATIVE, + ) + ], + }, + ), + ], + ) + def test_combination_on_database( + self, + person_visit, + db_session, + base_criterion, + patient_events, + criteria, + combination, + expected, + observation_window, + ): + persons = [pv[0] for pv in person_visit] + self.run_criteria_test( + combination, + expected, + db_session, + criteria, + base_criterion, + observation_window, + persons, + result_mode="exact", + ) + + +class TestCriterionCombinationLeftDependentToggleMultipleSameTime( + TestCriterionCombinationDatabase +): + """ + Test class for testing criterion combinations on the database. + """ + + @pytest.fixture + def observation_window(self) -> TimeRange: + return TimeRange( + start="2023-03-01 04:00:00Z", end="2023-03-04 18:00:00Z", name="observation" + ) + + @pytest.fixture + def criteria(self, db_session): + c1 = ConditionOccurrence( + concept=concept_covid19, + ) + + c2 = ProcedureOccurrence(concept=concept_artificial_respiration) + + c3 = Measurement( + concept=concept_body_weight, + value=ValueNumber(value_min=70, unit=concept_unit_kg), + ) + + c1.set_id(1) + c2.set_id(2) + c3.set_id(3) + + self.register_criterion(c1, db_session) + self.register_criterion(c2, db_session) + self.register_criterion(c3, db_session) + + return [c1, c2, c3] + + @pytest.fixture + def patient_events(self, db_session, person_visit): + _, visit_occurrence = person_visit[0] + + e1 = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_artificial_respiration.concept_id, + start_datetime=pendulum.parse("2023-03-02 12:00:00+01:00"), + end_datetime=pendulum.parse("2023-03-02 12:00:00+01:00"), + ) + + e2 = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_artificial_respiration.concept_id, + start_datetime=pendulum.parse("2023-03-02 12:00:00+01:00"), + end_datetime=pendulum.parse("2023-03-02 12:00:00+01:00"), + ) + + e3 = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_artificial_respiration.concept_id, + start_datetime=pendulum.parse("2023-03-02 12:00:00+01:00"), + end_datetime=pendulum.parse("2023-03-02 12:00:00+01:00"), + ) + + e4 = create_condition( + vo=visit_occurrence, + condition_concept_id=concept_covid19.concept_id, + condition_start_datetime=pendulum.parse("2023-03-02 06:00:00+01:00"), + condition_end_datetime=pendulum.parse("2023-03-03 18:00:00+01:00"), + ) + + db_session.add_all([e1, e2, e3, e4]) + + db_session.commit() + + @pytest.mark.parametrize( + "combination,expected", + [ + ( + "LeftDependentToggle(c1, c2)", + { + 1: [ + Interval( + pendulum.parse("2023-03-01 10:36:24+0100", tz="CET"), + pendulum.parse("2023-03-02 05:59:59+0100", tz="CET"), + type=IntervalType.NOT_APPLICABLE, + ), + Interval( + pendulum.parse("2023-03-02 06:00:00+0100", tz="CET"), + pendulum.parse("2023-03-02 11:59:59+0100", tz="CET"), + type=IntervalType.NEGATIVE, + ), + Interval( + pendulum.parse("2023-03-02 12:00:00+0100", tz="CET"), + pendulum.parse("2023-03-02 12:00:00+0100", tz="CET"), + type=IntervalType.POSITIVE, + ), + Interval( + pendulum.parse("2023-03-02 12:00:01+0100", tz="CET"), + pendulum.parse("2023-03-03 18:00:00+0100", tz="CET"), + type=IntervalType.NEGATIVE, + ), + Interval( + pendulum.parse("2023-03-03 18:00:01+0100", tz="CET"), + pendulum.parse("2023-03-04 19:00:00+0100", tz="CET"), + type=IntervalType.NOT_APPLICABLE, + ), + ], + 2: [ + Interval( + lower=pendulum.parse("2023-03-02 09:36:24+0000", tz="UTC"), + upper=pendulum.parse("2023-03-04 18:00:00+0000", tz="UTC"), + type=IntervalType.NEGATIVE, + ) + ], + 3: [ + Interval( + lower=pendulum.parse("2023-03-03 09:36:24+0000", tz="UTC"), + upper=pendulum.parse("2023-03-04 18:00:00+0000", tz="UTC"), + type=IntervalType.NEGATIVE, + ) + ], + }, + ), + ], + ) + def test_combination_on_database( + self, + person_visit, + db_session, + base_criterion, + patient_events, + criteria, + combination, + expected, + observation_window, + ): + persons = [pv[0] for pv in person_visit] + self.run_criteria_test( + combination, + expected, + db_session, + criteria, + base_criterion, + observation_window, + persons, + result_mode="exact", + ) + + +class TestCriterionCombinationLeftDependentToggleMultipleOverlapping( + TestCriterionCombinationDatabase +): + """ + Test class for testing criterion combinations on the database. + """ + + @pytest.fixture + def observation_window(self) -> TimeRange: + return TimeRange( + start="2023-03-01 04:00:00Z", end="2023-03-04 18:00:00Z", name="observation" + ) + + @pytest.fixture + def criteria(self, db_session): + c1 = ConditionOccurrence( + concept=concept_covid19, + ) + + c2 = ProcedureOccurrence(concept=concept_artificial_respiration) + + c3 = Measurement( + concept=concept_body_weight, + value=ValueNumber(value_min=70, unit=concept_unit_kg), + ) + + c1.set_id(1) + c2.set_id(2) + c3.set_id(3) + + self.register_criterion(c1, db_session) + self.register_criterion(c2, db_session) + self.register_criterion(c3, db_session) + + return [c1, c2, c3] + + @pytest.fixture + def patient_events(self, db_session, person_visit): + _, visit_occurrence = person_visit[0] + + e1 = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_artificial_respiration.concept_id, + start_datetime=pendulum.parse("2023-03-02 12:00:00+01:00"), + end_datetime=pendulum.parse("2023-03-02 12:00:00+01:00"), + ) + + e2 = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_artificial_respiration.concept_id, + start_datetime=pendulum.parse("2023-03-02 12:00:00+01:00"), + end_datetime=pendulum.parse("2023-03-02 12:00:01+01:00"), + ) + + e3 = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_artificial_respiration.concept_id, + start_datetime=pendulum.parse("2023-03-02 12:00:00+01:00"), + end_datetime=pendulum.parse("2023-03-02 12:00:02+01:00"), + ) + + e4 = create_condition( + vo=visit_occurrence, + condition_concept_id=concept_covid19.concept_id, + condition_start_datetime=pendulum.parse("2023-03-02 06:00:00+01:00"), + condition_end_datetime=pendulum.parse("2023-03-03 18:00:00+01:00"), + ) + + db_session.add_all([e1, e2, e3, e4]) + + db_session.commit() + + @pytest.mark.parametrize( + "combination,expected", + [ + ( + "LeftDependentToggle(c1, c2)", + { + 1: [ + Interval( + pendulum.parse("2023-03-01 10:36:24+0100", tz="CET"), + pendulum.parse("2023-03-02 05:59:59+0100", tz="CET"), + type=IntervalType.NOT_APPLICABLE, + ), + Interval( + pendulum.parse("2023-03-02 06:00:00+0100", tz="CET"), + pendulum.parse("2023-03-02 11:59:59+0100", tz="CET"), + type=IntervalType.NEGATIVE, + ), + Interval( + pendulum.parse("2023-03-02 12:00:00+0100", tz="CET"), + pendulum.parse("2023-03-02 12:00:02+0100", tz="CET"), + type=IntervalType.POSITIVE, + ), + Interval( + pendulum.parse("2023-03-02 12:00:03+0100", tz="CET"), + pendulum.parse("2023-03-03 18:00:00+0100", tz="CET"), + type=IntervalType.NEGATIVE, + ), + Interval( + pendulum.parse("2023-03-03 18:00:01+0100", tz="CET"), + pendulum.parse("2023-03-04 19:00:00+0100", tz="CET"), + type=IntervalType.NOT_APPLICABLE, + ), + ], + 2: [ + Interval( + lower=pendulum.parse("2023-03-02 09:36:24+0000", tz="UTC"), + upper=pendulum.parse("2023-03-04 18:00:00+0000", tz="UTC"), + type=IntervalType.NEGATIVE, + ) + ], + 3: [ + Interval( + lower=pendulum.parse("2023-03-03 09:36:24+0000", tz="UTC"), + upper=pendulum.parse("2023-03-04 18:00:00+0000", tz="UTC"), + type=IntervalType.NEGATIVE, + ) + ], + }, + ), + ], + ) + def test_combination_on_database( + self, + person_visit, + db_session, + base_criterion, + patient_events, + criteria, + combination, + expected, + observation_window, + ): + persons = [pv[0] for pv in person_visit] + self.run_criteria_test( + combination, + expected, + db_session, + criteria, + base_criterion, + observation_window, + persons, + result_mode="exact", + ) diff --git a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py index 9190b680..6d6878ca 100644 --- a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py +++ b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py @@ -1,6 +1,5 @@ import datetime -import pandas as pd import pendulum import pytest import sqlalchemy as sa @@ -18,15 +17,19 @@ from execution_engine.util import logic, temporal_logic_util from execution_engine.util.enum import TimeIntervalType from execution_engine.util.interval import IntervalType -from execution_engine.util.types import Dosage, TimeRange +from execution_engine.util.types import Dosage +from execution_engine.util.types.timerange import TimeRange from execution_engine.util.value import ValueNumber from tests._fixtures.concept import ( concept_artificial_respiration, + concept_body_height, concept_body_weight, concept_covid19, concept_delir_screening, concept_heparin_ingredient, concept_surgical_procedure, + concept_tidal_volume, + concept_unit_cm, concept_unit_kg, concept_unit_mg, ) @@ -39,18 +42,20 @@ create_procedure, create_visit, ) -from tests.functions import intervals_to_df as intervals_to_df_orig from tests.mocks.criterion import MockCriterion -process = get_processing_module() +@pytest.fixture(params=["cython", "python"], scope="session") +def process_module(request): + module = get_processing_module("rectangle", version=request.param) + assert module._impl.MODULE_IMPLEMENTATION == request.param + return module -def intervals_to_df(result, by=None): - df = intervals_to_df_orig(result, by, process.normalize_interval) - for col in df.columns: - if isinstance(df[col].dtype, pd.DatetimeTZDtype): - df[col] = df[col].dt.tz_convert("Europe/Berlin") - return df + +class ProcessTest: + @pytest.fixture(autouse=True) + def setup_method(self, process_module): + self.process = process_module class TestFixedWindowTemporalIndicatorCombination: @@ -100,6 +105,7 @@ def test_criterion_combination_from_dict(self, mock_criteria): expr = logic.Expr.from_dict(expr_dict) + assert isinstance(expr, logic.TemporalMinCount) assert len(expr.args) == len(mock_criteria) assert expr.start_time == datetime.time(8, 0) assert expr.end_time == datetime.time(16, 0) @@ -123,6 +129,7 @@ def test_criterion_combination_from_dict(self, mock_criteria): expr = logic.Expr.from_dict(expr_dict) + assert isinstance(expr, logic.TemporalMinCount) assert len(expr.args) == len(mock_criteria) assert expr.start_time is None assert expr.end_time is None @@ -132,9 +139,6 @@ def test_criterion_combination_from_dict(self, mock_criteria): for idx, criterion in enumerate(expr.args): assert str(criterion) == str(mock_criteria[idx]) - # @pytest.mark.skip( - # reason="the repr does not return arguments in a consistent manner" - # ) def test_repr(self, mock_criteria): expr = temporal_logic_util.MorningShift(mock_criteria[0]) @@ -200,7 +204,7 @@ def test_expr_contains_criteria(self, mock_criteria): concept=concept_covid19, ) -c3 = ProcedureOccurrence( +artificial_respiration = ProcedureOccurrence( concept=concept_artificial_respiration, ) @@ -208,6 +212,10 @@ def test_expr_contains_criteria(self, mock_criteria): concept=concept_surgical_procedure, ) +delir_screening = ProcedureOccurrence( + concept=concept_delir_screening, +) + bodyweight_measurement_without_forward_fill = Measurement( concept=concept_body_weight, value=ValueNumber.parse("<=110", unit=concept_unit_kg), @@ -221,12 +229,38 @@ def test_expr_contains_criteria(self, mock_criteria): static=False, ) -delir_screening = ProcedureOccurrence( - concept=concept_delir_screening, +body_height_measurement_without_forward_fill = Measurement( + concept=concept_body_height, + value=ValueNumber.parse("<=110", unit=concept_unit_cm), + static=False, + forward_fill=False, +) + +body_height_measurement_with_forward_fill = Measurement( + concept=concept_body_height, + value=ValueNumber.parse("<=110", unit=concept_unit_cm), + static=False, ) +tidal_volume_measurement_without_forward_fill = Measurement( + concept=concept_tidal_volume, + value=ValueNumber.parse( + "<=110", unit=concept_unit_cm + ), # TODO(jmoringe): copied; does not make sense + static=False, + forward_fill=False, +) + +tidal_volume_measurement_with_forward_fill = Measurement( + concept=concept_tidal_volume, + value=ValueNumber.parse( + "<=110", unit=concept_unit_cm + ), # TODO(jmoringe): copied; does not make sense + static=False, +) -class TestCriterionCombinationDatabase(TestCriterion): + +class TestCriterionCombinationDatabase(TestCriterion, ProcessTest): """ Test class for testing criterion combinations on the database. """ @@ -242,9 +276,14 @@ def criteria(self, db_session): criteria = [ c1, c2, - c3, + artificial_respiration, + c4, bodyweight_measurement_without_forward_fill, bodyweight_measurement_with_forward_fill, + body_height_measurement_without_forward_fill, + body_height_measurement_with_forward_fill, + tidal_volume_measurement_without_forward_fill, + tidal_volume_measurement_with_forward_fill, delir_screening, ] for i, c in enumerate(criteria): @@ -384,7 +423,7 @@ def patient_events(self, db_session, person_visit): ), ( temporal_logic_util.Day( - c3, + artificial_respiration, ), { 1: { @@ -461,7 +500,7 @@ def patient_events(self, db_session, person_visit): ), ( temporal_logic_util.Presence( - c3, + artificial_respiration, start_time=datetime.time(8, 30), end_time=datetime.time(16, 59), ), @@ -469,7 +508,7 @@ def patient_events(self, db_session, person_visit): ), ( temporal_logic_util.Presence( - c3, + artificial_respiration, start_time=datetime.time(17, 30), end_time=datetime.time(22, 00), ), @@ -523,7 +562,7 @@ def patient_events(self, db_session, person_visit): ), ( temporal_logic_util.MorningShift( - c3, + artificial_respiration, ), {1: set(), 2: set(), 3: set()}, ), @@ -570,7 +609,7 @@ def patient_events(self, db_session, person_visit): ), ( temporal_logic_util.AfternoonShift( - c3, + artificial_respiration, ), { 1: { @@ -622,7 +661,7 @@ def patient_events(self, db_session, person_visit): ), ( temporal_logic_util.NightShift( - c3, + artificial_respiration, ), {1: set(), 2: set(), 3: set()}, ), @@ -665,7 +704,7 @@ def patient_events(self, db_session, person_visit): ), ( temporal_logic_util.NightShiftBeforeMidnight( - c3, + artificial_respiration, ), {1: set(), 2: set(), 3: set()}, ), @@ -704,7 +743,7 @@ def patient_events(self, db_session, person_visit): ), ( temporal_logic_util.NightShiftAfterMidnight( - c3, + artificial_respiration, ), {1: set(), 2: set(), 3: set()}, ), @@ -817,7 +856,7 @@ def patient_events(self, db_session, visit_occurrence): ), ( temporal_logic_util.Day( - c3, + artificial_respiration, ), { 1: { @@ -909,7 +948,7 @@ def patient_events(self, db_session, visit_occurrence): ), ( temporal_logic_util.Presence( - c3, + artificial_respiration, start_time=datetime.time(17, 30), end_time=datetime.time(22, 00), ), @@ -1011,7 +1050,7 @@ def patient_events(self, db_session, visit_occurrence): ), ( temporal_logic_util.MorningShift( - c3, + artificial_respiration, ), { 1: { @@ -1115,7 +1154,7 @@ def patient_events(self, db_session, visit_occurrence): ), ( temporal_logic_util.AfternoonShift( - c3, + artificial_respiration, ), { 1: { @@ -1219,7 +1258,7 @@ def patient_events(self, db_session, visit_occurrence): ), ( temporal_logic_util.NightShift( - c3, + artificial_respiration, ), { 1: { @@ -1798,3 +1837,694 @@ def test_at_least_combination_on_database_no_measurements( ) ) assert result_tuples == expected[person.person_id] + + +class TestIntervalRatio(TestCriterionCombinationDatabase): + """This test ensures that counting criteria with minimum count + thresholds adapt to the temporal interval of the population + criterion. + + As a concrete test case, this class applies an intervention + criterion that requires a procedure to be performed in at least + two of three shifts for each day. However, if the hospital stay + ends during a given day, shifts on that day but outside the + hospital stay should not count towards the threshold of the + criterion. + """ + + @pytest.fixture + def observation_window(self) -> TimeRange: + return TimeRange( + name="observation", start="2025-02-18 13:55:00Z", end="2025-02-23 11:00:00Z" + ) + + def patient_events(self, db_session, visit_occurrence): + c1 = create_condition( + vo=visit_occurrence, + condition_concept_id=concept_covid19.concept_id, + condition_start_datetime=pendulum.parse("2025-02-19 08:00:00+01:00"), + condition_end_datetime=pendulum.parse("2025-02-23 02:00:00+01:00"), + ) + # One screen on the 19th + e1_night = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_delir_screening.concept_id, + start_datetime=pendulum.parse("2025-02-19 23:00:00+01:00"), + end_datetime=pendulum.parse("2025-02-19 23:01:00+01:00"), + ) + # Two screen on the 20th + e2_morn = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_delir_screening.concept_id, + start_datetime=pendulum.parse("2025-02-20 08:00:00+01:00"), + end_datetime=pendulum.parse("2025-02-20 08:01:00+01:00"), + ) + e2_late = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_delir_screening.concept_id, + start_datetime=pendulum.parse("2025-02-20 15:00:00+01:00"), + end_datetime=pendulum.parse("2025-02-20 15:01:00+01:00"), + ) + # Three screenings on the 21st + e3_morn = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_delir_screening.concept_id, + start_datetime=pendulum.parse("2025-02-21 08:00:00+01:00"), + end_datetime=pendulum.parse("2025-02-21 08:01:00+01:00"), + ) + e3_late = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_delir_screening.concept_id, + start_datetime=pendulum.parse("2025-02-21 15:00:00+01:00"), + end_datetime=pendulum.parse("2025-02-21 15:01:00+01:00"), + ) + e3_night = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_delir_screening.concept_id, + start_datetime=pendulum.parse("2025-02-21 23:00:00+01:00"), + end_datetime=pendulum.parse("2025-02-21 23:01:00+01:00"), + ) + # Four screenings on the 22st + e4_night_pre = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_delir_screening.concept_id, + start_datetime=pendulum.parse("2025-02-22 01:00:00+01:00"), + end_datetime=pendulum.parse("2025-02-22 01:01:00+01:00"), + ) + e4_morn = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_delir_screening.concept_id, + start_datetime=pendulum.parse("2025-02-22 08:00:00+01:00"), + end_datetime=pendulum.parse("2025-02-22 08:01:00+01:00"), + ) + e4_late = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_delir_screening.concept_id, + start_datetime=pendulum.parse("2025-02-22 15:00:00+01:00"), + end_datetime=pendulum.parse("2025-02-22 15:01:00+01:00"), + ) + e4_night = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_delir_screening.concept_id, + start_datetime=pendulum.parse("2025-02-22 23:00:00+01:00"), + end_datetime=pendulum.parse("2025-02-22 23:01:00+01:00"), + ) + db_session.add_all( + [ + c1, + e1_night, + e2_morn, + e2_late, + e3_morn, + e3_late, + e3_night, + e4_night_pre, + e4_morn, + e4_late, + e4_night, + ] + ) + db_session.commit() + + @pytest.mark.parametrize( + "population,intervention,expected", + [ + ( + logic.And(c2), # population + logic.CappedMinCount( + *[ + temporal_logic_util.Day( + criterion=shift_class(criterion=delir_screening), + ) + for shift_class in [ + temporal_logic_util.NightShiftAfterMidnight, + temporal_logic_util.MorningShift, + temporal_logic_util.AfternoonShift, + temporal_logic_util.NightShiftBeforeMidnight, + ] + ], + threshold=4, + ), + { + 1: [ + # The criterion should be fulfilled on the day + # before the discharge and on the day of the + # discharge even though the actual number of + # screenings on the latter day is just 1. + ( + IntervalType.NOT_APPLICABLE, + "nan", # workaround, is actually really a nan value + pendulum.parse("2025-02-18 16:55:00Z"), + pendulum.parse("2025-02-19 07:59:59+01:00"), + ), + ( + IntervalType.NEGATIVE, + 1 / 3, # one this day, only 3 shifts are possible, + pendulum.parse("2025-02-19 08:00:00+01:00"), + pendulum.parse("2025-02-19 23:59:59+01:00"), + ), + ( + IntervalType.NEGATIVE, + 0.5, + pendulum.parse("2025-02-20 00:00:00+01:00"), + pendulum.parse("2025-02-20 23:59:59+01:00"), + ), + ( + IntervalType.NEGATIVE, + 0.75, + pendulum.parse("2025-02-21 00:00:00+01:00"), + pendulum.parse("2025-02-21 23:59:59+01:00"), + ), + ( + IntervalType.POSITIVE, + 1.0, + pendulum.parse("2025-02-22 00:00:00+01:00"), + pendulum.parse("2025-02-22 23:59:59+01:00"), + ), + ( + IntervalType.NEGATIVE, + 0, + pendulum.parse("2025-02-23 00:00:00+01:00"), + pendulum.parse("2025-02-23 02:00:00+01:00"), + ), + ( + IntervalType.NOT_APPLICABLE, + "nan", # workaround, is actually really a nan value + pendulum.parse("2025-02-23 02:00:01+01:00"), + pendulum.parse("2025-02-23 04:30:00Z"), + ), + ] + }, + ), + ], + ) + def test_interval_ratio_on_database( + self, + person, + db_session, + population, + intervention, + base_criterion, + expected, + observation_window, + criteria, + ): + persons = [person[0]] # only one person + vos = [ + create_visit( + person_id=person.person_id, + visit_start_datetime=observation_window.start + + datetime.timedelta(hours=3), + visit_end_datetime=observation_window.end + - datetime.timedelta(hours=6.5), + visit_concept_id=concepts.INTENSIVE_CARE, + ) + for person in persons + ] + + self.patient_events(db_session, vos[0]) + + db_session.add_all(vos) + db_session.commit() + + self.insert_expression( + db_session, population, intervention, base_criterion, observation_window + ) + + df = self.fetch_interval_result( + db_session, + pi_pair_id=self.pi_pair_id, + criterion_id=None, + category=CohortCategory.POPULATION_INTERVENTION, + ) + + for person in persons: + result = df.query(f"person_id=={person.person_id}") + result_tuples = list( + result[ + [ + "interval_type", + "interval_ratio", + "interval_start", + "interval_end", + ] + ] + .fillna("nan") + .itertuples(index=False, name=None) + ) + + for result_tuple, expected_tuple in zip( + result_tuples, expected[person.person_id] + ): + assert result_tuple == expected_tuple + + +class TestIndicatorWindowsMulitplePatients(TestCriterionCombinationDatabase): + """ + This test ensures that the data TemporalCount operator works + independently between persons within a PersonIntervals data set. + + This is mostly a regression test since at one point the exact + problem of cross-talk between the data structures for different + persons caused the operator to return incorrect results. + """ + + @pytest.fixture + def observation_window(self) -> TimeRange: + return TimeRange( + name="observation", + start="2025-02-18 14:55:00+01:00", + end="2025-02-22 12:00:00+01:00", + ) + + def patient_events(self, db_session, visit_occurrence): + person_id = visit_occurrence.person_id + events = [] + c1 = create_condition( + vo=visit_occurrence, + condition_concept_id=concept_covid19.concept_id, + condition_start_datetime=pendulum.parse("2025-02-19 08:00:00+01:00"), + condition_end_datetime=pendulum.parse("2025-02-21 02:00:00+01:00"), + ) + events.append(c1) + if person_id == 1: + e1 = create_procedure( + vo=visit_occurrence, + procedure_concept_id=concept_delir_screening.concept_id, + start_datetime=pendulum.parse("2025-02-19 18:00:00+01:00"), + end_datetime=pendulum.parse("2025-02-19 18:01:00+01:00"), + ) + events.append(e1) + db_session.add_all(events) + db_session.commit() + + @pytest.mark.parametrize( + "population,intervention,expected", + [ + ( + logic.And(c2), # population + temporal_logic_util.Day(criterion=delir_screening), + { + 1: [ + ( + IntervalType.NOT_APPLICABLE, + pendulum.parse("2025-02-18 17:55:00+01:00"), + pendulum.parse("2025-02-19 07:59:59+01:00"), + ), + ( + IntervalType.POSITIVE, + pendulum.parse("2025-02-19 08:00:00+01:00"), + pendulum.parse("2025-02-19 23:59:59+01:00"), + ), + ( + IntervalType.NEGATIVE, + pendulum.parse("2025-02-20 00:00:00+01:00"), + pendulum.parse("2025-02-21 02:00:00+01:00"), + ), + ( + IntervalType.NOT_APPLICABLE, + pendulum.parse("2025-02-21 02:00:01+01:00"), + pendulum.parse("2025-02-22 05:30:00+01:00"), + ), + ], + 2: [ + ( + IntervalType.NOT_APPLICABLE, + pendulum.parse("2025-02-18 17:55:00+01:00"), + pendulum.parse("2025-02-19 07:59:59+01:00"), + ), + # If cross-talk between the data structures + # for different persons occurs, parts of the + # following interval may turn positive because + # of the results for the first person. + ( + IntervalType.NEGATIVE, + pendulum.parse("2025-02-19 08:00:00+01:00"), + pendulum.parse("2025-02-21 02:00:00+01:00"), + ), + ( + IntervalType.NOT_APPLICABLE, + pendulum.parse("2025-02-21 02:00:01+01:00"), + pendulum.parse("2025-02-22 05:30:00+01:00"), + ), + ], + }, + ), + ], + ) + def test_multiple_patients_on_database( + self, + person, + db_session, + population, + intervention, + base_criterion, + expected, + observation_window, + criteria, + ): + persons = person[:2] + vos = [] + for person in persons: + visit = create_visit( + person_id=person.person_id, + visit_start_datetime=observation_window.start + + datetime.timedelta(hours=3), + visit_end_datetime=observation_window.end + - datetime.timedelta(hours=6.5), + visit_concept_id=concepts.INTENSIVE_CARE, + ) + vos.append(visit) + self.patient_events(db_session, visit) + + db_session.add_all(vos) + db_session.commit() + + self.insert_expression( + db_session, population, intervention, base_criterion, observation_window + ) + + df = self.fetch_interval_result( + db_session, + pi_pair_id=self.pi_pair_id, + criterion_id=None, + category=CohortCategory.POPULATION_INTERVENTION, + ) + + for person in persons: + result = df.query(f"person_id=={person.person_id}") + result_tuples = list( + result[["interval_type", "interval_start", "interval_end"]] + .fillna("nan") + .itertuples(index=False, name=None) + ) + + for result_tuple, expected_tuple in zip( + result_tuples, expected[person.person_id] + ): + assert result_tuple == expected_tuple + + +class TestCountOnIndicatorWindows(TestCriterionCombinationDatabase): + """ + This test checks the behavior of the logical Count operator for + different thresholds and different kinds of inputs. Of particular + interest is the computed count attribute of the result intervals + and the behavior for edge cases regarding the count thresholds. + """ + + @pytest.fixture + def observation_window(self) -> TimeRange: + return TimeRange( + name="observation", + start="2025-02-18 14:55:00+01:00", + end="2025-02-22 12:00:00+01:00", + ) + + def patient_events(self, db_session, visit_occurrence): + person_id = visit_occurrence.person_id + events = [ + create_condition( + vo=visit_occurrence, + condition_concept_id=concept_covid19.concept_id, + condition_start_datetime=pendulum.parse("2025-02-19 08:00:00+01:00"), + condition_end_datetime=pendulum.parse("2025-02-21 02:00:00+01:00"), + ) + ] + if person_id == 1: + events.append( + create_measurement( + vo=visit_occurrence, + measurement_concept_id=concept_body_weight.concept_id, + measurement_datetime=pendulum.parse("2025-02-19 18:00:00+01:00"), + value_as_number=90, + unit_concept_id=concept_unit_kg.concept_id, + ) + ) + events.append( + create_measurement( + vo=visit_occurrence, + measurement_concept_id=concept_body_height.concept_id, + measurement_datetime=pendulum.parse("2025-02-20 07:00:00+01:00"), + value_as_number=90, + unit_concept_id=concept_unit_cm.concept_id, + ) + ) + events.append( + create_measurement( + vo=visit_occurrence, + measurement_concept_id=concept_tidal_volume.concept_id, + measurement_datetime=pendulum.parse("2025-02-20 18:00:00+01:00"), + value_as_number=90, + unit_concept_id=concept_unit_cm.concept_id, + ) + ) + db_session.add_all(events) + db_session.commit() + + @pytest.mark.parametrize( + "population,intervention,expected", + [ + ( + logic.And(c2), # population + logic.MinCount( + temporal_logic_util.AnyTime( + bodyweight_measurement_without_forward_fill + ), + temporal_logic_util.AnyTime( + body_height_measurement_without_forward_fill + ), + temporal_logic_util.AnyTime( + tidal_volume_measurement_without_forward_fill + ), + threshold=1, + ), + { + 1: [ + ( + IntervalType.NOT_APPLICABLE, + "nan", + pendulum.parse("2025-02-18 17:55:00+01:00"), + pendulum.parse("2025-02-19 07:59:59+01:00"), + ), + ( + IntervalType.POSITIVE, + 3, + pendulum.parse("2025-02-19 08:00:00+01:00"), + pendulum.parse("2025-02-21 02:00:00+01:00"), + ), + ( + IntervalType.NOT_APPLICABLE, + "nan", + pendulum.parse("2025-02-21 02:00:01+01:00"), + pendulum.parse("2025-02-22 05:30:00+01:00"), + ), + ], + 2: [ + ( + IntervalType.NOT_APPLICABLE, + "nan", + pendulum.parse("2025-02-18 17:55:00+01:00"), + pendulum.parse("2025-02-19 07:59:59+01:00"), + ), + ( + IntervalType.NEGATIVE, + 0, + pendulum.parse("2025-02-19 08:00:00+01:00"), + pendulum.parse("2025-02-21 02:00:00+01:00"), + ), + ( + IntervalType.NOT_APPLICABLE, + "nan", + pendulum.parse("2025-02-21 02:00:01+01:00"), + pendulum.parse("2025-02-22 05:30:00+01:00"), + ), + ], + }, + ), + ( + logic.And(c2), # population + logic.MinCount( + bodyweight_measurement_with_forward_fill, + body_height_measurement_with_forward_fill, + tidal_volume_measurement_with_forward_fill, + threshold=2, + ), + { + 1: [ + ( + IntervalType.NOT_APPLICABLE, + "nan", + pendulum.parse("2025-02-18 17:55:00+01:00"), + pendulum.parse("2025-02-19 07:59:59+01:00"), + ), + ( + IntervalType.NO_DATA, + 0.0, + pendulum.parse("2025-02-19 08:00:00+01:00"), + pendulum.parse("2025-02-19 17:59:59+01:00"), + ), + ( + IntervalType.NEGATIVE, + 0.5, + pendulum.parse("2025-02-19 18:00:00+01:00"), + pendulum.parse("2025-02-20 06:59:59+01:00"), + ), + ( + IntervalType.POSITIVE, + 1, + pendulum.parse("2025-02-20 07:00:00+01:00"), + pendulum.parse("2025-02-20 17:59:59+01:00"), + ), + ( + IntervalType.POSITIVE, + 1.5, + pendulum.parse("2025-02-20 18:00:00+01:00"), + pendulum.parse("2025-02-21 02:00:00+01:00"), + ), + ( + IntervalType.NOT_APPLICABLE, + "nan", + pendulum.parse("2025-02-21 02:00:01+01:00"), + pendulum.parse("2025-02-22 05:30:00+01:00"), + ), + ], + 2: [ + ( + IntervalType.NOT_APPLICABLE, + "nan", + pendulum.parse("2025-02-18 17:55:00+01:00"), + pendulum.parse("2025-02-19 07:59:59+01:00"), + ), + ( + IntervalType.NO_DATA, + 0, + pendulum.parse("2025-02-19 08:00:00+01:00"), + pendulum.parse("2025-02-21 02:00:00+01:00"), + ), + ( + IntervalType.NOT_APPLICABLE, + "nan", + pendulum.parse("2025-02-21 02:00:01+01:00"), + pendulum.parse("2025-02-22 05:30:00+01:00"), + ), + ], + }, + ), + ( + logic.And(c2), # population + logic.MaxCount( + bodyweight_measurement_with_forward_fill, + body_height_measurement_with_forward_fill, + tidal_volume_measurement_with_forward_fill, + threshold=2, + ), + { + 1: [ + ( + IntervalType.NOT_APPLICABLE, + "nan", + pendulum.parse("2025-02-18 17:55:00+01:00"), + pendulum.parse("2025-02-19 07:59:59+01:00"), + ), + ( + IntervalType.POSITIVE, + "nan", + pendulum.parse("2025-02-19 08:00:00+01:00"), + pendulum.parse("2025-02-20 17:59:59+01:00"), + ), + ( + IntervalType.NEGATIVE, + "nan", + pendulum.parse("2025-02-20 18:00:00+01:00"), + pendulum.parse("2025-02-21 02:00:00+01:00"), + ), + ( + IntervalType.NOT_APPLICABLE, + "nan", + pendulum.parse("2025-02-21 02:00:01+01:00"), + pendulum.parse("2025-02-22 05:30:00+01:00"), + ), + ], + 2: [ + ( + IntervalType.NOT_APPLICABLE, + "nan", + pendulum.parse("2025-02-18 17:55:00+01:00"), + pendulum.parse("2025-02-19 07:59:59+01:00"), + ), + ( + IntervalType.POSITIVE, + "nan", + pendulum.parse("2025-02-19 08:00:00+01:00"), + pendulum.parse("2025-02-21 02:00:00+01:00"), + ), + ( + IntervalType.NOT_APPLICABLE, + "nan", + pendulum.parse("2025-02-21 02:00:01+01:00"), + pendulum.parse("2025-02-22 05:30:00+01:00"), + ), + ], + }, + ), + ], + ) + def test_combination_on_database( + self, + person, + db_session, + population, + intervention, + base_criterion, + expected, + observation_window, + criteria, + ): + persons = person[:2] + vos = [] + for person in persons: + visit = create_visit( + person_id=person.person_id, + visit_start_datetime=observation_window.start + + datetime.timedelta(hours=3), + visit_end_datetime=observation_window.end + - datetime.timedelta(hours=6.5), + visit_concept_id=concepts.INTENSIVE_CARE, + ) + vos.append(visit) + self.patient_events(db_session, visit) + + db_session.add_all(vos) + db_session.commit() + + self.insert_expression( + db_session, population, intervention, base_criterion, observation_window + ) + + df = self.fetch_interval_result( + db_session, + pi_pair_id=self.pi_pair_id, + criterion_id=None, + category=CohortCategory.POPULATION_INTERVENTION, + ) + + for person in persons: + result = df.query(f"person_id=={person.person_id}") + result_tuples = list( + result[ + [ + "interval_type", + "interval_ratio", + "interval_start", + "interval_end", + ] + ] + .fillna("nan") + .itertuples(index=False, name=None) + ) + + for result_tuple, expected_tuple in zip( + result_tuples, expected[person.person_id] + ): + assert result_tuple == expected_tuple diff --git a/tests/execution_engine/omop/criterion/test_criterion.py b/tests/execution_engine/omop/criterion/test_criterion.py index 092e2e9c..ee96c222 100644 --- a/tests/execution_engine/omop/criterion/test_criterion.py +++ b/tests/execution_engine/omop/criterion/test_criterion.py @@ -21,6 +21,7 @@ partial_day_coverage, ) from execution_engine.omop.db.omop.tables import Person +from execution_engine.omop.sqlclient import datetime_cols_to_epoch from execution_engine.task import ( # noqa: F401 -- required for the mock.patch below runner, task, @@ -29,7 +30,7 @@ from execution_engine.util import datetime_converter, logic from execution_engine.util.db import add_result_insert from execution_engine.util.interval import IntervalType -from execution_engine.util.types import TimeRange +from execution_engine.util.types.timerange import TimeRange from execution_engine.util.value import ValueConcept, ValueNumber from tests._fixtures.omop_fixture import celida_recommendation from tests._testdata import concepts @@ -281,6 +282,7 @@ def insert_criterion(self, db_session, criterion, observation_window: TimeRange) self.register_criterion(criterion, db_session) query = criterion.create_query() + query = datetime_cols_to_epoch(query) result = db_session.connection().execute( query, parameters=observation_window.model_dump() | {"run_id": self.run_id} diff --git a/tests/execution_engine/omop/criterion/test_occurrence_criterion.py b/tests/execution_engine/omop/criterion/test_occurrence_criterion.py index 029a9ece..07ac99ed 100644 --- a/tests/execution_engine/omop/criterion/test_occurrence_criterion.py +++ b/tests/execution_engine/omop/criterion/test_occurrence_criterion.py @@ -4,7 +4,7 @@ import pytest from execution_engine.util.interval import IntervalType, TypedInterval -from execution_engine.util.types import TimeRange +from execution_engine.util.types.timerange import TimeRange from tests._fixtures.omop_fixture import disable_postgres_trigger from tests.execution_engine.omop.criterion.test_criterion import TestCriterion, date_set @@ -160,55 +160,53 @@ def test_multiple_occurrences_multiple_days( @pytest.mark.parametrize( "test_cases", [ - ( - [ - { - "time_range": [ # non-overlapping - ("2023-03-03 08:00:00Z", "2023-03-03 16:00:00Z"), - ("2023-03-04 09:00:00Z", "2023-03-06 15:00:00Z"), - ("2023-03-08 10:00:00Z", "2023-03-09 18:00:00Z"), - ], - "expected": { - "2023-03-03", - "2023-03-04", - "2023-03-05", - "2023-03-06", - "2023-03-08", - "2023-03-09", - }, + [ + { + "time_range": [ # non-overlapping + ("2023-03-03 08:00:00Z", "2023-03-03 16:00:00Z"), + ("2023-03-04 09:00:00Z", "2023-03-06 15:00:00Z"), + ("2023-03-08 10:00:00Z", "2023-03-09 18:00:00Z"), + ], + "expected": { + "2023-03-03", + "2023-03-04", + "2023-03-05", + "2023-03-06", + "2023-03-08", + "2023-03-09", }, - { - "time_range": [ # exact overlap - ("2023-03-01 08:00:00Z", "2023-03-02 16:00:00Z"), - ("2023-03-02 16:00:00Z", "2023-03-03 23:59:00Z"), - ("2023-03-03 23:59:00Z", "2023-03-04 11:00:00Z"), - ], - "expected": { - "2023-03-01", - "2023-03-02", - "2023-03-03", - "2023-03-04", - }, + }, + { + "time_range": [ # exact overlap + ("2023-03-01 08:00:00Z", "2023-03-02 16:00:00Z"), + ("2023-03-02 16:00:00Z", "2023-03-03 23:59:00Z"), + ("2023-03-03 23:59:00Z", "2023-03-04 11:00:00Z"), + ], + "expected": { + "2023-03-01", + "2023-03-02", + "2023-03-03", + "2023-03-04", }, - { - "time_range": [ # overlap by some margin - ("2023-03-01 08:00:00Z", "2023-03-03 16:00:00Z"), - ("2023-03-03 12:00:00Z", "2023-03-05 20:00:00Z"), - ("2023-03-06 10:00:00Z", "2023-03-08 18:00:00Z"), - ], - "expected": { - "2023-03-01", - "2023-03-02", - "2023-03-03", - "2023-03-04", - "2023-03-05", - "2023-03-06", - "2023-03-07", - "2023-03-08", - }, + }, + { + "time_range": [ # overlap by some margin + ("2023-03-01 08:00:00Z", "2023-03-03 16:00:00Z"), + ("2023-03-03 12:00:00Z", "2023-03-05 20:00:00Z"), + ("2023-03-06 10:00:00Z", "2023-03-08 18:00:00Z"), + ], + "expected": { + "2023-03-01", + "2023-03-02", + "2023-03-03", + "2023-03-04", + "2023-03-05", + "2023-03-06", + "2023-03-07", + "2023-03-08", }, - ] - ) + }, + ] ], ) def test_multiple_persons( diff --git a/tests/execution_engine/omop/criterion/test_procedure_occurrence.py b/tests/execution_engine/omop/criterion/test_procedure_occurrence.py index 0caa9861..aa2a9c05 100644 --- a/tests/execution_engine/omop/criterion/test_procedure_occurrence.py +++ b/tests/execution_engine/omop/criterion/test_procedure_occurrence.py @@ -4,7 +4,8 @@ from execution_engine.omop.concepts import Concept from execution_engine.omop.criterion.procedure_occurrence import ProcedureOccurrence from execution_engine.util.enum import TimeUnit -from execution_engine.util.types import TimeRange, Timing +from execution_engine.util.types import Timing +from execution_engine.util.types.timerange import TimeRange from execution_engine.util.value import ValueNumber from execution_engine.util.value.time import ValueDuration from tests.execution_engine.omop.criterion.test_occurrence_criterion import Occurrence diff --git a/tests/execution_engine/omop/db/celida/test_triggers.py b/tests/execution_engine/omop/db/celida/test_triggers.py index 911f6577..ad75dea2 100644 --- a/tests/execution_engine/omop/db/celida/test_triggers.py +++ b/tests/execution_engine/omop/db/celida/test_triggers.py @@ -8,7 +8,7 @@ from execution_engine.omop.db.omop.schema import SCHEMA_NAME as OMOP_SCHEMA_NAME from execution_engine.omop.db.omop.tables import Person from execution_engine.util.interval import IntervalType as T -from execution_engine.util.types import TimeRange +from execution_engine.util.types.timerange import TimeRange from tests._fixtures.omop_fixture import celida_recommendation diff --git a/tests/execution_engine/task/process/test_rectangle.py b/tests/execution_engine/task/process/test_rectangle.py index 97fd2c32..5203cf21 100644 --- a/tests/execution_engine/task/process/test_rectangle.py +++ b/tests/execution_engine/task/process/test_rectangle.py @@ -1,6 +1,3 @@ -import random -from datetime import time - import pandas as pd import pendulum import pytest @@ -10,9 +7,12 @@ Interval, IntervalWithCount, get_processing_module, + interval_like, ) +from execution_engine.util.interval import IntervalType from execution_engine.util.interval import IntervalType as T -from execution_engine.util.types import PersonIntervals, TimeRange +from execution_engine.util.types import PersonIntervals +from execution_engine.util.types.timerange import TimeRange from tests.functions import df_from_str from tests.functions import intervals_to_df as intervals_to_df_original from tests.functions import parse_dt @@ -1262,406 +1262,6 @@ def test_union_rect_timezones(self): ), "Failed: Mixed intervals not handled correctly" -class TestUnionRectWithCount(ProcessTest): - def test_union_rect_with_count_negative_duration(self): - intervals = [ - IntervalWithCount(lower=5, upper=3, type=T.POSITIVE, count=1), - ] - expected_intervals = [] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - def test_union_special_cases(self): - intervals = [ - IntervalWithCount(lower=180, upper=190, type=T.POSITIVE, count=1), - IntervalWithCount(lower=180, upper=180, type=T.POSITIVE, count=2), - IntervalWithCount(lower=190, upper=200, type=T.POSITIVE, count=1), - IntervalWithCount(lower=180, upper=189, type=T.POSITIVE, count=1), - IntervalWithCount(lower=190, upper=190, type=T.POSITIVE, count=2), - IntervalWithCount(lower=191, upper=200, type=T.POSITIVE, count=1), - ] - - expected_intervals = [ - IntervalWithCount(lower=180, upper=180, type=T.POSITIVE, count=4), - IntervalWithCount(lower=181, upper=189, type=T.POSITIVE, count=2), - IntervalWithCount(lower=190, upper=190, type=T.POSITIVE, count=4), - IntervalWithCount(lower=191, upper=200, type=T.POSITIVE, count=2), - ] - - result = self.process._impl.union_rects_with_count(intervals) - assert result == expected_intervals - - intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1), - IntervalWithCount(lower=4, upper=5, type=T.POSITIVE, count=1), - IntervalWithCount(lower=4, upper=5, type=T.POSITIVE, count=1), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1), - IntervalWithCount(lower=4, upper=5, type=T.POSITIVE, count=2), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - @pytest.mark.parametrize("factor", [1, 2, 3]) - def test_union_rect_with_count_adjacent(self, factor): - intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.POSITIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=4, type=T.POSITIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - @pytest.mark.parametrize("factor", [1, 2, 3]) - def test_union_rect_with_count_one(self, factor): - intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.NEGATIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.NEGATIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - intervals = [ - IntervalWithCount(lower=1, upper=3, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.NEGATIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=3, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=4, type=T.NEGATIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - intervals = [ - IntervalWithCount(lower=1, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.NEGATIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=4, type=T.POSITIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - intervals = [ - IntervalWithCount(lower=1, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.POSITIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.POSITIVE, count=2 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - intervals = [ - IntervalWithCount(lower=1, upper=4, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.POSITIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.POSITIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=5, upper=6, type=T.NEGATIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=5, upper=6, type=T.NEGATIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=6, type=T.NEGATIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=2 * factor), - IntervalWithCount(lower=3, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=5, upper=6, type=T.NEGATIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=2, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - # multiple intervals with the same start and end - intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=4 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - # multiple intervals with the same start, different end - intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=3, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=5, type=T.POSITIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=4 * factor), - IntervalWithCount(lower=3, upper=3, type=T.POSITIVE, count=3 * factor), - IntervalWithCount(lower=4, upper=4, type=T.POSITIVE, count=2 * factor), - IntervalWithCount(lower=5, upper=5, type=T.POSITIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - # multiple intervals with the same end, different start, with other types inbetween - intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=2, upper=8, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=3, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=2, upper=4, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=5, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=2, upper=4, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=5, upper=6, type=T.NEGATIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=4 * factor), - IntervalWithCount(lower=3, upper=3, type=T.POSITIVE, count=3 * factor), - IntervalWithCount(lower=4, upper=4, type=T.POSITIVE, count=2 * factor), - IntervalWithCount(lower=5, upper=5, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=6, upper=6, type=T.NEGATIVE, count=2 * factor), - IntervalWithCount(lower=7, upper=8, type=T.NEGATIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - # multiple intervals with the same end, different start - intervals = [ - IntervalWithCount(lower=4, upper=6, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=6, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=2, upper=6, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=6, type=T.POSITIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=1, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=2, upper=2, type=T.POSITIVE, count=2 * factor), - IntervalWithCount(lower=3, upper=3, type=T.POSITIVE, count=3 * factor), - IntervalWithCount(lower=4, upper=6, type=T.POSITIVE, count=4 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - # multiple starts and end at the same position inbetween - intervals = [ - IntervalWithCount(lower=1, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=2, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=5, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=6, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=7, type=T.POSITIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=1, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=2, upper=2, type=T.POSITIVE, count=2 * factor), - IntervalWithCount(lower=3, upper=3, type=T.POSITIVE, count=3 * factor), - IntervalWithCount(lower=4, upper=4, type=T.POSITIVE, count=6 * factor), - IntervalWithCount(lower=5, upper=5, type=T.POSITIVE, count=3 * factor), - IntervalWithCount(lower=6, upper=6, type=T.POSITIVE, count=2 * factor), - IntervalWithCount(lower=7, upper=7, type=T.POSITIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - # multiple starts and end at the same position inbetween, other types inbetween - intervals = [ - IntervalWithCount(lower=1, upper=4, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=2, upper=4, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=5, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=6, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=7, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=2, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=5, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=6, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=7, type=T.POSITIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=1, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=2, upper=2, type=T.POSITIVE, count=2 * factor), - IntervalWithCount(lower=3, upper=3, type=T.POSITIVE, count=3 * factor), - IntervalWithCount(lower=4, upper=4, type=T.POSITIVE, count=6 * factor), - IntervalWithCount(lower=5, upper=5, type=T.POSITIVE, count=3 * factor), - IntervalWithCount(lower=6, upper=6, type=T.POSITIVE, count=2 * factor), - IntervalWithCount(lower=7, upper=7, type=T.POSITIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - # multiple starts and end at the same position inbetween, other types inbetween (random order) - intervals = [ - IntervalWithCount(lower=3, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=6, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=3, upper=4, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=4, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=5, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=1, upper=4, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=2, upper=4, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=7, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=5, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=7, type=T.NEGATIVE, count=1 * factor), - IntervalWithCount(lower=4, upper=6, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=2, upper=4, type=T.POSITIVE, count=1 * factor), - ] - - expected_intervals = [ - IntervalWithCount(lower=1, upper=1, type=T.POSITIVE, count=1 * factor), - IntervalWithCount(lower=2, upper=2, type=T.POSITIVE, count=2 * factor), - IntervalWithCount(lower=3, upper=3, type=T.POSITIVE, count=3 * factor), - IntervalWithCount(lower=4, upper=4, type=T.POSITIVE, count=6 * factor), - IntervalWithCount(lower=5, upper=5, type=T.POSITIVE, count=3 * factor), - IntervalWithCount(lower=6, upper=6, type=T.POSITIVE, count=2 * factor), - IntervalWithCount(lower=7, upper=7, type=T.POSITIVE, count=1 * factor), - ] - - result = self.process._impl.union_rects_with_count(intervals) - - assert result == expected_intervals - - class TestMergeAdjacentIntervals(ProcessTest): def test_empty_list(self): """Test that an empty list returns an empty list.""" @@ -2174,1348 +1774,73 @@ def test_union_intervals_no_data_negative(self): pd.testing.assert_frame_equal(result, expected_df) -class TestCountIntervals(ProcessTest): - def test_count_intervals_empty_dataframe_list(self): - result = self.process.count_intervals([]) - assert ( - not result - ), "Failed: Empty list of DataFrames should return an empty DataFrame" - - def test_count_intervals_single_dataframe(self): - df = pd.DataFrame( - { - "person_id": ["A", "A"], - "interval_start": pd.to_datetime( - ["2020-01-01 12:00:00+00:00", "2020-01-02 12:00:00+00:00"] - ), - "interval_end": pd.to_datetime( - ["2020-01-02 18:00:00+00:00", "2020-01-03 18:00:00+00:00"] - ), - "interval_type": [T.POSITIVE, T.POSITIVE], - } - ) - expected_df = pd.DataFrame( - { - "person_id": ["A", "A", "A"], - "interval_start": pd.to_datetime( - [ - "2020-01-01 12:00:00+00:00", - "2020-01-02 12:00:00+00:00", - "2020-01-02 18:00:01+00:00", - ] - ), - "interval_end": pd.to_datetime( - [ - "2020-01-02 11:59:59+00:00", - "2020-01-02 18:00:00+00:00", - "2020-01-03 18:00:00+00:00", - ] - ), - "interval_type": [T.POSITIVE, T.POSITIVE, T.POSITIVE], - "interval_count": [1, 2, 1], - } - ) - - result = self.process.count_intervals([df_to_person_interval_tuple(df)]) - result = self.intervals_to_df(result, ["person_id"]) - - pd.testing.assert_frame_equal(result, expected_df) +class TestFindRectangles(ProcessTest): - def test_count_intervals_overlapping_intervals(self): - df1 = pd.DataFrame( - { - "person_id": ["A"], - "interval_start": pd.to_datetime(["2020-01-01 12:00:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-05 18:00:00+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - df2 = pd.DataFrame( - { - "person_id": ["A"], - "interval_start": pd.to_datetime(["2020-01-04 12:00:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-06 12:00:00+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - expected_df = pd.DataFrame( - { - "person_id": ["A", "A", "A"], - "interval_start": pd.to_datetime( - [ - "2020-01-01 12:00:00+00:00", - "2020-01-04 12:00:00+00:00", - "2020-01-05 18:00:01+00:00", - ] - ), - "interval_end": pd.to_datetime( - [ - "2020-01-04 11:59:59+00:00", - "2020-01-05 18:00:00+00:00", - "2020-01-06 12:00:00+00:00", - ] - ), - "interval_type": [T.POSITIVE, T.POSITIVE, T.POSITIVE], - "interval_count": [1, 2, 1], - } - ) - - result = self.process.count_intervals( - [ - df_to_person_interval_tuple(df1, by=["person_id"]), - df_to_person_interval_tuple(df2, by=["person_id"]), - ] - ) - result = self.intervals_to_df(result, ["person_id"]) - - pd.testing.assert_frame_equal(result, expected_df) - - def test_count_intervals_non_overlapping_intervals(self): - df1 = pd.DataFrame( - { - "person_id": ["A"], - "interval_start": pd.to_datetime(["2020-01-01 12:00:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-03 13:30:00+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - df2 = pd.DataFrame( - { - "person_id": ["A"], - "interval_start": pd.to_datetime(["2020-01-03 13:31:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-04 18:00:00+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - expected_df = ( - pd.concat([df1, df2]).reset_index(drop=True).assign(interval_count=1) - ) - - result = self.process.count_intervals( - [ - df_to_person_interval_tuple(df1, by=["person_id"]), - df_to_person_interval_tuple(df2, by=["person_id"]), - ] - ) - result = self.intervals_to_df(result, ["person_id"]) - - pd.testing.assert_frame_equal(result, expected_df) - - def test_count_intervals_group_by_multiple_columns(self): - df1 = pd.DataFrame( - { - "group1": ["A"], - "group2": ["B"], - "interval_start": pd.to_datetime(["2020-01-01 12:00:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-02 12:00:00+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - df2 = pd.DataFrame( - { - "group1": ["A"], - "group2": ["B"], - "interval_start": pd.to_datetime(["2020-01-02 12:00:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-03 12:00:00+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - expected_df = pd.DataFrame( - { - "group1": ["A", "A", "A"], - "group2": ["B", "B", "B"], - "interval_start": pd.to_datetime( - [ - "2020-01-01 12:00:00+00:00", - "2020-01-02 12:00:00+00:00", - "2020-01-02 12:00:01+00:00", - ] - ), - "interval_end": pd.to_datetime( - [ - "2020-01-02 11:59:59+00:00", - "2020-01-02 12:00:00+00:00", - "2020-01-03 12:00:00+00:00", - ] - ), - "interval_type": [T.POSITIVE, T.POSITIVE, T.POSITIVE], - "interval_count": [1, 2, 1], - } - ) - - result = self.process.count_intervals( - [ - df_to_person_interval_tuple(df1, by=["group1", "group2"]), - df_to_person_interval_tuple(df2, by=["group1", "group2"]), - ] - ) - result = self.intervals_to_df(result, ["group1", "group2"]) - - pd.testing.assert_frame_equal(result, expected_df) - - def test_count_intervals_adjacent_intervals(self): - df1 = pd.DataFrame( - { - "person_id": ["A"], - "interval_start": pd.to_datetime(["2020-01-01 12:00:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-03 13:30:59+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - df2 = pd.DataFrame( - { - "person_id": ["A"], - "interval_start": pd.to_datetime(["2020-01-03 13:31:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-04 18:00:00+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - expected_df = pd.DataFrame( - { - "person_id": ["A"], - "interval_start": pd.to_datetime(["2020-01-01 12:00:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-04 18:00:00+00:00"]), - "interval_type": [T.POSITIVE], - "interval_count": [1], - } - ) - - result = self.process.count_intervals( - [ - df_to_person_interval_tuple(df1, by=["person_id"]), - df_to_person_interval_tuple(df2, by=["person_id"]), - ] - ) - result = self.intervals_to_df(result, ["person_id"]) - - pd.testing.assert_frame_equal(result, expected_df) - - def test_count_intervals_with_timezone(self): - data1 = { - "person_id": [1, 1], - "concept_id": ["A", "A"], - "interval_start": pd.to_datetime( - ["2023-01-01T00:00:00Z", "2023-01-02T00:00:00Z"], utc=True + @pytest.mark.parametrize( + "right_intervals, expected", + ( + ( + [Interval(4, 4, IntervalType.POSITIVE)], + Interval(4, 4, IntervalType.POSITIVE), ), - "interval_end": pd.to_datetime( - ["2023-01-01T12:00:00Z", "2023-01-02T12:00:00Z"], utc=True + ( + [Interval(4, 5, IntervalType.POSITIVE)], + Interval(4, 5, IntervalType.POSITIVE), ), - "interval_type": [T.POSITIVE, T.POSITIVE], - } - data2 = { - "person_id": [1, 1], - "concept_id": ["A", "B"], - "interval_start": pd.to_datetime( - ["2023-01-01T06:00:00Z", "2023-01-03T00:00:00Z"], utc=True + ( + [Interval(4, 6, IntervalType.POSITIVE)], + Interval(4, 6, IntervalType.POSITIVE), ), - "interval_end": pd.to_datetime( - ["2023-01-01T18:00:00Z", "2023-01-03T12:00:00Z"], utc=True + ( + [ + Interval(4, 4, IntervalType.POSITIVE), + Interval(4, 4, IntervalType.POSITIVE), + ], + Interval(4, 4, IntervalType.POSITIVE), ), - "interval_type": [T.POSITIVE, T.POSITIVE], - } - df1 = pd.DataFrame(data1) - df2 = pd.DataFrame(data2) - - expected_data = { - "person_id": [1, 1, 1, 1, 1], - "concept_id": ["A", "A", "A", "A", "B"], - "interval_start": pd.to_datetime( + # the following test currently (2025-03-28) fails, but this input is artificial, as only criteria could yield + # multiple overlapping intervals "on the same track", but these are always union-ed before propagation anyway, + # so find_rectangles with the new_interval function below should never receive such intervals. + # ([Interval(4, 5, IntervalType.POSITIVE), Interval(4, 6, IntervalType.POSITIVE)], Interval(4, 6, IntervalType.POSITIVE)), + ( [ - "2023-01-01T00:00:00Z", - "2023-01-01T06:00:00Z", - "2023-01-01T12:00:01Z", - "2023-01-02T00:00:00Z", - "2023-01-03T00:00:00Z", + Interval(4, 5, IntervalType.POSITIVE), + Interval(5, 6, IntervalType.POSITIVE), ], - utc=True, + Interval(4, 6, IntervalType.POSITIVE), ), - "interval_end": pd.to_datetime( + ( [ - "2023-01-01T05:59:59Z", - "2023-01-01T12:00:00Z", - "2023-01-01T18:00:00Z", - "2023-01-02T12:00:00Z", - "2023-01-03T12:00:00Z", + Interval(4, 5, IntervalType.POSITIVE), + Interval(6, 6, IntervalType.POSITIVE), ], - utc=True, + Interval(4, 6, IntervalType.POSITIVE), ), - "interval_type": [ - T.POSITIVE, - T.POSITIVE, - T.POSITIVE, - T.POSITIVE, - T.POSITIVE, - ], - "interval_count": [1, 2, 1, 1, 1], - } - expected_df = pd.DataFrame(expected_data) - - result_df = self.process.count_intervals( - [ - df_to_person_interval_tuple(df1, by=["person_id", "concept_id"]), - df_to_person_interval_tuple(df2, by=["person_id", "concept_id"]), - ] - ) - result_df = self.intervals_to_df(result_df, ["person_id", "concept_id"]) - - pd.testing.assert_frame_equal(result_df, expected_df) - - def test_count_intervals_group_by_multiple_columns_complex_data2(self): - data1 = """ - group1 group2 interval_start interval_end interval_type - B 1 2024-01-01 13:00:00 2024-01-01 14:00:00 POSITIVE - B 1 2024-01-01 15:00:00 2024-01-01 16:00:00 POSITIVE - B 1 2024-01-01 17:00:00 2024-01-01 18:00:00 POSITIVE - B 1 2024-01-01 19:00:00 2024-01-01 20:00:00 POSITIVE - """ - df1 = df_from_str(data1) - - data2 = """ - group1 group2 interval_start interval_end interval_type - B 1 2024-01-01 18:00:00 2024-01-01 19:00:00 POSITIVE - B 1 2024-01-01 13:00:00 2024-01-01 14:00:00 POSITIVE - B 1 2024-01-01 15:00:00 2024-01-01 16:00:00 POSITIVE - """ - df2 = df_from_str(data2) - - expected_data = """ - group1 group2 interval_start interval_end interval_type interval_count - B 1 2024-01-01 13:00:00 2024-01-01 14:00:00 POSITIVE 2 - B 1 2024-01-01 15:00:00 2024-01-01 16:00:00 POSITIVE 2 - B 1 2024-01-01 17:00:00 2024-01-01 17:59:59 POSITIVE 1 - B 1 2024-01-01 18:00:00 2024-01-01 18:00:00 POSITIVE 2 - B 1 2024-01-01 18:00:01 2024-01-01 18:59:59 POSITIVE 1 - B 1 2024-01-01 19:00:00 2024-01-01 19:00:00 POSITIVE 2 - B 1 2024-01-01 19:00:01 2024-01-01 20:00:00 POSITIVE 1 - """ - expected_df = ( - df_from_str(expected_data) - .sort_values(by=["group1", "group2", "interval_start", "interval_end"]) - .reset_index(drop=True) - ) - - result = self.process.count_intervals( - [ - df_to_person_interval_tuple(df1, by=["group1", "group2"]), - df_to_person_interval_tuple(df2, by=["group1", "group2"]), + ), + ) + def test_counting(self, right_intervals, expected): + def new_interval(start: int, end: int, intervals): + left_interval, right_interval, observation_window_ = intervals + if (left_interval is None) or left_interval.type != IntervalType.POSITIVE: + # no left_interval or not positive -> use fill type + return Interval(start, end, IntervalType.NOT_APPLICABLE) + elif right_interval is not None: + return interval_like(right_interval, start, end) + else: # left_interval but not right_interval -> implicit negative + return None + + left = {1: [Interval(2, 8, IntervalType.POSITIVE)]} + right = {1: right_intervals} + window_intervals = {1: [Interval(0, 10, IntervalType.POSITIVE)]} + + result = self.process.find_rectangles( + [left, right, window_intervals], new_interval + ) + + assert result == { + 1: [ + Interval(lower=0, upper=1, type=IntervalType.NOT_APPLICABLE), + expected, + Interval(lower=9, upper=10, type=IntervalType.NOT_APPLICABLE), ] - ) - result = ( - self.intervals_to_df(result, ["group1", "group2"]) - .sort_values(by=["group1", "group2", "interval_start", "interval_end"]) - .reset_index(drop=True) - ) - - pd.testing.assert_frame_equal(result, expected_df) - - def test_count_intervals_group_by_multiple_columns_complex_data(self): - data1 = """ - group1 group2 interval_start interval_end interval_type - A 1 2023-01-01 12:00:00 2023-01-02 12:00:00 POSITIVE - A 1 2023-01-02 05:00:00 2023-01-03 05:00:00 POSITIVE - A 1 2023-01-03 06:00:00 2023-01-03 12:00:00 POSITIVE - A 1 2023-01-03 13:00:00 2023-01-04 12:00:00 POSITIVE - A 1 2023-01-04 18:00:00 2023-01-04 20:00:00 POSITIVE - A 1 2023-01-05 06:00:00 2023-01-05 23:59:00 POSITIVE - A 2 2023-02-01 12:59:00 2023-02-01 12:59:00 POSITIVE - A 2 2023-02-01 12:59:01 2023-02-01 12:59:01 POSITIVE - B 1 2024-01-01 13:00:00 2024-01-01 14:00:00 POSITIVE - B 1 2024-01-01 15:00:00 2024-01-01 16:00:00 POSITIVE - B 1 2024-01-01 17:00:00 2024-01-01 18:00:00 POSITIVE - B 1 2024-01-01 19:00:00 2024-01-01 20:00:00 POSITIVE - """ - df1 = df_from_str(data1) - - data2 = """ - group1 group2 interval_start interval_end interval_type - A 1 2023-01-03 12:00:01 2023-01-03 12:59:59 POSITIVE - B 2 2024-02-01 12:00:00 2024-02-01 13:00:00 POSITIVE - B 2 2024-02-01 13:00:00 2024-02-01 14:00:00 POSITIVE - B 2 2024-02-01 15:00:00 2024-02-01 16:00:00 POSITIVE - B 1 2024-01-01 18:00:00 2024-01-01 19:00:00 POSITIVE - B 1 2024-01-01 13:00:00 2024-01-01 14:00:00 POSITIVE - B 1 2024-01-01 15:00:00 2024-01-01 16:00:00 POSITIVE - """ - df2 = df_from_str(data2) - - data3 = """ - group1 group2 interval_start interval_end interval_type - A 2 2023-02-01 06:00:00 2023-02-01 06:00:00 POSITIVE - A 2 2023-02-01 06:00:02 2023-02-01 12:58:58 POSITIVE - A 1 2023-01-04 22:00:00 2023-01-05 02:00:00 POSITIVE - B 3 2023-03-04 22:00:00 2023-03-05 02:00:00 POSITIVE - B 2 2024-02-01 15:00:00 2024-02-01 16:00:01 POSITIVE - """ - df3 = df_from_str(data3) - - expected_data = """ - group1 group2 interval_start interval_end interval_type interval_count - A 1 2023-01-01 12:00:00 2023-01-02 04:59:59 POSITIVE 1 - A 1 2023-01-02 05:00:00 2023-01-02 12:00:00 POSITIVE 2 - A 1 2023-01-02 12:00:01 2023-01-03 05:00:00 POSITIVE 1 - A 1 2023-01-03 06:00:00 2023-01-04 12:00:00 POSITIVE 1 - A 1 2023-01-04 18:00:00 2023-01-04 20:00:00 POSITIVE 1 - A 1 2023-01-04 22:00:00 2023-01-05 02:00:00 POSITIVE 1 - A 1 2023-01-05 06:00:00 2023-01-05 23:59:00 POSITIVE 1 - A 2 2023-02-01 12:59:00 2023-02-01 12:59:01 POSITIVE 1 - A 2 2023-02-01 06:00:00 2023-02-01 06:00:00 POSITIVE 1 - A 2 2023-02-01 06:00:02 2023-02-01 12:58:58 POSITIVE 1 - B 1 2024-01-01 13:00:00 2024-01-01 14:00:00 POSITIVE 2 - B 1 2024-01-01 15:00:00 2024-01-01 16:00:00 POSITIVE 2 - B 1 2024-01-01 17:00:00 2024-01-01 17:59:59 POSITIVE 1 - B 1 2024-01-01 18:00:00 2024-01-01 18:00:00 POSITIVE 2 - B 1 2024-01-01 18:00:01 2024-01-01 18:59:59 POSITIVE 1 - B 1 2024-01-01 19:00:00 2024-01-01 19:00:00 POSITIVE 2 - B 1 2024-01-01 19:00:01 2024-01-01 20:00:00 POSITIVE 1 - B 2 2024-02-01 12:00:00 2024-02-01 12:59:59 POSITIVE 1 - B 2 2024-02-01 13:00:00 2024-02-01 13:00:00 POSITIVE 2 - B 2 2024-02-01 13:00:01 2024-02-01 14:00:00 POSITIVE 1 - B 2 2024-02-01 15:00:00 2024-02-01 16:00:00 POSITIVE 2 - B 2 2024-02-01 16:00:01 2024-02-01 16:00:01 POSITIVE 1 - B 3 2023-03-04 22:00:00 2023-03-05 02:00:00 POSITIVE 1 - """ - expected_df = ( - df_from_str(expected_data) - .sort_values(by=["group1", "group2", "interval_start", "interval_end"]) - .reset_index(drop=True) - ) - - result = self.process.count_intervals( - [ - df_to_person_interval_tuple(df1, by=["group1", "group2"]), - df_to_person_interval_tuple(df2, by=["group1", "group2"]), - df_to_person_interval_tuple(df3, by=["group1", "group2"]), - ] - ) - result = ( - self.intervals_to_df(result, ["group1", "group2"]) - .sort_values(by=["group1", "group2", "interval_start", "interval_end"]) - .reset_index(drop=True) - ) - - pd.testing.assert_frame_equal(result, expected_df) - - def test_count_intervals_edge_case(self): - data1 = """ - person_id interval_start interval_end interval_type - 30833 2023-03-02 13:00:01+00:00 2023-03-02 14:00:00+00:00 POSITIVE - 30833 2023-03-02 14:00:01+00:00 2023-03-02 19:00:00+00:00 NEGATIVE - """ - df1 = df_from_str(data1) - - data2 = """ - person_id interval_start interval_end interval_type - 30833 2023-03-02 14:00:01+00:00 2023-03-02 15:00:00+00:00 POSITIVE - """ - df2 = df_from_str(data2) - - expected_data = """ - person_id interval_start interval_end interval_type interval_count - 30833 2023-03-02 13:00:01+00:00 2023-03-02 15:00:00+00:00 POSITIVE 1 - 30833 2023-03-02 15:00:01+00:00 2023-03-02 19:00:00+00:00 NEGATIVE 1 - """ - expected_df = df_from_str(expected_data) - - result = self.process.count_intervals( - [ - df_to_person_interval_tuple(df1, by=["person_id"]), - df_to_person_interval_tuple(df2, by=["person_id"]), - ] - ) - result = self.intervals_to_df(result, ["person_id"]) - - pd.testing.assert_frame_equal(result, expected_df) - - data1 = """ - person_id interval_start interval_end interval_type - 30833 2023-03-02 13:00:01+00:00 2023-03-02 19:00:00+00:00 NEGATIVE - """ - df1 = df_from_str(data1) - - data2 = """ - person_id interval_start interval_end interval_type - 30833 2023-03-02 13:00:01+00:00 2023-03-02 14:00:00+00:00 POSITIVE - 30833 2023-03-02 14:00:01+00:00 2023-03-02 19:00:00+00:00 NEGATIVE - """ - df2 = df_from_str(data2) - - expected_data = """ - person_id interval_start interval_end interval_type interval_count - 30833 2023-03-02 13:00:01+00:00 2023-03-02 14:00:00+00:00 POSITIVE 1 - 30833 2023-03-02 14:00:01+00:00 2023-03-02 19:00:00+00:00 NEGATIVE 2 - """ - expected_df = df_from_str(expected_data) - - result = self.process.count_intervals( - [ - df_to_person_interval_tuple(df1, by=["person_id"]), - df_to_person_interval_tuple(df2, by=["person_id"]), - ] - ) - result = self.intervals_to_df(result, ["person_id"]) - - pd.testing.assert_frame_equal(result, expected_df) - - def test_count_intervals_edge_case_int(self): - intervals1 = { - 1: [ - Interval(lower=1, upper=4, type=T.NEGATIVE), - ] - } - - intervals2 = { - 1: [ - Interval(lower=1, upper=2, type=T.POSITIVE), - Interval(lower=3, upper=4, type=T.NEGATIVE), - ] - } - - expected_intervals = { - 1: [ - IntervalWithCount(lower=1, upper=2, type=T.POSITIVE, count=1), - IntervalWithCount(lower=3, upper=4, type=T.NEGATIVE, count=2), - ] - } - - result = self.process.count_intervals([intervals1, intervals2]) - - assert len(result) == len(expected_intervals) - assert result == expected_intervals - - def test_count_intervals_no_data_negative_int(self): - intervals1 = [ - Interval(lower=1, upper=2, type=T.NO_DATA), - Interval(lower=3, upper=4, type=T.POSITIVE), - Interval(lower=5, upper=6, type=T.NO_DATA), - ] - - intervals2 = [ - Interval(lower=1, upper=2, type=T.NO_DATA), - Interval(lower=3, upper=4, type=T.NEGATIVE), - Interval(lower=5, upper=6, type=T.NO_DATA), - ] - - expected_intervals = { - 1: [ - IntervalWithCount(lower=1, upper=2, type=T.NO_DATA, count=2), - IntervalWithCount(lower=3, upper=4, type=T.POSITIVE, count=1), - IntervalWithCount(lower=5, upper=6, type=T.NO_DATA, count=2), - ] - } - - result = self.process.count_intervals([{1: intervals1}, {1: intervals2}]) - - assert len(result) == len(expected_intervals) - assert result == expected_intervals - - def test_count_intervals_no_data_negative(self): - data1 = """ - person_id interval_start interval_end interval_type - 30748 2023-02-26 07:00:00+00:00 2023-03-02 12:59:59+00:00 NO_DATA - 30748 2023-03-02 13:00:00+00:00 2023-03-02 14:00:00+00:00 POSITIVE - 30748 2023-03-02 14:00:01+00:00 2023-04-03 23:00:00+00:00 NO_DATA - """ - df1 = df_from_str(data1) - - data2 = """ - person_id interval_start interval_end interval_type - 30748 2023-02-26 07:00:00+00:00 2023-03-02 12:59:59+00:00 NO_DATA - 30748 2023-03-02 13:00:00+00:00 2023-03-02 14:00:00+00:00 NEGATIVE - 30748 2023-03-02 14:00:01+00:00 2023-04-03 23:00:00+00:00 NO_DATA - """ - - df2 = df_from_str(data2) - - expected_data = """ - person_id interval_start interval_end interval_type interval_count - 30748 2023-02-26 07:00:00+00:00 2023-03-02 12:59:59+00:00 NO_DATA 2 - 30748 2023-03-02 13:00:00+00:00 2023-03-02 14:00:00+00:00 POSITIVE 1 - 30748 2023-03-02 14:00:01+00:00 2023-04-03 23:00:00+00:00 NO_DATA 2 - """ - expected_df = df_from_str(expected_data) - - result = self.process.count_intervals( - [ - df_to_person_interval_tuple(df1, by=["person_id"]), - df_to_person_interval_tuple(df2, by=["person_id"]), - ] - ) - result = self.intervals_to_df(result, ["person_id"]) - - pd.testing.assert_frame_equal(result, expected_df) - - -class TestIntersectIntervals(ProcessTest): - def test_intersect_intervals_empty_dataframe_list(self): - result = self.process.intersect_intervals([]) - assert ( - not result - ), "Failed: Empty list of DataFrames should return an empty DataFrame" - - def test_intersect_intervals_single_dataframe(self): - df = pd.DataFrame( - { - "person_id": ["A", "A"], - "interval_start": pd.to_datetime( - ["2020-01-01 08:00:00+00:00", "2020-01-02 09:00:00+00:00"] - ), - "interval_end": pd.to_datetime( - ["2020-01-01 10:00:00+00:00", "2020-01-02 11:00:00+00:00"] - ), - "interval_type": [T.POSITIVE, T.POSITIVE], - } - ) - by = ["person_id"] - result = self.process.intersect_intervals( - [df_to_person_interval_tuple(df, by=by)] - ) - result = self.intervals_to_df(result, by=by) - expected_df = df.copy() - pd.testing.assert_frame_equal(result, expected_df) - - def test_intersect_intervals_intersecting_intervals(self): - df1 = pd.DataFrame( - { - "person_id": ["A"], - "interval_start": pd.to_datetime(["2020-01-01 08:00:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-01 10:00:00+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - df2 = pd.DataFrame( - { - "person_id": ["A"], - "interval_start": pd.to_datetime(["2020-01-01 09:00:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-01 11:00:00+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - expected_df = pd.DataFrame( - { - "person_id": ["A"], - "interval_start": pd.to_datetime(["2020-01-01 09:00:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-01 10:00:00+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - - by = ["person_id"] - result = self.process.intersect_intervals( - [ - df_to_person_interval_tuple(df1, by=by), - df_to_person_interval_tuple(df2, by=by), - ] - ) - result = self.intervals_to_df(result, by=by) - - pd.testing.assert_frame_equal(result, expected_df) - - def test_intersect_intervals_no_intersecting_intervals(self): - df1 = pd.DataFrame( - { - "person_id": ["A"], - "interval_start": pd.to_datetime(["2020-01-01 08:00:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-01 09:00:00+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - df2 = pd.DataFrame( - { - "person_id": ["A"], - "interval_start": pd.to_datetime(["2020-01-01 10:00:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-01 11:00:00+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - expected_df = pd.DataFrame( - columns=["person_id", "interval_start", "interval_end", "interval_type"] - ) - - by = ["person_id"] - result = self.process.intersect_intervals( - [ - df_to_person_interval_tuple(df1, by=by), - df_to_person_interval_tuple(df2, by=by), - ] - ) - result = self.intervals_to_df(result, by=by) - - pd.testing.assert_frame_equal(result, expected_df) - - def test_intersect_intervals_group_by_multiple_columns(self): - df1 = pd.DataFrame( - { - "group1": ["A"], - "group2": ["B"], - "interval_start": pd.to_datetime(["2020-01-01 08:00:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-01 09:00:00+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - df2 = pd.DataFrame( - { - "group1": ["A"], - "group2": ["B"], - "interval_start": pd.to_datetime(["2020-01-01 08:30:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-01 09:30:00+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - expected_df = pd.DataFrame( - { - "group1": ["A"], - "group2": ["B"], - "interval_start": pd.to_datetime(["2020-01-01 08:30:00+00:00"]), - "interval_end": pd.to_datetime(["2020-01-01 09:00:00+00:00"]), - "interval_type": [T.POSITIVE], - } - ) - - by = ["group1", "group2"] - result = self.process.intersect_intervals( - [ - df_to_person_interval_tuple(df1, by=by), - df_to_person_interval_tuple(df2, by=by), - ] - ) - result = self.intervals_to_df(result, by=by) - - pd.testing.assert_frame_equal(result, expected_df) - - def test_intersect_interval_with_timezone(self): - # Prepare test data with datetime64[ns, UTC] dtype - data1 = { - "person_id": [1, 1], - "concept_id": ["A", "A"], - "interval_start": pd.to_datetime( - ["2023-01-01T00:00:00Z", "2023-01-02T00:00:00Z"], utc=True - ), - "interval_end": pd.to_datetime( - ["2023-01-01T12:00:00Z", "2023-01-02T12:00:00Z"], utc=True - ), - "interval_type": [T.POSITIVE, T.POSITIVE], - } - data2 = { - "person_id": [1, 1], - "concept_id": ["A", "B"], - "interval_start": pd.to_datetime( - ["2023-01-01T06:00:00Z", "2023-01-03T00:00:00Z"], utc=True - ), - "interval_end": pd.to_datetime( - ["2023-01-01T18:00:00Z", "2023-01-03T12:00:00Z"], utc=True - ), - "interval_type": [T.POSITIVE, T.POSITIVE], - } - df1 = pd.DataFrame(data1) - df2 = pd.DataFrame(data2) - - # Call the function - by = ["person_id", "concept_id"] - result = self.process.intersect_intervals( - [ - df_to_person_interval_tuple(df1, by=by), - df_to_person_interval_tuple(df2, by=by), - ] - ) - result = self.intervals_to_df(result, by=by) - - # Define expected output - expected_data = { - "person_id": [1], - "concept_id": ["A"], - "interval_start": pd.to_datetime(["2023-01-01T06:00:00Z"], utc=True), - "interval_end": pd.to_datetime(["2023-01-01T12:00:00Z"], utc=True), - "interval_type": [T.POSITIVE], - } - expected_df = pd.DataFrame(expected_data) - - # Assert - pd.testing.assert_frame_equal(result, expected_df) - - def test_intersect_intervals_group_by_multiple_columns_complex_data(self): - data1 = """ - group1 group2 interval_start interval_end interval_type - A 1 2023-01-01 12:00:00 2023-01-02 12:00:00 POSITIVE - A 1 2023-01-02 05:00:00 2023-01-03 05:00:00 POSITIVE - A 1 2023-01-03 06:00:00 2023-01-03 12:30:00 POSITIVE - A 1 2023-01-03 13:00:00 2023-01-04 12:00:00 POSITIVE - A 1 2023-01-04 18:00:00 2023-01-04 20:00:00 POSITIVE - A 1 2023-01-05 06:00:00 2023-01-05 23:59:00 POSITIVE - A 2 2023-02-01 12:59:00 2023-02-01 12:59:00 POSITIVE - A 2 2023-02-01 12:59:01 2023-02-01 12:59:01 POSITIVE - B 1 2024-01-01 13:00:00 2024-01-01 14:00:00 POSITIVE - B 1 2024-01-01 15:00:00 2024-01-01 16:00:00 POSITIVE - B 2 2023-02-01 12:00:00 2023-02-01 12:59:00 POSITIVE - B 3 2024-03-01 16:00:00 2024-03-01 20:00:00 POSITIVE - B 4 2024-04-01 12:00:00 2024-04-02 12:00:00 POSITIVE - """ - df1 = df_from_str(data1) - - data2 = """ - group1 group2 interval_start interval_end interval_type - A 1 2023-01-03 12:00:01 2023-01-03 12:59:59 POSITIVE - A 2 2023-02-01 06:00:00 2023-02-02 12:00:00 POSITIVE - B 2 2024-02-01 13:00:00 2024-02-01 14:00:00 POSITIVE - B 2 2024-02-01 15:00:00 2024-02-01 16:00:00 POSITIVE - B 1 2024-01-01 18:00:00 2024-01-01 19:00:00 POSITIVE - B 1 2024-01-01 13:10:00 2024-01-01 13:20:00 POSITIVE - B 1 2024-01-01 13:30:00 2024-01-01 13:40:00 POSITIVE - B 1 2024-01-01 13:50:00 2024-01-01 14:00:00 POSITIVE - B 2 2023-02-01 12:59:00 2023-02-01 14:00:00 POSITIVE - B 3 2024-03-01 16:00:00 2024-03-01 18:00:00 POSITIVE - B 4 2024-04-01 12:00:00 2024-04-02 12:00:00 POSITIVE - B 5 2024-05-01 16:00:00 2024-05-01 18:00:00 POSITIVE - """ - df2 = df_from_str(data2) - - data3 = """ - group1 group2 interval_start interval_end interval_type - A 1 2023-01-03 11:00:01 2023-01-03 12:59:59 POSITIVE - A 2 2023-02-01 06:00:02 2023-02-01 12:59:00 POSITIVE - A 1 2023-01-04 22:00:00 2023-01-05 02:00:00 POSITIVE - B 3 2023-03-04 22:00:00 2023-03-05 02:00:00 POSITIVE - B 1 2024-01-01 00:00:00 2024-01-01 13:25:00 POSITIVE - B 1 2024-01-01 13:45:00 2024-01-01 13:55:00 POSITIVE - B 1 2024-01-01 13:51:00 2024-01-01 13:53:00 POSITIVE - B 2 2023-02-01 12:00:00 2023-02-01 14:00:00 POSITIVE - B 3 2024-03-01 18:00:00 2024-03-01 20:00:00 POSITIVE - B 5 2024-05-01 18:00:00 2024-05-01 20:00:00 POSITIVE - """ - df3 = df_from_str(data3) - - expected_data = """ - group1 group2 interval_start interval_end interval_type - A 1 2023-01-03 12:00:01 2023-01-03 12:30:00 POSITIVE - A 2 2023-02-01 12:59:00 2023-02-01 12:59:00 POSITIVE - B 1 2024-01-01 13:10:00 2024-01-01 13:20:00 POSITIVE - B 1 2024-01-01 13:50:00 2024-01-01 13:55:00 POSITIVE - B 2 2023-02-01 12:59:00 2023-02-01 12:59:00 POSITIVE - B 3 2024-03-01 18:00:00 2024-03-01 18:00:00 POSITIVE - """ - expected_df = df_from_str(expected_data) - - by = ["group1", "group2"] - result = self.process.intersect_intervals( - [ - df_to_person_interval_tuple(df1, by=by), - df_to_person_interval_tuple(df2, by=by), - df_to_person_interval_tuple(df3, by=by), - ] - ) - result = ( - self.intervals_to_df(result, by=by) - .sort_values(by=by) - .reset_index(drop=True) - ) - - pd.testing.assert_frame_equal(result, expected_df) - - -class TestIntervalFilling(ProcessTest): - def assert_equal(self, data, expected, observation_window=None): - def to_df(data): - df = pd.DataFrame( - data, - columns=[ - "person_id", - "interval_start", - "interval_end", - "interval_type", - ], - ) - df["interval_start"] = pd.to_datetime(df["interval_start"]) - df["interval_end"] = pd.to_datetime(df["interval_end"], utc=True) - - return df - - result = self.process.forward_fill( - df_to_person_interval_tuple(to_df(data), by=["person_id"]), - observation_window, - ) - df_result = self.intervals_to_df(result, ["person_id"]) - df_expected = to_df(expected) - - pd.testing.assert_frame_equal(df_result, df_expected, check_dtype=False) - - def test_single_row(self): - data = [ - (1, "2023-03-01 08:00:00+00:00", "2023-03-01 08:00:00+00:00", "POSITIVE"), - ] - - expected = [ - (1, "2023-03-01 08:00:00+00:00", "2023-03-01 08:00:00+00:00", "POSITIVE"), - ] - - self.assert_equal(data, expected) - - def test_empty(self): - data = [] - expected = [] - - self.assert_equal(data, expected) - - def test_single_type_per_person(self): - data = [ - (1, "2023-03-01 08:00:00+00:00", "2023-03-01 08:00:00+00:00", "POSITIVE"), - (1, "2023-03-01 09:00:00+00:00", "2023-03-01 09:00:00+00:00", "POSITIVE"), - (1, "2023-03-01 09:10:00+00:00", "2023-03-01 10:00:00+00:00", "POSITIVE"), - (2, "2023-03-01 11:00:00+00:00", "2023-03-01 11:00:00+00:00", "POSITIVE"), - (2, "2023-03-01 12:00:00+00:00", "2023-03-01 13:00:00+00:00", "POSITIVE"), - (2, "2023-03-01 14:00:00+00:00", "2023-03-01 15:00:00+00:00", "POSITIVE"), - ] - - expected = [ - (1, "2023-03-01 08:00:00+00:00", "2023-03-01 10:00:00+00:00", "POSITIVE"), - (2, "2023-03-01 11:00:00+00:00", "2023-03-01 15:00:00+00:00", "POSITIVE"), - ] - - self.assert_equal(data, expected) - - def test_last_row_different(self): - data = [ - (1, "2023-03-01 08:00:00+00:00", "2023-03-01 08:00:00+00:00", "POSITIVE"), - (1, "2023-03-01 09:00:00+00:00", "2023-03-01 09:00:00+00:00", "POSITIVE"), - (1, "2023-03-01 09:10:00+00:00", "2023-03-01 10:00:00+00:00", "NEGATIVE"), - (2, "2023-03-01 11:00:00+00:00", "2023-03-01 11:00:00+00:00", "POSITIVE"), - (2, "2023-03-01 12:00:00+00:00", "2023-03-01 13:00:00+00:00", "POSITIVE"), - (2, "2023-03-01 14:00:00+00:00", "2023-03-01 15:00:00+00:00", "NEGATIVE"), - ] - - expected = [ - (1, "2023-03-01 08:00:00+00:00", "2023-03-01 09:09:59+00:00", "POSITIVE"), - (1, "2023-03-01 09:10:00+00:00", "2023-03-01 10:00:00+00:00", "NEGATIVE"), - (2, "2023-03-01 11:00:00+00:00", "2023-03-01 13:59:59+00:00", "POSITIVE"), - (2, "2023-03-01 14:00:00+00:00", "2023-03-01 15:00:00+00:00", "NEGATIVE"), - ] - - self.assert_equal(data, expected) - - def test_forward_fill(self): - data = [ - (1, "2023-03-01 08:00:00+00:00", "2023-03-01 08:00:00+00:00", "POSITIVE"), - (1, "2023-03-01 09:00:00+00:00", "2023-03-01 09:00:00+00:00", "POSITIVE"), - (1, "2023-03-01 09:10:00+00:00", "2023-03-01 10:00:00+00:00", "POSITIVE"), - (1, "2023-03-01 11:00:00+00:00", "2023-03-01 11:00:00+00:00", "NEGATIVE"), - (1, "2023-03-01 12:00:00+00:00", "2023-03-01 13:00:00+00:00", "POSITIVE"), - (1, "2023-03-01 14:00:00+00:00", "2023-03-01 15:00:00+00:00", "POSITIVE"), - ] - - expected = [ - (1, "2023-03-01 08:00:00+00:00", "2023-03-01 10:59:59+00:00", "POSITIVE"), - (1, "2023-03-01 11:00:00+00:00", "2023-03-01 11:59:59+00:00", "NEGATIVE"), - (1, "2023-03-01 12:00:00+00:00", "2023-03-01 15:00:00+00:00", "POSITIVE"), - ] - self.assert_equal(data, expected) - - data = [ - (1, "2021-01-01 08:00:00+00:00", "2021-01-01 09:00:00+00:00", "POSITIVE"), - (1, "2021-01-01 09:00:00+00:00", "2021-01-01 10:00:00+00:00", "POSITIVE"), - (2, "2021-01-02 10:00:00+00:00", "2021-01-02 10:15:00+00:00", "NEGATIVE"), - (2, "2021-01-02 10:30:00+00:00", "2021-01-02 11:00:00+00:00", "POSITIVE"), - (2, "2021-01-02 11:30:00+00:00", "2021-01-02 12:00:00+00:00", "NEGATIVE"), - (3, "2021-01-03 12:00:00+00:00", "2021-01-03 12:30:00+00:00", "POSITIVE"), - (3, "2021-01-03 12:45:00+00:00", "2021-01-03 13:00:00+00:00", "NEGATIVE"), - ] - - expected = [ - (1, "2021-01-01 08:00:00+00:00", "2021-01-01 10:00:00+00:00", "POSITIVE"), - (2, "2021-01-02 10:00:00+00:00", "2021-01-02 10:29:59+00:00", "NEGATIVE"), - (2, "2021-01-02 10:30:00+00:00", "2021-01-02 11:29:59+00:00", "POSITIVE"), - (2, "2021-01-02 11:30:00+00:00", "2021-01-02 12:00:00+00:00", "NEGATIVE"), - (3, "2021-01-03 12:00:00+00:00", "2021-01-03 12:44:59+00:00", "POSITIVE"), - (3, "2021-01-03 12:45:00+00:00", "2021-01-03 13:00:00+00:00", "NEGATIVE"), - ] - - self.assert_equal(data, expected) - - def test_forward_fill_with_observation_window(self): - observation_window = TimeRange.from_tuple( - ("2023-03-01 08:00:00+00:00", "2023-03-15 15:00:00+00:00") - ) - data = [ - (1, "2023-03-01 08:00:00+00:00", "2023-03-01 08:00:00+00:00", "POSITIVE"), - (1, "2023-03-01 09:00:00+00:00", "2023-03-01 09:00:00+00:00", "POSITIVE"), - (1, "2023-03-01 09:10:00+00:00", "2023-03-01 10:00:00+00:00", "POSITIVE"), - (1, "2023-03-01 11:00:00+00:00", "2023-03-01 11:00:00+00:00", "NEGATIVE"), - (1, "2023-03-01 12:00:00+00:00", "2023-03-01 13:00:00+00:00", "POSITIVE"), - (1, "2023-03-01 14:00:00+00:00", "2023-03-01 15:00:00+00:00", "POSITIVE"), - ] - - expected = [ - (1, "2023-03-01 08:00:00+00:00", "2023-03-01 10:59:59+00:00", "POSITIVE"), - (1, "2023-03-01 11:00:00+00:00", "2023-03-01 11:59:59+00:00", "NEGATIVE"), - (1, "2023-03-01 12:00:00+00:00", "2023-03-15 15:00:00+00:00", "POSITIVE"), - ] - self.assert_equal(data, expected, observation_window) - - # with timezone - observation_window = TimeRange.from_tuple( - ("2023-03-01 08:00:00+01:00", "2023-04-15 15:00:00+02:00") - ) - data = [ - (1, "2023-03-01 08:00:00+01:00", "2023-03-01 08:00:00+01:00", "POSITIVE"), - (1, "2023-03-01 09:00:00+01:00", "2023-03-01 09:00:00+01:00", "POSITIVE"), - (1, "2023-03-01 09:10:00+01:00", "2023-03-01 10:00:00+01:00", "POSITIVE"), - (1, "2023-03-01 11:00:00+01:00", "2023-03-01 11:00:00+01:00", "NEGATIVE"), - (1, "2023-03-01 12:00:00+01:00", "2023-03-01 13:00:00+01:00", "POSITIVE"), - (1, "2023-03-01 14:00:00+01:00", "2023-03-01 15:00:00+01:00", "POSITIVE"), - ] - - expected = [ - (1, "2023-03-01 08:00:00+01:00", "2023-03-01 10:59:59+01:00", "POSITIVE"), - (1, "2023-03-01 11:00:00+01:00", "2023-03-01 11:59:59+01:00", "NEGATIVE"), - (1, "2023-03-01 12:00:00+01:00", "2023-04-15 15:00:00+02:00", "POSITIVE"), - ] - self.assert_equal(data, expected, observation_window) - - -class TestCreateTimeIntervals(ProcessTest): - # Helper to create timezone-aware datetime objects using pendulum - def tz_aware_datetime(self, date_str, timezone): - return pendulum.parse(date_str, tz=timezone) - - @pytest.mark.parametrize("timezone", ["America/New_York", "Europe/Berlin", "UTC"]) - def test_naive_datetimes(self, timezone): - start_datetime = pendulum.parse("2023-07-01 12:00:00").naive() - end_datetime = pendulum.parse("2023-07-03 12:00:00").naive() - start_time = time(9, 0) - end_time = time(17, 0) - intervals = self.process.create_time_intervals( - start_datetime, - end_datetime, - start_time, - end_time, - interval_type=T.POSITIVE, - timezone=timezone, - ) - # Ignore intervals of type NOT_APPLICABLE at the boundary of the period - intervals = [i for i in intervals if i.type == T.POSITIVE] - assert len(intervals) == 3 # Expecting intervals for July 1st and 2nd - assert ( - intervals[0].lower - == pendulum.parse("2023-07-01 12:00:00", tz=timezone).timestamp() - ) - assert ( - intervals[0].upper - == pendulum.parse("2023-07-01 17:00:00", tz=timezone).timestamp() - ) - assert ( - intervals[1].lower - == pendulum.parse("2023-07-02 09:00:00", tz=timezone).timestamp() - ) - assert ( - intervals[1].upper - == pendulum.parse("2023-07-02 17:00:00", tz=timezone).timestamp() - ) - assert ( - intervals[2].lower - == pendulum.parse("2023-07-03 09:00:00", tz=timezone).timestamp() - ) - assert ( - intervals[2].upper - == pendulum.parse("2023-07-03 12:00:00", tz=timezone).timestamp() - ) - - def test_aware_datetimes(self): - tz = "America/New_York" - start_datetime = self.tz_aware_datetime("2023-07-01 12:00:00", tz) - end_datetime = self.tz_aware_datetime("2023-07-03 12:00:00", tz) - start_time = time(22, 0) - end_time = time(6, 0) - intervals = self.process.create_time_intervals( - start_datetime, - end_datetime, - start_time, - end_time, - interval_type=T.POSITIVE, - timezone=tz, - ) - # Ignore intervals of type NOT_APPLICABLE at the boundary of the period - intervals = [ i for i in intervals if i.type == T.POSITIVE ] - assert ( - len(intervals) == 2 - ) # Expecting intervals for the nights of July 1st and 2nd - assert ( - intervals[0].lower - == self.tz_aware_datetime("2023-07-01 22:00:00", tz).timestamp() - ) - assert ( - intervals[0].upper - == self.tz_aware_datetime("2023-07-02 06:00:00", tz).timestamp() - ) - assert ( - intervals[1].lower - == self.tz_aware_datetime("2023-07-02 22:00:00", tz).timestamp() - ) - assert ( - intervals[1].upper - == self.tz_aware_datetime("2023-07-03 06:00:00", tz).timestamp() - ) - - @pytest.mark.parametrize("timezone", ["America/New_York", "Europe/Berlin", "UTC"]) - def test_spanning_midnight_naive(self, timezone): - start_datetime = pendulum.parse("2023-07-01 12:00:00", tz=timezone) - end_datetime = pendulum.parse("2023-07-03 04:00:00", tz=timezone) - start_time = time(22, 0) - end_time = time(6, 0) - intervals = self.process.create_time_intervals( - start_datetime, - end_datetime, - start_time, - end_time, - interval_type=T.POSITIVE, - timezone=timezone, - ) - # Ignore intervals of type NOT_APPLICABLE at the boundary of the period - intervals = [i for i in intervals if i.type == T.POSITIVE] - assert ( - len(intervals) == 2 - ) # Expecting intervals for the nights of July 1st and 2nd - assert ( - intervals[0].lower - == pendulum.parse("2023-07-01 22:00:00", tz=timezone).timestamp() - ) - assert ( - intervals[0].upper - == pendulum.parse("2023-07-02 06:00:00", tz=timezone).timestamp() - ) - assert ( - intervals[1].lower - == pendulum.parse("2023-07-02 22:00:00", tz=timezone).timestamp() - ) - assert ( - intervals[1].upper - == pendulum.parse("2023-07-03 04:00:00", tz=timezone).timestamp() - ) - - -class TestFindOverlappingWindows(ProcessTest): - def test_no_intersection(self): - windows = [Interval(10, 20, T.POSITIVE)] - intervals = {1: [Interval(1, 5, T.POSITIVE)]} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [] - - def test_full_overlap(self): - windows = [Interval(10, 20, T.POSITIVE)] - intervals = {1: [Interval(10, 20, T.POSITIVE)]} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE) - ] - - def test_partial_overlap_start(self): - windows = [Interval(10, 20, T.POSITIVE)] - intervals = {1: [Interval(5, 15, T.POSITIVE)]} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE) - ] - - def test_partial_overlap_end(self): - windows = [Interval(10, 20, T.POSITIVE)] - intervals = {1: [Interval(15, 25, T.POSITIVE)]} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE) - ] - - def test_multiple_overlapping_intervals(self): - windows = [Interval(10, 20, T.POSITIVE)] - intervals = {1: [Interval(5, 15, T.POSITIVE), Interval(15, 25, T.POSITIVE)]} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE) - ] - - def test_interval_exactly_on_window_boundaries(self): - windows = [Interval(10, 20, T.POSITIVE)] - intervals = {1: [Interval(20, 30, T.POSITIVE)]} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE) - ] - - def test_point_interval(self): - windows = [Interval(10, 20, T.POSITIVE)] - - intervals = {1: [Interval(9, 9, T.POSITIVE)]} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [] - - intervals = {1: [Interval(10, 10, T.POSITIVE)]} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE) - ] - - intervals = {1: [Interval(10, 10, T.POSITIVE)] * 10} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE) - ] - - intervals = {1: [Interval(11, 11, T.POSITIVE)] * 10} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE) - ] - - intervals = {1: [Interval(20, 20, T.POSITIVE)]} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE) - ] - - intervals = {1: [Interval(20, 20, T.POSITIVE)] * 10} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE) - ] - - intervals = {1: [Interval(21, 21, T.POSITIVE)]} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [] - - def test_multiple_opening_closing(self): - windows = [Interval(10, 20, T.POSITIVE), Interval(25, 35, T.POSITIVE)] - - intervals = {1: [Interval(10, 10 + i, T.POSITIVE) for i in range(5)]} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE) - ] - - intervals = {1: [Interval(20, 20 + i, T.POSITIVE) for i in range(4)]} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE) - ] - - intervals = {1: [Interval(10 - i, 10, T.POSITIVE) for i in range(5)]} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE) - ] - - intervals = {1: [Interval(20 - i, 20, T.POSITIVE) for i in range(5)]} - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE) - ] - - intervals = { - 1: [Interval(20 - i, 20, T.POSITIVE) for i in range(5)] - + [Interval(21, 24, T.POSITIVE)] - } - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE) - ] - - intervals = { - 1: [Interval(20 - i, 20 + i, T.POSITIVE) for i in range(5)] - + [Interval(21, 25, T.POSITIVE)] - } - assert self.process.find_overlapping_windows(windows, intervals)[1] == [ - Interval(10, 20, T.POSITIVE), - Interval(25, 35, T.POSITIVE), - ] - - @pytest.mark.parametrize( - "windows,intervals,expected", - [ - ([], [], []), - ([], [(10, 20)], []), - ([(10, 20)], [], []), - ([(10, 20)], [(5, 10)], [(10, 20)]), - ([(10, 20)], [(5, 15)], [(10, 20)]), - ([(10, 20)], [(5, 25)], [(10, 20)]), - ([(10, 20)], [(15, 25)], [(10, 20)]), - ([(10, 20)], [(15, 25), (5, 10)], [(10, 20)]), - ([(10, 20)], [(15, 25), (5, 10), (20, 30)], [(10, 20)]), - ([(10, 20), (30, 40)], [(15, 25), (5, 10), (20, 30)], [(10, 20), (30, 40)]), - ( - [(10, 20), (30, 40)], - [(15, 25), (5, 10), (20, 30), (35, 45)], - [(10, 20), (30, 40)], - ), - ([(10, 20), (21, 30)], [(5, 30)], [(10, 20), (21, 30)]), - ([(10, 20), (21, 30), (31, 40), (41, 50)], [(5, 30)], [(10, 20), (21, 30)]), - ], - ) - def test_intervals(self, windows, intervals, expected): - - windows = [Interval(w[0], w[1], T.POSITIVE) for w in windows] - intervals = {1: [Interval(w[0], w[1], T.POSITIVE) for w in intervals]} - expected = [Interval(w[0], w[1], T.POSITIVE) for w in expected] - - assert self.process.find_overlapping_windows(windows, intervals)[1] == expected - - @pytest.mark.parametrize("seed", [None, 42, 99, 123]) - def test_multiple_windows(self, seed): - - windows = [Interval(i + 5, i + 8, "positive") for i in range(0, 150, 10)] - intervals = [ - Interval(2, 20, "positive"), - Interval(26, 30, "positive"), - Interval(32, 37, "positive"), - Interval(46, 47, "positive"), - Interval(65, 69, "positive"), - Interval(75, 80, "positive"), - Interval(85, 86, "positive"), - Interval(95, 95, "positive"), - Interval(106, 106, "positive"), - Interval(116, 118, "positive"), - Interval(122, 128, "positive"), - Interval(138, 138, "positive"), - ] - - expected = [ - Interval(lower=5, upper=8, type="positive"), - Interval(lower=15, upper=18, type="positive"), - Interval(lower=25, upper=28, type="positive"), - Interval(lower=35, upper=38, type="positive"), - Interval(lower=45, upper=48, type="positive"), - Interval(lower=65, upper=68, type="positive"), - Interval(lower=75, upper=78, type="positive"), - Interval(lower=85, upper=88, type="positive"), - Interval(lower=95, upper=98, type="positive"), - Interval(lower=105, upper=108, type="positive"), - Interval(lower=115, upper=118, type="positive"), - Interval(lower=125, upper=128, type="positive"), - Interval(lower=135, upper=138, type="positive"), - ] - - # Shuffle the windows and intervals list - if seed is not None: - random.seed(seed) # Seed the random number generator for reproducibility - random.shuffle(windows) - random.shuffle(intervals) - - assert ( - self.process.find_overlapping_windows(windows, {1: intervals})[1] - == expected - ) - - @pytest.mark.parametrize("seed", [None, 42, 99, 123]) - def test_multiple_windows_with_multiple_patients(self, seed): - # Define windows - windows = [Interval(i + 5, i + 8, T.POSITIVE) for i in range(0, 150, 10)] - - # Define intervals for multiple patients - patient_intervals = { - 1: [Interval(2, 20, T.POSITIVE), Interval(46, 47, T.POSITIVE)], - 2: [Interval(75, 80, T.POSITIVE), Interval(85, 86, T.POSITIVE)], - 3: [Interval(95, 95, T.POSITIVE), Interval(106, 106, T.POSITIVE)], - } - - # Expected results for each patient - expected = { - 1: [ - Interval(5, 8, T.POSITIVE), - Interval(15, 18, T.POSITIVE), - Interval(45, 48, T.POSITIVE), - ], - 2: [Interval(75, 78, T.POSITIVE), Interval(85, 88, T.POSITIVE)], - 3: [Interval(95, 98, T.POSITIVE), Interval(105, 108, T.POSITIVE)], } - - # Optionally shuffle the windows and intervals list - if seed is not None: - random.seed(seed) - random.shuffle(windows) - for intervals in patient_intervals.values(): - random.shuffle(intervals) - - # Run the test - result = self.process.find_overlapping_windows(windows, patient_intervals) - - # Sort results for comparison - for pid in expected: - result[pid].sort(key=lambda x: x.lower) - expected[pid].sort(key=lambda x: x.lower) - - assert result == expected diff --git a/tests/execution_engine/util/test_interval.py b/tests/execution_engine/util/test_interval.py index 4a80ef15..945b2ab9 100644 --- a/tests/execution_engine/util/test_interval.py +++ b/tests/execution_engine/util/test_interval.py @@ -198,14 +198,17 @@ def test_union_interval_same_type(self, type_): [ (T.POSITIVE, T.NEGATIVE, T.POSITIVE), (T.NO_DATA, T.NEGATIVE, T.NO_DATA), - (T.NOT_APPLICABLE, T.NEGATIVE, T.NOT_APPLICABLE), + (T.NEGATIVE, T.NOT_APPLICABLE, T.NEGATIVE), (T.POSITIVE, T.NO_DATA, T.POSITIVE), (T.POSITIVE, T.NOT_APPLICABLE, T.POSITIVE), (T.NO_DATA, T.NOT_APPLICABLE, T.NO_DATA), ], ) def test_union_interval_different_type(self, type1, type2, overlapping_type): - # type1 is the higher priority type + + with IntervalType.union_order(): + # type1 is the higher priority type + assert type1 > type2 assert interval_int(1, 2, type1) | interval_int(2, 3, type2) == IntInterval( interval_int(1, 2, type1), @@ -472,8 +475,8 @@ def test_union_priority(self): assert IntervalType.union_priority() == [ IntervalType.POSITIVE, IntervalType.NO_DATA, - IntervalType.NOT_APPLICABLE, IntervalType.NEGATIVE, + IntervalType.NOT_APPLICABLE, ] def test_intersection_priority(self): @@ -505,8 +508,8 @@ def test_custom_union_priority_order(self): assert IntervalType.union_priority() == [ IntervalType.POSITIVE, IntervalType.NO_DATA, - IntervalType.NOT_APPLICABLE, IntervalType.NEGATIVE, + IntervalType.NOT_APPLICABLE, ] def test_custom_intersection_priority_order(self): @@ -596,12 +599,10 @@ def test_or(self): # NOT_APPLICABLE | others assert ( - IntervalType.NOT_APPLICABLE | IntervalType.NEGATIVE - == IntervalType.NOT_APPLICABLE + IntervalType.NOT_APPLICABLE | IntervalType.NEGATIVE == IntervalType.NEGATIVE ) assert ( - IntervalType.NEGATIVE | IntervalType.NOT_APPLICABLE - == IntervalType.NOT_APPLICABLE + IntervalType.NEGATIVE | IntervalType.NOT_APPLICABLE == IntervalType.NEGATIVE ) # same types diff --git a/tests/execution_engine/util/test_types.py b/tests/execution_engine/util/test_types.py index 350fd66a..bcff9e02 100644 --- a/tests/execution_engine/util/test_types.py +++ b/tests/execution_engine/util/test_types.py @@ -6,7 +6,8 @@ from pydantic import ValidationError from execution_engine.util.enum import TimeUnit -from execution_engine.util.types import Dosage, TimeRange, Timing +from execution_engine.util.types import Dosage, Timing +from execution_engine.util.types.timerange import TimeRange from execution_engine.util.value import ValueNumber from execution_engine.util.value.time import ValueCount, ValueDuration, ValuePeriod from tests._fixtures.concept import concept_unit_mg diff --git a/tests/execution_engine/util/test_value.py b/tests/execution_engine/util/test_value.py index e339be60..5a101f10 100644 --- a/tests/execution_engine/util/test_value.py +++ b/tests/execution_engine/util/test_value.py @@ -307,10 +307,7 @@ def test_to_sql(self, test_concept, test_table): def test_str(self, test_concept): value_concept = ValueConcept(value=test_concept) - assert ( - str(value_concept) - == "Value == Concept(concept_id=1, concept_name='Test Concept', concept_code='unit', domain_id='units', vocabulary_id='test', concept_class_id='test', standard_concept=None, invalid_reason=None)" - ) + assert str(value_concept) == "Value == Test Concept" def test_repr(self, test_concept): value_concept = ValueConcept(value=test_concept) diff --git a/tests/recommendation/test_recommendation_base.py b/tests/recommendation/test_recommendation_base.py index a78af845..54580628 100644 --- a/tests/recommendation/test_recommendation_base.py +++ b/tests/recommendation/test_recommendation_base.py @@ -18,7 +18,7 @@ ) from execution_engine.omop.db.omop.tables import Person from execution_engine.util.interval import IntervalType -from execution_engine.util.types import TimeRange +from execution_engine.util.types.timerange import TimeRange from tests._testdata import concepts, parameter from tests.functions import ( create_condition, diff --git a/tests/recommendation/test_recommendation_base_v2.py b/tests/recommendation/test_recommendation_base_v2.py index 5bc1a584..ec322d32 100644 --- a/tests/recommendation/test_recommendation_base_v2.py +++ b/tests/recommendation/test_recommendation_base_v2.py @@ -8,7 +8,7 @@ from execution_engine.omop.db.omop.tables import Person from execution_engine.util.interval import IntervalType -from execution_engine.util.types import TimeRange +from execution_engine.util.types.timerange import TimeRange from tests._testdata import concepts from tests._testdata.generator.composite import ( AndGenerator,