From eefb49f09e76c39b4c3687f23db8af4286dfa3c1 Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Thu, 2 Apr 2026 16:23:20 +0000 Subject: [PATCH 1/2] Fixed bug that DSE didn't respect unit_length_to_sample --- pyproject.toml | 2 +- tests/test_splitter.py | 223 ++++++++++++++++++ .../instruction/data_splitter_events.py | 1 - 3 files changed, 224 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7985fb0..94bed0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ build-backend = "setuptools.build_meta" # Standard entry point for setuptools bu [project] name = "twinweaver" -version = "0.3.5" +version = "0.3.6" description = "Converting longitudinal patient data into text for LLM-based event prediction and forecasting." # --- NEW/UPDATED FIELDS --- diff --git a/tests/test_splitter.py b/tests/test_splitter.py index 7a6b60e..78fe2e4 100644 --- a/tests/test_splitter.py +++ b/tests/test_splitter.py @@ -726,3 +726,226 @@ def test_forecasting_truncation_allow_beyond_next_split_date(): base_date + pd.Timedelta(days=25), ] assert target[cfg.date_col].tolist() == expected_dates + + +# ──────────────────────────────────────────────────────────────────────────── +# Test that DataSplitterEvents respects unit_length_to_sample +# ──────────────────────────────────────────────────────────────────────────── + + +def test_events_splitter_respects_unit_length_to_sample(): + """ + Verify that the observation_end_date produced by DataSplitterEvents is + bounded by ``max_length_to_sample`` expressed in the correct + ``unit_length_to_sample``, and is NOT pushed forward to the next event + when that event lies beyond the sampled prediction window. + + Scenario (unit_length_to_sample = "days", max = 7 days, min = 1 day): + Timeline for patient p_test: + Day 0 - lot event (split event) + Day 0 - lab measurement (input – at the split date) + Day 100 - death event (far in the future) + + We split at Day 0 and predict "death" with a window of at most 7 days. + The observation_end_date must be <= Day 0 + 7 days. Before the fix + it would jump to Day 100 (the next event), violating the window. + """ + from twinweaver.common.config import Config + + cfg = Config() + cfg.seed = 42 + cfg.split_event_category = "lot" + cfg.event_category_events_prediction_with_naming = {"death": "death"} + cfg.constant_columns_to_use = [] + + base_date = pd.Timestamp("2020-01-01") + + events = pd.DataFrame( + { + cfg.date_col: [ + base_date, + base_date, + base_date + pd.Timedelta(days=100), + ], + cfg.event_category_col: ["lot", "lab", "death"], + cfg.event_name_col: ["line_number", "hemoglobin", "death"], + cfg.event_value_col: ["1", "13.0", "deceased"], + cfg.event_descriptive_name_col: ["line number", "hemoglobin", "death"], + cfg.source_col: ["events"] * 3, + cfg.meta_data_col: [pd.NA] * 3, + } + ) + + constant = pd.DataFrame( + { + cfg.patient_id_col: ["p_test"], + cfg.constant_split_col: ["train"], + } + ) + + patient_data = {"events": events, "constant": constant} + + # Create a minimal DataManager stub + dm = DataManager.__new__(DataManager) + dm.config = cfg + dm.data_frames = {"events": events} + dm.all_patientids = ["p_test"] + + max_length = pd.Timedelta(days=7) + min_length = pd.Timedelta(days=1) + + splitter = DataSplitterEvents( + data_manager=dm, + config=cfg, + max_length_to_sample=max_length, + min_length_to_sample=min_length, + unit_length_to_sample="days", + max_split_length_after_split_event=pd.Timedelta(days=0), + ) + splitter.setup_variables() + + np.random.seed(cfg.seed) + + # Use override split dates and category to make the test deterministic + splits = splitter.get_splits_from_patient( + patient_data, + max_nr_samples_per_split=1, + override_split_dates=[base_date], + override_category="death", + ) + + assert len(splits) == 1 + assert len(splits[0]) == 1 + + option = splits[0][0] + + # The critical assertion: observation_end_date must respect the sampled + # window which is at most base_date + 7 days. Before the fix it was + # pushed to base_date + 100 days (the death event). + assert option.observation_end_date <= base_date + max_length, ( + f"observation_end_date ({option.observation_end_date}) exceeded the " + f"maximum prediction window ({base_date + max_length}). " + f"unit_length_to_sample is not being respected." + ) + + # The event (death) is at Day 100, well outside the 7-day window, + # so it should NOT have occurred. + assert option.event_occurred is False + + # The end_date is within the data range (data goes to Day 100), so the + # event simply did not occur within the prediction window — not censored. + assert option.event_censored is None + + +def test_events_splitter_unit_days_vs_weeks(): + """ + Verify that changing ``unit_length_to_sample`` between 'days' and 'weeks' + actually produces different observation windows when + ``max_length_to_sample`` is the same Timedelta. + + With max_length_to_sample = 14 days: + - unit='days' → random window in [1 … 14] days + - unit='weeks' → random window in [1 … 2] weeks (= 7 or 14 days) + + By running many samples we can verify the units produce the expected + granularity. + """ + from twinweaver.common.config import Config + + cfg = Config() + cfg.seed = 123 + cfg.split_event_category = "lot" + cfg.event_category_events_prediction_with_naming = {"death": "death"} + cfg.constant_columns_to_use = [] + + base_date = pd.Timestamp("2020-01-01") + far_future = base_date + pd.Timedelta(days=365) + + events = pd.DataFrame( + { + cfg.date_col: [base_date, base_date, far_future], + cfg.event_category_col: ["lot", "lab", "death"], + cfg.event_name_col: ["line_number", "hemoglobin", "death"], + cfg.event_value_col: ["1", "13.0", "deceased"], + cfg.event_descriptive_name_col: ["line number", "hemoglobin", "death"], + cfg.source_col: ["events"] * 3, + cfg.meta_data_col: [pd.NA] * 3, + } + ) + + constant = pd.DataFrame( + { + cfg.patient_id_col: ["p_test"], + cfg.constant_split_col: ["train"], + } + ) + patient_data = {"events": events, "constant": constant} + + dm = DataManager.__new__(DataManager) + dm.config = cfg + dm.data_frames = {"events": events} + dm.all_patientids = ["p_test"] + + max_td = pd.Timedelta(days=14) + min_td = pd.Timedelta(days=1) + + # --- unit = "days" --- + splitter_days = DataSplitterEvents( + data_manager=dm, + config=cfg, + max_length_to_sample=max_td, + min_length_to_sample=min_td, + unit_length_to_sample="days", + max_split_length_after_split_event=pd.Timedelta(days=0), + ) + splitter_days.setup_variables() + + day_deltas = set() + for i in range(200): + np.random.seed(i) + splits = splitter_days.get_splits_from_patient( + patient_data, + max_nr_samples_per_split=1, + override_split_dates=[base_date], + override_category="death", + ) + delta = (splits[0][0].observation_end_date - base_date).days + day_deltas.add(delta) + + # With unit="days" and range [1..14], we expect many distinct day values + assert len(day_deltas) > 2, f"With unit='days', expected many distinct day offsets but got {day_deltas}" + # All values should be within [1, 14] + assert min(day_deltas) >= 1 + assert max(day_deltas) <= 14 + + # --- unit = "weeks" --- + splitter_weeks = DataSplitterEvents( + data_manager=dm, + config=cfg, + max_length_to_sample=max_td, + min_length_to_sample=min_td, + unit_length_to_sample="weeks", + max_split_length_after_split_event=pd.Timedelta(days=0), + ) + splitter_weeks.setup_variables() + + week_deltas = set() + for i in range(200): + np.random.seed(i) + splits = splitter_weeks.get_splits_from_patient( + patient_data, + max_nr_samples_per_split=1, + override_split_dates=[base_date], + override_category="death", + ) + delta = (splits[0][0].observation_end_date - base_date).days + week_deltas.add(delta) + + # With unit="weeks", min_units=0 (1 day // 7 = 0), max_units=2 (14 days // 7 = 2) + # So we get 0, 1, or 2 weeks → 0, 7, or 14 days + # All offsets must be multiples of 7 + for d in week_deltas: + assert d % 7 == 0, ( + f"With unit='weeks', got a non-week-multiple offset of {d} days. " + f"unit_length_to_sample is not being respected." + ) diff --git a/twinweaver/instruction/data_splitter_events.py b/twinweaver/instruction/data_splitter_events.py index 5e7531f..5819ec3 100644 --- a/twinweaver/instruction/data_splitter_events.py +++ b/twinweaver/instruction/data_splitter_events.py @@ -421,7 +421,6 @@ def get_splits_from_patient( # Process the actual end date end_date = curr_date + end_time_delta - end_date = max(end_date, events_after_split[self.config.date_col].min()) end_date_within_data = end_date <= events[self.config.date_col].max() events_limited_after_split = events_after_split[events_after_split[self.config.date_col] <= end_date] From c9a529f234f32cfff169a2f7c24fc4b188012625 Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Thu, 2 Apr 2026 16:36:25 +0000 Subject: [PATCH 2/2] Fixed edge case --- tests/test_converter_events_unit.py | 313 ++++++++++++++++++ twinweaver/instruction/converter_events.py | 48 ++- .../instruction/data_splitter_events.py | 8 + 3 files changed, 362 insertions(+), 7 deletions(-) create mode 100644 tests/test_converter_events_unit.py diff --git a/tests/test_converter_events_unit.py b/tests/test_converter_events_unit.py new file mode 100644 index 0000000..a1ebbf8 --- /dev/null +++ b/tests/test_converter_events_unit.py @@ -0,0 +1,313 @@ +"""Tests that ConverterEvents respects DataSplitterEventsOption.unit_length_to_sample. + +When the DataSplitterEvents is configured with a ``unit_length_to_sample`` that +differs from ``config.delta_time_unit``, the TTE prompt must: + +1. Express the delta-time value in the *splitter's* unit, **not** the config's. +2. Use the *splitter's* unit name in the prompt text (e.g. "days" instead of + "weeks"). +""" + +import pandas as pd + +from twinweaver.common.config import Config +from twinweaver.common.data_manager import DataManager +from twinweaver.instruction.converter_events import ConverterEvents +from twinweaver.instruction.converter_instruction import ConverterInstruction +from twinweaver.instruction.data_splitter_events import ( + DataSplitterEvents, + DataSplitterEventsOption, +) +from twinweaver.instruction.data_splitter import DataSplitter + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_option( + delta: pd.Timedelta, + unit: str = None, + category_name: str = "death", + category: str = "death", + occurred: bool = True, + censored=None, +) -> DataSplitterEventsOption: + """Create a minimal DataSplitterEventsOption for unit-testing the converter.""" + split_date = pd.Timestamp("2020-01-01") + return DataSplitterEventsOption( + events_until_split=pd.DataFrame(), + constant_data=pd.DataFrame(), + event_occurred=occurred, + event_censored=censored, + observation_end_date=split_date + delta, + split_date_included_in_input=split_date, + sampled_category=category, + sampled_category_name=category_name, + lot_date=split_date, + unit_length_to_sample=unit, + ) + + +def _make_converter(config: Config) -> ConverterEvents: + """Create a ConverterEvents instance with the given config.""" + return ConverterEvents( + config=config, + constant_description=pd.DataFrame(), + nr_tokens_budget_total=4096, + ) + + +# --------------------------------------------------------------------------- +# Unit-level tests on ConverterEvents._generate_prompt +# --------------------------------------------------------------------------- + + +class TestConverterEventsRespectsUnit: + """ConverterEvents._generate_prompt must honour unit_length_to_sample.""" + + def test_days_unit_overrides_config_weeks(self): + """Splitter uses days while config uses weeks → prompt should say 'days'.""" + cfg = Config() + cfg.seed = 42 + assert cfg.delta_time_unit == "weeks" # default is weeks + + converter = _make_converter(cfg) + option = _make_option(delta=pd.Timedelta(days=14), unit="days") + + prompt, delta_numeric = converter._generate_prompt(option) + + # 14 days expressed in *days* → 14 + assert "14" in prompt + assert "days" in prompt + # Make sure the config unit is NOT used for the numeric value + # 14 days / 7 = 2 weeks — "2" could appear coincidentally, so check + # the prompt text says "days" not "weeks" in the mid-section + assert "days from the last clinical visit" in prompt + + def test_hours_unit_overrides_config_weeks(self): + """Splitter uses hours while config uses weeks → prompt should say 'hours'.""" + cfg = Config() + cfg.seed = 42 + converter = _make_converter(cfg) + option = _make_option(delta=pd.Timedelta(hours=48), unit="hours") + + prompt, delta_numeric = converter._generate_prompt(option) + + assert "48" in prompt + assert "hours" in prompt + assert "hours from the last clinical visit" in prompt + + def test_minutes_unit_overrides_config_weeks(self): + """Splitter uses minutes while config uses weeks.""" + cfg = Config() + cfg.seed = 42 + converter = _make_converter(cfg) + option = _make_option(delta=pd.Timedelta(minutes=120), unit="minutes") + + prompt, delta_numeric = converter._generate_prompt(option) + + assert "120" in prompt + assert "minutes" in prompt + + def test_weeks_unit_matches_config_weeks(self): + """Splitter uses weeks and so does config — should work identically.""" + cfg = Config() + cfg.seed = 42 + converter = _make_converter(cfg) + option = _make_option(delta=pd.Timedelta(weeks=4), unit="weeks") + + prompt, delta_numeric = converter._generate_prompt(option) + + assert "4" in prompt + assert "weeks" in prompt + assert abs(delta_numeric - 4.0) < 0.01 + + def test_none_unit_falls_back_to_config(self): + """When unit_length_to_sample is None, the config unit is used (legacy behaviour).""" + cfg = Config() + cfg.seed = 42 + cfg.set_delta_time_unit("days", unit_sing="day") + + converter = _make_converter(cfg) + option = _make_option(delta=pd.Timedelta(days=14), unit=None) + + prompt, delta_numeric = converter._generate_prompt(option) + + assert "14" in prompt + assert "days" in prompt + + def test_numeric_value_is_correct_days(self): + """Delta numeric should reflect the splitter's unit, not the config's.""" + cfg = Config() + cfg.seed = 42 + converter = _make_converter(cfg) + + # 21 days = 3 weeks + option = _make_option(delta=pd.Timedelta(days=21), unit="days") + _, delta_numeric = converter._generate_prompt(option) + assert abs(delta_numeric - 21.0) < 0.01 + + def test_numeric_value_is_correct_weeks(self): + """When splitter unit is weeks, 21 days → 3 weeks.""" + cfg = Config() + cfg.seed = 42 + converter = _make_converter(cfg) + + option = _make_option(delta=pd.Timedelta(days=21), unit="weeks") + _, delta_numeric = converter._generate_prompt(option) + assert abs(delta_numeric - 3.0) < 0.01 + + def test_forward_conversion_uses_split_unit(self): + """forward_conversion should propagate the unit through _generate_prompt.""" + cfg = Config() + cfg.seed = 42 + converter = _make_converter(cfg) + + option = _make_option(delta=pd.Timedelta(days=7), unit="days") + prompt, target, meta = converter.forward_conversion(option) + + # Prompt should express time in days (7), not weeks (1) + assert "7" in prompt + assert "days" in prompt + # Meta should carry the numeric delta in days + assert abs(meta["delta_time_numeric"] - 7.0) < 0.01 + + +# --------------------------------------------------------------------------- +# Integration test: DataSplitterEvents propagates unit to option +# --------------------------------------------------------------------------- + + +class TestDataSplitterEventsPopulatesUnit: + """DataSplitterEvents.get_splits_from_patient must set unit_length_to_sample.""" + + def test_unit_propagated_to_options(self, mock_config, sample_data): + """Options created by the splitter carry the correct unit.""" + df_events, df_constant, df_constant_desc = sample_data + mock_config.split_event_category = "lot" + mock_config.event_category_forecast = ["lab"] + mock_config.event_category_events_prediction_with_naming = { + "death": "death", + "progression": "next progression", + } + mock_config.constant_columns_to_use = [ + "birthyear", + "gender", + "histology", + "smoking_history", + ] + mock_config.constant_birthdate_column = "birthyear" + + dm = DataManager(config=mock_config) + dm.load_indication_data(df_events, df_constant, df_constant_desc) + dm.process_indication_data() + dm.setup_unique_mapping_of_events() + dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1) + dm.infer_var_types() + + # Create splitter with unit_length_to_sample="days" + splitter = DataSplitterEvents( + dm, + config=mock_config, + max_length_to_sample=pd.Timedelta(days=90), + min_length_to_sample=pd.Timedelta(days=1), + max_split_length_after_split_event=pd.Timedelta(days=90), + unit_length_to_sample="days", + ) + splitter.setup_variables() + + patient_data = dm.get_patient_data("p0") + groups = splitter.get_splits_from_patient( + patient_data, + max_nr_samples_per_split=2, + ) + + # At least one group should be generated + assert len(groups) > 0 + + for group in groups: + for option in group.events_options: + assert option.unit_length_to_sample == "days" + + +# --------------------------------------------------------------------------- +# End-to-end: config says "minutes", splitter says "days" → prompt says "days" +# --------------------------------------------------------------------------- + + +class TestEndToEndUnitOverride: + """Full pipeline: splitter unit overrides config unit in generated prompt.""" + + def test_events_prompt_uses_splitter_unit_not_config(self, mock_config, sample_data): + df_events, df_constant, df_constant_desc = sample_data + mock_config.split_event_category = "lot" + mock_config.event_category_forecast = ["lab"] + mock_config.event_category_events_prediction_with_naming = { + "death": "death", + "progression": "next progression", + } + mock_config.constant_columns_to_use = [ + "birthyear", + "gender", + "histology", + "smoking_history", + ] + mock_config.constant_birthdate_column = "birthyear" + + # Set config to "minutes" — this should NOT appear in events prompt + mock_config.set_delta_time_unit("minutes", unit_sing="minute") + + dm = DataManager(config=mock_config) + dm.load_indication_data(df_events, df_constant, df_constant_desc) + dm.process_indication_data() + dm.setup_unique_mapping_of_events() + dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1) + dm.infer_var_types() + + # Splitter uses "days" + splitter_events = DataSplitterEvents( + dm, + config=mock_config, + max_length_to_sample=pd.Timedelta(days=90), + min_length_to_sample=pd.Timedelta(days=1), + max_split_length_after_split_event=pd.Timedelta(days=90), + unit_length_to_sample="days", + ) + splitter_events.setup_variables() + + data_splitter = DataSplitter(data_splitter_events=splitter_events) + + converter = ConverterInstruction( + nr_tokens_budget_total=4096, + config=mock_config, + dm=dm, + variable_stats=None, + ) + + patient_data = dm.get_patient_data("p0") + f_splits, e_splits, _ = data_splitter.get_splits_from_patient_with_target(patient_data) + + assert f_splits is None # No forecasting splitter configured + assert e_splits is not None and len(e_splits) > 0 + + # Pick the first events group that has options + event_group = None + for eg in e_splits: + if len(eg) > 0: + event_group = eg + break + assert event_group is not None, "No event splits generated" + + result = converter.forward_conversion( + forecasting_splits=None, + event_splits=event_group, + ) + + instruction = result["instruction"] + + # The events task prompt should say "days from the last clinical visit", + # NOT "minutes from the last clinical visit". + assert "days from the last clinical visit" in instruction + assert "minutes from the last clinical visit" not in instruction diff --git a/twinweaver/instruction/converter_events.py b/twinweaver/instruction/converter_events.py index 572723c..190ff85 100644 --- a/twinweaver/instruction/converter_events.py +++ b/twinweaver/instruction/converter_events.py @@ -139,9 +139,13 @@ def _generate_prompt(self, patient_split: DataSplitterEventsOption) -> tuple: Constructs a prompt asking the language model to predict the time until a specific event occurs. It calculates the time difference between the patient's - split date (last date included in input) and the actual event date, converts - it to weeks if config.delta_time_unit is "weeks", rounds it, and formats it into the prompt string using - templates from the config (e.g., `self.forecasting_prompt_start`). + split date (last date included in input) and the actual event date. + + When the *patient_split* carries a ``unit_length_to_sample`` (propagated from + :class:`DataSplitterEvents`), the delta time is expressed in that unit and + the prompt text is adjusted accordingly, regardless of + ``config.delta_time_unit``. If the attribute is ``None`` the behaviour + falls back to the global ``config.delta_time_unit``. Parameters ---------- @@ -154,22 +158,52 @@ def _generate_prompt(self, patient_split: DataSplitterEventsOption) -> tuple: The formatted prompt string, e.g.: "Predict the time in weeks until event Event A occurs: 12.3 weeks. Input data:\n" delta_time_numeric : float - The calculated time difference in config.delta_time_unit (numeric, before rounding/formatting). + The calculated time difference (numeric, before rounding/formatting) + expressed in the effective unit. """ #: Get event name descriptive curr_event_name = patient_split.sampled_category_name - #: get delta in time in config.delta_time_unit, rounded using round_and_strip + #: Determine the effective unit and time divisor. + # If the split carries unit_length_to_sample we honour it; otherwise + # fall back to the config-level delta_time_unit via self._time_divisor. + split_unit = getattr(patient_split, "unit_length_to_sample", None) + if split_unit is not None: + _unit_to_divisor = { + "weeks": 7.0, + "week(s)": 7.0, + "days": 1.0, + "day(s)": 1.0, + "hours": 1.0 / 24.0, + "hour(s)": 1.0 / 24.0, + "minutes": 1.0 / (24.0 * 60.0), + "minute(s)": 1.0 / (24.0 * 60.0), + "seconds": 1.0 / (24.0 * 60.0 * 60.0), + "second(s)": 1.0 / (24.0 * 60.0 * 60.0), + } + if split_unit not in _unit_to_divisor: + raise ValueError( + f"Unsupported unit_length_to_sample on split: '{split_unit}'. " + f"Supported values: {list(_unit_to_divisor.keys())}" + ) + effective_divisor = _unit_to_divisor[split_unit] + # Re-render the mid prompt template with the split's unit + effective_prompt_mid = self.config._forecasting_tte_prompt_mid_template.format(unit=split_unit) + else: + effective_divisor = self._time_divisor + effective_prompt_mid = self.forecasting_prompt_mid + + #: get delta in time, rounded using round_and_strip delta_time_numeric = patient_split.observation_end_date - patient_split.split_date_included_in_input - delta_time_numeric = delta_time_numeric.total_seconds() / (86400.0 * self._time_divisor) + delta_time_numeric = delta_time_numeric.total_seconds() / (86400.0 * effective_divisor) delta_time = round_and_strip(delta_time_numeric, self.decimal_precision) #: construct prompt using config attributes accessed via self ret_prompt = self.forecasting_prompt_start + str(delta_time) - ret_prompt += self.forecasting_prompt_mid + curr_event_name + ret_prompt += effective_prompt_mid + curr_event_name ret_prompt += self.forecasting_prompt_end #: return diff --git a/twinweaver/instruction/data_splitter_events.py b/twinweaver/instruction/data_splitter_events.py index 5819ec3..edd0f55 100644 --- a/twinweaver/instruction/data_splitter_events.py +++ b/twinweaver/instruction/data_splitter_events.py @@ -33,6 +33,11 @@ class DataSplitterEventsOption: The end of the prediction window. lot_date : pd.Timestamp The Line of Therapy (LoT) start date associated with this split point. + unit_length_to_sample : str or None + The time unit used by the DataSplitterEvents when sampling the observation + window (e.g. ``"weeks"``, ``"days"``, ``"hours"``). Propagated so that + downstream converters can express the delta time in the same unit, + regardless of ``config.delta_time_unit``. """ def __init__( @@ -46,6 +51,7 @@ def __init__( sampled_category: str, sampled_category_name: str, lot_date: pd.Timestamp, + unit_length_to_sample: str = None, ): self.events_until_split = events_until_split self.constant_data = constant_data @@ -56,6 +62,7 @@ def __init__( self.sampled_category = sampled_category self.sampled_category_name = sampled_category_name self.lot_date = lot_date + self.unit_length_to_sample = unit_length_to_sample class DataSplitterEventsGroup: @@ -501,6 +508,7 @@ def get_splits_from_patient( sampled_category=str(sampled_cateogry), sampled_category_name=sampled_var_name, lot_date=lot_date, + unit_length_to_sample=self.unit_length_to_sample, ) )