From 9896c84e06190f3a11b3c2dd9a42b7f3b08e0fff Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Tue, 17 Mar 2026 15:54:54 +0000 Subject: [PATCH 01/36] Added config.event_categories_to_exclude_from_input --- tests/test_converter.py | 124 ++++++++++++++++++++++++++++ twinweaver/common/config.py | 3 + twinweaver/common/converter_base.py | 6 ++ 3 files changed, 133 insertions(+) diff --git a/tests/test_converter.py b/tests/test_converter.py index 3a98e42..e50e8a9 100644 --- a/tests/test_converter.py +++ b/tests/test_converter.py @@ -91,6 +91,130 @@ def test_forward_conversion_training(setup_components): assert "hemoglobin - 718-7 is 14.01." in answer # Known value for p0 +def test_event_categories_to_exclude_from_input(mock_config, sample_data): + """Test that event_categories_to_exclude_from_input removes specified categories from generated text.""" + df_events, df_constant, df_constant_desc = sample_data + + # Configure with drug events excluded + mock_config.split_event_category = "lot" + mock_config.event_category_forecast = ["lab"] + mock_config.data_splitter_events_variables_category_mapping = {"death": "death", "progression": "next progression"} + mock_config.constant_columns_to_use = ["birthyear", "gender", "histology", "smoking_history"] + mock_config.constant_birthdate_column = "birthyear" + mock_config.event_categories_to_exclude_from_input = ["drug"] + + dm = DataManager(config=mock_config) + dm.load_indication_data(df_events, df_constant, df_constant_desc) + dm.process_indication_data() + dm.setup_unique_mapping_of_events() + dm.setup_dataset_splits() + dm.infer_var_types() + + splitter_events = DataSplitterEvents(dm, config=mock_config) + splitter_events.setup_variables() + + splitter_forecast = DataSplitterForecasting(data_manager=dm, config=mock_config) + splitter_forecast.setup_statistics() + + data_splitter = DataSplitter(splitter_events, splitter_forecast) + + converter = ConverterInstruction( + nr_tokens_budget_total=4096, config=mock_config, dm=dm, variable_stats=splitter_forecast.variable_stats + ) + + patient_data = dm.get_patient_data("p0") + f_splits, e_splits, _ = data_splitter.get_splits_from_patient_with_target(patient_data) + + result = converter.forward_conversion( + forecasting_splits=f_splits[0], + event_splits=e_splits[0], + override_mode_to_select_forecasting="both", + ) + + instruction = result["instruction"] + + # Drug events (e.g. "drug pemetrexed is administered") should be excluded + assert "drug pemetrexed is administered" not in instruction + assert ( + "drug" not in instruction.lower().split("demographic")[0].split("\n")[-1] + if "demographic" in instruction + else True + ) + + # Other event categories should still be present + assert "Starting with demographic data:" in instruction + assert "hemoglobin" in instruction # lab events still present + + +def test_event_categories_to_exclude_multiple(mock_config, sample_data): + """Test excluding multiple event categories from input.""" + df_events, df_constant, df_constant_desc = sample_data + + mock_config.split_event_category = "lot" + mock_config.event_category_forecast = ["lab"] + mock_config.data_splitter_events_variables_category_mapping = {"death": "death", "progression": "next progression"} + mock_config.constant_columns_to_use = ["birthyear", "gender", "histology", "smoking_history"] + mock_config.constant_birthdate_column = "birthyear" + mock_config.event_categories_to_exclude_from_input = ["drug", "ecog"] + + dm = DataManager(config=mock_config) + dm.load_indication_data(df_events, df_constant, df_constant_desc) + dm.process_indication_data() + dm.setup_unique_mapping_of_events() + dm.setup_dataset_splits() + dm.infer_var_types() + + splitter_events = DataSplitterEvents(dm, config=mock_config) + splitter_events.setup_variables() + + splitter_forecast = DataSplitterForecasting(data_manager=dm, config=mock_config) + splitter_forecast.setup_statistics() + + data_splitter = DataSplitter(splitter_events, splitter_forecast) + + converter = ConverterInstruction( + nr_tokens_budget_total=4096, config=mock_config, dm=dm, variable_stats=splitter_forecast.variable_stats + ) + + patient_data = dm.get_patient_data("p0") + f_splits, e_splits, _ = data_splitter.get_splits_from_patient_with_target(patient_data) + + result = converter.forward_conversion( + forecasting_splits=f_splits[0], + event_splits=e_splits[0], + override_mode_to_select_forecasting="both", + ) + + instruction = result["instruction"] + + # Both drug and ecog events should be absent + assert "drug pemetrexed is administered" not in instruction + assert "ECOG" not in instruction + + # Lab events should still be present + assert "hemoglobin" in instruction + + +def test_event_categories_to_exclude_empty_list(setup_components): + """Test that an empty exclude list leaves all events intact (default behavior).""" + dm, data_splitter, converter = setup_components + assert dm.config.event_categories_to_exclude_from_input == [] + + patient_data = dm.get_patient_data("p0") + f_splits, e_splits, _ = data_splitter.get_splits_from_patient_with_target(patient_data) + + result = converter.forward_conversion( + forecasting_splits=f_splits[0], + event_splits=e_splits[0], + override_mode_to_select_forecasting="both", + ) + + instruction = result["instruction"] + + # Drug events should be present when nothing is excluded + assert "drug pemetrexed is administered" in instruction + + def test_forward_conversion_inference(setup_components): """Test conversion for inference (no target string).""" dm, data_splitter, converter = setup_components diff --git a/twinweaver/common/config.py b/twinweaver/common/config.py index a103b49..593239c 100644 --- a/twinweaver/common/config.py +++ b/twinweaver/common/config.py @@ -296,6 +296,9 @@ def __init__(self): True # Whether to warn if a patient has no LoT events in DataSplitterEvents ) + # List of event categories to exclude from the input data (e.g., ["lot"]) + self.event_categories_to_exclude_from_input: list = [] + # --- Specific Event Categories / Values / Sources --- self.event_category_lot: str = "lot" self.event_category_death: str = "death" diff --git a/twinweaver/common/converter_base.py b/twinweaver/common/converter_base.py index a4e57d3..d90e04f 100644 --- a/twinweaver/common/converter_base.py +++ b/twinweaver/common/converter_base.py @@ -287,6 +287,12 @@ def _preprocess_events(self, events: pd.DataFrame) -> pd.DataFrame: round_and_strip, args=(self.decimal_precision,) ) + # Exclude specified event categories from the input if configured + if self.config.event_categories_to_exclude_from_input: + events = events[ + ~events[self.config.event_category_col].isin(self.config.event_categories_to_exclude_from_input) + ] + return events def _get_event_string( From 34d128eaba6e53bedfebc38cb507018193952f1a Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Tue, 17 Mar 2026 16:21:15 +0000 Subject: [PATCH 02/36] Added that unified DataSplitter API can now handle individual data splitters --- pyproject.toml | 2 +- tests/test_splitter.py | 164 ++++++++++++++++++++++++ twinweaver/instruction/data_splitter.py | 142 +++++++++++++------- 3 files changed, 259 insertions(+), 49 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index acf48a2..4edd460 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ build-backend = "setuptools.build_meta" # Standard entry point for setuptools bu [project] name = "twinweaver" -version = "0.2.1" +version = "0.3.0" description = "Converting longitudinal patient data into text for LLM-based event prediction and forecasting." # --- NEW/UPDATED FIELDS --- diff --git a/tests/test_splitter.py b/tests/test_splitter.py index 4ae7189..bfb4685 100644 --- a/tests/test_splitter.py +++ b/tests/test_splitter.py @@ -124,3 +124,167 @@ def test_inference_split(initialized_dm, mock_config): assert e_split.observation_end_date == pd.Timestamp("2018-02-23 00:00:00") assert e_split.event_censored == "end_of_data" assert not e_split.event_occurred + + +# ──────────────────────────────────────────────────────────────────────────── +# Tests for DataSplitter with individual (single) splitters +# ──────────────────────────────────────────────────────────────────────────── + + +def test_data_splitter_requires_at_least_one_splitter(): + """Test that DataSplitter raises if neither splitter is provided.""" + with pytest.raises(ValueError, match="At least one"): + DataSplitter() + + +def test_training_forecasting_only(initialized_dm, mock_config): + """Test training splits when only the forecasting splitter is provided.""" + splitter_forecast = DataSplitterForecasting(data_manager=initialized_dm, config=mock_config) + splitter_forecast.setup_statistics() + + data_splitter = DataSplitter(data_splitter_forecasting=splitter_forecast) + + patient_data = initialized_dm.get_patient_data("p0") + forecasting_splits, events_splits, ref_dates = data_splitter.get_splits_from_patient_with_target( + patient_data, max_num_splits_per_split_event=1 + ) + + # Forecasting should be populated + assert forecasting_splits is not None + assert len(forecasting_splits) == 1 + + # Events should be None since no events splitter was provided + assert events_splits is None + + # Reference dates should still be available from the forecasting splitter + assert ref_dates is not None + assert not ref_dates.empty + + # Validate forecasting split structure + f_split = forecasting_splits[0][0] + assert f_split.events_until_split is not None + assert f_split.constant_data["patientid"].iloc[0] == "p0" + + +def test_training_events_only(initialized_dm, mock_config): + """Test training splits when only the events splitter is provided.""" + splitter_events = DataSplitterEvents(initialized_dm, config=mock_config) + splitter_events.setup_variables() + + data_splitter = DataSplitter(data_splitter_events=splitter_events) + + patient_data = initialized_dm.get_patient_data("p0") + forecasting_splits, events_splits, ref_dates = data_splitter.get_splits_from_patient_with_target( + patient_data, max_num_splits_per_split_event=1 + ) + + # Forecasting should be None since no forecasting splitter was provided + assert forecasting_splits is None + + # Events should be populated + assert events_splits is not None + assert len(events_splits) >= 1 + + # Reference dates should be reconstructed from events splits + assert ref_dates is not None + assert not ref_dates.empty + + # Validate events split structure + e_split = events_splits[0][0] + assert e_split.events_until_split is not None + assert e_split.constant_data["patientid"].iloc[0] == "p0" + assert e_split.sampled_category in ["death", "progression"] + + +def test_inference_forecasting_only(initialized_dm, mock_config): + """Test inference split when only the forecasting splitter is provided.""" + splitter_forecast = DataSplitterForecasting(data_manager=initialized_dm, config=mock_config) + + data_splitter = DataSplitter(data_splitter_forecasting=splitter_forecast) + patient_data = initialized_dm.get_patient_data("p0") + + f_split, e_split = data_splitter.get_splits_from_patient_inference( + patient_data, + inference_type="forecasting", + forecasting_override_variables_to_predict=["hemoglobin_-_718-7"], + ) + + last_date = patient_data["events"]["date"].max() + + # Forecasting split should be populated + assert f_split is not None + assert f_split.split_date_included_in_input == last_date + assert f_split.sampled_variables == ["hemoglobin_-_718-7"] + + # Events split should be None + assert e_split is None + + +def test_inference_events_only(initialized_dm, mock_config): + """Test inference split when only the events splitter is provided.""" + splitter_events = DataSplitterEvents(initialized_dm, config=mock_config) + splitter_events.setup_variables() + + data_splitter = DataSplitter(data_splitter_events=splitter_events) + patient_data = initialized_dm.get_patient_data("p0") + + f_split, e_split = data_splitter.get_splits_from_patient_inference( + patient_data, + inference_type="events", + events_override_category="death", + events_override_observation_time_delta=pd.Timedelta(weeks=52), + ) + + last_date = patient_data["events"]["date"].max() + + # Forecasting split should be None + assert f_split is None + + # Events split should be populated + assert e_split is not None + assert e_split.split_date_included_in_input == last_date + assert e_split.sampled_category == "death" + + +def test_inference_both_type_with_only_forecasting(initialized_dm, mock_config): + """Test that inference_type='both' gracefully returns None for the missing splitter.""" + splitter_forecast = DataSplitterForecasting(data_manager=initialized_dm, config=mock_config) + + data_splitter = DataSplitter(data_splitter_forecasting=splitter_forecast) + patient_data = initialized_dm.get_patient_data("p0") + + f_split, e_split = data_splitter.get_splits_from_patient_inference( + patient_data, + inference_type="both", + forecasting_override_variables_to_predict=["hemoglobin_-_718-7"], + ) + + # Forecasting should work + assert f_split is not None + assert f_split.sampled_variables == ["hemoglobin_-_718-7"] + + # Events should be None because no events splitter is set + assert e_split is None + + +def test_inference_both_type_with_only_events(initialized_dm, mock_config): + """Test that inference_type='both' gracefully returns None for the missing splitter.""" + splitter_events = DataSplitterEvents(initialized_dm, config=mock_config) + splitter_events.setup_variables() + + data_splitter = DataSplitter(data_splitter_events=splitter_events) + patient_data = initialized_dm.get_patient_data("p0") + + f_split, e_split = data_splitter.get_splits_from_patient_inference( + patient_data, + inference_type="both", + events_override_category="death", + events_override_observation_time_delta=pd.Timedelta(weeks=52), + ) + + # Forecasting should be None because no forecasting splitter is set + assert f_split is None + + # Events should work + assert e_split is not None + assert e_split.sampled_category == "death" diff --git a/twinweaver/instruction/data_splitter.py b/twinweaver/instruction/data_splitter.py index a992568..74208b9 100644 --- a/twinweaver/instruction/data_splitter.py +++ b/twinweaver/instruction/data_splitter.py @@ -9,9 +9,19 @@ class DataSplitter: """ Combines both data splitters into one interface for easier usage. For more advanced use cases, the individual data splitters can still be used directly. + + At least one of ``data_splitter_events`` or ``data_splitter_forecasting`` must be + provided. When only one splitter is supplied, the methods will return ``None`` / + empty results for the missing task type. """ - def __init__(self, data_splitter_events: DataSplitterEvents, data_splitter_forecasting: DataSplitterForecasting): + def __init__( + self, + data_splitter_events: DataSplitterEvents = None, + data_splitter_forecasting: DataSplitterForecasting = None, + ): + if data_splitter_events is None and data_splitter_forecasting is None: + raise ValueError("At least one of data_splitter_events or data_splitter_forecasting must be provided.") self.data_splitter_events = data_splitter_events self.data_splitter_forecasting = data_splitter_forecasting @@ -64,33 +74,56 @@ def get_splits_from_patient_with_target( ------- tuple A tuple containing three elements: - 1. forecasting_splits: list[DataSplitterForecastingGroup] - List of generated forecasting split groups. - 2. events_splits: list[DataSplitterEventsGroup] - List of generated event prediction split groups, corresponding to the forecasting splits. + 1. forecasting_splits: list[DataSplitterForecastingGroup] or None + List of generated forecasting split groups, or None if no forecasting splitter is set. + 2. events_splits: list[DataSplitterEventsGroup] or None + List of generated event prediction split groups, or None if no events splitter is set. 3. reference_dates: pd.DataFrame DataFrame containing the split dates and LoT dates used. """ - # Process forecasting splits - forecasting_splits, reference_dates = self.data_splitter_forecasting.get_splits_from_patient( - patient_data, - nr_samples_per_split=forecasting_nr_samples_per_split, - include_metadata=True, - max_num_splits_per_split_event=max_num_splits_per_split_event, - filter_outliers=forecasting_filter_outliers, - override_categories_to_predict=forecasting_override_categories_to_predict, - override_variables_to_predict=forecasting_override_variables_to_predict, - override_split_dates=forecasting_override_split_dates, - ) + forecasting_splits = None + events_splits = None + reference_dates = None - # Process event prediction splits - events_splits = self.data_splitter_events.get_splits_from_patient( - patient_data, - reference_split_dates=reference_dates, - max_nr_samples_per_split=events_max_nr_samples_per_split, - override_category=events_override_category, - override_observation_time_delta=events_override_observation_time_delta, - ) + # Process forecasting splits (if available) + if self.data_splitter_forecasting is not None: + forecasting_splits, reference_dates = self.data_splitter_forecasting.get_splits_from_patient( + patient_data, + nr_samples_per_split=forecasting_nr_samples_per_split, + include_metadata=True, + max_num_splits_per_split_event=max_num_splits_per_split_event, + filter_outliers=forecasting_filter_outliers, + override_categories_to_predict=forecasting_override_categories_to_predict, + override_variables_to_predict=forecasting_override_variables_to_predict, + override_split_dates=forecasting_override_split_dates, + ) + + # Process event prediction splits (if available) + if self.data_splitter_events is not None: + events_splits = self.data_splitter_events.get_splits_from_patient( + patient_data, + reference_split_dates=reference_dates, + max_nr_samples_per_split=events_max_nr_samples_per_split, + max_num_splits_per_split_event=max_num_splits_per_split_event, + override_category=events_override_category, + override_observation_time_delta=events_override_observation_time_delta, + ) + + # When only events splitter is used, extract reference_dates from events_splits + if reference_dates is None and events_splits is not None: + config = self.data_splitter_events.config + ref_rows = [] + for group in events_splits: + if len(group) > 0: + opt = group[0] + ref_rows.append( + { + config.date_col: opt.split_date_included_in_input, + config.split_date_col: opt.lot_date, + } + ) + if ref_rows: + reference_dates = pd.DataFrame(ref_rows) #: return both, since we want to be able to still have the flexibility to use both splitters directly return forecasting_splits, events_splits, reference_dates @@ -135,36 +168,49 @@ def get_splits_from_patient_inference( 2. events_split: DataSplitterEventsOption or None The generated event prediction option, or None if inference_type is 'forecasting'. """ + # Resolve the config from whichever splitter is available + _config = ( + self.data_splitter_events.config + if self.data_splitter_events is not None + else self.data_splitter_forecasting.config + ) + # assume last date in events is the split date that we're interested in - patient_data["events"] = patient_data["events"].sort_values(by=self.data_splitter_events.config.date_col) - split_date = patient_data["events"][self.data_splitter_events.config.date_col].iloc[-1] + patient_data["events"] = patient_data["events"].sort_values(by=_config.date_col) + split_date = patient_data["events"][_config.date_col].iloc[-1] #: generate forecasting split + forecast_split = None if inference_type == "both" or inference_type == "forecasting": - forecast_splits = self.data_splitter_forecasting.get_splits_from_patient( - patient_data, - nr_samples_per_split=1, - filter_outliers=False, # Since no filtering needed, since no target exists - override_split_dates=[split_date], - override_variables_to_predict=forecasting_override_variables_to_predict, - ) - # The first one is the only one - forecast_split = forecast_splits[0][0] - else: - forecast_split = None + if self.data_splitter_forecasting is None: + if inference_type != "both": + raise ValueError("DataSplitterForecasting must be set to generate forecasting splits.") + else: + forecast_splits = self.data_splitter_forecasting.get_splits_from_patient( + patient_data, + nr_samples_per_split=1, + filter_outliers=False, # Since no filtering needed, since no target exists + override_split_dates=[split_date], + override_variables_to_predict=forecasting_override_variables_to_predict, + ) + # The first one is the only one + forecast_split = forecast_splits[0][0] #: generate event split + events_split = None if inference_type == "both" or inference_type == "events": - events_splits = self.data_splitter_events.get_splits_from_patient( - patient_data, - max_nr_samples_per_split=1, - override_split_dates=[split_date], - override_category=events_override_category, - override_observation_time_delta=events_override_observation_time_delta, - ) - # The first one is the only one - events_split = events_splits[0][0] - else: - events_split = None + if self.data_splitter_events is None: + if inference_type != "both": + raise ValueError("DataSplitterEvents must be set to generate event prediction splits.") + else: + events_splits = self.data_splitter_events.get_splits_from_patient( + patient_data, + max_nr_samples_per_split=1, + override_split_dates=[split_date], + override_category=events_override_category, + override_observation_time_delta=events_override_observation_time_delta, + ) + # The first one is the only one + events_split = events_splits[0][0] return forecast_split, events_split From 2a5983d3cc7ccc65365bafe94137ecd7c818205c Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Tue, 17 Mar 2026 16:29:54 +0000 Subject: [PATCH 03/36] Added more extensive unit tests for DataSplitter --- tests/test_splitter.py | 134 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 130 insertions(+), 4 deletions(-) diff --git a/tests/test_splitter.py b/tests/test_splitter.py index bfb4685..266da19 100644 --- a/tests/test_splitter.py +++ b/tests/test_splitter.py @@ -159,11 +159,35 @@ def test_training_forecasting_only(initialized_dm, mock_config): # Reference dates should still be available from the forecasting splitter assert ref_dates is not None assert not ref_dates.empty + assert "date" in ref_dates.columns + assert "split_date" in ref_dates.columns + assert ref_dates.shape == (1, 2) + assert ref_dates["date"].iloc[0] == pd.Timestamp("2015-07-08") + assert ref_dates["split_date"].iloc[0] == pd.Timestamp("2015-05-06") - # Validate forecasting split structure + # Validate forecasting split structure and content f_split = forecasting_splits[0][0] assert f_split.events_until_split is not None assert f_split.constant_data["patientid"].iloc[0] == "p0" + assert f_split.events_until_split.shape == (32, 8) + assert f_split.split_date_included_in_input == pd.Timestamp("2015-07-08") + assert f_split.lot_date == pd.Timestamp("2015-05-06") + assert f_split.sampled_variables == ["hemoglobin_-_718-7"] + + # Target events should exist and be after the split date + assert not f_split.target_events_after_split.empty + assert f_split.target_events_after_split.shape[0] == 4 + assert f_split.target_events_after_split["date"].min() == pd.Timestamp("2015-07-29") + assert f_split.target_events_after_split["date"].max() == pd.Timestamp("2015-09-30") + + # Events in input must be <= split date + assert (f_split.events_until_split["date"] <= f_split.split_date_included_in_input).all() + assert f_split.events_until_split["date"].min() == pd.Timestamp("2015-04-19") + + # Constant data should contain expected columns + assert "birthyear" in f_split.constant_data.columns + assert "gender" in f_split.constant_data.columns + assert f_split.constant_data.shape[0] == 1 def test_training_events_only(initialized_dm, mock_config): @@ -183,17 +207,50 @@ def test_training_events_only(initialized_dm, mock_config): # Events should be populated assert events_splits is not None - assert len(events_splits) >= 1 + assert len(events_splits) == 1 # Reference dates should be reconstructed from events splits assert ref_dates is not None assert not ref_dates.empty + assert "date" in ref_dates.columns + assert "split_date" in ref_dates.columns + assert ref_dates.shape == (1, 2) + assert ref_dates["date"].iloc[0] == pd.Timestamp("2015-07-08") + assert ref_dates["split_date"].iloc[0] == pd.Timestamp("2015-05-06") - # Validate events split structure + # Validate events split structure and content e_split = events_splits[0][0] assert e_split.events_until_split is not None assert e_split.constant_data["patientid"].iloc[0] == "p0" - assert e_split.sampled_category in ["death", "progression"] + assert e_split.events_until_split.shape == (32, 8) + assert e_split.split_date_included_in_input == pd.Timestamp("2015-07-08") + assert e_split.lot_date == pd.Timestamp("2015-05-06") + # Category must be one of the configured event categories (or a backup thereof) + expected_mapping = {"death": "death", "progression": "next progression"} + assert e_split.sampled_category in list(expected_mapping.keys()) + list( + mock_config.data_splitter_events_backup_category_mapping.values() + ) + # Category name must match one of the configured descriptive names + assert e_split.sampled_category_name in expected_mapping.values() + + # Event outcome must be boolean + assert isinstance(e_split.event_occurred, bool) + # Censoring should be None or one of the known censoring types + assert e_split.event_censored in [None, "new_therapy_start", "end_of_data", "data_cutoff"] + # Observation end date must be after the split date + assert e_split.observation_end_date >= e_split.split_date_included_in_input + + # Events in input must be <= split date + assert (e_split.events_until_split["date"] <= e_split.split_date_included_in_input).all() + assert e_split.events_until_split["date"].min() == pd.Timestamp("2015-04-19") + assert e_split.events_until_split["date"].max() == pd.Timestamp("2015-07-08") + + # Constant data integrity + assert "birthyear" in e_split.constant_data.columns + assert "gender" in e_split.constant_data.columns + assert "histology" in e_split.constant_data.columns + assert "smoking_history" in e_split.constant_data.columns + assert e_split.constant_data.shape[0] == 1 def test_inference_forecasting_only(initialized_dm, mock_config): @@ -214,7 +271,24 @@ def test_inference_forecasting_only(initialized_dm, mock_config): # Forecasting split should be populated assert f_split is not None assert f_split.split_date_included_in_input == last_date + assert f_split.split_date_included_in_input == pd.Timestamp("2016-05-13") assert f_split.sampled_variables == ["hemoglobin_-_718-7"] + assert f_split.lot_date == "override" + + # Inference has no target data + assert f_split.target_events_after_split.empty + + # Input events must cover the full patient history up to last date + assert f_split.events_until_split.shape == (78, 8) + assert (f_split.events_until_split["date"] <= f_split.split_date_included_in_input).all() + assert f_split.events_until_split["date"].min() == pd.Timestamp("2015-04-19") + assert f_split.events_until_split["date"].max() == pd.Timestamp("2016-05-13") + + # Constant data integrity + assert f_split.constant_data["patientid"].iloc[0] == "p0" + assert "birthyear" in f_split.constant_data.columns + assert "gender" in f_split.constant_data.columns + assert f_split.constant_data.shape[0] == 1 # Events split should be None assert e_split is None @@ -243,7 +317,25 @@ def test_inference_events_only(initialized_dm, mock_config): # Events split should be populated assert e_split is not None assert e_split.split_date_included_in_input == last_date + assert e_split.split_date_included_in_input == pd.Timestamp("2016-05-13") assert e_split.sampled_category == "death" + assert e_split.sampled_category_name == "death" + + # p0's last event is death itself; predicting death from last date with 52-week window: + # death already occurred at last_date so looking forward finds nothing → censored end_of_data + assert e_split.event_occurred is False + assert e_split.event_censored == "end_of_data" + assert e_split.observation_end_date == pd.Timestamp("2017-05-12") + + # Input events must cover the full patient history up to last date + assert e_split.events_until_split.shape == (78, 8) + assert (e_split.events_until_split["date"] <= e_split.split_date_included_in_input).all() + assert e_split.events_until_split["date"].max() == pd.Timestamp("2016-05-13") + + # Constant data integrity + assert e_split.constant_data["patientid"].iloc[0] == "p0" + assert "birthyear" in e_split.constant_data.columns + assert e_split.constant_data.shape[0] == 1 def test_inference_both_type_with_only_forecasting(initialized_dm, mock_config): @@ -259,9 +351,25 @@ def test_inference_both_type_with_only_forecasting(initialized_dm, mock_config): forecasting_override_variables_to_predict=["hemoglobin_-_718-7"], ) + last_date = patient_data["events"]["date"].max() + # Forecasting should work assert f_split is not None assert f_split.sampled_variables == ["hemoglobin_-_718-7"] + assert f_split.split_date_included_in_input == last_date + assert f_split.split_date_included_in_input == pd.Timestamp("2016-05-13") + assert f_split.lot_date == "override" + + # Inference: no target + assert f_split.target_events_after_split.empty + + # Full patient history should be used as input + assert f_split.events_until_split.shape == (78, 8) + assert (f_split.events_until_split["date"] <= f_split.split_date_included_in_input).all() + + # Constant data integrity + assert f_split.constant_data["patientid"].iloc[0] == "p0" + assert f_split.constant_data.shape[0] == 1 # Events should be None because no events splitter is set assert e_split is None @@ -282,9 +390,27 @@ def test_inference_both_type_with_only_events(initialized_dm, mock_config): events_override_observation_time_delta=pd.Timedelta(weeks=52), ) + last_date = patient_data["events"]["date"].max() + # Forecasting should be None because no forecasting splitter is set assert f_split is None # Events should work assert e_split is not None + assert e_split.split_date_included_in_input == last_date + assert e_split.split_date_included_in_input == pd.Timestamp("2016-05-13") assert e_split.sampled_category == "death" + assert e_split.sampled_category_name == "death" + + # Observation window and outcome + assert e_split.event_occurred is False + assert e_split.event_censored == "end_of_data" + assert e_split.observation_end_date == pd.Timestamp("2017-05-12") + + # Full patient history should be used as input + assert e_split.events_until_split.shape == (78, 8) + assert (e_split.events_until_split["date"] <= e_split.split_date_included_in_input).all() + + # Constant data integrity + assert e_split.constant_data["patientid"].iloc[0] == "p0" + assert e_split.constant_data.shape[0] == 1 From 71b490c73fff2f0359d5417352def065e107b643 Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Tue, 17 Mar 2026 16:42:25 +0000 Subject: [PATCH 04/36] Updated custom splitting notebooks to use updated unified DataSplitter API --- .../inference_individual_splitters.py | 70 ++++++++----------- .../training_forecasting_splitter_only.ipynb | 23 +++--- .../training_individual_splitters.ipynb | 37 +++++----- 3 files changed, 60 insertions(+), 70 deletions(-) diff --git a/examples/advanced/custom_splitting/inference_individual_splitters.py b/examples/advanced/custom_splitting/inference_individual_splitters.py index abf7273..7dfda94 100644 --- a/examples/advanced/custom_splitting/inference_individual_splitters.py +++ b/examples/advanced/custom_splitting/inference_individual_splitters.py @@ -2,6 +2,7 @@ DataSplitterForecasting, DataManager, DataSplitterEvents, + DataSplitter, ConverterInstruction, Config, ) @@ -45,10 +46,17 @@ def __init__( self.dm.setup_dataset_splits() self.dm.infer_var_types() - self.data_splitter_events = DataSplitterEvents(self.dm, config=self.config) - self.data_splitter_events.setup_variables() - self.data_splitter_forecasting = DataSplitterForecasting(data_manager=self.dm, config=self.config) - self.data_splitter_forecasting.setup_statistics() + data_splitter_events = DataSplitterEvents(self.dm, config=self.config) + data_splitter_events.setup_variables() + data_splitter_forecasting = DataSplitterForecasting(data_manager=self.dm, config=self.config) + data_splitter_forecasting.setup_statistics() + + # Use the unified DataSplitter API that combines both splitters + self.data_splitter = DataSplitter( + data_splitter_events=data_splitter_events, + data_splitter_forecasting=data_splitter_forecasting, + ) + self.converter = ConverterInstruction( nr_tokens_budget_total=8192, config=self.config, @@ -62,48 +70,26 @@ def convert_full_to_string_for_one_patient(self, patientid, override_events_or_f # To simulate that we only have input, half the events patient_data["events"] = patient_data["events"].iloc[: int(len(patient_data["events"]) / 2)] - # Here then split date - split_date = patient_data["events"]["date"].iloc[-1] - - #: generate event split - NOTE: this if statement is only to exemplify both cases! - if override_events_or_forecasting == "events": - ####### Example if we want to override for events - - events_splits = self.data_splitter_events.get_splits_from_patient( - patient_data, - max_nr_samples=1, - override_split_dates=[split_date], - override_category="death", - override_end_week_delta=52, - ) - # We just pick the first one - events_split = events_splits[0][0] - - #: no forecasting split - forecast_split = None - forecasting_times_to_predict = None - else: - ####### Example if we want to override for forecasting - - #: generate forecasting split - forecast_splits = self.data_splitter_forecasting.get_splits_from_patient( - patient_data, - nr_samples_per_split=1, - filter_outliers=False, - override_split_dates=[split_date], - override_variables_to_predict=["Neutrophils"], - ) - # We just pick the first one - forecast_split = forecast_splits[0][0] - - # We set which weeks to predict + # Use the unified DataSplitter API for inference + forecast_split, events_split = self.data_splitter.get_splits_from_patient_inference( + patient_data, + inference_type=override_events_or_forecasting, + forecasting_override_variables_to_predict=["Neutrophils"] + if override_events_or_forecasting != "events" + else None, + events_override_category="death" if override_events_or_forecasting != "forecasting" else None, + events_override_observation_time_delta=pd.Timedelta(weeks=52) + if override_events_or_forecasting != "forecasting" + else None, + ) + + # Set which weeks to predict for forecasting (if applicable) + forecasting_times_to_predict = None + if forecast_split is not None: forecasting_times_to_predict = { "Neutrophils": [1, 2, 8, 11], } - #: no events split - events_split = None - # Convert to text converted = self.converter.forward_conversion_inference( forecasting_split=forecast_split, diff --git a/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb b/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb index 373ea42..2a6e528 100644 --- a/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb +++ b/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb @@ -5,7 +5,7 @@ "id": "0", "metadata": {}, "source": [ - "# Forecasting-Only Example: Training Data Generation with Custom Dataset" + "# Forecasting-Only Example: Training Data Generation with the Unified DataSplitter API" ] }, { @@ -27,6 +27,7 @@ "\n", "from twinweaver import (\n", " DataSplitterForecasting,\n", + " DataSplitter,\n", " DataManager,\n", " ConverterInstruction,\n", " Config,\n", @@ -66,7 +67,7 @@ "id": "6", "metadata": {}, "source": [ - "Set up the data manager and the forecasting-only pipeline." + "Set up the data manager and the forecasting-only pipeline using the unified `DataSplitter` API. By passing only `data_splitter_forecasting`, the unified interface handles the forecasting-only case automatically." ] }, { @@ -116,6 +117,9 @@ "# In case you manually want to override the variables for forecasting selectiong, you can skip this next line.\n", "data_splitter_forecasting.setup_statistics()\n", "\n", + "# Use the unified DataSplitter API with only the forecasting splitter\n", + "data_splitter = DataSplitter(data_splitter_forecasting=data_splitter_forecasting)\n", + "\n", "converter = ConverterInstruction(\n", " nr_tokens_budget_total=8192,\n", " config=config,\n", @@ -198,9 +202,9 @@ "id": "16", "metadata": {}, "source": [ - "We start by generating random \"splits\" in the patient trajectory. We can make multiple relevant samples from each patient trajectory (e.g. depending on when the therapy started), and also to predict different variables (e.g. neutrophils/hemoglobin/... for forecasting).\n", + "We start by generating random \"splits\" in the patient trajectory using the unified `DataSplitter.get_splits_from_patient_with_target` method. We can make multiple relevant samples from each patient trajectory (e.g. depending on when the therapy started), and also to predict different variables (e.g. neutrophils/hemoglobin/... for forecasting).\n", "\n", - "Here we generate these random splits. We can also manually override them (see other examples on inference)." + "Since we only provided a forecasting splitter, `events_splits` will be `None`." ] }, { @@ -210,11 +214,10 @@ "metadata": {}, "outputs": [], "source": [ - "processed_splits_fc, split_dates = data_splitter_forecasting.get_splits_from_patient(\n", + "forecasting_splits, events_splits, reference_dates = data_splitter.get_splits_from_patient_with_target(\n", " patient_data,\n", - " nr_samples_per_split=4,\n", - " filter_outliers=False,\n", - " include_metadata=True,\n", + " forecasting_nr_samples_per_split=4,\n", + " forecasting_filter_outliers=False,\n", " max_num_splits_per_split_event=2,\n", ")" ] @@ -224,7 +227,7 @@ "id": "18", "metadata": {}, "source": [ - "Now for each split, we can generate the formatted strings. Note that `event_splits` is left empty since this example only uses the forecasting splitter." + "Now for each split, we can generate the formatted strings. Note that `events_splits` is `None` since we only provided a forecasting splitter, so we pass an empty list for `event_splits`." ] }, { @@ -236,7 +239,7 @@ "source": [ "split_idx = 0\n", "p_converted = converter.forward_conversion(\n", - " forecasting_splits=processed_splits_fc[split_idx],\n", + " forecasting_splits=forecasting_splits[split_idx],\n", " event_splits=[], # Not needed for forecasting-only splitter\n", " override_mode_to_select_forecasting=\"forecasting\",\n", ")" diff --git a/examples/advanced/custom_splitting/training_individual_splitters.ipynb b/examples/advanced/custom_splitting/training_individual_splitters.ipynb index 446142c..8bd7bb6 100644 --- a/examples/advanced/custom_splitting/training_individual_splitters.ipynb +++ b/examples/advanced/custom_splitting/training_individual_splitters.ipynb @@ -5,7 +5,7 @@ "id": "0", "metadata": {}, "source": [ - "# Example for single patient to convert using the instruction setup with custom dataset" + "# Example for single patient to convert using the unified DataSplitter API with custom dataset" ] }, { @@ -27,8 +27,9 @@ "\n", "from twinweaver import (\n", " DataSplitterForecasting,\n", - " DataManager,\n", " DataSplitterEvents,\n", + " DataSplitter,\n", + " DataManager,\n", " ConverterInstruction,\n", " Config,\n", ")" @@ -67,7 +68,7 @@ "id": "6", "metadata": {}, "source": [ - "Set up the data managers which hold the patient data." + "Set up the data managers and the unified `DataSplitter` which combines both event and forecasting splitters." ] }, { @@ -124,6 +125,12 @@ "# In case you manually want to override the variables for forecasting selectiong, you can skip this next line.\n", "data_splitter_forecasting.setup_statistics()\n", "\n", + "# Use the unified DataSplitter API that combines both splitters\n", + "data_splitter = DataSplitter(\n", + " data_splitter_events=data_splitter_events,\n", + " data_splitter_forecasting=data_splitter_forecasting,\n", + ")\n", + "\n", "converter = ConverterInstruction(\n", " nr_tokens_budget_total=8192,\n", " config=config,\n", @@ -206,9 +213,9 @@ "id": "16", "metadata": {}, "source": [ - "We start by generating random \"splits\" in the patient trajectory. We can make multiple relevant samples from each patient trajectory (e.g. depending on when the therapy started), and also to predict different variables (e.g. neutrophils/hemoglobin/... for forecasting, death/progression/metastases/next treatment for event).\n", + "We start by generating random \"splits\" in the patient trajectory using the unified `DataSplitter.get_splits_from_patient_with_target` method. This ensures that both forecasting and event splits use the same anchor points in time. We can make multiple relevant samples from each patient trajectory (e.g. depending on when the therapy started), and also to predict different variables (e.g. neutrophils/hemoglobin/... for forecasting, death/progression/metastases/next treatment for event).\n", "\n", - "Here we generate these random splits. We can also manually override them (see other examples on inference)." + "We can also manually override them (see other examples on inference)." ] }, { @@ -218,18 +225,12 @@ "metadata": {}, "outputs": [], "source": [ - "processed_splits_fc, split_dates = data_splitter_forecasting.get_splits_from_patient(\n", + "forecasting_splits, events_splits, reference_dates = data_splitter.get_splits_from_patient_with_target(\n", " patient_data,\n", - " nr_samples_per_split=4,\n", - " filter_outliers=False,\n", - " include_metadata=True,\n", + " forecasting_nr_samples_per_split=4,\n", + " forecasting_filter_outliers=False,\n", " max_num_splits_per_split_event=2,\n", - ")\n", - "\n", - "processed_splits_ev = data_splitter_events.get_splits_from_patient(\n", - " patient_data,\n", - " reference_split_dates=split_dates,\n", - " max_nr_samples_per_split=3,\n", + " events_max_nr_samples_per_split=3,\n", ")" ] }, @@ -250,8 +251,8 @@ "source": [ "split_idx = 0\n", "p_converted = converter.forward_conversion(\n", - " forecasting_splits=processed_splits_fc[split_idx],\n", - " event_splits=processed_splits_ev[split_idx],\n", + " forecasting_splits=forecasting_splits[split_idx],\n", + " event_splits=events_splits[split_idx],\n", " override_mode_to_select_forecasting=\"forecasting_qa\",\n", ")" ] @@ -296,7 +297,7 @@ "metadata": {}, "outputs": [], "source": [ - "date = split_dates[\"date\"][0]\n", + "date = reference_dates[\"date\"][0]\n", "return_list = converter.reverse_conversion(p_converted[\"answer\"], dm, date)\n", "return_list[2][\"result\"]" ] From 2ace8de023edf318e9b2d6c66417e8008f1e42a6 Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 12:25:35 +0000 Subject: [PATCH 05/36] Adjusted data splitting docs --- docs/data-splitting.md | 67 ++++++++++++++++--- docs/quickstart.md | 5 +- docs/tte-inference.md | 2 +- .../training_custom_split_events.ipynb | 5 +- .../training_forecasting_splitter_only.ipynb | 5 +- .../instruction/converter_instruction.py | 4 ++ 6 files changed, 73 insertions(+), 15 deletions(-) diff --git a/docs/data-splitting.md b/docs/data-splitting.md index 459260c..1b0a410 100644 --- a/docs/data-splitting.md +++ b/docs/data-splitting.md @@ -9,7 +9,7 @@ TwinWeaver provides specialized splitters for two complementary clinical predict | `DataSplitterForecasting` | Forecasting continuous or categorical variables | Predict hemoglobin values over the next 90 days | | `DataSplitterEvents` | Landmark event prediction (time-to-event) | Did the patient progress within 52 weeks? | -A unified `DataSplitter` interface combines both, ensuring they share the same split dates for multi-task training. +A unified `DataSplitter` interface combines one or both splitters into a single entry point. When both are supplied, it ensures they share the same split dates for multi-task training. Either splitter can also be used individually. --- @@ -171,14 +171,20 @@ config.data_splitter_events_variables_category_mapping = { ## Combined Splitting with `DataSplitter` -The `DataSplitter` class provides a unified interface that coordinates both splitters. This is the **recommended approach** for generating multi-task training data, as it ensures forecasting and event prediction tasks share the same split dates. +The `DataSplitter` class provides a unified interface that coordinates one or both splitters. At least one of `data_splitter_events` or `data_splitter_forecasting` must be provided. When both are supplied, it ensures they share the same split dates for multi-task training. When only one is supplied, the methods return `None` for the missing task type. -### Training Workflow +!!! tip "Single-task usage" + You don't need both splitters. Pass only `data_splitter_forecasting` for forecasting-only pipelines, or only `data_splitter_events` for event-prediction-only pipelines. See [Forecasting-Only](#forecasting-only) and [Events-Only](#events-only) below. + +### Training Workflow (Both Tasks) ```python from twinweaver import DataSplitter -data_splitter = DataSplitter(data_splitter_events, data_splitter_forecasting) +data_splitter = DataSplitter( + data_splitter_events=data_splitter_events, + data_splitter_forecasting=data_splitter_forecasting, +) # Generate aligned splits for both tasks forecasting_splits, events_splits, reference_dates = \ @@ -187,14 +193,49 @@ forecasting_splits, events_splits, reference_dates = \ Internally, `get_splits_from_patient_with_target`: -1. Calls `DataSplitterForecasting.get_splits_from_patient()` to determine split dates and generate forecasting tasks. -2. Passes those same split dates (`reference_dates`) to `DataSplitterEvents.get_splits_from_patient()` to generate aligned event prediction tasks. +1. Calls `DataSplitterForecasting.get_splits_from_patient()` (if available) to determine split dates and generate forecasting tasks. +2. Passes those same split dates (`reference_dates`) to `DataSplitterEvents.get_splits_from_patient()` (if available) to generate aligned event prediction tasks. +3. If only one splitter is provided, the other returns `None`. When only the events splitter is used, `reference_dates` are extracted from the generated event splits. + +This alignment is critical: when both task types are active, they see the same patient history up to the same point in time, enabling consistent multi-task learning. + +### Forecasting-Only + +```python +# Only forecasting — no event prediction splitter needed +data_splitter = DataSplitter(data_splitter_forecasting=data_splitter_forecasting) + +forecasting_splits, events_splits, reference_dates = \ + data_splitter.get_splits_from_patient_with_target(patient_data) +# events_splits is None + +converter.forward_conversion( + forecasting_splits=forecasting_splits[0], + event_splits=None, # No event splits available + override_mode_to_select_forecasting="forecasting", +) +``` -This alignment is critical: both task types see the same patient history up to the same point in time, enabling consistent multi-task learning. +### Events-Only + +```python +# Only event prediction — no forecasting splitter needed +data_splitter = DataSplitter(data_splitter_events=data_splitter_events) + +forecasting_splits, events_splits, reference_dates = \ + data_splitter.get_splits_from_patient_with_target(patient_data) +# forecasting_splits is None + +converter.forward_conversion( + forecasting_splits=None, # No forecasting splits available + event_splits=events_splits[0], + override_mode_to_select_forecasting="both", +) +``` ### Inference Workflow -For inference, use `get_splits_from_patient_inference`, which anchors the split at the **last available date** in the patient's record: +For inference, use `get_splits_from_patient_inference`, which anchors the split at the **last available date** in the patient's record. The `inference_type` parameter controls which tasks to generate — it defaults to `"both"` but gracefully handles the case when only one splitter is available: ```python forecast_split, events_split = data_splitter.get_splits_from_patient_inference( @@ -206,6 +247,9 @@ forecast_split, events_split = data_splitter.get_splits_from_patient_inference( ) ``` +!!! note + When `inference_type="both"` and only one splitter is provided, the missing task simply returns `None` without raising an error. If you request a specific `inference_type` (e.g., `"forecasting"`) but the corresponding splitter was not provided, a `ValueError` is raised. + --- ## How Multiple Training Examples Are Generated @@ -259,7 +303,10 @@ data_splitter_events.setup_variables() data_splitter_forecasting = DataSplitterForecasting(data_manager=dm, config=config) data_splitter_forecasting.setup_statistics() # Compute variable scores -data_splitter = DataSplitter(data_splitter_events, data_splitter_forecasting) +data_splitter = DataSplitter( + data_splitter_events=data_splitter_events, + data_splitter_forecasting=data_splitter_forecasting, +) # 4. Generate splits for a patient patient_data = dm.get_patient_data(dm.all_patientids[0]) @@ -290,4 +337,6 @@ print(result["answer"]) - **[Framework Overview](framework.md)**: Learn about TwinWeaver's architecture and task types - **[Data Preparation Tutorial](examples/01_data_preparation_for_training.ipynb)**: Step-by-step notebook walkthrough - **[Custom Splitting (Training)](examples/advanced/custom_splitting/training_individual_splitters.ipynb)**: Advanced splitting with individual splitters +- **[Forecasting-Only Splitting](examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb)**: Using `DataSplitter` with only the forecasting splitter +- **[Custom Split Events](examples/advanced/custom_splitting/training_custom_split_events.ipynb)**: Using `DataSplitter` with custom split events - **[API Reference — Data Splitters](reference/instruction/data_splitters.md)**: Full API documentation diff --git a/docs/quickstart.md b/docs/quickstart.md index 4b55545..a5ae1dd 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -63,7 +63,10 @@ data_splitter_forecasting = DataSplitterForecasting( ) # Combined interface for both task types -data_splitter = DataSplitter(data_splitter_events, data_splitter_forecasting) +data_splitter = DataSplitter( + data_splitter_events=data_splitter_events, + data_splitter_forecasting=data_splitter_forecasting, +) # Set up the text converter converter = ConverterInstruction( diff --git a/docs/tte-inference.md b/docs/tte-inference.md index d0428e1..f71a759 100644 --- a/docs/tte-inference.md +++ b/docs/tte-inference.md @@ -41,7 +41,7 @@ Patient data ──► DataSplitter (events) ──► ConverterInstruction ▼ compute_length_normalized_probabilities() │ - calibrated probabilities + probabilities + hard predictions (DataFrame) ``` diff --git a/examples/advanced/custom_splitting/training_custom_split_events.ipynb b/examples/advanced/custom_splitting/training_custom_split_events.ipynb index acd176a..274ccb3 100644 --- a/examples/advanced/custom_splitting/training_custom_split_events.ipynb +++ b/examples/advanced/custom_splitting/training_custom_split_events.ipynb @@ -213,7 +213,8 @@ "source": [ "forecasting_splits, events_splits, reference_dates = data_splitter.get_splits_from_patient_with_target(\n", " patient_data,\n", - ")" + ")\n", + "# Note, forecasting_splits will be none here" ] }, { @@ -233,7 +234,7 @@ "source": [ "split_idx = 0\n", "p_converted = converter.forward_conversion(\n", - " forecasting_splits=forecasting_splits[split_idx],\n", + " forecasting_splits=None, # Set to None since we don't want to generate forecasting tasks\n", " event_splits=events_splits[split_idx],\n", " override_mode_to_select_forecasting=\"both\",\n", ")" diff --git a/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb b/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb index 2a6e528..8abb9cd 100644 --- a/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb +++ b/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb @@ -219,7 +219,8 @@ " forecasting_nr_samples_per_split=4,\n", " forecasting_filter_outliers=False,\n", " max_num_splits_per_split_event=2,\n", - ")" + ")\n", + "# Note, events_splits will be None here since we don't have any split events for this patient" ] }, { @@ -240,7 +241,7 @@ "split_idx = 0\n", "p_converted = converter.forward_conversion(\n", " forecasting_splits=forecasting_splits[split_idx],\n", - " event_splits=[], # Not needed for forecasting-only splitter\n", + " event_splits=None, # Not needed for forecasting-only splitter\n", " override_mode_to_select_forecasting=\"forecasting\",\n", ")" ] diff --git a/twinweaver/instruction/converter_instruction.py b/twinweaver/instruction/converter_instruction.py index bc17130..4d20d1e 100644 --- a/twinweaver/instruction/converter_instruction.py +++ b/twinweaver/instruction/converter_instruction.py @@ -287,6 +287,10 @@ def forward_conversion( or if no tasks are generated. """ + # If events is None, set to empty list for easier processing + if event_splits is None: + event_splits = [] + #: make assertions that data has same split and lot date all_lot_dates_events = [x.lot_date for x in event_splits] all_lot_dates_forecasting = [x.lot_date for x in forecasting_splits] From 41d7cbd30026e322586e54c26f1f66b31f4d496b Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 12:29:29 +0000 Subject: [PATCH 06/36] Renamed allow_forecasting_beyond_next_split_date --- twinweaver/common/config.py | 10 ++++++---- twinweaver/instruction/data_splitter_forecasting.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/twinweaver/common/config.py b/twinweaver/common/config.py index 593239c..9fe0606 100644 --- a/twinweaver/common/config.py +++ b/twinweaver/common/config.py @@ -59,9 +59,9 @@ class Config: Column name for the name or identifier of the line of therapy (e.g., "First Line"). Default: "lot". event_value_lot_start : str Specific string value used in `event_value_col` to denote the start of a line of therapy. Default: "LoT Start". - skip_future_lot_filtering : bool - Flag indicating whether to skip filtering out future line of therapy events. Default: False. - Useful in case you accidentially overlap LoTs which are actually the same, use with caution. + allow_forecasting_beyond_next_split_date : bool + Flag indicating whether to allow forecasting of events that occur beyond the next split date + (e.g., next LoT event). Default: False. lot_concatenate_descriptive_and_value : bool Flag indicating whether to concatenate the descriptive name and value for line of therapy events. Default: False. @@ -283,7 +283,9 @@ def __init__(self): self.split_date_col: str = "split_date" self.lot_event_name: str = "lot" self.event_value_lot_start: str = "LoT Start" - self.skip_future_lot_filtering: bool = False # Whether to skip filtering future LoT events, by default False. + self.allow_forecasting_beyond_next_split_date: bool = ( + False # Whether to skip filtering future LoT events, by default False. + ) self.lot_concatenate_descriptive_and_value: bool = ( False # If true, concatenate descriptive name and value for LoT events, by default False (only event_vale.) ) diff --git a/twinweaver/instruction/data_splitter_forecasting.py b/twinweaver/instruction/data_splitter_forecasting.py index 0059092..44f124a 100644 --- a/twinweaver/instruction/data_splitter_forecasting.py +++ b/twinweaver/instruction/data_splitter_forecasting.py @@ -864,7 +864,7 @@ def _generate_variable_splits_for_date( lots = events[events[self.config.event_category_col] == self.config.event_category_lot] lots = lots[lots[self.config.date_col] > curr_date] lots = lots.sort_values(self.config.date_col) - if lots.shape[0] > 0 and not self.config.skip_future_lot_filtering: + if lots.shape[0] > 0 and not self.config.allow_forecasting_beyond_next_split_date: date_of_next_lot = lots[self.config.date_col].iloc[0] events_after_split = events_after_split[events_after_split[self.config.date_col] < date_of_next_lot] From 8a28591737af9a744c1cc587dcfdbe1d35b5abb4 Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 12:39:17 +0000 Subject: [PATCH 07/36] Fixed hard-coded LoT for splitting --- tests/test_splitter.py | 271 ++++++++++++++++++ .../instruction/data_splitter_forecasting.py | 16 +- 2 files changed, 280 insertions(+), 7 deletions(-) diff --git a/tests/test_splitter.py b/tests/test_splitter.py index 266da19..3869554 100644 --- a/tests/test_splitter.py +++ b/tests/test_splitter.py @@ -1,4 +1,5 @@ import pytest +import numpy as np import pandas as pd from twinweaver.common.data_manager import DataManager from twinweaver.instruction.data_splitter_events import DataSplitterEvents @@ -414,3 +415,273 @@ def test_inference_both_type_with_only_events(initialized_dm, mock_config): # Constant data integrity assert e_split.constant_data["patientid"].iloc[0] == "p0" assert e_split.constant_data.shape[0] == 1 + + +# ──────────────────────────────────────────────────────────────────────────── +# Test that forecasting truncation uses split_event_category (not just LoT) +# ──────────────────────────────────────────────────────────────────────────── + + +def test_forecasting_truncates_at_next_split_event_not_just_lot(): + """ + Verify that _generate_variable_splits_for_date truncates target events + at the next *split event* (config.split_event_category), not only at + the next LoT event (config.event_category_lot). + + Scenario (split_event_category = "custom_split"): + Timeline for a single patient: + Day 0 - custom_split event (the split event that anchors the window) + Day 5 - lab measurement (before split date, input) + Day 10 - split date (curr_date) + Day 15 - lab measurement (target - should be kept) + Day 20 - next custom_split (next split event - target boundary) + Day 25 - lab measurement (target - should be EXCLUDED) + Day 30 - lot event (LoT - should NOT be the boundary) + Day 35 - lab measurement (target - should be EXCLUDED) + + With the old code (filtering by event_category_lot), the target would + include days 15, 25, and 35 (cutting only at day 30 LoT). + With the fix (filtering by split_event_category), the target should + include only day 15 (cutting at day 20 custom_split). + """ + from twinweaver.common.config import Config + + cfg = Config() + cfg.split_event_category = "custom_split" + cfg.event_category_lot = "lot" + cfg.event_category_forecast = ["lab"] + cfg.allow_forecasting_beyond_next_split_date = False + + base_date = pd.Timestamp("2020-01-01") + + events = pd.DataFrame( + { + cfg.date_col: [base_date + pd.Timedelta(days=d) for d in [0, 5, 10, 15, 20, 25, 30, 35]], + cfg.event_category_col: [ + "custom_split", # Day 0: split event + "lab", # Day 5: lab (input) + "lab", # Day 10: lab at split date (input) + "lab", # Day 15: lab (target, before next split event) + "custom_split", # Day 20: next split event + "lab", # Day 25: lab (target, after next split event - exclude) + "lot", # Day 30: lot event (should NOT be the boundary) + "lab", # Day 35: lab (target, after lot - exclude) + ], + cfg.event_name_col: [ + "split_marker", + "hemoglobin", + "hemoglobin", + "hemoglobin", + "split_marker", + "hemoglobin", + "lot_marker", + "hemoglobin", + ], + cfg.event_value_col: [ + "start", + "13.0", + "13.1", + "13.2", + "start", + "13.3", + "LoT Start", + "13.4", + ], + cfg.event_descriptive_name_col: [ + "split marker", + "hemoglobin", + "hemoglobin", + "hemoglobin", + "split marker", + "hemoglobin", + "LoT", + "hemoglobin", + ], + cfg.source_col: ["events"] * 8, + cfg.meta_data_col: [pd.NA] * 8, + } + ) + + constant = pd.DataFrame( + { + cfg.patient_id_col: ["p_test"], + cfg.constant_split_col: ["train"], + } + ) + + patient_data = {"events": events, "constant": constant} + curr_date = base_date + pd.Timedelta(days=10) + lot_date = base_date # The split event that anchors this window + + # Build a minimal all_possible_split_dates with hemoglobin valid at curr_date + all_possible_split_dates = pd.DataFrame( + { + cfg.date_col: [curr_date], + cfg.event_name_col: ["hemoglobin"], + cfg.event_category_col: ["lab"], + "lot_date": [lot_date], + } + ) + + # Create a DataManager stub - we only need dm.variable_types for the splitter + dm = DataManager.__new__(DataManager) + dm.config = cfg + dm.variable_types = {"hemoglobin": "numeric"} + dm.data_frames = {} + dm.all_patientids = ["p_test"] + + splitter = DataSplitterForecasting( + config=cfg, + data_manager=dm, + max_forecast_time_for_value=pd.Timedelta(days=90), + max_lookback_time_for_value=pd.Timedelta(days=90), + max_split_length_after_split_event=pd.Timedelta(days=90), + sampling_strategy="uniform", + ) + + np.random.seed(42) + + (date_splits, valid_sample_date, date_splits_meta, _) = splitter._generate_variable_splits_for_date( + curr_date=curr_date, + nr_samples=1, + override_variables_to_predict=["hemoglobin"], + events=events, + all_possible_split_dates=all_possible_split_dates, + apply_filtering=False, + override_split_dates=None, + patient_data=patient_data, + lot_date=lot_date, + ) + + assert valid_sample_date is True + assert len(date_splits) == 1 + + target = date_splits[0].target_events_after_split + + # The target must only include lab events BEFORE the next custom_split (day 20). + # So only the measurement at day 15 should survive. + assert target.shape[0] == 1, ( + f"Expected 1 target event (day 15 only), got {target.shape[0]}. " + f"Dates in target: {target[cfg.date_col].tolist()}" + ) + assert target[cfg.date_col].iloc[0] == base_date + pd.Timedelta(days=15) + + # Also verify input events include everything up to and including curr_date + input_events = date_splits[0].events_until_split + assert (input_events[cfg.date_col] <= curr_date).all() + assert input_events.shape[0] == 3 # Day 0, 5, 10 + + +def test_forecasting_truncation_allow_beyond_next_split_date(): + """ + Verify that when allow_forecasting_beyond_next_split_date=True, + target events are NOT truncated at the next split event. + """ + from twinweaver.common.config import Config + + cfg = Config() + cfg.split_event_category = "custom_split" + cfg.event_category_lot = "lot" + cfg.event_category_forecast = ["lab"] + cfg.allow_forecasting_beyond_next_split_date = True + + base_date = pd.Timestamp("2020-01-01") + + events = pd.DataFrame( + { + cfg.date_col: [base_date + pd.Timedelta(days=d) for d in [0, 5, 10, 15, 20, 25]], + cfg.event_category_col: [ + "custom_split", + "lab", + "lab", + "lab", + "custom_split", + "lab", + ], + cfg.event_name_col: [ + "split_marker", + "hemoglobin", + "hemoglobin", + "hemoglobin", + "split_marker", + "hemoglobin", + ], + cfg.event_value_col: ["start", "13.0", "13.1", "13.2", "start", "13.3"], + cfg.event_descriptive_name_col: [ + "split marker", + "hemoglobin", + "hemoglobin", + "hemoglobin", + "split marker", + "hemoglobin", + ], + cfg.source_col: ["events"] * 6, + cfg.meta_data_col: [pd.NA] * 6, + } + ) + + constant = pd.DataFrame( + { + cfg.patient_id_col: ["p_test"], + cfg.constant_split_col: ["train"], + } + ) + + patient_data = {"events": events, "constant": constant} + curr_date = base_date + pd.Timedelta(days=10) + lot_date = base_date + + all_possible_split_dates = pd.DataFrame( + { + cfg.date_col: [curr_date], + cfg.event_name_col: ["hemoglobin"], + cfg.event_category_col: ["lab"], + "lot_date": [lot_date], + } + ) + + dm = DataManager.__new__(DataManager) + dm.config = cfg + dm.variable_types = {"hemoglobin": "numeric"} + dm.data_frames = {} + dm.all_patientids = ["p_test"] + + splitter = DataSplitterForecasting( + config=cfg, + data_manager=dm, + max_forecast_time_for_value=pd.Timedelta(days=90), + max_lookback_time_for_value=pd.Timedelta(days=90), + max_split_length_after_split_event=pd.Timedelta(days=90), + sampling_strategy="uniform", + ) + + np.random.seed(42) + + (date_splits, valid_sample_date, _, _) = splitter._generate_variable_splits_for_date( + curr_date=curr_date, + nr_samples=1, + override_variables_to_predict=["hemoglobin"], + events=events, + all_possible_split_dates=all_possible_split_dates, + apply_filtering=False, + override_split_dates=None, + patient_data=patient_data, + lot_date=lot_date, + ) + + assert valid_sample_date is True + assert len(date_splits) == 1 + + target = date_splits[0].target_events_after_split + + # With allow_forecasting_beyond_next_split_date=True, no truncation at + # the next split event, so both day 15 and day 25 labs should be in target. + assert target.shape[0] == 2, ( + f"Expected 2 target events (days 15 and 25), got {target.shape[0]}. " + f"Dates in target: {target[cfg.date_col].tolist()}" + ) + expected_dates = [ + base_date + pd.Timedelta(days=15), + base_date + pd.Timedelta(days=25), + ] + assert target[cfg.date_col].tolist() == expected_dates diff --git a/twinweaver/instruction/data_splitter_forecasting.py b/twinweaver/instruction/data_splitter_forecasting.py index 44f124a..2eceb06 100644 --- a/twinweaver/instruction/data_splitter_forecasting.py +++ b/twinweaver/instruction/data_splitter_forecasting.py @@ -860,13 +860,15 @@ def _generate_variable_splits_for_date( events_after_split[self.config.event_name_col].isin(sampled_variables) ] - #: filter so that we do not overlap with next LoT, since that will invalidate the results - lots = events[events[self.config.event_category_col] == self.config.event_category_lot] - lots = lots[lots[self.config.date_col] > curr_date] - lots = lots.sort_values(self.config.date_col) - if lots.shape[0] > 0 and not self.config.allow_forecasting_beyond_next_split_date: - date_of_next_lot = lots[self.config.date_col].iloc[0] - events_after_split = events_after_split[events_after_split[self.config.date_col] < date_of_next_lot] + #: filter so that we do not overlap with next split event, since that will invalidate the results + next_split_events = events[events[self.config.event_category_col] == self.config.split_event_category] + next_split_events = next_split_events[next_split_events[self.config.date_col] > curr_date] + next_split_events = next_split_events.sort_values(self.config.date_col) + if next_split_events.shape[0] > 0 and not self.config.allow_forecasting_beyond_next_split_date: + date_of_next_split_event = next_split_events[self.config.date_col].iloc[0] + events_after_split = events_after_split[ + events_after_split[self.config.date_col] < date_of_next_split_event + ] #: if apply_filtering, apply 3-sigma filtering (only to target) and drop any bad rows if apply_filtering: From 36141b1df2436366b57071fd0dd1c5bf15c284ad Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 12:43:37 +0000 Subject: [PATCH 08/36] allow_forecasting_beyond_next_split_date to init of DSF --- tests/test_splitter.py | 3 +-- twinweaver/common/config.py | 7 +------ twinweaver/instruction/data_splitter_forecasting.py | 7 ++++++- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/test_splitter.py b/tests/test_splitter.py index 3869554..6cac6a9 100644 --- a/tests/test_splitter.py +++ b/tests/test_splitter.py @@ -450,7 +450,6 @@ def test_forecasting_truncates_at_next_split_event_not_just_lot(): cfg.split_event_category = "custom_split" cfg.event_category_lot = "lot" cfg.event_category_forecast = ["lab"] - cfg.allow_forecasting_beyond_next_split_date = False base_date = pd.Timestamp("2020-01-01") @@ -583,7 +582,6 @@ def test_forecasting_truncation_allow_beyond_next_split_date(): cfg.split_event_category = "custom_split" cfg.event_category_lot = "lot" cfg.event_category_forecast = ["lab"] - cfg.allow_forecasting_beyond_next_split_date = True base_date = pd.Timestamp("2020-01-01") @@ -653,6 +651,7 @@ def test_forecasting_truncation_allow_beyond_next_split_date(): max_lookback_time_for_value=pd.Timedelta(days=90), max_split_length_after_split_event=pd.Timedelta(days=90), sampling_strategy="uniform", + allow_forecasting_beyond_next_split_date=True, ) np.random.seed(42) diff --git a/twinweaver/common/config.py b/twinweaver/common/config.py index 9fe0606..87a0e7d 100644 --- a/twinweaver/common/config.py +++ b/twinweaver/common/config.py @@ -59,9 +59,6 @@ class Config: Column name for the name or identifier of the line of therapy (e.g., "First Line"). Default: "lot". event_value_lot_start : str Specific string value used in `event_value_col` to denote the start of a line of therapy. Default: "LoT Start". - allow_forecasting_beyond_next_split_date : bool - Flag indicating whether to allow forecasting of events that occur beyond the next split date - (e.g., next LoT event). Default: False. lot_concatenate_descriptive_and_value : bool Flag indicating whether to concatenate the descriptive name and value for line of therapy events. Default: False. @@ -283,9 +280,7 @@ def __init__(self): self.split_date_col: str = "split_date" self.lot_event_name: str = "lot" self.event_value_lot_start: str = "LoT Start" - self.allow_forecasting_beyond_next_split_date: bool = ( - False # Whether to skip filtering future LoT events, by default False. - ) + self.lot_concatenate_descriptive_and_value: bool = ( False # If true, concatenate descriptive name and value for LoT events, by default False (only event_vale.) ) diff --git a/twinweaver/instruction/data_splitter_forecasting.py b/twinweaver/instruction/data_splitter_forecasting.py index 2eceb06..d827901 100644 --- a/twinweaver/instruction/data_splitter_forecasting.py +++ b/twinweaver/instruction/data_splitter_forecasting.py @@ -126,6 +126,7 @@ def __init__( max_nr_variables_to_sample: int = 3, filtering_strategy: str = "3-sigma", sampling_strategy: str = "proportional", + allow_forecasting_beyond_next_split_date: bool = False, ): """ Initializes the DataSplitterForecasting instance. @@ -174,6 +175,9 @@ def __init__( sampling_strategy : str The strategy for sampling variables ('proportional' or 'uniform'). Defaults to 'proportional'. + allow_forecasting_beyond_next_split_date : bool + Flag indicating whether to allow forecasting of events that occur beyond the next split date + (e.g., next LoT event). Default: False. """ super().__init__( data_manager, @@ -204,6 +208,7 @@ def __init__( self.max_nr_variables_to_sample = max_nr_variables_to_sample self.filtering_strategy = filtering_strategy self.sampling_strategy = sampling_strategy + self.allow_forecasting_beyond_next_split_date = allow_forecasting_beyond_next_split_date self._filtering_methods = {"3-sigma": self._filter_3_sigma} @@ -864,7 +869,7 @@ def _generate_variable_splits_for_date( next_split_events = events[events[self.config.event_category_col] == self.config.split_event_category] next_split_events = next_split_events[next_split_events[self.config.date_col] > curr_date] next_split_events = next_split_events.sort_values(self.config.date_col) - if next_split_events.shape[0] > 0 and not self.config.allow_forecasting_beyond_next_split_date: + if next_split_events.shape[0] > 0 and not self.allow_forecasting_beyond_next_split_date: date_of_next_split_event = next_split_events[self.config.date_col].iloc[0] events_after_split = events_after_split[ events_after_split[self.config.date_col] < date_of_next_split_event From d7da5dc53f493dd3466952aa68043248f48145a0 Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 12:47:35 +0000 Subject: [PATCH 09/36] Removed lot_event_name and event_value_lot_start dependencies --- .../examples/integrations/meds_data_import.ipynb | 1 - examples/integrations/meds_data_import.ipynb | 1 - twinweaver/common/config.py | 6 ------ twinweaver/common/converter_base.py | 16 ++++++++++------ 4 files changed, 10 insertions(+), 14 deletions(-) diff --git a/docs/examples/integrations/meds_data_import.ipynb b/docs/examples/integrations/meds_data_import.ipynb index c7a2057..8bedbda 100644 --- a/docs/examples/integrations/meds_data_import.ipynb +++ b/docs/examples/integrations/meds_data_import.ipynb @@ -475,7 +475,6 @@ "config = Config() # Override values here to customize pipeline\n", "config.constant_columns_to_use = constant_columns\n", "config.constant_birthdate_column = None # Not using in demo\n", - "config.lot_event_name = None # Setting for LoTs\n", "config.event_value_lot_start = None\n", "config.split_event_category = \"lot\"\n", "config.data_splitter_events_variables_category_mapping = {\n", diff --git a/examples/integrations/meds_data_import.ipynb b/examples/integrations/meds_data_import.ipynb index a93ae38..076e5eb 100644 --- a/examples/integrations/meds_data_import.ipynb +++ b/examples/integrations/meds_data_import.ipynb @@ -475,7 +475,6 @@ "config = Config() # Override values here to customize pipeline\n", "config.constant_columns_to_use = constant_columns\n", "config.constant_birthdate_column = None # Not using in demo\n", - "config.lot_event_name = None # Setting for LoTs\n", "config.event_value_lot_start = None\n", "config.split_event_category = \"lot\"\n", "config.data_splitter_events_variables_category_mapping = {\n", diff --git a/twinweaver/common/config.py b/twinweaver/common/config.py index 87a0e7d..78b9cb8 100644 --- a/twinweaver/common/config.py +++ b/twinweaver/common/config.py @@ -55,10 +55,6 @@ class Config: Default value to assign to `source_col` if it is missing. Default: "events". split_date_col : str Column name specifically used for dates related to line of therapy (LoT) events. Default: "lot_date". - lot_event_name : str - Column name for the name or identifier of the line of therapy (e.g., "First Line"). Default: "lot". - event_value_lot_start : str - Specific string value used in `event_value_col` to denote the start of a line of therapy. Default: "LoT Start". lot_concatenate_descriptive_and_value : bool Flag indicating whether to concatenate the descriptive name and value for line of therapy events. Default: False. @@ -278,8 +274,6 @@ def __init__(self): self.event_meta_default_value = pd.NA # Default value for event meta data if not present self.source_col_default_value: str = "events" # Default value for source column if not present self.split_date_col: str = "split_date" - self.lot_event_name: str = "lot" - self.event_value_lot_start: str = "LoT Start" self.lot_concatenate_descriptive_and_value: bool = ( False # If true, concatenate descriptive name and value for LoT events, by default False (only event_vale.) diff --git a/twinweaver/common/converter_base.py b/twinweaver/common/converter_base.py index d90e04f..03096bf 100644 --- a/twinweaver/common/converter_base.py +++ b/twinweaver/common/converter_base.py @@ -1049,7 +1049,13 @@ def _get_all_most_recent_events_within_budget(self, events: pd.DataFrame, budget #: return events return events_final.reset_index(drop=True) # Reset index for clean output - def _generate_summarized_row_string(self, input_event_data, combined_target_meta: dict) -> str: + def _generate_summarized_row_string( + self, + input_event_data, + combined_target_meta: dict, + lot_event_name: str = "lot", + event_value_lot_start: str = "LoT Start", + ) -> str: """ Creates a summary string containing the most recent genetic, LoT, and target variable values. @@ -1131,12 +1137,10 @@ def _generate_summarized_row_string(self, input_event_data, combined_target_meta lot_info = lot_info.sort_values(self.config.date_col) # Create selections based on event name and event value using config constants - if self.config.lot_event_name is not None and self.config.event_value_lot_start is not None: - lot_selection_1 = lot_info[ - lot_info[self.config.event_name_col] == self.config.lot_event_name - ] # Using config attribute + if lot_event_name is not None and event_value_lot_start is not None: + lot_selection_1 = lot_info[lot_info[self.config.event_name_col] == lot_event_name] # Using config attribute lot_selection_2 = lot_info[ - lot_info[self.config.event_value_col] == self.config.event_value_lot_start + lot_info[self.config.event_value_col] == event_value_lot_start ] # Using config attribute else: # Just use all lot_info if no specific columns are defined From 8dfa4cc4e6ca107468d5bc067a953d29f12385d8 Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 13:02:54 +0000 Subject: [PATCH 10/36] Removed dependencies on event_category_lots --- .pre-commit-config.yaml | 6 ++-- docs/data-splitting.md | 2 +- .../03_end_to_end_llm_finetuning.ipynb | 4 +-- ...nd_to_end_llm_training_with_pretrain.ipynb | 4 +-- examples/03_end_to_end_llm_finetuning.ipynb | 4 +-- ...nd_to_end_llm_training_with_pretrain.ipynb | 4 +-- .../02_llm_finetuning_challenge.ipynb | 4 +-- tests/test_common.py | 1 - tests/test_splitter.py | 9 ++---- twinweaver/common/config.py | 4 --- twinweaver/common/converter_base.py | 3 +- twinweaver/common/data_manager.py | 21 ++----------- .../instruction/data_splitter_events.py | 30 +++++++++++-------- .../instruction/data_splitter_forecasting.py | 4 +-- 14 files changed, 38 insertions(+), 62 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a57108f..ac4e937 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,3 +1,5 @@ +exclude: ^examples/hackathon + repos: # 1. Standard "Cleanup" Hooks - repo: https://github.com/pre-commit/pre-commit-hooks @@ -11,7 +13,7 @@ repos: # 2. Ruff (Linting + Formatting) - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.9.3 + rev: v0.9.3 hooks: - id: ruff args: [ --fix ] @@ -21,4 +23,4 @@ repos: - repo: https://github.com/kynan/nbstripout rev: 0.8.1 hooks: - - id: nbstripout \ No newline at end of file + - id: nbstripout diff --git a/docs/data-splitting.md b/docs/data-splitting.md index 1b0a410..a207f79 100644 --- a/docs/data-splitting.md +++ b/docs/data-splitting.md @@ -124,7 +124,7 @@ flowchart TD D --> E{Event occurred
within window
and before censoring event?} E -->|Yes| F[occurred = True] E -->|No| G{Censored by
next LoT or data end?} - G -->|Next LoT| H[censored = new_therapy_start] + G -->|Next LoT| H[censored = new_split_date_start] G -->|End of data| I[censored = end_of_data] G -->|No censoring| J[censored = None
Event truly did not occur] F --> K[Create DataSplitterEventsOption] diff --git a/docs/examples/03_end_to_end_llm_finetuning.ipynb b/docs/examples/03_end_to_end_llm_finetuning.ipynb index 6a2901f..3e91776 100644 --- a/docs/examples/03_end_to_end_llm_finetuning.ipynb +++ b/docs/examples/03_end_to_end_llm_finetuning.ipynb @@ -490,9 +490,7 @@ "# Lets simulate forecasts for after the first line of therapy\n", "df_constant_patient = patient_data[\"constant\"].copy()\n", "df_events_patient = patient_data[\"events\"].copy()\n", - "date_of_first_lot = df_events_patient.loc[\n", - " df_events_patient[\"event_category\"] == config.event_category_lot, \"date\"\n", - "].min()\n", + "date_of_first_lot = df_events_patient.loc[df_events_patient[\"event_category\"] == \"lot\", \"date\"].min()\n", "\n", "# Only keep data until (and including) first line of therapy\n", "df_events_patient = df_events_patient.loc[df_events_patient[\"date\"] <= date_of_first_lot]" diff --git a/docs/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb b/docs/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb index 9c7fd54..0ac8652 100644 --- a/docs/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb +++ b/docs/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb @@ -401,9 +401,7 @@ "# Lets simulate forecasts for after the first line of therapy\n", "df_constant_patient = patient_data[\"constant\"].copy()\n", "df_events_patient = patient_data[\"events\"].copy()\n", - "date_of_first_lot = df_events_patient.loc[\n", - " df_events_patient[\"event_category\"] == config.event_category_lot, \"date\"\n", - "].min()\n", + "date_of_first_lot = df_events_patient.loc[df_events_patient[\"event_category\"] == \"lot\", \"date\"].min()\n", "date_of_first_event = df_events_patient[\"date\"].min()\n", "\n", "# Only keep data until (and including) first line of therapy\n", diff --git a/examples/03_end_to_end_llm_finetuning.ipynb b/examples/03_end_to_end_llm_finetuning.ipynb index 6a2901f..3e91776 100644 --- a/examples/03_end_to_end_llm_finetuning.ipynb +++ b/examples/03_end_to_end_llm_finetuning.ipynb @@ -490,9 +490,7 @@ "# Lets simulate forecasts for after the first line of therapy\n", "df_constant_patient = patient_data[\"constant\"].copy()\n", "df_events_patient = patient_data[\"events\"].copy()\n", - "date_of_first_lot = df_events_patient.loc[\n", - " df_events_patient[\"event_category\"] == config.event_category_lot, \"date\"\n", - "].min()\n", + "date_of_first_lot = df_events_patient.loc[df_events_patient[\"event_category\"] == \"lot\", \"date\"].min()\n", "\n", "# Only keep data until (and including) first line of therapy\n", "df_events_patient = df_events_patient.loc[df_events_patient[\"date\"] <= date_of_first_lot]" diff --git a/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb b/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb index 9c7fd54..0ac8652 100644 --- a/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb +++ b/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb @@ -401,9 +401,7 @@ "# Lets simulate forecasts for after the first line of therapy\n", "df_constant_patient = patient_data[\"constant\"].copy()\n", "df_events_patient = patient_data[\"events\"].copy()\n", - "date_of_first_lot = df_events_patient.loc[\n", - " df_events_patient[\"event_category\"] == config.event_category_lot, \"date\"\n", - "].min()\n", + "date_of_first_lot = df_events_patient.loc[df_events_patient[\"event_category\"] == \"lot\", \"date\"].min()\n", "date_of_first_event = df_events_patient[\"date\"].min()\n", "\n", "# Only keep data until (and including) first line of therapy\n", diff --git a/examples/hackathon/02_llm_finetuning_challenge.ipynb b/examples/hackathon/02_llm_finetuning_challenge.ipynb index e45d570..28f1df1 100644 --- a/examples/hackathon/02_llm_finetuning_challenge.ipynb +++ b/examples/hackathon/02_llm_finetuning_challenge.ipynb @@ -776,9 +776,7 @@ "\n", "# Get the date of first line of therapy\n", "df_events_patient = patient_data[\"events\"].copy()\n", - "date_of_first_lot = df_events_patient.loc[\n", - " df_events_patient[\"event_category\"] == config.event_category_lot, \"date\"\n", - "].min()\n", + "date_of_first_lot = df_events_patient.loc[df_events_patient[\"event_category\"] == \"lot\", \"date\"].min()\n", "\n", "print(f\"Test patient: {test_patientid}\")\n", "print(f\"First LoT date: {date_of_first_lot}\")" diff --git a/tests/test_common.py b/tests/test_common.py index ada745b..93d75e6 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -7,7 +7,6 @@ def test_config_initialization(mock_config): assert mock_config.seed == 42 assert mock_config.patient_id_col == "patientid" # Verify defaults used in the library - assert mock_config.event_category_lot == "lot" def test_data_manager_loading(mock_config, sample_data): diff --git a/tests/test_splitter.py b/tests/test_splitter.py index 6cac6a9..548a1df 100644 --- a/tests/test_splitter.py +++ b/tests/test_splitter.py @@ -237,7 +237,7 @@ def test_training_events_only(initialized_dm, mock_config): # Event outcome must be boolean assert isinstance(e_split.event_occurred, bool) # Censoring should be None or one of the known censoring types - assert e_split.event_censored in [None, "new_therapy_start", "end_of_data", "data_cutoff"] + assert e_split.event_censored in [None, "new_split_date_start", "end_of_data", "data_cutoff"] # Observation end date must be after the split date assert e_split.observation_end_date >= e_split.split_date_included_in_input @@ -425,8 +425,7 @@ def test_inference_both_type_with_only_events(initialized_dm, mock_config): def test_forecasting_truncates_at_next_split_event_not_just_lot(): """ Verify that _generate_variable_splits_for_date truncates target events - at the next *split event* (config.split_event_category), not only at - the next LoT event (config.event_category_lot). + at the next *split event* (config.split_event_category). Scenario (split_event_category = "custom_split"): Timeline for a single patient: @@ -439,7 +438,7 @@ def test_forecasting_truncates_at_next_split_event_not_just_lot(): Day 30 - lot event (LoT - should NOT be the boundary) Day 35 - lab measurement (target - should be EXCLUDED) - With the old code (filtering by event_category_lot), the target would + With the old code, the target would include days 15, 25, and 35 (cutting only at day 30 LoT). With the fix (filtering by split_event_category), the target should include only day 15 (cutting at day 20 custom_split). @@ -448,7 +447,6 @@ def test_forecasting_truncates_at_next_split_event_not_just_lot(): cfg = Config() cfg.split_event_category = "custom_split" - cfg.event_category_lot = "lot" cfg.event_category_forecast = ["lab"] base_date = pd.Timestamp("2020-01-01") @@ -580,7 +578,6 @@ def test_forecasting_truncation_allow_beyond_next_split_date(): cfg = Config() cfg.split_event_category = "custom_split" - cfg.event_category_lot = "lot" cfg.event_category_forecast = ["lab"] base_date = pd.Timestamp("2020-01-01") diff --git a/twinweaver/common/config.py b/twinweaver/common/config.py index 78b9cb8..7bd9aa4 100644 --- a/twinweaver/common/config.py +++ b/twinweaver/common/config.py @@ -63,8 +63,6 @@ class Config: `lot_concatenate_descriptive_and_value` is True. Default: " - ". warning_for_splitters_patient_without_splits : bool Whether to warn if a patient has no split events. Default: True. - event_category_lot : str - Specific string value used in `event_category_col` to identify 'line of therapy' events. Default: "lot". event_category_death : str Specific string value used in `event_category_col` to identify 'death' events. Default: "death". event_category_labs : str @@ -291,9 +289,7 @@ def __init__(self): self.event_categories_to_exclude_from_input: list = [] # --- Specific Event Categories / Values / Sources --- - self.event_category_lot: str = "lot" self.event_category_death: str = "death" - self.event_category_labs: str = "lab" self.source_genetic: str = "genetic" self.genetic_skip_text_value: str = "present" diff --git a/twinweaver/common/converter_base.py b/twinweaver/common/converter_base.py index 03096bf..c3eb494 100644 --- a/twinweaver/common/converter_base.py +++ b/twinweaver/common/converter_base.py @@ -1055,6 +1055,7 @@ def _generate_summarized_row_string( combined_target_meta: dict, lot_event_name: str = "lot", event_value_lot_start: str = "LoT Start", + event_category_lot: str = "lot", ) -> str: """ Creates a summary string containing the most recent genetic, LoT, and target variable values. @@ -1131,7 +1132,7 @@ def _generate_summarized_row_string( #: add most recent LoT info using config constants ret_prompt += self.config.forecasting_prompt_summarized_lot # Using config attribute - lot_info = input_event_data[input_event_data[self.config.event_category_col] == self.config.event_category_lot] + lot_info = input_event_data[input_event_data[self.config.event_category_col] == event_category_lot] # Ensure lot_info is sorted by date to correctly find the last one lot_info = lot_info.sort_values(self.config.date_col) diff --git a/twinweaver/common/data_manager.py b/twinweaver/common/data_manager.py index 05ee274..d1bca65 100644 --- a/twinweaver/common/data_manager.py +++ b/twinweaver/common/data_manager.py @@ -26,7 +26,7 @@ def __init__( validation_split_max: float = 0.1, test_split_max: float = 0.1, max_val_test_nr_patients: int = 500, - replace_special_symbols_override: list = None, + replace_special_symbols: list = None, ) -> None: """ Initializes the DataManager for a specific indication. @@ -54,7 +54,7 @@ def __init__( max_val_test_nr_patients : int, optional The absolute maximum number of patients to include in the validation and test sets combined. Defaults to 500. - replace_special_symbols_override : list, optional + replace_special_symbols : list, optional A list of tuples to override the default special character replacements in event descriptive names. Each tuple should be in the format `(event_category, (string_to_replace, replacement_string))`. If None, @@ -70,22 +70,7 @@ def __init__( self.variable_types = {} # event_name -> "numeric" / "categorical" # Setup replacing of special symbol, format is event_category : (, ) - if replace_special_symbols_override is not None: - self.replace_special_symbols = replace_special_symbols_override - else: - # Use config constants for event categories where available - self.replace_special_symbols = [ - (self.config.event_category_labs, ("/", " per ")), - (self.config.event_category_labs, (".", " ")), - ( - "drug", - ("/", " "), - ), # "drug" category not explicitly in Config constants provided - ( - self.config.event_category_lot, - ("/", " "), - ), # Use config for 'lot' category - ] + self.replace_special_symbols = replace_special_symbols if replace_special_symbols is not None else [] # Setup indication self.data_frames = None diff --git a/twinweaver/instruction/data_splitter_events.py b/twinweaver/instruction/data_splitter_events.py index 523238e..beeb96d 100644 --- a/twinweaver/instruction/data_splitter_events.py +++ b/twinweaver/instruction/data_splitter_events.py @@ -330,8 +330,8 @@ def get_splits_from_patient( # Do some quick sanity checks if self.config.warning_for_splitters_patient_without_splits: - lot_events = events[events[self.config.event_category_col] == self.config.event_category_lot] - if lot_events.shape[0] == 0: + split_events = events[events[self.config.event_category_col] == self.config.split_event_category] + if split_events.shape[0] == 0: logging.warning( "Patient " + str(patient_data["constant"][self.config.patient_id_col].iloc[0]) @@ -427,15 +427,18 @@ def get_splits_from_patient( diagnosis_after_split = events_limited_after_split[ events_limited_after_split[self.config.event_category_col] == sampled_cateogry ] - lot_after_split = events_limited_after_split[ - events_limited_after_split[self.config.event_category_col] == self.config.event_category_lot + next_split_date_after_split = events_limited_after_split[ + events_limited_after_split[self.config.event_category_col] == self.config.split_event_category ] death_after_split = events_limited_after_split[ events_limited_after_split[self.config.event_name_col] == self.config.event_category_death ] - #: apply censoring using next_lot_date - next_lot_date = lot_after_split[self.config.date_col].min() if len(lot_after_split) > 0 else None + #: apply censoring using next_split_date + if len(next_split_date_after_split) > 0: + next_split_date = next_split_date_after_split[self.config.date_col].min() + else: + next_split_date = None next_death_date = death_after_split[self.config.date_col].min() if len(death_after_split) > 0 else None #: determine whether occurred, censored & if so, which date @@ -447,18 +450,21 @@ def get_splits_from_patient( # Event occurred within end date occurred = True - # If an lot occurred first though, then we're censored - if next_lot_date is not None and diagnosis_after_split[self.config.date_col].min() > next_lot_date: - censored = "new_therapy_start" + # If a split occurred first though, then we're censored + if ( + next_split_date is not None + and diagnosis_after_split[self.config.date_col].min() > next_split_date + ): + censored = "new_split_date_start" occurred = False else: # Event did not occur occurred = False - if next_lot_date is not None: - # If we were censored by the next lot date - censored = "new_therapy_start" + if next_split_date is not None: + # If we were censored by the next split date + censored = "new_split_date_start" elif next_death_date is not None: # If death occurred then not censored, since this is the only time we diff --git a/twinweaver/instruction/data_splitter_forecasting.py b/twinweaver/instruction/data_splitter_forecasting.py index d827901..808ea2e 100644 --- a/twinweaver/instruction/data_splitter_forecasting.py +++ b/twinweaver/instruction/data_splitter_forecasting.py @@ -980,8 +980,8 @@ def get_splits_from_patient( # Do some quick sanity checks if self.config.warning_for_splitters_patient_without_splits: - lot_events = events[events[self.config.event_category_col] == self.config.event_category_lot] - if lot_events.shape[0] == 0: + split_events = events[events[self.config.event_category_col] == self.config.split_event_category] + if split_events.shape[0] == 0: logging.warning( "Patient " + str(patient_data["constant"][self.config.patient_id_col].iloc[0]) From bc61f78d4701f3aa99b78dcf3c3c983b285251ef Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 13:04:04 +0000 Subject: [PATCH 11/36] Removed dependency on event_category_labs --- twinweaver/common/config.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/twinweaver/common/config.py b/twinweaver/common/config.py index 7bd9aa4..eebd84f 100644 --- a/twinweaver/common/config.py +++ b/twinweaver/common/config.py @@ -65,8 +65,6 @@ class Config: Whether to warn if a patient has no split events. Default: True. event_category_death : str Specific string value used in `event_category_col` to identify 'death' events. Default: "death". - event_category_labs : str - Specific string value used in `event_category_col` to identify 'lab result' events. Default: "lab". event_category_forecast : list[str] | None List of event categories to be considered for forecasting tasks. Default: None. split_event_category : str | None From 159bc90443c03e342ef0969c3515b462514e2aa0 Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 13:05:21 +0000 Subject: [PATCH 12/36] Better docs on death --- twinweaver/common/config.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/twinweaver/common/config.py b/twinweaver/common/config.py index eebd84f..e871931 100644 --- a/twinweaver/common/config.py +++ b/twinweaver/common/config.py @@ -64,6 +64,8 @@ class Config: warning_for_splitters_patient_without_splits : bool Whether to warn if a patient has no split events. Default: True. event_category_death : str + Important for censoring in TTE tasks and for identifying death events in general, since in medicine + they are common and critical events. Specific string value used in `event_category_col` to identify 'death' events. Default: "death". event_category_forecast : list[str] | None List of event categories to be considered for forecasting tasks. Default: None. From a42bf3468eb1a7ce001188a68e70a4bb35dc7292 Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 13:40:23 +0000 Subject: [PATCH 13/36] Adjusted default values for data splitter forecasting and events --- docs/data-splitting.md | 16 ++--- .../01_data_preparation_for_training.ipynb | 8 ++- .../02_inference_prompt_preparation.ipynb | 8 ++- .../03_end_to_end_llm_finetuning.ipynb | 8 ++- .../custom_output/custom_summarized_row.ipynb | 13 +++- .../customizing_text_generation.ipynb | 26 +++++-- .../inference_individual_splitters.py | 13 +++- .../training_custom_split_events.ipynb | 8 ++- .../training_forecasting_splitter_only.ipynb | 1 + .../training_individual_splitters.ipynb | 8 ++- .../tte_probability_inference.ipynb | 13 +++- .../raw_data_preprocessing.ipynb | 8 ++- .../integrations/meds_data_import.ipynb | 7 +- .../01_data_preparation_for_training.ipynb | 8 ++- .../02_inference_prompt_preparation.ipynb | 8 ++- examples/03_end_to_end_llm_finetuning.ipynb | 8 ++- .../custom_output/custom_summarized_row.ipynb | 13 +++- .../customizing_text_generation.ipynb | 26 +++++-- .../inference_individual_splitters.py | 13 +++- .../training_custom_split_events.ipynb | 8 ++- .../training_forecasting_splitter_only.ipynb | 1 + .../training_individual_splitters.ipynb | 8 ++- .../tte_probability_inference.ipynb | 13 +++- .../raw_data_preprocessing.ipynb | 8 ++- examples/integrations/meds_data_import.ipynb | 7 +- tests/test_converter.py | 45 ++++++++++-- tests/test_splitter.py | 71 +++++++++++++++---- twinweaver/instruction/data_splitter_base.py | 10 --- .../instruction/data_splitter_events.py | 21 +++--- .../instruction/data_splitter_forecasting.py | 37 +++++----- 30 files changed, 339 insertions(+), 103 deletions(-) diff --git a/docs/data-splitting.md b/docs/data-splitting.md index a207f79..61b971e 100644 --- a/docs/data-splitting.md +++ b/docs/data-splitting.md @@ -33,7 +33,7 @@ Patient timeline Split dates are anchored to **split events** — a configurable event category (typically Line of Therapy, `"lot"`). The framework: 1. **Finds all split-event start dates** in the patient's history (e.g., every LoT start). -2. **Identifies candidate dates** within a window around each split event (controlled by `max_split_length_after_split_event`, default 90 days). +2. **Identifies candidate dates** within a window around each split event (controlled by `max_split_length_after_split_event`, default 0 days). 3. **Randomly samples** one or more candidate dates per split event (`max_num_splits_per_split_event`). This anchoring ensures that training examples are centered on clinically meaningful time points rather than arbitrary dates. @@ -67,7 +67,7 @@ For each candidate split date, the forecasting splitter: 1. **Checks variable eligibility**: A variable is valid at a given date only if it has at least `min_nr_variable_seen_previously` occurrences in the lookback window and `min_nr_variable_seen_after` occurrences in the forecast window. 2. **Samples variables**: Between `min_nr_variables_to_sample` and `max_nr_variables_to_sample` variables are selected per task, using weighted proportional sampling based on pre-computed statistics (optionally uniform sampling). -3. **Creates the split**: Events before the split date form the input; future values of the sampled variables (within `max_forecast_time_for_value`) form the target. +3. **Creates the split**: Events before the split date form the input; future values of the sampled variables (within `max_forecasted_trajectory_length`) form the target. 4. **Filters future LoT overlap**: Target events occurring after the next Line of Therapy start are excluded to avoid data leakage. ### Variable Statistics & Sampling @@ -96,13 +96,13 @@ When `filter_outliers=True`, the **3-sigma strategy** clips target values to the data_splitter_forecasting = DataSplitterForecasting( data_manager=dm, config=config, - max_split_length_after_split_event=pd.Timedelta(days=90), # Window after split event + max_forecasted_trajectory_length=pd.Timedelta(days=90), # Forecast horizon (required) + max_split_length_after_split_event=pd.Timedelta(days=90), # Window after split event max_lookback_time_for_value=pd.Timedelta(days=90), # Lookback for variable history - max_forecast_time_for_value=pd.Timedelta(days=90), # Forecast horizon min_nr_variable_seen_previously=1, # Min past occurrences min_nr_variable_seen_after=1, # Min future occurrences min_nr_variables_to_sample=1, # Min variables per task - max_nr_variables_to_sample=3, # Max variables per task + max_nr_variables_to_sample=1, # Max variables per task filtering_strategy="3-sigma", # Outlier handling sampling_strategy="proportional", # Weighted or uniform sampling ) @@ -136,7 +136,7 @@ flowchart TD For each candidate split date, the event splitter: 1. **Samples an event category** from the configured mapping (e.g., `"death"` or `"progression"`), avoiding duplicate categories per split. -2. **Samples a prediction window** of random duration between `min_length_to_sample` (default: 1 week) and `max_length_to_sample` (default: 104 weeks). This trains the model to handle variable-length horizons. +2. **Samples a prediction window** of random duration between `min_length_to_sample` and `max_length_to_sample` (both required, no defaults). This trains the model to handle variable-length horizons. 3. **Determines the outcome**: - **Occurred**: The event was observed within the window before any censoring events. - **Censored**: The observation was cut short by a new therapy start, end of data, or a data cutoff date. @@ -149,8 +149,8 @@ For each candidate split date, the event splitter: data_splitter_events = DataSplitterEvents( data_manager=dm, config=config, - max_length_to_sample=pd.Timedelta(weeks=104), # Max prediction window - min_length_to_sample=pd.Timedelta(weeks=1), # Min prediction window + max_length_to_sample=pd.Timedelta(weeks=104), # Max prediction window (required) + min_length_to_sample=pd.Timedelta(weeks=1), # Min prediction window (required) unit_length_to_sample="weeks", # Window sampling unit max_split_length_after_split_event=pd.Timedelta(days=90), # Window after split event ) diff --git a/docs/examples/01_data_preparation_for_training.ipynb b/docs/examples/01_data_preparation_for_training.ipynb index 60fe0bd..6264cd6 100644 --- a/docs/examples/01_data_preparation_for_training.ipynb +++ b/docs/examples/01_data_preparation_for_training.ipynb @@ -202,13 +202,19 @@ "outputs": [], "source": [ "# This data splitter handles event prediction tasks\n", - "data_splitter_events = DataSplitterEvents(dm, config=config)\n", + "data_splitter_events = DataSplitterEvents(\n", + " dm,\n", + " config=config,\n", + " max_length_to_sample=pd.Timedelta(weeks=104),\n", + " min_length_to_sample=pd.Timedelta(weeks=1),\n", + ")\n", "data_splitter_events.setup_variables()\n", "\n", "# This data splitter handles forecasting tasks\n", "data_splitter_forecasting = DataSplitterForecasting(\n", " data_manager=dm,\n", " config=config,\n", + " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n", ")\n", "# If you don't want to do forecasting QA, proportional sampling, or 3-sigma filtering, you can skip this step\n", "data_splitter_forecasting.setup_statistics()\n", diff --git a/docs/examples/02_inference_prompt_preparation.ipynb b/docs/examples/02_inference_prompt_preparation.ipynb index 299e9cd..fca3954 100644 --- a/docs/examples/02_inference_prompt_preparation.ipynb +++ b/docs/examples/02_inference_prompt_preparation.ipynb @@ -91,12 +91,18 @@ "dm.setup_dataset_splits()\n", "dm.infer_var_types()\n", "\n", - "data_splitter_events = DataSplitterEvents(dm, config=config)\n", + "data_splitter_events = DataSplitterEvents(\n", + " dm,\n", + " config=config,\n", + " max_length_to_sample=pd.Timedelta(weeks=104),\n", + " min_length_to_sample=pd.Timedelta(weeks=1),\n", + ")\n", "data_splitter_events.setup_variables()\n", "\n", "data_splitter_forecasting = DataSplitterForecasting(\n", " data_manager=dm,\n", " config=config,\n", + " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n", ")\n", "# In case you manually want to override the variables for forecasting selectiong, you can skip this next line.\n", "data_splitter_forecasting.setup_statistics()\n", diff --git a/docs/examples/03_end_to_end_llm_finetuning.ipynb b/docs/examples/03_end_to_end_llm_finetuning.ipynb index 3e91776..a73099d 100644 --- a/docs/examples/03_end_to_end_llm_finetuning.ipynb +++ b/docs/examples/03_end_to_end_llm_finetuning.ipynb @@ -147,12 +147,18 @@ "dm.setup_dataset_splits()\n", "dm.infer_var_types()\n", "\n", - "data_splitter_events = DataSplitterEvents(dm, config=config)\n", + "data_splitter_events = DataSplitterEvents(\n", + " dm,\n", + " config=config,\n", + " max_length_to_sample=pd.Timedelta(weeks=104),\n", + " min_length_to_sample=pd.Timedelta(weeks=1),\n", + ")\n", "data_splitter_events.setup_variables()\n", "\n", "data_splitter_forecasting = DataSplitterForecasting(\n", " data_manager=dm,\n", " config=config,\n", + " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n", ")\n", "\n", "# If you don't want to do forecasting QA, proportional sampling, or 3-sigma filtering, you can skip this step\n", diff --git a/docs/examples/advanced/custom_output/custom_summarized_row.ipynb b/docs/examples/advanced/custom_output/custom_summarized_row.ipynb index d17f3db..9a88f68 100644 --- a/docs/examples/advanced/custom_output/custom_summarized_row.ipynb +++ b/docs/examples/advanced/custom_output/custom_summarized_row.ipynb @@ -134,10 +134,19 @@ "outputs": [], "source": [ "# Setup splitters\n", - "data_splitter_events = DataSplitterEvents(dm, config=config)\n", + "data_splitter_events = DataSplitterEvents(\n", + " dm,\n", + " config=config,\n", + " max_length_to_sample=pd.Timedelta(weeks=104),\n", + " min_length_to_sample=pd.Timedelta(weeks=1),\n", + ")\n", "data_splitter_events.setup_variables()\n", "\n", - "data_splitter_forecasting = DataSplitterForecasting(data_manager=dm, config=config)\n", + "data_splitter_forecasting = DataSplitterForecasting(\n", + " data_manager=dm,\n", + " config=config,\n", + " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n", + ")\n", "data_splitter_forecasting.setup_statistics()\n", "\n", "data_splitter = DataSplitter(data_splitter_events, data_splitter_forecasting)" diff --git a/docs/examples/advanced/custom_output/customizing_text_generation.ipynb b/docs/examples/advanced/custom_output/customizing_text_generation.ipynb index 6f15659..c1bad97 100644 --- a/docs/examples/advanced/custom_output/customizing_text_generation.ipynb +++ b/docs/examples/advanced/custom_output/customizing_text_generation.ipynb @@ -121,10 +121,19 @@ "outputs": [], "source": [ "# Setup splitters and converter\n", - "data_splitter_events_default = DataSplitterEvents(dm_default, config=config_default)\n", + "data_splitter_events_default = DataSplitterEvents(\n", + " dm_default,\n", + " config=config_default,\n", + " max_length_to_sample=pd.Timedelta(weeks=104),\n", + " min_length_to_sample=pd.Timedelta(weeks=1),\n", + ")\n", "data_splitter_events_default.setup_variables()\n", "\n", - "data_splitter_forecasting_default = DataSplitterForecasting(data_manager=dm_default, config=config_default)\n", + "data_splitter_forecasting_default = DataSplitterForecasting(\n", + " data_manager=dm_default,\n", + " config=config_default,\n", + " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n", + ")\n", "data_splitter_forecasting_default.setup_statistics()\n", "\n", "data_splitter_default = DataSplitter(data_splitter_events_default, data_splitter_forecasting_default)\n", @@ -623,10 +632,19 @@ "outputs": [], "source": [ "# Setup splitters and converter with custom config\n", - "data_splitter_events_custom = DataSplitterEvents(dm_custom, config=config_custom)\n", + "data_splitter_events_custom = DataSplitterEvents(\n", + " dm_custom,\n", + " config=config_custom,\n", + " max_length_to_sample=pd.Timedelta(weeks=104),\n", + " min_length_to_sample=pd.Timedelta(weeks=1),\n", + ")\n", "data_splitter_events_custom.setup_variables()\n", "\n", - "data_splitter_forecasting_custom = DataSplitterForecasting(data_manager=dm_custom, config=config_custom)\n", + "data_splitter_forecasting_custom = DataSplitterForecasting(\n", + " data_manager=dm_custom,\n", + " config=config_custom,\n", + " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n", + ")\n", "data_splitter_forecasting_custom.setup_statistics()\n", "\n", "data_splitter_custom = DataSplitter(data_splitter_events_custom, data_splitter_forecasting_custom)\n", diff --git a/docs/examples/advanced/custom_splitting/inference_individual_splitters.py b/docs/examples/advanced/custom_splitting/inference_individual_splitters.py index abf7273..377ac8e 100644 --- a/docs/examples/advanced/custom_splitting/inference_individual_splitters.py +++ b/docs/examples/advanced/custom_splitting/inference_individual_splitters.py @@ -45,9 +45,18 @@ def __init__( self.dm.setup_dataset_splits() self.dm.infer_var_types() - self.data_splitter_events = DataSplitterEvents(self.dm, config=self.config) + self.data_splitter_events = DataSplitterEvents( + self.dm, + config=self.config, + max_length_to_sample=pd.Timedelta(weeks=104), + min_length_to_sample=pd.Timedelta(weeks=1), + ) self.data_splitter_events.setup_variables() - self.data_splitter_forecasting = DataSplitterForecasting(data_manager=self.dm, config=self.config) + self.data_splitter_forecasting = DataSplitterForecasting( + data_manager=self.dm, + config=self.config, + max_forecasted_trajectory_length=pd.Timedelta(days=90), + ) self.data_splitter_forecasting.setup_statistics() self.converter = ConverterInstruction( nr_tokens_budget_total=8192, diff --git a/docs/examples/advanced/custom_splitting/training_custom_split_events.ipynb b/docs/examples/advanced/custom_splitting/training_custom_split_events.ipynb index acd176a..2c8df56 100644 --- a/docs/examples/advanced/custom_splitting/training_custom_split_events.ipynb +++ b/docs/examples/advanced/custom_splitting/training_custom_split_events.ipynb @@ -138,13 +138,19 @@ "outputs": [], "source": [ "# This data splitter handles event prediction tasks\n", - "data_splitter_events = DataSplitterEvents(dm, config=config)\n", + "data_splitter_events = DataSplitterEvents(\n", + " dm,\n", + " config=config,\n", + " max_length_to_sample=pd.Timedelta(weeks=104),\n", + " min_length_to_sample=pd.Timedelta(weeks=1),\n", + ")\n", "data_splitter_events.setup_variables()\n", "\n", "# This data splitter handles forecasting tasks\n", "data_splitter_forecasting = DataSplitterForecasting(\n", " data_manager=dm,\n", " config=config,\n", + " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n", ")\n", "# If you don't want to do forecasting QA, proportional sampling, or 3-sigma filtering, you can skip this step\n", "data_splitter_forecasting.setup_statistics()\n", diff --git a/docs/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb b/docs/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb index 373ea42..e4a5ae5 100644 --- a/docs/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb +++ b/docs/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb @@ -112,6 +112,7 @@ "data_splitter_forecasting = DataSplitterForecasting(\n", " data_manager=dm,\n", " config=config,\n", + " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n", ")\n", "# In case you manually want to override the variables for forecasting selectiong, you can skip this next line.\n", "data_splitter_forecasting.setup_statistics()\n", diff --git a/docs/examples/advanced/custom_splitting/training_individual_splitters.ipynb b/docs/examples/advanced/custom_splitting/training_individual_splitters.ipynb index 446142c..8389bcb 100644 --- a/docs/examples/advanced/custom_splitting/training_individual_splitters.ipynb +++ b/docs/examples/advanced/custom_splitting/training_individual_splitters.ipynb @@ -114,12 +114,18 @@ "dm.setup_dataset_splits()\n", "dm.infer_var_types()\n", "\n", - "data_splitter_events = DataSplitterEvents(dm, config=config)\n", + "data_splitter_events = DataSplitterEvents(\n", + " dm,\n", + " config=config,\n", + " max_length_to_sample=pd.Timedelta(weeks=104),\n", + " min_length_to_sample=pd.Timedelta(weeks=1),\n", + ")\n", "data_splitter_events.setup_variables()\n", "\n", "data_splitter_forecasting = DataSplitterForecasting(\n", " data_manager=dm,\n", " config=config,\n", + " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n", ")\n", "# In case you manually want to override the variables for forecasting selectiong, you can skip this next line.\n", "data_splitter_forecasting.setup_statistics()\n", diff --git a/docs/examples/advanced/tte_inference/tte_probability_inference.ipynb b/docs/examples/advanced/tte_inference/tte_probability_inference.ipynb index 9e7a2b2..64babc3 100644 --- a/docs/examples/advanced/tte_inference/tte_probability_inference.ipynb +++ b/docs/examples/advanced/tte_inference/tte_probability_inference.ipynb @@ -191,10 +191,19 @@ "dm.setup_dataset_splits()\n", "dm.infer_var_types()\n", "\n", - "data_splitter_events = DataSplitterEvents(dm, config=config)\n", + "data_splitter_events = DataSplitterEvents(\n", + " dm,\n", + " config=config,\n", + " max_length_to_sample=pd.Timedelta(weeks=104),\n", + " min_length_to_sample=pd.Timedelta(weeks=1),\n", + ")\n", "data_splitter_events.setup_variables()\n", "\n", - "data_splitter_forecasting = DataSplitterForecasting(data_manager=dm, config=config)\n", + "data_splitter_forecasting = DataSplitterForecasting(\n", + " data_manager=dm,\n", + " config=config,\n", + " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n", + ")\n", "data_splitter_forecasting.setup_statistics()\n", "\n", "# Combined interface\n", diff --git a/docs/examples/data_preprocessing/raw_data_preprocessing.ipynb b/docs/examples/data_preprocessing/raw_data_preprocessing.ipynb index efc5472..58bf5d1 100644 --- a/docs/examples/data_preprocessing/raw_data_preprocessing.ipynb +++ b/docs/examples/data_preprocessing/raw_data_preprocessing.ipynb @@ -1050,12 +1050,18 @@ "outputs": [], "source": [ "# Initialize data splitters and converter\n", - "data_splitter_events = DataSplitterEvents(dm, config=config)\n", + "data_splitter_events = DataSplitterEvents(\n", + " dm,\n", + " config=config,\n", + " max_length_to_sample=pd.Timedelta(weeks=104),\n", + " min_length_to_sample=pd.Timedelta(weeks=1),\n", + ")\n", "data_splitter_events.setup_variables()\n", "\n", "data_splitter_forecasting = DataSplitterForecasting(\n", " data_manager=dm,\n", " config=config,\n", + " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n", ")\n", "data_splitter_forecasting.setup_statistics()\n", "\n", diff --git a/docs/examples/integrations/meds_data_import.ipynb b/docs/examples/integrations/meds_data_import.ipynb index 8bedbda..d7412d2 100644 --- a/docs/examples/integrations/meds_data_import.ipynb +++ b/docs/examples/integrations/meds_data_import.ipynb @@ -501,7 +501,12 @@ "dm.setup_unique_mapping_of_events()\n", "dm.setup_dataset_splits()\n", "\n", - "data_splitter_events = DataSplitterEvents(dm, config=config)\n", + "data_splitter_events = DataSplitterEvents(\n", + " dm,\n", + " config=config,\n", + " max_length_to_sample=pd.Timedelta(weeks=104),\n", + " min_length_to_sample=pd.Timedelta(weeks=1),\n", + ")\n", "data_splitter_events.setup_variables()\n", "converter = ConverterInstruction(\n", " nr_tokens_budget_total=8192,\n", diff --git a/examples/01_data_preparation_for_training.ipynb b/examples/01_data_preparation_for_training.ipynb index 60fe0bd..6264cd6 100644 --- a/examples/01_data_preparation_for_training.ipynb +++ b/examples/01_data_preparation_for_training.ipynb @@ -202,13 +202,19 @@ "outputs": [], "source": [ "# This data splitter handles event prediction tasks\n", - "data_splitter_events = DataSplitterEvents(dm, config=config)\n", + "data_splitter_events = DataSplitterEvents(\n", + " dm,\n", + " config=config,\n", + " max_length_to_sample=pd.Timedelta(weeks=104),\n", + " min_length_to_sample=pd.Timedelta(weeks=1),\n", + ")\n", "data_splitter_events.setup_variables()\n", "\n", "# This data splitter handles forecasting tasks\n", "data_splitter_forecasting = DataSplitterForecasting(\n", " data_manager=dm,\n", " config=config,\n", + " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n", ")\n", "# If you don't want to do forecasting QA, proportional sampling, or 3-sigma filtering, you can skip this step\n", "data_splitter_forecasting.setup_statistics()\n", diff --git a/examples/02_inference_prompt_preparation.ipynb b/examples/02_inference_prompt_preparation.ipynb index 299e9cd..fca3954 100644 --- a/examples/02_inference_prompt_preparation.ipynb +++ b/examples/02_inference_prompt_preparation.ipynb @@ -91,12 +91,18 @@ "dm.setup_dataset_splits()\n", "dm.infer_var_types()\n", "\n", - "data_splitter_events = DataSplitterEvents(dm, config=config)\n", + "data_splitter_events = DataSplitterEvents(\n", + " dm,\n", + " config=config,\n", + " max_length_to_sample=pd.Timedelta(weeks=104),\n", + " min_length_to_sample=pd.Timedelta(weeks=1),\n", + ")\n", "data_splitter_events.setup_variables()\n", "\n", "data_splitter_forecasting = DataSplitterForecasting(\n", " data_manager=dm,\n", " config=config,\n", + " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n", ")\n", "# In case you manually want to override the variables for forecasting selectiong, you can skip this next line.\n", "data_splitter_forecasting.setup_statistics()\n", diff --git a/examples/03_end_to_end_llm_finetuning.ipynb b/examples/03_end_to_end_llm_finetuning.ipynb index 3e91776..a73099d 100644 --- a/examples/03_end_to_end_llm_finetuning.ipynb +++ b/examples/03_end_to_end_llm_finetuning.ipynb @@ -147,12 +147,18 @@ "dm.setup_dataset_splits()\n", "dm.infer_var_types()\n", "\n", - "data_splitter_events = DataSplitterEvents(dm, config=config)\n", + "data_splitter_events = DataSplitterEvents(\n", + " dm,\n", + " config=config,\n", + " max_length_to_sample=pd.Timedelta(weeks=104),\n", + " min_length_to_sample=pd.Timedelta(weeks=1),\n", + ")\n", "data_splitter_events.setup_variables()\n", "\n", "data_splitter_forecasting = DataSplitterForecasting(\n", " data_manager=dm,\n", " config=config,\n", + " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n", ")\n", "\n", "# If you don't want to do forecasting QA, proportional sampling, or 3-sigma filtering, you can skip this step\n", diff --git a/examples/advanced/custom_output/custom_summarized_row.ipynb b/examples/advanced/custom_output/custom_summarized_row.ipynb index d17f3db..9a88f68 100644 --- a/examples/advanced/custom_output/custom_summarized_row.ipynb +++ b/examples/advanced/custom_output/custom_summarized_row.ipynb @@ -134,10 +134,19 @@ "outputs": [], "source": [ "# Setup splitters\n", - "data_splitter_events = DataSplitterEvents(dm, config=config)\n", + "data_splitter_events = DataSplitterEvents(\n", + " dm,\n", + " config=config,\n", + " max_length_to_sample=pd.Timedelta(weeks=104),\n", + " min_length_to_sample=pd.Timedelta(weeks=1),\n", + ")\n", "data_splitter_events.setup_variables()\n", "\n", - "data_splitter_forecasting = DataSplitterForecasting(data_manager=dm, config=config)\n", + "data_splitter_forecasting = DataSplitterForecasting(\n", + " data_manager=dm,\n", + " config=config,\n", + " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n", + ")\n", "data_splitter_forecasting.setup_statistics()\n", "\n", "data_splitter = DataSplitter(data_splitter_events, data_splitter_forecasting)" diff --git a/examples/advanced/custom_output/customizing_text_generation.ipynb b/examples/advanced/custom_output/customizing_text_generation.ipynb index 6f15659..c1bad97 100644 --- a/examples/advanced/custom_output/customizing_text_generation.ipynb +++ b/examples/advanced/custom_output/customizing_text_generation.ipynb @@ -121,10 +121,19 @@ "outputs": [], "source": [ "# Setup splitters and converter\n", - "data_splitter_events_default = DataSplitterEvents(dm_default, config=config_default)\n", + "data_splitter_events_default = DataSplitterEvents(\n", + " dm_default,\n", + " config=config_default,\n", + " max_length_to_sample=pd.Timedelta(weeks=104),\n", + " min_length_to_sample=pd.Timedelta(weeks=1),\n", + ")\n", "data_splitter_events_default.setup_variables()\n", "\n", - "data_splitter_forecasting_default = DataSplitterForecasting(data_manager=dm_default, config=config_default)\n", + "data_splitter_forecasting_default = DataSplitterForecasting(\n", + " data_manager=dm_default,\n", + " config=config_default,\n", + " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n", + ")\n", "data_splitter_forecasting_default.setup_statistics()\n", "\n", "data_splitter_default = DataSplitter(data_splitter_events_default, data_splitter_forecasting_default)\n", @@ -623,10 +632,19 @@ "outputs": [], "source": [ "# Setup splitters and converter with custom config\n", - "data_splitter_events_custom = DataSplitterEvents(dm_custom, config=config_custom)\n", + "data_splitter_events_custom = DataSplitterEvents(\n", + " dm_custom,\n", + " config=config_custom,\n", + " max_length_to_sample=pd.Timedelta(weeks=104),\n", + " min_length_to_sample=pd.Timedelta(weeks=1),\n", + ")\n", "data_splitter_events_custom.setup_variables()\n", "\n", - "data_splitter_forecasting_custom = DataSplitterForecasting(data_manager=dm_custom, config=config_custom)\n", + "data_splitter_forecasting_custom = DataSplitterForecasting(\n", + " data_manager=dm_custom,\n", + " config=config_custom,\n", + " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n", + ")\n", "data_splitter_forecasting_custom.setup_statistics()\n", "\n", "data_splitter_custom = DataSplitter(data_splitter_events_custom, data_splitter_forecasting_custom)\n", diff --git a/examples/advanced/custom_splitting/inference_individual_splitters.py b/examples/advanced/custom_splitting/inference_individual_splitters.py index 7dfda94..0fb06f8 100644 --- a/examples/advanced/custom_splitting/inference_individual_splitters.py +++ b/examples/advanced/custom_splitting/inference_individual_splitters.py @@ -46,9 +46,18 @@ def __init__( self.dm.setup_dataset_splits() self.dm.infer_var_types() - data_splitter_events = DataSplitterEvents(self.dm, config=self.config) + data_splitter_events = DataSplitterEvents( + self.dm, + config=self.config, + max_length_to_sample=pd.Timedelta(weeks=104), + min_length_to_sample=pd.Timedelta(weeks=1), + ) data_splitter_events.setup_variables() - data_splitter_forecasting = DataSplitterForecasting(data_manager=self.dm, config=self.config) + data_splitter_forecasting = DataSplitterForecasting( + data_manager=self.dm, + config=self.config, + max_forecasted_trajectory_length=pd.Timedelta(days=90), + ) data_splitter_forecasting.setup_statistics() # Use the unified DataSplitter API that combines both splitters diff --git a/examples/advanced/custom_splitting/training_custom_split_events.ipynb b/examples/advanced/custom_splitting/training_custom_split_events.ipynb index 274ccb3..46eae55 100644 --- a/examples/advanced/custom_splitting/training_custom_split_events.ipynb +++ b/examples/advanced/custom_splitting/training_custom_split_events.ipynb @@ -138,13 +138,19 @@ "outputs": [], "source": [ "# This data splitter handles event prediction tasks\n", - "data_splitter_events = DataSplitterEvents(dm, config=config)\n", + "data_splitter_events = DataSplitterEvents(\n", + " dm,\n", + " config=config,\n", + " max_length_to_sample=pd.Timedelta(weeks=104),\n", + " min_length_to_sample=pd.Timedelta(weeks=1),\n", + ")\n", "data_splitter_events.setup_variables()\n", "\n", "# This data splitter handles forecasting tasks\n", "data_splitter_forecasting = DataSplitterForecasting(\n", " data_manager=dm,\n", " config=config,\n", + " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n", ")\n", "# If you don't want to do forecasting QA, proportional sampling, or 3-sigma filtering, you can skip this step\n", "data_splitter_forecasting.setup_statistics()\n", diff --git a/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb b/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb index 8abb9cd..fd6a63e 100644 --- a/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb +++ b/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb @@ -113,6 +113,7 @@ "data_splitter_forecasting = DataSplitterForecasting(\n", " data_manager=dm,\n", " config=config,\n", + " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n", ")\n", "# In case you manually want to override the variables for forecasting selectiong, you can skip this next line.\n", "data_splitter_forecasting.setup_statistics()\n", diff --git a/examples/advanced/custom_splitting/training_individual_splitters.ipynb b/examples/advanced/custom_splitting/training_individual_splitters.ipynb index 8bd7bb6..2e84ac2 100644 --- a/examples/advanced/custom_splitting/training_individual_splitters.ipynb +++ b/examples/advanced/custom_splitting/training_individual_splitters.ipynb @@ -115,12 +115,18 @@ "dm.setup_dataset_splits()\n", "dm.infer_var_types()\n", "\n", - "data_splitter_events = DataSplitterEvents(dm, config=config)\n", + "data_splitter_events = DataSplitterEvents(\n", + " dm,\n", + " config=config,\n", + " max_length_to_sample=pd.Timedelta(weeks=104),\n", + " min_length_to_sample=pd.Timedelta(weeks=1),\n", + ")\n", "data_splitter_events.setup_variables()\n", "\n", "data_splitter_forecasting = DataSplitterForecasting(\n", " data_manager=dm,\n", " config=config,\n", + " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n", ")\n", "# In case you manually want to override the variables for forecasting selectiong, you can skip this next line.\n", "data_splitter_forecasting.setup_statistics()\n", diff --git a/examples/advanced/tte_inference/tte_probability_inference.ipynb b/examples/advanced/tte_inference/tte_probability_inference.ipynb index 9e7a2b2..64babc3 100644 --- a/examples/advanced/tte_inference/tte_probability_inference.ipynb +++ b/examples/advanced/tte_inference/tte_probability_inference.ipynb @@ -191,10 +191,19 @@ "dm.setup_dataset_splits()\n", "dm.infer_var_types()\n", "\n", - "data_splitter_events = DataSplitterEvents(dm, config=config)\n", + "data_splitter_events = DataSplitterEvents(\n", + " dm,\n", + " config=config,\n", + " max_length_to_sample=pd.Timedelta(weeks=104),\n", + " min_length_to_sample=pd.Timedelta(weeks=1),\n", + ")\n", "data_splitter_events.setup_variables()\n", "\n", - "data_splitter_forecasting = DataSplitterForecasting(data_manager=dm, config=config)\n", + "data_splitter_forecasting = DataSplitterForecasting(\n", + " data_manager=dm,\n", + " config=config,\n", + " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n", + ")\n", "data_splitter_forecasting.setup_statistics()\n", "\n", "# Combined interface\n", diff --git a/examples/data_preprocessing/raw_data_preprocessing.ipynb b/examples/data_preprocessing/raw_data_preprocessing.ipynb index efc5472..58bf5d1 100644 --- a/examples/data_preprocessing/raw_data_preprocessing.ipynb +++ b/examples/data_preprocessing/raw_data_preprocessing.ipynb @@ -1050,12 +1050,18 @@ "outputs": [], "source": [ "# Initialize data splitters and converter\n", - "data_splitter_events = DataSplitterEvents(dm, config=config)\n", + "data_splitter_events = DataSplitterEvents(\n", + " dm,\n", + " config=config,\n", + " max_length_to_sample=pd.Timedelta(weeks=104),\n", + " min_length_to_sample=pd.Timedelta(weeks=1),\n", + ")\n", "data_splitter_events.setup_variables()\n", "\n", "data_splitter_forecasting = DataSplitterForecasting(\n", " data_manager=dm,\n", " config=config,\n", + " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n", ")\n", "data_splitter_forecasting.setup_statistics()\n", "\n", diff --git a/examples/integrations/meds_data_import.ipynb b/examples/integrations/meds_data_import.ipynb index 076e5eb..b75b3bc 100644 --- a/examples/integrations/meds_data_import.ipynb +++ b/examples/integrations/meds_data_import.ipynb @@ -501,7 +501,12 @@ "dm.setup_unique_mapping_of_events()\n", "dm.setup_dataset_splits()\n", "\n", - "data_splitter_events = DataSplitterEvents(dm, config=config)\n", + "data_splitter_events = DataSplitterEvents(\n", + " dm,\n", + " config=config,\n", + " max_length_to_sample=pd.Timedelta(weeks=104),\n", + " min_length_to_sample=pd.Timedelta(weeks=1),\n", + ")\n", "data_splitter_events.setup_variables()\n", "converter = ConverterInstruction(\n", " nr_tokens_budget_total=8192,\n", diff --git a/tests/test_converter.py b/tests/test_converter.py index e50e8a9..32f1503 100644 --- a/tests/test_converter.py +++ b/tests/test_converter.py @@ -24,10 +24,21 @@ def setup_components(mock_config, sample_data): dm.setup_dataset_splits() dm.infer_var_types() - splitter_events = DataSplitterEvents(dm, config=mock_config) + splitter_events = DataSplitterEvents( + dm, + config=mock_config, + max_length_to_sample=pd.Timedelta(weeks=104), + min_length_to_sample=pd.Timedelta(weeks=1), + max_split_length_after_split_event=pd.Timedelta(days=90), + ) splitter_events.setup_variables() - splitter_forecast = DataSplitterForecasting(data_manager=dm, config=mock_config) + splitter_forecast = DataSplitterForecasting( + data_manager=dm, + config=mock_config, + max_forecasted_trajectory_length=pd.Timedelta(days=90), + max_split_length_after_split_event=pd.Timedelta(days=90), + ) splitter_forecast.setup_statistics() data_splitter = DataSplitter(splitter_events, splitter_forecast) @@ -110,10 +121,21 @@ def test_event_categories_to_exclude_from_input(mock_config, sample_data): dm.setup_dataset_splits() dm.infer_var_types() - splitter_events = DataSplitterEvents(dm, config=mock_config) + splitter_events = DataSplitterEvents( + dm, + config=mock_config, + max_length_to_sample=pd.Timedelta(weeks=104), + min_length_to_sample=pd.Timedelta(weeks=1), + max_split_length_after_split_event=pd.Timedelta(days=90), + ) splitter_events.setup_variables() - splitter_forecast = DataSplitterForecasting(data_manager=dm, config=mock_config) + splitter_forecast = DataSplitterForecasting( + data_manager=dm, + config=mock_config, + max_forecasted_trajectory_length=pd.Timedelta(days=90), + max_split_length_after_split_event=pd.Timedelta(days=90), + ) splitter_forecast.setup_statistics() data_splitter = DataSplitter(splitter_events, splitter_forecast) @@ -164,10 +186,21 @@ def test_event_categories_to_exclude_multiple(mock_config, sample_data): dm.setup_dataset_splits() dm.infer_var_types() - splitter_events = DataSplitterEvents(dm, config=mock_config) + splitter_events = DataSplitterEvents( + dm, + config=mock_config, + max_length_to_sample=pd.Timedelta(weeks=104), + min_length_to_sample=pd.Timedelta(weeks=1), + max_split_length_after_split_event=pd.Timedelta(days=90), + ) splitter_events.setup_variables() - splitter_forecast = DataSplitterForecasting(data_manager=dm, config=mock_config) + splitter_forecast = DataSplitterForecasting( + data_manager=dm, + config=mock_config, + max_forecasted_trajectory_length=pd.Timedelta(days=90), + max_split_length_after_split_event=pd.Timedelta(days=90), + ) splitter_forecast.setup_statistics() data_splitter = DataSplitter(splitter_events, splitter_forecast) diff --git a/tests/test_splitter.py b/tests/test_splitter.py index 548a1df..7f1e42f 100644 --- a/tests/test_splitter.py +++ b/tests/test_splitter.py @@ -27,7 +27,9 @@ def initialized_dm(mock_config, sample_data): def test_splitter_forecasting_statistics(initialized_dm, mock_config): """Test that forecasting splitter can calculate statistics.""" - splitter_forecast = DataSplitterForecasting(data_manager=initialized_dm, config=mock_config) + splitter_forecast = DataSplitterForecasting( + data_manager=initialized_dm, config=mock_config, max_forecasted_trajectory_length=pd.Timedelta(days=90) + ) # This calculates R2, NRMSE etc. for the variables splitter_forecast.setup_statistics() @@ -50,10 +52,21 @@ def test_splitter_forecasting_statistics(initialized_dm, mock_config): def test_get_splits_from_patient(initialized_dm, mock_config): """Test generating splits for a single patient.""" # Setup Splitters - splitter_events = DataSplitterEvents(initialized_dm, config=mock_config) + splitter_events = DataSplitterEvents( + initialized_dm, + config=mock_config, + max_length_to_sample=pd.Timedelta(weeks=104), + min_length_to_sample=pd.Timedelta(weeks=1), + max_split_length_after_split_event=pd.Timedelta(days=90), + ) splitter_events.setup_variables() - splitter_forecast = DataSplitterForecasting(data_manager=initialized_dm, config=mock_config) + splitter_forecast = DataSplitterForecasting( + data_manager=initialized_dm, + config=mock_config, + max_forecasted_trajectory_length=pd.Timedelta(days=90), + max_split_length_after_split_event=pd.Timedelta(days=90), + ) splitter_forecast.setup_statistics() data_splitter = DataSplitter(splitter_events, splitter_forecast) @@ -99,9 +112,16 @@ def test_get_splits_from_patient(initialized_dm, mock_config): def test_inference_split(initialized_dm, mock_config): """Test generating an inference split (last date).""" - splitter_events = DataSplitterEvents(initialized_dm, config=mock_config) + splitter_events = DataSplitterEvents( + initialized_dm, + config=mock_config, + max_length_to_sample=pd.Timedelta(weeks=104), + min_length_to_sample=pd.Timedelta(weeks=1), + ) splitter_events.setup_variables() - splitter_forecast = DataSplitterForecasting(data_manager=initialized_dm, config=mock_config) + splitter_forecast = DataSplitterForecasting( + data_manager=initialized_dm, config=mock_config, max_forecasted_trajectory_length=pd.Timedelta(days=90) + ) data_splitter = DataSplitter(splitter_events, splitter_forecast) patient_data = initialized_dm.get_patient_data("p0") @@ -140,7 +160,12 @@ def test_data_splitter_requires_at_least_one_splitter(): def test_training_forecasting_only(initialized_dm, mock_config): """Test training splits when only the forecasting splitter is provided.""" - splitter_forecast = DataSplitterForecasting(data_manager=initialized_dm, config=mock_config) + splitter_forecast = DataSplitterForecasting( + data_manager=initialized_dm, + config=mock_config, + max_forecasted_trajectory_length=pd.Timedelta(days=90), + max_split_length_after_split_event=pd.Timedelta(days=90), + ) splitter_forecast.setup_statistics() data_splitter = DataSplitter(data_splitter_forecasting=splitter_forecast) @@ -193,7 +218,13 @@ def test_training_forecasting_only(initialized_dm, mock_config): def test_training_events_only(initialized_dm, mock_config): """Test training splits when only the events splitter is provided.""" - splitter_events = DataSplitterEvents(initialized_dm, config=mock_config) + splitter_events = DataSplitterEvents( + initialized_dm, + config=mock_config, + max_length_to_sample=pd.Timedelta(weeks=104), + min_length_to_sample=pd.Timedelta(weeks=1), + max_split_length_after_split_event=pd.Timedelta(days=90), + ) splitter_events.setup_variables() data_splitter = DataSplitter(data_splitter_events=splitter_events) @@ -256,7 +287,9 @@ def test_training_events_only(initialized_dm, mock_config): def test_inference_forecasting_only(initialized_dm, mock_config): """Test inference split when only the forecasting splitter is provided.""" - splitter_forecast = DataSplitterForecasting(data_manager=initialized_dm, config=mock_config) + splitter_forecast = DataSplitterForecasting( + data_manager=initialized_dm, config=mock_config, max_forecasted_trajectory_length=pd.Timedelta(days=90) + ) data_splitter = DataSplitter(data_splitter_forecasting=splitter_forecast) patient_data = initialized_dm.get_patient_data("p0") @@ -297,7 +330,12 @@ def test_inference_forecasting_only(initialized_dm, mock_config): def test_inference_events_only(initialized_dm, mock_config): """Test inference split when only the events splitter is provided.""" - splitter_events = DataSplitterEvents(initialized_dm, config=mock_config) + splitter_events = DataSplitterEvents( + initialized_dm, + config=mock_config, + max_length_to_sample=pd.Timedelta(weeks=104), + min_length_to_sample=pd.Timedelta(weeks=1), + ) splitter_events.setup_variables() data_splitter = DataSplitter(data_splitter_events=splitter_events) @@ -341,7 +379,9 @@ def test_inference_events_only(initialized_dm, mock_config): def test_inference_both_type_with_only_forecasting(initialized_dm, mock_config): """Test that inference_type='both' gracefully returns None for the missing splitter.""" - splitter_forecast = DataSplitterForecasting(data_manager=initialized_dm, config=mock_config) + splitter_forecast = DataSplitterForecasting( + data_manager=initialized_dm, config=mock_config, max_forecasted_trajectory_length=pd.Timedelta(days=90) + ) data_splitter = DataSplitter(data_splitter_forecasting=splitter_forecast) patient_data = initialized_dm.get_patient_data("p0") @@ -378,7 +418,12 @@ def test_inference_both_type_with_only_forecasting(initialized_dm, mock_config): def test_inference_both_type_with_only_events(initialized_dm, mock_config): """Test that inference_type='both' gracefully returns None for the missing splitter.""" - splitter_events = DataSplitterEvents(initialized_dm, config=mock_config) + splitter_events = DataSplitterEvents( + initialized_dm, + config=mock_config, + max_length_to_sample=pd.Timedelta(weeks=104), + min_length_to_sample=pd.Timedelta(weeks=1), + ) splitter_events.setup_variables() data_splitter = DataSplitter(data_splitter_events=splitter_events) @@ -530,7 +575,7 @@ def test_forecasting_truncates_at_next_split_event_not_just_lot(): splitter = DataSplitterForecasting( config=cfg, data_manager=dm, - max_forecast_time_for_value=pd.Timedelta(days=90), + max_forecasted_trajectory_length=pd.Timedelta(days=90), max_lookback_time_for_value=pd.Timedelta(days=90), max_split_length_after_split_event=pd.Timedelta(days=90), sampling_strategy="uniform", @@ -644,7 +689,7 @@ def test_forecasting_truncation_allow_beyond_next_split_date(): splitter = DataSplitterForecasting( config=cfg, data_manager=dm, - max_forecast_time_for_value=pd.Timedelta(days=90), + max_forecasted_trajectory_length=pd.Timedelta(days=90), max_lookback_time_for_value=pd.Timedelta(days=90), max_split_length_after_split_event=pd.Timedelta(days=90), sampling_strategy="uniform", diff --git a/twinweaver/instruction/data_splitter_base.py b/twinweaver/instruction/data_splitter_base.py index f489d6c..5e2bb4a 100644 --- a/twinweaver/instruction/data_splitter_base.py +++ b/twinweaver/instruction/data_splitter_base.py @@ -15,8 +15,6 @@ def __init__( data_manager: DataManager, config: Config, max_split_length_after_split_event: pd.Timedelta = pd.Timedelta(days=90), - max_lookback_time_for_value: pd.Timedelta = pd.Timedelta(days=90), - max_forecast_time_for_value: pd.Timedelta = pd.Timedelta(days=90), ): """ Constructor for the BaseDataSplitter class. @@ -30,12 +28,6 @@ def __init__( max_split_length_after_split_event: pd.Timedelta the maximum number of days after a LoT event that we want to consider as a starting point. - max_lookback_time_for_value: pd.Timedelta - the maximum number of days before a certain split date where we need to see - the value of the target variable. - max_forecast_time_for_value : pd.Timedelta - the maximum number of days after a certain split date where we need to see - the value of the target variable when filtering. """ assert config.split_event_category is not None, "config.split_event_category must be set (e.g. ['lab'])." @@ -43,8 +35,6 @@ def __init__( self.dm = data_manager self.config = config self.max_split_length_after_split_event = max_split_length_after_split_event - self.max_lookback_time_for_value = max_lookback_time_for_value - self.max_forecast_time_for_value = max_forecast_time_for_value def _get_all_dates_within_range_of_split_event( self, diff --git a/twinweaver/instruction/data_splitter_events.py b/twinweaver/instruction/data_splitter_events.py index beeb96d..46a27a9 100644 --- a/twinweaver/instruction/data_splitter_events.py +++ b/twinweaver/instruction/data_splitter_events.py @@ -88,12 +88,10 @@ def __init__( self, data_manager: DataManager, config: Config, - max_length_to_sample: pd.Timedelta = pd.Timedelta(weeks=104), - min_length_to_sample: pd.Timedelta = pd.Timedelta(weeks=1), + max_length_to_sample: pd.Timedelta, + min_length_to_sample: pd.Timedelta, unit_length_to_sample: str = "weeks", - max_split_length_after_split_event: pd.Timedelta = pd.Timedelta(days=90), - max_lookback_time_for_value: pd.Timedelta = pd.Timedelta(days=90), - max_forecast_time_for_value: pd.Timedelta = pd.Timedelta(days=90), + max_split_length_after_split_event: pd.Timedelta = pd.Timedelta(days=0), ): """ Initialize the DataSplitterEvents class. @@ -105,24 +103,21 @@ def __init__( config : Config Configuration object holding constants. max_length_to_sample : pd.Timedelta - The maximum number of weeks into the future to sample for event prediction. + The maximum length of time into the future to sample for event prediction. + Required, no default. min_length_to_sample : pd.Timedelta - The minimum number of weeks into the future to sample for event prediction. + The minimum length of time into the future to sample for event prediction. + Required, no default. unit_length_to_sample : str The unit of time for the length to sample (e.g. "weeks"). max_split_length_after_split_event : pd.Timedelta, optional The maximum number of days after the split event (e.g. line of therapy) to consider for split points. - max_lookback_time_for_value : pd.Timedelta, optional - The maximum number of days to look back for a value (inherited but not directly used here). - max_forecast_time_for_value : pd.Timedelta, optional - The maximum number of days to forecast a value (inherited but not directly used here). + Defaults to 0 days. """ super().__init__( data_manager, config, max_split_length_after_split_event, - max_lookback_time_for_value, - max_forecast_time_for_value, ) self.max_length_to_sample = max_length_to_sample self.min_length_to_sample = min_length_to_sample diff --git a/twinweaver/instruction/data_splitter_forecasting.py b/twinweaver/instruction/data_splitter_forecasting.py index 808ea2e..998ef07 100644 --- a/twinweaver/instruction/data_splitter_forecasting.py +++ b/twinweaver/instruction/data_splitter_forecasting.py @@ -113,9 +113,9 @@ def __init__( self, config: Config, data_manager: DataManager, - max_split_length_after_split_event: pd.Timedelta = pd.Timedelta(days=90), - max_lookback_time_for_value: pd.Timedelta = pd.Timedelta(days=90), - max_forecast_time_for_value: pd.Timedelta = pd.Timedelta(days=90), + max_forecasted_trajectory_length: pd.Timedelta, + max_split_length_after_split_event: pd.Timedelta = pd.Timedelta(days=0), + max_lookback_time_for_value: pd.Timedelta = pd.Timedelta(days=100000), min_num_samples_for_statistics: int = 10, sampling_score_to_use: str = "score_log_nrmse_n_samples", min_nr_variable_seen_previously: int = 1, @@ -123,7 +123,7 @@ def __init__( list_of_valid_categories: list = None, save_path_for_variable_stats: str = None, min_nr_variables_to_sample: int = 1, - max_nr_variables_to_sample: int = 3, + max_nr_variables_to_sample: int = 1, filtering_strategy: str = "3-sigma", sampling_strategy: str = "proportional", allow_forecasting_beyond_next_split_date: bool = False, @@ -137,17 +137,17 @@ def __init__( Configuration object containing shared settings like column names. data_manager : DataManager Provides access to patient data for a single indication. + max_forecasted_trajectory_length : pd.Timedelta + Max time after a split date to look for future variable occurrences (target + data). Required, no default. max_split_length_after_split_event : pd.Timedelta - Max days after LoT start to consider for split dates. Defaults to 90. + Max time after LoT start to consider for split dates. Defaults to 0 days. max_lookback_time_for_value : pd.Timedelta - Max days before a split date to look for past variable occurrences. - Defaults to 90. - max_forecast_time_for_value : pd.Timedelta - Max days after a split date to look for future variable occurrences (target - data). Defaults to 90. + Max time before a split date to look for past variable occurrences. + Defaults to 100000 days (effectively no limit). min_num_samples_for_statistics : int Minimum total occurrences of a variable across the training set - needed to calculate statistics. Defaults to 50. + needed to calculate statistics. Defaults to 10. sampling_score_to_use : str Column name in the computed statistics table used for weighted sampling of variables. Defaults to 'score_log_nrmse_n_samples'. @@ -165,10 +165,10 @@ def __init__( None. min_nr_variables_to_sample : int The minimum number of distinct variables to attempt to sample for each - forecasting task. Defaults to 3. + forecasting task. Defaults to 1. max_nr_variables_to_sample : int The maximum number of distinct variables to attempt to sample for each - forecasting task. Defaults to 3. + forecasting task. Defaults to 1. filtering_strategy : str The strategy for handling outliers in target variable values ('3-sigma'). Defaults to '3-sigma'. @@ -183,8 +183,6 @@ def __init__( data_manager, config, max_split_length_after_split_event, - max_lookback_time_for_value, - max_forecast_time_for_value, ) assert self.config.event_category_forecast is not None or list_of_valid_categories is not None, ( @@ -192,7 +190,8 @@ def __init__( "For example: ['lab']" " Alternatively, provide list_of_valid_categories directly." ) - + self.max_lookback_time_for_value = max_lookback_time_for_value + self.max_forecasted_trajectory_length = max_forecasted_trajectory_length self.variable_stats = None self.variable_type = {} # event_name -> "numeric" / "categorical" self.min_num_samples_for_statistics = min_num_samples_for_statistics @@ -259,7 +258,7 @@ def setup_statistics(self, train_patientids: list = None): temp_splits = self._get_all_dates_within_range_of_split_event( temp_patient_data, time_before_lot_start=self.max_lookback_time_for_value, - max_split_length_after_split_event=self.max_forecast_time_for_value, + max_split_length_after_split_event=self.max_forecasted_trajectory_length, ) temp_splits[self.config.patient_id_col] = patientid temp_splits = temp_splits[[self.config.date_col, self.config.patient_id_col]] @@ -657,7 +656,7 @@ def _get_all_possible_splits( # Pre-compute date ranges for lookback and forecast lookback_range = self.max_lookback_time_for_value - forecast_range = self.max_forecast_time_for_value + forecast_range = self.max_forecasted_trajectory_length # Initialize the return_splits list return_splits = [] @@ -859,7 +858,7 @@ def _generate_variable_splits_for_date( events_before_split = events[events[self.config.date_col] <= curr_date] events_after_split = events[events[self.config.date_col] > curr_date] events_after_split = events_after_split[ - events_after_split[self.config.date_col] <= curr_date + self.max_forecast_time_for_value + events_after_split[self.config.date_col] <= curr_date + self.max_forecasted_trajectory_length ] events_after_split = events_after_split[ events_after_split[self.config.event_name_col].isin(sampled_variables) From 0ee924e10e3b41353f1e045fd562f6bc11e00817 Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 13:46:51 +0000 Subject: [PATCH 14/36] Renamed to event mapping to event_category_events_prediction_with_naming --- .pre-commit-config.yaml | 8 +++++++- README.md | 2 +- docs/data-splitting.md | 6 +++--- docs/dataset-format.md | 4 ++-- .../01_data_preparation_for_training.ipynb | 4 ++-- .../02_inference_prompt_preparation.ipynb | 2 +- docs/examples/03_end_to_end_llm_finetuning.ipynb | 2 +- .../custom_output/custom_summarized_row.ipynb | 2 +- .../customizing_text_generation.ipynb | 4 ++-- .../inference_individual_splitters.py | 2 +- .../training_custom_split_events.ipynb | 2 +- .../training_individual_splitters.ipynb | 2 +- .../tte_probability_inference.ipynb | 2 +- .../raw_data_preprocessing.ipynb | 2 +- .../01_data_preparation_challenge.ipynb | 16 ++++++++-------- .../hackathon/02_llm_finetuning_challenge.ipynb | 2 +- .../examples/integrations/meds_data_import.ipynb | 2 +- docs/framework.md | 2 +- docs/quickstart.md | 2 +- examples/01_data_preparation_for_training.ipynb | 4 ++-- examples/02_inference_prompt_preparation.ipynb | 2 +- examples/03_end_to_end_llm_finetuning.ipynb | 2 +- .../custom_output/custom_summarized_row.ipynb | 2 +- .../customizing_text_generation.ipynb | 4 ++-- .../inference_individual_splitters.py | 2 +- .../training_custom_split_events.ipynb | 2 +- .../training_individual_splitters.ipynb | 2 +- .../tte_probability_inference.ipynb | 2 +- .../raw_data_preprocessing.ipynb | 2 +- .../01_data_preparation_challenge.ipynb | 16 ++++++++-------- .../hackathon/02_llm_finetuning_challenge.ipynb | 2 +- examples/integrations/meds_data_import.ipynb | 2 +- tests/test_common.py | 2 +- tests/test_converter.py | 6 +++--- tests/test_splitter.py | 2 +- twinweaver/common/config.py | 4 ++-- twinweaver/instruction/data_splitter_events.py | 8 ++++---- 37 files changed, 70 insertions(+), 64 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ac4e937..0c43503 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,10 @@ -exclude: ^examples/hackathon +# Exclude hackathon examples +exclude: | + (?x)^( + docs/examples/hackathon/| + examples/hackathon/| + \^examples/hackathon + ) repos: # 1. Standard "Cleanup" Hooks diff --git a/README.md b/README.md index 023de9a..59e77c0 100644 --- a/README.md +++ b/README.md @@ -133,7 +133,7 @@ config.event_category_forecast = ["lab"] # 3. Mapping of specific time to events to predict (e.g., we want to predict 'death' and 'progression') # Only needs to be set if you want to do time to event prediction -config.data_splitter_events_variables_category_mapping = { +config.event_category_events_prediction_with_naming = { "death": "death", "progression": "next progression", # Custom name in prompt: "next progression" instead of "progression" } diff --git a/docs/data-splitting.md b/docs/data-splitting.md index 61b971e..621fce3 100644 --- a/docs/data-splitting.md +++ b/docs/data-splitting.md @@ -161,7 +161,7 @@ data_splitter_events = DataSplitterEvents( The event-to-prediction mapping is configured via: ```python -config.data_splitter_events_variables_category_mapping = { +config.event_category_events_prediction_with_naming = { "death": "death", # event_category → descriptive name in prompt "progression": "next progression", # custom prompt label } @@ -261,7 +261,7 @@ A single patient can yield many training examples through several sources of var | Multiple split events (e.g., LoTs) | Patient history | One split per LoT by default | | Multiple dates per split event | `max_num_splits_per_split_event` | Random dates within the LoT window | | Different variable subsets | `min/max_nr_variables_to_sample` | Different forecasting questions per date | -| Different event categories | `data_splitter_events_variables_category_mapping` | Death vs. progression predictions | +| Different event categories | `event_category_events_prediction_with_naming` | Death vs. progression predictions | | Different prediction windows | `min/max_length_to_sample` | 1-week to 104-week horizons | This diversity encourages the model to generalize across time points, variables, and prediction tasks. @@ -282,7 +282,7 @@ from twinweaver import ( config = Config() config.split_event_category = "lot" config.event_category_forecast = ["lab"] -config.data_splitter_events_variables_category_mapping = { +config.event_category_events_prediction_with_naming = { "death": "death", "progression": "next progression", } diff --git a/docs/dataset-format.md b/docs/dataset-format.md index ace6107..80b5967 100644 --- a/docs/dataset-format.md +++ b/docs/dataset-format.md @@ -284,7 +284,7 @@ config.event_category_forecast = ["lab"] # 3. Mapping of specific time to events to predict (e.g., we want to predict 'death' and 'progression') # Only needs to be set if you want to do time to event prediction -config.data_splitter_events_variables_category_mapping = { +config.event_category_events_prediction_with_naming = { "death": "death", "progression": "next progression", # Custom name in prompt } @@ -301,6 +301,6 @@ dm.load_indication_data( !!! tip "Configuration Parameters" - **`split_event_category`**: The event category used to anchor split points for generating training samples (required for instruction tuning) - **`event_category_forecast`**: Which event categories to forecast as time-series values - - **`data_splitter_events_variables_category_mapping`**: Maps event names to prediction tasks (e.g., survival, progression) + - **`event_category_events_prediction_with_naming`**: Maps event names to prediction tasks (e.g., survival, progression) See the [Raw Data Preprocessing Tutorial](examples/data_preprocessing/raw_data_preprocessing.ipynb) for transforming raw clinical data into TwinWeaver format, or the [Data Preparation Tutorial](examples/01_data_preparation_for_training.ipynb) for a complete walkthrough of instruction-tuning data generation. diff --git a/docs/examples/01_data_preparation_for_training.ipynb b/docs/examples/01_data_preparation_for_training.ipynb index 6264cd6..1b242d8 100644 --- a/docs/examples/01_data_preparation_for_training.ipynb +++ b/docs/examples/01_data_preparation_for_training.ipynb @@ -95,7 +95,7 @@ "\n", "- **`config.split_event_category`**: Determines the event category used to split patient histories into input and output segments. In this example, we split data around \"Line of Therapy\" (`lot`) start dates.\n", "- **`config.event_category_forecast`**: Identifies which event categories should be forecasted as time-series values. Here, we target `lab` values.\n", - "- **`config.data_splitter_events_variables_category_mapping`**: Maps specific events to survival analysis or classification tasks. We configure the model to predict outcomes for `death` and `progression`." + "- **`config.event_category_events_prediction_with_naming`**: Maps specific events to survival analysis or classification tasks. We configure the model to predict outcomes for `death` and `progression`." ] }, { @@ -118,7 +118,7 @@ "\n", "# 3. Mapping of specific time to events to predict (e.g., we want to predict 'death' and 'progression')\n", "# Only needs to be set if you want to do time to event prediction\n", - "config.data_splitter_events_variables_category_mapping = {\n", + "config.event_category_events_prediction_with_naming = {\n", " \"death\": \"death\",\n", " \"progression\": \"next progression\", # Custom name in prompt: \"next progression\" instead of \"progression\"\n", "}" diff --git a/docs/examples/02_inference_prompt_preparation.ipynb b/docs/examples/02_inference_prompt_preparation.ipynb index fca3954..4facb67 100644 --- a/docs/examples/02_inference_prompt_preparation.ipynb +++ b/docs/examples/02_inference_prompt_preparation.ipynb @@ -67,7 +67,7 @@ "\n", "# 3. Mapping of specific time to events to predict (e.g., we want to predict 'death' and 'progression')\n", "# Only needs to be set if you want to do time to event prediction\n", - "config.data_splitter_events_variables_category_mapping = {\n", + "config.event_category_events_prediction_with_naming = {\n", " \"death\": \"death\",\n", " \"progression\": \"next progression\", # Custom name in prompt: \"next progression\" instead of \"progression\"\n", "}\n", diff --git a/docs/examples/03_end_to_end_llm_finetuning.ipynb b/docs/examples/03_end_to_end_llm_finetuning.ipynb index a73099d..8485533 100644 --- a/docs/examples/03_end_to_end_llm_finetuning.ipynb +++ b/docs/examples/03_end_to_end_llm_finetuning.ipynb @@ -109,7 +109,7 @@ "\n", "# 3. Mapping of specific time to events to predict (e.g., we want to predict 'death' and 'progression')\n", "# Only needs to be set if you want to do time to event prediction\n", - "config.data_splitter_events_variables_category_mapping = {\n", + "config.event_category_events_prediction_with_naming = {\n", " \"death\": \"death\",\n", " \"progression\": \"next progression\", # Custom name in prompt: \"next progression\" instead of \"progression\"\n", "}\n", diff --git a/docs/examples/advanced/custom_output/custom_summarized_row.ipynb b/docs/examples/advanced/custom_output/custom_summarized_row.ipynb index 9a88f68..43a3f1c 100644 --- a/docs/examples/advanced/custom_output/custom_summarized_row.ipynb +++ b/docs/examples/advanced/custom_output/custom_summarized_row.ipynb @@ -95,7 +95,7 @@ "# Required settings\n", "config.split_event_category = \"lot\"\n", "config.event_category_forecast = [\"lab\"]\n", - "config.data_splitter_events_variables_category_mapping = {\n", + "config.event_category_events_prediction_with_naming = {\n", " \"death\": \"death\",\n", " \"progression\": \"next progression\",\n", "}\n", diff --git a/docs/examples/advanced/custom_output/customizing_text_generation.ipynb b/docs/examples/advanced/custom_output/customizing_text_generation.ipynb index c1bad97..aeba1ee 100644 --- a/docs/examples/advanced/custom_output/customizing_text_generation.ipynb +++ b/docs/examples/advanced/custom_output/customizing_text_generation.ipynb @@ -87,7 +87,7 @@ "# Required settings for instruction mode\n", "config_default.split_event_category = \"lot\"\n", "config_default.event_category_forecast = [\"lab\"]\n", - "config_default.data_splitter_events_variables_category_mapping = {\n", + "config_default.event_category_events_prediction_with_naming = {\n", " \"death\": \"death\",\n", " \"progression\": \"next progression\",\n", "}\n", @@ -202,7 +202,7 @@ "# Required settings\n", "config_custom.split_event_category = \"lot\"\n", "config_custom.event_category_forecast = [\"lab\"]\n", - "config_custom.data_splitter_events_variables_category_mapping = {\n", + "config_custom.event_category_events_prediction_with_naming = {\n", " \"death\": \"mortality\", # Custom name for death event\n", " \"progression\": \"disease progression\", # Custom name for progression\n", "}\n", diff --git a/docs/examples/advanced/custom_splitting/inference_individual_splitters.py b/docs/examples/advanced/custom_splitting/inference_individual_splitters.py index 377ac8e..883d7a6 100644 --- a/docs/examples/advanced/custom_splitting/inference_individual_splitters.py +++ b/docs/examples/advanced/custom_splitting/inference_individual_splitters.py @@ -16,7 +16,7 @@ def __init__( self.config = Config() self.config.split_event_category = "lot" self.config.event_category_forecast = ["lab"] - self.config.data_splitter_events_variables_category_mapping = { + self.config.event_category_events_prediction_with_naming = { "death": "death", "progression": "next progression", # Custom name in prompt: "next progression" instead of "progression" } diff --git a/docs/examples/advanced/custom_splitting/training_custom_split_events.ipynb b/docs/examples/advanced/custom_splitting/training_custom_split_events.ipynb index 2c8df56..abd7ab3 100644 --- a/docs/examples/advanced/custom_splitting/training_custom_split_events.ipynb +++ b/docs/examples/advanced/custom_splitting/training_custom_split_events.ipynb @@ -101,7 +101,7 @@ "config.event_category_forecast = [\"vitals\"]\n", "\n", "# To predict different variables for the event categories, we set up a mapping here\n", - "config.data_splitter_events_variables_category_mapping = {\n", + "config.event_category_events_prediction_with_naming = {\n", " \"lot\": \"time to next lot\", # Custom name in prompt: \"time to next lot\" instead of \"lot\"\n", "}" ] diff --git a/docs/examples/advanced/custom_splitting/training_individual_splitters.ipynb b/docs/examples/advanced/custom_splitting/training_individual_splitters.ipynb index 8389bcb..afe4d6f 100644 --- a/docs/examples/advanced/custom_splitting/training_individual_splitters.ipynb +++ b/docs/examples/advanced/custom_splitting/training_individual_splitters.ipynb @@ -90,7 +90,7 @@ "\n", "# 3. Mapping of specific time to events to predict (e.g., we want to predict 'death' and 'progression')\n", "# Only needs to be set if you want to do time to event prediction\n", - "config.data_splitter_events_variables_category_mapping = {\n", + "config.event_category_events_prediction_with_naming = {\n", " \"death\": \"death\",\n", " \"progression\": \"next progression\", # Custom name in prompt: \"next progression\" instead of \"progression\"\n", "}\n", diff --git a/docs/examples/advanced/tte_inference/tte_probability_inference.ipynb b/docs/examples/advanced/tte_inference/tte_probability_inference.ipynb index 64babc3..056c219 100644 --- a/docs/examples/advanced/tte_inference/tte_probability_inference.ipynb +++ b/docs/examples/advanced/tte_inference/tte_probability_inference.ipynb @@ -136,7 +136,7 @@ "# 3. Time-to-event variables to predict\n", "# Keys = event_category values in the events DataFrame\n", "# Values = human-readable name used in the prompt\n", - "config.data_splitter_events_variables_category_mapping = {\n", + "config.event_category_events_prediction_with_naming = {\n", " \"death\": \"death\",\n", " \"progression\": \"next progression\",\n", "}\n", diff --git a/docs/examples/data_preprocessing/raw_data_preprocessing.ipynb b/docs/examples/data_preprocessing/raw_data_preprocessing.ipynb index 58bf5d1..ea8fab0 100644 --- a/docs/examples/data_preprocessing/raw_data_preprocessing.ipynb +++ b/docs/examples/data_preprocessing/raw_data_preprocessing.ipynb @@ -1007,7 +1007,7 @@ "\n", "# 3. Mapping of specific time to events to predict (e.g., we want to predict 'death' and 'progression')\n", "# Only needs to be set if you want to do time to event prediction\n", - "config.data_splitter_events_variables_category_mapping = {\n", + "config.event_category_events_prediction_with_naming = {\n", " \"death\": \"death\",\n", " \"progression\": \"next progression\", # Custom name in prompt: \"next progression\" instead of \"progression\"\n", "}\n", diff --git a/docs/examples/hackathon/01_data_preparation_challenge.ipynb b/docs/examples/hackathon/01_data_preparation_challenge.ipynb index 3cfd551..2461d78 100644 --- a/docs/examples/hackathon/01_data_preparation_challenge.ipynb +++ b/docs/examples/hackathon/01_data_preparation_challenge.ipynb @@ -180,7 +180,7 @@ "# TODO: Configure time-to-event prediction targets\n", "# HINT: This should be a dictionary mapping event names to display names\n", "# Example: {\"original_name\": \"display name in prompt\"}\n", - "config.data_splitter_events_variables_category_mapping = None # Replace with dict" + "config.event_category_events_prediction_with_naming = None # Replace with dict" ] }, { @@ -218,19 +218,19 @@ " else:\n", " print(f\"✅ event_category_forecast: {config.event_category_forecast}\")\n", "\n", - " if config.data_splitter_events_variables_category_mapping is None:\n", - " errors.append(\"❌ data_splitter_events_variables_category_mapping is not set\")\n", - " elif not isinstance(config.data_splitter_events_variables_category_mapping, dict):\n", - " errors.append(\"❌ data_splitter_events_variables_category_mapping should be a dict\")\n", + " if config.event_category_events_prediction_with_naming is None:\n", + " errors.append(\"❌ event_category_events_prediction_with_naming is not set\")\n", + " elif not isinstance(config.event_category_events_prediction_with_naming, dict):\n", + " errors.append(\"❌ event_category_events_prediction_with_naming should be a dict\")\n", " elif any(\n", " [\n", " cat not in df_events[\"event_category\"].unique()\n", - " for cat in config.data_splitter_events_variables_category_mapping.keys()\n", + " for cat in config.event_category_events_prediction_with_naming.keys()\n", " ]\n", " ):\n", - " errors.append(\"❌ At least one key in data_splitter_events_variables_category_mapping not found in data\")\n", + " errors.append(\"❌ At least one key in event_category_events_prediction_with_naming not found in data\")\n", " else:\n", - " print(f\"✅ Event mapping: {config.data_splitter_events_variables_category_mapping}\")\n", + " print(f\"✅ Event mapping: {config.event_category_events_prediction_with_naming}\")\n", "\n", " if errors:\n", " print(\"\\n\" + \"\\n\".join(errors))\n", diff --git a/docs/examples/hackathon/02_llm_finetuning_challenge.ipynb b/docs/examples/hackathon/02_llm_finetuning_challenge.ipynb index e45d570..07d0908 100644 --- a/docs/examples/hackathon/02_llm_finetuning_challenge.ipynb +++ b/docs/examples/hackathon/02_llm_finetuning_challenge.ipynb @@ -150,7 +150,7 @@ "# Set up:\n", "# - split_event_category\n", "# - event_category_forecast\n", - "# - data_splitter_events_variables_category_mapping\n", + "# - event_category_events_prediction_with_naming\n", "# - constant_columns_to_use\n", "# - constant_birthdate_column\n", "\n", diff --git a/docs/examples/integrations/meds_data_import.ipynb b/docs/examples/integrations/meds_data_import.ipynb index d7412d2..ed7af92 100644 --- a/docs/examples/integrations/meds_data_import.ipynb +++ b/docs/examples/integrations/meds_data_import.ipynb @@ -477,7 +477,7 @@ "config.constant_birthdate_column = None # Not using in demo\n", "config.event_value_lot_start = None\n", "config.split_event_category = \"lot\"\n", - "config.data_splitter_events_variables_category_mapping = {\n", + "config.event_category_events_prediction_with_naming = {\n", " \"death\": \"death\",\n", "}" ] diff --git a/docs/framework.md b/docs/framework.md index 82ef8bc..5f836e0 100644 --- a/docs/framework.md +++ b/docs/framework.md @@ -57,7 +57,7 @@ TwinWeaver supports two primary data formats, each serving a distinct stage in t - `config.split_event_category`: Event category used to anchor split points (e.g., `"lot"` for line of therapy) - `config.event_category_forecast`: List of event categories to forecast (e.g., `["lab"]`) - - `config.data_splitter_events_variables_category_mapping`: Mapping of events to prediction tasks (e.g., death, progression) + - `config.event_category_events_prediction_with_naming`: Mapping of events to prediction tasks (e.g., death, progression) See the [Data Splitting](data-splitting.md) page for a detailed explanation, or the [Quick Start](quickstart.md) and [Data Preparation Tutorial](examples/01_data_preparation_for_training.ipynb) for examples. diff --git a/docs/quickstart.md b/docs/quickstart.md index a5ae1dd..0a9012b 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -33,7 +33,7 @@ config.event_category_forecast = ["lab"] # 3. Mapping of specific time to events to predict (e.g., we want to predict 'death' and 'progression') # Only needs to be set if you want to do time to event prediction -config.data_splitter_events_variables_category_mapping = { +config.event_category_events_prediction_with_naming = { "death": "death", "progression": "next progression", # Custom name in prompt } diff --git a/examples/01_data_preparation_for_training.ipynb b/examples/01_data_preparation_for_training.ipynb index 6264cd6..1b242d8 100644 --- a/examples/01_data_preparation_for_training.ipynb +++ b/examples/01_data_preparation_for_training.ipynb @@ -95,7 +95,7 @@ "\n", "- **`config.split_event_category`**: Determines the event category used to split patient histories into input and output segments. In this example, we split data around \"Line of Therapy\" (`lot`) start dates.\n", "- **`config.event_category_forecast`**: Identifies which event categories should be forecasted as time-series values. Here, we target `lab` values.\n", - "- **`config.data_splitter_events_variables_category_mapping`**: Maps specific events to survival analysis or classification tasks. We configure the model to predict outcomes for `death` and `progression`." + "- **`config.event_category_events_prediction_with_naming`**: Maps specific events to survival analysis or classification tasks. We configure the model to predict outcomes for `death` and `progression`." ] }, { @@ -118,7 +118,7 @@ "\n", "# 3. Mapping of specific time to events to predict (e.g., we want to predict 'death' and 'progression')\n", "# Only needs to be set if you want to do time to event prediction\n", - "config.data_splitter_events_variables_category_mapping = {\n", + "config.event_category_events_prediction_with_naming = {\n", " \"death\": \"death\",\n", " \"progression\": \"next progression\", # Custom name in prompt: \"next progression\" instead of \"progression\"\n", "}" diff --git a/examples/02_inference_prompt_preparation.ipynb b/examples/02_inference_prompt_preparation.ipynb index fca3954..4facb67 100644 --- a/examples/02_inference_prompt_preparation.ipynb +++ b/examples/02_inference_prompt_preparation.ipynb @@ -67,7 +67,7 @@ "\n", "# 3. Mapping of specific time to events to predict (e.g., we want to predict 'death' and 'progression')\n", "# Only needs to be set if you want to do time to event prediction\n", - "config.data_splitter_events_variables_category_mapping = {\n", + "config.event_category_events_prediction_with_naming = {\n", " \"death\": \"death\",\n", " \"progression\": \"next progression\", # Custom name in prompt: \"next progression\" instead of \"progression\"\n", "}\n", diff --git a/examples/03_end_to_end_llm_finetuning.ipynb b/examples/03_end_to_end_llm_finetuning.ipynb index a73099d..8485533 100644 --- a/examples/03_end_to_end_llm_finetuning.ipynb +++ b/examples/03_end_to_end_llm_finetuning.ipynb @@ -109,7 +109,7 @@ "\n", "# 3. Mapping of specific time to events to predict (e.g., we want to predict 'death' and 'progression')\n", "# Only needs to be set if you want to do time to event prediction\n", - "config.data_splitter_events_variables_category_mapping = {\n", + "config.event_category_events_prediction_with_naming = {\n", " \"death\": \"death\",\n", " \"progression\": \"next progression\", # Custom name in prompt: \"next progression\" instead of \"progression\"\n", "}\n", diff --git a/examples/advanced/custom_output/custom_summarized_row.ipynb b/examples/advanced/custom_output/custom_summarized_row.ipynb index 9a88f68..43a3f1c 100644 --- a/examples/advanced/custom_output/custom_summarized_row.ipynb +++ b/examples/advanced/custom_output/custom_summarized_row.ipynb @@ -95,7 +95,7 @@ "# Required settings\n", "config.split_event_category = \"lot\"\n", "config.event_category_forecast = [\"lab\"]\n", - "config.data_splitter_events_variables_category_mapping = {\n", + "config.event_category_events_prediction_with_naming = {\n", " \"death\": \"death\",\n", " \"progression\": \"next progression\",\n", "}\n", diff --git a/examples/advanced/custom_output/customizing_text_generation.ipynb b/examples/advanced/custom_output/customizing_text_generation.ipynb index c1bad97..aeba1ee 100644 --- a/examples/advanced/custom_output/customizing_text_generation.ipynb +++ b/examples/advanced/custom_output/customizing_text_generation.ipynb @@ -87,7 +87,7 @@ "# Required settings for instruction mode\n", "config_default.split_event_category = \"lot\"\n", "config_default.event_category_forecast = [\"lab\"]\n", - "config_default.data_splitter_events_variables_category_mapping = {\n", + "config_default.event_category_events_prediction_with_naming = {\n", " \"death\": \"death\",\n", " \"progression\": \"next progression\",\n", "}\n", @@ -202,7 +202,7 @@ "# Required settings\n", "config_custom.split_event_category = \"lot\"\n", "config_custom.event_category_forecast = [\"lab\"]\n", - "config_custom.data_splitter_events_variables_category_mapping = {\n", + "config_custom.event_category_events_prediction_with_naming = {\n", " \"death\": \"mortality\", # Custom name for death event\n", " \"progression\": \"disease progression\", # Custom name for progression\n", "}\n", diff --git a/examples/advanced/custom_splitting/inference_individual_splitters.py b/examples/advanced/custom_splitting/inference_individual_splitters.py index 0fb06f8..0adacaf 100644 --- a/examples/advanced/custom_splitting/inference_individual_splitters.py +++ b/examples/advanced/custom_splitting/inference_individual_splitters.py @@ -17,7 +17,7 @@ def __init__( self.config = Config() self.config.split_event_category = "lot" self.config.event_category_forecast = ["lab"] - self.config.data_splitter_events_variables_category_mapping = { + self.config.event_category_events_prediction_with_naming = { "death": "death", "progression": "next progression", # Custom name in prompt: "next progression" instead of "progression" } diff --git a/examples/advanced/custom_splitting/training_custom_split_events.ipynb b/examples/advanced/custom_splitting/training_custom_split_events.ipynb index 46eae55..b0f6673 100644 --- a/examples/advanced/custom_splitting/training_custom_split_events.ipynb +++ b/examples/advanced/custom_splitting/training_custom_split_events.ipynb @@ -101,7 +101,7 @@ "config.event_category_forecast = [\"vitals\"]\n", "\n", "# To predict different variables for the event categories, we set up a mapping here\n", - "config.data_splitter_events_variables_category_mapping = {\n", + "config.event_category_events_prediction_with_naming = {\n", " \"lot\": \"time to next lot\", # Custom name in prompt: \"time to next lot\" instead of \"lot\"\n", "}" ] diff --git a/examples/advanced/custom_splitting/training_individual_splitters.ipynb b/examples/advanced/custom_splitting/training_individual_splitters.ipynb index 2e84ac2..63b9674 100644 --- a/examples/advanced/custom_splitting/training_individual_splitters.ipynb +++ b/examples/advanced/custom_splitting/training_individual_splitters.ipynb @@ -91,7 +91,7 @@ "\n", "# 3. Mapping of specific time to events to predict (e.g., we want to predict 'death' and 'progression')\n", "# Only needs to be set if you want to do time to event prediction\n", - "config.data_splitter_events_variables_category_mapping = {\n", + "config.event_category_events_prediction_with_naming = {\n", " \"death\": \"death\",\n", " \"progression\": \"next progression\", # Custom name in prompt: \"next progression\" instead of \"progression\"\n", "}\n", diff --git a/examples/advanced/tte_inference/tte_probability_inference.ipynb b/examples/advanced/tte_inference/tte_probability_inference.ipynb index 64babc3..056c219 100644 --- a/examples/advanced/tte_inference/tte_probability_inference.ipynb +++ b/examples/advanced/tte_inference/tte_probability_inference.ipynb @@ -136,7 +136,7 @@ "# 3. Time-to-event variables to predict\n", "# Keys = event_category values in the events DataFrame\n", "# Values = human-readable name used in the prompt\n", - "config.data_splitter_events_variables_category_mapping = {\n", + "config.event_category_events_prediction_with_naming = {\n", " \"death\": \"death\",\n", " \"progression\": \"next progression\",\n", "}\n", diff --git a/examples/data_preprocessing/raw_data_preprocessing.ipynb b/examples/data_preprocessing/raw_data_preprocessing.ipynb index 58bf5d1..ea8fab0 100644 --- a/examples/data_preprocessing/raw_data_preprocessing.ipynb +++ b/examples/data_preprocessing/raw_data_preprocessing.ipynb @@ -1007,7 +1007,7 @@ "\n", "# 3. Mapping of specific time to events to predict (e.g., we want to predict 'death' and 'progression')\n", "# Only needs to be set if you want to do time to event prediction\n", - "config.data_splitter_events_variables_category_mapping = {\n", + "config.event_category_events_prediction_with_naming = {\n", " \"death\": \"death\",\n", " \"progression\": \"next progression\", # Custom name in prompt: \"next progression\" instead of \"progression\"\n", "}\n", diff --git a/examples/hackathon/01_data_preparation_challenge.ipynb b/examples/hackathon/01_data_preparation_challenge.ipynb index 3cfd551..2461d78 100644 --- a/examples/hackathon/01_data_preparation_challenge.ipynb +++ b/examples/hackathon/01_data_preparation_challenge.ipynb @@ -180,7 +180,7 @@ "# TODO: Configure time-to-event prediction targets\n", "# HINT: This should be a dictionary mapping event names to display names\n", "# Example: {\"original_name\": \"display name in prompt\"}\n", - "config.data_splitter_events_variables_category_mapping = None # Replace with dict" + "config.event_category_events_prediction_with_naming = None # Replace with dict" ] }, { @@ -218,19 +218,19 @@ " else:\n", " print(f\"✅ event_category_forecast: {config.event_category_forecast}\")\n", "\n", - " if config.data_splitter_events_variables_category_mapping is None:\n", - " errors.append(\"❌ data_splitter_events_variables_category_mapping is not set\")\n", - " elif not isinstance(config.data_splitter_events_variables_category_mapping, dict):\n", - " errors.append(\"❌ data_splitter_events_variables_category_mapping should be a dict\")\n", + " if config.event_category_events_prediction_with_naming is None:\n", + " errors.append(\"❌ event_category_events_prediction_with_naming is not set\")\n", + " elif not isinstance(config.event_category_events_prediction_with_naming, dict):\n", + " errors.append(\"❌ event_category_events_prediction_with_naming should be a dict\")\n", " elif any(\n", " [\n", " cat not in df_events[\"event_category\"].unique()\n", - " for cat in config.data_splitter_events_variables_category_mapping.keys()\n", + " for cat in config.event_category_events_prediction_with_naming.keys()\n", " ]\n", " ):\n", - " errors.append(\"❌ At least one key in data_splitter_events_variables_category_mapping not found in data\")\n", + " errors.append(\"❌ At least one key in event_category_events_prediction_with_naming not found in data\")\n", " else:\n", - " print(f\"✅ Event mapping: {config.data_splitter_events_variables_category_mapping}\")\n", + " print(f\"✅ Event mapping: {config.event_category_events_prediction_with_naming}\")\n", "\n", " if errors:\n", " print(\"\\n\" + \"\\n\".join(errors))\n", diff --git a/examples/hackathon/02_llm_finetuning_challenge.ipynb b/examples/hackathon/02_llm_finetuning_challenge.ipynb index 28f1df1..9405caf 100644 --- a/examples/hackathon/02_llm_finetuning_challenge.ipynb +++ b/examples/hackathon/02_llm_finetuning_challenge.ipynb @@ -150,7 +150,7 @@ "# Set up:\n", "# - split_event_category\n", "# - event_category_forecast\n", - "# - data_splitter_events_variables_category_mapping\n", + "# - event_category_events_prediction_with_naming\n", "# - constant_columns_to_use\n", "# - constant_birthdate_column\n", "\n", diff --git a/examples/integrations/meds_data_import.ipynb b/examples/integrations/meds_data_import.ipynb index b75b3bc..5467c57 100644 --- a/examples/integrations/meds_data_import.ipynb +++ b/examples/integrations/meds_data_import.ipynb @@ -477,7 +477,7 @@ "config.constant_birthdate_column = None # Not using in demo\n", "config.event_value_lot_start = None\n", "config.split_event_category = \"lot\"\n", - "config.data_splitter_events_variables_category_mapping = {\n", + "config.event_category_events_prediction_with_naming = {\n", " \"death\": \"death\",\n", "}" ] diff --git a/tests/test_common.py b/tests/test_common.py index 93d75e6..0c37b3d 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -44,7 +44,7 @@ def test_data_manager_processing(mock_config, sample_data): # Override config mock_config.split_event_category = "lot" mock_config.event_category_forecast = ["lab"] - mock_config.data_splitter_events_variables_category_mapping = None + mock_config.event_category_events_prediction_with_naming = None mock_config.constant_columns_to_use = ["birthyear", "gender", "histology", "smoking_history"] dm = DataManager(config=mock_config) diff --git a/tests/test_converter.py b/tests/test_converter.py index 32f1503..4182775 100644 --- a/tests/test_converter.py +++ b/tests/test_converter.py @@ -13,7 +13,7 @@ def setup_components(mock_config, sample_data): df_events, df_constant, df_constant_desc = sample_data mock_config.split_event_category = "lot" mock_config.event_category_forecast = ["lab"] - mock_config.data_splitter_events_variables_category_mapping = {"death": "death", "progression": "next progression"} + mock_config.event_category_events_prediction_with_naming = {"death": "death", "progression": "next progression"} mock_config.constant_columns_to_use = ["birthyear", "gender", "histology", "smoking_history"] mock_config.constant_birthdate_column = "birthyear" @@ -109,7 +109,7 @@ def test_event_categories_to_exclude_from_input(mock_config, sample_data): # Configure with drug events excluded mock_config.split_event_category = "lot" mock_config.event_category_forecast = ["lab"] - mock_config.data_splitter_events_variables_category_mapping = {"death": "death", "progression": "next progression"} + mock_config.event_category_events_prediction_with_naming = {"death": "death", "progression": "next progression"} mock_config.constant_columns_to_use = ["birthyear", "gender", "histology", "smoking_history"] mock_config.constant_birthdate_column = "birthyear" mock_config.event_categories_to_exclude_from_input = ["drug"] @@ -174,7 +174,7 @@ def test_event_categories_to_exclude_multiple(mock_config, sample_data): mock_config.split_event_category = "lot" mock_config.event_category_forecast = ["lab"] - mock_config.data_splitter_events_variables_category_mapping = {"death": "death", "progression": "next progression"} + mock_config.event_category_events_prediction_with_naming = {"death": "death", "progression": "next progression"} mock_config.constant_columns_to_use = ["birthyear", "gender", "histology", "smoking_history"] mock_config.constant_birthdate_column = "birthyear" mock_config.event_categories_to_exclude_from_input = ["drug", "ecog"] diff --git a/tests/test_splitter.py b/tests/test_splitter.py index 7f1e42f..d7a1913 100644 --- a/tests/test_splitter.py +++ b/tests/test_splitter.py @@ -13,7 +13,7 @@ def initialized_dm(mock_config, sample_data): df_events, df_constant, df_constant_desc = sample_data mock_config.split_event_category = "lot" mock_config.event_category_forecast = ["lab"] - mock_config.data_splitter_events_variables_category_mapping = {"death": "death", "progression": "next progression"} + mock_config.event_category_events_prediction_with_naming = {"death": "death", "progression": "next progression"} mock_config.constant_columns_to_use = ["birthyear", "gender", "histology", "smoking_history"] dm = DataManager(config=mock_config) diff --git a/twinweaver/common/config.py b/twinweaver/common/config.py index e871931..9955710 100644 --- a/twinweaver/common/config.py +++ b/twinweaver/common/config.py @@ -223,7 +223,7 @@ class Config: Default: None. constant_birthdate_column_format : str Format of the birthdate column, either "date" or "age". Default: "date". - data_splitter_events_variables_category_mapping : dict | None + event_category_events_prediction_with_naming : dict | None Mapping defining which event categories correspond to specific prediction types in DataSplitterEvents. Keys are event categories (e.g., 'death', 'progression'), values are descriptive names for the target variable. Default: None. @@ -245,7 +245,7 @@ def __init__(self): # different event types as well as how they should be written down (since based on categories), # for example, based on GDT: { "death": "death", "progression": "next progression", "lot": # "next line of therapy", "metastasis": "next metastasis"} - self.data_splitter_events_variables_category_mapping = None + self.event_category_events_prediction_with_naming = None # --- Import data parameters --- self.date_cutoff = None # If set, only use data before this date (format: "YYYY-MM-DD"), censored after diff --git a/twinweaver/instruction/data_splitter_events.py b/twinweaver/instruction/data_splitter_events.py index 46a27a9..fce1823 100644 --- a/twinweaver/instruction/data_splitter_events.py +++ b/twinweaver/instruction/data_splitter_events.py @@ -123,12 +123,12 @@ def __init__( self.min_length_to_sample = min_length_to_sample self.unit_length_to_sample = unit_length_to_sample - assert self.config.data_splitter_events_variables_category_mapping is not None, ( - "data_splitter_events_variables_category_mapping must be set in Config for DataSplitterEvents." + assert self.config.event_category_events_prediction_with_naming is not None, ( + "event_category_events_prediction_with_naming must be set in Config for DataSplitterEvents." "For example: { 'death': 'death', 'progression': 'next progression'}" ) - self.manual_variables_category_mapping = self.config.data_splitter_events_variables_category_mapping + self.manual_variables_category_mapping = self.config.event_category_events_prediction_with_naming def setup_variables(self): """ @@ -149,7 +149,7 @@ def setup_variables(self): if len(self.manual_variables_category_mapping) == 0: raise ValueError( "No valid event categories found in the data for event prediction splitting. " - "Check the data or adjust data_splitter_events_variables_category_mapping in Config." + "Check the data or adjust event_category_events_prediction_with_naming in Config." ) def _sample_manual_variables(self, events_after_split: pd.DataFrame, override_category: str) -> tuple: From 13e0e1515d0501ee81e2fa70c909a20801914c3f Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 13:55:54 +0000 Subject: [PATCH 15/36] Renamed DM splitting to setup_hold_out_sets --- README.md | 2 +- docs/api-index.md | 2 +- docs/data-splitting.md | 2 +- .../01_data_preparation_for_training.ipynb | 2 +- .../02_inference_prompt_preparation.ipynb | 2 +- .../03_end_to_end_llm_finetuning.ipynb | 2 +- .../custom_output/custom_summarized_row.ipynb | 2 +- .../customizing_text_generation.ipynb | 4 +- .../inference_individual_splitters.py | 2 +- .../training_custom_split_events.ipynb | 2 +- .../training_forecasting_splitter_only.ipynb | 2 +- .../training_individual_splitters.ipynb | 2 +- ...nd_to_end_llm_training_with_pretrain.ipynb | 2 +- .../pretraining/prepare_pretraining_data.py | 2 +- .../tte_probability_inference.ipynb | 2 +- .../raw_data_preprocessing.ipynb | 2 +- .../integrations/meds_data_import.ipynb | 2 +- docs/quickstart.md | 2 +- .../01_data_preparation_for_training.ipynb | 2 +- .../02_inference_prompt_preparation.ipynb | 2 +- examples/03_end_to_end_llm_finetuning.ipynb | 2 +- .../custom_output/custom_summarized_row.ipynb | 2 +- .../customizing_text_generation.ipynb | 4 +- .../inference_individual_splitters.py | 2 +- .../training_custom_split_events.ipynb | 2 +- .../training_forecasting_splitter_only.ipynb | 2 +- .../training_individual_splitters.ipynb | 2 +- ...nd_to_end_llm_training_with_pretrain.ipynb | 2 +- .../pretraining/prepare_pretraining_data.py | 2 +- .../tte_probability_inference.ipynb | 2 +- .../raw_data_preprocessing.ipynb | 2 +- examples/integrations/meds_data_import.ipynb | 2 +- tests/test_common.py | 2 +- tests/test_converter.py | 6 +-- tests/test_converter_pretrain.py | 2 +- tests/test_splitter.py | 2 +- twinweaver/common/data_manager.py | 54 +++++++++---------- 37 files changed, 66 insertions(+), 68 deletions(-) diff --git a/README.md b/README.md index 59e77c0..2122905 100644 --- a/README.md +++ b/README.md @@ -143,7 +143,7 @@ dm = DataManager(config=config) dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description) dm.process_indication_data() dm.setup_unique_mapping_of_events() -dm.setup_dataset_splits() +dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1) dm.infer_var_types() # This data splitter handles event prediction tasks diff --git a/docs/api-index.md b/docs/api-index.md index d071451..d184bde 100644 --- a/docs/api-index.md +++ b/docs/api-index.md @@ -56,7 +56,7 @@ Handles data loading and management. | [`DataManager.load_indication_data`](reference/common/data_manager.md#twinweaver.common.data_manager.DataManager.load_indication_data) | Method | Load data tables for a specific indication | | [`DataManager.process_indication_data`](reference/common/data_manager.md#twinweaver.common.data_manager.DataManager.process_indication_data) | Method | Process loaded indication data | | [`DataManager.setup_unique_mapping_of_events`](reference/common/data_manager.md#twinweaver.common.data_manager.DataManager.setup_unique_mapping_of_events) | Method | Create unique mapping for all events | -| [`DataManager.setup_dataset_splits`](reference/common/data_manager.md#twinweaver.common.data_manager.DataManager.setup_dataset_splits) | Method | Split data into train/val/test sets | +| [`DataManager.setup_hold_out_sets(validation_split=0.1, test_split=0.1)`](reference/common/data_manager.md#twinweaver.common.data_manager.DataManager.setup_hold_out_sets(validation_split=0.1, test_split=0.1)) | Method | Split data into train/val/test sets | | [`DataManager.get_all_patientids_in_split`](reference/common/data_manager.md#twinweaver.common.data_manager.DataManager.get_all_patientids_in_split) | Method | Get all patient IDs in a specific split | | [`DataManager.get_patient_split`](reference/common/data_manager.md#twinweaver.common.data_manager.DataManager.get_patient_split) | Method | Get the split assignment for a patient | | [`DataManager.get_patient_data`](reference/common/data_manager.md#twinweaver.common.data_manager.DataManager.get_patient_data) | Method | Retrieve all data for a specific patient | diff --git a/docs/data-splitting.md b/docs/data-splitting.md index 621fce3..4ae264e 100644 --- a/docs/data-splitting.md +++ b/docs/data-splitting.md @@ -293,7 +293,7 @@ dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description) dm.process_indication_data() dm.setup_unique_mapping_of_events() -dm.setup_dataset_splits() +dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1) dm.infer_var_types() # 3. Initialize splitters diff --git a/docs/examples/01_data_preparation_for_training.ipynb b/docs/examples/01_data_preparation_for_training.ipynb index 1b242d8..24be08f 100644 --- a/docs/examples/01_data_preparation_for_training.ipynb +++ b/docs/examples/01_data_preparation_for_training.ipynb @@ -175,7 +175,7 @@ "# Setup unique mapping of events, to understand which events correspond to which categories\n", "dm.setup_unique_mapping_of_events()\n", "# (Optional) assign each patient to train/validation/test splits\n", - "dm.setup_dataset_splits()\n", + "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", "# (Optional - needed for forecasting) infer variable types\n", "dm.infer_var_types()" ] diff --git a/docs/examples/02_inference_prompt_preparation.ipynb b/docs/examples/02_inference_prompt_preparation.ipynb index 4facb67..9075f16 100644 --- a/docs/examples/02_inference_prompt_preparation.ipynb +++ b/docs/examples/02_inference_prompt_preparation.ipynb @@ -88,7 +88,7 @@ "dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n", "dm.process_indication_data()\n", "dm.setup_unique_mapping_of_events()\n", - "dm.setup_dataset_splits()\n", + "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", "dm.infer_var_types()\n", "\n", "data_splitter_events = DataSplitterEvents(\n", diff --git a/docs/examples/03_end_to_end_llm_finetuning.ipynb b/docs/examples/03_end_to_end_llm_finetuning.ipynb index 8485533..33020d8 100644 --- a/docs/examples/03_end_to_end_llm_finetuning.ipynb +++ b/docs/examples/03_end_to_end_llm_finetuning.ipynb @@ -144,7 +144,7 @@ "dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n", "dm.process_indication_data()\n", "dm.setup_unique_mapping_of_events()\n", - "dm.setup_dataset_splits()\n", + "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", "dm.infer_var_types()\n", "\n", "data_splitter_events = DataSplitterEvents(\n", diff --git a/docs/examples/advanced/custom_output/custom_summarized_row.ipynb b/docs/examples/advanced/custom_output/custom_summarized_row.ipynb index 43a3f1c..df47c0a 100644 --- a/docs/examples/advanced/custom_output/custom_summarized_row.ipynb +++ b/docs/examples/advanced/custom_output/custom_summarized_row.ipynb @@ -122,7 +122,7 @@ "dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n", "dm.process_indication_data()\n", "dm.setup_unique_mapping_of_events()\n", - "dm.setup_dataset_splits()\n", + "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", "dm.infer_var_types()" ] }, diff --git a/docs/examples/advanced/custom_output/customizing_text_generation.ipynb b/docs/examples/advanced/custom_output/customizing_text_generation.ipynb index aeba1ee..1e552d1 100644 --- a/docs/examples/advanced/custom_output/customizing_text_generation.ipynb +++ b/docs/examples/advanced/custom_output/customizing_text_generation.ipynb @@ -109,7 +109,7 @@ ")\n", "dm_default.process_indication_data()\n", "dm_default.setup_unique_mapping_of_events()\n", - "dm_default.setup_dataset_splits()\n", + "dm_default.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", "dm_default.infer_var_types()" ] }, @@ -620,7 +620,7 @@ ")\n", "dm_custom.process_indication_data()\n", "dm_custom.setup_unique_mapping_of_events()\n", - "dm_custom.setup_dataset_splits()\n", + "dm_custom.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", "dm_custom.infer_var_types()" ] }, diff --git a/docs/examples/advanced/custom_splitting/inference_individual_splitters.py b/docs/examples/advanced/custom_splitting/inference_individual_splitters.py index 883d7a6..8b773db 100644 --- a/docs/examples/advanced/custom_splitting/inference_individual_splitters.py +++ b/docs/examples/advanced/custom_splitting/inference_individual_splitters.py @@ -42,7 +42,7 @@ def __init__( ) self.dm.process_indication_data() self.dm.setup_unique_mapping_of_events() - self.dm.setup_dataset_splits() + self.dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1) self.dm.infer_var_types() self.data_splitter_events = DataSplitterEvents( diff --git a/docs/examples/advanced/custom_splitting/training_custom_split_events.ipynb b/docs/examples/advanced/custom_splitting/training_custom_split_events.ipynb index abd7ab3..23b50cd 100644 --- a/docs/examples/advanced/custom_splitting/training_custom_split_events.ipynb +++ b/docs/examples/advanced/custom_splitting/training_custom_split_events.ipynb @@ -118,7 +118,7 @@ "dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n", "dm.process_indication_data()\n", "dm.setup_unique_mapping_of_events()\n", - "dm.setup_dataset_splits()\n", + "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", "dm.infer_var_types()" ] }, diff --git a/docs/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb b/docs/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb index e4a5ae5..7c07e0b 100644 --- a/docs/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb +++ b/docs/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb @@ -105,7 +105,7 @@ "dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n", "dm.process_indication_data()\n", "dm.setup_unique_mapping_of_events()\n", - "dm.setup_dataset_splits()\n", + "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", "dm.infer_var_types()\n", "\n", "\n", diff --git a/docs/examples/advanced/custom_splitting/training_individual_splitters.ipynb b/docs/examples/advanced/custom_splitting/training_individual_splitters.ipynb index afe4d6f..71969f4 100644 --- a/docs/examples/advanced/custom_splitting/training_individual_splitters.ipynb +++ b/docs/examples/advanced/custom_splitting/training_individual_splitters.ipynb @@ -111,7 +111,7 @@ "dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n", "dm.process_indication_data()\n", "dm.setup_unique_mapping_of_events()\n", - "dm.setup_dataset_splits()\n", + "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", "dm.infer_var_types()\n", "\n", "data_splitter_events = DataSplitterEvents(\n", diff --git a/docs/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb b/docs/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb index 0ac8652..3e60fa5 100644 --- a/docs/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb +++ b/docs/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb @@ -88,7 +88,7 @@ "dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n", "dm.process_indication_data()\n", "dm.setup_unique_mapping_of_events()\n", - "dm.setup_dataset_splits()\n", + "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", "dm.infer_var_types()\n", "\n", "converter = ConverterPretrain(config=config, dm=dm)" diff --git a/docs/examples/advanced/pretraining/prepare_pretraining_data.py b/docs/examples/advanced/pretraining/prepare_pretraining_data.py index 05c5d3b..80a5402 100644 --- a/docs/examples/advanced/pretraining/prepare_pretraining_data.py +++ b/docs/examples/advanced/pretraining/prepare_pretraining_data.py @@ -28,7 +28,7 @@ def __init__(self): ) self.dm.process_indication_data() self.dm.setup_unique_mapping_of_events() - self.dm.setup_dataset_splits() + self.dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1) #: set up converter self.converter = ConverterPretrain(config=self.config, dm=self.dm) diff --git a/docs/examples/advanced/tte_inference/tte_probability_inference.ipynb b/docs/examples/advanced/tte_inference/tte_probability_inference.ipynb index 056c219..9b6a861 100644 --- a/docs/examples/advanced/tte_inference/tte_probability_inference.ipynb +++ b/docs/examples/advanced/tte_inference/tte_probability_inference.ipynb @@ -188,7 +188,7 @@ ")\n", "dm.process_indication_data()\n", "dm.setup_unique_mapping_of_events()\n", - "dm.setup_dataset_splits()\n", + "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", "dm.infer_var_types()\n", "\n", "data_splitter_events = DataSplitterEvents(\n", diff --git a/docs/examples/data_preprocessing/raw_data_preprocessing.ipynb b/docs/examples/data_preprocessing/raw_data_preprocessing.ipynb index ea8fab0..836c465 100644 --- a/docs/examples/data_preprocessing/raw_data_preprocessing.ipynb +++ b/docs/examples/data_preprocessing/raw_data_preprocessing.ipynb @@ -1036,7 +1036,7 @@ "dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n", "dm.process_indication_data()\n", "dm.setup_unique_mapping_of_events()\n", - "dm.setup_dataset_splits()\n", + "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", "dm.infer_var_types()\n", "\n", "print(f\"Loaded {len(dm.all_patientids)} patients into DataManager\")" diff --git a/docs/examples/integrations/meds_data_import.ipynb b/docs/examples/integrations/meds_data_import.ipynb index ed7af92..c593888 100644 --- a/docs/examples/integrations/meds_data_import.ipynb +++ b/docs/examples/integrations/meds_data_import.ipynb @@ -499,7 +499,7 @@ ")\n", "dm.process_indication_data()\n", "dm.setup_unique_mapping_of_events()\n", - "dm.setup_dataset_splits()\n", + "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", "\n", "data_splitter_events = DataSplitterEvents(\n", " dm,\n", diff --git a/docs/quickstart.md b/docs/quickstart.md index 0a9012b..207b919 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -48,7 +48,7 @@ dm.load_indication_data( ) dm.process_indication_data() dm.setup_unique_mapping_of_events() -dm.setup_dataset_splits() +dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1) dm.infer_var_types() # Set up data splitters for different task types diff --git a/examples/01_data_preparation_for_training.ipynb b/examples/01_data_preparation_for_training.ipynb index 1b242d8..24be08f 100644 --- a/examples/01_data_preparation_for_training.ipynb +++ b/examples/01_data_preparation_for_training.ipynb @@ -175,7 +175,7 @@ "# Setup unique mapping of events, to understand which events correspond to which categories\n", "dm.setup_unique_mapping_of_events()\n", "# (Optional) assign each patient to train/validation/test splits\n", - "dm.setup_dataset_splits()\n", + "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", "# (Optional - needed for forecasting) infer variable types\n", "dm.infer_var_types()" ] diff --git a/examples/02_inference_prompt_preparation.ipynb b/examples/02_inference_prompt_preparation.ipynb index 4facb67..9075f16 100644 --- a/examples/02_inference_prompt_preparation.ipynb +++ b/examples/02_inference_prompt_preparation.ipynb @@ -88,7 +88,7 @@ "dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n", "dm.process_indication_data()\n", "dm.setup_unique_mapping_of_events()\n", - "dm.setup_dataset_splits()\n", + "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", "dm.infer_var_types()\n", "\n", "data_splitter_events = DataSplitterEvents(\n", diff --git a/examples/03_end_to_end_llm_finetuning.ipynb b/examples/03_end_to_end_llm_finetuning.ipynb index 8485533..33020d8 100644 --- a/examples/03_end_to_end_llm_finetuning.ipynb +++ b/examples/03_end_to_end_llm_finetuning.ipynb @@ -144,7 +144,7 @@ "dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n", "dm.process_indication_data()\n", "dm.setup_unique_mapping_of_events()\n", - "dm.setup_dataset_splits()\n", + "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", "dm.infer_var_types()\n", "\n", "data_splitter_events = DataSplitterEvents(\n", diff --git a/examples/advanced/custom_output/custom_summarized_row.ipynb b/examples/advanced/custom_output/custom_summarized_row.ipynb index 43a3f1c..df47c0a 100644 --- a/examples/advanced/custom_output/custom_summarized_row.ipynb +++ b/examples/advanced/custom_output/custom_summarized_row.ipynb @@ -122,7 +122,7 @@ "dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n", "dm.process_indication_data()\n", "dm.setup_unique_mapping_of_events()\n", - "dm.setup_dataset_splits()\n", + "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", "dm.infer_var_types()" ] }, diff --git a/examples/advanced/custom_output/customizing_text_generation.ipynb b/examples/advanced/custom_output/customizing_text_generation.ipynb index aeba1ee..1e552d1 100644 --- a/examples/advanced/custom_output/customizing_text_generation.ipynb +++ b/examples/advanced/custom_output/customizing_text_generation.ipynb @@ -109,7 +109,7 @@ ")\n", "dm_default.process_indication_data()\n", "dm_default.setup_unique_mapping_of_events()\n", - "dm_default.setup_dataset_splits()\n", + "dm_default.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", "dm_default.infer_var_types()" ] }, @@ -620,7 +620,7 @@ ")\n", "dm_custom.process_indication_data()\n", "dm_custom.setup_unique_mapping_of_events()\n", - "dm_custom.setup_dataset_splits()\n", + "dm_custom.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", "dm_custom.infer_var_types()" ] }, diff --git a/examples/advanced/custom_splitting/inference_individual_splitters.py b/examples/advanced/custom_splitting/inference_individual_splitters.py index 0adacaf..3840087 100644 --- a/examples/advanced/custom_splitting/inference_individual_splitters.py +++ b/examples/advanced/custom_splitting/inference_individual_splitters.py @@ -43,7 +43,7 @@ def __init__( ) self.dm.process_indication_data() self.dm.setup_unique_mapping_of_events() - self.dm.setup_dataset_splits() + self.dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1) self.dm.infer_var_types() data_splitter_events = DataSplitterEvents( diff --git a/examples/advanced/custom_splitting/training_custom_split_events.ipynb b/examples/advanced/custom_splitting/training_custom_split_events.ipynb index b0f6673..387bb91 100644 --- a/examples/advanced/custom_splitting/training_custom_split_events.ipynb +++ b/examples/advanced/custom_splitting/training_custom_split_events.ipynb @@ -118,7 +118,7 @@ "dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n", "dm.process_indication_data()\n", "dm.setup_unique_mapping_of_events()\n", - "dm.setup_dataset_splits()\n", + "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", "dm.infer_var_types()" ] }, diff --git a/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb b/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb index fd6a63e..91af9bc 100644 --- a/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb +++ b/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb @@ -106,7 +106,7 @@ "dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n", "dm.process_indication_data()\n", "dm.setup_unique_mapping_of_events()\n", - "dm.setup_dataset_splits()\n", + "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", "dm.infer_var_types()\n", "\n", "\n", diff --git a/examples/advanced/custom_splitting/training_individual_splitters.ipynb b/examples/advanced/custom_splitting/training_individual_splitters.ipynb index 63b9674..70f2f01 100644 --- a/examples/advanced/custom_splitting/training_individual_splitters.ipynb +++ b/examples/advanced/custom_splitting/training_individual_splitters.ipynb @@ -112,7 +112,7 @@ "dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n", "dm.process_indication_data()\n", "dm.setup_unique_mapping_of_events()\n", - "dm.setup_dataset_splits()\n", + "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", "dm.infer_var_types()\n", "\n", "data_splitter_events = DataSplitterEvents(\n", diff --git a/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb b/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb index 0ac8652..3e60fa5 100644 --- a/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb +++ b/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb @@ -88,7 +88,7 @@ "dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n", "dm.process_indication_data()\n", "dm.setup_unique_mapping_of_events()\n", - "dm.setup_dataset_splits()\n", + "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", "dm.infer_var_types()\n", "\n", "converter = ConverterPretrain(config=config, dm=dm)" diff --git a/examples/advanced/pretraining/prepare_pretraining_data.py b/examples/advanced/pretraining/prepare_pretraining_data.py index 05c5d3b..80a5402 100644 --- a/examples/advanced/pretraining/prepare_pretraining_data.py +++ b/examples/advanced/pretraining/prepare_pretraining_data.py @@ -28,7 +28,7 @@ def __init__(self): ) self.dm.process_indication_data() self.dm.setup_unique_mapping_of_events() - self.dm.setup_dataset_splits() + self.dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1) #: set up converter self.converter = ConverterPretrain(config=self.config, dm=self.dm) diff --git a/examples/advanced/tte_inference/tte_probability_inference.ipynb b/examples/advanced/tte_inference/tte_probability_inference.ipynb index 056c219..9b6a861 100644 --- a/examples/advanced/tte_inference/tte_probability_inference.ipynb +++ b/examples/advanced/tte_inference/tte_probability_inference.ipynb @@ -188,7 +188,7 @@ ")\n", "dm.process_indication_data()\n", "dm.setup_unique_mapping_of_events()\n", - "dm.setup_dataset_splits()\n", + "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", "dm.infer_var_types()\n", "\n", "data_splitter_events = DataSplitterEvents(\n", diff --git a/examples/data_preprocessing/raw_data_preprocessing.ipynb b/examples/data_preprocessing/raw_data_preprocessing.ipynb index ea8fab0..836c465 100644 --- a/examples/data_preprocessing/raw_data_preprocessing.ipynb +++ b/examples/data_preprocessing/raw_data_preprocessing.ipynb @@ -1036,7 +1036,7 @@ "dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n", "dm.process_indication_data()\n", "dm.setup_unique_mapping_of_events()\n", - "dm.setup_dataset_splits()\n", + "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", "dm.infer_var_types()\n", "\n", "print(f\"Loaded {len(dm.all_patientids)} patients into DataManager\")" diff --git a/examples/integrations/meds_data_import.ipynb b/examples/integrations/meds_data_import.ipynb index 5467c57..5183219 100644 --- a/examples/integrations/meds_data_import.ipynb +++ b/examples/integrations/meds_data_import.ipynb @@ -499,7 +499,7 @@ ")\n", "dm.process_indication_data()\n", "dm.setup_unique_mapping_of_events()\n", - "dm.setup_dataset_splits()\n", + "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", "\n", "data_splitter_events = DataSplitterEvents(\n", " dm,\n", diff --git a/tests/test_common.py b/tests/test_common.py index 0c37b3d..cf3d60c 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -53,7 +53,7 @@ def test_data_manager_processing(mock_config, sample_data): # Run pipeline dm.process_indication_data() dm.setup_unique_mapping_of_events() - dm.setup_dataset_splits() + dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1) dm.infer_var_types() # 1. Check Date Processing diff --git a/tests/test_converter.py b/tests/test_converter.py index 4182775..6633bbe 100644 --- a/tests/test_converter.py +++ b/tests/test_converter.py @@ -21,7 +21,7 @@ def setup_components(mock_config, sample_data): dm.load_indication_data(df_events, df_constant, df_constant_desc) dm.process_indication_data() dm.setup_unique_mapping_of_events() - dm.setup_dataset_splits() + dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1) dm.infer_var_types() splitter_events = DataSplitterEvents( @@ -118,7 +118,7 @@ def test_event_categories_to_exclude_from_input(mock_config, sample_data): dm.load_indication_data(df_events, df_constant, df_constant_desc) dm.process_indication_data() dm.setup_unique_mapping_of_events() - dm.setup_dataset_splits() + dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1) dm.infer_var_types() splitter_events = DataSplitterEvents( @@ -183,7 +183,7 @@ def test_event_categories_to_exclude_multiple(mock_config, sample_data): dm.load_indication_data(df_events, df_constant, df_constant_desc) dm.process_indication_data() dm.setup_unique_mapping_of_events() - dm.setup_dataset_splits() + dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1) dm.infer_var_types() splitter_events = DataSplitterEvents( diff --git a/tests/test_converter_pretrain.py b/tests/test_converter_pretrain.py index 9c182d5..56804ba 100644 --- a/tests/test_converter_pretrain.py +++ b/tests/test_converter_pretrain.py @@ -15,7 +15,7 @@ def setup_pretrain_components(mock_config, sample_data): dm.load_indication_data(df_events, df_constant, df_constant_desc) dm.process_indication_data() dm.setup_unique_mapping_of_events() - dm.setup_dataset_splits() + dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1) converter = ConverterPretrain(config=mock_config, dm=dm) diff --git a/tests/test_splitter.py b/tests/test_splitter.py index d7a1913..7a6b60e 100644 --- a/tests/test_splitter.py +++ b/tests/test_splitter.py @@ -20,7 +20,7 @@ def initialized_dm(mock_config, sample_data): dm.load_indication_data(df_events, df_constant, df_constant_desc) dm.process_indication_data() dm.setup_unique_mapping_of_events() - dm.setup_dataset_splits() + dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1) dm.infer_var_types() return dm diff --git a/twinweaver/common/data_manager.py b/twinweaver/common/data_manager.py index d1bca65..30129fa 100644 --- a/twinweaver/common/data_manager.py +++ b/twinweaver/common/data_manager.py @@ -22,10 +22,6 @@ class DataManager: def __init__( self, config: Config, # Added config parameter - train_split_min: float = 0.8, - validation_split_max: float = 0.1, - test_split_max: float = 0.1, - max_val_test_nr_patients: int = 500, replace_special_symbols: list = None, ) -> None: """ @@ -39,21 +35,6 @@ def __init__( config : Config A configuration object containing paths, column names, category names, and other constants used throughout the data management process. - train_split_min : float, optional - The minimum proportion of patients to allocate to the training set. - Defaults to 0.8. The actual number will be the remainder after - allocating validation and test sets. - validation_split_max : float, optional - The maximum proportion of the total patients to allocate to the - validation set. The actual number is capped by - `max_val_test_nr_patients`. Defaults to 0.1. - test_split_max : float, optional - The maximum proportion of the total patients to allocate to the - test set. The actual number is capped by `max_val_test_nr_patients`. - Defaults to 0.1. - max_val_test_nr_patients : int, optional - The absolute maximum number of patients to include in the validation - and test sets combined. Defaults to 500. replace_special_symbols : list, optional A list of tuples to override the default special character replacements in event descriptive names. Each tuple should be in the format @@ -63,10 +44,6 @@ def __init__( #: initialize the data manager self.config = config # Store config object - self.train_split = train_split_min - self.validation_split = validation_split_max - self.test_split = test_split_max - self.max_val_test_nr_patients = max_val_test_nr_patients self.variable_types = {} # event_name -> "numeric" / "categorical" # Setup replacing of special symbol, format is event_category : (, ) @@ -333,8 +310,11 @@ def setup_unique_mapping_of_events(self) -> None: # Use config constant assert len(self.unique_events) == len(self.data_frames[events_table_key][event_desc_name_col].unique()) - def setup_dataset_splits( + def setup_hold_out_sets( self, + validation_split: float, + test_split: float, + max_val_test_nr_patients: int = None, ) -> None: """ Assigns each patient to a data split (train, validation, or test). @@ -353,6 +333,20 @@ def setup_dataset_splits( `self.all_patientids`. Asserts are performed to ensure the mapping covers all patients without overlap and that the split sizes match calculations. + Parameters + ---------- + validation_split : float + The proportion of the total patients to allocate to the + validation set. The actual number is capped by + `max_val_test_nr_patients` if provided. + test_split : float + The proportion of the total patients to allocate to the + test set. The actual number is capped by + `max_val_test_nr_patients` if provided. + max_val_test_nr_patients : int, optional + The absolute maximum number of patients to include in the validation + and test sets individually. Defaults to None. + Raises ------ ValueError @@ -383,13 +377,17 @@ def setup_dataset_splits( #: get min(self.validation_split * num_patients, self.max_val_test_nr_patients) validation_nr_patients = min( - int(self.validation_split * len(all_patients)), - self.max_val_test_nr_patients, + int(validation_split * len(all_patients)), + max_val_test_nr_patients + if max_val_test_nr_patients is not None + else int(validation_split * len(all_patients)), ) #: then the same for test - test_nr_patients = min(int(self.test_split * len(all_patients)), self.max_val_test_nr_patients) - + test_nr_patients = min( + int(test_split * len(all_patients)), + max_val_test_nr_patients if max_val_test_nr_patients is not None else int(test_split * len(all_patients)), + ) #: randomly shuffle with seed and split into train/val/test, using df.sample np.random.seed(self.config.seed) all_patients = np.random.permutation(all_patients) From fc5e6c4c6e9fa23b5622e5f872d1d9ef204291d9 Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 13:57:35 +0000 Subject: [PATCH 16/36] DM automatically converts event category, name and descriptive name to string --- twinweaver/common/data_manager.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/twinweaver/common/data_manager.py b/twinweaver/common/data_manager.py index 30129fa..7840cd0 100644 --- a/twinweaver/common/data_manager.py +++ b/twinweaver/common/data_manager.py @@ -207,6 +207,17 @@ def handle_missing_dates(df_key, missing_count, total_count, col_date): self.config.event_value_col ].astype(str) + # Convert event_descriptive_name, event_name, event_category to string as well to avoid issues later on + self.data_frames[events_table_key][self.config.event_descriptive_name_col] = self.data_frames[events_table_key][ + self.config.event_descriptive_name_col + ].astype(str) + self.data_frames[events_table_key][self.config.event_name_col] = self.data_frames[events_table_key][ + self.config.event_name_col + ].astype(str) + self.data_frames[events_table_key][self.config.event_category_col] = self.data_frames[events_table_key][ + self.config.event_category_col + ].astype(str) + logging.info("Data processed") def setup_unique_mapping_of_events(self) -> None: From 28bb03507fbf7401d3ae1a33bb4f7442f53dd743 Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 13:59:09 +0000 Subject: [PATCH 17/36] Fixed DataSplitterForecastingOption.events_until_split type in docstring --- twinweaver/instruction/data_splitter_forecasting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/twinweaver/instruction/data_splitter_forecasting.py b/twinweaver/instruction/data_splitter_forecasting.py index 998ef07..95590e5 100644 --- a/twinweaver/instruction/data_splitter_forecasting.py +++ b/twinweaver/instruction/data_splitter_forecasting.py @@ -15,9 +15,9 @@ class DataSplitterForecastingOption: Attributes ---------- - events_until_split : list + events_until_split : pd.DataFrame Events occurring until the split point. - target_events_after_split : list + target_events_after_split : pd.DataFrame Target events occurring after the split point. constant_data : dict Constant data related to the patient or context. From c37751442d80e4c42403677bcaea395957068168 Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 14:21:57 +0000 Subject: [PATCH 18/36] Now made forecasting qa an explicit decision --- README.md | 2 +- docs/data-splitting.md | 3 - .../01_data_preparation_for_training.ipynb | 1 - .../03_end_to_end_llm_finetuning.ipynb | 1 - .../custom_output/custom_summarized_row.ipynb | 4 - .../customizing_text_generation.ipynb | 2 - .../training_custom_split_events.ipynb | 1 - .../training_forecasting_qa.ipynb | 420 +++++++++++++++++ .../training_forecasting_splitter_only.ipynb | 1 - .../training_individual_splitters.ipynb | 1 - .../raw_data_preprocessing.ipynb | 1 - .../01_data_preparation_challenge.ipynb | 1 - docs/quickstart.md | 1 - docs/tutorials.md | 1 + .../01_data_preparation_for_training.ipynb | 1 - examples/03_end_to_end_llm_finetuning.ipynb | 1 - .../custom_output/custom_summarized_row.ipynb | 4 - .../customizing_text_generation.ipynb | 2 - .../training_custom_split_events.ipynb | 1 - .../training_forecasting_qa.ipynb | 425 ++++++++++++++++++ .../training_forecasting_splitter_only.ipynb | 1 - .../training_individual_splitters.ipynb | 1 - .../raw_data_preprocessing.ipynb | 1 - .../01_data_preparation_challenge.ipynb | 1 - .../instruction/converter_instruction.py | 4 +- 25 files changed, 849 insertions(+), 33 deletions(-) create mode 100644 docs/examples/advanced/custom_splitting/training_forecasting_qa.ipynb create mode 100644 examples/advanced/custom_splitting/training_forecasting_qa.ipynb diff --git a/README.md b/README.md index 2122905..fdea890 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,7 @@ For users needing custom behavior or specific integrations: * [`examples/advanced/custom_splitting/training_individual_splitters.ipynb`](examples/advanced/custom_splitting/training_individual_splitters.ipynb): Notebook demonstrating training data generation with individual splitters. * [`examples/advanced/custom_splitting/training_custom_split_events.ipynb`](examples/advanced/custom_splitting/training_custom_split_events.ipynb): Notebook showing how to customize split events and forecast different event categories. * [`examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb`](examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb): Forecasting-only example showing training data generation using only the `DataSplitterForecasting` (no event splitter). + * [`examples/advanced/custom_splitting/training_forecasting_qa.ipynb`](examples/advanced/custom_splitting/training_forecasting_qa.ipynb): Demonstrates the **Forecasting QA** mode, which bins continuous target values into discrete categories for classification-style prediction, and compares all three forecasting modes (`"forecasting"`, `"forecasting_qa"`, `"both"`). * **Custom Text Generation**: [`examples/advanced/custom_output/customizing_text_generation.ipynb`](examples/advanced/custom_output/customizing_text_generation.ipynb) * A comprehensive tutorial on customizing every textual component of the instruction generation pipeline. Learn how to modify preambles, event formatting, time units, genetic data tags, forecasting prompts, and more to adapt outputs for different LLMs, languages, or institutional requirements. * **Custom Summarized Row**: [`examples/advanced/custom_output/custom_summarized_row.ipynb`](examples/advanced/custom_output/custom_summarized_row.ipynb) @@ -175,7 +176,6 @@ split_idx = 0 training_data = converter.forward_conversion( forecasting_splits=forecasting_splits[split_idx], event_splits=events_splits[split_idx], - override_mode_to_select_forecasting="both", ) # training_data now contains (Input, Target) pairs ready for LLM fine-tuning diff --git a/docs/data-splitting.md b/docs/data-splitting.md index 4ae264e..526a07c 100644 --- a/docs/data-splitting.md +++ b/docs/data-splitting.md @@ -212,7 +212,6 @@ forecasting_splits, events_splits, reference_dates = \ converter.forward_conversion( forecasting_splits=forecasting_splits[0], event_splits=None, # No event splits available - override_mode_to_select_forecasting="forecasting", ) ``` @@ -229,7 +228,6 @@ forecasting_splits, events_splits, reference_dates = \ converter.forward_conversion( forecasting_splits=None, # No forecasting splits available event_splits=events_splits[0], - override_mode_to_select_forecasting="both", ) ``` @@ -322,7 +320,6 @@ converter = ConverterInstruction( result = converter.forward_conversion( forecasting_splits=forecasting_splits[0], event_splits=events_splits[0], - override_mode_to_select_forecasting="both", ) print(result["instruction"][:500]) diff --git a/docs/examples/01_data_preparation_for_training.ipynb b/docs/examples/01_data_preparation_for_training.ipynb index 24be08f..e74876d 100644 --- a/docs/examples/01_data_preparation_for_training.ipynb +++ b/docs/examples/01_data_preparation_for_training.ipynb @@ -345,7 +345,6 @@ "p_converted = converter.forward_conversion(\n", " forecasting_splits=forecasting_splits[split_idx],\n", " event_splits=events_splits[split_idx],\n", - " override_mode_to_select_forecasting=\"both\",\n", ")" ] }, diff --git a/docs/examples/03_end_to_end_llm_finetuning.ipynb b/docs/examples/03_end_to_end_llm_finetuning.ipynb index 33020d8..3cfda1d 100644 --- a/docs/examples/03_end_to_end_llm_finetuning.ipynb +++ b/docs/examples/03_end_to_end_llm_finetuning.ipynb @@ -219,7 +219,6 @@ " p_converted = converter.forward_conversion(\n", " forecasting_splits=forecasting_splits[split_idx],\n", " event_splits=events_splits[split_idx],\n", - " override_mode_to_select_forecasting=\"both\",\n", " )\n", " new_data = {\n", " \"prompt\": p_converted[\"instruction\"],\n", diff --git a/docs/examples/advanced/custom_output/custom_summarized_row.ipynb b/docs/examples/advanced/custom_output/custom_summarized_row.ipynb index df47c0a..e790228 100644 --- a/docs/examples/advanced/custom_output/custom_summarized_row.ipynb +++ b/docs/examples/advanced/custom_output/custom_summarized_row.ipynb @@ -209,7 +209,6 @@ "p_default = converter_default.forward_conversion(\n", " forecasting_splits=forecasting_splits[0],\n", " event_splits=events_splits[0],\n", - " override_mode_to_select_forecasting=\"both\",\n", ")\n", "\n", "print(\"=\" * 80)\n", @@ -314,7 +313,6 @@ "p_custom = converter_custom.forward_conversion(\n", " forecasting_splits=forecasting_splits[0],\n", " event_splits=events_splits[0],\n", - " override_mode_to_select_forecasting=\"both\",\n", ")\n", "\n", "print(\"=\" * 80)\n", @@ -417,7 +415,6 @@ "p_advanced = converter_advanced.forward_conversion(\n", " forecasting_splits=forecasting_splits[0],\n", " event_splits=events_splits[0],\n", - " override_mode_to_select_forecasting=\"both\",\n", ")\n", "\n", "print(\"=\" * 80)\n", @@ -490,7 +487,6 @@ " converter_broken.forward_conversion(\n", " forecasting_splits=forecasting_splits[0],\n", " event_splits=events_splits[0],\n", - " override_mode_to_select_forecasting=\"both\",\n", " )\n", "except TypeError as e:\n", " print(f\"Caught runtime error: {e}\")" diff --git a/docs/examples/advanced/custom_output/customizing_text_generation.ipynb b/docs/examples/advanced/custom_output/customizing_text_generation.ipynb index 1e552d1..03ea854 100644 --- a/docs/examples/advanced/custom_output/customizing_text_generation.ipynb +++ b/docs/examples/advanced/custom_output/customizing_text_generation.ipynb @@ -164,7 +164,6 @@ "p_converted_default = converter_default.forward_conversion(\n", " forecasting_splits=forecasting_splits[0],\n", " event_splits=events_splits[0],\n", - " override_mode_to_select_forecasting=\"both\",\n", ")\n", "\n", "print(\"=\" * 80)\n", @@ -675,7 +674,6 @@ "p_converted_custom = converter_custom.forward_conversion(\n", " forecasting_splits=forecasting_splits_custom[0],\n", " event_splits=events_splits_custom[0],\n", - " override_mode_to_select_forecasting=\"both\",\n", ")\n", "\n", "print(\"=\" * 80)\n", diff --git a/docs/examples/advanced/custom_splitting/training_custom_split_events.ipynb b/docs/examples/advanced/custom_splitting/training_custom_split_events.ipynb index 23b50cd..8a0bc01 100644 --- a/docs/examples/advanced/custom_splitting/training_custom_split_events.ipynb +++ b/docs/examples/advanced/custom_splitting/training_custom_split_events.ipynb @@ -241,7 +241,6 @@ "p_converted = converter.forward_conversion(\n", " forecasting_splits=forecasting_splits[split_idx],\n", " event_splits=events_splits[split_idx],\n", - " override_mode_to_select_forecasting=\"both\",\n", ")" ] }, diff --git a/docs/examples/advanced/custom_splitting/training_forecasting_qa.ipynb b/docs/examples/advanced/custom_splitting/training_forecasting_qa.ipynb new file mode 100644 index 0000000..c79a9a1 --- /dev/null +++ b/docs/examples/advanced/custom_splitting/training_forecasting_qa.ipynb @@ -0,0 +1,420 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Forecasting QA: Question-Answering Style Value Prediction\n", + "\n", + "This notebook demonstrates the **Forecasting QA** task mode in TwinWeaver.\n", + "\n", + "While the default `forward_conversion` mode (`\"forecasting\"`) asks the model to predict exact future values,\n", + "the **Forecasting QA** mode bins continuous target values into discrete categories (e.g., `A`, `B`, `C`)\n", + "and asks the model to predict the correct bin — turning regression into classification.\n", + "\n", + "This can be useful when:\n", + "- Exact numeric predictions are noisy or unreliable\n", + "- You want the model to reason about value ranges instead of point estimates\n", + "- You want to combine both forecasting and QA tasks (mode `\"both\"`)\n", + "\n", + "### How it works\n", + "\n", + "The `ConverterForecastingQA` sub-converter:\n", + "1. Computes bin edges from the variable statistics (`setup_statistics()` on the forecasting splitter)\n", + "2. Maps each target value to a lettered bin (A, B, C, …)\n", + "3. Generates a prompt listing the bin definitions and asks the model to choose the right bin\n", + "4. The target answer uses the bin letter instead of the raw numeric value\n", + "\n", + "### Selecting the mode\n", + "\n", + "The `forward_conversion` method accepts an `override_mode_to_select_forecasting` parameter:\n", + "- `\"forecasting\"` (default): numeric value prediction\n", + "- `\"forecasting_qa\"`: bin-based QA prediction\n", + "- `\"both\"`: includes both a forecasting and a forecasting QA task\n", + "- `None`: randomly selects one of the above at each call" + ] + }, + { + "cell_type": "markdown", + "id": "1", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "Load libraries and example data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "\n", + "from twinweaver import (\n", + " DataSplitterForecasting,\n", + " DataSplitterEvents,\n", + " DataSplitter,\n", + " DataManager,\n", + " ConverterInstruction,\n", + " Config,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "df_events = pd.read_csv(\"../../example_data/events.csv\")\n", + "df_constant = pd.read_csv(\"../../example_data/constant.csv\")\n", + "df_constant_description = pd.read_csv(\"../../example_data/constant_description.csv\")" + ] + }, + { + "cell_type": "markdown", + "id": "4", + "metadata": {}, + "source": [ + "## Configuration\n", + "\n", + "Set up the config, data manager, splitters, and converter.\n", + "\n", + "> **Important:** `variable_stats` from `DataSplitterForecasting.setup_statistics()` must be\n", + "> passed to `ConverterInstruction` for the QA mode to work — it provides the bin edges." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "config = Config()\n", + "\n", + "config.split_event_category = \"lot\"\n", + "config.event_category_forecast = [\"lab\"]\n", + "config.event_category_events_prediction_with_naming = {\n", + " \"death\": \"death\",\n", + " \"progression\": \"next progression\",\n", + "}\n", + "config.constant_columns_to_use = [\"birthyear\", \"gender\", \"histology\", \"smoking_history\"]\n", + "config.constant_birthdate_column = \"birthyear\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "dm = DataManager(config=config)\n", + "dm.load_indication_data(\n", + " df_events=df_events,\n", + " df_constant=df_constant,\n", + " df_constant_description=df_constant_description,\n", + ")\n", + "dm.process_indication_data()\n", + "dm.setup_unique_mapping_of_events()\n", + "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", + "dm.infer_var_types()\n", + "\n", + "data_splitter_events = DataSplitterEvents(\n", + " dm,\n", + " config=config,\n", + " max_length_to_sample=pd.Timedelta(weeks=104),\n", + " min_length_to_sample=pd.Timedelta(weeks=1),\n", + ")\n", + "data_splitter_events.setup_variables()\n", + "\n", + "data_splitter_forecasting = DataSplitterForecasting(\n", + " data_manager=dm,\n", + " config=config,\n", + " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n", + ")\n", + "# setup_statistics() computes per-variable quantile bin edges used by the QA mode\n", + "data_splitter_forecasting.setup_statistics()\n", + "\n", + "data_splitter = DataSplitter(\n", + " data_splitter_events=data_splitter_events,\n", + " data_splitter_forecasting=data_splitter_forecasting,\n", + ")\n", + "\n", + "converter = ConverterInstruction(\n", + " nr_tokens_budget_total=8192,\n", + " config=config,\n", + " dm=dm,\n", + " variable_stats=data_splitter_forecasting.variable_stats, # Required for QA mode\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "7", + "metadata": {}, + "source": [ + "## Generate Splits\n", + "\n", + "Get training splits for a patient." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "np.random.seed(42)\n", + "\n", + "patientid = dm.all_patientids[2]\n", + "patient_data = dm.get_patient_data(patientid)\n", + "\n", + "forecasting_splits, events_splits, reference_dates = data_splitter.get_splits_from_patient_with_target(\n", + " patient_data,\n", + " forecasting_nr_samples_per_split=4,\n", + " forecasting_filter_outliers=False,\n", + " max_num_splits_per_split_event=2,\n", + " events_max_nr_samples_per_split=3,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "9", + "metadata": {}, + "source": [ + "## Mode 1: Default Forecasting (numeric predictions)\n", + "\n", + "By default, `forward_conversion` uses `override_mode_to_select_forecasting=\"forecasting\"`,\n", + "which generates tasks that ask the model to predict exact numeric values." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "np.random.seed(42)\n", + "split_idx = 0\n", + "\n", + "p_forecasting = converter.forward_conversion(\n", + " forecasting_splits=forecasting_splits[split_idx],\n", + " event_splits=events_splits[split_idx],\n", + " # override_mode_to_select_forecasting=\"forecasting\" # this is the default\n", + ")\n", + "\n", + "print(\"=\" * 80)\n", + "print(\"INSTRUCTION (last 1500 chars):\")\n", + "print(\"=\" * 80)\n", + "print(p_forecasting[\"instruction\"][-1500:])\n", + "print()\n", + "print(\"=\" * 80)\n", + "print(\"ANSWER:\")\n", + "print(\"=\" * 80)\n", + "print(p_forecasting[\"answer\"])" + ] + }, + { + "cell_type": "markdown", + "id": "11", + "metadata": {}, + "source": [ + "## Mode 2: Forecasting QA (bin-based predictions)\n", + "\n", + "Setting `override_mode_to_select_forecasting=\"forecasting_qa\"` activates the QA mode.\n", + "The prompt now lists bin definitions (e.g., `a = Bin (-inf, 5.2]`) and asks the model\n", + "to output the bin letter instead of a raw number." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "np.random.seed(42)\n", + "split_idx = 0\n", + "\n", + "p_qa = converter.forward_conversion(\n", + " forecasting_splits=forecasting_splits[split_idx],\n", + " event_splits=events_splits[split_idx],\n", + " override_mode_to_select_forecasting=\"forecasting_qa\",\n", + ")\n", + "\n", + "print(\"=\" * 80)\n", + "print(\"INSTRUCTION (last 2000 chars):\")\n", + "print(\"=\" * 80)\n", + "print(p_qa[\"instruction\"][-2000:])\n", + "print()\n", + "print(\"=\" * 80)\n", + "print(\"ANSWER:\")\n", + "print(\"=\" * 80)\n", + "print(p_qa[\"answer\"])" + ] + }, + { + "cell_type": "markdown", + "id": "13", + "metadata": {}, + "source": [ + "### Inspect the bin definitions\n", + "\n", + "The metadata returned by the conversion includes the bin mapping for each variable.\n", + "These are derived from quantiles computed by `setup_statistics()`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "# The target meta contains the category splits for each variable\n", + "for task_meta in p_qa[\"meta\"][\"target_meta_detailed\"]:\n", + " if \"category_splits\" in task_meta:\n", + " print(\"Variable bin definitions:\")\n", + " for var, bins in task_meta[\"category_splits\"].items():\n", + " print(f\" {var}:\")\n", + " for letter, bin_range in bins.items():\n", + " print(f\" {letter} = {bin_range}\")" + ] + }, + { + "cell_type": "markdown", + "id": "15", + "metadata": {}, + "source": [ + "## Mode 3: Both (forecasting + QA in one prompt)\n", + "\n", + "Setting `override_mode_to_select_forecasting=\"both\"` includes **both** a numeric forecasting\n", + "task and a QA-style bin prediction task in the same multi-task prompt. This encourages the\n", + "model to learn complementary representations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "np.random.seed(42)\n", + "split_idx = 0\n", + "\n", + "p_both = converter.forward_conversion(\n", + " forecasting_splits=forecasting_splits[split_idx],\n", + " event_splits=events_splits[split_idx],\n", + " override_mode_to_select_forecasting=\"both\",\n", + ")\n", + "\n", + "print(\"=\" * 80)\n", + "print(\"INSTRUCTION (last 3000 chars):\")\n", + "print(\"=\" * 80)\n", + "print(p_both[\"instruction\"][-3000:])\n", + "print()\n", + "print(\"=\" * 80)\n", + "print(\"ANSWER:\")\n", + "print(\"=\" * 80)\n", + "print(p_both[\"answer\"])" + ] + }, + { + "cell_type": "markdown", + "id": "17", + "metadata": {}, + "source": [ + "### Compare the task types generated\n", + "\n", + "Let's inspect which task types were included in each mode." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18", + "metadata": {}, + "outputs": [], + "source": [ + "for label, result in [(\"forecasting\", p_forecasting), (\"forecasting_qa\", p_qa), (\"both\", p_both)]:\n", + " task_types = [m[\"task_type\"] for m in result[\"meta\"][\"target_meta_detailed\"]]\n", + " print(f\"Mode '{label}': {task_types}\")" + ] + }, + { + "cell_type": "markdown", + "id": "19", + "metadata": {}, + "source": [ + "## Reverse Conversion\n", + "\n", + "The reverse conversion works for all modes. It parses the model output back into\n", + "structured data, regardless of whether the target was numeric or bin-based." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20", + "metadata": {}, + "outputs": [], + "source": [ + "date = reference_dates[\"date\"][0]\n", + "\n", + "# Reverse convert the QA mode output\n", + "return_list = converter.reverse_conversion(p_qa[\"answer\"], dm, date)\n", + "\n", + "for task in return_list:\n", + " print(f\"Task type: {task['task_type']}\")\n", + " print(f\"Result: {task['result']}\")\n", + " print()" + ] + }, + { + "cell_type": "markdown", + "id": "21", + "metadata": {}, + "source": [ + "## Variable Statistics\n", + "\n", + "The bin edges come from the `variable_stats` DataFrame computed during `setup_statistics()`.\n", + "You can inspect them directly:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22", + "metadata": {}, + "outputs": [], + "source": [ + "data_splitter_forecasting.variable_stats" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb b/docs/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb index 7c07e0b..6f6a1cb 100644 --- a/docs/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb +++ b/docs/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb @@ -239,7 +239,6 @@ "p_converted = converter.forward_conversion(\n", " forecasting_splits=processed_splits_fc[split_idx],\n", " event_splits=[], # Not needed for forecasting-only splitter\n", - " override_mode_to_select_forecasting=\"forecasting\",\n", ")" ] }, diff --git a/docs/examples/advanced/custom_splitting/training_individual_splitters.ipynb b/docs/examples/advanced/custom_splitting/training_individual_splitters.ipynb index 71969f4..5b42911 100644 --- a/docs/examples/advanced/custom_splitting/training_individual_splitters.ipynb +++ b/docs/examples/advanced/custom_splitting/training_individual_splitters.ipynb @@ -258,7 +258,6 @@ "p_converted = converter.forward_conversion(\n", " forecasting_splits=processed_splits_fc[split_idx],\n", " event_splits=processed_splits_ev[split_idx],\n", - " override_mode_to_select_forecasting=\"forecasting_qa\",\n", ")" ] }, diff --git a/docs/examples/data_preprocessing/raw_data_preprocessing.ipynb b/docs/examples/data_preprocessing/raw_data_preprocessing.ipynb index 836c465..d7eeccd 100644 --- a/docs/examples/data_preprocessing/raw_data_preprocessing.ipynb +++ b/docs/examples/data_preprocessing/raw_data_preprocessing.ipynb @@ -1120,7 +1120,6 @@ " p_converted = converter.forward_conversion(\n", " forecasting_splits=forecasting_splits[split_idx],\n", " event_splits=events_splits[split_idx],\n", - " override_mode_to_select_forecasting=\"both\",\n", " )\n", "\n", " print(\"=\" * 80)\n", diff --git a/docs/examples/hackathon/01_data_preparation_challenge.ipynb b/docs/examples/hackathon/01_data_preparation_challenge.ipynb index 2461d78..a4984ff 100644 --- a/docs/examples/hackathon/01_data_preparation_challenge.ipynb +++ b/docs/examples/hackathon/01_data_preparation_challenge.ipynb @@ -565,7 +565,6 @@ "# Parameters:\n", "# - forecasting_splits: the forecasting split for one time point\n", "# - event_splits: the event split for one time point\n", - "# - override_mode_to_select_forecasting: set to \"both\"\n", "\n", "split_idx = 0\n", "p_converted = None # Replace with actual conversion call" diff --git a/docs/quickstart.md b/docs/quickstart.md index 207b919..1eabc19 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -88,7 +88,6 @@ split_idx = 0 # Use first split training_data = converter.forward_conversion( forecasting_splits=forecasting_splits[split_idx], event_splits=events_splits[split_idx], - override_mode_to_select_forecasting="both", ) # training_data now contains (Input, Target) pairs ready for LLM fine-tuning diff --git a/docs/tutorials.md b/docs/tutorials.md index 30924b9..d3befe6 100644 --- a/docs/tutorials.md +++ b/docs/tutorials.md @@ -95,6 +95,7 @@ A complete notebook demonstrating how to train LLMs on full patient histories wi - **Inference**: [`examples/advanced/custom_splitting/inference_individual_splitters.py`](examples/advanced/custom_splitting/inference_individual_splitters.py) — Example script for inference using individual splitters. - **Training**: [`examples/advanced/custom_splitting/training_individual_splitters.ipynb`](examples/advanced/custom_splitting/training_individual_splitters.ipynb) — Notebook demonstrating training data generation with individual splitters. - **Custom Split Events**: [`examples/advanced/custom_splitting/training_custom_split_events.ipynb`](examples/advanced/custom_splitting/training_custom_split_events.ipynb) — Notebook showing how to customize split events and forecast different event categories (e.g., using genetic events as split points and forecasting vitals). +- **Forecasting QA**: [`examples/advanced/custom_splitting/training_forecasting_qa.ipynb`](examples/advanced/custom_splitting/training_forecasting_qa.ipynb) — Demonstrates the **Forecasting QA** mode, which bins continuous target values into discrete categories for classification-style prediction. Compares all three forecasting modes (`"forecasting"`, `"forecasting_qa"`, `"both"`). ### TTE Probability Inference diff --git a/examples/01_data_preparation_for_training.ipynb b/examples/01_data_preparation_for_training.ipynb index 24be08f..e74876d 100644 --- a/examples/01_data_preparation_for_training.ipynb +++ b/examples/01_data_preparation_for_training.ipynb @@ -345,7 +345,6 @@ "p_converted = converter.forward_conversion(\n", " forecasting_splits=forecasting_splits[split_idx],\n", " event_splits=events_splits[split_idx],\n", - " override_mode_to_select_forecasting=\"both\",\n", ")" ] }, diff --git a/examples/03_end_to_end_llm_finetuning.ipynb b/examples/03_end_to_end_llm_finetuning.ipynb index 33020d8..3cfda1d 100644 --- a/examples/03_end_to_end_llm_finetuning.ipynb +++ b/examples/03_end_to_end_llm_finetuning.ipynb @@ -219,7 +219,6 @@ " p_converted = converter.forward_conversion(\n", " forecasting_splits=forecasting_splits[split_idx],\n", " event_splits=events_splits[split_idx],\n", - " override_mode_to_select_forecasting=\"both\",\n", " )\n", " new_data = {\n", " \"prompt\": p_converted[\"instruction\"],\n", diff --git a/examples/advanced/custom_output/custom_summarized_row.ipynb b/examples/advanced/custom_output/custom_summarized_row.ipynb index df47c0a..e790228 100644 --- a/examples/advanced/custom_output/custom_summarized_row.ipynb +++ b/examples/advanced/custom_output/custom_summarized_row.ipynb @@ -209,7 +209,6 @@ "p_default = converter_default.forward_conversion(\n", " forecasting_splits=forecasting_splits[0],\n", " event_splits=events_splits[0],\n", - " override_mode_to_select_forecasting=\"both\",\n", ")\n", "\n", "print(\"=\" * 80)\n", @@ -314,7 +313,6 @@ "p_custom = converter_custom.forward_conversion(\n", " forecasting_splits=forecasting_splits[0],\n", " event_splits=events_splits[0],\n", - " override_mode_to_select_forecasting=\"both\",\n", ")\n", "\n", "print(\"=\" * 80)\n", @@ -417,7 +415,6 @@ "p_advanced = converter_advanced.forward_conversion(\n", " forecasting_splits=forecasting_splits[0],\n", " event_splits=events_splits[0],\n", - " override_mode_to_select_forecasting=\"both\",\n", ")\n", "\n", "print(\"=\" * 80)\n", @@ -490,7 +487,6 @@ " converter_broken.forward_conversion(\n", " forecasting_splits=forecasting_splits[0],\n", " event_splits=events_splits[0],\n", - " override_mode_to_select_forecasting=\"both\",\n", " )\n", "except TypeError as e:\n", " print(f\"Caught runtime error: {e}\")" diff --git a/examples/advanced/custom_output/customizing_text_generation.ipynb b/examples/advanced/custom_output/customizing_text_generation.ipynb index 1e552d1..03ea854 100644 --- a/examples/advanced/custom_output/customizing_text_generation.ipynb +++ b/examples/advanced/custom_output/customizing_text_generation.ipynb @@ -164,7 +164,6 @@ "p_converted_default = converter_default.forward_conversion(\n", " forecasting_splits=forecasting_splits[0],\n", " event_splits=events_splits[0],\n", - " override_mode_to_select_forecasting=\"both\",\n", ")\n", "\n", "print(\"=\" * 80)\n", @@ -675,7 +674,6 @@ "p_converted_custom = converter_custom.forward_conversion(\n", " forecasting_splits=forecasting_splits_custom[0],\n", " event_splits=events_splits_custom[0],\n", - " override_mode_to_select_forecasting=\"both\",\n", ")\n", "\n", "print(\"=\" * 80)\n", diff --git a/examples/advanced/custom_splitting/training_custom_split_events.ipynb b/examples/advanced/custom_splitting/training_custom_split_events.ipynb index 387bb91..2d2c48a 100644 --- a/examples/advanced/custom_splitting/training_custom_split_events.ipynb +++ b/examples/advanced/custom_splitting/training_custom_split_events.ipynb @@ -242,7 +242,6 @@ "p_converted = converter.forward_conversion(\n", " forecasting_splits=None, # Set to None since we don't want to generate forecasting tasks\n", " event_splits=events_splits[split_idx],\n", - " override_mode_to_select_forecasting=\"both\",\n", ")" ] }, diff --git a/examples/advanced/custom_splitting/training_forecasting_qa.ipynb b/examples/advanced/custom_splitting/training_forecasting_qa.ipynb new file mode 100644 index 0000000..4d0fd0c --- /dev/null +++ b/examples/advanced/custom_splitting/training_forecasting_qa.ipynb @@ -0,0 +1,425 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Forecasting QA: Question-Answering Style Value Prediction\n", + "\n", + "This notebook demonstrates the **Forecasting QA** task mode in TwinWeaver.\n", + "\n", + "While the default `forward_conversion` mode (`\"forecasting\"`) asks the model to predict exact future values,\n", + "the **Forecasting QA** mode bins continuous target values into discrete categories (e.g., `A`, `B`, `C`)\n", + "and asks the model to predict the correct bin — turning regression into classification.\n", + "\n", + "This can be useful when:\n", + "- Exact numeric predictions are noisy or unreliable\n", + "- You want the model to reason about value ranges instead of point estimates\n", + "- You want to combine both forecasting and QA tasks (mode `\"both\"`)\n", + "\n", + "### How it works\n", + "\n", + "The `ConverterForecastingQA` sub-converter:\n", + "1. Computes bin edges from the variable statistics (`setup_statistics()` on the forecasting splitter)\n", + "2. Maps each target value to a lettered bin (A, B, C, …)\n", + "3. Generates a prompt listing the bin definitions and asks the model to choose the right bin\n", + "4. The target answer uses the bin letter instead of the raw numeric value\n", + "\n", + "### Selecting the mode\n", + "\n", + "The `forward_conversion` method accepts an `override_mode_to_select_forecasting` parameter:\n", + "- `\"forecasting\"` (default): numeric value prediction\n", + "- `\"forecasting_qa\"`: bin-based QA prediction\n", + "- `\"both\"`: includes both a forecasting and a forecasting QA task\n", + "- `None`: randomly selects one of the above at each call" + ] + }, + { + "cell_type": "markdown", + "id": "1", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "Load libraries and example data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "\n", + "from twinweaver import (\n", + " DataSplitterForecasting,\n", + " DataSplitterEvents,\n", + " DataSplitter,\n", + " DataManager,\n", + " ConverterInstruction,\n", + " Config,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "df_events = pd.read_csv(\"../../example_data/events.csv\")\n", + "df_constant = pd.read_csv(\"../../example_data/constant.csv\")\n", + "df_constant_description = pd.read_csv(\"../../example_data/constant_description.csv\")" + ] + }, + { + "cell_type": "markdown", + "id": "4", + "metadata": {}, + "source": [ + "## Configuration\n", + "\n", + "Set up the config, data manager, splitters, and converter.\n", + "\n", + "> **Important:** `variable_stats` from `DataSplitterForecasting.setup_statistics()` must be\n", + "> passed to `ConverterInstruction` for the QA mode to work — it provides the bin edges." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "config = Config()\n", + "\n", + "config.split_event_category = \"lot\"\n", + "config.event_category_forecast = [\"lab\"]\n", + "config.event_category_events_prediction_with_naming = {\n", + " \"death\": \"death\",\n", + " \"progression\": \"next progression\",\n", + "}\n", + "config.constant_columns_to_use = [\"birthyear\", \"gender\", \"histology\", \"smoking_history\"]\n", + "config.constant_birthdate_column = \"birthyear\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "dm = DataManager(config=config)\n", + "dm.load_indication_data(\n", + " df_events=df_events,\n", + " df_constant=df_constant,\n", + " df_constant_description=df_constant_description,\n", + ")\n", + "dm.process_indication_data()\n", + "dm.setup_unique_mapping_of_events()\n", + "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", + "dm.infer_var_types()\n", + "\n", + "data_splitter_events = DataSplitterEvents(\n", + " dm,\n", + " config=config,\n", + " max_length_to_sample=pd.Timedelta(weeks=104),\n", + " min_length_to_sample=pd.Timedelta(weeks=1),\n", + ")\n", + "data_splitter_events.setup_variables()\n", + "\n", + "data_splitter_forecasting = DataSplitterForecasting(\n", + " data_manager=dm,\n", + " config=config,\n", + " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n", + ")\n", + "# setup_statistics() computes per-variable quantile bin edges used by the QA mode\n", + "data_splitter_forecasting.setup_statistics()\n", + "\n", + "data_splitter = DataSplitter(\n", + " data_splitter_events=data_splitter_events,\n", + " data_splitter_forecasting=data_splitter_forecasting,\n", + ")\n", + "\n", + "converter = ConverterInstruction(\n", + " nr_tokens_budget_total=8192,\n", + " config=config,\n", + " dm=dm,\n", + " variable_stats=data_splitter_forecasting.variable_stats, # Required for QA mode\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "7", + "metadata": {}, + "source": [ + "## Generate Splits\n", + "\n", + "Get training splits for a patient." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "np.random.seed(42)\n", + "\n", + "patientid = dm.all_patientids[2]\n", + "patient_data = dm.get_patient_data(patientid)\n", + "\n", + "forecasting_splits, events_splits, reference_dates = data_splitter.get_splits_from_patient_with_target(\n", + " patient_data,\n", + " forecasting_nr_samples_per_split=4,\n", + " forecasting_filter_outliers=False,\n", + " max_num_splits_per_split_event=2,\n", + " events_max_nr_samples_per_split=3,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "9", + "metadata": {}, + "source": [ + "## Mode 1: Default Forecasting (numeric predictions)\n", + "\n", + "By default, `forward_conversion` uses `override_mode_to_select_forecasting=\"forecasting\"`,\n", + "which generates tasks that ask the model to predict exact numeric values." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "np.random.seed(42)\n", + "split_idx = 0\n", + "\n", + "p_forecasting = converter.forward_conversion(\n", + " forecasting_splits=forecasting_splits[split_idx],\n", + " event_splits=events_splits[split_idx],\n", + " # override_mode_to_select_forecasting=\"forecasting\" # this is the default\n", + ")\n", + "\n", + "print(\"=\" * 80)\n", + "print(\"INSTRUCTION (last 1500 chars):\")\n", + "print(\"=\" * 80)\n", + "print(p_forecasting[\"instruction\"][-1500:])\n", + "print()\n", + "print(\"=\" * 80)\n", + "print(\"ANSWER:\")\n", + "print(\"=\" * 80)\n", + "print(p_forecasting[\"answer\"])" + ] + }, + { + "cell_type": "markdown", + "id": "11", + "metadata": {}, + "source": [ + "## Mode 2: Forecasting QA (bin-based predictions)\n", + "\n", + "Setting `override_mode_to_select_forecasting=\"forecasting_qa\"` activates the QA mode.\n", + "The prompt now lists bin definitions (e.g., `a = Bin (-inf, 5.2]`) and asks the model\n", + "to output the bin letter instead of a raw number." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "np.random.seed(42)\n", + "split_idx = 0\n", + "\n", + "p_qa = converter.forward_conversion(\n", + " forecasting_splits=forecasting_splits[split_idx],\n", + " event_splits=events_splits[split_idx],\n", + " override_mode_to_select_forecasting=\"forecasting_qa\",\n", + ")\n", + "\n", + "print(\"=\" * 80)\n", + "print(\"INSTRUCTION (last 2000 chars):\")\n", + "print(\"=\" * 80)\n", + "print(p_qa[\"instruction\"][-2000:])\n", + "print()\n", + "print(\"=\" * 80)\n", + "print(\"ANSWER:\")\n", + "print(\"=\" * 80)\n", + "print(p_qa[\"answer\"])" + ] + }, + { + "cell_type": "markdown", + "id": "13", + "metadata": {}, + "source": [ + "### Inspect the bin definitions\n", + "\n", + "The metadata returned by the conversion includes the bin mapping for each variable.\n", + "These are derived from quantiles computed by `setup_statistics()`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "# The target meta contains the category splits for each variable\n", + "for task_meta in p_qa[\"meta\"][\"target_meta_detailed\"]:\n", + " if \"category_splits\" in task_meta:\n", + " print(\"Variable bin definitions:\")\n", + " for var, bins in task_meta[\"category_splits\"].items():\n", + " print(f\" {var}:\")\n", + " for letter, bin_range in bins.items():\n", + " print(f\" {letter} = {bin_range}\")" + ] + }, + { + "cell_type": "markdown", + "id": "15", + "metadata": {}, + "source": [ + "## Mode 3: Both (forecasting + QA in one prompt)\n", + "\n", + "Setting `override_mode_to_select_forecasting=\"both\"` includes **both** a numeric forecasting\n", + "task and a QA-style bin prediction task in the same multi-task prompt. This encourages the\n", + "model to learn complementary representations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "np.random.seed(42)\n", + "split_idx = 0\n", + "\n", + "p_both = converter.forward_conversion(\n", + " forecasting_splits=forecasting_splits[split_idx],\n", + " event_splits=events_splits[split_idx],\n", + " override_mode_to_select_forecasting=\"both\",\n", + ")\n", + "\n", + "print(\"=\" * 80)\n", + "print(\"INSTRUCTION (last 3000 chars):\")\n", + "print(\"=\" * 80)\n", + "print(p_both[\"instruction\"][-3000:])\n", + "print()\n", + "print(\"=\" * 80)\n", + "print(\"ANSWER:\")\n", + "print(\"=\" * 80)\n", + "print(p_both[\"answer\"])" + ] + }, + { + "cell_type": "markdown", + "id": "17", + "metadata": {}, + "source": [ + "### Compare the task types generated\n", + "\n", + "Let's inspect which task types were included in each mode." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18", + "metadata": {}, + "outputs": [], + "source": [ + "for label, result in [(\"forecasting\", p_forecasting), (\"forecasting_qa\", p_qa), (\"both\", p_both)]:\n", + " task_types = [m[\"task_type\"] for m in result[\"meta\"][\"target_meta_detailed\"]]\n", + " print(f\"Mode '{label}': {task_types}\")" + ] + }, + { + "cell_type": "markdown", + "id": "19", + "metadata": {}, + "source": [ + "## Reverse Conversion\n", + "\n", + "The reverse conversion works for all modes. It parses the model output back into\n", + "structured data, regardless of whether the target was numeric or bin-based." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20", + "metadata": {}, + "outputs": [], + "source": [ + "date = reference_dates[\"date\"][0]\n", + "\n", + "# Reverse convert the QA mode output\n", + "return_list = converter.reverse_conversion(p_qa[\"answer\"], dm, date)\n", + "\n", + "return_list[2][\"result\"]" + ] + }, + { + "cell_type": "markdown", + "id": "21", + "metadata": {}, + "source": [ + "## Variable Statistics\n", + "\n", + "The bin edges come from the `variable_stats` DataFrame computed during `setup_statistics()`.\n", + "You can inspect them directly:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22", + "metadata": {}, + "outputs": [], + "source": [ + "data_splitter_forecasting.variable_stats" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv_dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb b/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb index 91af9bc..409abe2 100644 --- a/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb +++ b/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb @@ -243,7 +243,6 @@ "p_converted = converter.forward_conversion(\n", " forecasting_splits=forecasting_splits[split_idx],\n", " event_splits=None, # Not needed for forecasting-only splitter\n", - " override_mode_to_select_forecasting=\"forecasting\",\n", ")" ] }, diff --git a/examples/advanced/custom_splitting/training_individual_splitters.ipynb b/examples/advanced/custom_splitting/training_individual_splitters.ipynb index 70f2f01..56718f3 100644 --- a/examples/advanced/custom_splitting/training_individual_splitters.ipynb +++ b/examples/advanced/custom_splitting/training_individual_splitters.ipynb @@ -259,7 +259,6 @@ "p_converted = converter.forward_conversion(\n", " forecasting_splits=forecasting_splits[split_idx],\n", " event_splits=events_splits[split_idx],\n", - " override_mode_to_select_forecasting=\"forecasting_qa\",\n", ")" ] }, diff --git a/examples/data_preprocessing/raw_data_preprocessing.ipynb b/examples/data_preprocessing/raw_data_preprocessing.ipynb index 836c465..d7eeccd 100644 --- a/examples/data_preprocessing/raw_data_preprocessing.ipynb +++ b/examples/data_preprocessing/raw_data_preprocessing.ipynb @@ -1120,7 +1120,6 @@ " p_converted = converter.forward_conversion(\n", " forecasting_splits=forecasting_splits[split_idx],\n", " event_splits=events_splits[split_idx],\n", - " override_mode_to_select_forecasting=\"both\",\n", " )\n", "\n", " print(\"=\" * 80)\n", diff --git a/examples/hackathon/01_data_preparation_challenge.ipynb b/examples/hackathon/01_data_preparation_challenge.ipynb index 2461d78..a4984ff 100644 --- a/examples/hackathon/01_data_preparation_challenge.ipynb +++ b/examples/hackathon/01_data_preparation_challenge.ipynb @@ -565,7 +565,6 @@ "# Parameters:\n", "# - forecasting_splits: the forecasting split for one time point\n", "# - event_splits: the event split for one time point\n", - "# - override_mode_to_select_forecasting: set to \"both\"\n", "\n", "split_idx = 0\n", "p_converted = None # Replace with actual conversion call" diff --git a/twinweaver/instruction/converter_instruction.py b/twinweaver/instruction/converter_instruction.py index 4d20d1e..bf75ca5 100644 --- a/twinweaver/instruction/converter_instruction.py +++ b/twinweaver/instruction/converter_instruction.py @@ -245,7 +245,7 @@ def forward_conversion( self, forecasting_splits: list[DataSplitterForecastingGroup], event_splits: list[DataSplitterEventsGroup], - override_mode_to_select_forecasting: str = None, + override_mode_to_select_forecasting: str = "forecasting", ) -> dict: """ Generates a multi-task instruction prompt and target answer from patient data splits. @@ -268,7 +268,7 @@ def forward_conversion( (containing patient history up to a split date and event outcome/censoring info). override_mode_to_select_forecasting : str, optional If provided, forces the selection mode for forecasting tasks ('forecasting', - 'forecasting_qa', or 'both'). If None, the mode is chosen randomly. Defaults to None. + 'forecasting_qa', or 'both'). If None, the mode is chosen randomly. Defaults to "forecasting". Returns ------- From 105e92c97378070b1d17f0562857c1af7609de78 Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 14:25:19 +0000 Subject: [PATCH 19/36] DM now checks for missing values in event name, descriptive name and category --- twinweaver/common/data_manager.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/twinweaver/common/data_manager.py b/twinweaver/common/data_manager.py index 7840cd0..b743f8c 100644 --- a/twinweaver/common/data_manager.py +++ b/twinweaver/common/data_manager.py @@ -202,6 +202,20 @@ def handle_missing_dates(df_key, missing_count, total_count, col_date): "please fix the data or set drop_missing_event_values=True" ) + # Check for missing values in event_descriptive_name, event_name, and event_category columns + for col in [ + self.config.event_descriptive_name_col, + self.config.event_name_col, + self.config.event_category_col, + ]: + missing_count = self.data_frames[events_table_key][col].isnull().sum() + if missing_count > 0: + total = len(self.data_frames[events_table_key]) + raise ValueError( + f"Found {missing_count} out of {total} missing values in '{col}' column " + f"in events table - please fix the data before proceeding" + ) + # Convert all event values to string self.data_frames[events_table_key][self.config.event_value_col] = self.data_frames[events_table_key][ self.config.event_value_col From 2e6cbaf34ce5e588022b8b69b684a3d712d1464c Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 14:45:12 +0000 Subject: [PATCH 20/36] Added renormalization, excluding censoring --- twinweaver/utils/tte_inference.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/twinweaver/utils/tte_inference.py b/twinweaver/utils/tte_inference.py index 003a651..7ddcd47 100644 --- a/twinweaver/utils/tte_inference.py +++ b/twinweaver/utils/tte_inference.py @@ -645,4 +645,14 @@ def _hard_prediction(row): df["probability_no_occurrence"] = df[f"softmax_{LABEL_NOT_OCCURRED}"] df["probability_censored"] = df[f"softmax_{LABEL_CENSORED}"] + # 5. Add in renormalized probabilities that exclude the censored class + # (for some analyses it may be useful to look at the relative probabilities of occurrence vs no occurrence, + # ignoring censoring). + df["probability_occurrence_renormalized"] = df["probability_occurrence"] / ( + df["probability_occurrence"] + df["probability_no_occurrence"] + ) + df["probability_no_occurrence_renormalized"] = df["probability_no_occurrence"] / ( + df["probability_occurrence"] + df["probability_no_occurrence"] + ) + return df From 053141f7f8dc8ee2c83a0b90138607f039bd760b Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 14:55:38 +0000 Subject: [PATCH 21/36] Added function for VLLM based forecsting inference and notebook --- .../forecasting_vllm_inference.ipynb | 658 ++++++++++++++++++ docs/reference/utils/forecasting_inference.md | 3 + .../forecasting_vllm_inference.ipynb | 658 ++++++++++++++++++ mkdocs.yml | 2 + twinweaver/__init__.py | 5 + twinweaver/utils/__init__.py | 5 + twinweaver/utils/forecasting_inference.py | 474 +++++++++++++ 7 files changed, 1805 insertions(+) create mode 100644 docs/examples/advanced/forecasting_inference/forecasting_vllm_inference.ipynb create mode 100644 docs/reference/utils/forecasting_inference.md create mode 100644 examples/advanced/forecasting_inference/forecasting_vllm_inference.ipynb create mode 100644 twinweaver/utils/forecasting_inference.py diff --git a/docs/examples/advanced/forecasting_inference/forecasting_vllm_inference.ipynb b/docs/examples/advanced/forecasting_inference/forecasting_vllm_inference.ipynb new file mode 100644 index 0000000..c2b2826 --- /dev/null +++ b/docs/examples/advanced/forecasting_inference/forecasting_vllm_inference.ipynb @@ -0,0 +1,658 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Forecasting Inference with TwinWeaver and vLLM\n", + "\n", + "This notebook demonstrates how to use TwinWeaver's **forecasting inference pipeline**\n", + "to predict future clinical values (e.g., lab results like hemoglobin) using a\n", + "fine-tuned LLM served via [vLLM](https://github.com/vllm-project/vllm).\n", + "\n", + "Unlike the TTE (time-to-event) probability pipeline, which *scores* fixed\n", + "completions, this pipeline uses **free-text generation**: the model produces\n", + "an answer string that is then **reverse-converted** back into a structured\n", + "DataFrame with predicted dates and values.\n", + "\n", + "### Pipeline overview\n", + "\n", + "```\n", + "Patient data ──► DataSplitter (forecasting) ──► ConverterInstruction\n", + " │\n", + " instruction text per patient\n", + " │\n", + " ┌──────────────────────────────┘\n", + " ▼\n", + " vLLM server (OpenAI-compatible API)\n", + " │\n", + " generated text completions\n", + " │\n", + " ▼\n", + " parse_forecasting_results()\n", + " (calls converter.reverse_conversion internally)\n", + " │\n", + " structured DataFrame\n", + " with predicted dates & values\n", + "```\n", + "\n", + "> **⚠️ Important:** The quality of the forecasts depends critically\n", + "> on having a **fine-tuned model**. An off-the-shelf instruction model\n", + "> (like the default `microsoft/Phi-4-mini-instruct` used here for demonstration)\n", + "> will produce **meaningless predictions**. Always fine-tune on your\n", + "> clinical dataset first (see `03_end_to_end_llm_finetuning.ipynb`).\n", + "\n", + "> **Requirements:**\n", + "> - A GPU with enough memory to serve the model via vLLM (≥16 GB for a 4-bit 8B model)\n", + "> - `pip install twinweaver[fine-tuning-example] vllm openai`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "import subprocess\n", + "import time\n", + "import sys\n", + "import os\n", + "\n", + "import pandas as pd\n", + "\n", + "from twinweaver import (\n", + " DataManager,\n", + " Config,\n", + " DataSplitterForecasting,\n", + " DataSplitterEvents,\n", + " DataSplitter,\n", + " ConverterInstruction,\n", + " run_forecasting_inference_notebook,\n", + " parse_forecasting_results,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "## 1. Configuration\n", + "\n", + "We define all key settings up front so they are easy to change in one place.\n", + "\n", + "> **Note:** Replace `MODEL_PATH` with the path to your **fine-tuned** model for\n", + "> meaningful results. The default `microsoft/Phi-4-mini-instruct` is only used\n", + "> here so that the notebook is self-contained." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "# ---------------------------------------------------------------------------\n", + "# Model & server settings\n", + "# ---------------------------------------------------------------------------\n", + "MODEL_PATH = \"microsoft/Phi-4-mini-instruct\" # ⚠️ Replace with your fine-tuned model path\n", + "TOKENIZER_PATH = MODEL_PATH # Usually the same as the model path\n", + "\n", + "VLLM_PORT = 8000\n", + "PREDICTION_URL = f\"http://0.0.0.0:{VLLM_PORT}/v1/\"\n", + "\n", + "MAX_CONTEXT_LENGTH = 4096 # Must match what the model was trained with\n", + "\n", + "# Generation settings\n", + "MAX_NEW_TOKENS = 256 # Max tokens for the generated forecast answer\n", + "TEMPERATURE = 0.9 # Sampling temperature (0 = greedy)\n", + "TOP_P = 0.9 # Nucleus sampling\n", + "N_SAMPLES = 3 # Number of independent samples per patient (>1 enables aggregation)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "# ---------------------------------------------------------------------------\n", + "# TwinWeaver data settings (same as used during training)\n", + "# ---------------------------------------------------------------------------\n", + "config = Config()\n", + "\n", + "# 1. Event category used for data splitting\n", + "config.split_event_category = \"lot\"\n", + "\n", + "# 2. Forecasting categories\n", + "config.event_category_forecast = [\"lab\"]\n", + "\n", + "# 3. Time-to-event variables (needed to initialise splitters, even if we only forecast)\n", + "config.event_category_events_prediction_with_naming = {\n", + " \"death\": \"death\",\n", + " \"progression\": \"next progression\",\n", + "}\n", + "\n", + "# 4. Constant (static) columns\n", + "config.constant_columns_to_use = [\"birthyear\", \"gender\", \"histology\", \"smoking_history\"]\n", + "config.constant_birthdate_column = \"birthyear\"" + ] + }, + { + "cell_type": "markdown", + "id": "5", + "metadata": {}, + "source": [ + "## 2. Load and prepare data\n", + "\n", + "We use the same example data shipped with TwinWeaver. In a real scenario you\n", + "would load your own clinical dataset here." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "# Load example data (adjust paths if running from a different directory)\n", + "df_events = pd.read_csv(\"../../example_data/events.csv\")\n", + "df_constant = pd.read_csv(\"../../example_data/constant.csv\")\n", + "df_constant_description = pd.read_csv(\"../../example_data/constant_description.csv\")\n", + "\n", + "print(f\"Loaded {len(df_events)} events for {df_events['patientid'].nunique()} patients\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "# Initialise DataManager and all splitters\n", + "dm = DataManager(config=config)\n", + "dm.load_indication_data(\n", + " df_events=df_events,\n", + " df_constant=df_constant,\n", + " df_constant_description=df_constant_description,\n", + ")\n", + "dm.process_indication_data()\n", + "dm.setup_unique_mapping_of_events()\n", + "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", + "dm.infer_var_types()\n", + "\n", + "data_splitter_events = DataSplitterEvents(\n", + " dm,\n", + " config=config,\n", + " max_length_to_sample=pd.Timedelta(weeks=104),\n", + " min_length_to_sample=pd.Timedelta(weeks=1),\n", + ")\n", + "data_splitter_events.setup_variables()\n", + "\n", + "data_splitter_forecasting = DataSplitterForecasting(\n", + " data_manager=dm,\n", + " config=config,\n", + " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n", + ")\n", + "data_splitter_forecasting.setup_statistics()\n", + "\n", + "# Combined interface\n", + "data_splitter = DataSplitter(data_splitter_events, data_splitter_forecasting)\n", + "\n", + "converter = ConverterInstruction(\n", + " nr_tokens_budget_total=MAX_CONTEXT_LENGTH,\n", + " config=config,\n", + " dm=dm,\n", + " variable_stats=data_splitter_forecasting.variable_stats,\n", + ")\n", + "\n", + "print(\"✅ Data pipeline ready\")" + ] + }, + { + "cell_type": "markdown", + "id": "8", + "metadata": {}, + "source": [ + "## 3. Generate forecasting instruction prompts\n", + "\n", + "For each test patient we generate a **forecasting-only** instruction prompt.\n", + "The key parameters are:\n", + "\n", + "- `forecasting_override_variables_to_predict` – which variables to forecast\n", + " (e.g. hemoglobin)\n", + "- `forecasting_future_weeks_per_variable` – at which future time points\n", + " (in weeks) to request predictions\n", + "\n", + "The converter produces a text instruction asking the model to predict the\n", + "specified variable(s) at the given future time points." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "# Choose patients to evaluate (here: all test-set patients)\n", + "test_patientids = dm.get_all_patientids_in_split(config.test_split_name)\n", + "print(f\"Number of test patients: {len(test_patientids)}\")\n", + "\n", + "# Define the prediction task:\n", + "# Which variables to predict and at which future week offsets\n", + "VARIABLES_TO_PREDICT = [\"hemoglobin_-_718-7\"]\n", + "FUTURE_WEEKS = {\n", + " \"hemoglobin_-_718-7\": [4, 8, 12], # Predict hemoglobin at 4, 8, and 12 weeks\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "# Build the list of prompt payloads for the forecasting API\n", + "# Each payload is a dict with \"patientid\", \"instruction\", and \"split_date\"\n", + "prompts_with_meta: list[dict] = []\n", + "\n", + "for pid in test_patientids:\n", + " patient_data = dm.get_patient_data(pid)\n", + "\n", + " # Use only data up to first lot event (simulates baseline information)\n", + " patient_data[\"events\"] = patient_data[\"events\"].sort_values(\"date\")\n", + " first_lot_date = patient_data[\"events\"][patient_data[\"events\"][\"event_category\"] == \"lot\"][\"date\"].min()\n", + " assert pd.notna(first_lot_date), f\"Patient {pid} has no lot event\"\n", + " patient_data[\"events\"] = patient_data[\"events\"][patient_data[\"events\"][\"date\"] <= first_lot_date].copy()\n", + "\n", + " # Generate the forecasting-only split for inference\n", + " forecast_split, _ = data_splitter.get_splits_from_patient_inference(\n", + " patient_data,\n", + " inference_type=\"forecasting\",\n", + " forecasting_override_variables_to_predict=VARIABLES_TO_PREDICT,\n", + " )\n", + "\n", + " # Convert to instruction text (no target answer)\n", + " converted = converter.forward_conversion_inference(\n", + " forecasting_split=forecast_split,\n", + " forecasting_future_weeks_per_variable=FUTURE_WEEKS,\n", + " )\n", + "\n", + " # Collect the prompt payload\n", + " prompts_with_meta.append(\n", + " {\n", + " \"patientid\": pid,\n", + " \"instruction\": converted[\"instruction\"],\n", + " \"split_date\": forecast_split.split_date_included_in_input,\n", + " }\n", + " )\n", + "\n", + "print(f\"Generated {len(prompts_with_meta)} instruction prompts\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "# Let's inspect one instruction to see what the model will receive\n", + "sample = prompts_with_meta[0]\n", + "print(f\"=== Patient: {sample['patientid']} ===\")\n", + "print(f\"Split date: {sample['split_date']}\\n\")\n", + "print(sample[\"instruction\"])" + ] + }, + { + "cell_type": "markdown", + "id": "12", + "metadata": {}, + "source": [ + "## 4. Launch the vLLM server\n", + "\n", + "We launch a vLLM OpenAI-compatible server as a background process.\n", + "\n", + "> **If you already have a vLLM server running**, skip this cell and just update\n", + "> `PREDICTION_URL` and `MODEL_PATH` in the configuration section above.\n", + "\n", + "> **Tip:** For production use, launch the server in a separate terminal." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "# Launch vLLM server as a background subprocess\n", + "# Set to True to skip launching (if you already have a server running)\n", + "SKIP_VLLM_LAUNCH = False\n", + "\n", + "vllm_process = None\n", + "\n", + "if not SKIP_VLLM_LAUNCH:\n", + " env = os.environ.copy()\n", + " env[\"VLLM_ATTENTION_BACKEND\"] = \"FLASH_ATTN\"\n", + "\n", + " vllm_command = [\n", + " sys.executable,\n", + " \"-m\",\n", + " \"vllm.entrypoints.openai.api_server\",\n", + " \"--port\",\n", + " str(VLLM_PORT),\n", + " \"--model\",\n", + " MODEL_PATH,\n", + " \"--tokenizer\",\n", + " TOKENIZER_PATH,\n", + " \"--enable-prefix-caching\",\n", + " ]\n", + "\n", + " print(f\"🚀 Launching vLLM server:\\n {' '.join(vllm_command)}\\n\")\n", + "\n", + " vllm_process = subprocess.Popen(\n", + " vllm_command,\n", + " env=env,\n", + " stdout=subprocess.PIPE,\n", + " stderr=subprocess.STDOUT,\n", + " text=True,\n", + " )\n", + "\n", + " # Wait for the server to be ready\n", + " WAIT_SECONDS = 240\n", + " print(f\"⏳ Waiting up to {WAIT_SECONDS}s for the server to start...\")\n", + " import urllib.request\n", + "\n", + " for i in range(WAIT_SECONDS):\n", + " time.sleep(1)\n", + " try:\n", + " urllib.request.urlopen(f\"http://localhost:{VLLM_PORT}/health\")\n", + " print(f\"✅ vLLM server is ready after {i + 1}s\")\n", + " break\n", + " except Exception:\n", + " pass\n", + " else:\n", + " print(\"⚠️ Server did not respond in time. Check GPU memory and logs.\")\n", + " print(\" You can read server output with: vllm_process.stdout.read()\")\n", + "else:\n", + " print(\"Skipping vLLM launch – using existing server.\")" + ] + }, + { + "cell_type": "markdown", + "id": "14", + "metadata": {}, + "source": [ + "## 5. Run forecasting inference\n", + "\n", + "This is the core step. `run_forecasting_inference_notebook` sends each patient's\n", + "instruction to the vLLM server and collects the generated text completions.\n", + "\n", + "Under the hood it:\n", + "1. Wraps each instruction in a chat message (with optional system prompt).\n", + "2. Calls the OpenAI-compatible `/v1/chat/completions` endpoint.\n", + "3. Returns the generated text(s) alongside patient metadata.\n", + "\n", + "When `n_samples > 1`, multiple independent completions are generated per\n", + "patient, which can later be aggregated into a mean trajectory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "# Run the forecasting inference\n", + "# This calls the vLLM server asynchronously for all patients\n", + "raw_results = await run_forecasting_inference_notebook(\n", + " prompts_with_meta,\n", + " prediction_url=PREDICTION_URL,\n", + " prediction_model=MODEL_PATH,\n", + " max_concurrent_requests=40,\n", + " max_new_tokens=MAX_NEW_TOKENS,\n", + " temperature=TEMPERATURE,\n", + " top_p=TOP_P,\n", + " n_samples=N_SAMPLES,\n", + " api_key=\"EMPTY\",\n", + " timeout=600.0,\n", + ")\n", + "\n", + "# Check for failures\n", + "n_success = sum(1 for r in raw_results if r is not None)\n", + "n_fail = sum(1 for r in raw_results if r is None)\n", + "print(f\"✅ Generated forecasts for {n_success} patients, {n_fail} failures\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "# Let's inspect the raw generated text for one patient\n", + "for r in raw_results:\n", + " if r is not None:\n", + " print(f\"=== Patient: {r['patientid']} ===\")\n", + " for i, text in enumerate(r[\"generated_texts\"]):\n", + " print(f\"\\n--- Sample {i} ---\")\n", + " print(text)\n", + " break" + ] + }, + { + "cell_type": "markdown", + "id": "17", + "metadata": {}, + "source": [ + "## 6. Parse results with reverse conversion\n", + "\n", + "`parse_forecasting_results` takes the raw generated texts and:\n", + "\n", + "1. Calls `converter.reverse_conversion` on each generated text to parse it\n", + " back into a structured DataFrame with dates and predicted values.\n", + "2. When `n_samples > 1` and `aggregate_samples=True`, aggregates multiple\n", + " trajectories using `converter.aggregate_multiple_responses` (e.g. averaging\n", + " numeric predictions).\n", + "3. Returns a single long-format DataFrame with all patients' predictions.\n", + "\n", + "> **Note:** Reverse conversion is robust to slightly malformed model output\n", + "> thanks to `inference_override=True`, but a fine-tuned model will produce\n", + "> much more parseable results than a generic instruction model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18", + "metadata": {}, + "outputs": [], + "source": [ + "# Parse the generated texts into structured DataFrames\n", + "df_results = parse_forecasting_results(\n", + " raw_results,\n", + " converter,\n", + " dm,\n", + " drop_failures=True,\n", + " aggregate_samples=(N_SAMPLES > 1), # Only aggregate if we have multiple samples\n", + ")\n", + "\n", + "print(f\"Parsed {len(df_results)} prediction rows for {df_results['patientid'].nunique()} patients\")\n", + "df_results.head(20)" + ] + }, + { + "cell_type": "markdown", + "id": "19", + "metadata": {}, + "source": [ + "### Understanding the output\n", + "\n", + "The returned DataFrame has the standard TwinWeaver event format:\n", + "\n", + "| Column | Description |\n", + "|---|---|\n", + "| `date` | Predicted date (computed from split_date + week offset) |\n", + "| `event_name` | The variable being predicted (e.g. `hemoglobin_-_718-7`) |\n", + "| `event_value` | The predicted value |\n", + "| `event_category` | Category of the event (e.g. `lab`) |\n", + "| `patientid` | Patient identifier |\n", + "| `task_type` | Which task type produced this prediction |\n", + "| `sample_idx` | Sample index (when `aggregate_samples=False` and `n_samples > 1`) |" + ] + }, + { + "cell_type": "markdown", + "id": "20", + "metadata": {}, + "source": [ + "## 7. Multi-sample aggregation (optional)\n", + "\n", + "When using `n_samples > 1`, each patient gets multiple independent forecast\n", + "trajectories. Aggregation (e.g. averaging numeric predictions) can reduce\n", + "variance and give more robust estimates.\n", + "\n", + "Below is an example of running with multiple samples and then aggregating." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21", + "metadata": {}, + "outputs": [], + "source": [ + "# Example: generate 3 samples per patient and aggregate\n", + "N_SAMPLES_AGG = 3\n", + "\n", + "raw_results_multi = await run_forecasting_inference_notebook(\n", + " prompts_with_meta,\n", + " prediction_url=PREDICTION_URL,\n", + " prediction_model=MODEL_PATH,\n", + " max_concurrent_requests=40,\n", + " max_new_tokens=MAX_NEW_TOKENS,\n", + " temperature=TEMPERATURE,\n", + " top_p=TOP_P,\n", + " n_samples=N_SAMPLES_AGG,\n", + " api_key=\"EMPTY\",\n", + ")\n", + "\n", + "# Parse with aggregation enabled\n", + "df_aggregated = parse_forecasting_results(\n", + " raw_results_multi,\n", + " converter,\n", + " dm,\n", + " drop_failures=True,\n", + " aggregate_samples=True, # Average numeric values across samples\n", + ")\n", + "\n", + "print(f\"Aggregated results: {len(df_aggregated)} rows for {df_aggregated['patientid'].nunique()} patients\")\n", + "df_aggregated.head(20)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22", + "metadata": {}, + "outputs": [], + "source": [ + "# You can also get the individual (non-aggregated) samples for deeper analysis\n", + "df_individual = parse_forecasting_results(\n", + " raw_results_multi,\n", + " converter,\n", + " dm,\n", + " drop_failures=True,\n", + " aggregate_samples=False, # Keep individual samples\n", + ")\n", + "\n", + "print(f\"Individual results: {len(df_individual)} rows\")\n", + "print(f\"Samples per patient: {df_individual.groupby('patientid')['sample_idx'].nunique().describe()}\")\n", + "df_individual.head(20)" + ] + }, + { + "cell_type": "markdown", + "id": "23", + "metadata": {}, + "source": [ + "## 8. Clean up\n", + "\n", + "Shut down the vLLM server if we launched it from this notebook." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24", + "metadata": {}, + "outputs": [], + "source": [ + "if vllm_process is not None:\n", + " print(\"Terminating vLLM server...\")\n", + " vllm_process.terminate()\n", + " vllm_process.wait(timeout=10)\n", + " print(\"✅ Server stopped.\")" + ] + }, + { + "cell_type": "markdown", + "id": "25", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "This notebook demonstrated the full forecasting inference workflow:\n", + "\n", + "1. **Data preparation** – Load clinical data and generate forecasting instruction prompts\n", + "2. **Model serving** – Launch (or connect to) a vLLM OpenAI-compatible server\n", + "3. **Text generation** – Use `run_forecasting_inference_notebook` to generate completions\n", + "4. **Reverse conversion** – Use `parse_forecasting_results` to convert text → structured DataFrame\n", + "5. **Aggregation** – Optionally average multiple samples for more robust predictions\n", + "\n", + "### Key functions\n", + "\n", + "| Function | Purpose |\n", + "|---|---|\n", + "| `run_forecasting_inference()` | Generate completions for all patients (sync wrapper) |\n", + "| `run_forecasting_inference_notebook()` | Same but async – for notebooks |\n", + "| `parse_forecasting_results()` | Reverse-convert generated text → structured DataFrame |\n", + "\n", + "### Comparison with TTE probability inference\n", + "\n", + "| Aspect | TTE Inference | Forecasting Inference |\n", + "|---|---|---|\n", + "| **Method** | Log-prob scoring of fixed completions | Free-text generation |\n", + "| **Output** | Probabilities (censored/occurred/not occurred) | Predicted values at future time points |\n", + "| **API endpoint** | `/v1/completions` (scoring) | `/v1/chat/completions` (generation) |\n", + "| **Post-processing** | Softmax over log-probs | Reverse conversion (text → DataFrame) |\n", + "| **Multi-sample** | Not applicable | Average trajectories via `aggregate_multiple_responses` |\n", + "\n", + "### Next steps\n", + "\n", + "- **Fine-tune a model** on your dataset using `03_end_to_end_llm_finetuning.ipynb`\n", + "- **Combine with TTE** by using `inference_type=\"both\"` in the data splitter\n", + "- **Evaluate** predictions against ground truth to assess model performance\n", + "- **Experiment** with different variables, time horizons, and sample counts" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/reference/utils/forecasting_inference.md b/docs/reference/utils/forecasting_inference.md new file mode 100644 index 0000000..03750b5 --- /dev/null +++ b/docs/reference/utils/forecasting_inference.md @@ -0,0 +1,3 @@ +# Forecasting Inference + +::: twinweaver.utils.forecasting_inference diff --git a/examples/advanced/forecasting_inference/forecasting_vllm_inference.ipynb b/examples/advanced/forecasting_inference/forecasting_vllm_inference.ipynb new file mode 100644 index 0000000..c2b2826 --- /dev/null +++ b/examples/advanced/forecasting_inference/forecasting_vllm_inference.ipynb @@ -0,0 +1,658 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Forecasting Inference with TwinWeaver and vLLM\n", + "\n", + "This notebook demonstrates how to use TwinWeaver's **forecasting inference pipeline**\n", + "to predict future clinical values (e.g., lab results like hemoglobin) using a\n", + "fine-tuned LLM served via [vLLM](https://github.com/vllm-project/vllm).\n", + "\n", + "Unlike the TTE (time-to-event) probability pipeline, which *scores* fixed\n", + "completions, this pipeline uses **free-text generation**: the model produces\n", + "an answer string that is then **reverse-converted** back into a structured\n", + "DataFrame with predicted dates and values.\n", + "\n", + "### Pipeline overview\n", + "\n", + "```\n", + "Patient data ──► DataSplitter (forecasting) ──► ConverterInstruction\n", + " │\n", + " instruction text per patient\n", + " │\n", + " ┌──────────────────────────────┘\n", + " ▼\n", + " vLLM server (OpenAI-compatible API)\n", + " │\n", + " generated text completions\n", + " │\n", + " ▼\n", + " parse_forecasting_results()\n", + " (calls converter.reverse_conversion internally)\n", + " │\n", + " structured DataFrame\n", + " with predicted dates & values\n", + "```\n", + "\n", + "> **⚠️ Important:** The quality of the forecasts depends critically\n", + "> on having a **fine-tuned model**. An off-the-shelf instruction model\n", + "> (like the default `microsoft/Phi-4-mini-instruct` used here for demonstration)\n", + "> will produce **meaningless predictions**. Always fine-tune on your\n", + "> clinical dataset first (see `03_end_to_end_llm_finetuning.ipynb`).\n", + "\n", + "> **Requirements:**\n", + "> - A GPU with enough memory to serve the model via vLLM (≥16 GB for a 4-bit 8B model)\n", + "> - `pip install twinweaver[fine-tuning-example] vllm openai`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "import subprocess\n", + "import time\n", + "import sys\n", + "import os\n", + "\n", + "import pandas as pd\n", + "\n", + "from twinweaver import (\n", + " DataManager,\n", + " Config,\n", + " DataSplitterForecasting,\n", + " DataSplitterEvents,\n", + " DataSplitter,\n", + " ConverterInstruction,\n", + " run_forecasting_inference_notebook,\n", + " parse_forecasting_results,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "## 1. Configuration\n", + "\n", + "We define all key settings up front so they are easy to change in one place.\n", + "\n", + "> **Note:** Replace `MODEL_PATH` with the path to your **fine-tuned** model for\n", + "> meaningful results. The default `microsoft/Phi-4-mini-instruct` is only used\n", + "> here so that the notebook is self-contained." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "# ---------------------------------------------------------------------------\n", + "# Model & server settings\n", + "# ---------------------------------------------------------------------------\n", + "MODEL_PATH = \"microsoft/Phi-4-mini-instruct\" # ⚠️ Replace with your fine-tuned model path\n", + "TOKENIZER_PATH = MODEL_PATH # Usually the same as the model path\n", + "\n", + "VLLM_PORT = 8000\n", + "PREDICTION_URL = f\"http://0.0.0.0:{VLLM_PORT}/v1/\"\n", + "\n", + "MAX_CONTEXT_LENGTH = 4096 # Must match what the model was trained with\n", + "\n", + "# Generation settings\n", + "MAX_NEW_TOKENS = 256 # Max tokens for the generated forecast answer\n", + "TEMPERATURE = 0.9 # Sampling temperature (0 = greedy)\n", + "TOP_P = 0.9 # Nucleus sampling\n", + "N_SAMPLES = 3 # Number of independent samples per patient (>1 enables aggregation)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "# ---------------------------------------------------------------------------\n", + "# TwinWeaver data settings (same as used during training)\n", + "# ---------------------------------------------------------------------------\n", + "config = Config()\n", + "\n", + "# 1. Event category used for data splitting\n", + "config.split_event_category = \"lot\"\n", + "\n", + "# 2. Forecasting categories\n", + "config.event_category_forecast = [\"lab\"]\n", + "\n", + "# 3. Time-to-event variables (needed to initialise splitters, even if we only forecast)\n", + "config.event_category_events_prediction_with_naming = {\n", + " \"death\": \"death\",\n", + " \"progression\": \"next progression\",\n", + "}\n", + "\n", + "# 4. Constant (static) columns\n", + "config.constant_columns_to_use = [\"birthyear\", \"gender\", \"histology\", \"smoking_history\"]\n", + "config.constant_birthdate_column = \"birthyear\"" + ] + }, + { + "cell_type": "markdown", + "id": "5", + "metadata": {}, + "source": [ + "## 2. Load and prepare data\n", + "\n", + "We use the same example data shipped with TwinWeaver. In a real scenario you\n", + "would load your own clinical dataset here." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "# Load example data (adjust paths if running from a different directory)\n", + "df_events = pd.read_csv(\"../../example_data/events.csv\")\n", + "df_constant = pd.read_csv(\"../../example_data/constant.csv\")\n", + "df_constant_description = pd.read_csv(\"../../example_data/constant_description.csv\")\n", + "\n", + "print(f\"Loaded {len(df_events)} events for {df_events['patientid'].nunique()} patients\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "# Initialise DataManager and all splitters\n", + "dm = DataManager(config=config)\n", + "dm.load_indication_data(\n", + " df_events=df_events,\n", + " df_constant=df_constant,\n", + " df_constant_description=df_constant_description,\n", + ")\n", + "dm.process_indication_data()\n", + "dm.setup_unique_mapping_of_events()\n", + "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", + "dm.infer_var_types()\n", + "\n", + "data_splitter_events = DataSplitterEvents(\n", + " dm,\n", + " config=config,\n", + " max_length_to_sample=pd.Timedelta(weeks=104),\n", + " min_length_to_sample=pd.Timedelta(weeks=1),\n", + ")\n", + "data_splitter_events.setup_variables()\n", + "\n", + "data_splitter_forecasting = DataSplitterForecasting(\n", + " data_manager=dm,\n", + " config=config,\n", + " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n", + ")\n", + "data_splitter_forecasting.setup_statistics()\n", + "\n", + "# Combined interface\n", + "data_splitter = DataSplitter(data_splitter_events, data_splitter_forecasting)\n", + "\n", + "converter = ConverterInstruction(\n", + " nr_tokens_budget_total=MAX_CONTEXT_LENGTH,\n", + " config=config,\n", + " dm=dm,\n", + " variable_stats=data_splitter_forecasting.variable_stats,\n", + ")\n", + "\n", + "print(\"✅ Data pipeline ready\")" + ] + }, + { + "cell_type": "markdown", + "id": "8", + "metadata": {}, + "source": [ + "## 3. Generate forecasting instruction prompts\n", + "\n", + "For each test patient we generate a **forecasting-only** instruction prompt.\n", + "The key parameters are:\n", + "\n", + "- `forecasting_override_variables_to_predict` – which variables to forecast\n", + " (e.g. hemoglobin)\n", + "- `forecasting_future_weeks_per_variable` – at which future time points\n", + " (in weeks) to request predictions\n", + "\n", + "The converter produces a text instruction asking the model to predict the\n", + "specified variable(s) at the given future time points." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "# Choose patients to evaluate (here: all test-set patients)\n", + "test_patientids = dm.get_all_patientids_in_split(config.test_split_name)\n", + "print(f\"Number of test patients: {len(test_patientids)}\")\n", + "\n", + "# Define the prediction task:\n", + "# Which variables to predict and at which future week offsets\n", + "VARIABLES_TO_PREDICT = [\"hemoglobin_-_718-7\"]\n", + "FUTURE_WEEKS = {\n", + " \"hemoglobin_-_718-7\": [4, 8, 12], # Predict hemoglobin at 4, 8, and 12 weeks\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "# Build the list of prompt payloads for the forecasting API\n", + "# Each payload is a dict with \"patientid\", \"instruction\", and \"split_date\"\n", + "prompts_with_meta: list[dict] = []\n", + "\n", + "for pid in test_patientids:\n", + " patient_data = dm.get_patient_data(pid)\n", + "\n", + " # Use only data up to first lot event (simulates baseline information)\n", + " patient_data[\"events\"] = patient_data[\"events\"].sort_values(\"date\")\n", + " first_lot_date = patient_data[\"events\"][patient_data[\"events\"][\"event_category\"] == \"lot\"][\"date\"].min()\n", + " assert pd.notna(first_lot_date), f\"Patient {pid} has no lot event\"\n", + " patient_data[\"events\"] = patient_data[\"events\"][patient_data[\"events\"][\"date\"] <= first_lot_date].copy()\n", + "\n", + " # Generate the forecasting-only split for inference\n", + " forecast_split, _ = data_splitter.get_splits_from_patient_inference(\n", + " patient_data,\n", + " inference_type=\"forecasting\",\n", + " forecasting_override_variables_to_predict=VARIABLES_TO_PREDICT,\n", + " )\n", + "\n", + " # Convert to instruction text (no target answer)\n", + " converted = converter.forward_conversion_inference(\n", + " forecasting_split=forecast_split,\n", + " forecasting_future_weeks_per_variable=FUTURE_WEEKS,\n", + " )\n", + "\n", + " # Collect the prompt payload\n", + " prompts_with_meta.append(\n", + " {\n", + " \"patientid\": pid,\n", + " \"instruction\": converted[\"instruction\"],\n", + " \"split_date\": forecast_split.split_date_included_in_input,\n", + " }\n", + " )\n", + "\n", + "print(f\"Generated {len(prompts_with_meta)} instruction prompts\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "# Let's inspect one instruction to see what the model will receive\n", + "sample = prompts_with_meta[0]\n", + "print(f\"=== Patient: {sample['patientid']} ===\")\n", + "print(f\"Split date: {sample['split_date']}\\n\")\n", + "print(sample[\"instruction\"])" + ] + }, + { + "cell_type": "markdown", + "id": "12", + "metadata": {}, + "source": [ + "## 4. Launch the vLLM server\n", + "\n", + "We launch a vLLM OpenAI-compatible server as a background process.\n", + "\n", + "> **If you already have a vLLM server running**, skip this cell and just update\n", + "> `PREDICTION_URL` and `MODEL_PATH` in the configuration section above.\n", + "\n", + "> **Tip:** For production use, launch the server in a separate terminal." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "# Launch vLLM server as a background subprocess\n", + "# Set to True to skip launching (if you already have a server running)\n", + "SKIP_VLLM_LAUNCH = False\n", + "\n", + "vllm_process = None\n", + "\n", + "if not SKIP_VLLM_LAUNCH:\n", + " env = os.environ.copy()\n", + " env[\"VLLM_ATTENTION_BACKEND\"] = \"FLASH_ATTN\"\n", + "\n", + " vllm_command = [\n", + " sys.executable,\n", + " \"-m\",\n", + " \"vllm.entrypoints.openai.api_server\",\n", + " \"--port\",\n", + " str(VLLM_PORT),\n", + " \"--model\",\n", + " MODEL_PATH,\n", + " \"--tokenizer\",\n", + " TOKENIZER_PATH,\n", + " \"--enable-prefix-caching\",\n", + " ]\n", + "\n", + " print(f\"🚀 Launching vLLM server:\\n {' '.join(vllm_command)}\\n\")\n", + "\n", + " vllm_process = subprocess.Popen(\n", + " vllm_command,\n", + " env=env,\n", + " stdout=subprocess.PIPE,\n", + " stderr=subprocess.STDOUT,\n", + " text=True,\n", + " )\n", + "\n", + " # Wait for the server to be ready\n", + " WAIT_SECONDS = 240\n", + " print(f\"⏳ Waiting up to {WAIT_SECONDS}s for the server to start...\")\n", + " import urllib.request\n", + "\n", + " for i in range(WAIT_SECONDS):\n", + " time.sleep(1)\n", + " try:\n", + " urllib.request.urlopen(f\"http://localhost:{VLLM_PORT}/health\")\n", + " print(f\"✅ vLLM server is ready after {i + 1}s\")\n", + " break\n", + " except Exception:\n", + " pass\n", + " else:\n", + " print(\"⚠️ Server did not respond in time. Check GPU memory and logs.\")\n", + " print(\" You can read server output with: vllm_process.stdout.read()\")\n", + "else:\n", + " print(\"Skipping vLLM launch – using existing server.\")" + ] + }, + { + "cell_type": "markdown", + "id": "14", + "metadata": {}, + "source": [ + "## 5. Run forecasting inference\n", + "\n", + "This is the core step. `run_forecasting_inference_notebook` sends each patient's\n", + "instruction to the vLLM server and collects the generated text completions.\n", + "\n", + "Under the hood it:\n", + "1. Wraps each instruction in a chat message (with optional system prompt).\n", + "2. Calls the OpenAI-compatible `/v1/chat/completions` endpoint.\n", + "3. Returns the generated text(s) alongside patient metadata.\n", + "\n", + "When `n_samples > 1`, multiple independent completions are generated per\n", + "patient, which can later be aggregated into a mean trajectory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "# Run the forecasting inference\n", + "# This calls the vLLM server asynchronously for all patients\n", + "raw_results = await run_forecasting_inference_notebook(\n", + " prompts_with_meta,\n", + " prediction_url=PREDICTION_URL,\n", + " prediction_model=MODEL_PATH,\n", + " max_concurrent_requests=40,\n", + " max_new_tokens=MAX_NEW_TOKENS,\n", + " temperature=TEMPERATURE,\n", + " top_p=TOP_P,\n", + " n_samples=N_SAMPLES,\n", + " api_key=\"EMPTY\",\n", + " timeout=600.0,\n", + ")\n", + "\n", + "# Check for failures\n", + "n_success = sum(1 for r in raw_results if r is not None)\n", + "n_fail = sum(1 for r in raw_results if r is None)\n", + "print(f\"✅ Generated forecasts for {n_success} patients, {n_fail} failures\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "# Let's inspect the raw generated text for one patient\n", + "for r in raw_results:\n", + " if r is not None:\n", + " print(f\"=== Patient: {r['patientid']} ===\")\n", + " for i, text in enumerate(r[\"generated_texts\"]):\n", + " print(f\"\\n--- Sample {i} ---\")\n", + " print(text)\n", + " break" + ] + }, + { + "cell_type": "markdown", + "id": "17", + "metadata": {}, + "source": [ + "## 6. Parse results with reverse conversion\n", + "\n", + "`parse_forecasting_results` takes the raw generated texts and:\n", + "\n", + "1. Calls `converter.reverse_conversion` on each generated text to parse it\n", + " back into a structured DataFrame with dates and predicted values.\n", + "2. When `n_samples > 1` and `aggregate_samples=True`, aggregates multiple\n", + " trajectories using `converter.aggregate_multiple_responses` (e.g. averaging\n", + " numeric predictions).\n", + "3. Returns a single long-format DataFrame with all patients' predictions.\n", + "\n", + "> **Note:** Reverse conversion is robust to slightly malformed model output\n", + "> thanks to `inference_override=True`, but a fine-tuned model will produce\n", + "> much more parseable results than a generic instruction model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18", + "metadata": {}, + "outputs": [], + "source": [ + "# Parse the generated texts into structured DataFrames\n", + "df_results = parse_forecasting_results(\n", + " raw_results,\n", + " converter,\n", + " dm,\n", + " drop_failures=True,\n", + " aggregate_samples=(N_SAMPLES > 1), # Only aggregate if we have multiple samples\n", + ")\n", + "\n", + "print(f\"Parsed {len(df_results)} prediction rows for {df_results['patientid'].nunique()} patients\")\n", + "df_results.head(20)" + ] + }, + { + "cell_type": "markdown", + "id": "19", + "metadata": {}, + "source": [ + "### Understanding the output\n", + "\n", + "The returned DataFrame has the standard TwinWeaver event format:\n", + "\n", + "| Column | Description |\n", + "|---|---|\n", + "| `date` | Predicted date (computed from split_date + week offset) |\n", + "| `event_name` | The variable being predicted (e.g. `hemoglobin_-_718-7`) |\n", + "| `event_value` | The predicted value |\n", + "| `event_category` | Category of the event (e.g. `lab`) |\n", + "| `patientid` | Patient identifier |\n", + "| `task_type` | Which task type produced this prediction |\n", + "| `sample_idx` | Sample index (when `aggregate_samples=False` and `n_samples > 1`) |" + ] + }, + { + "cell_type": "markdown", + "id": "20", + "metadata": {}, + "source": [ + "## 7. Multi-sample aggregation (optional)\n", + "\n", + "When using `n_samples > 1`, each patient gets multiple independent forecast\n", + "trajectories. Aggregation (e.g. averaging numeric predictions) can reduce\n", + "variance and give more robust estimates.\n", + "\n", + "Below is an example of running with multiple samples and then aggregating." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21", + "metadata": {}, + "outputs": [], + "source": [ + "# Example: generate 3 samples per patient and aggregate\n", + "N_SAMPLES_AGG = 3\n", + "\n", + "raw_results_multi = await run_forecasting_inference_notebook(\n", + " prompts_with_meta,\n", + " prediction_url=PREDICTION_URL,\n", + " prediction_model=MODEL_PATH,\n", + " max_concurrent_requests=40,\n", + " max_new_tokens=MAX_NEW_TOKENS,\n", + " temperature=TEMPERATURE,\n", + " top_p=TOP_P,\n", + " n_samples=N_SAMPLES_AGG,\n", + " api_key=\"EMPTY\",\n", + ")\n", + "\n", + "# Parse with aggregation enabled\n", + "df_aggregated = parse_forecasting_results(\n", + " raw_results_multi,\n", + " converter,\n", + " dm,\n", + " drop_failures=True,\n", + " aggregate_samples=True, # Average numeric values across samples\n", + ")\n", + "\n", + "print(f\"Aggregated results: {len(df_aggregated)} rows for {df_aggregated['patientid'].nunique()} patients\")\n", + "df_aggregated.head(20)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22", + "metadata": {}, + "outputs": [], + "source": [ + "# You can also get the individual (non-aggregated) samples for deeper analysis\n", + "df_individual = parse_forecasting_results(\n", + " raw_results_multi,\n", + " converter,\n", + " dm,\n", + " drop_failures=True,\n", + " aggregate_samples=False, # Keep individual samples\n", + ")\n", + "\n", + "print(f\"Individual results: {len(df_individual)} rows\")\n", + "print(f\"Samples per patient: {df_individual.groupby('patientid')['sample_idx'].nunique().describe()}\")\n", + "df_individual.head(20)" + ] + }, + { + "cell_type": "markdown", + "id": "23", + "metadata": {}, + "source": [ + "## 8. Clean up\n", + "\n", + "Shut down the vLLM server if we launched it from this notebook." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24", + "metadata": {}, + "outputs": [], + "source": [ + "if vllm_process is not None:\n", + " print(\"Terminating vLLM server...\")\n", + " vllm_process.terminate()\n", + " vllm_process.wait(timeout=10)\n", + " print(\"✅ Server stopped.\")" + ] + }, + { + "cell_type": "markdown", + "id": "25", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "This notebook demonstrated the full forecasting inference workflow:\n", + "\n", + "1. **Data preparation** – Load clinical data and generate forecasting instruction prompts\n", + "2. **Model serving** – Launch (or connect to) a vLLM OpenAI-compatible server\n", + "3. **Text generation** – Use `run_forecasting_inference_notebook` to generate completions\n", + "4. **Reverse conversion** – Use `parse_forecasting_results` to convert text → structured DataFrame\n", + "5. **Aggregation** – Optionally average multiple samples for more robust predictions\n", + "\n", + "### Key functions\n", + "\n", + "| Function | Purpose |\n", + "|---|---|\n", + "| `run_forecasting_inference()` | Generate completions for all patients (sync wrapper) |\n", + "| `run_forecasting_inference_notebook()` | Same but async – for notebooks |\n", + "| `parse_forecasting_results()` | Reverse-convert generated text → structured DataFrame |\n", + "\n", + "### Comparison with TTE probability inference\n", + "\n", + "| Aspect | TTE Inference | Forecasting Inference |\n", + "|---|---|---|\n", + "| **Method** | Log-prob scoring of fixed completions | Free-text generation |\n", + "| **Output** | Probabilities (censored/occurred/not occurred) | Predicted values at future time points |\n", + "| **API endpoint** | `/v1/completions` (scoring) | `/v1/chat/completions` (generation) |\n", + "| **Post-processing** | Softmax over log-probs | Reverse conversion (text → DataFrame) |\n", + "| **Multi-sample** | Not applicable | Average trajectories via `aggregate_multiple_responses` |\n", + "\n", + "### Next steps\n", + "\n", + "- **Fine-tune a model** on your dataset using `03_end_to_end_llm_finetuning.ipynb`\n", + "- **Combine with TTE** by using `inference_type=\"both\"` in the data splitter\n", + "- **Evaluate** predictions against ground truth to assess model performance\n", + "- **Experiment** with different variables, time horizons, and sample counts" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/mkdocs.yml b/mkdocs.yml index 0e2b9ec..3dfe42a 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -140,6 +140,7 @@ nav: - Custom Summarized Row: examples/advanced/custom_output/custom_summarized_row.ipynb - Pretraining: examples/advanced/pretraining/prepare_pretraining_data.md - TTE Probability Inference: examples/advanced/tte_inference/tte_probability_inference.ipynb + - Forecasting Inference: examples/advanced/forecasting_inference/forecasting_vllm_inference.ipynb - Integrations: - MEDS Import: examples/integrations/meds_data_import.ipynb - API Reference: @@ -160,3 +161,4 @@ nav: - MEDS Importer: reference/utils/meds_importer.md - Preprocessing Helpers: reference/utils/preprocessing_helpers.md - TTE Inference: reference/utils/tte_inference.md + - Forecasting Inference: reference/utils/forecasting_inference.md diff --git a/twinweaver/__init__.py b/twinweaver/__init__.py index cef9112..a457dc0 100644 --- a/twinweaver/__init__.py +++ b/twinweaver/__init__.py @@ -20,3 +20,8 @@ run_tte_probability_estimation, run_tte_probability_estimation_notebook, ) +from twinweaver.utils.forecasting_inference import ( + parse_forecasting_results, + run_forecasting_inference, + run_forecasting_inference_notebook, +) diff --git a/twinweaver/utils/__init__.py b/twinweaver/utils/__init__.py index 96d4156..ce81882 100644 --- a/twinweaver/utils/__init__.py +++ b/twinweaver/utils/__init__.py @@ -7,3 +7,8 @@ compute_length_normalized_probabilities, run_tte_probability_estimation, ) +from twinweaver.utils.forecasting_inference import ( + parse_forecasting_results, + run_forecasting_inference, + run_forecasting_inference_notebook, +) diff --git a/twinweaver/utils/forecasting_inference.py b/twinweaver/utils/forecasting_inference.py new file mode 100644 index 0000000..f49199d --- /dev/null +++ b/twinweaver/utils/forecasting_inference.py @@ -0,0 +1,474 @@ +""" +Forecasting inference helpers for vLLM-based text generation. + +Provides functions to generate forecasting predictions (future lab values, +vitals, etc.) by sending instruction prompts to an OpenAI-compatible API +(e.g. a local vLLM server) and parsing the model's text output back into +structured DataFrames via :class:`~twinweaver.instruction.converter_instruction.ConverterInstruction`. + +The prompt construction is driven by +:class:`~twinweaver.instruction.converter_instruction.ConverterInstruction` +(specifically its ``forward_conversion_inference`` method), so the same +code works for any dataset / prompt template. + +Typical usage +------------- +>>> import asyncio +>>> from twinweaver.common.config import Config +>>> from twinweaver.utils.forecasting_inference import ( +... run_forecasting_inference, +... parse_forecasting_results, +... ) +>>> config = Config() +>>> # prompts_with_meta is a list of dicts with keys: +>>> # "patientid", "instruction", "split_date" +>>> results = asyncio.run(run_forecasting_inference( +... prompts_with_meta, +... prediction_url="http://localhost:8000/v1/", +... prediction_model="my-model", +... )) +>>> df = parse_forecasting_results(results, converter, dm) +""" + +from __future__ import annotations + +import asyncio +from typing import Any + +import pandas as pd +from openai import ( + APIConnectionError, + AsyncOpenAI, + AuthenticationError, + OpenAIError, + RateLimitError, +) + + +# --------------------------------------------------------------------------- +# Type alias for a single prompt payload +# --------------------------------------------------------------------------- +# Each element is a dict with at least: +# "patientid" : str – unique patient identifier +# "instruction" : str – full instruction text (from converter) +# "split_date" : datetime – the reference date used for reverse conversion +# Additional keys are preserved and passed through to the results. +PromptPayload = dict[str, Any] + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _build_chat_messages( + instruction: str, + *, + system_prompt: str | None = None, +) -> list[dict[str, str]]: + """Build the list of chat messages for the API call. + + Parameters + ---------- + instruction : str + The user-facing instruction text. + system_prompt : str or None + Optional system prompt. + + Returns + ------- + list[dict[str, str]] + """ + messages: list[dict[str, str]] = [] + if system_prompt is not None: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": instruction}) + return messages + + +# --------------------------------------------------------------------------- +# Async LLM call (single patient) +# --------------------------------------------------------------------------- + + +async def _call_llm_for_generation_async( + client: AsyncOpenAI, + model_to_use: str, + payload: PromptPayload, + semaphore: asyncio.Semaphore, + *, + system_prompt: str | None = None, + max_new_tokens: int = 512, + temperature: float = 0.7, + top_p: float = 0.9, + n_samples: int = 1, +) -> dict | None: + """Generate forecasting text for a single patient via the API. + + Parameters + ---------- + client : AsyncOpenAI + An ``openai.AsyncOpenAI`` client pointing at the inference server. + model_to_use : str + Model identifier / path accepted by the server. + payload : PromptPayload + Dict with ``"patientid"``, ``"instruction"``, ``"split_date"`` and + any other metadata the caller wants to pass through. + semaphore : asyncio.Semaphore + Concurrency limiter. + system_prompt : str or None, optional + Optional system prompt prepended to every request. + max_new_tokens : int + Maximum number of tokens to generate per completion. + temperature : float + Sampling temperature. + top_p : float + Nucleus-sampling probability mass. + n_samples : int + Number of independent completions to generate per prompt. When > 1 + the caller can later aggregate multiple trajectories. + + Returns + ------- + dict or None + A copy of the input *payload* augmented with: + + * ``"generated_texts"`` – list[str] of generated completion texts + (length = *n_samples*). + + Returns *None* on unrecoverable API errors. + """ + messages = _build_chat_messages(payload["instruction"], system_prompt=system_prompt) + + async with semaphore: + try: + response = await client.chat.completions.create( + model=model_to_use, + messages=messages, + max_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + n=n_samples, + ) + + generated_texts = [choice.message.content for choice in response.choices] + + result = {k: v for k, v in payload.items() if k != "instruction"} + result["generated_texts"] = generated_texts + return result + + except AuthenticationError as exc: + print(f"Authentication error for patient {payload.get('patientid')}: {exc}") + except RateLimitError as exc: + print(f"Rate limit exceeded for patient {payload.get('patientid')}: {exc}") + except APIConnectionError as exc: + print(f"Network error for patient {payload.get('patientid')}: {exc}") + raise + except OpenAIError as exc: + print(f"An OpenAI error occurred for patient {payload.get('patientid')}: {exc}") + + return None + + +# --------------------------------------------------------------------------- +# Async orchestrator (all patients) +# --------------------------------------------------------------------------- + + +async def _run_forecasting_inference_async( + prompts_with_meta: list[PromptPayload], + *, + prediction_url: str = "http://0.0.0.0:8000/v1/", + prediction_model: str = "default-model", + max_concurrent_requests: int = 40, + system_prompt: str | None = None, + max_new_tokens: int = 512, + temperature: float = 0.7, + top_p: float = 0.9, + n_samples: int = 1, + api_key: str = "EMPTY", + timeout: float = 600.0, +) -> list[dict | None]: + """Async implementation of :func:`run_forecasting_inference`.""" + client = AsyncOpenAI( + base_url=prediction_url, + api_key=api_key, + timeout=timeout, + ) + + semaphore = asyncio.Semaphore(max_concurrent_requests) + + tasks = [ + _call_llm_for_generation_async( + client, + prediction_model, + payload, + semaphore, + system_prompt=system_prompt, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + n_samples=n_samples, + ) + for payload in prompts_with_meta + ] + + return await asyncio.gather(*tasks) + + +def run_forecasting_inference( + prompts_with_meta: list[PromptPayload], + *, + prediction_url: str = "http://0.0.0.0:8000/v1/", + prediction_model: str = "default-model", + max_concurrent_requests: int = 40, + system_prompt: str | None = None, + max_new_tokens: int = 512, + temperature: float = 0.7, + top_p: float = 0.9, + n_samples: int = 1, + api_key: str = "EMPTY", + timeout: float = 600.0, +) -> list[dict | None]: + """Generate forecasting predictions for all patients via an OpenAI-compatible API. + + This is the main synchronous entry-point. It calls ``asyncio.run`` + internally so it can be used from plain scripts. + + Parameters + ---------- + prompts_with_meta : list[PromptPayload] + Each element is a dict with **at least** the following keys: + + * ``"patientid"`` – unique identifier (str) + * ``"instruction"`` – full instruction text produced by + ``ConverterInstruction.forward_conversion_inference`` + * ``"split_date"`` – the reference date (datetime) used when + building the split; needed later for ``reverse_conversion`` + + Any extra keys are passed through unchanged to the results. + + prediction_url : str + Base URL of the OpenAI-compatible inference server. + prediction_model : str + Model name / path served by the inference server. + max_concurrent_requests : int + Maximum number of concurrent API requests. + system_prompt : str or None + Optional system prompt. + max_new_tokens : int + Maximum number of tokens to generate per completion. + temperature : float + Sampling temperature (0 = greedy). + top_p : float + Nucleus-sampling probability mass. + n_samples : int + Number of independent completions per prompt. Useful for + trajectory aggregation (see + :meth:`ConverterInstruction.aggregate_multiple_responses`). + api_key : str + API key (``"EMPTY"`` for local vLLM servers). + timeout : float + Per-request timeout in seconds. + + Returns + ------- + list[dict or None] + One dict per patient. Each dict contains all keys from the input + payload (except ``"instruction"``) plus ``"generated_texts"`` – a + list of *n_samples* generated completion strings. + ``None`` entries indicate API failures. + """ + return asyncio.run( + _run_forecasting_inference_async( + prompts_with_meta, + prediction_url=prediction_url, + prediction_model=prediction_model, + max_concurrent_requests=max_concurrent_requests, + system_prompt=system_prompt, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + n_samples=n_samples, + api_key=api_key, + timeout=timeout, + ) + ) + + +def run_forecasting_inference_notebook( + prompts_with_meta: list[PromptPayload], + *, + prediction_url: str = "http://0.0.0.0:8000/v1/", + prediction_model: str = "default-model", + max_concurrent_requests: int = 40, + system_prompt: str | None = None, + max_new_tokens: int = 512, + temperature: float = 0.7, + top_p: float = 0.9, + n_samples: int = 1, + api_key: str = "EMPTY", + timeout: float = 600.0, +) -> list[dict | None]: + """Generate forecasting predictions – async variant for Jupyter notebooks. + + Identical to :func:`run_forecasting_inference` but returns a *coroutine* + that can be ``await``-ed directly in a notebook cell (which already has + a running event loop). + + Returns + ------- + Coroutine[..., list[dict or None]] + """ + return _run_forecasting_inference_async( + prompts_with_meta, + prediction_url=prediction_url, + prediction_model=prediction_model, + max_concurrent_requests=max_concurrent_requests, + system_prompt=system_prompt, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + n_samples=n_samples, + api_key=api_key, + timeout=timeout, + ) + + +# --------------------------------------------------------------------------- +# Post-processing: parse & reverse-convert generated text +# --------------------------------------------------------------------------- + + +def parse_forecasting_results( + raw_results: list[dict | None], + converter: Any, + data_manager: Any, + *, + drop_failures: bool = False, + aggregate_samples: bool = True, +) -> pd.DataFrame: + """Parse raw generated texts into structured DataFrames via reverse conversion. + + For each patient the function: + + 1. Calls ``converter.reverse_conversion`` on every generated text to obtain + structured forecasting DataFrames. + 2. When *n_samples > 1* and ``aggregate_samples`` is ``True``, aggregates the + multiple trajectories using ``converter.aggregate_multiple_responses``. + 3. Returns a single long-format DataFrame with all patients' predictions. + + Parameters + ---------- + raw_results : list[dict or None] + Output of :func:`run_forecasting_inference`. + converter : ConverterInstruction + The same converter instance used to generate the instruction prompts. + Must expose ``reverse_conversion`` and ``aggregate_multiple_responses``. + data_manager : DataManager + The data manager instance (passed to ``reverse_conversion``). + drop_failures : bool + If *True*, silently drop ``None`` entries (API failures). + If *False*, raise a ``ValueError`` when any entry is ``None``. + aggregate_samples : bool + If *True* (default) and multiple samples were generated per patient, + aggregate them via ``converter.aggregate_multiple_responses``. + If *False*, each sample is returned as a separate row block with + a ``"sample_idx"`` column. + + Returns + ------- + pd.DataFrame + A long-format DataFrame with columns from the reverse-converted + forecasting data plus ``"patientid"`` and optionally ``"sample_idx"``. + + Raises + ------ + ValueError + If *drop_failures* is *False* and any result is ``None``. + """ + if drop_failures: + valid = [r for r in raw_results if r is not None] + else: + if any(r is None for r in raw_results): + raise ValueError("Some results are None (API failures). Set drop_failures=True to silently ignore them.") + valid = raw_results # type: ignore[assignment] + + if not valid: + raise ValueError("No valid results to process.") + + all_rows: list[pd.DataFrame] = [] + + for result in valid: + patientid = result["patientid"] + split_date = result["split_date"] + generated_texts = result["generated_texts"] + + # Reverse-convert each sample + sample_dfs: list[pd.DataFrame] = [] + for sample_idx, text in enumerate(generated_texts): + try: + parsed_tasks = converter.reverse_conversion( + text, + data_manager, + split_date, + patientid=patientid, + inference_override=True, + ) + except Exception as exc: + print(f"Warning: reverse_conversion failed for patient {patientid} sample {sample_idx}: {exc}") + continue + + # Collect forecasting task results + for task_result in parsed_tasks: + task_type = task_result.get("task_type", "") + result_data = task_result.get("result") + + if isinstance(result_data, pd.DataFrame) and not result_data.empty: + df_task = result_data.copy() + df_task["patientid"] = patientid + df_task["sample_idx"] = sample_idx + df_task["task_type"] = task_type + sample_dfs.append(df_task) + + if not sample_dfs: + continue + + if aggregate_samples and len(generated_texts) > 1 and len(sample_dfs) > 1: + # Use the converter's built-in aggregation (groups by task type) + # Separate by task type, aggregate each, then combine + combined = pd.concat(sample_dfs, ignore_index=True) + task_types = combined["task_type"].unique() + agg_parts = [] + for tt in task_types: + task_subset = combined[combined["task_type"] == tt] + # Group into per-sample DataFrames for the aggregator + per_sample = [ + task_subset[task_subset["sample_idx"] == si].drop( + columns=["sample_idx", "task_type"], errors="ignore" + ) + for si in task_subset["sample_idx"].unique() + ] + try: + agg_df, _meta = converter.aggregate_multiple_responses(per_sample) + agg_df["task_type"] = tt + agg_df["patientid"] = patientid + agg_parts.append(agg_df) + except Exception as exc: + print(f"Warning: aggregation failed for patient {patientid}: {exc}") + # Fallback: just keep first sample + fallback = per_sample[0].copy() + fallback["task_type"] = tt + fallback["patientid"] = patientid + agg_parts.append(fallback) + + if agg_parts: + all_rows.append(pd.concat(agg_parts, ignore_index=True)) + else: + all_rows.extend(sample_dfs) + + if not all_rows: + return pd.DataFrame() + + df_out = pd.concat(all_rows, ignore_index=True) + return df_out From c0f78d5eb38f1853e6f4b580e2d3d1c02478d7b4 Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 15:12:35 +0000 Subject: [PATCH 22/36] Minor fixes --- .../forecasting_vllm_inference.ipynb | 72 ++++++++++++++----- .../instruction/converter_instruction.py | 3 + 2 files changed, 59 insertions(+), 16 deletions(-) diff --git a/examples/advanced/forecasting_inference/forecasting_vllm_inference.ipynb b/examples/advanced/forecasting_inference/forecasting_vllm_inference.ipynb index c2b2826..f57bd1c 100644 --- a/examples/advanced/forecasting_inference/forecasting_vllm_inference.ipynb +++ b/examples/advanced/forecasting_inference/forecasting_vllm_inference.ipynb @@ -98,7 +98,7 @@ "# ---------------------------------------------------------------------------\n", "# Model & server settings\n", "# ---------------------------------------------------------------------------\n", - "MODEL_PATH = \"microsoft/Phi-4-mini-instruct\" # ⚠️ Replace with your fine-tuned model path\n", + "MODEL_PATH = \"microsoft/Phi-4-mini-instruct\" # ⚠️ Replace with your fine-tuned model path!!!!!!!!!!!\n", "TOKENIZER_PATH = MODEL_PATH # Usually the same as the model path\n", "\n", "VLLM_PORT = 8000\n", @@ -478,6 +478,22 @@ "id": "18", "metadata": {}, "outputs": [], + "source": [ + "converter = ConverterInstruction(\n", + " nr_tokens_budget_total=MAX_CONTEXT_LENGTH,\n", + " config=config,\n", + " dm=dm,\n", + " variable_stats=data_splitter_forecasting.variable_stats,\n", + ")\n", + "print(\"✅ Converter reloaded with latest code\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19", + "metadata": {}, + "outputs": [], "source": [ "# Parse the generated texts into structured DataFrames\n", "df_results = parse_forecasting_results(\n", @@ -488,13 +504,17 @@ " aggregate_samples=(N_SAMPLES > 1), # Only aggregate if we have multiple samples\n", ")\n", "\n", - "print(f\"Parsed {len(df_results)} prediction rows for {df_results['patientid'].nunique()} patients\")\n", - "df_results.head(20)" + "if df_results.empty:\n", + " print(\"⚠️ No predictions could be parsed. This is expected when using a non-fine-tuned model.\")\n", + " print(\" Fine-tune a model first (see 03_end_to_end_llm_finetuning.ipynb) for meaningful results.\")\n", + "else:\n", + " print(f\"Parsed {len(df_results)} prediction rows for {df_results['patientid'].nunique()} patients\")\n", + " df_results.head(20)" ] }, { "cell_type": "markdown", - "id": "19", + "id": "20", "metadata": {}, "source": [ "### Understanding the output\n", @@ -514,7 +534,7 @@ }, { "cell_type": "markdown", - "id": "20", + "id": "21", "metadata": {}, "source": [ "## 7. Multi-sample aggregation (optional)\n", @@ -529,7 +549,7 @@ { "cell_type": "code", "execution_count": null, - "id": "21", + "id": "22", "metadata": {}, "outputs": [], "source": [ @@ -557,14 +577,17 @@ " aggregate_samples=True, # Average numeric values across samples\n", ")\n", "\n", - "print(f\"Aggregated results: {len(df_aggregated)} rows for {df_aggregated['patientid'].nunique()} patients\")\n", - "df_aggregated.head(20)" + "if df_aggregated.empty:\n", + " print(\"⚠️ No predictions could be parsed. This is expected when using a non-fine-tuned model.\")\n", + "else:\n", + " print(f\"Aggregated results: {len(df_aggregated)} rows for {df_aggregated['patientid'].nunique()} patients\")\n", + " df_aggregated.head(20)" ] }, { "cell_type": "code", "execution_count": null, - "id": "22", + "id": "23", "metadata": {}, "outputs": [], "source": [ @@ -577,14 +600,17 @@ " aggregate_samples=False, # Keep individual samples\n", ")\n", "\n", - "print(f\"Individual results: {len(df_individual)} rows\")\n", - "print(f\"Samples per patient: {df_individual.groupby('patientid')['sample_idx'].nunique().describe()}\")\n", - "df_individual.head(20)" + "if df_individual.empty:\n", + " print(\"⚠️ No predictions could be parsed. This is expected when using a non-fine-tuned model.\")\n", + "else:\n", + " print(f\"Individual results: {len(df_individual)} rows\")\n", + " print(f\"Samples per patient: {df_individual.groupby('patientid')['sample_idx'].nunique().describe()}\")\n", + " df_individual.head(20)" ] }, { "cell_type": "markdown", - "id": "23", + "id": "24", "metadata": {}, "source": [ "## 8. Clean up\n", @@ -595,7 +621,7 @@ { "cell_type": "code", "execution_count": null, - "id": "24", + "id": "25", "metadata": {}, "outputs": [], "source": [ @@ -608,7 +634,7 @@ }, { "cell_type": "markdown", - "id": "25", + "id": "26", "metadata": {}, "source": [ "## Summary\n", @@ -649,8 +675,22 @@ } ], "metadata": { + "kernelspec": { + "display_name": ".venv_dev_gpu", + "language": "python", + "name": "python3" + }, "language_info": { - "name": "python" + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" } }, "nbformat": 4, diff --git a/twinweaver/instruction/converter_instruction.py b/twinweaver/instruction/converter_instruction.py index bf75ca5..dd21c70 100644 --- a/twinweaver/instruction/converter_instruction.py +++ b/twinweaver/instruction/converter_instruction.py @@ -812,6 +812,9 @@ def reverse_conversion( standard_extraction = False elif inference_override is False: raise ValueError("Could not determine task type") + else: + # inference_override is True but task type is unknown – skip + continue #: extract the relevant parts if standard_extraction: From 368f4a06467c08358a5b34dc3d605b49f8b55b8f Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 15:15:28 +0000 Subject: [PATCH 23/36] Adjusted TTE example --- .../tte_inference/tte_probability_inference.ipynb | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/examples/advanced/tte_inference/tte_probability_inference.ipynb b/examples/advanced/tte_inference/tte_probability_inference.ipynb index 9b6a861..c442b0b 100644 --- a/examples/advanced/tte_inference/tte_probability_inference.ipynb +++ b/examples/advanced/tte_inference/tte_probability_inference.ipynb @@ -590,7 +590,15 @@ "print(f\"\\n✅ Total rows: {len(df_all_horizons)}\")\n", "df_all_horizons = df_all_horizons.sort_values([\"patientid\", \"week_horizon\"])\n", "df_all_horizons[\n", - " [\"patientid\", \"week_horizon\", \"probability_occurrence\", \"probability_no_occurrence\", \"probability_censored\"]\n", + " [\n", + " \"patientid\",\n", + " \"week_horizon\",\n", + " \"probability_occurrence\",\n", + " \"probability_no_occurrence\",\n", + " \"probability_censored\",\n", + " \"probability_occurrence_renormalized\",\n", + " \"probability_no_occurrence_renormalized\",\n", + " ]\n", "]" ] }, From 0b55f4e6f4704aa1b80670fc38fe5667730f3243 Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 15:19:26 +0000 Subject: [PATCH 24/36] Removed some death hard coded values --- twinweaver/common/converter_base.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/twinweaver/common/converter_base.py b/twinweaver/common/converter_base.py index c3eb494..70dba94 100644 --- a/twinweaver/common/converter_base.py +++ b/twinweaver/common/converter_base.py @@ -113,10 +113,11 @@ def __init__(self, config: Config) -> None: self.config.event_category_and_name_replace_override if self.config.event_category_and_name_replace_override is not None else { + # Death specific overrides, using config constants for category and event name self.config.event_category_death: { # Use config constant - "death": { # Assuming 'death' is the event_name associated with this category - "full_replacement_string": "death", - "reverse_string_value": "death", + self.config.event_category_death: { + "full_replacement_string": self.config.event_category_death, + "reverse_string_value": self.config.event_category_death, } } } From 8d4bd4471fede62a9d2f9509de60195cd15db2d7 Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 15:22:56 +0000 Subject: [PATCH 25/36] Added constant_birthdate_columns_silence_print option to config to silence birthdate messages --- examples/01_data_preparation_for_training.ipynb | 2 +- twinweaver/common/config.py | 1 + twinweaver/common/converter_base.py | 14 ++++++++++++-- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/examples/01_data_preparation_for_training.ipynb b/examples/01_data_preparation_for_training.ipynb index e74876d..d42c8f7 100644 --- a/examples/01_data_preparation_for_training.ipynb +++ b/examples/01_data_preparation_for_training.ipynb @@ -423,7 +423,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv_dev", + "display_name": ".venv_dev_gpu", "language": "python", "name": "python3" }, diff --git a/twinweaver/common/config.py b/twinweaver/common/config.py index 9955710..28bf7d4 100644 --- a/twinweaver/common/config.py +++ b/twinweaver/common/config.py @@ -425,6 +425,7 @@ def __init__(self): ] # Which columns to use from the constant data self.constant_birthdate_column: str = None # If set, use this column for age calculation self.constant_birthdate_column_format: str = "date" # Either "date" or "age" + self.constant_birthdate_columns_silence_print: bool = False # To silence print statements related to birthdate # Used to backup event categories for event types if no variables are found # e.g. progression -> death diff --git a/twinweaver/common/converter_base.py b/twinweaver/common/converter_base.py index 70dba94..56d0c7a 100644 --- a/twinweaver/common/converter_base.py +++ b/twinweaver/common/converter_base.py @@ -176,7 +176,12 @@ def _preprocess_constant_date( constant[self.config.constant_birthdate_column] = ( constant[self.config.constant_birthdate_column].astype(int).astype(str) + " years" ) - print(f"Using provided ages in {self.config.constant_birthdate_column} as age format") + if not self.config.constant_birthdate_columns_silence_print: + print( + f"Using provided ages in {self.config.constant_birthdate_column} as age format." + "To silence this print statement, set constant_birthdate_columns_silence_print to True" + "in the config." + ) else: # Check if the column contains integer ages (not birthdates) try: @@ -185,7 +190,12 @@ def _preprocess_constant_date( constant[self.config.constant_birthdate_column] = pd.to_datetime( constant[self.config.constant_birthdate_column].astype(int).astype(str) + "-01-01" ) - print(f"Converted integer ages in {self.config.constant_birthdate_column} to age format") + if not self.config.constant_birthdate_columns_silence_print: + print( + f"Converted integer ages in {self.config.constant_birthdate_column} to age format." + "To silence this print statement set constant_birthdate_columns_silence_print to True" + "in the config." + ) # Try converting the column to datetime if it is not already, if doesn't work then just keep it elif not pd.api.types.is_datetime64_any_dtype(constant[self.config.constant_birthdate_column]): From 2204cbeaf9ff3670c9a4cb56c939bb08ef823db4 Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 15:24:34 +0000 Subject: [PATCH 26/36] Added more descriptive assert statements --- twinweaver/common/data_manager.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/twinweaver/common/data_manager.py b/twinweaver/common/data_manager.py index b743f8c..6144d2d 100644 --- a/twinweaver/common/data_manager.py +++ b/twinweaver/common/data_manager.py @@ -333,7 +333,11 @@ def setup_unique_mapping_of_events(self) -> None: # Assert that all unique now # Use config constant - assert len(self.unique_events) == len(self.data_frames[events_table_key][event_desc_name_col].unique()) + assert len(self.unique_events) == len(self.data_frames[events_table_key][event_desc_name_col].unique()), ( + "Each descriptive name needs a unique mapping to an event name - please check the data and the " + "whether there are any duplicates in the event_descriptive_name column after processing. " + "If there are, consider adding more unique identifiers or modifying the existing ones." + ) def setup_hold_out_sets( self, From fbf65623160c9bea817f269399bd32a929c40d6d Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 15:34:20 +0000 Subject: [PATCH 27/36] Added better warning for non unique descriptive names --- twinweaver/common/data_manager.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/twinweaver/common/data_manager.py b/twinweaver/common/data_manager.py index 6144d2d..544c529 100644 --- a/twinweaver/common/data_manager.py +++ b/twinweaver/common/data_manager.py @@ -285,6 +285,7 @@ def setup_unique_mapping_of_events(self) -> None: # Extract corresponding event_name and event_category filtered_events = self.unique_events[event_desc_name_col] non_unique_events = self.unique_events[filtered_events.isin(non_unique_events.index)].copy() + non_unique_descriptive_names = non_unique_events[event_desc_name_col].unique().tolist() # create mapping for all non-unique descriptive names, and # then add event_name to those, and apply across entire dataset @@ -307,6 +308,15 @@ def setup_unique_mapping_of_events(self) -> None: events_df[event_desc_name_col] = events_df[new_desc_name].fillna(events_df[event_desc_name_col]) self.data_frames[events_table_key] = self.data_frames[events_table_key].drop(columns=["new_descriptive_name"]) + # Warn if there were any non-unique descriptive names + num_renamed = len(non_unique_events) + if num_renamed > 0: + logging.warning( + f"Found {len(non_unique_descriptive_names)} non-unique descriptive name(s) " + f"({num_renamed} event mappings total) that were disambiguated " + f"by appending the event_name: {non_unique_descriptive_names}" + ) + #: first convert special symbols in event_descriptive_name to alternatives, using self.replace_special_symbols for event_category, ( string_to_replace, @@ -334,9 +344,12 @@ def setup_unique_mapping_of_events(self) -> None: # Assert that all unique now # Use config constant assert len(self.unique_events) == len(self.data_frames[events_table_key][event_desc_name_col].unique()), ( - "Each descriptive name needs a unique mapping to an event name - please check the data and the " - "whether there are any duplicates in the event_descriptive_name column after processing. " - "If there are, consider adding more unique identifiers or modifying the existing ones." + "Each descriptive name needs a unique mapping to an event name/category" + f" Found this many unique descriptive names: " + f"{len(self.data_frames[events_table_key][event_desc_name_col].unique())} " + f"but this many unique combinations of event name, descriptive name, and " + f"category: {len(self.unique_events)}. Please ensure that after processing, each descriptive name maps " + f"to exactly one event name within its category." ) def setup_hold_out_sets( From 39771b7996dd7664bb1c296f11050e2f41da7748 Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 15:37:41 +0000 Subject: [PATCH 28/36] Added in DataSplitterForecasting checks so that event_category_forecast and split_event_category do not overlap --- twinweaver/instruction/data_splitter_forecasting.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/twinweaver/instruction/data_splitter_forecasting.py b/twinweaver/instruction/data_splitter_forecasting.py index 95590e5..6fbc279 100644 --- a/twinweaver/instruction/data_splitter_forecasting.py +++ b/twinweaver/instruction/data_splitter_forecasting.py @@ -211,6 +211,14 @@ def __init__( self._filtering_methods = {"3-sigma": self._filter_3_sigma} + # Check that the forecasting and split event categories do not overlap, as this could cause data leakage + if self.config.event_category_forecast is not None and self.config.split_event_category is not None: + overlap = set(self.config.event_category_forecast).intersection(set(self.config.split_event_category)) + if overlap: + raise ValueError( + f"Forecasting and split event categories overlap: {overlap}. This could cause data leakage." + ) + def setup_statistics(self, train_patientids: list = None): """ Calculates baseline performance statistics for variables. From 6d4222d4bc6bf35a856a11294d544a27741e7a26 Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 15:43:05 +0000 Subject: [PATCH 29/36] Added automatic checks for constant description DF --- tests/conftest.py | 2 ++ twinweaver/common/data_manager.py | 13 +++++++++++++ 2 files changed, 15 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 4612096..5f3c3b6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,8 @@ def mock_config(): cfg = Config() # Ensure the random seed is fixed for reproducible tests cfg.seed = 42 + # Set constant_columns_to_use to match the test data + cfg.constant_columns_to_use = ["birthyear", "gender", "histology", "smoking_history"] return cfg diff --git a/twinweaver/common/data_manager.py b/twinweaver/common/data_manager.py index 544c529..5261fa6 100644 --- a/twinweaver/common/data_manager.py +++ b/twinweaver/common/data_manager.py @@ -232,6 +232,19 @@ def handle_missing_dates(df_key, missing_count, total_count, col_date): self.config.event_category_col ].astype(str) + # Check that every column selected in config for constant is also in constant descriptive df + constant_desc_variables = set(self.data_frames["constant_description"]["variable"].unique()) + missing_in_description = [ + col for col in self.config.constant_columns_to_use if col not in constant_desc_variables + ] + if missing_in_description: + raise ValueError( + f"The following columns are listed in config.constant_columns_to_use but are not " + f"present in the constant_description 'variable' column: {missing_in_description}. " + f"Please add them to the constant_description dataframe or remove them from " + f"config.constant_columns_to_use." + ) + logging.info("Data processed") def setup_unique_mapping_of_events(self) -> None: From 1e58f939e1ee93dd260bc51a1ada42b7c58a8959 Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 15:49:24 +0000 Subject: [PATCH 30/36] Now data manager allows all patient from same train/val/test set --- twinweaver/common/data_manager.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/twinweaver/common/data_manager.py b/twinweaver/common/data_manager.py index 5261fa6..620adc8 100644 --- a/twinweaver/common/data_manager.py +++ b/twinweaver/common/data_manager.py @@ -245,6 +245,11 @@ def handle_missing_dates(df_key, missing_count, total_count, col_date): f"config.constant_columns_to_use." ) + # Add all unique patientids overlapping constant and events to self.all_patientids + constant_patientids = set(self.data_frames["constant"]["patientid"].unique()) + event_patientids = set(self.data_frames["events"]["patientid"].unique()) + self.all_patientids = list(constant_patientids.intersection(event_patientids)) + logging.info("Data processed") def setup_unique_mapping_of_events(self) -> None: From c736fd647186e4cc6267d85e0ec94918f756f679 Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 15:52:46 +0000 Subject: [PATCH 31/36] Improved docs on relative dating --- docs/dataset-format.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/dataset-format.md b/docs/dataset-format.md index 80b5967..564cced 100644 --- a/docs/dataset-format.md +++ b/docs/dataset-format.md @@ -104,6 +104,10 @@ On the first visit, the patient experienced the following: Hemoglobin is 11.8. ``` +#### Relative Dating + +TwinWeaver uses **relative dating** instead of absolute dates. All calendar dates from the input data are converted into time deltas relative to the previous visit (e.g., *"2 weeks later"*) rather than being included as raw dates (e.g., *"2024-01-29"*). This serves two important purposes: first, it **anonymizes** the patient data by removing identifiable calendar dates from the training text; second, it provides the model with **clinically meaningful temporal context** — the time elapsed between visits — rather than arbitrary date strings. By default, time intervals are expressed in weeks, but this can be changed to days using `Config.set_delta_time_unit("days")`. Accumulative deltas (time since the very first visit rather than since the previous visit) are also supported. + ### Final Output Structure For training, TwinWeaver produces input-target pairs: From 7846f22b596ce3210cf840754c8065399b5d8689 Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 15:58:10 +0000 Subject: [PATCH 32/36] Fixed minor issues with docs --- docs/images/favicon.png | Bin 0 -> 958 bytes mkdocs.yml | 8 ++++---- 2 files changed, 4 insertions(+), 4 deletions(-) create mode 100644 docs/images/favicon.png diff --git a/docs/images/favicon.png b/docs/images/favicon.png new file mode 100644 index 0000000000000000000000000000000000000000..1fa93207d632b7b2408450901c98b59d75218fda GIT binary patch literal 958 zcmV;v13~RE1+X3TSh*Z-+Cl4i_Ts_a#) z?viHAl4i{O*MxYpR#mL-l4i_RtnU7)H9m~U{n&&5t29-t?x=0JbWWyZ#_?|%FnO=FVTy##Vh+e@$lFtjZ6?gyu04Q`)PE!CuI?oVd&k%ooVGy6j zAfNBK@2Oy7&y+A*&w$Th-~530_rKrjkC#6Np#lH^0*gsRK~#8NrPtYVn=llB;RBWp zN{qtd#C8@p&MLrL_N{&Y*Q&NwN;`uk&VRw>H%MrJ1L3hMt7d?2l5@^;yfe$WEW$eE zMmhL9y5p*Qti!Z(wK_mTm3BT?Sjp`4$qFe&%BkQGD^c_4(Nd-$04rg(C<7xsZ;$&z zT=p_H(ssS>DG?2|T1w8>eWl1iD9J*j7A09I8n^A!kt>n79|Ys4_r{Tmx-P@0=%*|K zIBFUzJ=UTk4EmB?f-B916AUo}{guePsCCo50Gm&XL2cTI9*R`9TrOX4>pzvOTOsM9 z_Nk2bFlFu4pxcQTeoUD{XLG;XS)`=fi5=keb(zYF+s(#B21Odfo>*y+LCWl*x&D#? z#3BM-73RicduS{a(A=!)z=kr?INcZWzcp_EpYxLiNTn2|m{?lW%68FV+h%eY7Rtecrg z1@ACVhF7H1iE=+Ci=f-_Qz-NJ2iH|}_}(<9iNM+0gXo%?qM85Cd`{*1FB&7g$9eq>W5KXvus;0MHUMLjhV!&0b@rt#V2Ltn_GBhLz07!>|x>=Px|8Ab_sW g^A|ocEs7q_Z(sk%(P)%Xh5!Hn07*qoM6N<$f?|-^x&QzG literal 0 HcmV?d00001 diff --git a/mkdocs.yml b/mkdocs.yml index 3dfe42a..131f47d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -4,7 +4,7 @@ site_url: https://example.com/twinweaver/ # 1. Source Code Link (Adds a nice GitHub icon to the header) repo_name: TwinWeaver -repo_url: https://github.com/your-org/twinweaver # Replace with actual +repo_url: https://github.com/MendenLab/TwinWeaver edit_uri: edit/main/docs/ extra_css: @@ -13,10 +13,10 @@ extra_css: theme: name: material custom_dir: docs/overrides - # 2. Custom Logo/Favicon (Use an emoji if you don't have an SVG yet) + # 2. Custom Logo/Favicon icon: logo: material/dna - favicon: material/dna + favicon: images/favicon.png features: - navigation.tabs # Top-level tabs @@ -54,7 +54,7 @@ theme: extra: social: - icon: fontawesome/brands/github - link: https://github.com/your-org/twinweaver + link: https://github.com/MendenLab/TwinWeaver - icon: fontawesome/brands/python link: https://pypi.org/project/twinweaver/ analytics: From 1269b6b5b50d5fae2f64b4ac925520e5378d70b4 Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 16:03:10 +0000 Subject: [PATCH 33/36] Improved docstrings --- twinweaver/common/config.py | 12 +++++--- twinweaver/common/data_manager.py | 47 ++++++++++++++++++++++++------- 2 files changed, 45 insertions(+), 14 deletions(-) diff --git a/twinweaver/common/config.py b/twinweaver/common/config.py index 28bf7d4..a1d3f37 100644 --- a/twinweaver/common/config.py +++ b/twinweaver/common/config.py @@ -54,7 +54,7 @@ class Config: source_col_default_value : str Default value to assign to `source_col` if it is missing. Default: "events". split_date_col : str - Column name specifically used for dates related to line of therapy (LoT) events. Default: "lot_date". + Column name used for dates related to data splitting events (e.g., line of therapy). Default: "split_date". lot_concatenate_descriptive_and_value : bool Flag indicating whether to concatenate the descriptive name and value for line of therapy events. Default: False. @@ -71,6 +71,8 @@ class Config: List of event categories to be considered for forecasting tasks. Default: None. split_event_category : str | None Event category used for data splitting (e.g., LoT). Default: None. + event_categories_to_exclude_from_input : list[str] + List of event categories to exclude from the input data (e.g., ["lot"]). Default: []. source_genetic : str Specific string value used in `source_col` to identify data originating from genetic testing. Default: "genetic". @@ -106,7 +108,7 @@ class Config: Text inserted before the description of events for visits subsequent to the first one. Default: "\\n". event_day_text : str Template text used to introduce events on subsequent visit days, indicating the time elapsed since the previous - visit. Default: " self.delta_time_unit : later, the patient visited and experienced the following: \\n". + visit. Default: " weeks later, the patient visited and experienced the following: \\n". post_event_text : str Text appended after listing all events for a specific visit day. Default: ".\\n". forecasting_fval_prompt_start : str @@ -223,6 +225,8 @@ class Config: Default: None. constant_birthdate_column_format : str Format of the birthdate column, either "date" or "age". Default: "date". + constant_birthdate_columns_silence_print : bool + Whether to silence print statements related to birthdate column processing. Default: False. event_category_events_prediction_with_naming : dict | None Mapping defining which event categories correspond to specific prediction types in DataSplitterEvents. Keys are event categories (e.g., 'death', 'progression'), values are descriptive names for the target variable. @@ -458,11 +462,11 @@ def seed(self) -> int: @seed.setter def seed(self, value: int): - """Set the seed value and update all random seeds (numpy, pandas, random).""" + """Set the seed value and update all random seeds (numpy and random).""" self._seed = value self._set_all_seeds(value) def _set_all_seeds(self, seed: int): - """Set seeds for numpy, pandas, and random modules.""" + """Set seeds for numpy and random modules.""" np.random.seed(seed) random.seed(seed) diff --git a/twinweaver/common/data_manager.py b/twinweaver/common/data_manager.py index 620adc8..3c4bda1 100644 --- a/twinweaver/common/data_manager.py +++ b/twinweaver/common/data_manager.py @@ -39,7 +39,7 @@ def __init__( A list of tuples to override the default special character replacements in event descriptive names. Each tuple should be in the format `(event_category, (string_to_replace, replacement_string))`. If None, - default replacements specified in the method are used. Defaults to None. + an empty list is used (no replacements). Defaults to None. """ #: initialize the data manager @@ -146,16 +146,43 @@ def process_indication_data( Performs initial processing on the loaded indication data. Requires `load_indication_data` to be called first. - This method converts the date columns (specified by `config.date_col`) - in the 'events' DataFrame to datetime objects. - It also checks for and removes rows with missing dates in these tables, - logging an error if any are found, unless `skip_missing_dates` is True. + This method performs the following steps: + + 1. Converts the date column (specified by `config.date_col`) in the + 'events' DataFrame to datetime objects. + 2. Checks for and removes rows with missing dates, raising a + ``ValueError`` unless `skip_missing_dates` is True. + 3. Checks for missing event values and either drops them (if + `drop_missing_event_values` is True) or raises a ``ValueError``. + 4. Validates that there are no missing values in + ``event_descriptive_name``, ``event_name``, and ``event_category`` + columns, raising a ``ValueError`` if any are found. + 5. Converts event values, descriptive names, event names, and event + categories to string type. + 6. Validates that all columns listed in + ``config.constant_columns_to_use`` are present in the + ``constant_description`` dataframe. + 7. Computes ``self.all_patientids`` as the intersection of patient IDs + appearing in both the constant and events tables. + + Parameters + ---------- + skip_missing_dates : bool, optional + If True, rows with missing dates are silently dropped instead of + raising an error. Defaults to False. + drop_missing_event_values : bool, optional + If True, rows with missing event values are dropped with a warning + instead of raising an error. Defaults to False. Raises ------ ValueError If `load_indication_data` has not been successfully called before - this method, or if missing dates are found and `skip_missing_dates` is False. + this method, if missing dates are found and `skip_missing_dates` + is False, if missing event values are found and + `drop_missing_event_values` is False, if missing values are found + in event name columns, or if constant columns are missing from + the constant_description dataframe. """ # Check that we already have self.data_frames @@ -383,8 +410,8 @@ def setup_hold_out_sets( The method determines the split assignment for each patient. It retrieves all unique patient IDs from the 'constant' data table. It calculates the number of patients for validation and test sets based on - the `validation_split_max`, `test_split_max`, and `max_val_test_nr_patients` - parameters set during initialization. The remaining patients are assigned to the training set # + the `validation_split`, `test_split`, and `max_val_test_nr_patients` + parameters. The remaining patients are assigned to the training set (calculated as the remainder after validation and test sets are allocated). Patients are randomly shuffled (with a fixed seed for reproducibility) before assignment. @@ -502,7 +529,7 @@ def setup_hold_out_sets( patient_id_col ].map(patient_to_split_mapping) - def get_all_patientids_in_split(self, split: str) -> str: + def get_all_patientids_in_split(self, split: str) -> list: """ Retrieves all patient IDs belonging to a specific data split. @@ -617,7 +644,7 @@ def get_patient_data(self, patientid: str) -> dict: def infer_var_types(self): """ - Fills self.dm.variable_types for every candidate forecasting variable. + Fills self.variable_types for every candidate forecasting variable. Classifies as "numeric" if at least `self.config.numeric_detect_min_fraction` of values can be parsed as numeric, otherwise "categorical". """ From 3fd56fd6d2996f9c00b5b7ac538f9c5d3ed21f3a Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 16:15:33 +0000 Subject: [PATCH 34/36] Minor fixes for notebooks --- examples/advanced/custom_output/custom_summarized_row.ipynb | 2 +- .../custom_splitting/training_custom_split_events.ipynb | 2 +- .../custom_splitting/training_individual_splitters.ipynb | 4 ++-- .../pretraining/end_to_end_llm_training_with_pretrain.ipynb | 5 ++--- twinweaver/common/converter_base.py | 4 ++-- twinweaver/instruction/converter_instruction.py | 2 ++ 6 files changed, 10 insertions(+), 9 deletions(-) diff --git a/examples/advanced/custom_output/custom_summarized_row.ipynb b/examples/advanced/custom_output/custom_summarized_row.ipynb index e790228..58bf69f 100644 --- a/examples/advanced/custom_output/custom_summarized_row.ipynb +++ b/examples/advanced/custom_output/custom_summarized_row.ipynb @@ -520,7 +520,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv_dev", + "display_name": ".venv_dev_gpu", "language": "python", "name": "python3" }, diff --git a/examples/advanced/custom_splitting/training_custom_split_events.ipynb b/examples/advanced/custom_splitting/training_custom_split_events.ipynb index 2d2c48a..68976a7 100644 --- a/examples/advanced/custom_splitting/training_custom_split_events.ipynb +++ b/examples/advanced/custom_splitting/training_custom_split_events.ipynb @@ -306,7 +306,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv_dev", + "display_name": ".venv_dev_gpu", "language": "python", "name": "python3" }, diff --git a/examples/advanced/custom_splitting/training_individual_splitters.ipynb b/examples/advanced/custom_splitting/training_individual_splitters.ipynb index 56718f3..54fc0ad 100644 --- a/examples/advanced/custom_splitting/training_individual_splitters.ipynb +++ b/examples/advanced/custom_splitting/training_individual_splitters.ipynb @@ -304,13 +304,13 @@ "source": [ "date = reference_dates[\"date\"][0]\n", "return_list = converter.reverse_conversion(p_converted[\"answer\"], dm, date)\n", - "return_list[2][\"result\"]" + "return_list[0][\"result\"]" ] } ], "metadata": { "kernelspec": { - "display_name": ".venv_dev", + "display_name": ".venv_dev_gpu", "language": "python", "name": "python3" }, diff --git a/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb b/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb index 3e60fa5..fde4635 100644 --- a/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb +++ b/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb @@ -89,7 +89,6 @@ "dm.process_indication_data()\n", "dm.setup_unique_mapping_of_events()\n", "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", - "dm.infer_var_types()\n", "\n", "converter = ConverterPretrain(config=config, dm=dm)" ] @@ -505,7 +504,7 @@ "outputs": [], "source": [ "# Show the generated answer\n", - "generated_answer" + "print(generated_answer)" ] }, { @@ -541,7 +540,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv_test", + "display_name": ".venv_dev_gpu", "language": "python", "name": "python3" }, diff --git a/twinweaver/common/converter_base.py b/twinweaver/common/converter_base.py index 56d0c7a..3f76c07 100644 --- a/twinweaver/common/converter_base.py +++ b/twinweaver/common/converter_base.py @@ -193,8 +193,8 @@ def _preprocess_constant_date( if not self.config.constant_birthdate_columns_silence_print: print( f"Converted integer ages in {self.config.constant_birthdate_column} to age format." - "To silence this print statement set constant_birthdate_columns_silence_print to True" - "in the config." + " To silence this print statement set constant_birthdate_columns_silence_print to True" + " in the config." ) # Try converting the column to datetime if it is not already, if doesn't work then just keep it diff --git a/twinweaver/instruction/converter_instruction.py b/twinweaver/instruction/converter_instruction.py index dd21c70..42c88b6 100644 --- a/twinweaver/instruction/converter_instruction.py +++ b/twinweaver/instruction/converter_instruction.py @@ -290,6 +290,8 @@ def forward_conversion( # If events is None, set to empty list for easier processing if event_splits is None: event_splits = [] + if forecasting_splits is None: + forecasting_splits = [] #: make assertions that data has same split and lot date all_lot_dates_events = [x.lot_date for x in event_splits] From b374ffb06cd3c96241a15f75fe9183b84fcee75d Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 16:19:05 +0000 Subject: [PATCH 35/36] Added more individual converter unit tests --- tests/test_converter.py | 182 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 182 insertions(+) diff --git a/tests/test_converter.py b/tests/test_converter.py index 6633bbe..d5d713a 100644 --- a/tests/test_converter.py +++ b/tests/test_converter.py @@ -319,3 +319,185 @@ def test_reverse_conversion(setup_components): assert e_res["task_type"] == task_events assert e_res["result"]["censoring"].iloc[0].item() is False assert e_res["result"]["occurred"].iloc[0].item() is False + + +# ── Tests for forecasting-only and events-only conversion ────────────────── + + +@pytest.fixture +def setup_forecasting_only(mock_config, sample_data): + """Helper to set up components with only a forecasting splitter (no events splitter).""" + df_events, df_constant, df_constant_desc = sample_data + mock_config.split_event_category = "lot" + mock_config.event_category_forecast = ["lab"] + mock_config.event_category_events_prediction_with_naming = {"death": "death", "progression": "next progression"} + mock_config.constant_columns_to_use = ["birthyear", "gender", "histology", "smoking_history"] + mock_config.constant_birthdate_column = "birthyear" + + dm = DataManager(config=mock_config) + dm.load_indication_data(df_events, df_constant, df_constant_desc) + dm.process_indication_data() + dm.setup_unique_mapping_of_events() + dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1) + dm.infer_var_types() + + splitter_forecast = DataSplitterForecasting( + data_manager=dm, + config=mock_config, + max_forecasted_trajectory_length=pd.Timedelta(days=90), + max_split_length_after_split_event=pd.Timedelta(days=90), + ) + splitter_forecast.setup_statistics() + + data_splitter = DataSplitter(data_splitter_forecasting=splitter_forecast) + + converter = ConverterInstruction( + nr_tokens_budget_total=4096, config=mock_config, dm=dm, variable_stats=splitter_forecast.variable_stats + ) + + return dm, data_splitter, converter, splitter_forecast + + +@pytest.fixture +def setup_events_only(mock_config, sample_data): + """Helper to set up components with only an events splitter (no forecasting splitter).""" + df_events, df_constant, df_constant_desc = sample_data + mock_config.split_event_category = "lot" + mock_config.event_category_forecast = ["lab"] + mock_config.event_category_events_prediction_with_naming = {"death": "death", "progression": "next progression"} + mock_config.constant_columns_to_use = ["birthyear", "gender", "histology", "smoking_history"] + mock_config.constant_birthdate_column = "birthyear" + + dm = DataManager(config=mock_config) + dm.load_indication_data(df_events, df_constant, df_constant_desc) + dm.process_indication_data() + dm.setup_unique_mapping_of_events() + dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1) + dm.infer_var_types() + + splitter_events = DataSplitterEvents( + dm, + config=mock_config, + max_length_to_sample=pd.Timedelta(weeks=104), + min_length_to_sample=pd.Timedelta(weeks=1), + max_split_length_after_split_event=pd.Timedelta(days=90), + ) + splitter_events.setup_variables() + + data_splitter = DataSplitter(data_splitter_events=splitter_events) + + converter = ConverterInstruction(nr_tokens_budget_total=4096, config=mock_config, dm=dm, variable_stats=None) + + return dm, data_splitter, converter + + +def test_forward_conversion_forecasting_only_training(setup_forecasting_only): + """Test training conversion with only forecasting splits (no events).""" + dm, data_splitter, converter, _ = setup_forecasting_only + patient_data = dm.get_patient_data("p0") + + f_splits, e_splits, _ = data_splitter.get_splits_from_patient_with_target(patient_data) + + assert e_splits is None # No events splitter configured + assert f_splits is not None + + result = converter.forward_conversion( + forecasting_splits=f_splits[0], + event_splits=None, + override_mode_to_select_forecasting="forecasting", + ) + + instruction = result["instruction"] + answer = result["answer"] + + # Should contain forecasting task but no events task + assert "Task 1 is forecasting:" in instruction + assert "time to event prediction:" not in instruction + assert "Starting with demographic data:" in instruction + + # Answer should contain exactly one task + assert len(answer) > 0 + assert "Task 1 is forecasting:" in answer + assert "time to event prediction:" not in answer + + +def test_forward_conversion_events_only_training(setup_events_only): + """Test training conversion with only event splits (no forecasting).""" + dm, data_splitter, converter = setup_events_only + patient_data = dm.get_patient_data("p0") + + f_splits, e_splits, _ = data_splitter.get_splits_from_patient_with_target(patient_data) + + assert f_splits is None # No forecasting splitter configured + assert e_splits is not None + + result = converter.forward_conversion( + forecasting_splits=None, + event_splits=e_splits[0], + ) + + instruction = result["instruction"] + answer = result["answer"] + + # Should contain events task but no forecasting task + assert "time to event prediction:" in instruction + assert "forecasting:" not in instruction + assert "Starting with demographic data:" in instruction + + # Answer should contain events task only + assert len(answer) > 0 + assert "time to event prediction:" in answer + assert "forecasting:" not in answer + + +def test_forward_conversion_inference_forecasting_only(setup_forecasting_only): + """Test inference conversion with only a forecasting split (no events).""" + dm, data_splitter, converter, _ = setup_forecasting_only + patient_data = dm.get_patient_data("p0") + + f_split, e_split = data_splitter.get_splits_from_patient_inference( + patient_data, + inference_type="forecasting", + forecasting_override_variables_to_predict=["hemoglobin_-_718-7"], + ) + + assert e_split is None + assert f_split is not None + + result = converter.forward_conversion_inference( + forecasting_split=f_split, + forecasting_future_weeks_per_variable={"hemoglobin_-_718-7": [4, 8, 12]}, + event_split=None, + ) + + instruction = result["instruction"] + assert result["answer"] is None + assert "hemoglobin" in instruction + assert "future weeks 4, 8, 12" in instruction + assert "Task 1 is forecasting:" in instruction + assert "time to event prediction:" not in instruction + + +def test_forward_conversion_inference_events_only(setup_events_only): + """Test inference conversion with only an event split (no forecasting).""" + dm, data_splitter, converter = setup_events_only + patient_data = dm.get_patient_data("p0") + + f_split, e_split = data_splitter.get_splits_from_patient_inference( + patient_data, + inference_type="events", + ) + + assert f_split is None + assert e_split is not None + + result = converter.forward_conversion_inference( + forecasting_split=None, + event_split=e_split, + ) + + instruction = result["instruction"] + assert result["answer"] is None + assert "Task 1 is time to event prediction:" in instruction + assert "forecasting:" not in instruction + assert "Starting with demographic data:" in instruction From e708247d10349c2d11f914706b41e61e51886338 Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Wed, 18 Mar 2026 16:25:08 +0000 Subject: [PATCH 36/36] Updated docs --- .../01_data_preparation_for_training.ipynb | 2 +- .../custom_output/custom_summarized_row.ipynb | 2 +- .../inference_individual_splitters.md | 17 +++++ .../inference_individual_splitters.py | 70 ++++++++---------- .../training_custom_split_events.ipynb | 7 +- .../training_forecasting_qa.ipynb | 17 +++-- .../training_forecasting_splitter_only.ipynb | 28 ++++---- .../training_individual_splitters.ipynb | 41 +++++------ .../forecasting_vllm_inference.ipynb | 72 ++++++++++++++----- ...nd_to_end_llm_training_with_pretrain.ipynb | 5 +- .../pretraining/prepare_pretraining_data.md | 17 +++++ .../tte_probability_inference.ipynb | 10 ++- .../02_llm_finetuning_challenge.ipynb | 4 +- .../integrations/meds_data_import.ipynb | 1 + mkdocs.yml | 25 +++++-- 15 files changed, 203 insertions(+), 115 deletions(-) create mode 100644 docs/examples/advanced/custom_splitting/inference_individual_splitters.md create mode 100644 docs/examples/advanced/pretraining/prepare_pretraining_data.md diff --git a/docs/examples/01_data_preparation_for_training.ipynb b/docs/examples/01_data_preparation_for_training.ipynb index e74876d..d42c8f7 100644 --- a/docs/examples/01_data_preparation_for_training.ipynb +++ b/docs/examples/01_data_preparation_for_training.ipynb @@ -423,7 +423,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv_dev", + "display_name": ".venv_dev_gpu", "language": "python", "name": "python3" }, diff --git a/docs/examples/advanced/custom_output/custom_summarized_row.ipynb b/docs/examples/advanced/custom_output/custom_summarized_row.ipynb index e790228..58bf69f 100644 --- a/docs/examples/advanced/custom_output/custom_summarized_row.ipynb +++ b/docs/examples/advanced/custom_output/custom_summarized_row.ipynb @@ -520,7 +520,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv_dev", + "display_name": ".venv_dev_gpu", "language": "python", "name": "python3" }, diff --git a/docs/examples/advanced/custom_splitting/inference_individual_splitters.md b/docs/examples/advanced/custom_splitting/inference_individual_splitters.md new file mode 100644 index 0000000..cd5cb1c --- /dev/null +++ b/docs/examples/advanced/custom_splitting/inference_individual_splitters.md @@ -0,0 +1,17 @@ +# Custom Splitting for Inference + +This script demonstrates how to use individual splitters (`DataSplitterForecasting` and `DataSplitterEvents`) combined via the unified `DataSplitter` API for **inference** — i.e., when you only have input data and no target labels. + +Key concepts: + +- Configuring split events and forecasting categories +- Using `DataSplitter.get_splits_from_patient_inference()` to generate splits at inference time +- Converting splits to text prompts with `ConverterInstruction.forward_conversion_inference()` +- Overriding which variables/events to predict + +!!! note "Run from project root" + This script should be run from the root folder of the TwinWeaver repository. + +```python title="inference_individual_splitters.py" +--8<-- "examples/advanced/custom_splitting/inference_individual_splitters.py" +``` diff --git a/docs/examples/advanced/custom_splitting/inference_individual_splitters.py b/docs/examples/advanced/custom_splitting/inference_individual_splitters.py index 8b773db..3840087 100644 --- a/docs/examples/advanced/custom_splitting/inference_individual_splitters.py +++ b/docs/examples/advanced/custom_splitting/inference_individual_splitters.py @@ -2,6 +2,7 @@ DataSplitterForecasting, DataManager, DataSplitterEvents, + DataSplitter, ConverterInstruction, Config, ) @@ -45,19 +46,26 @@ def __init__( self.dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1) self.dm.infer_var_types() - self.data_splitter_events = DataSplitterEvents( + data_splitter_events = DataSplitterEvents( self.dm, config=self.config, max_length_to_sample=pd.Timedelta(weeks=104), min_length_to_sample=pd.Timedelta(weeks=1), ) - self.data_splitter_events.setup_variables() - self.data_splitter_forecasting = DataSplitterForecasting( + data_splitter_events.setup_variables() + data_splitter_forecasting = DataSplitterForecasting( data_manager=self.dm, config=self.config, max_forecasted_trajectory_length=pd.Timedelta(days=90), ) - self.data_splitter_forecasting.setup_statistics() + data_splitter_forecasting.setup_statistics() + + # Use the unified DataSplitter API that combines both splitters + self.data_splitter = DataSplitter( + data_splitter_events=data_splitter_events, + data_splitter_forecasting=data_splitter_forecasting, + ) + self.converter = ConverterInstruction( nr_tokens_budget_total=8192, config=self.config, @@ -71,48 +79,26 @@ def convert_full_to_string_for_one_patient(self, patientid, override_events_or_f # To simulate that we only have input, half the events patient_data["events"] = patient_data["events"].iloc[: int(len(patient_data["events"]) / 2)] - # Here then split date - split_date = patient_data["events"]["date"].iloc[-1] - - #: generate event split - NOTE: this if statement is only to exemplify both cases! - if override_events_or_forecasting == "events": - ####### Example if we want to override for events - - events_splits = self.data_splitter_events.get_splits_from_patient( - patient_data, - max_nr_samples=1, - override_split_dates=[split_date], - override_category="death", - override_end_week_delta=52, - ) - # We just pick the first one - events_split = events_splits[0][0] - - #: no forecasting split - forecast_split = None - forecasting_times_to_predict = None - else: - ####### Example if we want to override for forecasting - - #: generate forecasting split - forecast_splits = self.data_splitter_forecasting.get_splits_from_patient( - patient_data, - nr_samples_per_split=1, - filter_outliers=False, - override_split_dates=[split_date], - override_variables_to_predict=["Neutrophils"], - ) - # We just pick the first one - forecast_split = forecast_splits[0][0] - - # We set which weeks to predict + # Use the unified DataSplitter API for inference + forecast_split, events_split = self.data_splitter.get_splits_from_patient_inference( + patient_data, + inference_type=override_events_or_forecasting, + forecasting_override_variables_to_predict=["Neutrophils"] + if override_events_or_forecasting != "events" + else None, + events_override_category="death" if override_events_or_forecasting != "forecasting" else None, + events_override_observation_time_delta=pd.Timedelta(weeks=52) + if override_events_or_forecasting != "forecasting" + else None, + ) + + # Set which weeks to predict for forecasting (if applicable) + forecasting_times_to_predict = None + if forecast_split is not None: forecasting_times_to_predict = { "Neutrophils": [1, 2, 8, 11], } - #: no events split - events_split = None - # Convert to text converted = self.converter.forward_conversion_inference( forecasting_split=forecast_split, diff --git a/docs/examples/advanced/custom_splitting/training_custom_split_events.ipynb b/docs/examples/advanced/custom_splitting/training_custom_split_events.ipynb index 8a0bc01..68976a7 100644 --- a/docs/examples/advanced/custom_splitting/training_custom_split_events.ipynb +++ b/docs/examples/advanced/custom_splitting/training_custom_split_events.ipynb @@ -219,7 +219,8 @@ "source": [ "forecasting_splits, events_splits, reference_dates = data_splitter.get_splits_from_patient_with_target(\n", " patient_data,\n", - ")" + ")\n", + "# Note, forecasting_splits will be none here" ] }, { @@ -239,7 +240,7 @@ "source": [ "split_idx = 0\n", "p_converted = converter.forward_conversion(\n", - " forecasting_splits=forecasting_splits[split_idx],\n", + " forecasting_splits=None, # Set to None since we don't want to generate forecasting tasks\n", " event_splits=events_splits[split_idx],\n", ")" ] @@ -305,7 +306,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv_dev", + "display_name": ".venv_dev_gpu", "language": "python", "name": "python3" }, diff --git a/docs/examples/advanced/custom_splitting/training_forecasting_qa.ipynb b/docs/examples/advanced/custom_splitting/training_forecasting_qa.ipynb index c79a9a1..4d0fd0c 100644 --- a/docs/examples/advanced/custom_splitting/training_forecasting_qa.ipynb +++ b/docs/examples/advanced/custom_splitting/training_forecasting_qa.ipynb @@ -376,10 +376,7 @@ "# Reverse convert the QA mode output\n", "return_list = converter.reverse_conversion(p_qa[\"answer\"], dm, date)\n", "\n", - "for task in return_list:\n", - " print(f\"Task type: {task['task_type']}\")\n", - " print(f\"Result: {task['result']}\")\n", - " print()" + "return_list[2][\"result\"]" ] }, { @@ -406,13 +403,21 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": ".venv_dev", "language": "python", "name": "python3" }, "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", "name": "python", - "version": "3.10.0" + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" } }, "nbformat": 4, diff --git a/docs/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb b/docs/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb index 6f6a1cb..409abe2 100644 --- a/docs/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb +++ b/docs/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb @@ -5,7 +5,7 @@ "id": "0", "metadata": {}, "source": [ - "# Forecasting-Only Example: Training Data Generation with Custom Dataset" + "# Forecasting-Only Example: Training Data Generation with the Unified DataSplitter API" ] }, { @@ -27,6 +27,7 @@ "\n", "from twinweaver import (\n", " DataSplitterForecasting,\n", + " DataSplitter,\n", " DataManager,\n", " ConverterInstruction,\n", " Config,\n", @@ -66,7 +67,7 @@ "id": "6", "metadata": {}, "source": [ - "Set up the data manager and the forecasting-only pipeline." + "Set up the data manager and the forecasting-only pipeline using the unified `DataSplitter` API. By passing only `data_splitter_forecasting`, the unified interface handles the forecasting-only case automatically." ] }, { @@ -117,6 +118,9 @@ "# In case you manually want to override the variables for forecasting selectiong, you can skip this next line.\n", "data_splitter_forecasting.setup_statistics()\n", "\n", + "# Use the unified DataSplitter API with only the forecasting splitter\n", + "data_splitter = DataSplitter(data_splitter_forecasting=data_splitter_forecasting)\n", + "\n", "converter = ConverterInstruction(\n", " nr_tokens_budget_total=8192,\n", " config=config,\n", @@ -199,9 +203,9 @@ "id": "16", "metadata": {}, "source": [ - "We start by generating random \"splits\" in the patient trajectory. We can make multiple relevant samples from each patient trajectory (e.g. depending on when the therapy started), and also to predict different variables (e.g. neutrophils/hemoglobin/... for forecasting).\n", + "We start by generating random \"splits\" in the patient trajectory using the unified `DataSplitter.get_splits_from_patient_with_target` method. We can make multiple relevant samples from each patient trajectory (e.g. depending on when the therapy started), and also to predict different variables (e.g. neutrophils/hemoglobin/... for forecasting).\n", "\n", - "Here we generate these random splits. We can also manually override them (see other examples on inference)." + "Since we only provided a forecasting splitter, `events_splits` will be `None`." ] }, { @@ -211,13 +215,13 @@ "metadata": {}, "outputs": [], "source": [ - "processed_splits_fc, split_dates = data_splitter_forecasting.get_splits_from_patient(\n", + "forecasting_splits, events_splits, reference_dates = data_splitter.get_splits_from_patient_with_target(\n", " patient_data,\n", - " nr_samples_per_split=4,\n", - " filter_outliers=False,\n", - " include_metadata=True,\n", + " forecasting_nr_samples_per_split=4,\n", + " forecasting_filter_outliers=False,\n", " max_num_splits_per_split_event=2,\n", - ")" + ")\n", + "# Note, events_splits will be None here since we don't have any split events for this patient" ] }, { @@ -225,7 +229,7 @@ "id": "18", "metadata": {}, "source": [ - "Now for each split, we can generate the formatted strings. Note that `event_splits` is left empty since this example only uses the forecasting splitter." + "Now for each split, we can generate the formatted strings. Note that `events_splits` is `None` since we only provided a forecasting splitter, so we pass an empty list for `event_splits`." ] }, { @@ -237,8 +241,8 @@ "source": [ "split_idx = 0\n", "p_converted = converter.forward_conversion(\n", - " forecasting_splits=processed_splits_fc[split_idx],\n", - " event_splits=[], # Not needed for forecasting-only splitter\n", + " forecasting_splits=forecasting_splits[split_idx],\n", + " event_splits=None, # Not needed for forecasting-only splitter\n", ")" ] }, diff --git a/docs/examples/advanced/custom_splitting/training_individual_splitters.ipynb b/docs/examples/advanced/custom_splitting/training_individual_splitters.ipynb index 5b42911..54fc0ad 100644 --- a/docs/examples/advanced/custom_splitting/training_individual_splitters.ipynb +++ b/docs/examples/advanced/custom_splitting/training_individual_splitters.ipynb @@ -5,7 +5,7 @@ "id": "0", "metadata": {}, "source": [ - "# Example for single patient to convert using the instruction setup with custom dataset" + "# Example for single patient to convert using the unified DataSplitter API with custom dataset" ] }, { @@ -27,8 +27,9 @@ "\n", "from twinweaver import (\n", " DataSplitterForecasting,\n", - " DataManager,\n", " DataSplitterEvents,\n", + " DataSplitter,\n", + " DataManager,\n", " ConverterInstruction,\n", " Config,\n", ")" @@ -67,7 +68,7 @@ "id": "6", "metadata": {}, "source": [ - "Set up the data managers which hold the patient data." + "Set up the data managers and the unified `DataSplitter` which combines both event and forecasting splitters." ] }, { @@ -130,6 +131,12 @@ "# In case you manually want to override the variables for forecasting selectiong, you can skip this next line.\n", "data_splitter_forecasting.setup_statistics()\n", "\n", + "# Use the unified DataSplitter API that combines both splitters\n", + "data_splitter = DataSplitter(\n", + " data_splitter_events=data_splitter_events,\n", + " data_splitter_forecasting=data_splitter_forecasting,\n", + ")\n", + "\n", "converter = ConverterInstruction(\n", " nr_tokens_budget_total=8192,\n", " config=config,\n", @@ -212,9 +219,9 @@ "id": "16", "metadata": {}, "source": [ - "We start by generating random \"splits\" in the patient trajectory. We can make multiple relevant samples from each patient trajectory (e.g. depending on when the therapy started), and also to predict different variables (e.g. neutrophils/hemoglobin/... for forecasting, death/progression/metastases/next treatment for event).\n", + "We start by generating random \"splits\" in the patient trajectory using the unified `DataSplitter.get_splits_from_patient_with_target` method. This ensures that both forecasting and event splits use the same anchor points in time. We can make multiple relevant samples from each patient trajectory (e.g. depending on when the therapy started), and also to predict different variables (e.g. neutrophils/hemoglobin/... for forecasting, death/progression/metastases/next treatment for event).\n", "\n", - "Here we generate these random splits. We can also manually override them (see other examples on inference)." + "We can also manually override them (see other examples on inference)." ] }, { @@ -224,18 +231,12 @@ "metadata": {}, "outputs": [], "source": [ - "processed_splits_fc, split_dates = data_splitter_forecasting.get_splits_from_patient(\n", + "forecasting_splits, events_splits, reference_dates = data_splitter.get_splits_from_patient_with_target(\n", " patient_data,\n", - " nr_samples_per_split=4,\n", - " filter_outliers=False,\n", - " include_metadata=True,\n", + " forecasting_nr_samples_per_split=4,\n", + " forecasting_filter_outliers=False,\n", " max_num_splits_per_split_event=2,\n", - ")\n", - "\n", - "processed_splits_ev = data_splitter_events.get_splits_from_patient(\n", - " patient_data,\n", - " reference_split_dates=split_dates,\n", - " max_nr_samples_per_split=3,\n", + " events_max_nr_samples_per_split=3,\n", ")" ] }, @@ -256,8 +257,8 @@ "source": [ "split_idx = 0\n", "p_converted = converter.forward_conversion(\n", - " forecasting_splits=processed_splits_fc[split_idx],\n", - " event_splits=processed_splits_ev[split_idx],\n", + " forecasting_splits=forecasting_splits[split_idx],\n", + " event_splits=events_splits[split_idx],\n", ")" ] }, @@ -301,15 +302,15 @@ "metadata": {}, "outputs": [], "source": [ - "date = split_dates[\"date\"][0]\n", + "date = reference_dates[\"date\"][0]\n", "return_list = converter.reverse_conversion(p_converted[\"answer\"], dm, date)\n", - "return_list[2][\"result\"]" + "return_list[0][\"result\"]" ] } ], "metadata": { "kernelspec": { - "display_name": ".venv_dev", + "display_name": ".venv_dev_gpu", "language": "python", "name": "python3" }, diff --git a/docs/examples/advanced/forecasting_inference/forecasting_vllm_inference.ipynb b/docs/examples/advanced/forecasting_inference/forecasting_vllm_inference.ipynb index c2b2826..f57bd1c 100644 --- a/docs/examples/advanced/forecasting_inference/forecasting_vllm_inference.ipynb +++ b/docs/examples/advanced/forecasting_inference/forecasting_vllm_inference.ipynb @@ -98,7 +98,7 @@ "# ---------------------------------------------------------------------------\n", "# Model & server settings\n", "# ---------------------------------------------------------------------------\n", - "MODEL_PATH = \"microsoft/Phi-4-mini-instruct\" # ⚠️ Replace with your fine-tuned model path\n", + "MODEL_PATH = \"microsoft/Phi-4-mini-instruct\" # ⚠️ Replace with your fine-tuned model path!!!!!!!!!!!\n", "TOKENIZER_PATH = MODEL_PATH # Usually the same as the model path\n", "\n", "VLLM_PORT = 8000\n", @@ -478,6 +478,22 @@ "id": "18", "metadata": {}, "outputs": [], + "source": [ + "converter = ConverterInstruction(\n", + " nr_tokens_budget_total=MAX_CONTEXT_LENGTH,\n", + " config=config,\n", + " dm=dm,\n", + " variable_stats=data_splitter_forecasting.variable_stats,\n", + ")\n", + "print(\"✅ Converter reloaded with latest code\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19", + "metadata": {}, + "outputs": [], "source": [ "# Parse the generated texts into structured DataFrames\n", "df_results = parse_forecasting_results(\n", @@ -488,13 +504,17 @@ " aggregate_samples=(N_SAMPLES > 1), # Only aggregate if we have multiple samples\n", ")\n", "\n", - "print(f\"Parsed {len(df_results)} prediction rows for {df_results['patientid'].nunique()} patients\")\n", - "df_results.head(20)" + "if df_results.empty:\n", + " print(\"⚠️ No predictions could be parsed. This is expected when using a non-fine-tuned model.\")\n", + " print(\" Fine-tune a model first (see 03_end_to_end_llm_finetuning.ipynb) for meaningful results.\")\n", + "else:\n", + " print(f\"Parsed {len(df_results)} prediction rows for {df_results['patientid'].nunique()} patients\")\n", + " df_results.head(20)" ] }, { "cell_type": "markdown", - "id": "19", + "id": "20", "metadata": {}, "source": [ "### Understanding the output\n", @@ -514,7 +534,7 @@ }, { "cell_type": "markdown", - "id": "20", + "id": "21", "metadata": {}, "source": [ "## 7. Multi-sample aggregation (optional)\n", @@ -529,7 +549,7 @@ { "cell_type": "code", "execution_count": null, - "id": "21", + "id": "22", "metadata": {}, "outputs": [], "source": [ @@ -557,14 +577,17 @@ " aggregate_samples=True, # Average numeric values across samples\n", ")\n", "\n", - "print(f\"Aggregated results: {len(df_aggregated)} rows for {df_aggregated['patientid'].nunique()} patients\")\n", - "df_aggregated.head(20)" + "if df_aggregated.empty:\n", + " print(\"⚠️ No predictions could be parsed. This is expected when using a non-fine-tuned model.\")\n", + "else:\n", + " print(f\"Aggregated results: {len(df_aggregated)} rows for {df_aggregated['patientid'].nunique()} patients\")\n", + " df_aggregated.head(20)" ] }, { "cell_type": "code", "execution_count": null, - "id": "22", + "id": "23", "metadata": {}, "outputs": [], "source": [ @@ -577,14 +600,17 @@ " aggregate_samples=False, # Keep individual samples\n", ")\n", "\n", - "print(f\"Individual results: {len(df_individual)} rows\")\n", - "print(f\"Samples per patient: {df_individual.groupby('patientid')['sample_idx'].nunique().describe()}\")\n", - "df_individual.head(20)" + "if df_individual.empty:\n", + " print(\"⚠️ No predictions could be parsed. This is expected when using a non-fine-tuned model.\")\n", + "else:\n", + " print(f\"Individual results: {len(df_individual)} rows\")\n", + " print(f\"Samples per patient: {df_individual.groupby('patientid')['sample_idx'].nunique().describe()}\")\n", + " df_individual.head(20)" ] }, { "cell_type": "markdown", - "id": "23", + "id": "24", "metadata": {}, "source": [ "## 8. Clean up\n", @@ -595,7 +621,7 @@ { "cell_type": "code", "execution_count": null, - "id": "24", + "id": "25", "metadata": {}, "outputs": [], "source": [ @@ -608,7 +634,7 @@ }, { "cell_type": "markdown", - "id": "25", + "id": "26", "metadata": {}, "source": [ "## Summary\n", @@ -649,8 +675,22 @@ } ], "metadata": { + "kernelspec": { + "display_name": ".venv_dev_gpu", + "language": "python", + "name": "python3" + }, "language_info": { - "name": "python" + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" } }, "nbformat": 4, diff --git a/docs/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb b/docs/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb index 3e60fa5..fde4635 100644 --- a/docs/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb +++ b/docs/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb @@ -89,7 +89,6 @@ "dm.process_indication_data()\n", "dm.setup_unique_mapping_of_events()\n", "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n", - "dm.infer_var_types()\n", "\n", "converter = ConverterPretrain(config=config, dm=dm)" ] @@ -505,7 +504,7 @@ "outputs": [], "source": [ "# Show the generated answer\n", - "generated_answer" + "print(generated_answer)" ] }, { @@ -541,7 +540,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv_test", + "display_name": ".venv_dev_gpu", "language": "python", "name": "python3" }, diff --git a/docs/examples/advanced/pretraining/prepare_pretraining_data.md b/docs/examples/advanced/pretraining/prepare_pretraining_data.md new file mode 100644 index 0000000..cf03bad --- /dev/null +++ b/docs/examples/advanced/pretraining/prepare_pretraining_data.md @@ -0,0 +1,17 @@ +# Prepare Pretraining Data + +This script demonstrates how to prepare pretraining data using the `ConverterPretrain` class. Unlike instruction-tuning, pretraining converts full patient timelines into continuous text without explicit question/answer formatting. + +Key concepts: + +- Setting up `Config` and `DataManager` for pretraining +- Using `ConverterPretrain.forward_conversion()` to convert patient data to text +- Verifying data integrity with `ConverterPretrain.reverse_conversion()` +- Checking round-trip consistency with `get_difference_in_event_dataframes()` + +!!! note "Run from project root" + This script should be run from the root folder of the TwinWeaver repository. + +```python title="prepare_pretraining_data.py" +--8<-- "examples/advanced/pretraining/prepare_pretraining_data.py" +``` diff --git a/docs/examples/advanced/tte_inference/tte_probability_inference.ipynb b/docs/examples/advanced/tte_inference/tte_probability_inference.ipynb index 9b6a861..c442b0b 100644 --- a/docs/examples/advanced/tte_inference/tte_probability_inference.ipynb +++ b/docs/examples/advanced/tte_inference/tte_probability_inference.ipynb @@ -590,7 +590,15 @@ "print(f\"\\n✅ Total rows: {len(df_all_horizons)}\")\n", "df_all_horizons = df_all_horizons.sort_values([\"patientid\", \"week_horizon\"])\n", "df_all_horizons[\n", - " [\"patientid\", \"week_horizon\", \"probability_occurrence\", \"probability_no_occurrence\", \"probability_censored\"]\n", + " [\n", + " \"patientid\",\n", + " \"week_horizon\",\n", + " \"probability_occurrence\",\n", + " \"probability_no_occurrence\",\n", + " \"probability_censored\",\n", + " \"probability_occurrence_renormalized\",\n", + " \"probability_no_occurrence_renormalized\",\n", + " ]\n", "]" ] }, diff --git a/docs/examples/hackathon/02_llm_finetuning_challenge.ipynb b/docs/examples/hackathon/02_llm_finetuning_challenge.ipynb index 07d0908..9405caf 100644 --- a/docs/examples/hackathon/02_llm_finetuning_challenge.ipynb +++ b/docs/examples/hackathon/02_llm_finetuning_challenge.ipynb @@ -776,9 +776,7 @@ "\n", "# Get the date of first line of therapy\n", "df_events_patient = patient_data[\"events\"].copy()\n", - "date_of_first_lot = df_events_patient.loc[\n", - " df_events_patient[\"event_category\"] == config.event_category_lot, \"date\"\n", - "].min()\n", + "date_of_first_lot = df_events_patient.loc[df_events_patient[\"event_category\"] == \"lot\", \"date\"].min()\n", "\n", "print(f\"Test patient: {test_patientid}\")\n", "print(f\"First LoT date: {date_of_first_lot}\")" diff --git a/docs/examples/integrations/meds_data_import.ipynb b/docs/examples/integrations/meds_data_import.ipynb index c593888..5183219 100644 --- a/docs/examples/integrations/meds_data_import.ipynb +++ b/docs/examples/integrations/meds_data_import.ipynb @@ -557,6 +557,7 @@ " custom_tasks=None,\n", ")\n", "\n", + "\n", "print(converted[\"instruction\"])" ] } diff --git a/mkdocs.yml b/mkdocs.yml index 131f47d..e6f6c2e 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -131,18 +131,29 @@ nav: - Data Preparation: examples/01_data_preparation_for_training.ipynb - Inference Prompt Prep: examples/02_inference_prompt_preparation.ipynb - End-to-End Finetuning: examples/03_end_to_end_llm_finetuning.ipynb + - Data Preprocessing: + - Raw Data Preprocessing: examples/data_preprocessing/raw_data_preprocessing.ipynb - Advanced: - - Custom Splitting (Training): examples/advanced/custom_splitting/training_individual_splitters.ipynb - - Custom Split Events: examples/advanced/custom_splitting/training_custom_split_events.ipynb - - Forecasting-Only Splitter: examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb - - Custom Splitting (Inference): examples/advanced/custom_splitting/inference_individual_splitters.md - - Custom Text Generation: examples/advanced/custom_output/customizing_text_generation.ipynb - - Custom Summarized Row: examples/advanced/custom_output/custom_summarized_row.ipynb - - Pretraining: examples/advanced/pretraining/prepare_pretraining_data.md + - Custom Splitting: + - Training (Individual Splitters): examples/advanced/custom_splitting/training_individual_splitters.ipynb + - Training (Custom Split Events): examples/advanced/custom_splitting/training_custom_split_events.ipynb + - Training (Forecasting QA): examples/advanced/custom_splitting/training_forecasting_qa.ipynb + - Training (Forecasting-Only Splitter): examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb + - Inference (Individual Splitters): examples/advanced/custom_splitting/inference_individual_splitters.md + - Custom Output: + - Custom Text Generation: examples/advanced/custom_output/customizing_text_generation.ipynb + - Custom Summarized Row: examples/advanced/custom_output/custom_summarized_row.ipynb + - Pretraining: + - Prepare Pretraining Data: examples/advanced/pretraining/prepare_pretraining_data.md + - End-to-End Training with Pretrain: examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb - TTE Probability Inference: examples/advanced/tte_inference/tte_probability_inference.ipynb - Forecasting Inference: examples/advanced/forecasting_inference/forecasting_vllm_inference.ipynb - Integrations: - MEDS Import: examples/integrations/meds_data_import.ipynb + - Hackathon: + - Overview: examples/hackathon/README.md + - Data Preparation Challenge: examples/hackathon/01_data_preparation_challenge.ipynb + - LLM Finetuning Challenge: examples/hackathon/02_llm_finetuning_challenge.ipynb - API Reference: - Common: - Config: reference/common/config.md