diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index a57108f..0c43503 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,3 +1,11 @@
+# Exclude hackathon examples
+exclude: |
+ (?x)^(
+ docs/examples/hackathon/|
+ examples/hackathon/|
+ \^examples/hackathon
+ )
+
repos:
# 1. Standard "Cleanup" Hooks
- repo: https://github.com/pre-commit/pre-commit-hooks
@@ -11,7 +19,7 @@ repos:
# 2. Ruff (Linting + Formatting)
- repo: https://github.com/astral-sh/ruff-pre-commit
- rev: v0.9.3
+ rev: v0.9.3
hooks:
- id: ruff
args: [ --fix ]
@@ -21,4 +29,4 @@ repos:
- repo: https://github.com/kynan/nbstripout
rev: 0.8.1
hooks:
- - id: nbstripout
\ No newline at end of file
+ - id: nbstripout
diff --git a/README.md b/README.md
index 023de9a..fdea890 100644
--- a/README.md
+++ b/README.md
@@ -75,6 +75,7 @@ For users needing custom behavior or specific integrations:
* [`examples/advanced/custom_splitting/training_individual_splitters.ipynb`](examples/advanced/custom_splitting/training_individual_splitters.ipynb): Notebook demonstrating training data generation with individual splitters.
* [`examples/advanced/custom_splitting/training_custom_split_events.ipynb`](examples/advanced/custom_splitting/training_custom_split_events.ipynb): Notebook showing how to customize split events and forecast different event categories.
* [`examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb`](examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb): Forecasting-only example showing training data generation using only the `DataSplitterForecasting` (no event splitter).
+ * [`examples/advanced/custom_splitting/training_forecasting_qa.ipynb`](examples/advanced/custom_splitting/training_forecasting_qa.ipynb): Demonstrates the **Forecasting QA** mode, which bins continuous target values into discrete categories for classification-style prediction, and compares all three forecasting modes (`"forecasting"`, `"forecasting_qa"`, `"both"`).
* **Custom Text Generation**: [`examples/advanced/custom_output/customizing_text_generation.ipynb`](examples/advanced/custom_output/customizing_text_generation.ipynb)
* A comprehensive tutorial on customizing every textual component of the instruction generation pipeline. Learn how to modify preambles, event formatting, time units, genetic data tags, forecasting prompts, and more to adapt outputs for different LLMs, languages, or institutional requirements.
* **Custom Summarized Row**: [`examples/advanced/custom_output/custom_summarized_row.ipynb`](examples/advanced/custom_output/custom_summarized_row.ipynb)
@@ -133,7 +134,7 @@ config.event_category_forecast = ["lab"]
# 3. Mapping of specific time to events to predict (e.g., we want to predict 'death' and 'progression')
# Only needs to be set if you want to do time to event prediction
-config.data_splitter_events_variables_category_mapping = {
+config.event_category_events_prediction_with_naming = {
"death": "death",
"progression": "next progression", # Custom name in prompt: "next progression" instead of "progression"
}
@@ -143,7 +144,7 @@ dm = DataManager(config=config)
dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)
dm.process_indication_data()
dm.setup_unique_mapping_of_events()
-dm.setup_dataset_splits()
+dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)
dm.infer_var_types()
# This data splitter handles event prediction tasks
@@ -175,7 +176,6 @@ split_idx = 0
training_data = converter.forward_conversion(
forecasting_splits=forecasting_splits[split_idx],
event_splits=events_splits[split_idx],
- override_mode_to_select_forecasting="both",
)
# training_data now contains (Input, Target) pairs ready for LLM fine-tuning
diff --git a/docs/api-index.md b/docs/api-index.md
index d071451..d184bde 100644
--- a/docs/api-index.md
+++ b/docs/api-index.md
@@ -56,7 +56,7 @@ Handles data loading and management.
| [`DataManager.load_indication_data`](reference/common/data_manager.md#twinweaver.common.data_manager.DataManager.load_indication_data) | Method | Load data tables for a specific indication |
| [`DataManager.process_indication_data`](reference/common/data_manager.md#twinweaver.common.data_manager.DataManager.process_indication_data) | Method | Process loaded indication data |
| [`DataManager.setup_unique_mapping_of_events`](reference/common/data_manager.md#twinweaver.common.data_manager.DataManager.setup_unique_mapping_of_events) | Method | Create unique mapping for all events |
-| [`DataManager.setup_dataset_splits`](reference/common/data_manager.md#twinweaver.common.data_manager.DataManager.setup_dataset_splits) | Method | Split data into train/val/test sets |
+| [`DataManager.setup_hold_out_sets(validation_split=0.1, test_split=0.1)`](reference/common/data_manager.md#twinweaver.common.data_manager.DataManager.setup_hold_out_sets(validation_split=0.1, test_split=0.1)) | Method | Split data into train/val/test sets |
| [`DataManager.get_all_patientids_in_split`](reference/common/data_manager.md#twinweaver.common.data_manager.DataManager.get_all_patientids_in_split) | Method | Get all patient IDs in a specific split |
| [`DataManager.get_patient_split`](reference/common/data_manager.md#twinweaver.common.data_manager.DataManager.get_patient_split) | Method | Get the split assignment for a patient |
| [`DataManager.get_patient_data`](reference/common/data_manager.md#twinweaver.common.data_manager.DataManager.get_patient_data) | Method | Retrieve all data for a specific patient |
diff --git a/docs/data-splitting.md b/docs/data-splitting.md
index 459260c..526a07c 100644
--- a/docs/data-splitting.md
+++ b/docs/data-splitting.md
@@ -9,7 +9,7 @@ TwinWeaver provides specialized splitters for two complementary clinical predict
| `DataSplitterForecasting` | Forecasting continuous or categorical variables | Predict hemoglobin values over the next 90 days |
| `DataSplitterEvents` | Landmark event prediction (time-to-event) | Did the patient progress within 52 weeks? |
-A unified `DataSplitter` interface combines both, ensuring they share the same split dates for multi-task training.
+A unified `DataSplitter` interface combines one or both splitters into a single entry point. When both are supplied, it ensures they share the same split dates for multi-task training. Either splitter can also be used individually.
---
@@ -33,7 +33,7 @@ Patient timeline
Split dates are anchored to **split events** — a configurable event category (typically Line of Therapy, `"lot"`). The framework:
1. **Finds all split-event start dates** in the patient's history (e.g., every LoT start).
-2. **Identifies candidate dates** within a window around each split event (controlled by `max_split_length_after_split_event`, default 90 days).
+2. **Identifies candidate dates** within a window around each split event (controlled by `max_split_length_after_split_event`, default 0 days).
3. **Randomly samples** one or more candidate dates per split event (`max_num_splits_per_split_event`).
This anchoring ensures that training examples are centered on clinically meaningful time points rather than arbitrary dates.
@@ -67,7 +67,7 @@ For each candidate split date, the forecasting splitter:
1. **Checks variable eligibility**: A variable is valid at a given date only if it has at least `min_nr_variable_seen_previously` occurrences in the lookback window and `min_nr_variable_seen_after` occurrences in the forecast window.
2. **Samples variables**: Between `min_nr_variables_to_sample` and `max_nr_variables_to_sample` variables are selected per task, using weighted proportional sampling based on pre-computed statistics (optionally uniform sampling).
-3. **Creates the split**: Events before the split date form the input; future values of the sampled variables (within `max_forecast_time_for_value`) form the target.
+3. **Creates the split**: Events before the split date form the input; future values of the sampled variables (within `max_forecasted_trajectory_length`) form the target.
4. **Filters future LoT overlap**: Target events occurring after the next Line of Therapy start are excluded to avoid data leakage.
### Variable Statistics & Sampling
@@ -96,13 +96,13 @@ When `filter_outliers=True`, the **3-sigma strategy** clips target values to the
data_splitter_forecasting = DataSplitterForecasting(
data_manager=dm,
config=config,
- max_split_length_after_split_event=pd.Timedelta(days=90), # Window after split event
+ max_forecasted_trajectory_length=pd.Timedelta(days=90), # Forecast horizon (required)
+ max_split_length_after_split_event=pd.Timedelta(days=90), # Window after split event
max_lookback_time_for_value=pd.Timedelta(days=90), # Lookback for variable history
- max_forecast_time_for_value=pd.Timedelta(days=90), # Forecast horizon
min_nr_variable_seen_previously=1, # Min past occurrences
min_nr_variable_seen_after=1, # Min future occurrences
min_nr_variables_to_sample=1, # Min variables per task
- max_nr_variables_to_sample=3, # Max variables per task
+ max_nr_variables_to_sample=1, # Max variables per task
filtering_strategy="3-sigma", # Outlier handling
sampling_strategy="proportional", # Weighted or uniform sampling
)
@@ -124,7 +124,7 @@ flowchart TD
D --> E{Event occurred
within window
and before censoring event?}
E -->|Yes| F[occurred = True]
E -->|No| G{Censored by
next LoT or data end?}
- G -->|Next LoT| H[censored = new_therapy_start]
+ G -->|Next LoT| H[censored = new_split_date_start]
G -->|End of data| I[censored = end_of_data]
G -->|No censoring| J[censored = None
Event truly did not occur]
F --> K[Create DataSplitterEventsOption]
@@ -136,7 +136,7 @@ flowchart TD
For each candidate split date, the event splitter:
1. **Samples an event category** from the configured mapping (e.g., `"death"` or `"progression"`), avoiding duplicate categories per split.
-2. **Samples a prediction window** of random duration between `min_length_to_sample` (default: 1 week) and `max_length_to_sample` (default: 104 weeks). This trains the model to handle variable-length horizons.
+2. **Samples a prediction window** of random duration between `min_length_to_sample` and `max_length_to_sample` (both required, no defaults). This trains the model to handle variable-length horizons.
3. **Determines the outcome**:
- **Occurred**: The event was observed within the window before any censoring events.
- **Censored**: The observation was cut short by a new therapy start, end of data, or a data cutoff date.
@@ -149,8 +149,8 @@ For each candidate split date, the event splitter:
data_splitter_events = DataSplitterEvents(
data_manager=dm,
config=config,
- max_length_to_sample=pd.Timedelta(weeks=104), # Max prediction window
- min_length_to_sample=pd.Timedelta(weeks=1), # Min prediction window
+ max_length_to_sample=pd.Timedelta(weeks=104), # Max prediction window (required)
+ min_length_to_sample=pd.Timedelta(weeks=1), # Min prediction window (required)
unit_length_to_sample="weeks", # Window sampling unit
max_split_length_after_split_event=pd.Timedelta(days=90), # Window after split event
)
@@ -161,7 +161,7 @@ data_splitter_events = DataSplitterEvents(
The event-to-prediction mapping is configured via:
```python
-config.data_splitter_events_variables_category_mapping = {
+config.event_category_events_prediction_with_naming = {
"death": "death", # event_category → descriptive name in prompt
"progression": "next progression", # custom prompt label
}
@@ -171,14 +171,20 @@ config.data_splitter_events_variables_category_mapping = {
## Combined Splitting with `DataSplitter`
-The `DataSplitter` class provides a unified interface that coordinates both splitters. This is the **recommended approach** for generating multi-task training data, as it ensures forecasting and event prediction tasks share the same split dates.
+The `DataSplitter` class provides a unified interface that coordinates one or both splitters. At least one of `data_splitter_events` or `data_splitter_forecasting` must be provided. When both are supplied, it ensures they share the same split dates for multi-task training. When only one is supplied, the methods return `None` for the missing task type.
-### Training Workflow
+!!! tip "Single-task usage"
+ You don't need both splitters. Pass only `data_splitter_forecasting` for forecasting-only pipelines, or only `data_splitter_events` for event-prediction-only pipelines. See [Forecasting-Only](#forecasting-only) and [Events-Only](#events-only) below.
+
+### Training Workflow (Both Tasks)
```python
from twinweaver import DataSplitter
-data_splitter = DataSplitter(data_splitter_events, data_splitter_forecasting)
+data_splitter = DataSplitter(
+ data_splitter_events=data_splitter_events,
+ data_splitter_forecasting=data_splitter_forecasting,
+)
# Generate aligned splits for both tasks
forecasting_splits, events_splits, reference_dates = \
@@ -187,14 +193,47 @@ forecasting_splits, events_splits, reference_dates = \
Internally, `get_splits_from_patient_with_target`:
-1. Calls `DataSplitterForecasting.get_splits_from_patient()` to determine split dates and generate forecasting tasks.
-2. Passes those same split dates (`reference_dates`) to `DataSplitterEvents.get_splits_from_patient()` to generate aligned event prediction tasks.
+1. Calls `DataSplitterForecasting.get_splits_from_patient()` (if available) to determine split dates and generate forecasting tasks.
+2. Passes those same split dates (`reference_dates`) to `DataSplitterEvents.get_splits_from_patient()` (if available) to generate aligned event prediction tasks.
+3. If only one splitter is provided, the other returns `None`. When only the events splitter is used, `reference_dates` are extracted from the generated event splits.
+
+This alignment is critical: when both task types are active, they see the same patient history up to the same point in time, enabling consistent multi-task learning.
+
+### Forecasting-Only
+
+```python
+# Only forecasting — no event prediction splitter needed
+data_splitter = DataSplitter(data_splitter_forecasting=data_splitter_forecasting)
+
+forecasting_splits, events_splits, reference_dates = \
+ data_splitter.get_splits_from_patient_with_target(patient_data)
+# events_splits is None
+
+converter.forward_conversion(
+ forecasting_splits=forecasting_splits[0],
+ event_splits=None, # No event splits available
+)
+```
+
+### Events-Only
+
+```python
+# Only event prediction — no forecasting splitter needed
+data_splitter = DataSplitter(data_splitter_events=data_splitter_events)
+
+forecasting_splits, events_splits, reference_dates = \
+ data_splitter.get_splits_from_patient_with_target(patient_data)
+# forecasting_splits is None
-This alignment is critical: both task types see the same patient history up to the same point in time, enabling consistent multi-task learning.
+converter.forward_conversion(
+ forecasting_splits=None, # No forecasting splits available
+ event_splits=events_splits[0],
+)
+```
### Inference Workflow
-For inference, use `get_splits_from_patient_inference`, which anchors the split at the **last available date** in the patient's record:
+For inference, use `get_splits_from_patient_inference`, which anchors the split at the **last available date** in the patient's record. The `inference_type` parameter controls which tasks to generate — it defaults to `"both"` but gracefully handles the case when only one splitter is available:
```python
forecast_split, events_split = data_splitter.get_splits_from_patient_inference(
@@ -206,6 +245,9 @@ forecast_split, events_split = data_splitter.get_splits_from_patient_inference(
)
```
+!!! note
+ When `inference_type="both"` and only one splitter is provided, the missing task simply returns `None` without raising an error. If you request a specific `inference_type` (e.g., `"forecasting"`) but the corresponding splitter was not provided, a `ValueError` is raised.
+
---
## How Multiple Training Examples Are Generated
@@ -217,7 +259,7 @@ A single patient can yield many training examples through several sources of var
| Multiple split events (e.g., LoTs) | Patient history | One split per LoT by default |
| Multiple dates per split event | `max_num_splits_per_split_event` | Random dates within the LoT window |
| Different variable subsets | `min/max_nr_variables_to_sample` | Different forecasting questions per date |
-| Different event categories | `data_splitter_events_variables_category_mapping` | Death vs. progression predictions |
+| Different event categories | `event_category_events_prediction_with_naming` | Death vs. progression predictions |
| Different prediction windows | `min/max_length_to_sample` | 1-week to 104-week horizons |
This diversity encourages the model to generalize across time points, variables, and prediction tasks.
@@ -238,7 +280,7 @@ from twinweaver import (
config = Config()
config.split_event_category = "lot"
config.event_category_forecast = ["lab"]
-config.data_splitter_events_variables_category_mapping = {
+config.event_category_events_prediction_with_naming = {
"death": "death",
"progression": "next progression",
}
@@ -249,7 +291,7 @@ dm.load_indication_data(df_events=df_events, df_constant=df_constant,
df_constant_description=df_constant_description)
dm.process_indication_data()
dm.setup_unique_mapping_of_events()
-dm.setup_dataset_splits()
+dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)
dm.infer_var_types()
# 3. Initialize splitters
@@ -259,7 +301,10 @@ data_splitter_events.setup_variables()
data_splitter_forecasting = DataSplitterForecasting(data_manager=dm, config=config)
data_splitter_forecasting.setup_statistics() # Compute variable scores
-data_splitter = DataSplitter(data_splitter_events, data_splitter_forecasting)
+data_splitter = DataSplitter(
+ data_splitter_events=data_splitter_events,
+ data_splitter_forecasting=data_splitter_forecasting,
+)
# 4. Generate splits for a patient
patient_data = dm.get_patient_data(dm.all_patientids[0])
@@ -275,7 +320,6 @@ converter = ConverterInstruction(
result = converter.forward_conversion(
forecasting_splits=forecasting_splits[0],
event_splits=events_splits[0],
- override_mode_to_select_forecasting="both",
)
print(result["instruction"][:500])
@@ -290,4 +334,6 @@ print(result["answer"])
- **[Framework Overview](framework.md)**: Learn about TwinWeaver's architecture and task types
- **[Data Preparation Tutorial](examples/01_data_preparation_for_training.ipynb)**: Step-by-step notebook walkthrough
- **[Custom Splitting (Training)](examples/advanced/custom_splitting/training_individual_splitters.ipynb)**: Advanced splitting with individual splitters
+- **[Forecasting-Only Splitting](examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb)**: Using `DataSplitter` with only the forecasting splitter
+- **[Custom Split Events](examples/advanced/custom_splitting/training_custom_split_events.ipynb)**: Using `DataSplitter` with custom split events
- **[API Reference — Data Splitters](reference/instruction/data_splitters.md)**: Full API documentation
diff --git a/docs/dataset-format.md b/docs/dataset-format.md
index ace6107..564cced 100644
--- a/docs/dataset-format.md
+++ b/docs/dataset-format.md
@@ -104,6 +104,10 @@ On the first visit, the patient experienced the following:
Hemoglobin is 11.8.
```
+#### Relative Dating
+
+TwinWeaver uses **relative dating** instead of absolute dates. All calendar dates from the input data are converted into time deltas relative to the previous visit (e.g., *"2 weeks later"*) rather than being included as raw dates (e.g., *"2024-01-29"*). This serves two important purposes: first, it **anonymizes** the patient data by removing identifiable calendar dates from the training text; second, it provides the model with **clinically meaningful temporal context** — the time elapsed between visits — rather than arbitrary date strings. By default, time intervals are expressed in weeks, but this can be changed to days using `Config.set_delta_time_unit("days")`. Accumulative deltas (time since the very first visit rather than since the previous visit) are also supported.
+
### Final Output Structure
For training, TwinWeaver produces input-target pairs:
@@ -284,7 +288,7 @@ config.event_category_forecast = ["lab"]
# 3. Mapping of specific time to events to predict (e.g., we want to predict 'death' and 'progression')
# Only needs to be set if you want to do time to event prediction
-config.data_splitter_events_variables_category_mapping = {
+config.event_category_events_prediction_with_naming = {
"death": "death",
"progression": "next progression", # Custom name in prompt
}
@@ -301,6 +305,6 @@ dm.load_indication_data(
!!! tip "Configuration Parameters"
- **`split_event_category`**: The event category used to anchor split points for generating training samples (required for instruction tuning)
- **`event_category_forecast`**: Which event categories to forecast as time-series values
- - **`data_splitter_events_variables_category_mapping`**: Maps event names to prediction tasks (e.g., survival, progression)
+ - **`event_category_events_prediction_with_naming`**: Maps event names to prediction tasks (e.g., survival, progression)
See the [Raw Data Preprocessing Tutorial](examples/data_preprocessing/raw_data_preprocessing.ipynb) for transforming raw clinical data into TwinWeaver format, or the [Data Preparation Tutorial](examples/01_data_preparation_for_training.ipynb) for a complete walkthrough of instruction-tuning data generation.
diff --git a/docs/examples/01_data_preparation_for_training.ipynb b/docs/examples/01_data_preparation_for_training.ipynb
index 60fe0bd..d42c8f7 100644
--- a/docs/examples/01_data_preparation_for_training.ipynb
+++ b/docs/examples/01_data_preparation_for_training.ipynb
@@ -95,7 +95,7 @@
"\n",
"- **`config.split_event_category`**: Determines the event category used to split patient histories into input and output segments. In this example, we split data around \"Line of Therapy\" (`lot`) start dates.\n",
"- **`config.event_category_forecast`**: Identifies which event categories should be forecasted as time-series values. Here, we target `lab` values.\n",
- "- **`config.data_splitter_events_variables_category_mapping`**: Maps specific events to survival analysis or classification tasks. We configure the model to predict outcomes for `death` and `progression`."
+ "- **`config.event_category_events_prediction_with_naming`**: Maps specific events to survival analysis or classification tasks. We configure the model to predict outcomes for `death` and `progression`."
]
},
{
@@ -118,7 +118,7 @@
"\n",
"# 3. Mapping of specific time to events to predict (e.g., we want to predict 'death' and 'progression')\n",
"# Only needs to be set if you want to do time to event prediction\n",
- "config.data_splitter_events_variables_category_mapping = {\n",
+ "config.event_category_events_prediction_with_naming = {\n",
" \"death\": \"death\",\n",
" \"progression\": \"next progression\", # Custom name in prompt: \"next progression\" instead of \"progression\"\n",
"}"
@@ -175,7 +175,7 @@
"# Setup unique mapping of events, to understand which events correspond to which categories\n",
"dm.setup_unique_mapping_of_events()\n",
"# (Optional) assign each patient to train/validation/test splits\n",
- "dm.setup_dataset_splits()\n",
+ "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n",
"# (Optional - needed for forecasting) infer variable types\n",
"dm.infer_var_types()"
]
@@ -202,13 +202,19 @@
"outputs": [],
"source": [
"# This data splitter handles event prediction tasks\n",
- "data_splitter_events = DataSplitterEvents(dm, config=config)\n",
+ "data_splitter_events = DataSplitterEvents(\n",
+ " dm,\n",
+ " config=config,\n",
+ " max_length_to_sample=pd.Timedelta(weeks=104),\n",
+ " min_length_to_sample=pd.Timedelta(weeks=1),\n",
+ ")\n",
"data_splitter_events.setup_variables()\n",
"\n",
"# This data splitter handles forecasting tasks\n",
"data_splitter_forecasting = DataSplitterForecasting(\n",
" data_manager=dm,\n",
" config=config,\n",
+ " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n",
")\n",
"# If you don't want to do forecasting QA, proportional sampling, or 3-sigma filtering, you can skip this step\n",
"data_splitter_forecasting.setup_statistics()\n",
@@ -339,7 +345,6 @@
"p_converted = converter.forward_conversion(\n",
" forecasting_splits=forecasting_splits[split_idx],\n",
" event_splits=events_splits[split_idx],\n",
- " override_mode_to_select_forecasting=\"both\",\n",
")"
]
},
@@ -418,7 +423,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": ".venv_dev",
+ "display_name": ".venv_dev_gpu",
"language": "python",
"name": "python3"
},
diff --git a/docs/examples/02_inference_prompt_preparation.ipynb b/docs/examples/02_inference_prompt_preparation.ipynb
index 299e9cd..9075f16 100644
--- a/docs/examples/02_inference_prompt_preparation.ipynb
+++ b/docs/examples/02_inference_prompt_preparation.ipynb
@@ -67,7 +67,7 @@
"\n",
"# 3. Mapping of specific time to events to predict (e.g., we want to predict 'death' and 'progression')\n",
"# Only needs to be set if you want to do time to event prediction\n",
- "config.data_splitter_events_variables_category_mapping = {\n",
+ "config.event_category_events_prediction_with_naming = {\n",
" \"death\": \"death\",\n",
" \"progression\": \"next progression\", # Custom name in prompt: \"next progression\" instead of \"progression\"\n",
"}\n",
@@ -88,15 +88,21 @@
"dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n",
"dm.process_indication_data()\n",
"dm.setup_unique_mapping_of_events()\n",
- "dm.setup_dataset_splits()\n",
+ "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n",
"dm.infer_var_types()\n",
"\n",
- "data_splitter_events = DataSplitterEvents(dm, config=config)\n",
+ "data_splitter_events = DataSplitterEvents(\n",
+ " dm,\n",
+ " config=config,\n",
+ " max_length_to_sample=pd.Timedelta(weeks=104),\n",
+ " min_length_to_sample=pd.Timedelta(weeks=1),\n",
+ ")\n",
"data_splitter_events.setup_variables()\n",
"\n",
"data_splitter_forecasting = DataSplitterForecasting(\n",
" data_manager=dm,\n",
" config=config,\n",
+ " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n",
")\n",
"# In case you manually want to override the variables for forecasting selectiong, you can skip this next line.\n",
"data_splitter_forecasting.setup_statistics()\n",
diff --git a/docs/examples/03_end_to_end_llm_finetuning.ipynb b/docs/examples/03_end_to_end_llm_finetuning.ipynb
index 6a2901f..3cfda1d 100644
--- a/docs/examples/03_end_to_end_llm_finetuning.ipynb
+++ b/docs/examples/03_end_to_end_llm_finetuning.ipynb
@@ -109,7 +109,7 @@
"\n",
"# 3. Mapping of specific time to events to predict (e.g., we want to predict 'death' and 'progression')\n",
"# Only needs to be set if you want to do time to event prediction\n",
- "config.data_splitter_events_variables_category_mapping = {\n",
+ "config.event_category_events_prediction_with_naming = {\n",
" \"death\": \"death\",\n",
" \"progression\": \"next progression\", # Custom name in prompt: \"next progression\" instead of \"progression\"\n",
"}\n",
@@ -144,15 +144,21 @@
"dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n",
"dm.process_indication_data()\n",
"dm.setup_unique_mapping_of_events()\n",
- "dm.setup_dataset_splits()\n",
+ "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n",
"dm.infer_var_types()\n",
"\n",
- "data_splitter_events = DataSplitterEvents(dm, config=config)\n",
+ "data_splitter_events = DataSplitterEvents(\n",
+ " dm,\n",
+ " config=config,\n",
+ " max_length_to_sample=pd.Timedelta(weeks=104),\n",
+ " min_length_to_sample=pd.Timedelta(weeks=1),\n",
+ ")\n",
"data_splitter_events.setup_variables()\n",
"\n",
"data_splitter_forecasting = DataSplitterForecasting(\n",
" data_manager=dm,\n",
" config=config,\n",
+ " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n",
")\n",
"\n",
"# If you don't want to do forecasting QA, proportional sampling, or 3-sigma filtering, you can skip this step\n",
@@ -213,7 +219,6 @@
" p_converted = converter.forward_conversion(\n",
" forecasting_splits=forecasting_splits[split_idx],\n",
" event_splits=events_splits[split_idx],\n",
- " override_mode_to_select_forecasting=\"both\",\n",
" )\n",
" new_data = {\n",
" \"prompt\": p_converted[\"instruction\"],\n",
@@ -490,9 +495,7 @@
"# Lets simulate forecasts for after the first line of therapy\n",
"df_constant_patient = patient_data[\"constant\"].copy()\n",
"df_events_patient = patient_data[\"events\"].copy()\n",
- "date_of_first_lot = df_events_patient.loc[\n",
- " df_events_patient[\"event_category\"] == config.event_category_lot, \"date\"\n",
- "].min()\n",
+ "date_of_first_lot = df_events_patient.loc[df_events_patient[\"event_category\"] == \"lot\", \"date\"].min()\n",
"\n",
"# Only keep data until (and including) first line of therapy\n",
"df_events_patient = df_events_patient.loc[df_events_patient[\"date\"] <= date_of_first_lot]"
diff --git a/docs/examples/advanced/custom_output/custom_summarized_row.ipynb b/docs/examples/advanced/custom_output/custom_summarized_row.ipynb
index d17f3db..58bf69f 100644
--- a/docs/examples/advanced/custom_output/custom_summarized_row.ipynb
+++ b/docs/examples/advanced/custom_output/custom_summarized_row.ipynb
@@ -95,7 +95,7 @@
"# Required settings\n",
"config.split_event_category = \"lot\"\n",
"config.event_category_forecast = [\"lab\"]\n",
- "config.data_splitter_events_variables_category_mapping = {\n",
+ "config.event_category_events_prediction_with_naming = {\n",
" \"death\": \"death\",\n",
" \"progression\": \"next progression\",\n",
"}\n",
@@ -122,7 +122,7 @@
"dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n",
"dm.process_indication_data()\n",
"dm.setup_unique_mapping_of_events()\n",
- "dm.setup_dataset_splits()\n",
+ "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n",
"dm.infer_var_types()"
]
},
@@ -134,10 +134,19 @@
"outputs": [],
"source": [
"# Setup splitters\n",
- "data_splitter_events = DataSplitterEvents(dm, config=config)\n",
+ "data_splitter_events = DataSplitterEvents(\n",
+ " dm,\n",
+ " config=config,\n",
+ " max_length_to_sample=pd.Timedelta(weeks=104),\n",
+ " min_length_to_sample=pd.Timedelta(weeks=1),\n",
+ ")\n",
"data_splitter_events.setup_variables()\n",
"\n",
- "data_splitter_forecasting = DataSplitterForecasting(data_manager=dm, config=config)\n",
+ "data_splitter_forecasting = DataSplitterForecasting(\n",
+ " data_manager=dm,\n",
+ " config=config,\n",
+ " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n",
+ ")\n",
"data_splitter_forecasting.setup_statistics()\n",
"\n",
"data_splitter = DataSplitter(data_splitter_events, data_splitter_forecasting)"
@@ -200,7 +209,6 @@
"p_default = converter_default.forward_conversion(\n",
" forecasting_splits=forecasting_splits[0],\n",
" event_splits=events_splits[0],\n",
- " override_mode_to_select_forecasting=\"both\",\n",
")\n",
"\n",
"print(\"=\" * 80)\n",
@@ -305,7 +313,6 @@
"p_custom = converter_custom.forward_conversion(\n",
" forecasting_splits=forecasting_splits[0],\n",
" event_splits=events_splits[0],\n",
- " override_mode_to_select_forecasting=\"both\",\n",
")\n",
"\n",
"print(\"=\" * 80)\n",
@@ -408,7 +415,6 @@
"p_advanced = converter_advanced.forward_conversion(\n",
" forecasting_splits=forecasting_splits[0],\n",
" event_splits=events_splits[0],\n",
- " override_mode_to_select_forecasting=\"both\",\n",
")\n",
"\n",
"print(\"=\" * 80)\n",
@@ -481,7 +487,6 @@
" converter_broken.forward_conversion(\n",
" forecasting_splits=forecasting_splits[0],\n",
" event_splits=events_splits[0],\n",
- " override_mode_to_select_forecasting=\"both\",\n",
" )\n",
"except TypeError as e:\n",
" print(f\"Caught runtime error: {e}\")"
@@ -515,7 +520,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": ".venv_dev",
+ "display_name": ".venv_dev_gpu",
"language": "python",
"name": "python3"
},
diff --git a/docs/examples/advanced/custom_output/customizing_text_generation.ipynb b/docs/examples/advanced/custom_output/customizing_text_generation.ipynb
index 6f15659..03ea854 100644
--- a/docs/examples/advanced/custom_output/customizing_text_generation.ipynb
+++ b/docs/examples/advanced/custom_output/customizing_text_generation.ipynb
@@ -87,7 +87,7 @@
"# Required settings for instruction mode\n",
"config_default.split_event_category = \"lot\"\n",
"config_default.event_category_forecast = [\"lab\"]\n",
- "config_default.data_splitter_events_variables_category_mapping = {\n",
+ "config_default.event_category_events_prediction_with_naming = {\n",
" \"death\": \"death\",\n",
" \"progression\": \"next progression\",\n",
"}\n",
@@ -109,7 +109,7 @@
")\n",
"dm_default.process_indication_data()\n",
"dm_default.setup_unique_mapping_of_events()\n",
- "dm_default.setup_dataset_splits()\n",
+ "dm_default.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n",
"dm_default.infer_var_types()"
]
},
@@ -121,10 +121,19 @@
"outputs": [],
"source": [
"# Setup splitters and converter\n",
- "data_splitter_events_default = DataSplitterEvents(dm_default, config=config_default)\n",
+ "data_splitter_events_default = DataSplitterEvents(\n",
+ " dm_default,\n",
+ " config=config_default,\n",
+ " max_length_to_sample=pd.Timedelta(weeks=104),\n",
+ " min_length_to_sample=pd.Timedelta(weeks=1),\n",
+ ")\n",
"data_splitter_events_default.setup_variables()\n",
"\n",
- "data_splitter_forecasting_default = DataSplitterForecasting(data_manager=dm_default, config=config_default)\n",
+ "data_splitter_forecasting_default = DataSplitterForecasting(\n",
+ " data_manager=dm_default,\n",
+ " config=config_default,\n",
+ " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n",
+ ")\n",
"data_splitter_forecasting_default.setup_statistics()\n",
"\n",
"data_splitter_default = DataSplitter(data_splitter_events_default, data_splitter_forecasting_default)\n",
@@ -155,7 +164,6 @@
"p_converted_default = converter_default.forward_conversion(\n",
" forecasting_splits=forecasting_splits[0],\n",
" event_splits=events_splits[0],\n",
- " override_mode_to_select_forecasting=\"both\",\n",
")\n",
"\n",
"print(\"=\" * 80)\n",
@@ -193,7 +201,7 @@
"# Required settings\n",
"config_custom.split_event_category = \"lot\"\n",
"config_custom.event_category_forecast = [\"lab\"]\n",
- "config_custom.data_splitter_events_variables_category_mapping = {\n",
+ "config_custom.event_category_events_prediction_with_naming = {\n",
" \"death\": \"mortality\", # Custom name for death event\n",
" \"progression\": \"disease progression\", # Custom name for progression\n",
"}\n",
@@ -611,7 +619,7 @@
")\n",
"dm_custom.process_indication_data()\n",
"dm_custom.setup_unique_mapping_of_events()\n",
- "dm_custom.setup_dataset_splits()\n",
+ "dm_custom.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n",
"dm_custom.infer_var_types()"
]
},
@@ -623,10 +631,19 @@
"outputs": [],
"source": [
"# Setup splitters and converter with custom config\n",
- "data_splitter_events_custom = DataSplitterEvents(dm_custom, config=config_custom)\n",
+ "data_splitter_events_custom = DataSplitterEvents(\n",
+ " dm_custom,\n",
+ " config=config_custom,\n",
+ " max_length_to_sample=pd.Timedelta(weeks=104),\n",
+ " min_length_to_sample=pd.Timedelta(weeks=1),\n",
+ ")\n",
"data_splitter_events_custom.setup_variables()\n",
"\n",
- "data_splitter_forecasting_custom = DataSplitterForecasting(data_manager=dm_custom, config=config_custom)\n",
+ "data_splitter_forecasting_custom = DataSplitterForecasting(\n",
+ " data_manager=dm_custom,\n",
+ " config=config_custom,\n",
+ " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n",
+ ")\n",
"data_splitter_forecasting_custom.setup_statistics()\n",
"\n",
"data_splitter_custom = DataSplitter(data_splitter_events_custom, data_splitter_forecasting_custom)\n",
@@ -657,7 +674,6 @@
"p_converted_custom = converter_custom.forward_conversion(\n",
" forecasting_splits=forecasting_splits_custom[0],\n",
" event_splits=events_splits_custom[0],\n",
- " override_mode_to_select_forecasting=\"both\",\n",
")\n",
"\n",
"print(\"=\" * 80)\n",
diff --git a/docs/examples/advanced/custom_splitting/inference_individual_splitters.md b/docs/examples/advanced/custom_splitting/inference_individual_splitters.md
new file mode 100644
index 0000000..cd5cb1c
--- /dev/null
+++ b/docs/examples/advanced/custom_splitting/inference_individual_splitters.md
@@ -0,0 +1,17 @@
+# Custom Splitting for Inference
+
+This script demonstrates how to use individual splitters (`DataSplitterForecasting` and `DataSplitterEvents`) combined via the unified `DataSplitter` API for **inference** — i.e., when you only have input data and no target labels.
+
+Key concepts:
+
+- Configuring split events and forecasting categories
+- Using `DataSplitter.get_splits_from_patient_inference()` to generate splits at inference time
+- Converting splits to text prompts with `ConverterInstruction.forward_conversion_inference()`
+- Overriding which variables/events to predict
+
+!!! note "Run from project root"
+ This script should be run from the root folder of the TwinWeaver repository.
+
+```python title="inference_individual_splitters.py"
+--8<-- "examples/advanced/custom_splitting/inference_individual_splitters.py"
+```
diff --git a/docs/examples/advanced/custom_splitting/inference_individual_splitters.py b/docs/examples/advanced/custom_splitting/inference_individual_splitters.py
index abf7273..3840087 100644
--- a/docs/examples/advanced/custom_splitting/inference_individual_splitters.py
+++ b/docs/examples/advanced/custom_splitting/inference_individual_splitters.py
@@ -2,6 +2,7 @@
DataSplitterForecasting,
DataManager,
DataSplitterEvents,
+ DataSplitter,
ConverterInstruction,
Config,
)
@@ -16,7 +17,7 @@ def __init__(
self.config = Config()
self.config.split_event_category = "lot"
self.config.event_category_forecast = ["lab"]
- self.config.data_splitter_events_variables_category_mapping = {
+ self.config.event_category_events_prediction_with_naming = {
"death": "death",
"progression": "next progression", # Custom name in prompt: "next progression" instead of "progression"
}
@@ -42,13 +43,29 @@ def __init__(
)
self.dm.process_indication_data()
self.dm.setup_unique_mapping_of_events()
- self.dm.setup_dataset_splits()
+ self.dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)
self.dm.infer_var_types()
- self.data_splitter_events = DataSplitterEvents(self.dm, config=self.config)
- self.data_splitter_events.setup_variables()
- self.data_splitter_forecasting = DataSplitterForecasting(data_manager=self.dm, config=self.config)
- self.data_splitter_forecasting.setup_statistics()
+ data_splitter_events = DataSplitterEvents(
+ self.dm,
+ config=self.config,
+ max_length_to_sample=pd.Timedelta(weeks=104),
+ min_length_to_sample=pd.Timedelta(weeks=1),
+ )
+ data_splitter_events.setup_variables()
+ data_splitter_forecasting = DataSplitterForecasting(
+ data_manager=self.dm,
+ config=self.config,
+ max_forecasted_trajectory_length=pd.Timedelta(days=90),
+ )
+ data_splitter_forecasting.setup_statistics()
+
+ # Use the unified DataSplitter API that combines both splitters
+ self.data_splitter = DataSplitter(
+ data_splitter_events=data_splitter_events,
+ data_splitter_forecasting=data_splitter_forecasting,
+ )
+
self.converter = ConverterInstruction(
nr_tokens_budget_total=8192,
config=self.config,
@@ -62,48 +79,26 @@ def convert_full_to_string_for_one_patient(self, patientid, override_events_or_f
# To simulate that we only have input, half the events
patient_data["events"] = patient_data["events"].iloc[: int(len(patient_data["events"]) / 2)]
- # Here then split date
- split_date = patient_data["events"]["date"].iloc[-1]
-
- #: generate event split - NOTE: this if statement is only to exemplify both cases!
- if override_events_or_forecasting == "events":
- ####### Example if we want to override for events
-
- events_splits = self.data_splitter_events.get_splits_from_patient(
- patient_data,
- max_nr_samples=1,
- override_split_dates=[split_date],
- override_category="death",
- override_end_week_delta=52,
- )
- # We just pick the first one
- events_split = events_splits[0][0]
-
- #: no forecasting split
- forecast_split = None
- forecasting_times_to_predict = None
- else:
- ####### Example if we want to override for forecasting
-
- #: generate forecasting split
- forecast_splits = self.data_splitter_forecasting.get_splits_from_patient(
- patient_data,
- nr_samples_per_split=1,
- filter_outliers=False,
- override_split_dates=[split_date],
- override_variables_to_predict=["Neutrophils"],
- )
- # We just pick the first one
- forecast_split = forecast_splits[0][0]
-
- # We set which weeks to predict
+ # Use the unified DataSplitter API for inference
+ forecast_split, events_split = self.data_splitter.get_splits_from_patient_inference(
+ patient_data,
+ inference_type=override_events_or_forecasting,
+ forecasting_override_variables_to_predict=["Neutrophils"]
+ if override_events_or_forecasting != "events"
+ else None,
+ events_override_category="death" if override_events_or_forecasting != "forecasting" else None,
+ events_override_observation_time_delta=pd.Timedelta(weeks=52)
+ if override_events_or_forecasting != "forecasting"
+ else None,
+ )
+
+ # Set which weeks to predict for forecasting (if applicable)
+ forecasting_times_to_predict = None
+ if forecast_split is not None:
forecasting_times_to_predict = {
"Neutrophils": [1, 2, 8, 11],
}
- #: no events split
- events_split = None
-
# Convert to text
converted = self.converter.forward_conversion_inference(
forecasting_split=forecast_split,
diff --git a/docs/examples/advanced/custom_splitting/training_custom_split_events.ipynb b/docs/examples/advanced/custom_splitting/training_custom_split_events.ipynb
index acd176a..68976a7 100644
--- a/docs/examples/advanced/custom_splitting/training_custom_split_events.ipynb
+++ b/docs/examples/advanced/custom_splitting/training_custom_split_events.ipynb
@@ -101,7 +101,7 @@
"config.event_category_forecast = [\"vitals\"]\n",
"\n",
"# To predict different variables for the event categories, we set up a mapping here\n",
- "config.data_splitter_events_variables_category_mapping = {\n",
+ "config.event_category_events_prediction_with_naming = {\n",
" \"lot\": \"time to next lot\", # Custom name in prompt: \"time to next lot\" instead of \"lot\"\n",
"}"
]
@@ -118,7 +118,7 @@
"dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n",
"dm.process_indication_data()\n",
"dm.setup_unique_mapping_of_events()\n",
- "dm.setup_dataset_splits()\n",
+ "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n",
"dm.infer_var_types()"
]
},
@@ -138,13 +138,19 @@
"outputs": [],
"source": [
"# This data splitter handles event prediction tasks\n",
- "data_splitter_events = DataSplitterEvents(dm, config=config)\n",
+ "data_splitter_events = DataSplitterEvents(\n",
+ " dm,\n",
+ " config=config,\n",
+ " max_length_to_sample=pd.Timedelta(weeks=104),\n",
+ " min_length_to_sample=pd.Timedelta(weeks=1),\n",
+ ")\n",
"data_splitter_events.setup_variables()\n",
"\n",
"# This data splitter handles forecasting tasks\n",
"data_splitter_forecasting = DataSplitterForecasting(\n",
" data_manager=dm,\n",
" config=config,\n",
+ " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n",
")\n",
"# If you don't want to do forecasting QA, proportional sampling, or 3-sigma filtering, you can skip this step\n",
"data_splitter_forecasting.setup_statistics()\n",
@@ -213,7 +219,8 @@
"source": [
"forecasting_splits, events_splits, reference_dates = data_splitter.get_splits_from_patient_with_target(\n",
" patient_data,\n",
- ")"
+ ")\n",
+ "# Note, forecasting_splits will be none here"
]
},
{
@@ -233,9 +240,8 @@
"source": [
"split_idx = 0\n",
"p_converted = converter.forward_conversion(\n",
- " forecasting_splits=forecasting_splits[split_idx],\n",
+ " forecasting_splits=None, # Set to None since we don't want to generate forecasting tasks\n",
" event_splits=events_splits[split_idx],\n",
- " override_mode_to_select_forecasting=\"both\",\n",
")"
]
},
@@ -300,7 +306,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": ".venv_dev",
+ "display_name": ".venv_dev_gpu",
"language": "python",
"name": "python3"
},
diff --git a/docs/examples/advanced/custom_splitting/training_forecasting_qa.ipynb b/docs/examples/advanced/custom_splitting/training_forecasting_qa.ipynb
new file mode 100644
index 0000000..4d0fd0c
--- /dev/null
+++ b/docs/examples/advanced/custom_splitting/training_forecasting_qa.ipynb
@@ -0,0 +1,425 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "0",
+ "metadata": {},
+ "source": [
+ "# Forecasting QA: Question-Answering Style Value Prediction\n",
+ "\n",
+ "This notebook demonstrates the **Forecasting QA** task mode in TwinWeaver.\n",
+ "\n",
+ "While the default `forward_conversion` mode (`\"forecasting\"`) asks the model to predict exact future values,\n",
+ "the **Forecasting QA** mode bins continuous target values into discrete categories (e.g., `A`, `B`, `C`)\n",
+ "and asks the model to predict the correct bin — turning regression into classification.\n",
+ "\n",
+ "This can be useful when:\n",
+ "- Exact numeric predictions are noisy or unreliable\n",
+ "- You want the model to reason about value ranges instead of point estimates\n",
+ "- You want to combine both forecasting and QA tasks (mode `\"both\"`)\n",
+ "\n",
+ "### How it works\n",
+ "\n",
+ "The `ConverterForecastingQA` sub-converter:\n",
+ "1. Computes bin edges from the variable statistics (`setup_statistics()` on the forecasting splitter)\n",
+ "2. Maps each target value to a lettered bin (A, B, C, …)\n",
+ "3. Generates a prompt listing the bin definitions and asks the model to choose the right bin\n",
+ "4. The target answer uses the bin letter instead of the raw numeric value\n",
+ "\n",
+ "### Selecting the mode\n",
+ "\n",
+ "The `forward_conversion` method accepts an `override_mode_to_select_forecasting` parameter:\n",
+ "- `\"forecasting\"` (default): numeric value prediction\n",
+ "- `\"forecasting_qa\"`: bin-based QA prediction\n",
+ "- `\"both\"`: includes both a forecasting and a forecasting QA task\n",
+ "- `None`: randomly selects one of the above at each call"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1",
+ "metadata": {},
+ "source": [
+ "## Setup\n",
+ "\n",
+ "Load libraries and example data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "\n",
+ "from twinweaver import (\n",
+ " DataSplitterForecasting,\n",
+ " DataSplitterEvents,\n",
+ " DataSplitter,\n",
+ " DataManager,\n",
+ " ConverterInstruction,\n",
+ " Config,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df_events = pd.read_csv(\"../../example_data/events.csv\")\n",
+ "df_constant = pd.read_csv(\"../../example_data/constant.csv\")\n",
+ "df_constant_description = pd.read_csv(\"../../example_data/constant_description.csv\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4",
+ "metadata": {},
+ "source": [
+ "## Configuration\n",
+ "\n",
+ "Set up the config, data manager, splitters, and converter.\n",
+ "\n",
+ "> **Important:** `variable_stats` from `DataSplitterForecasting.setup_statistics()` must be\n",
+ "> passed to `ConverterInstruction` for the QA mode to work — it provides the bin edges."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "config = Config()\n",
+ "\n",
+ "config.split_event_category = \"lot\"\n",
+ "config.event_category_forecast = [\"lab\"]\n",
+ "config.event_category_events_prediction_with_naming = {\n",
+ " \"death\": \"death\",\n",
+ " \"progression\": \"next progression\",\n",
+ "}\n",
+ "config.constant_columns_to_use = [\"birthyear\", \"gender\", \"histology\", \"smoking_history\"]\n",
+ "config.constant_birthdate_column = \"birthyear\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dm = DataManager(config=config)\n",
+ "dm.load_indication_data(\n",
+ " df_events=df_events,\n",
+ " df_constant=df_constant,\n",
+ " df_constant_description=df_constant_description,\n",
+ ")\n",
+ "dm.process_indication_data()\n",
+ "dm.setup_unique_mapping_of_events()\n",
+ "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n",
+ "dm.infer_var_types()\n",
+ "\n",
+ "data_splitter_events = DataSplitterEvents(\n",
+ " dm,\n",
+ " config=config,\n",
+ " max_length_to_sample=pd.Timedelta(weeks=104),\n",
+ " min_length_to_sample=pd.Timedelta(weeks=1),\n",
+ ")\n",
+ "data_splitter_events.setup_variables()\n",
+ "\n",
+ "data_splitter_forecasting = DataSplitterForecasting(\n",
+ " data_manager=dm,\n",
+ " config=config,\n",
+ " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n",
+ ")\n",
+ "# setup_statistics() computes per-variable quantile bin edges used by the QA mode\n",
+ "data_splitter_forecasting.setup_statistics()\n",
+ "\n",
+ "data_splitter = DataSplitter(\n",
+ " data_splitter_events=data_splitter_events,\n",
+ " data_splitter_forecasting=data_splitter_forecasting,\n",
+ ")\n",
+ "\n",
+ "converter = ConverterInstruction(\n",
+ " nr_tokens_budget_total=8192,\n",
+ " config=config,\n",
+ " dm=dm,\n",
+ " variable_stats=data_splitter_forecasting.variable_stats, # Required for QA mode\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7",
+ "metadata": {},
+ "source": [
+ "## Generate Splits\n",
+ "\n",
+ "Get training splits for a patient."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "np.random.seed(42)\n",
+ "\n",
+ "patientid = dm.all_patientids[2]\n",
+ "patient_data = dm.get_patient_data(patientid)\n",
+ "\n",
+ "forecasting_splits, events_splits, reference_dates = data_splitter.get_splits_from_patient_with_target(\n",
+ " patient_data,\n",
+ " forecasting_nr_samples_per_split=4,\n",
+ " forecasting_filter_outliers=False,\n",
+ " max_num_splits_per_split_event=2,\n",
+ " events_max_nr_samples_per_split=3,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9",
+ "metadata": {},
+ "source": [
+ "## Mode 1: Default Forecasting (numeric predictions)\n",
+ "\n",
+ "By default, `forward_conversion` uses `override_mode_to_select_forecasting=\"forecasting\"`,\n",
+ "which generates tasks that ask the model to predict exact numeric values."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "10",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "np.random.seed(42)\n",
+ "split_idx = 0\n",
+ "\n",
+ "p_forecasting = converter.forward_conversion(\n",
+ " forecasting_splits=forecasting_splits[split_idx],\n",
+ " event_splits=events_splits[split_idx],\n",
+ " # override_mode_to_select_forecasting=\"forecasting\" # this is the default\n",
+ ")\n",
+ "\n",
+ "print(\"=\" * 80)\n",
+ "print(\"INSTRUCTION (last 1500 chars):\")\n",
+ "print(\"=\" * 80)\n",
+ "print(p_forecasting[\"instruction\"][-1500:])\n",
+ "print()\n",
+ "print(\"=\" * 80)\n",
+ "print(\"ANSWER:\")\n",
+ "print(\"=\" * 80)\n",
+ "print(p_forecasting[\"answer\"])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "11",
+ "metadata": {},
+ "source": [
+ "## Mode 2: Forecasting QA (bin-based predictions)\n",
+ "\n",
+ "Setting `override_mode_to_select_forecasting=\"forecasting_qa\"` activates the QA mode.\n",
+ "The prompt now lists bin definitions (e.g., `a = Bin (-inf, 5.2]`) and asks the model\n",
+ "to output the bin letter instead of a raw number."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "12",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "np.random.seed(42)\n",
+ "split_idx = 0\n",
+ "\n",
+ "p_qa = converter.forward_conversion(\n",
+ " forecasting_splits=forecasting_splits[split_idx],\n",
+ " event_splits=events_splits[split_idx],\n",
+ " override_mode_to_select_forecasting=\"forecasting_qa\",\n",
+ ")\n",
+ "\n",
+ "print(\"=\" * 80)\n",
+ "print(\"INSTRUCTION (last 2000 chars):\")\n",
+ "print(\"=\" * 80)\n",
+ "print(p_qa[\"instruction\"][-2000:])\n",
+ "print()\n",
+ "print(\"=\" * 80)\n",
+ "print(\"ANSWER:\")\n",
+ "print(\"=\" * 80)\n",
+ "print(p_qa[\"answer\"])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "13",
+ "metadata": {},
+ "source": [
+ "### Inspect the bin definitions\n",
+ "\n",
+ "The metadata returned by the conversion includes the bin mapping for each variable.\n",
+ "These are derived from quantiles computed by `setup_statistics()`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "14",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# The target meta contains the category splits for each variable\n",
+ "for task_meta in p_qa[\"meta\"][\"target_meta_detailed\"]:\n",
+ " if \"category_splits\" in task_meta:\n",
+ " print(\"Variable bin definitions:\")\n",
+ " for var, bins in task_meta[\"category_splits\"].items():\n",
+ " print(f\" {var}:\")\n",
+ " for letter, bin_range in bins.items():\n",
+ " print(f\" {letter} = {bin_range}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "15",
+ "metadata": {},
+ "source": [
+ "## Mode 3: Both (forecasting + QA in one prompt)\n",
+ "\n",
+ "Setting `override_mode_to_select_forecasting=\"both\"` includes **both** a numeric forecasting\n",
+ "task and a QA-style bin prediction task in the same multi-task prompt. This encourages the\n",
+ "model to learn complementary representations."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "16",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "np.random.seed(42)\n",
+ "split_idx = 0\n",
+ "\n",
+ "p_both = converter.forward_conversion(\n",
+ " forecasting_splits=forecasting_splits[split_idx],\n",
+ " event_splits=events_splits[split_idx],\n",
+ " override_mode_to_select_forecasting=\"both\",\n",
+ ")\n",
+ "\n",
+ "print(\"=\" * 80)\n",
+ "print(\"INSTRUCTION (last 3000 chars):\")\n",
+ "print(\"=\" * 80)\n",
+ "print(p_both[\"instruction\"][-3000:])\n",
+ "print()\n",
+ "print(\"=\" * 80)\n",
+ "print(\"ANSWER:\")\n",
+ "print(\"=\" * 80)\n",
+ "print(p_both[\"answer\"])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "17",
+ "metadata": {},
+ "source": [
+ "### Compare the task types generated\n",
+ "\n",
+ "Let's inspect which task types were included in each mode."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "18",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for label, result in [(\"forecasting\", p_forecasting), (\"forecasting_qa\", p_qa), (\"both\", p_both)]:\n",
+ " task_types = [m[\"task_type\"] for m in result[\"meta\"][\"target_meta_detailed\"]]\n",
+ " print(f\"Mode '{label}': {task_types}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "19",
+ "metadata": {},
+ "source": [
+ "## Reverse Conversion\n",
+ "\n",
+ "The reverse conversion works for all modes. It parses the model output back into\n",
+ "structured data, regardless of whether the target was numeric or bin-based."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "20",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "date = reference_dates[\"date\"][0]\n",
+ "\n",
+ "# Reverse convert the QA mode output\n",
+ "return_list = converter.reverse_conversion(p_qa[\"answer\"], dm, date)\n",
+ "\n",
+ "return_list[2][\"result\"]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "21",
+ "metadata": {},
+ "source": [
+ "## Variable Statistics\n",
+ "\n",
+ "The bin edges come from the `variable_stats` DataFrame computed during `setup_statistics()`.\n",
+ "You can inspect them directly:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "22",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "data_splitter_forecasting.variable_stats"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv_dev",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/docs/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb b/docs/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb
index 373ea42..409abe2 100644
--- a/docs/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb
+++ b/docs/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb
@@ -5,7 +5,7 @@
"id": "0",
"metadata": {},
"source": [
- "# Forecasting-Only Example: Training Data Generation with Custom Dataset"
+ "# Forecasting-Only Example: Training Data Generation with the Unified DataSplitter API"
]
},
{
@@ -27,6 +27,7 @@
"\n",
"from twinweaver import (\n",
" DataSplitterForecasting,\n",
+ " DataSplitter,\n",
" DataManager,\n",
" ConverterInstruction,\n",
" Config,\n",
@@ -66,7 +67,7 @@
"id": "6",
"metadata": {},
"source": [
- "Set up the data manager and the forecasting-only pipeline."
+ "Set up the data manager and the forecasting-only pipeline using the unified `DataSplitter` API. By passing only `data_splitter_forecasting`, the unified interface handles the forecasting-only case automatically."
]
},
{
@@ -105,17 +106,21 @@
"dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n",
"dm.process_indication_data()\n",
"dm.setup_unique_mapping_of_events()\n",
- "dm.setup_dataset_splits()\n",
+ "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n",
"dm.infer_var_types()\n",
"\n",
"\n",
"data_splitter_forecasting = DataSplitterForecasting(\n",
" data_manager=dm,\n",
" config=config,\n",
+ " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n",
")\n",
"# In case you manually want to override the variables for forecasting selectiong, you can skip this next line.\n",
"data_splitter_forecasting.setup_statistics()\n",
"\n",
+ "# Use the unified DataSplitter API with only the forecasting splitter\n",
+ "data_splitter = DataSplitter(data_splitter_forecasting=data_splitter_forecasting)\n",
+ "\n",
"converter = ConverterInstruction(\n",
" nr_tokens_budget_total=8192,\n",
" config=config,\n",
@@ -198,9 +203,9 @@
"id": "16",
"metadata": {},
"source": [
- "We start by generating random \"splits\" in the patient trajectory. We can make multiple relevant samples from each patient trajectory (e.g. depending on when the therapy started), and also to predict different variables (e.g. neutrophils/hemoglobin/... for forecasting).\n",
+ "We start by generating random \"splits\" in the patient trajectory using the unified `DataSplitter.get_splits_from_patient_with_target` method. We can make multiple relevant samples from each patient trajectory (e.g. depending on when the therapy started), and also to predict different variables (e.g. neutrophils/hemoglobin/... for forecasting).\n",
"\n",
- "Here we generate these random splits. We can also manually override them (see other examples on inference)."
+ "Since we only provided a forecasting splitter, `events_splits` will be `None`."
]
},
{
@@ -210,13 +215,13 @@
"metadata": {},
"outputs": [],
"source": [
- "processed_splits_fc, split_dates = data_splitter_forecasting.get_splits_from_patient(\n",
+ "forecasting_splits, events_splits, reference_dates = data_splitter.get_splits_from_patient_with_target(\n",
" patient_data,\n",
- " nr_samples_per_split=4,\n",
- " filter_outliers=False,\n",
- " include_metadata=True,\n",
+ " forecasting_nr_samples_per_split=4,\n",
+ " forecasting_filter_outliers=False,\n",
" max_num_splits_per_split_event=2,\n",
- ")"
+ ")\n",
+ "# Note, events_splits will be None here since we don't have any split events for this patient"
]
},
{
@@ -224,7 +229,7 @@
"id": "18",
"metadata": {},
"source": [
- "Now for each split, we can generate the formatted strings. Note that `event_splits` is left empty since this example only uses the forecasting splitter."
+ "Now for each split, we can generate the formatted strings. Note that `events_splits` is `None` since we only provided a forecasting splitter, so we pass an empty list for `event_splits`."
]
},
{
@@ -236,9 +241,8 @@
"source": [
"split_idx = 0\n",
"p_converted = converter.forward_conversion(\n",
- " forecasting_splits=processed_splits_fc[split_idx],\n",
- " event_splits=[], # Not needed for forecasting-only splitter\n",
- " override_mode_to_select_forecasting=\"forecasting\",\n",
+ " forecasting_splits=forecasting_splits[split_idx],\n",
+ " event_splits=None, # Not needed for forecasting-only splitter\n",
")"
]
},
diff --git a/docs/examples/advanced/custom_splitting/training_individual_splitters.ipynb b/docs/examples/advanced/custom_splitting/training_individual_splitters.ipynb
index 446142c..54fc0ad 100644
--- a/docs/examples/advanced/custom_splitting/training_individual_splitters.ipynb
+++ b/docs/examples/advanced/custom_splitting/training_individual_splitters.ipynb
@@ -5,7 +5,7 @@
"id": "0",
"metadata": {},
"source": [
- "# Example for single patient to convert using the instruction setup with custom dataset"
+ "# Example for single patient to convert using the unified DataSplitter API with custom dataset"
]
},
{
@@ -27,8 +27,9 @@
"\n",
"from twinweaver import (\n",
" DataSplitterForecasting,\n",
- " DataManager,\n",
" DataSplitterEvents,\n",
+ " DataSplitter,\n",
+ " DataManager,\n",
" ConverterInstruction,\n",
" Config,\n",
")"
@@ -67,7 +68,7 @@
"id": "6",
"metadata": {},
"source": [
- "Set up the data managers which hold the patient data."
+ "Set up the data managers and the unified `DataSplitter` which combines both event and forecasting splitters."
]
},
{
@@ -90,7 +91,7 @@
"\n",
"# 3. Mapping of specific time to events to predict (e.g., we want to predict 'death' and 'progression')\n",
"# Only needs to be set if you want to do time to event prediction\n",
- "config.data_splitter_events_variables_category_mapping = {\n",
+ "config.event_category_events_prediction_with_naming = {\n",
" \"death\": \"death\",\n",
" \"progression\": \"next progression\", # Custom name in prompt: \"next progression\" instead of \"progression\"\n",
"}\n",
@@ -111,19 +112,31 @@
"dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n",
"dm.process_indication_data()\n",
"dm.setup_unique_mapping_of_events()\n",
- "dm.setup_dataset_splits()\n",
+ "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n",
"dm.infer_var_types()\n",
"\n",
- "data_splitter_events = DataSplitterEvents(dm, config=config)\n",
+ "data_splitter_events = DataSplitterEvents(\n",
+ " dm,\n",
+ " config=config,\n",
+ " max_length_to_sample=pd.Timedelta(weeks=104),\n",
+ " min_length_to_sample=pd.Timedelta(weeks=1),\n",
+ ")\n",
"data_splitter_events.setup_variables()\n",
"\n",
"data_splitter_forecasting = DataSplitterForecasting(\n",
" data_manager=dm,\n",
" config=config,\n",
+ " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n",
")\n",
"# In case you manually want to override the variables for forecasting selectiong, you can skip this next line.\n",
"data_splitter_forecasting.setup_statistics()\n",
"\n",
+ "# Use the unified DataSplitter API that combines both splitters\n",
+ "data_splitter = DataSplitter(\n",
+ " data_splitter_events=data_splitter_events,\n",
+ " data_splitter_forecasting=data_splitter_forecasting,\n",
+ ")\n",
+ "\n",
"converter = ConverterInstruction(\n",
" nr_tokens_budget_total=8192,\n",
" config=config,\n",
@@ -206,9 +219,9 @@
"id": "16",
"metadata": {},
"source": [
- "We start by generating random \"splits\" in the patient trajectory. We can make multiple relevant samples from each patient trajectory (e.g. depending on when the therapy started), and also to predict different variables (e.g. neutrophils/hemoglobin/... for forecasting, death/progression/metastases/next treatment for event).\n",
+ "We start by generating random \"splits\" in the patient trajectory using the unified `DataSplitter.get_splits_from_patient_with_target` method. This ensures that both forecasting and event splits use the same anchor points in time. We can make multiple relevant samples from each patient trajectory (e.g. depending on when the therapy started), and also to predict different variables (e.g. neutrophils/hemoglobin/... for forecasting, death/progression/metastases/next treatment for event).\n",
"\n",
- "Here we generate these random splits. We can also manually override them (see other examples on inference)."
+ "We can also manually override them (see other examples on inference)."
]
},
{
@@ -218,18 +231,12 @@
"metadata": {},
"outputs": [],
"source": [
- "processed_splits_fc, split_dates = data_splitter_forecasting.get_splits_from_patient(\n",
+ "forecasting_splits, events_splits, reference_dates = data_splitter.get_splits_from_patient_with_target(\n",
" patient_data,\n",
- " nr_samples_per_split=4,\n",
- " filter_outliers=False,\n",
- " include_metadata=True,\n",
+ " forecasting_nr_samples_per_split=4,\n",
+ " forecasting_filter_outliers=False,\n",
" max_num_splits_per_split_event=2,\n",
- ")\n",
- "\n",
- "processed_splits_ev = data_splitter_events.get_splits_from_patient(\n",
- " patient_data,\n",
- " reference_split_dates=split_dates,\n",
- " max_nr_samples_per_split=3,\n",
+ " events_max_nr_samples_per_split=3,\n",
")"
]
},
@@ -250,9 +257,8 @@
"source": [
"split_idx = 0\n",
"p_converted = converter.forward_conversion(\n",
- " forecasting_splits=processed_splits_fc[split_idx],\n",
- " event_splits=processed_splits_ev[split_idx],\n",
- " override_mode_to_select_forecasting=\"forecasting_qa\",\n",
+ " forecasting_splits=forecasting_splits[split_idx],\n",
+ " event_splits=events_splits[split_idx],\n",
")"
]
},
@@ -296,15 +302,15 @@
"metadata": {},
"outputs": [],
"source": [
- "date = split_dates[\"date\"][0]\n",
+ "date = reference_dates[\"date\"][0]\n",
"return_list = converter.reverse_conversion(p_converted[\"answer\"], dm, date)\n",
- "return_list[2][\"result\"]"
+ "return_list[0][\"result\"]"
]
}
],
"metadata": {
"kernelspec": {
- "display_name": ".venv_dev",
+ "display_name": ".venv_dev_gpu",
"language": "python",
"name": "python3"
},
diff --git a/docs/examples/advanced/forecasting_inference/forecasting_vllm_inference.ipynb b/docs/examples/advanced/forecasting_inference/forecasting_vllm_inference.ipynb
new file mode 100644
index 0000000..f57bd1c
--- /dev/null
+++ b/docs/examples/advanced/forecasting_inference/forecasting_vllm_inference.ipynb
@@ -0,0 +1,698 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "0",
+ "metadata": {},
+ "source": [
+ "# Forecasting Inference with TwinWeaver and vLLM\n",
+ "\n",
+ "This notebook demonstrates how to use TwinWeaver's **forecasting inference pipeline**\n",
+ "to predict future clinical values (e.g., lab results like hemoglobin) using a\n",
+ "fine-tuned LLM served via [vLLM](https://github.com/vllm-project/vllm).\n",
+ "\n",
+ "Unlike the TTE (time-to-event) probability pipeline, which *scores* fixed\n",
+ "completions, this pipeline uses **free-text generation**: the model produces\n",
+ "an answer string that is then **reverse-converted** back into a structured\n",
+ "DataFrame with predicted dates and values.\n",
+ "\n",
+ "### Pipeline overview\n",
+ "\n",
+ "```\n",
+ "Patient data ──► DataSplitter (forecasting) ──► ConverterInstruction\n",
+ " │\n",
+ " instruction text per patient\n",
+ " │\n",
+ " ┌──────────────────────────────┘\n",
+ " ▼\n",
+ " vLLM server (OpenAI-compatible API)\n",
+ " │\n",
+ " generated text completions\n",
+ " │\n",
+ " ▼\n",
+ " parse_forecasting_results()\n",
+ " (calls converter.reverse_conversion internally)\n",
+ " │\n",
+ " structured DataFrame\n",
+ " with predicted dates & values\n",
+ "```\n",
+ "\n",
+ "> **⚠️ Important:** The quality of the forecasts depends critically\n",
+ "> on having a **fine-tuned model**. An off-the-shelf instruction model\n",
+ "> (like the default `microsoft/Phi-4-mini-instruct` used here for demonstration)\n",
+ "> will produce **meaningless predictions**. Always fine-tune on your\n",
+ "> clinical dataset first (see `03_end_to_end_llm_finetuning.ipynb`).\n",
+ "\n",
+ "> **Requirements:**\n",
+ "> - A GPU with enough memory to serve the model via vLLM (≥16 GB for a 4-bit 8B model)\n",
+ "> - `pip install twinweaver[fine-tuning-example] vllm openai`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import subprocess\n",
+ "import time\n",
+ "import sys\n",
+ "import os\n",
+ "\n",
+ "import pandas as pd\n",
+ "\n",
+ "from twinweaver import (\n",
+ " DataManager,\n",
+ " Config,\n",
+ " DataSplitterForecasting,\n",
+ " DataSplitterEvents,\n",
+ " DataSplitter,\n",
+ " ConverterInstruction,\n",
+ " run_forecasting_inference_notebook,\n",
+ " parse_forecasting_results,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2",
+ "metadata": {},
+ "source": [
+ "## 1. Configuration\n",
+ "\n",
+ "We define all key settings up front so they are easy to change in one place.\n",
+ "\n",
+ "> **Note:** Replace `MODEL_PATH` with the path to your **fine-tuned** model for\n",
+ "> meaningful results. The default `microsoft/Phi-4-mini-instruct` is only used\n",
+ "> here so that the notebook is self-contained."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# ---------------------------------------------------------------------------\n",
+ "# Model & server settings\n",
+ "# ---------------------------------------------------------------------------\n",
+ "MODEL_PATH = \"microsoft/Phi-4-mini-instruct\" # ⚠️ Replace with your fine-tuned model path!!!!!!!!!!!\n",
+ "TOKENIZER_PATH = MODEL_PATH # Usually the same as the model path\n",
+ "\n",
+ "VLLM_PORT = 8000\n",
+ "PREDICTION_URL = f\"http://0.0.0.0:{VLLM_PORT}/v1/\"\n",
+ "\n",
+ "MAX_CONTEXT_LENGTH = 4096 # Must match what the model was trained with\n",
+ "\n",
+ "# Generation settings\n",
+ "MAX_NEW_TOKENS = 256 # Max tokens for the generated forecast answer\n",
+ "TEMPERATURE = 0.9 # Sampling temperature (0 = greedy)\n",
+ "TOP_P = 0.9 # Nucleus sampling\n",
+ "N_SAMPLES = 3 # Number of independent samples per patient (>1 enables aggregation)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# ---------------------------------------------------------------------------\n",
+ "# TwinWeaver data settings (same as used during training)\n",
+ "# ---------------------------------------------------------------------------\n",
+ "config = Config()\n",
+ "\n",
+ "# 1. Event category used for data splitting\n",
+ "config.split_event_category = \"lot\"\n",
+ "\n",
+ "# 2. Forecasting categories\n",
+ "config.event_category_forecast = [\"lab\"]\n",
+ "\n",
+ "# 3. Time-to-event variables (needed to initialise splitters, even if we only forecast)\n",
+ "config.event_category_events_prediction_with_naming = {\n",
+ " \"death\": \"death\",\n",
+ " \"progression\": \"next progression\",\n",
+ "}\n",
+ "\n",
+ "# 4. Constant (static) columns\n",
+ "config.constant_columns_to_use = [\"birthyear\", \"gender\", \"histology\", \"smoking_history\"]\n",
+ "config.constant_birthdate_column = \"birthyear\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5",
+ "metadata": {},
+ "source": [
+ "## 2. Load and prepare data\n",
+ "\n",
+ "We use the same example data shipped with TwinWeaver. In a real scenario you\n",
+ "would load your own clinical dataset here."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Load example data (adjust paths if running from a different directory)\n",
+ "df_events = pd.read_csv(\"../../example_data/events.csv\")\n",
+ "df_constant = pd.read_csv(\"../../example_data/constant.csv\")\n",
+ "df_constant_description = pd.read_csv(\"../../example_data/constant_description.csv\")\n",
+ "\n",
+ "print(f\"Loaded {len(df_events)} events for {df_events['patientid'].nunique()} patients\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Initialise DataManager and all splitters\n",
+ "dm = DataManager(config=config)\n",
+ "dm.load_indication_data(\n",
+ " df_events=df_events,\n",
+ " df_constant=df_constant,\n",
+ " df_constant_description=df_constant_description,\n",
+ ")\n",
+ "dm.process_indication_data()\n",
+ "dm.setup_unique_mapping_of_events()\n",
+ "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n",
+ "dm.infer_var_types()\n",
+ "\n",
+ "data_splitter_events = DataSplitterEvents(\n",
+ " dm,\n",
+ " config=config,\n",
+ " max_length_to_sample=pd.Timedelta(weeks=104),\n",
+ " min_length_to_sample=pd.Timedelta(weeks=1),\n",
+ ")\n",
+ "data_splitter_events.setup_variables()\n",
+ "\n",
+ "data_splitter_forecasting = DataSplitterForecasting(\n",
+ " data_manager=dm,\n",
+ " config=config,\n",
+ " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n",
+ ")\n",
+ "data_splitter_forecasting.setup_statistics()\n",
+ "\n",
+ "# Combined interface\n",
+ "data_splitter = DataSplitter(data_splitter_events, data_splitter_forecasting)\n",
+ "\n",
+ "converter = ConverterInstruction(\n",
+ " nr_tokens_budget_total=MAX_CONTEXT_LENGTH,\n",
+ " config=config,\n",
+ " dm=dm,\n",
+ " variable_stats=data_splitter_forecasting.variable_stats,\n",
+ ")\n",
+ "\n",
+ "print(\"✅ Data pipeline ready\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8",
+ "metadata": {},
+ "source": [
+ "## 3. Generate forecasting instruction prompts\n",
+ "\n",
+ "For each test patient we generate a **forecasting-only** instruction prompt.\n",
+ "The key parameters are:\n",
+ "\n",
+ "- `forecasting_override_variables_to_predict` – which variables to forecast\n",
+ " (e.g. hemoglobin)\n",
+ "- `forecasting_future_weeks_per_variable` – at which future time points\n",
+ " (in weeks) to request predictions\n",
+ "\n",
+ "The converter produces a text instruction asking the model to predict the\n",
+ "specified variable(s) at the given future time points."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Choose patients to evaluate (here: all test-set patients)\n",
+ "test_patientids = dm.get_all_patientids_in_split(config.test_split_name)\n",
+ "print(f\"Number of test patients: {len(test_patientids)}\")\n",
+ "\n",
+ "# Define the prediction task:\n",
+ "# Which variables to predict and at which future week offsets\n",
+ "VARIABLES_TO_PREDICT = [\"hemoglobin_-_718-7\"]\n",
+ "FUTURE_WEEKS = {\n",
+ " \"hemoglobin_-_718-7\": [4, 8, 12], # Predict hemoglobin at 4, 8, and 12 weeks\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "10",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Build the list of prompt payloads for the forecasting API\n",
+ "# Each payload is a dict with \"patientid\", \"instruction\", and \"split_date\"\n",
+ "prompts_with_meta: list[dict] = []\n",
+ "\n",
+ "for pid in test_patientids:\n",
+ " patient_data = dm.get_patient_data(pid)\n",
+ "\n",
+ " # Use only data up to first lot event (simulates baseline information)\n",
+ " patient_data[\"events\"] = patient_data[\"events\"].sort_values(\"date\")\n",
+ " first_lot_date = patient_data[\"events\"][patient_data[\"events\"][\"event_category\"] == \"lot\"][\"date\"].min()\n",
+ " assert pd.notna(first_lot_date), f\"Patient {pid} has no lot event\"\n",
+ " patient_data[\"events\"] = patient_data[\"events\"][patient_data[\"events\"][\"date\"] <= first_lot_date].copy()\n",
+ "\n",
+ " # Generate the forecasting-only split for inference\n",
+ " forecast_split, _ = data_splitter.get_splits_from_patient_inference(\n",
+ " patient_data,\n",
+ " inference_type=\"forecasting\",\n",
+ " forecasting_override_variables_to_predict=VARIABLES_TO_PREDICT,\n",
+ " )\n",
+ "\n",
+ " # Convert to instruction text (no target answer)\n",
+ " converted = converter.forward_conversion_inference(\n",
+ " forecasting_split=forecast_split,\n",
+ " forecasting_future_weeks_per_variable=FUTURE_WEEKS,\n",
+ " )\n",
+ "\n",
+ " # Collect the prompt payload\n",
+ " prompts_with_meta.append(\n",
+ " {\n",
+ " \"patientid\": pid,\n",
+ " \"instruction\": converted[\"instruction\"],\n",
+ " \"split_date\": forecast_split.split_date_included_in_input,\n",
+ " }\n",
+ " )\n",
+ "\n",
+ "print(f\"Generated {len(prompts_with_meta)} instruction prompts\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "11",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Let's inspect one instruction to see what the model will receive\n",
+ "sample = prompts_with_meta[0]\n",
+ "print(f\"=== Patient: {sample['patientid']} ===\")\n",
+ "print(f\"Split date: {sample['split_date']}\\n\")\n",
+ "print(sample[\"instruction\"])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "12",
+ "metadata": {},
+ "source": [
+ "## 4. Launch the vLLM server\n",
+ "\n",
+ "We launch a vLLM OpenAI-compatible server as a background process.\n",
+ "\n",
+ "> **If you already have a vLLM server running**, skip this cell and just update\n",
+ "> `PREDICTION_URL` and `MODEL_PATH` in the configuration section above.\n",
+ "\n",
+ "> **Tip:** For production use, launch the server in a separate terminal."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "13",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Launch vLLM server as a background subprocess\n",
+ "# Set to True to skip launching (if you already have a server running)\n",
+ "SKIP_VLLM_LAUNCH = False\n",
+ "\n",
+ "vllm_process = None\n",
+ "\n",
+ "if not SKIP_VLLM_LAUNCH:\n",
+ " env = os.environ.copy()\n",
+ " env[\"VLLM_ATTENTION_BACKEND\"] = \"FLASH_ATTN\"\n",
+ "\n",
+ " vllm_command = [\n",
+ " sys.executable,\n",
+ " \"-m\",\n",
+ " \"vllm.entrypoints.openai.api_server\",\n",
+ " \"--port\",\n",
+ " str(VLLM_PORT),\n",
+ " \"--model\",\n",
+ " MODEL_PATH,\n",
+ " \"--tokenizer\",\n",
+ " TOKENIZER_PATH,\n",
+ " \"--enable-prefix-caching\",\n",
+ " ]\n",
+ "\n",
+ " print(f\"🚀 Launching vLLM server:\\n {' '.join(vllm_command)}\\n\")\n",
+ "\n",
+ " vllm_process = subprocess.Popen(\n",
+ " vllm_command,\n",
+ " env=env,\n",
+ " stdout=subprocess.PIPE,\n",
+ " stderr=subprocess.STDOUT,\n",
+ " text=True,\n",
+ " )\n",
+ "\n",
+ " # Wait for the server to be ready\n",
+ " WAIT_SECONDS = 240\n",
+ " print(f\"⏳ Waiting up to {WAIT_SECONDS}s for the server to start...\")\n",
+ " import urllib.request\n",
+ "\n",
+ " for i in range(WAIT_SECONDS):\n",
+ " time.sleep(1)\n",
+ " try:\n",
+ " urllib.request.urlopen(f\"http://localhost:{VLLM_PORT}/health\")\n",
+ " print(f\"✅ vLLM server is ready after {i + 1}s\")\n",
+ " break\n",
+ " except Exception:\n",
+ " pass\n",
+ " else:\n",
+ " print(\"⚠️ Server did not respond in time. Check GPU memory and logs.\")\n",
+ " print(\" You can read server output with: vllm_process.stdout.read()\")\n",
+ "else:\n",
+ " print(\"Skipping vLLM launch – using existing server.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "14",
+ "metadata": {},
+ "source": [
+ "## 5. Run forecasting inference\n",
+ "\n",
+ "This is the core step. `run_forecasting_inference_notebook` sends each patient's\n",
+ "instruction to the vLLM server and collects the generated text completions.\n",
+ "\n",
+ "Under the hood it:\n",
+ "1. Wraps each instruction in a chat message (with optional system prompt).\n",
+ "2. Calls the OpenAI-compatible `/v1/chat/completions` endpoint.\n",
+ "3. Returns the generated text(s) alongside patient metadata.\n",
+ "\n",
+ "When `n_samples > 1`, multiple independent completions are generated per\n",
+ "patient, which can later be aggregated into a mean trajectory."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "15",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Run the forecasting inference\n",
+ "# This calls the vLLM server asynchronously for all patients\n",
+ "raw_results = await run_forecasting_inference_notebook(\n",
+ " prompts_with_meta,\n",
+ " prediction_url=PREDICTION_URL,\n",
+ " prediction_model=MODEL_PATH,\n",
+ " max_concurrent_requests=40,\n",
+ " max_new_tokens=MAX_NEW_TOKENS,\n",
+ " temperature=TEMPERATURE,\n",
+ " top_p=TOP_P,\n",
+ " n_samples=N_SAMPLES,\n",
+ " api_key=\"EMPTY\",\n",
+ " timeout=600.0,\n",
+ ")\n",
+ "\n",
+ "# Check for failures\n",
+ "n_success = sum(1 for r in raw_results if r is not None)\n",
+ "n_fail = sum(1 for r in raw_results if r is None)\n",
+ "print(f\"✅ Generated forecasts for {n_success} patients, {n_fail} failures\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "16",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Let's inspect the raw generated text for one patient\n",
+ "for r in raw_results:\n",
+ " if r is not None:\n",
+ " print(f\"=== Patient: {r['patientid']} ===\")\n",
+ " for i, text in enumerate(r[\"generated_texts\"]):\n",
+ " print(f\"\\n--- Sample {i} ---\")\n",
+ " print(text)\n",
+ " break"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "17",
+ "metadata": {},
+ "source": [
+ "## 6. Parse results with reverse conversion\n",
+ "\n",
+ "`parse_forecasting_results` takes the raw generated texts and:\n",
+ "\n",
+ "1. Calls `converter.reverse_conversion` on each generated text to parse it\n",
+ " back into a structured DataFrame with dates and predicted values.\n",
+ "2. When `n_samples > 1` and `aggregate_samples=True`, aggregates multiple\n",
+ " trajectories using `converter.aggregate_multiple_responses` (e.g. averaging\n",
+ " numeric predictions).\n",
+ "3. Returns a single long-format DataFrame with all patients' predictions.\n",
+ "\n",
+ "> **Note:** Reverse conversion is robust to slightly malformed model output\n",
+ "> thanks to `inference_override=True`, but a fine-tuned model will produce\n",
+ "> much more parseable results than a generic instruction model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "18",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "converter = ConverterInstruction(\n",
+ " nr_tokens_budget_total=MAX_CONTEXT_LENGTH,\n",
+ " config=config,\n",
+ " dm=dm,\n",
+ " variable_stats=data_splitter_forecasting.variable_stats,\n",
+ ")\n",
+ "print(\"✅ Converter reloaded with latest code\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "19",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Parse the generated texts into structured DataFrames\n",
+ "df_results = parse_forecasting_results(\n",
+ " raw_results,\n",
+ " converter,\n",
+ " dm,\n",
+ " drop_failures=True,\n",
+ " aggregate_samples=(N_SAMPLES > 1), # Only aggregate if we have multiple samples\n",
+ ")\n",
+ "\n",
+ "if df_results.empty:\n",
+ " print(\"⚠️ No predictions could be parsed. This is expected when using a non-fine-tuned model.\")\n",
+ " print(\" Fine-tune a model first (see 03_end_to_end_llm_finetuning.ipynb) for meaningful results.\")\n",
+ "else:\n",
+ " print(f\"Parsed {len(df_results)} prediction rows for {df_results['patientid'].nunique()} patients\")\n",
+ " df_results.head(20)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "20",
+ "metadata": {},
+ "source": [
+ "### Understanding the output\n",
+ "\n",
+ "The returned DataFrame has the standard TwinWeaver event format:\n",
+ "\n",
+ "| Column | Description |\n",
+ "|---|---|\n",
+ "| `date` | Predicted date (computed from split_date + week offset) |\n",
+ "| `event_name` | The variable being predicted (e.g. `hemoglobin_-_718-7`) |\n",
+ "| `event_value` | The predicted value |\n",
+ "| `event_category` | Category of the event (e.g. `lab`) |\n",
+ "| `patientid` | Patient identifier |\n",
+ "| `task_type` | Which task type produced this prediction |\n",
+ "| `sample_idx` | Sample index (when `aggregate_samples=False` and `n_samples > 1`) |"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "21",
+ "metadata": {},
+ "source": [
+ "## 7. Multi-sample aggregation (optional)\n",
+ "\n",
+ "When using `n_samples > 1`, each patient gets multiple independent forecast\n",
+ "trajectories. Aggregation (e.g. averaging numeric predictions) can reduce\n",
+ "variance and give more robust estimates.\n",
+ "\n",
+ "Below is an example of running with multiple samples and then aggregating."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "22",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Example: generate 3 samples per patient and aggregate\n",
+ "N_SAMPLES_AGG = 3\n",
+ "\n",
+ "raw_results_multi = await run_forecasting_inference_notebook(\n",
+ " prompts_with_meta,\n",
+ " prediction_url=PREDICTION_URL,\n",
+ " prediction_model=MODEL_PATH,\n",
+ " max_concurrent_requests=40,\n",
+ " max_new_tokens=MAX_NEW_TOKENS,\n",
+ " temperature=TEMPERATURE,\n",
+ " top_p=TOP_P,\n",
+ " n_samples=N_SAMPLES_AGG,\n",
+ " api_key=\"EMPTY\",\n",
+ ")\n",
+ "\n",
+ "# Parse with aggregation enabled\n",
+ "df_aggregated = parse_forecasting_results(\n",
+ " raw_results_multi,\n",
+ " converter,\n",
+ " dm,\n",
+ " drop_failures=True,\n",
+ " aggregate_samples=True, # Average numeric values across samples\n",
+ ")\n",
+ "\n",
+ "if df_aggregated.empty:\n",
+ " print(\"⚠️ No predictions could be parsed. This is expected when using a non-fine-tuned model.\")\n",
+ "else:\n",
+ " print(f\"Aggregated results: {len(df_aggregated)} rows for {df_aggregated['patientid'].nunique()} patients\")\n",
+ " df_aggregated.head(20)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "23",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# You can also get the individual (non-aggregated) samples for deeper analysis\n",
+ "df_individual = parse_forecasting_results(\n",
+ " raw_results_multi,\n",
+ " converter,\n",
+ " dm,\n",
+ " drop_failures=True,\n",
+ " aggregate_samples=False, # Keep individual samples\n",
+ ")\n",
+ "\n",
+ "if df_individual.empty:\n",
+ " print(\"⚠️ No predictions could be parsed. This is expected when using a non-fine-tuned model.\")\n",
+ "else:\n",
+ " print(f\"Individual results: {len(df_individual)} rows\")\n",
+ " print(f\"Samples per patient: {df_individual.groupby('patientid')['sample_idx'].nunique().describe()}\")\n",
+ " df_individual.head(20)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "24",
+ "metadata": {},
+ "source": [
+ "## 8. Clean up\n",
+ "\n",
+ "Shut down the vLLM server if we launched it from this notebook."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "25",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if vllm_process is not None:\n",
+ " print(\"Terminating vLLM server...\")\n",
+ " vllm_process.terminate()\n",
+ " vllm_process.wait(timeout=10)\n",
+ " print(\"✅ Server stopped.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "26",
+ "metadata": {},
+ "source": [
+ "## Summary\n",
+ "\n",
+ "This notebook demonstrated the full forecasting inference workflow:\n",
+ "\n",
+ "1. **Data preparation** – Load clinical data and generate forecasting instruction prompts\n",
+ "2. **Model serving** – Launch (or connect to) a vLLM OpenAI-compatible server\n",
+ "3. **Text generation** – Use `run_forecasting_inference_notebook` to generate completions\n",
+ "4. **Reverse conversion** – Use `parse_forecasting_results` to convert text → structured DataFrame\n",
+ "5. **Aggregation** – Optionally average multiple samples for more robust predictions\n",
+ "\n",
+ "### Key functions\n",
+ "\n",
+ "| Function | Purpose |\n",
+ "|---|---|\n",
+ "| `run_forecasting_inference()` | Generate completions for all patients (sync wrapper) |\n",
+ "| `run_forecasting_inference_notebook()` | Same but async – for notebooks |\n",
+ "| `parse_forecasting_results()` | Reverse-convert generated text → structured DataFrame |\n",
+ "\n",
+ "### Comparison with TTE probability inference\n",
+ "\n",
+ "| Aspect | TTE Inference | Forecasting Inference |\n",
+ "|---|---|---|\n",
+ "| **Method** | Log-prob scoring of fixed completions | Free-text generation |\n",
+ "| **Output** | Probabilities (censored/occurred/not occurred) | Predicted values at future time points |\n",
+ "| **API endpoint** | `/v1/completions` (scoring) | `/v1/chat/completions` (generation) |\n",
+ "| **Post-processing** | Softmax over log-probs | Reverse conversion (text → DataFrame) |\n",
+ "| **Multi-sample** | Not applicable | Average trajectories via `aggregate_multiple_responses` |\n",
+ "\n",
+ "### Next steps\n",
+ "\n",
+ "- **Fine-tune a model** on your dataset using `03_end_to_end_llm_finetuning.ipynb`\n",
+ "- **Combine with TTE** by using `inference_type=\"both\"` in the data splitter\n",
+ "- **Evaluate** predictions against ground truth to assess model performance\n",
+ "- **Experiment** with different variables, time horizons, and sample counts"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv_dev_gpu",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/docs/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb b/docs/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb
index 9c7fd54..fde4635 100644
--- a/docs/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb
+++ b/docs/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb
@@ -88,8 +88,7 @@
"dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n",
"dm.process_indication_data()\n",
"dm.setup_unique_mapping_of_events()\n",
- "dm.setup_dataset_splits()\n",
- "dm.infer_var_types()\n",
+ "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n",
"\n",
"converter = ConverterPretrain(config=config, dm=dm)"
]
@@ -401,9 +400,7 @@
"# Lets simulate forecasts for after the first line of therapy\n",
"df_constant_patient = patient_data[\"constant\"].copy()\n",
"df_events_patient = patient_data[\"events\"].copy()\n",
- "date_of_first_lot = df_events_patient.loc[\n",
- " df_events_patient[\"event_category\"] == config.event_category_lot, \"date\"\n",
- "].min()\n",
+ "date_of_first_lot = df_events_patient.loc[df_events_patient[\"event_category\"] == \"lot\", \"date\"].min()\n",
"date_of_first_event = df_events_patient[\"date\"].min()\n",
"\n",
"# Only keep data until (and including) first line of therapy\n",
@@ -507,7 +504,7 @@
"outputs": [],
"source": [
"# Show the generated answer\n",
- "generated_answer"
+ "print(generated_answer)"
]
},
{
@@ -543,7 +540,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": ".venv_test",
+ "display_name": ".venv_dev_gpu",
"language": "python",
"name": "python3"
},
diff --git a/docs/examples/advanced/pretraining/prepare_pretraining_data.md b/docs/examples/advanced/pretraining/prepare_pretraining_data.md
new file mode 100644
index 0000000..cf03bad
--- /dev/null
+++ b/docs/examples/advanced/pretraining/prepare_pretraining_data.md
@@ -0,0 +1,17 @@
+# Prepare Pretraining Data
+
+This script demonstrates how to prepare pretraining data using the `ConverterPretrain` class. Unlike instruction-tuning, pretraining converts full patient timelines into continuous text without explicit question/answer formatting.
+
+Key concepts:
+
+- Setting up `Config` and `DataManager` for pretraining
+- Using `ConverterPretrain.forward_conversion()` to convert patient data to text
+- Verifying data integrity with `ConverterPretrain.reverse_conversion()`
+- Checking round-trip consistency with `get_difference_in_event_dataframes()`
+
+!!! note "Run from project root"
+ This script should be run from the root folder of the TwinWeaver repository.
+
+```python title="prepare_pretraining_data.py"
+--8<-- "examples/advanced/pretraining/prepare_pretraining_data.py"
+```
diff --git a/docs/examples/advanced/pretraining/prepare_pretraining_data.py b/docs/examples/advanced/pretraining/prepare_pretraining_data.py
index 05c5d3b..80a5402 100644
--- a/docs/examples/advanced/pretraining/prepare_pretraining_data.py
+++ b/docs/examples/advanced/pretraining/prepare_pretraining_data.py
@@ -28,7 +28,7 @@ def __init__(self):
)
self.dm.process_indication_data()
self.dm.setup_unique_mapping_of_events()
- self.dm.setup_dataset_splits()
+ self.dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)
#: set up converter
self.converter = ConverterPretrain(config=self.config, dm=self.dm)
diff --git a/docs/examples/advanced/tte_inference/tte_probability_inference.ipynb b/docs/examples/advanced/tte_inference/tte_probability_inference.ipynb
index 9e7a2b2..c442b0b 100644
--- a/docs/examples/advanced/tte_inference/tte_probability_inference.ipynb
+++ b/docs/examples/advanced/tte_inference/tte_probability_inference.ipynb
@@ -136,7 +136,7 @@
"# 3. Time-to-event variables to predict\n",
"# Keys = event_category values in the events DataFrame\n",
"# Values = human-readable name used in the prompt\n",
- "config.data_splitter_events_variables_category_mapping = {\n",
+ "config.event_category_events_prediction_with_naming = {\n",
" \"death\": \"death\",\n",
" \"progression\": \"next progression\",\n",
"}\n",
@@ -188,13 +188,22 @@
")\n",
"dm.process_indication_data()\n",
"dm.setup_unique_mapping_of_events()\n",
- "dm.setup_dataset_splits()\n",
+ "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n",
"dm.infer_var_types()\n",
"\n",
- "data_splitter_events = DataSplitterEvents(dm, config=config)\n",
+ "data_splitter_events = DataSplitterEvents(\n",
+ " dm,\n",
+ " config=config,\n",
+ " max_length_to_sample=pd.Timedelta(weeks=104),\n",
+ " min_length_to_sample=pd.Timedelta(weeks=1),\n",
+ ")\n",
"data_splitter_events.setup_variables()\n",
"\n",
- "data_splitter_forecasting = DataSplitterForecasting(data_manager=dm, config=config)\n",
+ "data_splitter_forecasting = DataSplitterForecasting(\n",
+ " data_manager=dm,\n",
+ " config=config,\n",
+ " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n",
+ ")\n",
"data_splitter_forecasting.setup_statistics()\n",
"\n",
"# Combined interface\n",
@@ -581,7 +590,15 @@
"print(f\"\\n✅ Total rows: {len(df_all_horizons)}\")\n",
"df_all_horizons = df_all_horizons.sort_values([\"patientid\", \"week_horizon\"])\n",
"df_all_horizons[\n",
- " [\"patientid\", \"week_horizon\", \"probability_occurrence\", \"probability_no_occurrence\", \"probability_censored\"]\n",
+ " [\n",
+ " \"patientid\",\n",
+ " \"week_horizon\",\n",
+ " \"probability_occurrence\",\n",
+ " \"probability_no_occurrence\",\n",
+ " \"probability_censored\",\n",
+ " \"probability_occurrence_renormalized\",\n",
+ " \"probability_no_occurrence_renormalized\",\n",
+ " ]\n",
"]"
]
},
diff --git a/docs/examples/data_preprocessing/raw_data_preprocessing.ipynb b/docs/examples/data_preprocessing/raw_data_preprocessing.ipynb
index efc5472..d7eeccd 100644
--- a/docs/examples/data_preprocessing/raw_data_preprocessing.ipynb
+++ b/docs/examples/data_preprocessing/raw_data_preprocessing.ipynb
@@ -1007,7 +1007,7 @@
"\n",
"# 3. Mapping of specific time to events to predict (e.g., we want to predict 'death' and 'progression')\n",
"# Only needs to be set if you want to do time to event prediction\n",
- "config.data_splitter_events_variables_category_mapping = {\n",
+ "config.event_category_events_prediction_with_naming = {\n",
" \"death\": \"death\",\n",
" \"progression\": \"next progression\", # Custom name in prompt: \"next progression\" instead of \"progression\"\n",
"}\n",
@@ -1036,7 +1036,7 @@
"dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n",
"dm.process_indication_data()\n",
"dm.setup_unique_mapping_of_events()\n",
- "dm.setup_dataset_splits()\n",
+ "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n",
"dm.infer_var_types()\n",
"\n",
"print(f\"Loaded {len(dm.all_patientids)} patients into DataManager\")"
@@ -1050,12 +1050,18 @@
"outputs": [],
"source": [
"# Initialize data splitters and converter\n",
- "data_splitter_events = DataSplitterEvents(dm, config=config)\n",
+ "data_splitter_events = DataSplitterEvents(\n",
+ " dm,\n",
+ " config=config,\n",
+ " max_length_to_sample=pd.Timedelta(weeks=104),\n",
+ " min_length_to_sample=pd.Timedelta(weeks=1),\n",
+ ")\n",
"data_splitter_events.setup_variables()\n",
"\n",
"data_splitter_forecasting = DataSplitterForecasting(\n",
" data_manager=dm,\n",
" config=config,\n",
+ " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n",
")\n",
"data_splitter_forecasting.setup_statistics()\n",
"\n",
@@ -1114,7 +1120,6 @@
" p_converted = converter.forward_conversion(\n",
" forecasting_splits=forecasting_splits[split_idx],\n",
" event_splits=events_splits[split_idx],\n",
- " override_mode_to_select_forecasting=\"both\",\n",
" )\n",
"\n",
" print(\"=\" * 80)\n",
diff --git a/docs/examples/hackathon/01_data_preparation_challenge.ipynb b/docs/examples/hackathon/01_data_preparation_challenge.ipynb
index 3cfd551..a4984ff 100644
--- a/docs/examples/hackathon/01_data_preparation_challenge.ipynb
+++ b/docs/examples/hackathon/01_data_preparation_challenge.ipynb
@@ -180,7 +180,7 @@
"# TODO: Configure time-to-event prediction targets\n",
"# HINT: This should be a dictionary mapping event names to display names\n",
"# Example: {\"original_name\": \"display name in prompt\"}\n",
- "config.data_splitter_events_variables_category_mapping = None # Replace with dict"
+ "config.event_category_events_prediction_with_naming = None # Replace with dict"
]
},
{
@@ -218,19 +218,19 @@
" else:\n",
" print(f\"✅ event_category_forecast: {config.event_category_forecast}\")\n",
"\n",
- " if config.data_splitter_events_variables_category_mapping is None:\n",
- " errors.append(\"❌ data_splitter_events_variables_category_mapping is not set\")\n",
- " elif not isinstance(config.data_splitter_events_variables_category_mapping, dict):\n",
- " errors.append(\"❌ data_splitter_events_variables_category_mapping should be a dict\")\n",
+ " if config.event_category_events_prediction_with_naming is None:\n",
+ " errors.append(\"❌ event_category_events_prediction_with_naming is not set\")\n",
+ " elif not isinstance(config.event_category_events_prediction_with_naming, dict):\n",
+ " errors.append(\"❌ event_category_events_prediction_with_naming should be a dict\")\n",
" elif any(\n",
" [\n",
" cat not in df_events[\"event_category\"].unique()\n",
- " for cat in config.data_splitter_events_variables_category_mapping.keys()\n",
+ " for cat in config.event_category_events_prediction_with_naming.keys()\n",
" ]\n",
" ):\n",
- " errors.append(\"❌ At least one key in data_splitter_events_variables_category_mapping not found in data\")\n",
+ " errors.append(\"❌ At least one key in event_category_events_prediction_with_naming not found in data\")\n",
" else:\n",
- " print(f\"✅ Event mapping: {config.data_splitter_events_variables_category_mapping}\")\n",
+ " print(f\"✅ Event mapping: {config.event_category_events_prediction_with_naming}\")\n",
"\n",
" if errors:\n",
" print(\"\\n\" + \"\\n\".join(errors))\n",
@@ -565,7 +565,6 @@
"# Parameters:\n",
"# - forecasting_splits: the forecasting split for one time point\n",
"# - event_splits: the event split for one time point\n",
- "# - override_mode_to_select_forecasting: set to \"both\"\n",
"\n",
"split_idx = 0\n",
"p_converted = None # Replace with actual conversion call"
diff --git a/docs/examples/hackathon/02_llm_finetuning_challenge.ipynb b/docs/examples/hackathon/02_llm_finetuning_challenge.ipynb
index e45d570..9405caf 100644
--- a/docs/examples/hackathon/02_llm_finetuning_challenge.ipynb
+++ b/docs/examples/hackathon/02_llm_finetuning_challenge.ipynb
@@ -150,7 +150,7 @@
"# Set up:\n",
"# - split_event_category\n",
"# - event_category_forecast\n",
- "# - data_splitter_events_variables_category_mapping\n",
+ "# - event_category_events_prediction_with_naming\n",
"# - constant_columns_to_use\n",
"# - constant_birthdate_column\n",
"\n",
@@ -776,9 +776,7 @@
"\n",
"# Get the date of first line of therapy\n",
"df_events_patient = patient_data[\"events\"].copy()\n",
- "date_of_first_lot = df_events_patient.loc[\n",
- " df_events_patient[\"event_category\"] == config.event_category_lot, \"date\"\n",
- "].min()\n",
+ "date_of_first_lot = df_events_patient.loc[df_events_patient[\"event_category\"] == \"lot\", \"date\"].min()\n",
"\n",
"print(f\"Test patient: {test_patientid}\")\n",
"print(f\"First LoT date: {date_of_first_lot}\")"
diff --git a/docs/examples/integrations/meds_data_import.ipynb b/docs/examples/integrations/meds_data_import.ipynb
index c7a2057..5183219 100644
--- a/docs/examples/integrations/meds_data_import.ipynb
+++ b/docs/examples/integrations/meds_data_import.ipynb
@@ -475,10 +475,9 @@
"config = Config() # Override values here to customize pipeline\n",
"config.constant_columns_to_use = constant_columns\n",
"config.constant_birthdate_column = None # Not using in demo\n",
- "config.lot_event_name = None # Setting for LoTs\n",
"config.event_value_lot_start = None\n",
"config.split_event_category = \"lot\"\n",
- "config.data_splitter_events_variables_category_mapping = {\n",
+ "config.event_category_events_prediction_with_naming = {\n",
" \"death\": \"death\",\n",
"}"
]
@@ -500,9 +499,14 @@
")\n",
"dm.process_indication_data()\n",
"dm.setup_unique_mapping_of_events()\n",
- "dm.setup_dataset_splits()\n",
+ "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n",
"\n",
- "data_splitter_events = DataSplitterEvents(dm, config=config)\n",
+ "data_splitter_events = DataSplitterEvents(\n",
+ " dm,\n",
+ " config=config,\n",
+ " max_length_to_sample=pd.Timedelta(weeks=104),\n",
+ " min_length_to_sample=pd.Timedelta(weeks=1),\n",
+ ")\n",
"data_splitter_events.setup_variables()\n",
"converter = ConverterInstruction(\n",
" nr_tokens_budget_total=8192,\n",
@@ -553,6 +557,7 @@
" custom_tasks=None,\n",
")\n",
"\n",
+ "\n",
"print(converted[\"instruction\"])"
]
}
diff --git a/docs/framework.md b/docs/framework.md
index 82ef8bc..5f836e0 100644
--- a/docs/framework.md
+++ b/docs/framework.md
@@ -57,7 +57,7 @@ TwinWeaver supports two primary data formats, each serving a distinct stage in t
- `config.split_event_category`: Event category used to anchor split points (e.g., `"lot"` for line of therapy)
- `config.event_category_forecast`: List of event categories to forecast (e.g., `["lab"]`)
- - `config.data_splitter_events_variables_category_mapping`: Mapping of events to prediction tasks (e.g., death, progression)
+ - `config.event_category_events_prediction_with_naming`: Mapping of events to prediction tasks (e.g., death, progression)
See the [Data Splitting](data-splitting.md) page for a detailed explanation, or the [Quick Start](quickstart.md) and [Data Preparation Tutorial](examples/01_data_preparation_for_training.ipynb) for examples.
diff --git a/docs/images/favicon.png b/docs/images/favicon.png
new file mode 100644
index 0000000..1fa9320
Binary files /dev/null and b/docs/images/favicon.png differ
diff --git a/docs/quickstart.md b/docs/quickstart.md
index 4b55545..1eabc19 100644
--- a/docs/quickstart.md
+++ b/docs/quickstart.md
@@ -33,7 +33,7 @@ config.event_category_forecast = ["lab"]
# 3. Mapping of specific time to events to predict (e.g., we want to predict 'death' and 'progression')
# Only needs to be set if you want to do time to event prediction
-config.data_splitter_events_variables_category_mapping = {
+config.event_category_events_prediction_with_naming = {
"death": "death",
"progression": "next progression", # Custom name in prompt
}
@@ -48,7 +48,7 @@ dm.load_indication_data(
)
dm.process_indication_data()
dm.setup_unique_mapping_of_events()
-dm.setup_dataset_splits()
+dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)
dm.infer_var_types()
# Set up data splitters for different task types
@@ -63,7 +63,10 @@ data_splitter_forecasting = DataSplitterForecasting(
)
# Combined interface for both task types
-data_splitter = DataSplitter(data_splitter_events, data_splitter_forecasting)
+data_splitter = DataSplitter(
+ data_splitter_events=data_splitter_events,
+ data_splitter_forecasting=data_splitter_forecasting,
+)
# Set up the text converter
converter = ConverterInstruction(
@@ -85,7 +88,6 @@ split_idx = 0 # Use first split
training_data = converter.forward_conversion(
forecasting_splits=forecasting_splits[split_idx],
event_splits=events_splits[split_idx],
- override_mode_to_select_forecasting="both",
)
# training_data now contains (Input, Target) pairs ready for LLM fine-tuning
diff --git a/docs/reference/utils/forecasting_inference.md b/docs/reference/utils/forecasting_inference.md
new file mode 100644
index 0000000..03750b5
--- /dev/null
+++ b/docs/reference/utils/forecasting_inference.md
@@ -0,0 +1,3 @@
+# Forecasting Inference
+
+::: twinweaver.utils.forecasting_inference
diff --git a/docs/tte-inference.md b/docs/tte-inference.md
index d0428e1..f71a759 100644
--- a/docs/tte-inference.md
+++ b/docs/tte-inference.md
@@ -41,7 +41,7 @@ Patient data ──► DataSplitter (events) ──► ConverterInstruction
▼
compute_length_normalized_probabilities()
│
- calibrated probabilities
+ probabilities
+ hard predictions (DataFrame)
```
diff --git a/docs/tutorials.md b/docs/tutorials.md
index 30924b9..d3befe6 100644
--- a/docs/tutorials.md
+++ b/docs/tutorials.md
@@ -95,6 +95,7 @@ A complete notebook demonstrating how to train LLMs on full patient histories wi
- **Inference**: [`examples/advanced/custom_splitting/inference_individual_splitters.py`](examples/advanced/custom_splitting/inference_individual_splitters.py) — Example script for inference using individual splitters.
- **Training**: [`examples/advanced/custom_splitting/training_individual_splitters.ipynb`](examples/advanced/custom_splitting/training_individual_splitters.ipynb) — Notebook demonstrating training data generation with individual splitters.
- **Custom Split Events**: [`examples/advanced/custom_splitting/training_custom_split_events.ipynb`](examples/advanced/custom_splitting/training_custom_split_events.ipynb) — Notebook showing how to customize split events and forecast different event categories (e.g., using genetic events as split points and forecasting vitals).
+- **Forecasting QA**: [`examples/advanced/custom_splitting/training_forecasting_qa.ipynb`](examples/advanced/custom_splitting/training_forecasting_qa.ipynb) — Demonstrates the **Forecasting QA** mode, which bins continuous target values into discrete categories for classification-style prediction. Compares all three forecasting modes (`"forecasting"`, `"forecasting_qa"`, `"both"`).
### TTE Probability Inference
diff --git a/examples/01_data_preparation_for_training.ipynb b/examples/01_data_preparation_for_training.ipynb
index 60fe0bd..d42c8f7 100644
--- a/examples/01_data_preparation_for_training.ipynb
+++ b/examples/01_data_preparation_for_training.ipynb
@@ -95,7 +95,7 @@
"\n",
"- **`config.split_event_category`**: Determines the event category used to split patient histories into input and output segments. In this example, we split data around \"Line of Therapy\" (`lot`) start dates.\n",
"- **`config.event_category_forecast`**: Identifies which event categories should be forecasted as time-series values. Here, we target `lab` values.\n",
- "- **`config.data_splitter_events_variables_category_mapping`**: Maps specific events to survival analysis or classification tasks. We configure the model to predict outcomes for `death` and `progression`."
+ "- **`config.event_category_events_prediction_with_naming`**: Maps specific events to survival analysis or classification tasks. We configure the model to predict outcomes for `death` and `progression`."
]
},
{
@@ -118,7 +118,7 @@
"\n",
"# 3. Mapping of specific time to events to predict (e.g., we want to predict 'death' and 'progression')\n",
"# Only needs to be set if you want to do time to event prediction\n",
- "config.data_splitter_events_variables_category_mapping = {\n",
+ "config.event_category_events_prediction_with_naming = {\n",
" \"death\": \"death\",\n",
" \"progression\": \"next progression\", # Custom name in prompt: \"next progression\" instead of \"progression\"\n",
"}"
@@ -175,7 +175,7 @@
"# Setup unique mapping of events, to understand which events correspond to which categories\n",
"dm.setup_unique_mapping_of_events()\n",
"# (Optional) assign each patient to train/validation/test splits\n",
- "dm.setup_dataset_splits()\n",
+ "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n",
"# (Optional - needed for forecasting) infer variable types\n",
"dm.infer_var_types()"
]
@@ -202,13 +202,19 @@
"outputs": [],
"source": [
"# This data splitter handles event prediction tasks\n",
- "data_splitter_events = DataSplitterEvents(dm, config=config)\n",
+ "data_splitter_events = DataSplitterEvents(\n",
+ " dm,\n",
+ " config=config,\n",
+ " max_length_to_sample=pd.Timedelta(weeks=104),\n",
+ " min_length_to_sample=pd.Timedelta(weeks=1),\n",
+ ")\n",
"data_splitter_events.setup_variables()\n",
"\n",
"# This data splitter handles forecasting tasks\n",
"data_splitter_forecasting = DataSplitterForecasting(\n",
" data_manager=dm,\n",
" config=config,\n",
+ " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n",
")\n",
"# If you don't want to do forecasting QA, proportional sampling, or 3-sigma filtering, you can skip this step\n",
"data_splitter_forecasting.setup_statistics()\n",
@@ -339,7 +345,6 @@
"p_converted = converter.forward_conversion(\n",
" forecasting_splits=forecasting_splits[split_idx],\n",
" event_splits=events_splits[split_idx],\n",
- " override_mode_to_select_forecasting=\"both\",\n",
")"
]
},
@@ -418,7 +423,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": ".venv_dev",
+ "display_name": ".venv_dev_gpu",
"language": "python",
"name": "python3"
},
diff --git a/examples/02_inference_prompt_preparation.ipynb b/examples/02_inference_prompt_preparation.ipynb
index 299e9cd..9075f16 100644
--- a/examples/02_inference_prompt_preparation.ipynb
+++ b/examples/02_inference_prompt_preparation.ipynb
@@ -67,7 +67,7 @@
"\n",
"# 3. Mapping of specific time to events to predict (e.g., we want to predict 'death' and 'progression')\n",
"# Only needs to be set if you want to do time to event prediction\n",
- "config.data_splitter_events_variables_category_mapping = {\n",
+ "config.event_category_events_prediction_with_naming = {\n",
" \"death\": \"death\",\n",
" \"progression\": \"next progression\", # Custom name in prompt: \"next progression\" instead of \"progression\"\n",
"}\n",
@@ -88,15 +88,21 @@
"dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n",
"dm.process_indication_data()\n",
"dm.setup_unique_mapping_of_events()\n",
- "dm.setup_dataset_splits()\n",
+ "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n",
"dm.infer_var_types()\n",
"\n",
- "data_splitter_events = DataSplitterEvents(dm, config=config)\n",
+ "data_splitter_events = DataSplitterEvents(\n",
+ " dm,\n",
+ " config=config,\n",
+ " max_length_to_sample=pd.Timedelta(weeks=104),\n",
+ " min_length_to_sample=pd.Timedelta(weeks=1),\n",
+ ")\n",
"data_splitter_events.setup_variables()\n",
"\n",
"data_splitter_forecasting = DataSplitterForecasting(\n",
" data_manager=dm,\n",
" config=config,\n",
+ " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n",
")\n",
"# In case you manually want to override the variables for forecasting selectiong, you can skip this next line.\n",
"data_splitter_forecasting.setup_statistics()\n",
diff --git a/examples/03_end_to_end_llm_finetuning.ipynb b/examples/03_end_to_end_llm_finetuning.ipynb
index 6a2901f..3cfda1d 100644
--- a/examples/03_end_to_end_llm_finetuning.ipynb
+++ b/examples/03_end_to_end_llm_finetuning.ipynb
@@ -109,7 +109,7 @@
"\n",
"# 3. Mapping of specific time to events to predict (e.g., we want to predict 'death' and 'progression')\n",
"# Only needs to be set if you want to do time to event prediction\n",
- "config.data_splitter_events_variables_category_mapping = {\n",
+ "config.event_category_events_prediction_with_naming = {\n",
" \"death\": \"death\",\n",
" \"progression\": \"next progression\", # Custom name in prompt: \"next progression\" instead of \"progression\"\n",
"}\n",
@@ -144,15 +144,21 @@
"dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n",
"dm.process_indication_data()\n",
"dm.setup_unique_mapping_of_events()\n",
- "dm.setup_dataset_splits()\n",
+ "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n",
"dm.infer_var_types()\n",
"\n",
- "data_splitter_events = DataSplitterEvents(dm, config=config)\n",
+ "data_splitter_events = DataSplitterEvents(\n",
+ " dm,\n",
+ " config=config,\n",
+ " max_length_to_sample=pd.Timedelta(weeks=104),\n",
+ " min_length_to_sample=pd.Timedelta(weeks=1),\n",
+ ")\n",
"data_splitter_events.setup_variables()\n",
"\n",
"data_splitter_forecasting = DataSplitterForecasting(\n",
" data_manager=dm,\n",
" config=config,\n",
+ " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n",
")\n",
"\n",
"# If you don't want to do forecasting QA, proportional sampling, or 3-sigma filtering, you can skip this step\n",
@@ -213,7 +219,6 @@
" p_converted = converter.forward_conversion(\n",
" forecasting_splits=forecasting_splits[split_idx],\n",
" event_splits=events_splits[split_idx],\n",
- " override_mode_to_select_forecasting=\"both\",\n",
" )\n",
" new_data = {\n",
" \"prompt\": p_converted[\"instruction\"],\n",
@@ -490,9 +495,7 @@
"# Lets simulate forecasts for after the first line of therapy\n",
"df_constant_patient = patient_data[\"constant\"].copy()\n",
"df_events_patient = patient_data[\"events\"].copy()\n",
- "date_of_first_lot = df_events_patient.loc[\n",
- " df_events_patient[\"event_category\"] == config.event_category_lot, \"date\"\n",
- "].min()\n",
+ "date_of_first_lot = df_events_patient.loc[df_events_patient[\"event_category\"] == \"lot\", \"date\"].min()\n",
"\n",
"# Only keep data until (and including) first line of therapy\n",
"df_events_patient = df_events_patient.loc[df_events_patient[\"date\"] <= date_of_first_lot]"
diff --git a/examples/advanced/custom_output/custom_summarized_row.ipynb b/examples/advanced/custom_output/custom_summarized_row.ipynb
index d17f3db..58bf69f 100644
--- a/examples/advanced/custom_output/custom_summarized_row.ipynb
+++ b/examples/advanced/custom_output/custom_summarized_row.ipynb
@@ -95,7 +95,7 @@
"# Required settings\n",
"config.split_event_category = \"lot\"\n",
"config.event_category_forecast = [\"lab\"]\n",
- "config.data_splitter_events_variables_category_mapping = {\n",
+ "config.event_category_events_prediction_with_naming = {\n",
" \"death\": \"death\",\n",
" \"progression\": \"next progression\",\n",
"}\n",
@@ -122,7 +122,7 @@
"dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n",
"dm.process_indication_data()\n",
"dm.setup_unique_mapping_of_events()\n",
- "dm.setup_dataset_splits()\n",
+ "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n",
"dm.infer_var_types()"
]
},
@@ -134,10 +134,19 @@
"outputs": [],
"source": [
"# Setup splitters\n",
- "data_splitter_events = DataSplitterEvents(dm, config=config)\n",
+ "data_splitter_events = DataSplitterEvents(\n",
+ " dm,\n",
+ " config=config,\n",
+ " max_length_to_sample=pd.Timedelta(weeks=104),\n",
+ " min_length_to_sample=pd.Timedelta(weeks=1),\n",
+ ")\n",
"data_splitter_events.setup_variables()\n",
"\n",
- "data_splitter_forecasting = DataSplitterForecasting(data_manager=dm, config=config)\n",
+ "data_splitter_forecasting = DataSplitterForecasting(\n",
+ " data_manager=dm,\n",
+ " config=config,\n",
+ " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n",
+ ")\n",
"data_splitter_forecasting.setup_statistics()\n",
"\n",
"data_splitter = DataSplitter(data_splitter_events, data_splitter_forecasting)"
@@ -200,7 +209,6 @@
"p_default = converter_default.forward_conversion(\n",
" forecasting_splits=forecasting_splits[0],\n",
" event_splits=events_splits[0],\n",
- " override_mode_to_select_forecasting=\"both\",\n",
")\n",
"\n",
"print(\"=\" * 80)\n",
@@ -305,7 +313,6 @@
"p_custom = converter_custom.forward_conversion(\n",
" forecasting_splits=forecasting_splits[0],\n",
" event_splits=events_splits[0],\n",
- " override_mode_to_select_forecasting=\"both\",\n",
")\n",
"\n",
"print(\"=\" * 80)\n",
@@ -408,7 +415,6 @@
"p_advanced = converter_advanced.forward_conversion(\n",
" forecasting_splits=forecasting_splits[0],\n",
" event_splits=events_splits[0],\n",
- " override_mode_to_select_forecasting=\"both\",\n",
")\n",
"\n",
"print(\"=\" * 80)\n",
@@ -481,7 +487,6 @@
" converter_broken.forward_conversion(\n",
" forecasting_splits=forecasting_splits[0],\n",
" event_splits=events_splits[0],\n",
- " override_mode_to_select_forecasting=\"both\",\n",
" )\n",
"except TypeError as e:\n",
" print(f\"Caught runtime error: {e}\")"
@@ -515,7 +520,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": ".venv_dev",
+ "display_name": ".venv_dev_gpu",
"language": "python",
"name": "python3"
},
diff --git a/examples/advanced/custom_output/customizing_text_generation.ipynb b/examples/advanced/custom_output/customizing_text_generation.ipynb
index 6f15659..03ea854 100644
--- a/examples/advanced/custom_output/customizing_text_generation.ipynb
+++ b/examples/advanced/custom_output/customizing_text_generation.ipynb
@@ -87,7 +87,7 @@
"# Required settings for instruction mode\n",
"config_default.split_event_category = \"lot\"\n",
"config_default.event_category_forecast = [\"lab\"]\n",
- "config_default.data_splitter_events_variables_category_mapping = {\n",
+ "config_default.event_category_events_prediction_with_naming = {\n",
" \"death\": \"death\",\n",
" \"progression\": \"next progression\",\n",
"}\n",
@@ -109,7 +109,7 @@
")\n",
"dm_default.process_indication_data()\n",
"dm_default.setup_unique_mapping_of_events()\n",
- "dm_default.setup_dataset_splits()\n",
+ "dm_default.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n",
"dm_default.infer_var_types()"
]
},
@@ -121,10 +121,19 @@
"outputs": [],
"source": [
"# Setup splitters and converter\n",
- "data_splitter_events_default = DataSplitterEvents(dm_default, config=config_default)\n",
+ "data_splitter_events_default = DataSplitterEvents(\n",
+ " dm_default,\n",
+ " config=config_default,\n",
+ " max_length_to_sample=pd.Timedelta(weeks=104),\n",
+ " min_length_to_sample=pd.Timedelta(weeks=1),\n",
+ ")\n",
"data_splitter_events_default.setup_variables()\n",
"\n",
- "data_splitter_forecasting_default = DataSplitterForecasting(data_manager=dm_default, config=config_default)\n",
+ "data_splitter_forecasting_default = DataSplitterForecasting(\n",
+ " data_manager=dm_default,\n",
+ " config=config_default,\n",
+ " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n",
+ ")\n",
"data_splitter_forecasting_default.setup_statistics()\n",
"\n",
"data_splitter_default = DataSplitter(data_splitter_events_default, data_splitter_forecasting_default)\n",
@@ -155,7 +164,6 @@
"p_converted_default = converter_default.forward_conversion(\n",
" forecasting_splits=forecasting_splits[0],\n",
" event_splits=events_splits[0],\n",
- " override_mode_to_select_forecasting=\"both\",\n",
")\n",
"\n",
"print(\"=\" * 80)\n",
@@ -193,7 +201,7 @@
"# Required settings\n",
"config_custom.split_event_category = \"lot\"\n",
"config_custom.event_category_forecast = [\"lab\"]\n",
- "config_custom.data_splitter_events_variables_category_mapping = {\n",
+ "config_custom.event_category_events_prediction_with_naming = {\n",
" \"death\": \"mortality\", # Custom name for death event\n",
" \"progression\": \"disease progression\", # Custom name for progression\n",
"}\n",
@@ -611,7 +619,7 @@
")\n",
"dm_custom.process_indication_data()\n",
"dm_custom.setup_unique_mapping_of_events()\n",
- "dm_custom.setup_dataset_splits()\n",
+ "dm_custom.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n",
"dm_custom.infer_var_types()"
]
},
@@ -623,10 +631,19 @@
"outputs": [],
"source": [
"# Setup splitters and converter with custom config\n",
- "data_splitter_events_custom = DataSplitterEvents(dm_custom, config=config_custom)\n",
+ "data_splitter_events_custom = DataSplitterEvents(\n",
+ " dm_custom,\n",
+ " config=config_custom,\n",
+ " max_length_to_sample=pd.Timedelta(weeks=104),\n",
+ " min_length_to_sample=pd.Timedelta(weeks=1),\n",
+ ")\n",
"data_splitter_events_custom.setup_variables()\n",
"\n",
- "data_splitter_forecasting_custom = DataSplitterForecasting(data_manager=dm_custom, config=config_custom)\n",
+ "data_splitter_forecasting_custom = DataSplitterForecasting(\n",
+ " data_manager=dm_custom,\n",
+ " config=config_custom,\n",
+ " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n",
+ ")\n",
"data_splitter_forecasting_custom.setup_statistics()\n",
"\n",
"data_splitter_custom = DataSplitter(data_splitter_events_custom, data_splitter_forecasting_custom)\n",
@@ -657,7 +674,6 @@
"p_converted_custom = converter_custom.forward_conversion(\n",
" forecasting_splits=forecasting_splits_custom[0],\n",
" event_splits=events_splits_custom[0],\n",
- " override_mode_to_select_forecasting=\"both\",\n",
")\n",
"\n",
"print(\"=\" * 80)\n",
diff --git a/examples/advanced/custom_splitting/inference_individual_splitters.py b/examples/advanced/custom_splitting/inference_individual_splitters.py
index abf7273..3840087 100644
--- a/examples/advanced/custom_splitting/inference_individual_splitters.py
+++ b/examples/advanced/custom_splitting/inference_individual_splitters.py
@@ -2,6 +2,7 @@
DataSplitterForecasting,
DataManager,
DataSplitterEvents,
+ DataSplitter,
ConverterInstruction,
Config,
)
@@ -16,7 +17,7 @@ def __init__(
self.config = Config()
self.config.split_event_category = "lot"
self.config.event_category_forecast = ["lab"]
- self.config.data_splitter_events_variables_category_mapping = {
+ self.config.event_category_events_prediction_with_naming = {
"death": "death",
"progression": "next progression", # Custom name in prompt: "next progression" instead of "progression"
}
@@ -42,13 +43,29 @@ def __init__(
)
self.dm.process_indication_data()
self.dm.setup_unique_mapping_of_events()
- self.dm.setup_dataset_splits()
+ self.dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)
self.dm.infer_var_types()
- self.data_splitter_events = DataSplitterEvents(self.dm, config=self.config)
- self.data_splitter_events.setup_variables()
- self.data_splitter_forecasting = DataSplitterForecasting(data_manager=self.dm, config=self.config)
- self.data_splitter_forecasting.setup_statistics()
+ data_splitter_events = DataSplitterEvents(
+ self.dm,
+ config=self.config,
+ max_length_to_sample=pd.Timedelta(weeks=104),
+ min_length_to_sample=pd.Timedelta(weeks=1),
+ )
+ data_splitter_events.setup_variables()
+ data_splitter_forecasting = DataSplitterForecasting(
+ data_manager=self.dm,
+ config=self.config,
+ max_forecasted_trajectory_length=pd.Timedelta(days=90),
+ )
+ data_splitter_forecasting.setup_statistics()
+
+ # Use the unified DataSplitter API that combines both splitters
+ self.data_splitter = DataSplitter(
+ data_splitter_events=data_splitter_events,
+ data_splitter_forecasting=data_splitter_forecasting,
+ )
+
self.converter = ConverterInstruction(
nr_tokens_budget_total=8192,
config=self.config,
@@ -62,48 +79,26 @@ def convert_full_to_string_for_one_patient(self, patientid, override_events_or_f
# To simulate that we only have input, half the events
patient_data["events"] = patient_data["events"].iloc[: int(len(patient_data["events"]) / 2)]
- # Here then split date
- split_date = patient_data["events"]["date"].iloc[-1]
-
- #: generate event split - NOTE: this if statement is only to exemplify both cases!
- if override_events_or_forecasting == "events":
- ####### Example if we want to override for events
-
- events_splits = self.data_splitter_events.get_splits_from_patient(
- patient_data,
- max_nr_samples=1,
- override_split_dates=[split_date],
- override_category="death",
- override_end_week_delta=52,
- )
- # We just pick the first one
- events_split = events_splits[0][0]
-
- #: no forecasting split
- forecast_split = None
- forecasting_times_to_predict = None
- else:
- ####### Example if we want to override for forecasting
-
- #: generate forecasting split
- forecast_splits = self.data_splitter_forecasting.get_splits_from_patient(
- patient_data,
- nr_samples_per_split=1,
- filter_outliers=False,
- override_split_dates=[split_date],
- override_variables_to_predict=["Neutrophils"],
- )
- # We just pick the first one
- forecast_split = forecast_splits[0][0]
-
- # We set which weeks to predict
+ # Use the unified DataSplitter API for inference
+ forecast_split, events_split = self.data_splitter.get_splits_from_patient_inference(
+ patient_data,
+ inference_type=override_events_or_forecasting,
+ forecasting_override_variables_to_predict=["Neutrophils"]
+ if override_events_or_forecasting != "events"
+ else None,
+ events_override_category="death" if override_events_or_forecasting != "forecasting" else None,
+ events_override_observation_time_delta=pd.Timedelta(weeks=52)
+ if override_events_or_forecasting != "forecasting"
+ else None,
+ )
+
+ # Set which weeks to predict for forecasting (if applicable)
+ forecasting_times_to_predict = None
+ if forecast_split is not None:
forecasting_times_to_predict = {
"Neutrophils": [1, 2, 8, 11],
}
- #: no events split
- events_split = None
-
# Convert to text
converted = self.converter.forward_conversion_inference(
forecasting_split=forecast_split,
diff --git a/examples/advanced/custom_splitting/training_custom_split_events.ipynb b/examples/advanced/custom_splitting/training_custom_split_events.ipynb
index acd176a..68976a7 100644
--- a/examples/advanced/custom_splitting/training_custom_split_events.ipynb
+++ b/examples/advanced/custom_splitting/training_custom_split_events.ipynb
@@ -101,7 +101,7 @@
"config.event_category_forecast = [\"vitals\"]\n",
"\n",
"# To predict different variables for the event categories, we set up a mapping here\n",
- "config.data_splitter_events_variables_category_mapping = {\n",
+ "config.event_category_events_prediction_with_naming = {\n",
" \"lot\": \"time to next lot\", # Custom name in prompt: \"time to next lot\" instead of \"lot\"\n",
"}"
]
@@ -118,7 +118,7 @@
"dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n",
"dm.process_indication_data()\n",
"dm.setup_unique_mapping_of_events()\n",
- "dm.setup_dataset_splits()\n",
+ "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n",
"dm.infer_var_types()"
]
},
@@ -138,13 +138,19 @@
"outputs": [],
"source": [
"# This data splitter handles event prediction tasks\n",
- "data_splitter_events = DataSplitterEvents(dm, config=config)\n",
+ "data_splitter_events = DataSplitterEvents(\n",
+ " dm,\n",
+ " config=config,\n",
+ " max_length_to_sample=pd.Timedelta(weeks=104),\n",
+ " min_length_to_sample=pd.Timedelta(weeks=1),\n",
+ ")\n",
"data_splitter_events.setup_variables()\n",
"\n",
"# This data splitter handles forecasting tasks\n",
"data_splitter_forecasting = DataSplitterForecasting(\n",
" data_manager=dm,\n",
" config=config,\n",
+ " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n",
")\n",
"# If you don't want to do forecasting QA, proportional sampling, or 3-sigma filtering, you can skip this step\n",
"data_splitter_forecasting.setup_statistics()\n",
@@ -213,7 +219,8 @@
"source": [
"forecasting_splits, events_splits, reference_dates = data_splitter.get_splits_from_patient_with_target(\n",
" patient_data,\n",
- ")"
+ ")\n",
+ "# Note, forecasting_splits will be none here"
]
},
{
@@ -233,9 +240,8 @@
"source": [
"split_idx = 0\n",
"p_converted = converter.forward_conversion(\n",
- " forecasting_splits=forecasting_splits[split_idx],\n",
+ " forecasting_splits=None, # Set to None since we don't want to generate forecasting tasks\n",
" event_splits=events_splits[split_idx],\n",
- " override_mode_to_select_forecasting=\"both\",\n",
")"
]
},
@@ -300,7 +306,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": ".venv_dev",
+ "display_name": ".venv_dev_gpu",
"language": "python",
"name": "python3"
},
diff --git a/examples/advanced/custom_splitting/training_forecasting_qa.ipynb b/examples/advanced/custom_splitting/training_forecasting_qa.ipynb
new file mode 100644
index 0000000..4d0fd0c
--- /dev/null
+++ b/examples/advanced/custom_splitting/training_forecasting_qa.ipynb
@@ -0,0 +1,425 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "0",
+ "metadata": {},
+ "source": [
+ "# Forecasting QA: Question-Answering Style Value Prediction\n",
+ "\n",
+ "This notebook demonstrates the **Forecasting QA** task mode in TwinWeaver.\n",
+ "\n",
+ "While the default `forward_conversion` mode (`\"forecasting\"`) asks the model to predict exact future values,\n",
+ "the **Forecasting QA** mode bins continuous target values into discrete categories (e.g., `A`, `B`, `C`)\n",
+ "and asks the model to predict the correct bin — turning regression into classification.\n",
+ "\n",
+ "This can be useful when:\n",
+ "- Exact numeric predictions are noisy or unreliable\n",
+ "- You want the model to reason about value ranges instead of point estimates\n",
+ "- You want to combine both forecasting and QA tasks (mode `\"both\"`)\n",
+ "\n",
+ "### How it works\n",
+ "\n",
+ "The `ConverterForecastingQA` sub-converter:\n",
+ "1. Computes bin edges from the variable statistics (`setup_statistics()` on the forecasting splitter)\n",
+ "2. Maps each target value to a lettered bin (A, B, C, …)\n",
+ "3. Generates a prompt listing the bin definitions and asks the model to choose the right bin\n",
+ "4. The target answer uses the bin letter instead of the raw numeric value\n",
+ "\n",
+ "### Selecting the mode\n",
+ "\n",
+ "The `forward_conversion` method accepts an `override_mode_to_select_forecasting` parameter:\n",
+ "- `\"forecasting\"` (default): numeric value prediction\n",
+ "- `\"forecasting_qa\"`: bin-based QA prediction\n",
+ "- `\"both\"`: includes both a forecasting and a forecasting QA task\n",
+ "- `None`: randomly selects one of the above at each call"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1",
+ "metadata": {},
+ "source": [
+ "## Setup\n",
+ "\n",
+ "Load libraries and example data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "\n",
+ "from twinweaver import (\n",
+ " DataSplitterForecasting,\n",
+ " DataSplitterEvents,\n",
+ " DataSplitter,\n",
+ " DataManager,\n",
+ " ConverterInstruction,\n",
+ " Config,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df_events = pd.read_csv(\"../../example_data/events.csv\")\n",
+ "df_constant = pd.read_csv(\"../../example_data/constant.csv\")\n",
+ "df_constant_description = pd.read_csv(\"../../example_data/constant_description.csv\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4",
+ "metadata": {},
+ "source": [
+ "## Configuration\n",
+ "\n",
+ "Set up the config, data manager, splitters, and converter.\n",
+ "\n",
+ "> **Important:** `variable_stats` from `DataSplitterForecasting.setup_statistics()` must be\n",
+ "> passed to `ConverterInstruction` for the QA mode to work — it provides the bin edges."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "config = Config()\n",
+ "\n",
+ "config.split_event_category = \"lot\"\n",
+ "config.event_category_forecast = [\"lab\"]\n",
+ "config.event_category_events_prediction_with_naming = {\n",
+ " \"death\": \"death\",\n",
+ " \"progression\": \"next progression\",\n",
+ "}\n",
+ "config.constant_columns_to_use = [\"birthyear\", \"gender\", \"histology\", \"smoking_history\"]\n",
+ "config.constant_birthdate_column = \"birthyear\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dm = DataManager(config=config)\n",
+ "dm.load_indication_data(\n",
+ " df_events=df_events,\n",
+ " df_constant=df_constant,\n",
+ " df_constant_description=df_constant_description,\n",
+ ")\n",
+ "dm.process_indication_data()\n",
+ "dm.setup_unique_mapping_of_events()\n",
+ "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n",
+ "dm.infer_var_types()\n",
+ "\n",
+ "data_splitter_events = DataSplitterEvents(\n",
+ " dm,\n",
+ " config=config,\n",
+ " max_length_to_sample=pd.Timedelta(weeks=104),\n",
+ " min_length_to_sample=pd.Timedelta(weeks=1),\n",
+ ")\n",
+ "data_splitter_events.setup_variables()\n",
+ "\n",
+ "data_splitter_forecasting = DataSplitterForecasting(\n",
+ " data_manager=dm,\n",
+ " config=config,\n",
+ " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n",
+ ")\n",
+ "# setup_statistics() computes per-variable quantile bin edges used by the QA mode\n",
+ "data_splitter_forecasting.setup_statistics()\n",
+ "\n",
+ "data_splitter = DataSplitter(\n",
+ " data_splitter_events=data_splitter_events,\n",
+ " data_splitter_forecasting=data_splitter_forecasting,\n",
+ ")\n",
+ "\n",
+ "converter = ConverterInstruction(\n",
+ " nr_tokens_budget_total=8192,\n",
+ " config=config,\n",
+ " dm=dm,\n",
+ " variable_stats=data_splitter_forecasting.variable_stats, # Required for QA mode\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7",
+ "metadata": {},
+ "source": [
+ "## Generate Splits\n",
+ "\n",
+ "Get training splits for a patient."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "np.random.seed(42)\n",
+ "\n",
+ "patientid = dm.all_patientids[2]\n",
+ "patient_data = dm.get_patient_data(patientid)\n",
+ "\n",
+ "forecasting_splits, events_splits, reference_dates = data_splitter.get_splits_from_patient_with_target(\n",
+ " patient_data,\n",
+ " forecasting_nr_samples_per_split=4,\n",
+ " forecasting_filter_outliers=False,\n",
+ " max_num_splits_per_split_event=2,\n",
+ " events_max_nr_samples_per_split=3,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9",
+ "metadata": {},
+ "source": [
+ "## Mode 1: Default Forecasting (numeric predictions)\n",
+ "\n",
+ "By default, `forward_conversion` uses `override_mode_to_select_forecasting=\"forecasting\"`,\n",
+ "which generates tasks that ask the model to predict exact numeric values."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "10",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "np.random.seed(42)\n",
+ "split_idx = 0\n",
+ "\n",
+ "p_forecasting = converter.forward_conversion(\n",
+ " forecasting_splits=forecasting_splits[split_idx],\n",
+ " event_splits=events_splits[split_idx],\n",
+ " # override_mode_to_select_forecasting=\"forecasting\" # this is the default\n",
+ ")\n",
+ "\n",
+ "print(\"=\" * 80)\n",
+ "print(\"INSTRUCTION (last 1500 chars):\")\n",
+ "print(\"=\" * 80)\n",
+ "print(p_forecasting[\"instruction\"][-1500:])\n",
+ "print()\n",
+ "print(\"=\" * 80)\n",
+ "print(\"ANSWER:\")\n",
+ "print(\"=\" * 80)\n",
+ "print(p_forecasting[\"answer\"])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "11",
+ "metadata": {},
+ "source": [
+ "## Mode 2: Forecasting QA (bin-based predictions)\n",
+ "\n",
+ "Setting `override_mode_to_select_forecasting=\"forecasting_qa\"` activates the QA mode.\n",
+ "The prompt now lists bin definitions (e.g., `a = Bin (-inf, 5.2]`) and asks the model\n",
+ "to output the bin letter instead of a raw number."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "12",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "np.random.seed(42)\n",
+ "split_idx = 0\n",
+ "\n",
+ "p_qa = converter.forward_conversion(\n",
+ " forecasting_splits=forecasting_splits[split_idx],\n",
+ " event_splits=events_splits[split_idx],\n",
+ " override_mode_to_select_forecasting=\"forecasting_qa\",\n",
+ ")\n",
+ "\n",
+ "print(\"=\" * 80)\n",
+ "print(\"INSTRUCTION (last 2000 chars):\")\n",
+ "print(\"=\" * 80)\n",
+ "print(p_qa[\"instruction\"][-2000:])\n",
+ "print()\n",
+ "print(\"=\" * 80)\n",
+ "print(\"ANSWER:\")\n",
+ "print(\"=\" * 80)\n",
+ "print(p_qa[\"answer\"])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "13",
+ "metadata": {},
+ "source": [
+ "### Inspect the bin definitions\n",
+ "\n",
+ "The metadata returned by the conversion includes the bin mapping for each variable.\n",
+ "These are derived from quantiles computed by `setup_statistics()`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "14",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# The target meta contains the category splits for each variable\n",
+ "for task_meta in p_qa[\"meta\"][\"target_meta_detailed\"]:\n",
+ " if \"category_splits\" in task_meta:\n",
+ " print(\"Variable bin definitions:\")\n",
+ " for var, bins in task_meta[\"category_splits\"].items():\n",
+ " print(f\" {var}:\")\n",
+ " for letter, bin_range in bins.items():\n",
+ " print(f\" {letter} = {bin_range}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "15",
+ "metadata": {},
+ "source": [
+ "## Mode 3: Both (forecasting + QA in one prompt)\n",
+ "\n",
+ "Setting `override_mode_to_select_forecasting=\"both\"` includes **both** a numeric forecasting\n",
+ "task and a QA-style bin prediction task in the same multi-task prompt. This encourages the\n",
+ "model to learn complementary representations."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "16",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "np.random.seed(42)\n",
+ "split_idx = 0\n",
+ "\n",
+ "p_both = converter.forward_conversion(\n",
+ " forecasting_splits=forecasting_splits[split_idx],\n",
+ " event_splits=events_splits[split_idx],\n",
+ " override_mode_to_select_forecasting=\"both\",\n",
+ ")\n",
+ "\n",
+ "print(\"=\" * 80)\n",
+ "print(\"INSTRUCTION (last 3000 chars):\")\n",
+ "print(\"=\" * 80)\n",
+ "print(p_both[\"instruction\"][-3000:])\n",
+ "print()\n",
+ "print(\"=\" * 80)\n",
+ "print(\"ANSWER:\")\n",
+ "print(\"=\" * 80)\n",
+ "print(p_both[\"answer\"])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "17",
+ "metadata": {},
+ "source": [
+ "### Compare the task types generated\n",
+ "\n",
+ "Let's inspect which task types were included in each mode."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "18",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for label, result in [(\"forecasting\", p_forecasting), (\"forecasting_qa\", p_qa), (\"both\", p_both)]:\n",
+ " task_types = [m[\"task_type\"] for m in result[\"meta\"][\"target_meta_detailed\"]]\n",
+ " print(f\"Mode '{label}': {task_types}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "19",
+ "metadata": {},
+ "source": [
+ "## Reverse Conversion\n",
+ "\n",
+ "The reverse conversion works for all modes. It parses the model output back into\n",
+ "structured data, regardless of whether the target was numeric or bin-based."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "20",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "date = reference_dates[\"date\"][0]\n",
+ "\n",
+ "# Reverse convert the QA mode output\n",
+ "return_list = converter.reverse_conversion(p_qa[\"answer\"], dm, date)\n",
+ "\n",
+ "return_list[2][\"result\"]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "21",
+ "metadata": {},
+ "source": [
+ "## Variable Statistics\n",
+ "\n",
+ "The bin edges come from the `variable_stats` DataFrame computed during `setup_statistics()`.\n",
+ "You can inspect them directly:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "22",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "data_splitter_forecasting.variable_stats"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv_dev",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb b/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb
index 373ea42..409abe2 100644
--- a/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb
+++ b/examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb
@@ -5,7 +5,7 @@
"id": "0",
"metadata": {},
"source": [
- "# Forecasting-Only Example: Training Data Generation with Custom Dataset"
+ "# Forecasting-Only Example: Training Data Generation with the Unified DataSplitter API"
]
},
{
@@ -27,6 +27,7 @@
"\n",
"from twinweaver import (\n",
" DataSplitterForecasting,\n",
+ " DataSplitter,\n",
" DataManager,\n",
" ConverterInstruction,\n",
" Config,\n",
@@ -66,7 +67,7 @@
"id": "6",
"metadata": {},
"source": [
- "Set up the data manager and the forecasting-only pipeline."
+ "Set up the data manager and the forecasting-only pipeline using the unified `DataSplitter` API. By passing only `data_splitter_forecasting`, the unified interface handles the forecasting-only case automatically."
]
},
{
@@ -105,17 +106,21 @@
"dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n",
"dm.process_indication_data()\n",
"dm.setup_unique_mapping_of_events()\n",
- "dm.setup_dataset_splits()\n",
+ "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n",
"dm.infer_var_types()\n",
"\n",
"\n",
"data_splitter_forecasting = DataSplitterForecasting(\n",
" data_manager=dm,\n",
" config=config,\n",
+ " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n",
")\n",
"# In case you manually want to override the variables for forecasting selectiong, you can skip this next line.\n",
"data_splitter_forecasting.setup_statistics()\n",
"\n",
+ "# Use the unified DataSplitter API with only the forecasting splitter\n",
+ "data_splitter = DataSplitter(data_splitter_forecasting=data_splitter_forecasting)\n",
+ "\n",
"converter = ConverterInstruction(\n",
" nr_tokens_budget_total=8192,\n",
" config=config,\n",
@@ -198,9 +203,9 @@
"id": "16",
"metadata": {},
"source": [
- "We start by generating random \"splits\" in the patient trajectory. We can make multiple relevant samples from each patient trajectory (e.g. depending on when the therapy started), and also to predict different variables (e.g. neutrophils/hemoglobin/... for forecasting).\n",
+ "We start by generating random \"splits\" in the patient trajectory using the unified `DataSplitter.get_splits_from_patient_with_target` method. We can make multiple relevant samples from each patient trajectory (e.g. depending on when the therapy started), and also to predict different variables (e.g. neutrophils/hemoglobin/... for forecasting).\n",
"\n",
- "Here we generate these random splits. We can also manually override them (see other examples on inference)."
+ "Since we only provided a forecasting splitter, `events_splits` will be `None`."
]
},
{
@@ -210,13 +215,13 @@
"metadata": {},
"outputs": [],
"source": [
- "processed_splits_fc, split_dates = data_splitter_forecasting.get_splits_from_patient(\n",
+ "forecasting_splits, events_splits, reference_dates = data_splitter.get_splits_from_patient_with_target(\n",
" patient_data,\n",
- " nr_samples_per_split=4,\n",
- " filter_outliers=False,\n",
- " include_metadata=True,\n",
+ " forecasting_nr_samples_per_split=4,\n",
+ " forecasting_filter_outliers=False,\n",
" max_num_splits_per_split_event=2,\n",
- ")"
+ ")\n",
+ "# Note, events_splits will be None here since we don't have any split events for this patient"
]
},
{
@@ -224,7 +229,7 @@
"id": "18",
"metadata": {},
"source": [
- "Now for each split, we can generate the formatted strings. Note that `event_splits` is left empty since this example only uses the forecasting splitter."
+ "Now for each split, we can generate the formatted strings. Note that `events_splits` is `None` since we only provided a forecasting splitter, so we pass an empty list for `event_splits`."
]
},
{
@@ -236,9 +241,8 @@
"source": [
"split_idx = 0\n",
"p_converted = converter.forward_conversion(\n",
- " forecasting_splits=processed_splits_fc[split_idx],\n",
- " event_splits=[], # Not needed for forecasting-only splitter\n",
- " override_mode_to_select_forecasting=\"forecasting\",\n",
+ " forecasting_splits=forecasting_splits[split_idx],\n",
+ " event_splits=None, # Not needed for forecasting-only splitter\n",
")"
]
},
diff --git a/examples/advanced/custom_splitting/training_individual_splitters.ipynb b/examples/advanced/custom_splitting/training_individual_splitters.ipynb
index 446142c..54fc0ad 100644
--- a/examples/advanced/custom_splitting/training_individual_splitters.ipynb
+++ b/examples/advanced/custom_splitting/training_individual_splitters.ipynb
@@ -5,7 +5,7 @@
"id": "0",
"metadata": {},
"source": [
- "# Example for single patient to convert using the instruction setup with custom dataset"
+ "# Example for single patient to convert using the unified DataSplitter API with custom dataset"
]
},
{
@@ -27,8 +27,9 @@
"\n",
"from twinweaver import (\n",
" DataSplitterForecasting,\n",
- " DataManager,\n",
" DataSplitterEvents,\n",
+ " DataSplitter,\n",
+ " DataManager,\n",
" ConverterInstruction,\n",
" Config,\n",
")"
@@ -67,7 +68,7 @@
"id": "6",
"metadata": {},
"source": [
- "Set up the data managers which hold the patient data."
+ "Set up the data managers and the unified `DataSplitter` which combines both event and forecasting splitters."
]
},
{
@@ -90,7 +91,7 @@
"\n",
"# 3. Mapping of specific time to events to predict (e.g., we want to predict 'death' and 'progression')\n",
"# Only needs to be set if you want to do time to event prediction\n",
- "config.data_splitter_events_variables_category_mapping = {\n",
+ "config.event_category_events_prediction_with_naming = {\n",
" \"death\": \"death\",\n",
" \"progression\": \"next progression\", # Custom name in prompt: \"next progression\" instead of \"progression\"\n",
"}\n",
@@ -111,19 +112,31 @@
"dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n",
"dm.process_indication_data()\n",
"dm.setup_unique_mapping_of_events()\n",
- "dm.setup_dataset_splits()\n",
+ "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n",
"dm.infer_var_types()\n",
"\n",
- "data_splitter_events = DataSplitterEvents(dm, config=config)\n",
+ "data_splitter_events = DataSplitterEvents(\n",
+ " dm,\n",
+ " config=config,\n",
+ " max_length_to_sample=pd.Timedelta(weeks=104),\n",
+ " min_length_to_sample=pd.Timedelta(weeks=1),\n",
+ ")\n",
"data_splitter_events.setup_variables()\n",
"\n",
"data_splitter_forecasting = DataSplitterForecasting(\n",
" data_manager=dm,\n",
" config=config,\n",
+ " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n",
")\n",
"# In case you manually want to override the variables for forecasting selectiong, you can skip this next line.\n",
"data_splitter_forecasting.setup_statistics()\n",
"\n",
+ "# Use the unified DataSplitter API that combines both splitters\n",
+ "data_splitter = DataSplitter(\n",
+ " data_splitter_events=data_splitter_events,\n",
+ " data_splitter_forecasting=data_splitter_forecasting,\n",
+ ")\n",
+ "\n",
"converter = ConverterInstruction(\n",
" nr_tokens_budget_total=8192,\n",
" config=config,\n",
@@ -206,9 +219,9 @@
"id": "16",
"metadata": {},
"source": [
- "We start by generating random \"splits\" in the patient trajectory. We can make multiple relevant samples from each patient trajectory (e.g. depending on when the therapy started), and also to predict different variables (e.g. neutrophils/hemoglobin/... for forecasting, death/progression/metastases/next treatment for event).\n",
+ "We start by generating random \"splits\" in the patient trajectory using the unified `DataSplitter.get_splits_from_patient_with_target` method. This ensures that both forecasting and event splits use the same anchor points in time. We can make multiple relevant samples from each patient trajectory (e.g. depending on when the therapy started), and also to predict different variables (e.g. neutrophils/hemoglobin/... for forecasting, death/progression/metastases/next treatment for event).\n",
"\n",
- "Here we generate these random splits. We can also manually override them (see other examples on inference)."
+ "We can also manually override them (see other examples on inference)."
]
},
{
@@ -218,18 +231,12 @@
"metadata": {},
"outputs": [],
"source": [
- "processed_splits_fc, split_dates = data_splitter_forecasting.get_splits_from_patient(\n",
+ "forecasting_splits, events_splits, reference_dates = data_splitter.get_splits_from_patient_with_target(\n",
" patient_data,\n",
- " nr_samples_per_split=4,\n",
- " filter_outliers=False,\n",
- " include_metadata=True,\n",
+ " forecasting_nr_samples_per_split=4,\n",
+ " forecasting_filter_outliers=False,\n",
" max_num_splits_per_split_event=2,\n",
- ")\n",
- "\n",
- "processed_splits_ev = data_splitter_events.get_splits_from_patient(\n",
- " patient_data,\n",
- " reference_split_dates=split_dates,\n",
- " max_nr_samples_per_split=3,\n",
+ " events_max_nr_samples_per_split=3,\n",
")"
]
},
@@ -250,9 +257,8 @@
"source": [
"split_idx = 0\n",
"p_converted = converter.forward_conversion(\n",
- " forecasting_splits=processed_splits_fc[split_idx],\n",
- " event_splits=processed_splits_ev[split_idx],\n",
- " override_mode_to_select_forecasting=\"forecasting_qa\",\n",
+ " forecasting_splits=forecasting_splits[split_idx],\n",
+ " event_splits=events_splits[split_idx],\n",
")"
]
},
@@ -296,15 +302,15 @@
"metadata": {},
"outputs": [],
"source": [
- "date = split_dates[\"date\"][0]\n",
+ "date = reference_dates[\"date\"][0]\n",
"return_list = converter.reverse_conversion(p_converted[\"answer\"], dm, date)\n",
- "return_list[2][\"result\"]"
+ "return_list[0][\"result\"]"
]
}
],
"metadata": {
"kernelspec": {
- "display_name": ".venv_dev",
+ "display_name": ".venv_dev_gpu",
"language": "python",
"name": "python3"
},
diff --git a/examples/advanced/forecasting_inference/forecasting_vllm_inference.ipynb b/examples/advanced/forecasting_inference/forecasting_vllm_inference.ipynb
new file mode 100644
index 0000000..f57bd1c
--- /dev/null
+++ b/examples/advanced/forecasting_inference/forecasting_vllm_inference.ipynb
@@ -0,0 +1,698 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "0",
+ "metadata": {},
+ "source": [
+ "# Forecasting Inference with TwinWeaver and vLLM\n",
+ "\n",
+ "This notebook demonstrates how to use TwinWeaver's **forecasting inference pipeline**\n",
+ "to predict future clinical values (e.g., lab results like hemoglobin) using a\n",
+ "fine-tuned LLM served via [vLLM](https://github.com/vllm-project/vllm).\n",
+ "\n",
+ "Unlike the TTE (time-to-event) probability pipeline, which *scores* fixed\n",
+ "completions, this pipeline uses **free-text generation**: the model produces\n",
+ "an answer string that is then **reverse-converted** back into a structured\n",
+ "DataFrame with predicted dates and values.\n",
+ "\n",
+ "### Pipeline overview\n",
+ "\n",
+ "```\n",
+ "Patient data ──► DataSplitter (forecasting) ──► ConverterInstruction\n",
+ " │\n",
+ " instruction text per patient\n",
+ " │\n",
+ " ┌──────────────────────────────┘\n",
+ " ▼\n",
+ " vLLM server (OpenAI-compatible API)\n",
+ " │\n",
+ " generated text completions\n",
+ " │\n",
+ " ▼\n",
+ " parse_forecasting_results()\n",
+ " (calls converter.reverse_conversion internally)\n",
+ " │\n",
+ " structured DataFrame\n",
+ " with predicted dates & values\n",
+ "```\n",
+ "\n",
+ "> **⚠️ Important:** The quality of the forecasts depends critically\n",
+ "> on having a **fine-tuned model**. An off-the-shelf instruction model\n",
+ "> (like the default `microsoft/Phi-4-mini-instruct` used here for demonstration)\n",
+ "> will produce **meaningless predictions**. Always fine-tune on your\n",
+ "> clinical dataset first (see `03_end_to_end_llm_finetuning.ipynb`).\n",
+ "\n",
+ "> **Requirements:**\n",
+ "> - A GPU with enough memory to serve the model via vLLM (≥16 GB for a 4-bit 8B model)\n",
+ "> - `pip install twinweaver[fine-tuning-example] vllm openai`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import subprocess\n",
+ "import time\n",
+ "import sys\n",
+ "import os\n",
+ "\n",
+ "import pandas as pd\n",
+ "\n",
+ "from twinweaver import (\n",
+ " DataManager,\n",
+ " Config,\n",
+ " DataSplitterForecasting,\n",
+ " DataSplitterEvents,\n",
+ " DataSplitter,\n",
+ " ConverterInstruction,\n",
+ " run_forecasting_inference_notebook,\n",
+ " parse_forecasting_results,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2",
+ "metadata": {},
+ "source": [
+ "## 1. Configuration\n",
+ "\n",
+ "We define all key settings up front so they are easy to change in one place.\n",
+ "\n",
+ "> **Note:** Replace `MODEL_PATH` with the path to your **fine-tuned** model for\n",
+ "> meaningful results. The default `microsoft/Phi-4-mini-instruct` is only used\n",
+ "> here so that the notebook is self-contained."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# ---------------------------------------------------------------------------\n",
+ "# Model & server settings\n",
+ "# ---------------------------------------------------------------------------\n",
+ "MODEL_PATH = \"microsoft/Phi-4-mini-instruct\" # ⚠️ Replace with your fine-tuned model path!!!!!!!!!!!\n",
+ "TOKENIZER_PATH = MODEL_PATH # Usually the same as the model path\n",
+ "\n",
+ "VLLM_PORT = 8000\n",
+ "PREDICTION_URL = f\"http://0.0.0.0:{VLLM_PORT}/v1/\"\n",
+ "\n",
+ "MAX_CONTEXT_LENGTH = 4096 # Must match what the model was trained with\n",
+ "\n",
+ "# Generation settings\n",
+ "MAX_NEW_TOKENS = 256 # Max tokens for the generated forecast answer\n",
+ "TEMPERATURE = 0.9 # Sampling temperature (0 = greedy)\n",
+ "TOP_P = 0.9 # Nucleus sampling\n",
+ "N_SAMPLES = 3 # Number of independent samples per patient (>1 enables aggregation)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# ---------------------------------------------------------------------------\n",
+ "# TwinWeaver data settings (same as used during training)\n",
+ "# ---------------------------------------------------------------------------\n",
+ "config = Config()\n",
+ "\n",
+ "# 1. Event category used for data splitting\n",
+ "config.split_event_category = \"lot\"\n",
+ "\n",
+ "# 2. Forecasting categories\n",
+ "config.event_category_forecast = [\"lab\"]\n",
+ "\n",
+ "# 3. Time-to-event variables (needed to initialise splitters, even if we only forecast)\n",
+ "config.event_category_events_prediction_with_naming = {\n",
+ " \"death\": \"death\",\n",
+ " \"progression\": \"next progression\",\n",
+ "}\n",
+ "\n",
+ "# 4. Constant (static) columns\n",
+ "config.constant_columns_to_use = [\"birthyear\", \"gender\", \"histology\", \"smoking_history\"]\n",
+ "config.constant_birthdate_column = \"birthyear\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5",
+ "metadata": {},
+ "source": [
+ "## 2. Load and prepare data\n",
+ "\n",
+ "We use the same example data shipped with TwinWeaver. In a real scenario you\n",
+ "would load your own clinical dataset here."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Load example data (adjust paths if running from a different directory)\n",
+ "df_events = pd.read_csv(\"../../example_data/events.csv\")\n",
+ "df_constant = pd.read_csv(\"../../example_data/constant.csv\")\n",
+ "df_constant_description = pd.read_csv(\"../../example_data/constant_description.csv\")\n",
+ "\n",
+ "print(f\"Loaded {len(df_events)} events for {df_events['patientid'].nunique()} patients\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Initialise DataManager and all splitters\n",
+ "dm = DataManager(config=config)\n",
+ "dm.load_indication_data(\n",
+ " df_events=df_events,\n",
+ " df_constant=df_constant,\n",
+ " df_constant_description=df_constant_description,\n",
+ ")\n",
+ "dm.process_indication_data()\n",
+ "dm.setup_unique_mapping_of_events()\n",
+ "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n",
+ "dm.infer_var_types()\n",
+ "\n",
+ "data_splitter_events = DataSplitterEvents(\n",
+ " dm,\n",
+ " config=config,\n",
+ " max_length_to_sample=pd.Timedelta(weeks=104),\n",
+ " min_length_to_sample=pd.Timedelta(weeks=1),\n",
+ ")\n",
+ "data_splitter_events.setup_variables()\n",
+ "\n",
+ "data_splitter_forecasting = DataSplitterForecasting(\n",
+ " data_manager=dm,\n",
+ " config=config,\n",
+ " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n",
+ ")\n",
+ "data_splitter_forecasting.setup_statistics()\n",
+ "\n",
+ "# Combined interface\n",
+ "data_splitter = DataSplitter(data_splitter_events, data_splitter_forecasting)\n",
+ "\n",
+ "converter = ConverterInstruction(\n",
+ " nr_tokens_budget_total=MAX_CONTEXT_LENGTH,\n",
+ " config=config,\n",
+ " dm=dm,\n",
+ " variable_stats=data_splitter_forecasting.variable_stats,\n",
+ ")\n",
+ "\n",
+ "print(\"✅ Data pipeline ready\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8",
+ "metadata": {},
+ "source": [
+ "## 3. Generate forecasting instruction prompts\n",
+ "\n",
+ "For each test patient we generate a **forecasting-only** instruction prompt.\n",
+ "The key parameters are:\n",
+ "\n",
+ "- `forecasting_override_variables_to_predict` – which variables to forecast\n",
+ " (e.g. hemoglobin)\n",
+ "- `forecasting_future_weeks_per_variable` – at which future time points\n",
+ " (in weeks) to request predictions\n",
+ "\n",
+ "The converter produces a text instruction asking the model to predict the\n",
+ "specified variable(s) at the given future time points."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Choose patients to evaluate (here: all test-set patients)\n",
+ "test_patientids = dm.get_all_patientids_in_split(config.test_split_name)\n",
+ "print(f\"Number of test patients: {len(test_patientids)}\")\n",
+ "\n",
+ "# Define the prediction task:\n",
+ "# Which variables to predict and at which future week offsets\n",
+ "VARIABLES_TO_PREDICT = [\"hemoglobin_-_718-7\"]\n",
+ "FUTURE_WEEKS = {\n",
+ " \"hemoglobin_-_718-7\": [4, 8, 12], # Predict hemoglobin at 4, 8, and 12 weeks\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "10",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Build the list of prompt payloads for the forecasting API\n",
+ "# Each payload is a dict with \"patientid\", \"instruction\", and \"split_date\"\n",
+ "prompts_with_meta: list[dict] = []\n",
+ "\n",
+ "for pid in test_patientids:\n",
+ " patient_data = dm.get_patient_data(pid)\n",
+ "\n",
+ " # Use only data up to first lot event (simulates baseline information)\n",
+ " patient_data[\"events\"] = patient_data[\"events\"].sort_values(\"date\")\n",
+ " first_lot_date = patient_data[\"events\"][patient_data[\"events\"][\"event_category\"] == \"lot\"][\"date\"].min()\n",
+ " assert pd.notna(first_lot_date), f\"Patient {pid} has no lot event\"\n",
+ " patient_data[\"events\"] = patient_data[\"events\"][patient_data[\"events\"][\"date\"] <= first_lot_date].copy()\n",
+ "\n",
+ " # Generate the forecasting-only split for inference\n",
+ " forecast_split, _ = data_splitter.get_splits_from_patient_inference(\n",
+ " patient_data,\n",
+ " inference_type=\"forecasting\",\n",
+ " forecasting_override_variables_to_predict=VARIABLES_TO_PREDICT,\n",
+ " )\n",
+ "\n",
+ " # Convert to instruction text (no target answer)\n",
+ " converted = converter.forward_conversion_inference(\n",
+ " forecasting_split=forecast_split,\n",
+ " forecasting_future_weeks_per_variable=FUTURE_WEEKS,\n",
+ " )\n",
+ "\n",
+ " # Collect the prompt payload\n",
+ " prompts_with_meta.append(\n",
+ " {\n",
+ " \"patientid\": pid,\n",
+ " \"instruction\": converted[\"instruction\"],\n",
+ " \"split_date\": forecast_split.split_date_included_in_input,\n",
+ " }\n",
+ " )\n",
+ "\n",
+ "print(f\"Generated {len(prompts_with_meta)} instruction prompts\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "11",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Let's inspect one instruction to see what the model will receive\n",
+ "sample = prompts_with_meta[0]\n",
+ "print(f\"=== Patient: {sample['patientid']} ===\")\n",
+ "print(f\"Split date: {sample['split_date']}\\n\")\n",
+ "print(sample[\"instruction\"])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "12",
+ "metadata": {},
+ "source": [
+ "## 4. Launch the vLLM server\n",
+ "\n",
+ "We launch a vLLM OpenAI-compatible server as a background process.\n",
+ "\n",
+ "> **If you already have a vLLM server running**, skip this cell and just update\n",
+ "> `PREDICTION_URL` and `MODEL_PATH` in the configuration section above.\n",
+ "\n",
+ "> **Tip:** For production use, launch the server in a separate terminal."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "13",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Launch vLLM server as a background subprocess\n",
+ "# Set to True to skip launching (if you already have a server running)\n",
+ "SKIP_VLLM_LAUNCH = False\n",
+ "\n",
+ "vllm_process = None\n",
+ "\n",
+ "if not SKIP_VLLM_LAUNCH:\n",
+ " env = os.environ.copy()\n",
+ " env[\"VLLM_ATTENTION_BACKEND\"] = \"FLASH_ATTN\"\n",
+ "\n",
+ " vllm_command = [\n",
+ " sys.executable,\n",
+ " \"-m\",\n",
+ " \"vllm.entrypoints.openai.api_server\",\n",
+ " \"--port\",\n",
+ " str(VLLM_PORT),\n",
+ " \"--model\",\n",
+ " MODEL_PATH,\n",
+ " \"--tokenizer\",\n",
+ " TOKENIZER_PATH,\n",
+ " \"--enable-prefix-caching\",\n",
+ " ]\n",
+ "\n",
+ " print(f\"🚀 Launching vLLM server:\\n {' '.join(vllm_command)}\\n\")\n",
+ "\n",
+ " vllm_process = subprocess.Popen(\n",
+ " vllm_command,\n",
+ " env=env,\n",
+ " stdout=subprocess.PIPE,\n",
+ " stderr=subprocess.STDOUT,\n",
+ " text=True,\n",
+ " )\n",
+ "\n",
+ " # Wait for the server to be ready\n",
+ " WAIT_SECONDS = 240\n",
+ " print(f\"⏳ Waiting up to {WAIT_SECONDS}s for the server to start...\")\n",
+ " import urllib.request\n",
+ "\n",
+ " for i in range(WAIT_SECONDS):\n",
+ " time.sleep(1)\n",
+ " try:\n",
+ " urllib.request.urlopen(f\"http://localhost:{VLLM_PORT}/health\")\n",
+ " print(f\"✅ vLLM server is ready after {i + 1}s\")\n",
+ " break\n",
+ " except Exception:\n",
+ " pass\n",
+ " else:\n",
+ " print(\"⚠️ Server did not respond in time. Check GPU memory and logs.\")\n",
+ " print(\" You can read server output with: vllm_process.stdout.read()\")\n",
+ "else:\n",
+ " print(\"Skipping vLLM launch – using existing server.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "14",
+ "metadata": {},
+ "source": [
+ "## 5. Run forecasting inference\n",
+ "\n",
+ "This is the core step. `run_forecasting_inference_notebook` sends each patient's\n",
+ "instruction to the vLLM server and collects the generated text completions.\n",
+ "\n",
+ "Under the hood it:\n",
+ "1. Wraps each instruction in a chat message (with optional system prompt).\n",
+ "2. Calls the OpenAI-compatible `/v1/chat/completions` endpoint.\n",
+ "3. Returns the generated text(s) alongside patient metadata.\n",
+ "\n",
+ "When `n_samples > 1`, multiple independent completions are generated per\n",
+ "patient, which can later be aggregated into a mean trajectory."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "15",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Run the forecasting inference\n",
+ "# This calls the vLLM server asynchronously for all patients\n",
+ "raw_results = await run_forecasting_inference_notebook(\n",
+ " prompts_with_meta,\n",
+ " prediction_url=PREDICTION_URL,\n",
+ " prediction_model=MODEL_PATH,\n",
+ " max_concurrent_requests=40,\n",
+ " max_new_tokens=MAX_NEW_TOKENS,\n",
+ " temperature=TEMPERATURE,\n",
+ " top_p=TOP_P,\n",
+ " n_samples=N_SAMPLES,\n",
+ " api_key=\"EMPTY\",\n",
+ " timeout=600.0,\n",
+ ")\n",
+ "\n",
+ "# Check for failures\n",
+ "n_success = sum(1 for r in raw_results if r is not None)\n",
+ "n_fail = sum(1 for r in raw_results if r is None)\n",
+ "print(f\"✅ Generated forecasts for {n_success} patients, {n_fail} failures\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "16",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Let's inspect the raw generated text for one patient\n",
+ "for r in raw_results:\n",
+ " if r is not None:\n",
+ " print(f\"=== Patient: {r['patientid']} ===\")\n",
+ " for i, text in enumerate(r[\"generated_texts\"]):\n",
+ " print(f\"\\n--- Sample {i} ---\")\n",
+ " print(text)\n",
+ " break"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "17",
+ "metadata": {},
+ "source": [
+ "## 6. Parse results with reverse conversion\n",
+ "\n",
+ "`parse_forecasting_results` takes the raw generated texts and:\n",
+ "\n",
+ "1. Calls `converter.reverse_conversion` on each generated text to parse it\n",
+ " back into a structured DataFrame with dates and predicted values.\n",
+ "2. When `n_samples > 1` and `aggregate_samples=True`, aggregates multiple\n",
+ " trajectories using `converter.aggregate_multiple_responses` (e.g. averaging\n",
+ " numeric predictions).\n",
+ "3. Returns a single long-format DataFrame with all patients' predictions.\n",
+ "\n",
+ "> **Note:** Reverse conversion is robust to slightly malformed model output\n",
+ "> thanks to `inference_override=True`, but a fine-tuned model will produce\n",
+ "> much more parseable results than a generic instruction model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "18",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "converter = ConverterInstruction(\n",
+ " nr_tokens_budget_total=MAX_CONTEXT_LENGTH,\n",
+ " config=config,\n",
+ " dm=dm,\n",
+ " variable_stats=data_splitter_forecasting.variable_stats,\n",
+ ")\n",
+ "print(\"✅ Converter reloaded with latest code\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "19",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Parse the generated texts into structured DataFrames\n",
+ "df_results = parse_forecasting_results(\n",
+ " raw_results,\n",
+ " converter,\n",
+ " dm,\n",
+ " drop_failures=True,\n",
+ " aggregate_samples=(N_SAMPLES > 1), # Only aggregate if we have multiple samples\n",
+ ")\n",
+ "\n",
+ "if df_results.empty:\n",
+ " print(\"⚠️ No predictions could be parsed. This is expected when using a non-fine-tuned model.\")\n",
+ " print(\" Fine-tune a model first (see 03_end_to_end_llm_finetuning.ipynb) for meaningful results.\")\n",
+ "else:\n",
+ " print(f\"Parsed {len(df_results)} prediction rows for {df_results['patientid'].nunique()} patients\")\n",
+ " df_results.head(20)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "20",
+ "metadata": {},
+ "source": [
+ "### Understanding the output\n",
+ "\n",
+ "The returned DataFrame has the standard TwinWeaver event format:\n",
+ "\n",
+ "| Column | Description |\n",
+ "|---|---|\n",
+ "| `date` | Predicted date (computed from split_date + week offset) |\n",
+ "| `event_name` | The variable being predicted (e.g. `hemoglobin_-_718-7`) |\n",
+ "| `event_value` | The predicted value |\n",
+ "| `event_category` | Category of the event (e.g. `lab`) |\n",
+ "| `patientid` | Patient identifier |\n",
+ "| `task_type` | Which task type produced this prediction |\n",
+ "| `sample_idx` | Sample index (when `aggregate_samples=False` and `n_samples > 1`) |"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "21",
+ "metadata": {},
+ "source": [
+ "## 7. Multi-sample aggregation (optional)\n",
+ "\n",
+ "When using `n_samples > 1`, each patient gets multiple independent forecast\n",
+ "trajectories. Aggregation (e.g. averaging numeric predictions) can reduce\n",
+ "variance and give more robust estimates.\n",
+ "\n",
+ "Below is an example of running with multiple samples and then aggregating."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "22",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Example: generate 3 samples per patient and aggregate\n",
+ "N_SAMPLES_AGG = 3\n",
+ "\n",
+ "raw_results_multi = await run_forecasting_inference_notebook(\n",
+ " prompts_with_meta,\n",
+ " prediction_url=PREDICTION_URL,\n",
+ " prediction_model=MODEL_PATH,\n",
+ " max_concurrent_requests=40,\n",
+ " max_new_tokens=MAX_NEW_TOKENS,\n",
+ " temperature=TEMPERATURE,\n",
+ " top_p=TOP_P,\n",
+ " n_samples=N_SAMPLES_AGG,\n",
+ " api_key=\"EMPTY\",\n",
+ ")\n",
+ "\n",
+ "# Parse with aggregation enabled\n",
+ "df_aggregated = parse_forecasting_results(\n",
+ " raw_results_multi,\n",
+ " converter,\n",
+ " dm,\n",
+ " drop_failures=True,\n",
+ " aggregate_samples=True, # Average numeric values across samples\n",
+ ")\n",
+ "\n",
+ "if df_aggregated.empty:\n",
+ " print(\"⚠️ No predictions could be parsed. This is expected when using a non-fine-tuned model.\")\n",
+ "else:\n",
+ " print(f\"Aggregated results: {len(df_aggregated)} rows for {df_aggregated['patientid'].nunique()} patients\")\n",
+ " df_aggregated.head(20)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "23",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# You can also get the individual (non-aggregated) samples for deeper analysis\n",
+ "df_individual = parse_forecasting_results(\n",
+ " raw_results_multi,\n",
+ " converter,\n",
+ " dm,\n",
+ " drop_failures=True,\n",
+ " aggregate_samples=False, # Keep individual samples\n",
+ ")\n",
+ "\n",
+ "if df_individual.empty:\n",
+ " print(\"⚠️ No predictions could be parsed. This is expected when using a non-fine-tuned model.\")\n",
+ "else:\n",
+ " print(f\"Individual results: {len(df_individual)} rows\")\n",
+ " print(f\"Samples per patient: {df_individual.groupby('patientid')['sample_idx'].nunique().describe()}\")\n",
+ " df_individual.head(20)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "24",
+ "metadata": {},
+ "source": [
+ "## 8. Clean up\n",
+ "\n",
+ "Shut down the vLLM server if we launched it from this notebook."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "25",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if vllm_process is not None:\n",
+ " print(\"Terminating vLLM server...\")\n",
+ " vllm_process.terminate()\n",
+ " vllm_process.wait(timeout=10)\n",
+ " print(\"✅ Server stopped.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "26",
+ "metadata": {},
+ "source": [
+ "## Summary\n",
+ "\n",
+ "This notebook demonstrated the full forecasting inference workflow:\n",
+ "\n",
+ "1. **Data preparation** – Load clinical data and generate forecasting instruction prompts\n",
+ "2. **Model serving** – Launch (or connect to) a vLLM OpenAI-compatible server\n",
+ "3. **Text generation** – Use `run_forecasting_inference_notebook` to generate completions\n",
+ "4. **Reverse conversion** – Use `parse_forecasting_results` to convert text → structured DataFrame\n",
+ "5. **Aggregation** – Optionally average multiple samples for more robust predictions\n",
+ "\n",
+ "### Key functions\n",
+ "\n",
+ "| Function | Purpose |\n",
+ "|---|---|\n",
+ "| `run_forecasting_inference()` | Generate completions for all patients (sync wrapper) |\n",
+ "| `run_forecasting_inference_notebook()` | Same but async – for notebooks |\n",
+ "| `parse_forecasting_results()` | Reverse-convert generated text → structured DataFrame |\n",
+ "\n",
+ "### Comparison with TTE probability inference\n",
+ "\n",
+ "| Aspect | TTE Inference | Forecasting Inference |\n",
+ "|---|---|---|\n",
+ "| **Method** | Log-prob scoring of fixed completions | Free-text generation |\n",
+ "| **Output** | Probabilities (censored/occurred/not occurred) | Predicted values at future time points |\n",
+ "| **API endpoint** | `/v1/completions` (scoring) | `/v1/chat/completions` (generation) |\n",
+ "| **Post-processing** | Softmax over log-probs | Reverse conversion (text → DataFrame) |\n",
+ "| **Multi-sample** | Not applicable | Average trajectories via `aggregate_multiple_responses` |\n",
+ "\n",
+ "### Next steps\n",
+ "\n",
+ "- **Fine-tune a model** on your dataset using `03_end_to_end_llm_finetuning.ipynb`\n",
+ "- **Combine with TTE** by using `inference_type=\"both\"` in the data splitter\n",
+ "- **Evaluate** predictions against ground truth to assess model performance\n",
+ "- **Experiment** with different variables, time horizons, and sample counts"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv_dev_gpu",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb b/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb
index 9c7fd54..fde4635 100644
--- a/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb
+++ b/examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb
@@ -88,8 +88,7 @@
"dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n",
"dm.process_indication_data()\n",
"dm.setup_unique_mapping_of_events()\n",
- "dm.setup_dataset_splits()\n",
- "dm.infer_var_types()\n",
+ "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n",
"\n",
"converter = ConverterPretrain(config=config, dm=dm)"
]
@@ -401,9 +400,7 @@
"# Lets simulate forecasts for after the first line of therapy\n",
"df_constant_patient = patient_data[\"constant\"].copy()\n",
"df_events_patient = patient_data[\"events\"].copy()\n",
- "date_of_first_lot = df_events_patient.loc[\n",
- " df_events_patient[\"event_category\"] == config.event_category_lot, \"date\"\n",
- "].min()\n",
+ "date_of_first_lot = df_events_patient.loc[df_events_patient[\"event_category\"] == \"lot\", \"date\"].min()\n",
"date_of_first_event = df_events_patient[\"date\"].min()\n",
"\n",
"# Only keep data until (and including) first line of therapy\n",
@@ -507,7 +504,7 @@
"outputs": [],
"source": [
"# Show the generated answer\n",
- "generated_answer"
+ "print(generated_answer)"
]
},
{
@@ -543,7 +540,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": ".venv_test",
+ "display_name": ".venv_dev_gpu",
"language": "python",
"name": "python3"
},
diff --git a/examples/advanced/pretraining/prepare_pretraining_data.py b/examples/advanced/pretraining/prepare_pretraining_data.py
index 05c5d3b..80a5402 100644
--- a/examples/advanced/pretraining/prepare_pretraining_data.py
+++ b/examples/advanced/pretraining/prepare_pretraining_data.py
@@ -28,7 +28,7 @@ def __init__(self):
)
self.dm.process_indication_data()
self.dm.setup_unique_mapping_of_events()
- self.dm.setup_dataset_splits()
+ self.dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)
#: set up converter
self.converter = ConverterPretrain(config=self.config, dm=self.dm)
diff --git a/examples/advanced/tte_inference/tte_probability_inference.ipynb b/examples/advanced/tte_inference/tte_probability_inference.ipynb
index 9e7a2b2..c442b0b 100644
--- a/examples/advanced/tte_inference/tte_probability_inference.ipynb
+++ b/examples/advanced/tte_inference/tte_probability_inference.ipynb
@@ -136,7 +136,7 @@
"# 3. Time-to-event variables to predict\n",
"# Keys = event_category values in the events DataFrame\n",
"# Values = human-readable name used in the prompt\n",
- "config.data_splitter_events_variables_category_mapping = {\n",
+ "config.event_category_events_prediction_with_naming = {\n",
" \"death\": \"death\",\n",
" \"progression\": \"next progression\",\n",
"}\n",
@@ -188,13 +188,22 @@
")\n",
"dm.process_indication_data()\n",
"dm.setup_unique_mapping_of_events()\n",
- "dm.setup_dataset_splits()\n",
+ "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n",
"dm.infer_var_types()\n",
"\n",
- "data_splitter_events = DataSplitterEvents(dm, config=config)\n",
+ "data_splitter_events = DataSplitterEvents(\n",
+ " dm,\n",
+ " config=config,\n",
+ " max_length_to_sample=pd.Timedelta(weeks=104),\n",
+ " min_length_to_sample=pd.Timedelta(weeks=1),\n",
+ ")\n",
"data_splitter_events.setup_variables()\n",
"\n",
- "data_splitter_forecasting = DataSplitterForecasting(data_manager=dm, config=config)\n",
+ "data_splitter_forecasting = DataSplitterForecasting(\n",
+ " data_manager=dm,\n",
+ " config=config,\n",
+ " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n",
+ ")\n",
"data_splitter_forecasting.setup_statistics()\n",
"\n",
"# Combined interface\n",
@@ -581,7 +590,15 @@
"print(f\"\\n✅ Total rows: {len(df_all_horizons)}\")\n",
"df_all_horizons = df_all_horizons.sort_values([\"patientid\", \"week_horizon\"])\n",
"df_all_horizons[\n",
- " [\"patientid\", \"week_horizon\", \"probability_occurrence\", \"probability_no_occurrence\", \"probability_censored\"]\n",
+ " [\n",
+ " \"patientid\",\n",
+ " \"week_horizon\",\n",
+ " \"probability_occurrence\",\n",
+ " \"probability_no_occurrence\",\n",
+ " \"probability_censored\",\n",
+ " \"probability_occurrence_renormalized\",\n",
+ " \"probability_no_occurrence_renormalized\",\n",
+ " ]\n",
"]"
]
},
diff --git a/examples/data_preprocessing/raw_data_preprocessing.ipynb b/examples/data_preprocessing/raw_data_preprocessing.ipynb
index efc5472..d7eeccd 100644
--- a/examples/data_preprocessing/raw_data_preprocessing.ipynb
+++ b/examples/data_preprocessing/raw_data_preprocessing.ipynb
@@ -1007,7 +1007,7 @@
"\n",
"# 3. Mapping of specific time to events to predict (e.g., we want to predict 'death' and 'progression')\n",
"# Only needs to be set if you want to do time to event prediction\n",
- "config.data_splitter_events_variables_category_mapping = {\n",
+ "config.event_category_events_prediction_with_naming = {\n",
" \"death\": \"death\",\n",
" \"progression\": \"next progression\", # Custom name in prompt: \"next progression\" instead of \"progression\"\n",
"}\n",
@@ -1036,7 +1036,7 @@
"dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)\n",
"dm.process_indication_data()\n",
"dm.setup_unique_mapping_of_events()\n",
- "dm.setup_dataset_splits()\n",
+ "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n",
"dm.infer_var_types()\n",
"\n",
"print(f\"Loaded {len(dm.all_patientids)} patients into DataManager\")"
@@ -1050,12 +1050,18 @@
"outputs": [],
"source": [
"# Initialize data splitters and converter\n",
- "data_splitter_events = DataSplitterEvents(dm, config=config)\n",
+ "data_splitter_events = DataSplitterEvents(\n",
+ " dm,\n",
+ " config=config,\n",
+ " max_length_to_sample=pd.Timedelta(weeks=104),\n",
+ " min_length_to_sample=pd.Timedelta(weeks=1),\n",
+ ")\n",
"data_splitter_events.setup_variables()\n",
"\n",
"data_splitter_forecasting = DataSplitterForecasting(\n",
" data_manager=dm,\n",
" config=config,\n",
+ " max_forecasted_trajectory_length=pd.Timedelta(days=90),\n",
")\n",
"data_splitter_forecasting.setup_statistics()\n",
"\n",
@@ -1114,7 +1120,6 @@
" p_converted = converter.forward_conversion(\n",
" forecasting_splits=forecasting_splits[split_idx],\n",
" event_splits=events_splits[split_idx],\n",
- " override_mode_to_select_forecasting=\"both\",\n",
" )\n",
"\n",
" print(\"=\" * 80)\n",
diff --git a/examples/hackathon/01_data_preparation_challenge.ipynb b/examples/hackathon/01_data_preparation_challenge.ipynb
index 3cfd551..a4984ff 100644
--- a/examples/hackathon/01_data_preparation_challenge.ipynb
+++ b/examples/hackathon/01_data_preparation_challenge.ipynb
@@ -180,7 +180,7 @@
"# TODO: Configure time-to-event prediction targets\n",
"# HINT: This should be a dictionary mapping event names to display names\n",
"# Example: {\"original_name\": \"display name in prompt\"}\n",
- "config.data_splitter_events_variables_category_mapping = None # Replace with dict"
+ "config.event_category_events_prediction_with_naming = None # Replace with dict"
]
},
{
@@ -218,19 +218,19 @@
" else:\n",
" print(f\"✅ event_category_forecast: {config.event_category_forecast}\")\n",
"\n",
- " if config.data_splitter_events_variables_category_mapping is None:\n",
- " errors.append(\"❌ data_splitter_events_variables_category_mapping is not set\")\n",
- " elif not isinstance(config.data_splitter_events_variables_category_mapping, dict):\n",
- " errors.append(\"❌ data_splitter_events_variables_category_mapping should be a dict\")\n",
+ " if config.event_category_events_prediction_with_naming is None:\n",
+ " errors.append(\"❌ event_category_events_prediction_with_naming is not set\")\n",
+ " elif not isinstance(config.event_category_events_prediction_with_naming, dict):\n",
+ " errors.append(\"❌ event_category_events_prediction_with_naming should be a dict\")\n",
" elif any(\n",
" [\n",
" cat not in df_events[\"event_category\"].unique()\n",
- " for cat in config.data_splitter_events_variables_category_mapping.keys()\n",
+ " for cat in config.event_category_events_prediction_with_naming.keys()\n",
" ]\n",
" ):\n",
- " errors.append(\"❌ At least one key in data_splitter_events_variables_category_mapping not found in data\")\n",
+ " errors.append(\"❌ At least one key in event_category_events_prediction_with_naming not found in data\")\n",
" else:\n",
- " print(f\"✅ Event mapping: {config.data_splitter_events_variables_category_mapping}\")\n",
+ " print(f\"✅ Event mapping: {config.event_category_events_prediction_with_naming}\")\n",
"\n",
" if errors:\n",
" print(\"\\n\" + \"\\n\".join(errors))\n",
@@ -565,7 +565,6 @@
"# Parameters:\n",
"# - forecasting_splits: the forecasting split for one time point\n",
"# - event_splits: the event split for one time point\n",
- "# - override_mode_to_select_forecasting: set to \"both\"\n",
"\n",
"split_idx = 0\n",
"p_converted = None # Replace with actual conversion call"
diff --git a/examples/hackathon/02_llm_finetuning_challenge.ipynb b/examples/hackathon/02_llm_finetuning_challenge.ipynb
index e45d570..9405caf 100644
--- a/examples/hackathon/02_llm_finetuning_challenge.ipynb
+++ b/examples/hackathon/02_llm_finetuning_challenge.ipynb
@@ -150,7 +150,7 @@
"# Set up:\n",
"# - split_event_category\n",
"# - event_category_forecast\n",
- "# - data_splitter_events_variables_category_mapping\n",
+ "# - event_category_events_prediction_with_naming\n",
"# - constant_columns_to_use\n",
"# - constant_birthdate_column\n",
"\n",
@@ -776,9 +776,7 @@
"\n",
"# Get the date of first line of therapy\n",
"df_events_patient = patient_data[\"events\"].copy()\n",
- "date_of_first_lot = df_events_patient.loc[\n",
- " df_events_patient[\"event_category\"] == config.event_category_lot, \"date\"\n",
- "].min()\n",
+ "date_of_first_lot = df_events_patient.loc[df_events_patient[\"event_category\"] == \"lot\", \"date\"].min()\n",
"\n",
"print(f\"Test patient: {test_patientid}\")\n",
"print(f\"First LoT date: {date_of_first_lot}\")"
diff --git a/examples/integrations/meds_data_import.ipynb b/examples/integrations/meds_data_import.ipynb
index a93ae38..5183219 100644
--- a/examples/integrations/meds_data_import.ipynb
+++ b/examples/integrations/meds_data_import.ipynb
@@ -475,10 +475,9 @@
"config = Config() # Override values here to customize pipeline\n",
"config.constant_columns_to_use = constant_columns\n",
"config.constant_birthdate_column = None # Not using in demo\n",
- "config.lot_event_name = None # Setting for LoTs\n",
"config.event_value_lot_start = None\n",
"config.split_event_category = \"lot\"\n",
- "config.data_splitter_events_variables_category_mapping = {\n",
+ "config.event_category_events_prediction_with_naming = {\n",
" \"death\": \"death\",\n",
"}"
]
@@ -500,9 +499,14 @@
")\n",
"dm.process_indication_data()\n",
"dm.setup_unique_mapping_of_events()\n",
- "dm.setup_dataset_splits()\n",
+ "dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)\n",
"\n",
- "data_splitter_events = DataSplitterEvents(dm, config=config)\n",
+ "data_splitter_events = DataSplitterEvents(\n",
+ " dm,\n",
+ " config=config,\n",
+ " max_length_to_sample=pd.Timedelta(weeks=104),\n",
+ " min_length_to_sample=pd.Timedelta(weeks=1),\n",
+ ")\n",
"data_splitter_events.setup_variables()\n",
"converter = ConverterInstruction(\n",
" nr_tokens_budget_total=8192,\n",
diff --git a/mkdocs.yml b/mkdocs.yml
index 0e2b9ec..e6f6c2e 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -4,7 +4,7 @@ site_url: https://example.com/twinweaver/
# 1. Source Code Link (Adds a nice GitHub icon to the header)
repo_name: TwinWeaver
-repo_url: https://github.com/your-org/twinweaver # Replace with actual
+repo_url: https://github.com/MendenLab/TwinWeaver
edit_uri: edit/main/docs/
extra_css:
@@ -13,10 +13,10 @@ extra_css:
theme:
name: material
custom_dir: docs/overrides
- # 2. Custom Logo/Favicon (Use an emoji if you don't have an SVG yet)
+ # 2. Custom Logo/Favicon
icon:
logo: material/dna
- favicon: material/dna
+ favicon: images/favicon.png
features:
- navigation.tabs # Top-level tabs
@@ -54,7 +54,7 @@ theme:
extra:
social:
- icon: fontawesome/brands/github
- link: https://github.com/your-org/twinweaver
+ link: https://github.com/MendenLab/TwinWeaver
- icon: fontawesome/brands/python
link: https://pypi.org/project/twinweaver/
analytics:
@@ -131,17 +131,29 @@ nav:
- Data Preparation: examples/01_data_preparation_for_training.ipynb
- Inference Prompt Prep: examples/02_inference_prompt_preparation.ipynb
- End-to-End Finetuning: examples/03_end_to_end_llm_finetuning.ipynb
+ - Data Preprocessing:
+ - Raw Data Preprocessing: examples/data_preprocessing/raw_data_preprocessing.ipynb
- Advanced:
- - Custom Splitting (Training): examples/advanced/custom_splitting/training_individual_splitters.ipynb
- - Custom Split Events: examples/advanced/custom_splitting/training_custom_split_events.ipynb
- - Forecasting-Only Splitter: examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb
- - Custom Splitting (Inference): examples/advanced/custom_splitting/inference_individual_splitters.md
- - Custom Text Generation: examples/advanced/custom_output/customizing_text_generation.ipynb
- - Custom Summarized Row: examples/advanced/custom_output/custom_summarized_row.ipynb
- - Pretraining: examples/advanced/pretraining/prepare_pretraining_data.md
+ - Custom Splitting:
+ - Training (Individual Splitters): examples/advanced/custom_splitting/training_individual_splitters.ipynb
+ - Training (Custom Split Events): examples/advanced/custom_splitting/training_custom_split_events.ipynb
+ - Training (Forecasting QA): examples/advanced/custom_splitting/training_forecasting_qa.ipynb
+ - Training (Forecasting-Only Splitter): examples/advanced/custom_splitting/training_forecasting_splitter_only.ipynb
+ - Inference (Individual Splitters): examples/advanced/custom_splitting/inference_individual_splitters.md
+ - Custom Output:
+ - Custom Text Generation: examples/advanced/custom_output/customizing_text_generation.ipynb
+ - Custom Summarized Row: examples/advanced/custom_output/custom_summarized_row.ipynb
+ - Pretraining:
+ - Prepare Pretraining Data: examples/advanced/pretraining/prepare_pretraining_data.md
+ - End-to-End Training with Pretrain: examples/advanced/pretraining/end_to_end_llm_training_with_pretrain.ipynb
- TTE Probability Inference: examples/advanced/tte_inference/tte_probability_inference.ipynb
+ - Forecasting Inference: examples/advanced/forecasting_inference/forecasting_vllm_inference.ipynb
- Integrations:
- MEDS Import: examples/integrations/meds_data_import.ipynb
+ - Hackathon:
+ - Overview: examples/hackathon/README.md
+ - Data Preparation Challenge: examples/hackathon/01_data_preparation_challenge.ipynb
+ - LLM Finetuning Challenge: examples/hackathon/02_llm_finetuning_challenge.ipynb
- API Reference:
- Common:
- Config: reference/common/config.md
@@ -160,3 +172,4 @@ nav:
- MEDS Importer: reference/utils/meds_importer.md
- Preprocessing Helpers: reference/utils/preprocessing_helpers.md
- TTE Inference: reference/utils/tte_inference.md
+ - Forecasting Inference: reference/utils/forecasting_inference.md
diff --git a/pyproject.toml b/pyproject.toml
index acf48a2..4edd460 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -15,7 +15,7 @@ build-backend = "setuptools.build_meta" # Standard entry point for setuptools bu
[project]
name = "twinweaver"
-version = "0.2.1"
+version = "0.3.0"
description = "Converting longitudinal patient data into text for LLM-based event prediction and forecasting."
# --- NEW/UPDATED FIELDS ---
diff --git a/tests/conftest.py b/tests/conftest.py
index 4612096..5f3c3b6 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -9,6 +9,8 @@ def mock_config():
cfg = Config()
# Ensure the random seed is fixed for reproducible tests
cfg.seed = 42
+ # Set constant_columns_to_use to match the test data
+ cfg.constant_columns_to_use = ["birthyear", "gender", "histology", "smoking_history"]
return cfg
diff --git a/tests/test_common.py b/tests/test_common.py
index ada745b..cf3d60c 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -7,7 +7,6 @@ def test_config_initialization(mock_config):
assert mock_config.seed == 42
assert mock_config.patient_id_col == "patientid"
# Verify defaults used in the library
- assert mock_config.event_category_lot == "lot"
def test_data_manager_loading(mock_config, sample_data):
@@ -45,7 +44,7 @@ def test_data_manager_processing(mock_config, sample_data):
# Override config
mock_config.split_event_category = "lot"
mock_config.event_category_forecast = ["lab"]
- mock_config.data_splitter_events_variables_category_mapping = None
+ mock_config.event_category_events_prediction_with_naming = None
mock_config.constant_columns_to_use = ["birthyear", "gender", "histology", "smoking_history"]
dm = DataManager(config=mock_config)
@@ -54,7 +53,7 @@ def test_data_manager_processing(mock_config, sample_data):
# Run pipeline
dm.process_indication_data()
dm.setup_unique_mapping_of_events()
- dm.setup_dataset_splits()
+ dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)
dm.infer_var_types()
# 1. Check Date Processing
diff --git a/tests/test_converter.py b/tests/test_converter.py
index 3a98e42..d5d713a 100644
--- a/tests/test_converter.py
+++ b/tests/test_converter.py
@@ -13,7 +13,7 @@ def setup_components(mock_config, sample_data):
df_events, df_constant, df_constant_desc = sample_data
mock_config.split_event_category = "lot"
mock_config.event_category_forecast = ["lab"]
- mock_config.data_splitter_events_variables_category_mapping = {"death": "death", "progression": "next progression"}
+ mock_config.event_category_events_prediction_with_naming = {"death": "death", "progression": "next progression"}
mock_config.constant_columns_to_use = ["birthyear", "gender", "histology", "smoking_history"]
mock_config.constant_birthdate_column = "birthyear"
@@ -21,13 +21,24 @@ def setup_components(mock_config, sample_data):
dm.load_indication_data(df_events, df_constant, df_constant_desc)
dm.process_indication_data()
dm.setup_unique_mapping_of_events()
- dm.setup_dataset_splits()
+ dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)
dm.infer_var_types()
- splitter_events = DataSplitterEvents(dm, config=mock_config)
+ splitter_events = DataSplitterEvents(
+ dm,
+ config=mock_config,
+ max_length_to_sample=pd.Timedelta(weeks=104),
+ min_length_to_sample=pd.Timedelta(weeks=1),
+ max_split_length_after_split_event=pd.Timedelta(days=90),
+ )
splitter_events.setup_variables()
- splitter_forecast = DataSplitterForecasting(data_manager=dm, config=mock_config)
+ splitter_forecast = DataSplitterForecasting(
+ data_manager=dm,
+ config=mock_config,
+ max_forecasted_trajectory_length=pd.Timedelta(days=90),
+ max_split_length_after_split_event=pd.Timedelta(days=90),
+ )
splitter_forecast.setup_statistics()
data_splitter = DataSplitter(splitter_events, splitter_forecast)
@@ -91,6 +102,152 @@ def test_forward_conversion_training(setup_components):
assert "hemoglobin - 718-7 is 14.01." in answer # Known value for p0
+def test_event_categories_to_exclude_from_input(mock_config, sample_data):
+ """Test that event_categories_to_exclude_from_input removes specified categories from generated text."""
+ df_events, df_constant, df_constant_desc = sample_data
+
+ # Configure with drug events excluded
+ mock_config.split_event_category = "lot"
+ mock_config.event_category_forecast = ["lab"]
+ mock_config.event_category_events_prediction_with_naming = {"death": "death", "progression": "next progression"}
+ mock_config.constant_columns_to_use = ["birthyear", "gender", "histology", "smoking_history"]
+ mock_config.constant_birthdate_column = "birthyear"
+ mock_config.event_categories_to_exclude_from_input = ["drug"]
+
+ dm = DataManager(config=mock_config)
+ dm.load_indication_data(df_events, df_constant, df_constant_desc)
+ dm.process_indication_data()
+ dm.setup_unique_mapping_of_events()
+ dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)
+ dm.infer_var_types()
+
+ splitter_events = DataSplitterEvents(
+ dm,
+ config=mock_config,
+ max_length_to_sample=pd.Timedelta(weeks=104),
+ min_length_to_sample=pd.Timedelta(weeks=1),
+ max_split_length_after_split_event=pd.Timedelta(days=90),
+ )
+ splitter_events.setup_variables()
+
+ splitter_forecast = DataSplitterForecasting(
+ data_manager=dm,
+ config=mock_config,
+ max_forecasted_trajectory_length=pd.Timedelta(days=90),
+ max_split_length_after_split_event=pd.Timedelta(days=90),
+ )
+ splitter_forecast.setup_statistics()
+
+ data_splitter = DataSplitter(splitter_events, splitter_forecast)
+
+ converter = ConverterInstruction(
+ nr_tokens_budget_total=4096, config=mock_config, dm=dm, variable_stats=splitter_forecast.variable_stats
+ )
+
+ patient_data = dm.get_patient_data("p0")
+ f_splits, e_splits, _ = data_splitter.get_splits_from_patient_with_target(patient_data)
+
+ result = converter.forward_conversion(
+ forecasting_splits=f_splits[0],
+ event_splits=e_splits[0],
+ override_mode_to_select_forecasting="both",
+ )
+
+ instruction = result["instruction"]
+
+ # Drug events (e.g. "drug pemetrexed is administered") should be excluded
+ assert "drug pemetrexed is administered" not in instruction
+ assert (
+ "drug" not in instruction.lower().split("demographic")[0].split("\n")[-1]
+ if "demographic" in instruction
+ else True
+ )
+
+ # Other event categories should still be present
+ assert "Starting with demographic data:" in instruction
+ assert "hemoglobin" in instruction # lab events still present
+
+
+def test_event_categories_to_exclude_multiple(mock_config, sample_data):
+ """Test excluding multiple event categories from input."""
+ df_events, df_constant, df_constant_desc = sample_data
+
+ mock_config.split_event_category = "lot"
+ mock_config.event_category_forecast = ["lab"]
+ mock_config.event_category_events_prediction_with_naming = {"death": "death", "progression": "next progression"}
+ mock_config.constant_columns_to_use = ["birthyear", "gender", "histology", "smoking_history"]
+ mock_config.constant_birthdate_column = "birthyear"
+ mock_config.event_categories_to_exclude_from_input = ["drug", "ecog"]
+
+ dm = DataManager(config=mock_config)
+ dm.load_indication_data(df_events, df_constant, df_constant_desc)
+ dm.process_indication_data()
+ dm.setup_unique_mapping_of_events()
+ dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)
+ dm.infer_var_types()
+
+ splitter_events = DataSplitterEvents(
+ dm,
+ config=mock_config,
+ max_length_to_sample=pd.Timedelta(weeks=104),
+ min_length_to_sample=pd.Timedelta(weeks=1),
+ max_split_length_after_split_event=pd.Timedelta(days=90),
+ )
+ splitter_events.setup_variables()
+
+ splitter_forecast = DataSplitterForecasting(
+ data_manager=dm,
+ config=mock_config,
+ max_forecasted_trajectory_length=pd.Timedelta(days=90),
+ max_split_length_after_split_event=pd.Timedelta(days=90),
+ )
+ splitter_forecast.setup_statistics()
+
+ data_splitter = DataSplitter(splitter_events, splitter_forecast)
+
+ converter = ConverterInstruction(
+ nr_tokens_budget_total=4096, config=mock_config, dm=dm, variable_stats=splitter_forecast.variable_stats
+ )
+
+ patient_data = dm.get_patient_data("p0")
+ f_splits, e_splits, _ = data_splitter.get_splits_from_patient_with_target(patient_data)
+
+ result = converter.forward_conversion(
+ forecasting_splits=f_splits[0],
+ event_splits=e_splits[0],
+ override_mode_to_select_forecasting="both",
+ )
+
+ instruction = result["instruction"]
+
+ # Both drug and ecog events should be absent
+ assert "drug pemetrexed is administered" not in instruction
+ assert "ECOG" not in instruction
+
+ # Lab events should still be present
+ assert "hemoglobin" in instruction
+
+
+def test_event_categories_to_exclude_empty_list(setup_components):
+ """Test that an empty exclude list leaves all events intact (default behavior)."""
+ dm, data_splitter, converter = setup_components
+ assert dm.config.event_categories_to_exclude_from_input == []
+
+ patient_data = dm.get_patient_data("p0")
+ f_splits, e_splits, _ = data_splitter.get_splits_from_patient_with_target(patient_data)
+
+ result = converter.forward_conversion(
+ forecasting_splits=f_splits[0],
+ event_splits=e_splits[0],
+ override_mode_to_select_forecasting="both",
+ )
+
+ instruction = result["instruction"]
+
+ # Drug events should be present when nothing is excluded
+ assert "drug pemetrexed is administered" in instruction
+
+
def test_forward_conversion_inference(setup_components):
"""Test conversion for inference (no target string)."""
dm, data_splitter, converter = setup_components
@@ -162,3 +319,185 @@ def test_reverse_conversion(setup_components):
assert e_res["task_type"] == task_events
assert e_res["result"]["censoring"].iloc[0].item() is False
assert e_res["result"]["occurred"].iloc[0].item() is False
+
+
+# ── Tests for forecasting-only and events-only conversion ──────────────────
+
+
+@pytest.fixture
+def setup_forecasting_only(mock_config, sample_data):
+ """Helper to set up components with only a forecasting splitter (no events splitter)."""
+ df_events, df_constant, df_constant_desc = sample_data
+ mock_config.split_event_category = "lot"
+ mock_config.event_category_forecast = ["lab"]
+ mock_config.event_category_events_prediction_with_naming = {"death": "death", "progression": "next progression"}
+ mock_config.constant_columns_to_use = ["birthyear", "gender", "histology", "smoking_history"]
+ mock_config.constant_birthdate_column = "birthyear"
+
+ dm = DataManager(config=mock_config)
+ dm.load_indication_data(df_events, df_constant, df_constant_desc)
+ dm.process_indication_data()
+ dm.setup_unique_mapping_of_events()
+ dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)
+ dm.infer_var_types()
+
+ splitter_forecast = DataSplitterForecasting(
+ data_manager=dm,
+ config=mock_config,
+ max_forecasted_trajectory_length=pd.Timedelta(days=90),
+ max_split_length_after_split_event=pd.Timedelta(days=90),
+ )
+ splitter_forecast.setup_statistics()
+
+ data_splitter = DataSplitter(data_splitter_forecasting=splitter_forecast)
+
+ converter = ConverterInstruction(
+ nr_tokens_budget_total=4096, config=mock_config, dm=dm, variable_stats=splitter_forecast.variable_stats
+ )
+
+ return dm, data_splitter, converter, splitter_forecast
+
+
+@pytest.fixture
+def setup_events_only(mock_config, sample_data):
+ """Helper to set up components with only an events splitter (no forecasting splitter)."""
+ df_events, df_constant, df_constant_desc = sample_data
+ mock_config.split_event_category = "lot"
+ mock_config.event_category_forecast = ["lab"]
+ mock_config.event_category_events_prediction_with_naming = {"death": "death", "progression": "next progression"}
+ mock_config.constant_columns_to_use = ["birthyear", "gender", "histology", "smoking_history"]
+ mock_config.constant_birthdate_column = "birthyear"
+
+ dm = DataManager(config=mock_config)
+ dm.load_indication_data(df_events, df_constant, df_constant_desc)
+ dm.process_indication_data()
+ dm.setup_unique_mapping_of_events()
+ dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)
+ dm.infer_var_types()
+
+ splitter_events = DataSplitterEvents(
+ dm,
+ config=mock_config,
+ max_length_to_sample=pd.Timedelta(weeks=104),
+ min_length_to_sample=pd.Timedelta(weeks=1),
+ max_split_length_after_split_event=pd.Timedelta(days=90),
+ )
+ splitter_events.setup_variables()
+
+ data_splitter = DataSplitter(data_splitter_events=splitter_events)
+
+ converter = ConverterInstruction(nr_tokens_budget_total=4096, config=mock_config, dm=dm, variable_stats=None)
+
+ return dm, data_splitter, converter
+
+
+def test_forward_conversion_forecasting_only_training(setup_forecasting_only):
+ """Test training conversion with only forecasting splits (no events)."""
+ dm, data_splitter, converter, _ = setup_forecasting_only
+ patient_data = dm.get_patient_data("p0")
+
+ f_splits, e_splits, _ = data_splitter.get_splits_from_patient_with_target(patient_data)
+
+ assert e_splits is None # No events splitter configured
+ assert f_splits is not None
+
+ result = converter.forward_conversion(
+ forecasting_splits=f_splits[0],
+ event_splits=None,
+ override_mode_to_select_forecasting="forecasting",
+ )
+
+ instruction = result["instruction"]
+ answer = result["answer"]
+
+ # Should contain forecasting task but no events task
+ assert "Task 1 is forecasting:" in instruction
+ assert "time to event prediction:" not in instruction
+ assert "Starting with demographic data:" in instruction
+
+ # Answer should contain exactly one task
+ assert len(answer) > 0
+ assert "Task 1 is forecasting:" in answer
+ assert "time to event prediction:" not in answer
+
+
+def test_forward_conversion_events_only_training(setup_events_only):
+ """Test training conversion with only event splits (no forecasting)."""
+ dm, data_splitter, converter = setup_events_only
+ patient_data = dm.get_patient_data("p0")
+
+ f_splits, e_splits, _ = data_splitter.get_splits_from_patient_with_target(patient_data)
+
+ assert f_splits is None # No forecasting splitter configured
+ assert e_splits is not None
+
+ result = converter.forward_conversion(
+ forecasting_splits=None,
+ event_splits=e_splits[0],
+ )
+
+ instruction = result["instruction"]
+ answer = result["answer"]
+
+ # Should contain events task but no forecasting task
+ assert "time to event prediction:" in instruction
+ assert "forecasting:" not in instruction
+ assert "Starting with demographic data:" in instruction
+
+ # Answer should contain events task only
+ assert len(answer) > 0
+ assert "time to event prediction:" in answer
+ assert "forecasting:" not in answer
+
+
+def test_forward_conversion_inference_forecasting_only(setup_forecasting_only):
+ """Test inference conversion with only a forecasting split (no events)."""
+ dm, data_splitter, converter, _ = setup_forecasting_only
+ patient_data = dm.get_patient_data("p0")
+
+ f_split, e_split = data_splitter.get_splits_from_patient_inference(
+ patient_data,
+ inference_type="forecasting",
+ forecasting_override_variables_to_predict=["hemoglobin_-_718-7"],
+ )
+
+ assert e_split is None
+ assert f_split is not None
+
+ result = converter.forward_conversion_inference(
+ forecasting_split=f_split,
+ forecasting_future_weeks_per_variable={"hemoglobin_-_718-7": [4, 8, 12]},
+ event_split=None,
+ )
+
+ instruction = result["instruction"]
+ assert result["answer"] is None
+ assert "hemoglobin" in instruction
+ assert "future weeks 4, 8, 12" in instruction
+ assert "Task 1 is forecasting:" in instruction
+ assert "time to event prediction:" not in instruction
+
+
+def test_forward_conversion_inference_events_only(setup_events_only):
+ """Test inference conversion with only an event split (no forecasting)."""
+ dm, data_splitter, converter = setup_events_only
+ patient_data = dm.get_patient_data("p0")
+
+ f_split, e_split = data_splitter.get_splits_from_patient_inference(
+ patient_data,
+ inference_type="events",
+ )
+
+ assert f_split is None
+ assert e_split is not None
+
+ result = converter.forward_conversion_inference(
+ forecasting_split=None,
+ event_split=e_split,
+ )
+
+ instruction = result["instruction"]
+ assert result["answer"] is None
+ assert "Task 1 is time to event prediction:" in instruction
+ assert "forecasting:" not in instruction
+ assert "Starting with demographic data:" in instruction
diff --git a/tests/test_converter_pretrain.py b/tests/test_converter_pretrain.py
index 9c182d5..56804ba 100644
--- a/tests/test_converter_pretrain.py
+++ b/tests/test_converter_pretrain.py
@@ -15,7 +15,7 @@ def setup_pretrain_components(mock_config, sample_data):
dm.load_indication_data(df_events, df_constant, df_constant_desc)
dm.process_indication_data()
dm.setup_unique_mapping_of_events()
- dm.setup_dataset_splits()
+ dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)
converter = ConverterPretrain(config=mock_config, dm=dm)
diff --git a/tests/test_splitter.py b/tests/test_splitter.py
index 4ae7189..7a6b60e 100644
--- a/tests/test_splitter.py
+++ b/tests/test_splitter.py
@@ -1,4 +1,5 @@
import pytest
+import numpy as np
import pandas as pd
from twinweaver.common.data_manager import DataManager
from twinweaver.instruction.data_splitter_events import DataSplitterEvents
@@ -12,21 +13,23 @@ def initialized_dm(mock_config, sample_data):
df_events, df_constant, df_constant_desc = sample_data
mock_config.split_event_category = "lot"
mock_config.event_category_forecast = ["lab"]
- mock_config.data_splitter_events_variables_category_mapping = {"death": "death", "progression": "next progression"}
+ mock_config.event_category_events_prediction_with_naming = {"death": "death", "progression": "next progression"}
mock_config.constant_columns_to_use = ["birthyear", "gender", "histology", "smoking_history"]
dm = DataManager(config=mock_config)
dm.load_indication_data(df_events, df_constant, df_constant_desc)
dm.process_indication_data()
dm.setup_unique_mapping_of_events()
- dm.setup_dataset_splits()
+ dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)
dm.infer_var_types()
return dm
def test_splitter_forecasting_statistics(initialized_dm, mock_config):
"""Test that forecasting splitter can calculate statistics."""
- splitter_forecast = DataSplitterForecasting(data_manager=initialized_dm, config=mock_config)
+ splitter_forecast = DataSplitterForecasting(
+ data_manager=initialized_dm, config=mock_config, max_forecasted_trajectory_length=pd.Timedelta(days=90)
+ )
# This calculates R2, NRMSE etc. for the variables
splitter_forecast.setup_statistics()
@@ -49,10 +52,21 @@ def test_splitter_forecasting_statistics(initialized_dm, mock_config):
def test_get_splits_from_patient(initialized_dm, mock_config):
"""Test generating splits for a single patient."""
# Setup Splitters
- splitter_events = DataSplitterEvents(initialized_dm, config=mock_config)
+ splitter_events = DataSplitterEvents(
+ initialized_dm,
+ config=mock_config,
+ max_length_to_sample=pd.Timedelta(weeks=104),
+ min_length_to_sample=pd.Timedelta(weeks=1),
+ max_split_length_after_split_event=pd.Timedelta(days=90),
+ )
splitter_events.setup_variables()
- splitter_forecast = DataSplitterForecasting(data_manager=initialized_dm, config=mock_config)
+ splitter_forecast = DataSplitterForecasting(
+ data_manager=initialized_dm,
+ config=mock_config,
+ max_forecasted_trajectory_length=pd.Timedelta(days=90),
+ max_split_length_after_split_event=pd.Timedelta(days=90),
+ )
splitter_forecast.setup_statistics()
data_splitter = DataSplitter(splitter_events, splitter_forecast)
@@ -98,9 +112,16 @@ def test_get_splits_from_patient(initialized_dm, mock_config):
def test_inference_split(initialized_dm, mock_config):
"""Test generating an inference split (last date)."""
- splitter_events = DataSplitterEvents(initialized_dm, config=mock_config)
+ splitter_events = DataSplitterEvents(
+ initialized_dm,
+ config=mock_config,
+ max_length_to_sample=pd.Timedelta(weeks=104),
+ min_length_to_sample=pd.Timedelta(weeks=1),
+ )
splitter_events.setup_variables()
- splitter_forecast = DataSplitterForecasting(data_manager=initialized_dm, config=mock_config)
+ splitter_forecast = DataSplitterForecasting(
+ data_manager=initialized_dm, config=mock_config, max_forecasted_trajectory_length=pd.Timedelta(days=90)
+ )
data_splitter = DataSplitter(splitter_events, splitter_forecast)
patient_data = initialized_dm.get_patient_data("p0")
@@ -124,3 +145,584 @@ def test_inference_split(initialized_dm, mock_config):
assert e_split.observation_end_date == pd.Timestamp("2018-02-23 00:00:00")
assert e_split.event_censored == "end_of_data"
assert not e_split.event_occurred
+
+
+# ────────────────────────────────────────────────────────────────────────────
+# Tests for DataSplitter with individual (single) splitters
+# ────────────────────────────────────────────────────────────────────────────
+
+
+def test_data_splitter_requires_at_least_one_splitter():
+ """Test that DataSplitter raises if neither splitter is provided."""
+ with pytest.raises(ValueError, match="At least one"):
+ DataSplitter()
+
+
+def test_training_forecasting_only(initialized_dm, mock_config):
+ """Test training splits when only the forecasting splitter is provided."""
+ splitter_forecast = DataSplitterForecasting(
+ data_manager=initialized_dm,
+ config=mock_config,
+ max_forecasted_trajectory_length=pd.Timedelta(days=90),
+ max_split_length_after_split_event=pd.Timedelta(days=90),
+ )
+ splitter_forecast.setup_statistics()
+
+ data_splitter = DataSplitter(data_splitter_forecasting=splitter_forecast)
+
+ patient_data = initialized_dm.get_patient_data("p0")
+ forecasting_splits, events_splits, ref_dates = data_splitter.get_splits_from_patient_with_target(
+ patient_data, max_num_splits_per_split_event=1
+ )
+
+ # Forecasting should be populated
+ assert forecasting_splits is not None
+ assert len(forecasting_splits) == 1
+
+ # Events should be None since no events splitter was provided
+ assert events_splits is None
+
+ # Reference dates should still be available from the forecasting splitter
+ assert ref_dates is not None
+ assert not ref_dates.empty
+ assert "date" in ref_dates.columns
+ assert "split_date" in ref_dates.columns
+ assert ref_dates.shape == (1, 2)
+ assert ref_dates["date"].iloc[0] == pd.Timestamp("2015-07-08")
+ assert ref_dates["split_date"].iloc[0] == pd.Timestamp("2015-05-06")
+
+ # Validate forecasting split structure and content
+ f_split = forecasting_splits[0][0]
+ assert f_split.events_until_split is not None
+ assert f_split.constant_data["patientid"].iloc[0] == "p0"
+ assert f_split.events_until_split.shape == (32, 8)
+ assert f_split.split_date_included_in_input == pd.Timestamp("2015-07-08")
+ assert f_split.lot_date == pd.Timestamp("2015-05-06")
+ assert f_split.sampled_variables == ["hemoglobin_-_718-7"]
+
+ # Target events should exist and be after the split date
+ assert not f_split.target_events_after_split.empty
+ assert f_split.target_events_after_split.shape[0] == 4
+ assert f_split.target_events_after_split["date"].min() == pd.Timestamp("2015-07-29")
+ assert f_split.target_events_after_split["date"].max() == pd.Timestamp("2015-09-30")
+
+ # Events in input must be <= split date
+ assert (f_split.events_until_split["date"] <= f_split.split_date_included_in_input).all()
+ assert f_split.events_until_split["date"].min() == pd.Timestamp("2015-04-19")
+
+ # Constant data should contain expected columns
+ assert "birthyear" in f_split.constant_data.columns
+ assert "gender" in f_split.constant_data.columns
+ assert f_split.constant_data.shape[0] == 1
+
+
+def test_training_events_only(initialized_dm, mock_config):
+ """Test training splits when only the events splitter is provided."""
+ splitter_events = DataSplitterEvents(
+ initialized_dm,
+ config=mock_config,
+ max_length_to_sample=pd.Timedelta(weeks=104),
+ min_length_to_sample=pd.Timedelta(weeks=1),
+ max_split_length_after_split_event=pd.Timedelta(days=90),
+ )
+ splitter_events.setup_variables()
+
+ data_splitter = DataSplitter(data_splitter_events=splitter_events)
+
+ patient_data = initialized_dm.get_patient_data("p0")
+ forecasting_splits, events_splits, ref_dates = data_splitter.get_splits_from_patient_with_target(
+ patient_data, max_num_splits_per_split_event=1
+ )
+
+ # Forecasting should be None since no forecasting splitter was provided
+ assert forecasting_splits is None
+
+ # Events should be populated
+ assert events_splits is not None
+ assert len(events_splits) == 1
+
+ # Reference dates should be reconstructed from events splits
+ assert ref_dates is not None
+ assert not ref_dates.empty
+ assert "date" in ref_dates.columns
+ assert "split_date" in ref_dates.columns
+ assert ref_dates.shape == (1, 2)
+ assert ref_dates["date"].iloc[0] == pd.Timestamp("2015-07-08")
+ assert ref_dates["split_date"].iloc[0] == pd.Timestamp("2015-05-06")
+
+ # Validate events split structure and content
+ e_split = events_splits[0][0]
+ assert e_split.events_until_split is not None
+ assert e_split.constant_data["patientid"].iloc[0] == "p0"
+ assert e_split.events_until_split.shape == (32, 8)
+ assert e_split.split_date_included_in_input == pd.Timestamp("2015-07-08")
+ assert e_split.lot_date == pd.Timestamp("2015-05-06")
+ # Category must be one of the configured event categories (or a backup thereof)
+ expected_mapping = {"death": "death", "progression": "next progression"}
+ assert e_split.sampled_category in list(expected_mapping.keys()) + list(
+ mock_config.data_splitter_events_backup_category_mapping.values()
+ )
+ # Category name must match one of the configured descriptive names
+ assert e_split.sampled_category_name in expected_mapping.values()
+
+ # Event outcome must be boolean
+ assert isinstance(e_split.event_occurred, bool)
+ # Censoring should be None or one of the known censoring types
+ assert e_split.event_censored in [None, "new_split_date_start", "end_of_data", "data_cutoff"]
+ # Observation end date must be after the split date
+ assert e_split.observation_end_date >= e_split.split_date_included_in_input
+
+ # Events in input must be <= split date
+ assert (e_split.events_until_split["date"] <= e_split.split_date_included_in_input).all()
+ assert e_split.events_until_split["date"].min() == pd.Timestamp("2015-04-19")
+ assert e_split.events_until_split["date"].max() == pd.Timestamp("2015-07-08")
+
+ # Constant data integrity
+ assert "birthyear" in e_split.constant_data.columns
+ assert "gender" in e_split.constant_data.columns
+ assert "histology" in e_split.constant_data.columns
+ assert "smoking_history" in e_split.constant_data.columns
+ assert e_split.constant_data.shape[0] == 1
+
+
+def test_inference_forecasting_only(initialized_dm, mock_config):
+ """Test inference split when only the forecasting splitter is provided."""
+ splitter_forecast = DataSplitterForecasting(
+ data_manager=initialized_dm, config=mock_config, max_forecasted_trajectory_length=pd.Timedelta(days=90)
+ )
+
+ data_splitter = DataSplitter(data_splitter_forecasting=splitter_forecast)
+ patient_data = initialized_dm.get_patient_data("p0")
+
+ f_split, e_split = data_splitter.get_splits_from_patient_inference(
+ patient_data,
+ inference_type="forecasting",
+ forecasting_override_variables_to_predict=["hemoglobin_-_718-7"],
+ )
+
+ last_date = patient_data["events"]["date"].max()
+
+ # Forecasting split should be populated
+ assert f_split is not None
+ assert f_split.split_date_included_in_input == last_date
+ assert f_split.split_date_included_in_input == pd.Timestamp("2016-05-13")
+ assert f_split.sampled_variables == ["hemoglobin_-_718-7"]
+ assert f_split.lot_date == "override"
+
+ # Inference has no target data
+ assert f_split.target_events_after_split.empty
+
+ # Input events must cover the full patient history up to last date
+ assert f_split.events_until_split.shape == (78, 8)
+ assert (f_split.events_until_split["date"] <= f_split.split_date_included_in_input).all()
+ assert f_split.events_until_split["date"].min() == pd.Timestamp("2015-04-19")
+ assert f_split.events_until_split["date"].max() == pd.Timestamp("2016-05-13")
+
+ # Constant data integrity
+ assert f_split.constant_data["patientid"].iloc[0] == "p0"
+ assert "birthyear" in f_split.constant_data.columns
+ assert "gender" in f_split.constant_data.columns
+ assert f_split.constant_data.shape[0] == 1
+
+ # Events split should be None
+ assert e_split is None
+
+
+def test_inference_events_only(initialized_dm, mock_config):
+ """Test inference split when only the events splitter is provided."""
+ splitter_events = DataSplitterEvents(
+ initialized_dm,
+ config=mock_config,
+ max_length_to_sample=pd.Timedelta(weeks=104),
+ min_length_to_sample=pd.Timedelta(weeks=1),
+ )
+ splitter_events.setup_variables()
+
+ data_splitter = DataSplitter(data_splitter_events=splitter_events)
+ patient_data = initialized_dm.get_patient_data("p0")
+
+ f_split, e_split = data_splitter.get_splits_from_patient_inference(
+ patient_data,
+ inference_type="events",
+ events_override_category="death",
+ events_override_observation_time_delta=pd.Timedelta(weeks=52),
+ )
+
+ last_date = patient_data["events"]["date"].max()
+
+ # Forecasting split should be None
+ assert f_split is None
+
+ # Events split should be populated
+ assert e_split is not None
+ assert e_split.split_date_included_in_input == last_date
+ assert e_split.split_date_included_in_input == pd.Timestamp("2016-05-13")
+ assert e_split.sampled_category == "death"
+ assert e_split.sampled_category_name == "death"
+
+ # p0's last event is death itself; predicting death from last date with 52-week window:
+ # death already occurred at last_date so looking forward finds nothing → censored end_of_data
+ assert e_split.event_occurred is False
+ assert e_split.event_censored == "end_of_data"
+ assert e_split.observation_end_date == pd.Timestamp("2017-05-12")
+
+ # Input events must cover the full patient history up to last date
+ assert e_split.events_until_split.shape == (78, 8)
+ assert (e_split.events_until_split["date"] <= e_split.split_date_included_in_input).all()
+ assert e_split.events_until_split["date"].max() == pd.Timestamp("2016-05-13")
+
+ # Constant data integrity
+ assert e_split.constant_data["patientid"].iloc[0] == "p0"
+ assert "birthyear" in e_split.constant_data.columns
+ assert e_split.constant_data.shape[0] == 1
+
+
+def test_inference_both_type_with_only_forecasting(initialized_dm, mock_config):
+ """Test that inference_type='both' gracefully returns None for the missing splitter."""
+ splitter_forecast = DataSplitterForecasting(
+ data_manager=initialized_dm, config=mock_config, max_forecasted_trajectory_length=pd.Timedelta(days=90)
+ )
+
+ data_splitter = DataSplitter(data_splitter_forecasting=splitter_forecast)
+ patient_data = initialized_dm.get_patient_data("p0")
+
+ f_split, e_split = data_splitter.get_splits_from_patient_inference(
+ patient_data,
+ inference_type="both",
+ forecasting_override_variables_to_predict=["hemoglobin_-_718-7"],
+ )
+
+ last_date = patient_data["events"]["date"].max()
+
+ # Forecasting should work
+ assert f_split is not None
+ assert f_split.sampled_variables == ["hemoglobin_-_718-7"]
+ assert f_split.split_date_included_in_input == last_date
+ assert f_split.split_date_included_in_input == pd.Timestamp("2016-05-13")
+ assert f_split.lot_date == "override"
+
+ # Inference: no target
+ assert f_split.target_events_after_split.empty
+
+ # Full patient history should be used as input
+ assert f_split.events_until_split.shape == (78, 8)
+ assert (f_split.events_until_split["date"] <= f_split.split_date_included_in_input).all()
+
+ # Constant data integrity
+ assert f_split.constant_data["patientid"].iloc[0] == "p0"
+ assert f_split.constant_data.shape[0] == 1
+
+ # Events should be None because no events splitter is set
+ assert e_split is None
+
+
+def test_inference_both_type_with_only_events(initialized_dm, mock_config):
+ """Test that inference_type='both' gracefully returns None for the missing splitter."""
+ splitter_events = DataSplitterEvents(
+ initialized_dm,
+ config=mock_config,
+ max_length_to_sample=pd.Timedelta(weeks=104),
+ min_length_to_sample=pd.Timedelta(weeks=1),
+ )
+ splitter_events.setup_variables()
+
+ data_splitter = DataSplitter(data_splitter_events=splitter_events)
+ patient_data = initialized_dm.get_patient_data("p0")
+
+ f_split, e_split = data_splitter.get_splits_from_patient_inference(
+ patient_data,
+ inference_type="both",
+ events_override_category="death",
+ events_override_observation_time_delta=pd.Timedelta(weeks=52),
+ )
+
+ last_date = patient_data["events"]["date"].max()
+
+ # Forecasting should be None because no forecasting splitter is set
+ assert f_split is None
+
+ # Events should work
+ assert e_split is not None
+ assert e_split.split_date_included_in_input == last_date
+ assert e_split.split_date_included_in_input == pd.Timestamp("2016-05-13")
+ assert e_split.sampled_category == "death"
+ assert e_split.sampled_category_name == "death"
+
+ # Observation window and outcome
+ assert e_split.event_occurred is False
+ assert e_split.event_censored == "end_of_data"
+ assert e_split.observation_end_date == pd.Timestamp("2017-05-12")
+
+ # Full patient history should be used as input
+ assert e_split.events_until_split.shape == (78, 8)
+ assert (e_split.events_until_split["date"] <= e_split.split_date_included_in_input).all()
+
+ # Constant data integrity
+ assert e_split.constant_data["patientid"].iloc[0] == "p0"
+ assert e_split.constant_data.shape[0] == 1
+
+
+# ────────────────────────────────────────────────────────────────────────────
+# Test that forecasting truncation uses split_event_category (not just LoT)
+# ────────────────────────────────────────────────────────────────────────────
+
+
+def test_forecasting_truncates_at_next_split_event_not_just_lot():
+ """
+ Verify that _generate_variable_splits_for_date truncates target events
+ at the next *split event* (config.split_event_category).
+
+ Scenario (split_event_category = "custom_split"):
+ Timeline for a single patient:
+ Day 0 - custom_split event (the split event that anchors the window)
+ Day 5 - lab measurement (before split date, input)
+ Day 10 - split date (curr_date)
+ Day 15 - lab measurement (target - should be kept)
+ Day 20 - next custom_split (next split event - target boundary)
+ Day 25 - lab measurement (target - should be EXCLUDED)
+ Day 30 - lot event (LoT - should NOT be the boundary)
+ Day 35 - lab measurement (target - should be EXCLUDED)
+
+ With the old code, the target would
+ include days 15, 25, and 35 (cutting only at day 30 LoT).
+ With the fix (filtering by split_event_category), the target should
+ include only day 15 (cutting at day 20 custom_split).
+ """
+ from twinweaver.common.config import Config
+
+ cfg = Config()
+ cfg.split_event_category = "custom_split"
+ cfg.event_category_forecast = ["lab"]
+
+ base_date = pd.Timestamp("2020-01-01")
+
+ events = pd.DataFrame(
+ {
+ cfg.date_col: [base_date + pd.Timedelta(days=d) for d in [0, 5, 10, 15, 20, 25, 30, 35]],
+ cfg.event_category_col: [
+ "custom_split", # Day 0: split event
+ "lab", # Day 5: lab (input)
+ "lab", # Day 10: lab at split date (input)
+ "lab", # Day 15: lab (target, before next split event)
+ "custom_split", # Day 20: next split event
+ "lab", # Day 25: lab (target, after next split event - exclude)
+ "lot", # Day 30: lot event (should NOT be the boundary)
+ "lab", # Day 35: lab (target, after lot - exclude)
+ ],
+ cfg.event_name_col: [
+ "split_marker",
+ "hemoglobin",
+ "hemoglobin",
+ "hemoglobin",
+ "split_marker",
+ "hemoglobin",
+ "lot_marker",
+ "hemoglobin",
+ ],
+ cfg.event_value_col: [
+ "start",
+ "13.0",
+ "13.1",
+ "13.2",
+ "start",
+ "13.3",
+ "LoT Start",
+ "13.4",
+ ],
+ cfg.event_descriptive_name_col: [
+ "split marker",
+ "hemoglobin",
+ "hemoglobin",
+ "hemoglobin",
+ "split marker",
+ "hemoglobin",
+ "LoT",
+ "hemoglobin",
+ ],
+ cfg.source_col: ["events"] * 8,
+ cfg.meta_data_col: [pd.NA] * 8,
+ }
+ )
+
+ constant = pd.DataFrame(
+ {
+ cfg.patient_id_col: ["p_test"],
+ cfg.constant_split_col: ["train"],
+ }
+ )
+
+ patient_data = {"events": events, "constant": constant}
+ curr_date = base_date + pd.Timedelta(days=10)
+ lot_date = base_date # The split event that anchors this window
+
+ # Build a minimal all_possible_split_dates with hemoglobin valid at curr_date
+ all_possible_split_dates = pd.DataFrame(
+ {
+ cfg.date_col: [curr_date],
+ cfg.event_name_col: ["hemoglobin"],
+ cfg.event_category_col: ["lab"],
+ "lot_date": [lot_date],
+ }
+ )
+
+ # Create a DataManager stub - we only need dm.variable_types for the splitter
+ dm = DataManager.__new__(DataManager)
+ dm.config = cfg
+ dm.variable_types = {"hemoglobin": "numeric"}
+ dm.data_frames = {}
+ dm.all_patientids = ["p_test"]
+
+ splitter = DataSplitterForecasting(
+ config=cfg,
+ data_manager=dm,
+ max_forecasted_trajectory_length=pd.Timedelta(days=90),
+ max_lookback_time_for_value=pd.Timedelta(days=90),
+ max_split_length_after_split_event=pd.Timedelta(days=90),
+ sampling_strategy="uniform",
+ )
+
+ np.random.seed(42)
+
+ (date_splits, valid_sample_date, date_splits_meta, _) = splitter._generate_variable_splits_for_date(
+ curr_date=curr_date,
+ nr_samples=1,
+ override_variables_to_predict=["hemoglobin"],
+ events=events,
+ all_possible_split_dates=all_possible_split_dates,
+ apply_filtering=False,
+ override_split_dates=None,
+ patient_data=patient_data,
+ lot_date=lot_date,
+ )
+
+ assert valid_sample_date is True
+ assert len(date_splits) == 1
+
+ target = date_splits[0].target_events_after_split
+
+ # The target must only include lab events BEFORE the next custom_split (day 20).
+ # So only the measurement at day 15 should survive.
+ assert target.shape[0] == 1, (
+ f"Expected 1 target event (day 15 only), got {target.shape[0]}. "
+ f"Dates in target: {target[cfg.date_col].tolist()}"
+ )
+ assert target[cfg.date_col].iloc[0] == base_date + pd.Timedelta(days=15)
+
+ # Also verify input events include everything up to and including curr_date
+ input_events = date_splits[0].events_until_split
+ assert (input_events[cfg.date_col] <= curr_date).all()
+ assert input_events.shape[0] == 3 # Day 0, 5, 10
+
+
+def test_forecasting_truncation_allow_beyond_next_split_date():
+ """
+ Verify that when allow_forecasting_beyond_next_split_date=True,
+ target events are NOT truncated at the next split event.
+ """
+ from twinweaver.common.config import Config
+
+ cfg = Config()
+ cfg.split_event_category = "custom_split"
+ cfg.event_category_forecast = ["lab"]
+
+ base_date = pd.Timestamp("2020-01-01")
+
+ events = pd.DataFrame(
+ {
+ cfg.date_col: [base_date + pd.Timedelta(days=d) for d in [0, 5, 10, 15, 20, 25]],
+ cfg.event_category_col: [
+ "custom_split",
+ "lab",
+ "lab",
+ "lab",
+ "custom_split",
+ "lab",
+ ],
+ cfg.event_name_col: [
+ "split_marker",
+ "hemoglobin",
+ "hemoglobin",
+ "hemoglobin",
+ "split_marker",
+ "hemoglobin",
+ ],
+ cfg.event_value_col: ["start", "13.0", "13.1", "13.2", "start", "13.3"],
+ cfg.event_descriptive_name_col: [
+ "split marker",
+ "hemoglobin",
+ "hemoglobin",
+ "hemoglobin",
+ "split marker",
+ "hemoglobin",
+ ],
+ cfg.source_col: ["events"] * 6,
+ cfg.meta_data_col: [pd.NA] * 6,
+ }
+ )
+
+ constant = pd.DataFrame(
+ {
+ cfg.patient_id_col: ["p_test"],
+ cfg.constant_split_col: ["train"],
+ }
+ )
+
+ patient_data = {"events": events, "constant": constant}
+ curr_date = base_date + pd.Timedelta(days=10)
+ lot_date = base_date
+
+ all_possible_split_dates = pd.DataFrame(
+ {
+ cfg.date_col: [curr_date],
+ cfg.event_name_col: ["hemoglobin"],
+ cfg.event_category_col: ["lab"],
+ "lot_date": [lot_date],
+ }
+ )
+
+ dm = DataManager.__new__(DataManager)
+ dm.config = cfg
+ dm.variable_types = {"hemoglobin": "numeric"}
+ dm.data_frames = {}
+ dm.all_patientids = ["p_test"]
+
+ splitter = DataSplitterForecasting(
+ config=cfg,
+ data_manager=dm,
+ max_forecasted_trajectory_length=pd.Timedelta(days=90),
+ max_lookback_time_for_value=pd.Timedelta(days=90),
+ max_split_length_after_split_event=pd.Timedelta(days=90),
+ sampling_strategy="uniform",
+ allow_forecasting_beyond_next_split_date=True,
+ )
+
+ np.random.seed(42)
+
+ (date_splits, valid_sample_date, _, _) = splitter._generate_variable_splits_for_date(
+ curr_date=curr_date,
+ nr_samples=1,
+ override_variables_to_predict=["hemoglobin"],
+ events=events,
+ all_possible_split_dates=all_possible_split_dates,
+ apply_filtering=False,
+ override_split_dates=None,
+ patient_data=patient_data,
+ lot_date=lot_date,
+ )
+
+ assert valid_sample_date is True
+ assert len(date_splits) == 1
+
+ target = date_splits[0].target_events_after_split
+
+ # With allow_forecasting_beyond_next_split_date=True, no truncation at
+ # the next split event, so both day 15 and day 25 labs should be in target.
+ assert target.shape[0] == 2, (
+ f"Expected 2 target events (days 15 and 25), got {target.shape[0]}. "
+ f"Dates in target: {target[cfg.date_col].tolist()}"
+ )
+ expected_dates = [
+ base_date + pd.Timedelta(days=15),
+ base_date + pd.Timedelta(days=25),
+ ]
+ assert target[cfg.date_col].tolist() == expected_dates
diff --git a/twinweaver/__init__.py b/twinweaver/__init__.py
index cef9112..a457dc0 100644
--- a/twinweaver/__init__.py
+++ b/twinweaver/__init__.py
@@ -20,3 +20,8 @@
run_tte_probability_estimation,
run_tte_probability_estimation_notebook,
)
+from twinweaver.utils.forecasting_inference import (
+ parse_forecasting_results,
+ run_forecasting_inference,
+ run_forecasting_inference_notebook,
+)
diff --git a/twinweaver/common/config.py b/twinweaver/common/config.py
index a103b49..a1d3f37 100644
--- a/twinweaver/common/config.py
+++ b/twinweaver/common/config.py
@@ -54,14 +54,7 @@ class Config:
source_col_default_value : str
Default value to assign to `source_col` if it is missing. Default: "events".
split_date_col : str
- Column name specifically used for dates related to line of therapy (LoT) events. Default: "lot_date".
- lot_event_name : str
- Column name for the name or identifier of the line of therapy (e.g., "First Line"). Default: "lot".
- event_value_lot_start : str
- Specific string value used in `event_value_col` to denote the start of a line of therapy. Default: "LoT Start".
- skip_future_lot_filtering : bool
- Flag indicating whether to skip filtering out future line of therapy events. Default: False.
- Useful in case you accidentially overlap LoTs which are actually the same, use with caution.
+ Column name used for dates related to data splitting events (e.g., line of therapy). Default: "split_date".
lot_concatenate_descriptive_and_value : bool
Flag indicating whether to concatenate the descriptive name and value for line of therapy events.
Default: False.
@@ -70,16 +63,16 @@ class Config:
`lot_concatenate_descriptive_and_value` is True. Default: " - ".
warning_for_splitters_patient_without_splits : bool
Whether to warn if a patient has no split events. Default: True.
- event_category_lot : str
- Specific string value used in `event_category_col` to identify 'line of therapy' events. Default: "lot".
event_category_death : str
+ Important for censoring in TTE tasks and for identifying death events in general, since in medicine
+ they are common and critical events.
Specific string value used in `event_category_col` to identify 'death' events. Default: "death".
- event_category_labs : str
- Specific string value used in `event_category_col` to identify 'lab result' events. Default: "lab".
event_category_forecast : list[str] | None
List of event categories to be considered for forecasting tasks. Default: None.
split_event_category : str | None
Event category used for data splitting (e.g., LoT). Default: None.
+ event_categories_to_exclude_from_input : list[str]
+ List of event categories to exclude from the input data (e.g., ["lot"]). Default: [].
source_genetic : str
Specific string value used in `source_col` to identify data originating from genetic testing.
Default: "genetic".
@@ -115,7 +108,7 @@ class Config:
Text inserted before the description of events for visits subsequent to the first one. Default: "\\n".
event_day_text : str
Template text used to introduce events on subsequent visit days, indicating the time elapsed since the previous
- visit. Default: " self.delta_time_unit : later, the patient visited and experienced the following: \\n".
+ visit. Default: " weeks later, the patient visited and experienced the following: \\n".
post_event_text : str
Text appended after listing all events for a specific visit day. Default: ".\\n".
forecasting_fval_prompt_start : str
@@ -232,7 +225,9 @@ class Config:
Default: None.
constant_birthdate_column_format : str
Format of the birthdate column, either "date" or "age". Default: "date".
- data_splitter_events_variables_category_mapping : dict | None
+ constant_birthdate_columns_silence_print : bool
+ Whether to silence print statements related to birthdate column processing. Default: False.
+ event_category_events_prediction_with_naming : dict | None
Mapping defining which event categories correspond to specific prediction types in DataSplitterEvents.
Keys are event categories (e.g., 'death', 'progression'), values are descriptive names for the target variable.
Default: None.
@@ -254,7 +249,7 @@ def __init__(self):
# different event types as well as how they should be written down (since based on categories),
# for example, based on GDT: { "death": "death", "progression": "next progression", "lot":
# "next line of therapy", "metastasis": "next metastasis"}
- self.data_splitter_events_variables_category_mapping = None
+ self.event_category_events_prediction_with_naming = None
# --- Import data parameters ---
self.date_cutoff = None # If set, only use data before this date (format: "YYYY-MM-DD"), censored after
@@ -281,9 +276,7 @@ def __init__(self):
self.event_meta_default_value = pd.NA # Default value for event meta data if not present
self.source_col_default_value: str = "events" # Default value for source column if not present
self.split_date_col: str = "split_date"
- self.lot_event_name: str = "lot"
- self.event_value_lot_start: str = "LoT Start"
- self.skip_future_lot_filtering: bool = False # Whether to skip filtering future LoT events, by default False.
+
self.lot_concatenate_descriptive_and_value: bool = (
False # If true, concatenate descriptive name and value for LoT events, by default False (only event_vale.)
)
@@ -296,10 +289,11 @@ def __init__(self):
True # Whether to warn if a patient has no LoT events in DataSplitterEvents
)
+ # List of event categories to exclude from the input data (e.g., ["lot"])
+ self.event_categories_to_exclude_from_input: list = []
+
# --- Specific Event Categories / Values / Sources ---
- self.event_category_lot: str = "lot"
self.event_category_death: str = "death"
- self.event_category_labs: str = "lab"
self.source_genetic: str = "genetic"
self.genetic_skip_text_value: str = "present"
@@ -435,6 +429,7 @@ def __init__(self):
] # Which columns to use from the constant data
self.constant_birthdate_column: str = None # If set, use this column for age calculation
self.constant_birthdate_column_format: str = "date" # Either "date" or "age"
+ self.constant_birthdate_columns_silence_print: bool = False # To silence print statements related to birthdate
# Used to backup event categories for event types if no variables are found
# e.g. progression -> death
@@ -467,11 +462,11 @@ def seed(self) -> int:
@seed.setter
def seed(self, value: int):
- """Set the seed value and update all random seeds (numpy, pandas, random)."""
+ """Set the seed value and update all random seeds (numpy and random)."""
self._seed = value
self._set_all_seeds(value)
def _set_all_seeds(self, seed: int):
- """Set seeds for numpy, pandas, and random modules."""
+ """Set seeds for numpy and random modules."""
np.random.seed(seed)
random.seed(seed)
diff --git a/twinweaver/common/converter_base.py b/twinweaver/common/converter_base.py
index a4e57d3..3f76c07 100644
--- a/twinweaver/common/converter_base.py
+++ b/twinweaver/common/converter_base.py
@@ -113,10 +113,11 @@ def __init__(self, config: Config) -> None:
self.config.event_category_and_name_replace_override
if self.config.event_category_and_name_replace_override is not None
else {
+ # Death specific overrides, using config constants for category and event name
self.config.event_category_death: { # Use config constant
- "death": { # Assuming 'death' is the event_name associated with this category
- "full_replacement_string": "death",
- "reverse_string_value": "death",
+ self.config.event_category_death: {
+ "full_replacement_string": self.config.event_category_death,
+ "reverse_string_value": self.config.event_category_death,
}
}
}
@@ -175,7 +176,12 @@ def _preprocess_constant_date(
constant[self.config.constant_birthdate_column] = (
constant[self.config.constant_birthdate_column].astype(int).astype(str) + " years"
)
- print(f"Using provided ages in {self.config.constant_birthdate_column} as age format")
+ if not self.config.constant_birthdate_columns_silence_print:
+ print(
+ f"Using provided ages in {self.config.constant_birthdate_column} as age format."
+ "To silence this print statement, set constant_birthdate_columns_silence_print to True"
+ "in the config."
+ )
else:
# Check if the column contains integer ages (not birthdates)
try:
@@ -184,7 +190,12 @@ def _preprocess_constant_date(
constant[self.config.constant_birthdate_column] = pd.to_datetime(
constant[self.config.constant_birthdate_column].astype(int).astype(str) + "-01-01"
)
- print(f"Converted integer ages in {self.config.constant_birthdate_column} to age format")
+ if not self.config.constant_birthdate_columns_silence_print:
+ print(
+ f"Converted integer ages in {self.config.constant_birthdate_column} to age format."
+ " To silence this print statement set constant_birthdate_columns_silence_print to True"
+ " in the config."
+ )
# Try converting the column to datetime if it is not already, if doesn't work then just keep it
elif not pd.api.types.is_datetime64_any_dtype(constant[self.config.constant_birthdate_column]):
@@ -287,6 +298,12 @@ def _preprocess_events(self, events: pd.DataFrame) -> pd.DataFrame:
round_and_strip, args=(self.decimal_precision,)
)
+ # Exclude specified event categories from the input if configured
+ if self.config.event_categories_to_exclude_from_input:
+ events = events[
+ ~events[self.config.event_category_col].isin(self.config.event_categories_to_exclude_from_input)
+ ]
+
return events
def _get_event_string(
@@ -1043,7 +1060,14 @@ def _get_all_most_recent_events_within_budget(self, events: pd.DataFrame, budget
#: return events
return events_final.reset_index(drop=True) # Reset index for clean output
- def _generate_summarized_row_string(self, input_event_data, combined_target_meta: dict) -> str:
+ def _generate_summarized_row_string(
+ self,
+ input_event_data,
+ combined_target_meta: dict,
+ lot_event_name: str = "lot",
+ event_value_lot_start: str = "LoT Start",
+ event_category_lot: str = "lot",
+ ) -> str:
"""
Creates a summary string containing the most recent genetic, LoT, and target variable values.
@@ -1119,18 +1143,16 @@ def _generate_summarized_row_string(self, input_event_data, combined_target_meta
#: add most recent LoT info using config constants
ret_prompt += self.config.forecasting_prompt_summarized_lot # Using config attribute
- lot_info = input_event_data[input_event_data[self.config.event_category_col] == self.config.event_category_lot]
+ lot_info = input_event_data[input_event_data[self.config.event_category_col] == event_category_lot]
# Ensure lot_info is sorted by date to correctly find the last one
lot_info = lot_info.sort_values(self.config.date_col)
# Create selections based on event name and event value using config constants
- if self.config.lot_event_name is not None and self.config.event_value_lot_start is not None:
- lot_selection_1 = lot_info[
- lot_info[self.config.event_name_col] == self.config.lot_event_name
- ] # Using config attribute
+ if lot_event_name is not None and event_value_lot_start is not None:
+ lot_selection_1 = lot_info[lot_info[self.config.event_name_col] == lot_event_name] # Using config attribute
lot_selection_2 = lot_info[
- lot_info[self.config.event_value_col] == self.config.event_value_lot_start
+ lot_info[self.config.event_value_col] == event_value_lot_start
] # Using config attribute
else:
# Just use all lot_info if no specific columns are defined
diff --git a/twinweaver/common/data_manager.py b/twinweaver/common/data_manager.py
index 05ee274..3c4bda1 100644
--- a/twinweaver/common/data_manager.py
+++ b/twinweaver/common/data_manager.py
@@ -22,11 +22,7 @@ class DataManager:
def __init__(
self,
config: Config, # Added config parameter
- train_split_min: float = 0.8,
- validation_split_max: float = 0.1,
- test_split_max: float = 0.1,
- max_val_test_nr_patients: int = 500,
- replace_special_symbols_override: list = None,
+ replace_special_symbols: list = None,
) -> None:
"""
Initializes the DataManager for a specific indication.
@@ -39,53 +35,19 @@ def __init__(
config : Config
A configuration object containing paths, column names, category names,
and other constants used throughout the data management process.
- train_split_min : float, optional
- The minimum proportion of patients to allocate to the training set.
- Defaults to 0.8. The actual number will be the remainder after
- allocating validation and test sets.
- validation_split_max : float, optional
- The maximum proportion of the total patients to allocate to the
- validation set. The actual number is capped by
- `max_val_test_nr_patients`. Defaults to 0.1.
- test_split_max : float, optional
- The maximum proportion of the total patients to allocate to the
- test set. The actual number is capped by `max_val_test_nr_patients`.
- Defaults to 0.1.
- max_val_test_nr_patients : int, optional
- The absolute maximum number of patients to include in the validation
- and test sets combined. Defaults to 500.
- replace_special_symbols_override : list, optional
+ replace_special_symbols : list, optional
A list of tuples to override the default special character replacements
in event descriptive names. Each tuple should be in the format
`(event_category, (string_to_replace, replacement_string))`. If None,
- default replacements specified in the method are used. Defaults to None.
+ an empty list is used (no replacements). Defaults to None.
"""
#: initialize the data manager
self.config = config # Store config object
- self.train_split = train_split_min
- self.validation_split = validation_split_max
- self.test_split = test_split_max
- self.max_val_test_nr_patients = max_val_test_nr_patients
self.variable_types = {} # event_name -> "numeric" / "categorical"
# Setup replacing of special symbol, format is event_category : (, )
- if replace_special_symbols_override is not None:
- self.replace_special_symbols = replace_special_symbols_override
- else:
- # Use config constants for event categories where available
- self.replace_special_symbols = [
- (self.config.event_category_labs, ("/", " per ")),
- (self.config.event_category_labs, (".", " ")),
- (
- "drug",
- ("/", " "),
- ), # "drug" category not explicitly in Config constants provided
- (
- self.config.event_category_lot,
- ("/", " "),
- ), # Use config for 'lot' category
- ]
+ self.replace_special_symbols = replace_special_symbols if replace_special_symbols is not None else []
# Setup indication
self.data_frames = None
@@ -184,16 +146,43 @@ def process_indication_data(
Performs initial processing on the loaded indication data.
Requires `load_indication_data` to be called first.
- This method converts the date columns (specified by `config.date_col`)
- in the 'events' DataFrame to datetime objects.
- It also checks for and removes rows with missing dates in these tables,
- logging an error if any are found, unless `skip_missing_dates` is True.
+ This method performs the following steps:
+
+ 1. Converts the date column (specified by `config.date_col`) in the
+ 'events' DataFrame to datetime objects.
+ 2. Checks for and removes rows with missing dates, raising a
+ ``ValueError`` unless `skip_missing_dates` is True.
+ 3. Checks for missing event values and either drops them (if
+ `drop_missing_event_values` is True) or raises a ``ValueError``.
+ 4. Validates that there are no missing values in
+ ``event_descriptive_name``, ``event_name``, and ``event_category``
+ columns, raising a ``ValueError`` if any are found.
+ 5. Converts event values, descriptive names, event names, and event
+ categories to string type.
+ 6. Validates that all columns listed in
+ ``config.constant_columns_to_use`` are present in the
+ ``constant_description`` dataframe.
+ 7. Computes ``self.all_patientids`` as the intersection of patient IDs
+ appearing in both the constant and events tables.
+
+ Parameters
+ ----------
+ skip_missing_dates : bool, optional
+ If True, rows with missing dates are silently dropped instead of
+ raising an error. Defaults to False.
+ drop_missing_event_values : bool, optional
+ If True, rows with missing event values are dropped with a warning
+ instead of raising an error. Defaults to False.
Raises
------
ValueError
If `load_indication_data` has not been successfully called before
- this method, or if missing dates are found and `skip_missing_dates` is False.
+ this method, if missing dates are found and `skip_missing_dates`
+ is False, if missing event values are found and
+ `drop_missing_event_values` is False, if missing values are found
+ in event name columns, or if constant columns are missing from
+ the constant_description dataframe.
"""
# Check that we already have self.data_frames
@@ -240,11 +229,54 @@ def handle_missing_dates(df_key, missing_count, total_count, col_date):
"please fix the data or set drop_missing_event_values=True"
)
+ # Check for missing values in event_descriptive_name, event_name, and event_category columns
+ for col in [
+ self.config.event_descriptive_name_col,
+ self.config.event_name_col,
+ self.config.event_category_col,
+ ]:
+ missing_count = self.data_frames[events_table_key][col].isnull().sum()
+ if missing_count > 0:
+ total = len(self.data_frames[events_table_key])
+ raise ValueError(
+ f"Found {missing_count} out of {total} missing values in '{col}' column "
+ f"in events table - please fix the data before proceeding"
+ )
+
# Convert all event values to string
self.data_frames[events_table_key][self.config.event_value_col] = self.data_frames[events_table_key][
self.config.event_value_col
].astype(str)
+ # Convert event_descriptive_name, event_name, event_category to string as well to avoid issues later on
+ self.data_frames[events_table_key][self.config.event_descriptive_name_col] = self.data_frames[events_table_key][
+ self.config.event_descriptive_name_col
+ ].astype(str)
+ self.data_frames[events_table_key][self.config.event_name_col] = self.data_frames[events_table_key][
+ self.config.event_name_col
+ ].astype(str)
+ self.data_frames[events_table_key][self.config.event_category_col] = self.data_frames[events_table_key][
+ self.config.event_category_col
+ ].astype(str)
+
+ # Check that every column selected in config for constant is also in constant descriptive df
+ constant_desc_variables = set(self.data_frames["constant_description"]["variable"].unique())
+ missing_in_description = [
+ col for col in self.config.constant_columns_to_use if col not in constant_desc_variables
+ ]
+ if missing_in_description:
+ raise ValueError(
+ f"The following columns are listed in config.constant_columns_to_use but are not "
+ f"present in the constant_description 'variable' column: {missing_in_description}. "
+ f"Please add them to the constant_description dataframe or remove them from "
+ f"config.constant_columns_to_use."
+ )
+
+ # Add all unique patientids overlapping constant and events to self.all_patientids
+ constant_patientids = set(self.data_frames["constant"]["patientid"].unique())
+ event_patientids = set(self.data_frames["events"]["patientid"].unique())
+ self.all_patientids = list(constant_patientids.intersection(event_patientids))
+
logging.info("Data processed")
def setup_unique_mapping_of_events(self) -> None:
@@ -298,6 +330,7 @@ def setup_unique_mapping_of_events(self) -> None:
# Extract corresponding event_name and event_category
filtered_events = self.unique_events[event_desc_name_col]
non_unique_events = self.unique_events[filtered_events.isin(non_unique_events.index)].copy()
+ non_unique_descriptive_names = non_unique_events[event_desc_name_col].unique().tolist()
# create mapping for all non-unique descriptive names, and
# then add event_name to those, and apply across entire dataset
@@ -320,6 +353,15 @@ def setup_unique_mapping_of_events(self) -> None:
events_df[event_desc_name_col] = events_df[new_desc_name].fillna(events_df[event_desc_name_col])
self.data_frames[events_table_key] = self.data_frames[events_table_key].drop(columns=["new_descriptive_name"])
+ # Warn if there were any non-unique descriptive names
+ num_renamed = len(non_unique_events)
+ if num_renamed > 0:
+ logging.warning(
+ f"Found {len(non_unique_descriptive_names)} non-unique descriptive name(s) "
+ f"({num_renamed} event mappings total) that were disambiguated "
+ f"by appending the event_name: {non_unique_descriptive_names}"
+ )
+
#: first convert special symbols in event_descriptive_name to alternatives, using self.replace_special_symbols
for event_category, (
string_to_replace,
@@ -346,10 +388,20 @@ def setup_unique_mapping_of_events(self) -> None:
# Assert that all unique now
# Use config constant
- assert len(self.unique_events) == len(self.data_frames[events_table_key][event_desc_name_col].unique())
+ assert len(self.unique_events) == len(self.data_frames[events_table_key][event_desc_name_col].unique()), (
+ "Each descriptive name needs a unique mapping to an event name/category"
+ f" Found this many unique descriptive names: "
+ f"{len(self.data_frames[events_table_key][event_desc_name_col].unique())} "
+ f"but this many unique combinations of event name, descriptive name, and "
+ f"category: {len(self.unique_events)}. Please ensure that after processing, each descriptive name maps "
+ f"to exactly one event name within its category."
+ )
- def setup_dataset_splits(
+ def setup_hold_out_sets(
self,
+ validation_split: float,
+ test_split: float,
+ max_val_test_nr_patients: int = None,
) -> None:
"""
Assigns each patient to a data split (train, validation, or test).
@@ -358,8 +410,8 @@ def setup_dataset_splits(
The method determines the split assignment for each patient.
It retrieves all unique patient IDs from the 'constant' data table.
It calculates the number of patients for validation and test sets based on
- the `validation_split_max`, `test_split_max`, and `max_val_test_nr_patients`
- parameters set during initialization. The remaining patients are assigned to the training set #
+ the `validation_split`, `test_split`, and `max_val_test_nr_patients`
+ parameters. The remaining patients are assigned to the training set
(calculated as the remainder after validation and test sets are allocated). Patients are randomly
shuffled (with a fixed seed for reproducibility) before assignment.
@@ -368,6 +420,20 @@ def setup_dataset_splits(
`self.all_patientids`. Asserts are performed to ensure the mapping covers
all patients without overlap and that the split sizes match calculations.
+ Parameters
+ ----------
+ validation_split : float
+ The proportion of the total patients to allocate to the
+ validation set. The actual number is capped by
+ `max_val_test_nr_patients` if provided.
+ test_split : float
+ The proportion of the total patients to allocate to the
+ test set. The actual number is capped by
+ `max_val_test_nr_patients` if provided.
+ max_val_test_nr_patients : int, optional
+ The absolute maximum number of patients to include in the validation
+ and test sets individually. Defaults to None.
+
Raises
------
ValueError
@@ -398,13 +464,17 @@ def setup_dataset_splits(
#: get min(self.validation_split * num_patients, self.max_val_test_nr_patients)
validation_nr_patients = min(
- int(self.validation_split * len(all_patients)),
- self.max_val_test_nr_patients,
+ int(validation_split * len(all_patients)),
+ max_val_test_nr_patients
+ if max_val_test_nr_patients is not None
+ else int(validation_split * len(all_patients)),
)
#: then the same for test
- test_nr_patients = min(int(self.test_split * len(all_patients)), self.max_val_test_nr_patients)
-
+ test_nr_patients = min(
+ int(test_split * len(all_patients)),
+ max_val_test_nr_patients if max_val_test_nr_patients is not None else int(test_split * len(all_patients)),
+ )
#: randomly shuffle with seed and split into train/val/test, using df.sample
np.random.seed(self.config.seed)
all_patients = np.random.permutation(all_patients)
@@ -459,7 +529,7 @@ def setup_dataset_splits(
patient_id_col
].map(patient_to_split_mapping)
- def get_all_patientids_in_split(self, split: str) -> str:
+ def get_all_patientids_in_split(self, split: str) -> list:
"""
Retrieves all patient IDs belonging to a specific data split.
@@ -574,7 +644,7 @@ def get_patient_data(self, patientid: str) -> dict:
def infer_var_types(self):
"""
- Fills self.dm.variable_types for every candidate forecasting variable.
+ Fills self.variable_types for every candidate forecasting variable.
Classifies as "numeric" if at least `self.config.numeric_detect_min_fraction` of values
can be parsed as numeric, otherwise "categorical".
"""
diff --git a/twinweaver/instruction/converter_instruction.py b/twinweaver/instruction/converter_instruction.py
index bc17130..42c88b6 100644
--- a/twinweaver/instruction/converter_instruction.py
+++ b/twinweaver/instruction/converter_instruction.py
@@ -245,7 +245,7 @@ def forward_conversion(
self,
forecasting_splits: list[DataSplitterForecastingGroup],
event_splits: list[DataSplitterEventsGroup],
- override_mode_to_select_forecasting: str = None,
+ override_mode_to_select_forecasting: str = "forecasting",
) -> dict:
"""
Generates a multi-task instruction prompt and target answer from patient data splits.
@@ -268,7 +268,7 @@ def forward_conversion(
(containing patient history up to a split date and event outcome/censoring info).
override_mode_to_select_forecasting : str, optional
If provided, forces the selection mode for forecasting tasks ('forecasting',
- 'forecasting_qa', or 'both'). If None, the mode is chosen randomly. Defaults to None.
+ 'forecasting_qa', or 'both'). If None, the mode is chosen randomly. Defaults to "forecasting".
Returns
-------
@@ -287,6 +287,12 @@ def forward_conversion(
or if no tasks are generated.
"""
+ # If events is None, set to empty list for easier processing
+ if event_splits is None:
+ event_splits = []
+ if forecasting_splits is None:
+ forecasting_splits = []
+
#: make assertions that data has same split and lot date
all_lot_dates_events = [x.lot_date for x in event_splits]
all_lot_dates_forecasting = [x.lot_date for x in forecasting_splits]
@@ -808,6 +814,9 @@ def reverse_conversion(
standard_extraction = False
elif inference_override is False:
raise ValueError("Could not determine task type")
+ else:
+ # inference_override is True but task type is unknown – skip
+ continue
#: extract the relevant parts
if standard_extraction:
diff --git a/twinweaver/instruction/data_splitter.py b/twinweaver/instruction/data_splitter.py
index a992568..74208b9 100644
--- a/twinweaver/instruction/data_splitter.py
+++ b/twinweaver/instruction/data_splitter.py
@@ -9,9 +9,19 @@ class DataSplitter:
"""
Combines both data splitters into one interface for easier usage.
For more advanced use cases, the individual data splitters can still be used directly.
+
+ At least one of ``data_splitter_events`` or ``data_splitter_forecasting`` must be
+ provided. When only one splitter is supplied, the methods will return ``None`` /
+ empty results for the missing task type.
"""
- def __init__(self, data_splitter_events: DataSplitterEvents, data_splitter_forecasting: DataSplitterForecasting):
+ def __init__(
+ self,
+ data_splitter_events: DataSplitterEvents = None,
+ data_splitter_forecasting: DataSplitterForecasting = None,
+ ):
+ if data_splitter_events is None and data_splitter_forecasting is None:
+ raise ValueError("At least one of data_splitter_events or data_splitter_forecasting must be provided.")
self.data_splitter_events = data_splitter_events
self.data_splitter_forecasting = data_splitter_forecasting
@@ -64,33 +74,56 @@ def get_splits_from_patient_with_target(
-------
tuple
A tuple containing three elements:
- 1. forecasting_splits: list[DataSplitterForecastingGroup]
- List of generated forecasting split groups.
- 2. events_splits: list[DataSplitterEventsGroup]
- List of generated event prediction split groups, corresponding to the forecasting splits.
+ 1. forecasting_splits: list[DataSplitterForecastingGroup] or None
+ List of generated forecasting split groups, or None if no forecasting splitter is set.
+ 2. events_splits: list[DataSplitterEventsGroup] or None
+ List of generated event prediction split groups, or None if no events splitter is set.
3. reference_dates: pd.DataFrame
DataFrame containing the split dates and LoT dates used.
"""
- # Process forecasting splits
- forecasting_splits, reference_dates = self.data_splitter_forecasting.get_splits_from_patient(
- patient_data,
- nr_samples_per_split=forecasting_nr_samples_per_split,
- include_metadata=True,
- max_num_splits_per_split_event=max_num_splits_per_split_event,
- filter_outliers=forecasting_filter_outliers,
- override_categories_to_predict=forecasting_override_categories_to_predict,
- override_variables_to_predict=forecasting_override_variables_to_predict,
- override_split_dates=forecasting_override_split_dates,
- )
+ forecasting_splits = None
+ events_splits = None
+ reference_dates = None
- # Process event prediction splits
- events_splits = self.data_splitter_events.get_splits_from_patient(
- patient_data,
- reference_split_dates=reference_dates,
- max_nr_samples_per_split=events_max_nr_samples_per_split,
- override_category=events_override_category,
- override_observation_time_delta=events_override_observation_time_delta,
- )
+ # Process forecasting splits (if available)
+ if self.data_splitter_forecasting is not None:
+ forecasting_splits, reference_dates = self.data_splitter_forecasting.get_splits_from_patient(
+ patient_data,
+ nr_samples_per_split=forecasting_nr_samples_per_split,
+ include_metadata=True,
+ max_num_splits_per_split_event=max_num_splits_per_split_event,
+ filter_outliers=forecasting_filter_outliers,
+ override_categories_to_predict=forecasting_override_categories_to_predict,
+ override_variables_to_predict=forecasting_override_variables_to_predict,
+ override_split_dates=forecasting_override_split_dates,
+ )
+
+ # Process event prediction splits (if available)
+ if self.data_splitter_events is not None:
+ events_splits = self.data_splitter_events.get_splits_from_patient(
+ patient_data,
+ reference_split_dates=reference_dates,
+ max_nr_samples_per_split=events_max_nr_samples_per_split,
+ max_num_splits_per_split_event=max_num_splits_per_split_event,
+ override_category=events_override_category,
+ override_observation_time_delta=events_override_observation_time_delta,
+ )
+
+ # When only events splitter is used, extract reference_dates from events_splits
+ if reference_dates is None and events_splits is not None:
+ config = self.data_splitter_events.config
+ ref_rows = []
+ for group in events_splits:
+ if len(group) > 0:
+ opt = group[0]
+ ref_rows.append(
+ {
+ config.date_col: opt.split_date_included_in_input,
+ config.split_date_col: opt.lot_date,
+ }
+ )
+ if ref_rows:
+ reference_dates = pd.DataFrame(ref_rows)
#: return both, since we want to be able to still have the flexibility to use both splitters directly
return forecasting_splits, events_splits, reference_dates
@@ -135,36 +168,49 @@ def get_splits_from_patient_inference(
2. events_split: DataSplitterEventsOption or None
The generated event prediction option, or None if inference_type is 'forecasting'.
"""
+ # Resolve the config from whichever splitter is available
+ _config = (
+ self.data_splitter_events.config
+ if self.data_splitter_events is not None
+ else self.data_splitter_forecasting.config
+ )
+
# assume last date in events is the split date that we're interested in
- patient_data["events"] = patient_data["events"].sort_values(by=self.data_splitter_events.config.date_col)
- split_date = patient_data["events"][self.data_splitter_events.config.date_col].iloc[-1]
+ patient_data["events"] = patient_data["events"].sort_values(by=_config.date_col)
+ split_date = patient_data["events"][_config.date_col].iloc[-1]
#: generate forecasting split
+ forecast_split = None
if inference_type == "both" or inference_type == "forecasting":
- forecast_splits = self.data_splitter_forecasting.get_splits_from_patient(
- patient_data,
- nr_samples_per_split=1,
- filter_outliers=False, # Since no filtering needed, since no target exists
- override_split_dates=[split_date],
- override_variables_to_predict=forecasting_override_variables_to_predict,
- )
- # The first one is the only one
- forecast_split = forecast_splits[0][0]
- else:
- forecast_split = None
+ if self.data_splitter_forecasting is None:
+ if inference_type != "both":
+ raise ValueError("DataSplitterForecasting must be set to generate forecasting splits.")
+ else:
+ forecast_splits = self.data_splitter_forecasting.get_splits_from_patient(
+ patient_data,
+ nr_samples_per_split=1,
+ filter_outliers=False, # Since no filtering needed, since no target exists
+ override_split_dates=[split_date],
+ override_variables_to_predict=forecasting_override_variables_to_predict,
+ )
+ # The first one is the only one
+ forecast_split = forecast_splits[0][0]
#: generate event split
+ events_split = None
if inference_type == "both" or inference_type == "events":
- events_splits = self.data_splitter_events.get_splits_from_patient(
- patient_data,
- max_nr_samples_per_split=1,
- override_split_dates=[split_date],
- override_category=events_override_category,
- override_observation_time_delta=events_override_observation_time_delta,
- )
- # The first one is the only one
- events_split = events_splits[0][0]
- else:
- events_split = None
+ if self.data_splitter_events is None:
+ if inference_type != "both":
+ raise ValueError("DataSplitterEvents must be set to generate event prediction splits.")
+ else:
+ events_splits = self.data_splitter_events.get_splits_from_patient(
+ patient_data,
+ max_nr_samples_per_split=1,
+ override_split_dates=[split_date],
+ override_category=events_override_category,
+ override_observation_time_delta=events_override_observation_time_delta,
+ )
+ # The first one is the only one
+ events_split = events_splits[0][0]
return forecast_split, events_split
diff --git a/twinweaver/instruction/data_splitter_base.py b/twinweaver/instruction/data_splitter_base.py
index f489d6c..5e2bb4a 100644
--- a/twinweaver/instruction/data_splitter_base.py
+++ b/twinweaver/instruction/data_splitter_base.py
@@ -15,8 +15,6 @@ def __init__(
data_manager: DataManager,
config: Config,
max_split_length_after_split_event: pd.Timedelta = pd.Timedelta(days=90),
- max_lookback_time_for_value: pd.Timedelta = pd.Timedelta(days=90),
- max_forecast_time_for_value: pd.Timedelta = pd.Timedelta(days=90),
):
"""
Constructor for the BaseDataSplitter class.
@@ -30,12 +28,6 @@ def __init__(
max_split_length_after_split_event: pd.Timedelta
the maximum number of days after a LoT event that we want to consider as
a starting point.
- max_lookback_time_for_value: pd.Timedelta
- the maximum number of days before a certain split date where we need to see
- the value of the target variable.
- max_forecast_time_for_value : pd.Timedelta
- the maximum number of days after a certain split date where we need to see
- the value of the target variable when filtering.
"""
assert config.split_event_category is not None, "config.split_event_category must be set (e.g. ['lab'])."
@@ -43,8 +35,6 @@ def __init__(
self.dm = data_manager
self.config = config
self.max_split_length_after_split_event = max_split_length_after_split_event
- self.max_lookback_time_for_value = max_lookback_time_for_value
- self.max_forecast_time_for_value = max_forecast_time_for_value
def _get_all_dates_within_range_of_split_event(
self,
diff --git a/twinweaver/instruction/data_splitter_events.py b/twinweaver/instruction/data_splitter_events.py
index 523238e..fce1823 100644
--- a/twinweaver/instruction/data_splitter_events.py
+++ b/twinweaver/instruction/data_splitter_events.py
@@ -88,12 +88,10 @@ def __init__(
self,
data_manager: DataManager,
config: Config,
- max_length_to_sample: pd.Timedelta = pd.Timedelta(weeks=104),
- min_length_to_sample: pd.Timedelta = pd.Timedelta(weeks=1),
+ max_length_to_sample: pd.Timedelta,
+ min_length_to_sample: pd.Timedelta,
unit_length_to_sample: str = "weeks",
- max_split_length_after_split_event: pd.Timedelta = pd.Timedelta(days=90),
- max_lookback_time_for_value: pd.Timedelta = pd.Timedelta(days=90),
- max_forecast_time_for_value: pd.Timedelta = pd.Timedelta(days=90),
+ max_split_length_after_split_event: pd.Timedelta = pd.Timedelta(days=0),
):
"""
Initialize the DataSplitterEvents class.
@@ -105,35 +103,32 @@ def __init__(
config : Config
Configuration object holding constants.
max_length_to_sample : pd.Timedelta
- The maximum number of weeks into the future to sample for event prediction.
+ The maximum length of time into the future to sample for event prediction.
+ Required, no default.
min_length_to_sample : pd.Timedelta
- The minimum number of weeks into the future to sample for event prediction.
+ The minimum length of time into the future to sample for event prediction.
+ Required, no default.
unit_length_to_sample : str
The unit of time for the length to sample (e.g. "weeks").
max_split_length_after_split_event : pd.Timedelta, optional
The maximum number of days after the split event (e.g. line of therapy) to consider for split points.
- max_lookback_time_for_value : pd.Timedelta, optional
- The maximum number of days to look back for a value (inherited but not directly used here).
- max_forecast_time_for_value : pd.Timedelta, optional
- The maximum number of days to forecast a value (inherited but not directly used here).
+ Defaults to 0 days.
"""
super().__init__(
data_manager,
config,
max_split_length_after_split_event,
- max_lookback_time_for_value,
- max_forecast_time_for_value,
)
self.max_length_to_sample = max_length_to_sample
self.min_length_to_sample = min_length_to_sample
self.unit_length_to_sample = unit_length_to_sample
- assert self.config.data_splitter_events_variables_category_mapping is not None, (
- "data_splitter_events_variables_category_mapping must be set in Config for DataSplitterEvents."
+ assert self.config.event_category_events_prediction_with_naming is not None, (
+ "event_category_events_prediction_with_naming must be set in Config for DataSplitterEvents."
"For example: { 'death': 'death', 'progression': 'next progression'}"
)
- self.manual_variables_category_mapping = self.config.data_splitter_events_variables_category_mapping
+ self.manual_variables_category_mapping = self.config.event_category_events_prediction_with_naming
def setup_variables(self):
"""
@@ -154,7 +149,7 @@ def setup_variables(self):
if len(self.manual_variables_category_mapping) == 0:
raise ValueError(
"No valid event categories found in the data for event prediction splitting. "
- "Check the data or adjust data_splitter_events_variables_category_mapping in Config."
+ "Check the data or adjust event_category_events_prediction_with_naming in Config."
)
def _sample_manual_variables(self, events_after_split: pd.DataFrame, override_category: str) -> tuple:
@@ -330,8 +325,8 @@ def get_splits_from_patient(
# Do some quick sanity checks
if self.config.warning_for_splitters_patient_without_splits:
- lot_events = events[events[self.config.event_category_col] == self.config.event_category_lot]
- if lot_events.shape[0] == 0:
+ split_events = events[events[self.config.event_category_col] == self.config.split_event_category]
+ if split_events.shape[0] == 0:
logging.warning(
"Patient "
+ str(patient_data["constant"][self.config.patient_id_col].iloc[0])
@@ -427,15 +422,18 @@ def get_splits_from_patient(
diagnosis_after_split = events_limited_after_split[
events_limited_after_split[self.config.event_category_col] == sampled_cateogry
]
- lot_after_split = events_limited_after_split[
- events_limited_after_split[self.config.event_category_col] == self.config.event_category_lot
+ next_split_date_after_split = events_limited_after_split[
+ events_limited_after_split[self.config.event_category_col] == self.config.split_event_category
]
death_after_split = events_limited_after_split[
events_limited_after_split[self.config.event_name_col] == self.config.event_category_death
]
- #: apply censoring using next_lot_date
- next_lot_date = lot_after_split[self.config.date_col].min() if len(lot_after_split) > 0 else None
+ #: apply censoring using next_split_date
+ if len(next_split_date_after_split) > 0:
+ next_split_date = next_split_date_after_split[self.config.date_col].min()
+ else:
+ next_split_date = None
next_death_date = death_after_split[self.config.date_col].min() if len(death_after_split) > 0 else None
#: determine whether occurred, censored & if so, which date
@@ -447,18 +445,21 @@ def get_splits_from_patient(
# Event occurred within end date
occurred = True
- # If an lot occurred first though, then we're censored
- if next_lot_date is not None and diagnosis_after_split[self.config.date_col].min() > next_lot_date:
- censored = "new_therapy_start"
+ # If a split occurred first though, then we're censored
+ if (
+ next_split_date is not None
+ and diagnosis_after_split[self.config.date_col].min() > next_split_date
+ ):
+ censored = "new_split_date_start"
occurred = False
else:
# Event did not occur
occurred = False
- if next_lot_date is not None:
- # If we were censored by the next lot date
- censored = "new_therapy_start"
+ if next_split_date is not None:
+ # If we were censored by the next split date
+ censored = "new_split_date_start"
elif next_death_date is not None:
# If death occurred then not censored, since this is the only time we
diff --git a/twinweaver/instruction/data_splitter_forecasting.py b/twinweaver/instruction/data_splitter_forecasting.py
index 0059092..6fbc279 100644
--- a/twinweaver/instruction/data_splitter_forecasting.py
+++ b/twinweaver/instruction/data_splitter_forecasting.py
@@ -15,9 +15,9 @@ class DataSplitterForecastingOption:
Attributes
----------
- events_until_split : list
+ events_until_split : pd.DataFrame
Events occurring until the split point.
- target_events_after_split : list
+ target_events_after_split : pd.DataFrame
Target events occurring after the split point.
constant_data : dict
Constant data related to the patient or context.
@@ -113,9 +113,9 @@ def __init__(
self,
config: Config,
data_manager: DataManager,
- max_split_length_after_split_event: pd.Timedelta = pd.Timedelta(days=90),
- max_lookback_time_for_value: pd.Timedelta = pd.Timedelta(days=90),
- max_forecast_time_for_value: pd.Timedelta = pd.Timedelta(days=90),
+ max_forecasted_trajectory_length: pd.Timedelta,
+ max_split_length_after_split_event: pd.Timedelta = pd.Timedelta(days=0),
+ max_lookback_time_for_value: pd.Timedelta = pd.Timedelta(days=100000),
min_num_samples_for_statistics: int = 10,
sampling_score_to_use: str = "score_log_nrmse_n_samples",
min_nr_variable_seen_previously: int = 1,
@@ -123,9 +123,10 @@ def __init__(
list_of_valid_categories: list = None,
save_path_for_variable_stats: str = None,
min_nr_variables_to_sample: int = 1,
- max_nr_variables_to_sample: int = 3,
+ max_nr_variables_to_sample: int = 1,
filtering_strategy: str = "3-sigma",
sampling_strategy: str = "proportional",
+ allow_forecasting_beyond_next_split_date: bool = False,
):
"""
Initializes the DataSplitterForecasting instance.
@@ -136,17 +137,17 @@ def __init__(
Configuration object containing shared settings like column names.
data_manager : DataManager
Provides access to patient data for a single indication.
+ max_forecasted_trajectory_length : pd.Timedelta
+ Max time after a split date to look for future variable occurrences (target
+ data). Required, no default.
max_split_length_after_split_event : pd.Timedelta
- Max days after LoT start to consider for split dates. Defaults to 90.
+ Max time after LoT start to consider for split dates. Defaults to 0 days.
max_lookback_time_for_value : pd.Timedelta
- Max days before a split date to look for past variable occurrences.
- Defaults to 90.
- max_forecast_time_for_value : pd.Timedelta
- Max days after a split date to look for future variable occurrences (target
- data). Defaults to 90.
+ Max time before a split date to look for past variable occurrences.
+ Defaults to 100000 days (effectively no limit).
min_num_samples_for_statistics : int
Minimum total occurrences of a variable across the training set
- needed to calculate statistics. Defaults to 50.
+ needed to calculate statistics. Defaults to 10.
sampling_score_to_use : str
Column name in the computed statistics table used for weighted sampling of variables.
Defaults to 'score_log_nrmse_n_samples'.
@@ -164,23 +165,24 @@ def __init__(
None.
min_nr_variables_to_sample : int
The minimum number of distinct variables to attempt to sample for each
- forecasting task. Defaults to 3.
+ forecasting task. Defaults to 1.
max_nr_variables_to_sample : int
The maximum number of distinct variables to attempt to sample for each
- forecasting task. Defaults to 3.
+ forecasting task. Defaults to 1.
filtering_strategy : str
The strategy for handling outliers in target variable values ('3-sigma').
Defaults to '3-sigma'.
sampling_strategy : str
The strategy for sampling variables ('proportional' or 'uniform').
Defaults to 'proportional'.
+ allow_forecasting_beyond_next_split_date : bool
+ Flag indicating whether to allow forecasting of events that occur beyond the next split date
+ (e.g., next LoT event). Default: False.
"""
super().__init__(
data_manager,
config,
max_split_length_after_split_event,
- max_lookback_time_for_value,
- max_forecast_time_for_value,
)
assert self.config.event_category_forecast is not None or list_of_valid_categories is not None, (
@@ -188,7 +190,8 @@ def __init__(
"For example: ['lab']"
" Alternatively, provide list_of_valid_categories directly."
)
-
+ self.max_lookback_time_for_value = max_lookback_time_for_value
+ self.max_forecasted_trajectory_length = max_forecasted_trajectory_length
self.variable_stats = None
self.variable_type = {} # event_name -> "numeric" / "categorical"
self.min_num_samples_for_statistics = min_num_samples_for_statistics
@@ -204,9 +207,18 @@ def __init__(
self.max_nr_variables_to_sample = max_nr_variables_to_sample
self.filtering_strategy = filtering_strategy
self.sampling_strategy = sampling_strategy
+ self.allow_forecasting_beyond_next_split_date = allow_forecasting_beyond_next_split_date
self._filtering_methods = {"3-sigma": self._filter_3_sigma}
+ # Check that the forecasting and split event categories do not overlap, as this could cause data leakage
+ if self.config.event_category_forecast is not None and self.config.split_event_category is not None:
+ overlap = set(self.config.event_category_forecast).intersection(set(self.config.split_event_category))
+ if overlap:
+ raise ValueError(
+ f"Forecasting and split event categories overlap: {overlap}. This could cause data leakage."
+ )
+
def setup_statistics(self, train_patientids: list = None):
"""
Calculates baseline performance statistics for variables.
@@ -254,7 +266,7 @@ def setup_statistics(self, train_patientids: list = None):
temp_splits = self._get_all_dates_within_range_of_split_event(
temp_patient_data,
time_before_lot_start=self.max_lookback_time_for_value,
- max_split_length_after_split_event=self.max_forecast_time_for_value,
+ max_split_length_after_split_event=self.max_forecasted_trajectory_length,
)
temp_splits[self.config.patient_id_col] = patientid
temp_splits = temp_splits[[self.config.date_col, self.config.patient_id_col]]
@@ -652,7 +664,7 @@ def _get_all_possible_splits(
# Pre-compute date ranges for lookback and forecast
lookback_range = self.max_lookback_time_for_value
- forecast_range = self.max_forecast_time_for_value
+ forecast_range = self.max_forecasted_trajectory_length
# Initialize the return_splits list
return_splits = []
@@ -854,19 +866,21 @@ def _generate_variable_splits_for_date(
events_before_split = events[events[self.config.date_col] <= curr_date]
events_after_split = events[events[self.config.date_col] > curr_date]
events_after_split = events_after_split[
- events_after_split[self.config.date_col] <= curr_date + self.max_forecast_time_for_value
+ events_after_split[self.config.date_col] <= curr_date + self.max_forecasted_trajectory_length
]
events_after_split = events_after_split[
events_after_split[self.config.event_name_col].isin(sampled_variables)
]
- #: filter so that we do not overlap with next LoT, since that will invalidate the results
- lots = events[events[self.config.event_category_col] == self.config.event_category_lot]
- lots = lots[lots[self.config.date_col] > curr_date]
- lots = lots.sort_values(self.config.date_col)
- if lots.shape[0] > 0 and not self.config.skip_future_lot_filtering:
- date_of_next_lot = lots[self.config.date_col].iloc[0]
- events_after_split = events_after_split[events_after_split[self.config.date_col] < date_of_next_lot]
+ #: filter so that we do not overlap with next split event, since that will invalidate the results
+ next_split_events = events[events[self.config.event_category_col] == self.config.split_event_category]
+ next_split_events = next_split_events[next_split_events[self.config.date_col] > curr_date]
+ next_split_events = next_split_events.sort_values(self.config.date_col)
+ if next_split_events.shape[0] > 0 and not self.allow_forecasting_beyond_next_split_date:
+ date_of_next_split_event = next_split_events[self.config.date_col].iloc[0]
+ events_after_split = events_after_split[
+ events_after_split[self.config.date_col] < date_of_next_split_event
+ ]
#: if apply_filtering, apply 3-sigma filtering (only to target) and drop any bad rows
if apply_filtering:
@@ -973,8 +987,8 @@ def get_splits_from_patient(
# Do some quick sanity checks
if self.config.warning_for_splitters_patient_without_splits:
- lot_events = events[events[self.config.event_category_col] == self.config.event_category_lot]
- if lot_events.shape[0] == 0:
+ split_events = events[events[self.config.event_category_col] == self.config.split_event_category]
+ if split_events.shape[0] == 0:
logging.warning(
"Patient "
+ str(patient_data["constant"][self.config.patient_id_col].iloc[0])
diff --git a/twinweaver/utils/__init__.py b/twinweaver/utils/__init__.py
index 96d4156..ce81882 100644
--- a/twinweaver/utils/__init__.py
+++ b/twinweaver/utils/__init__.py
@@ -7,3 +7,8 @@
compute_length_normalized_probabilities,
run_tte_probability_estimation,
)
+from twinweaver.utils.forecasting_inference import (
+ parse_forecasting_results,
+ run_forecasting_inference,
+ run_forecasting_inference_notebook,
+)
diff --git a/twinweaver/utils/forecasting_inference.py b/twinweaver/utils/forecasting_inference.py
new file mode 100644
index 0000000..f49199d
--- /dev/null
+++ b/twinweaver/utils/forecasting_inference.py
@@ -0,0 +1,474 @@
+"""
+Forecasting inference helpers for vLLM-based text generation.
+
+Provides functions to generate forecasting predictions (future lab values,
+vitals, etc.) by sending instruction prompts to an OpenAI-compatible API
+(e.g. a local vLLM server) and parsing the model's text output back into
+structured DataFrames via :class:`~twinweaver.instruction.converter_instruction.ConverterInstruction`.
+
+The prompt construction is driven by
+:class:`~twinweaver.instruction.converter_instruction.ConverterInstruction`
+(specifically its ``forward_conversion_inference`` method), so the same
+code works for any dataset / prompt template.
+
+Typical usage
+-------------
+>>> import asyncio
+>>> from twinweaver.common.config import Config
+>>> from twinweaver.utils.forecasting_inference import (
+... run_forecasting_inference,
+... parse_forecasting_results,
+... )
+>>> config = Config()
+>>> # prompts_with_meta is a list of dicts with keys:
+>>> # "patientid", "instruction", "split_date"
+>>> results = asyncio.run(run_forecasting_inference(
+... prompts_with_meta,
+... prediction_url="http://localhost:8000/v1/",
+... prediction_model="my-model",
+... ))
+>>> df = parse_forecasting_results(results, converter, dm)
+"""
+
+from __future__ import annotations
+
+import asyncio
+from typing import Any
+
+import pandas as pd
+from openai import (
+ APIConnectionError,
+ AsyncOpenAI,
+ AuthenticationError,
+ OpenAIError,
+ RateLimitError,
+)
+
+
+# ---------------------------------------------------------------------------
+# Type alias for a single prompt payload
+# ---------------------------------------------------------------------------
+# Each element is a dict with at least:
+# "patientid" : str – unique patient identifier
+# "instruction" : str – full instruction text (from converter)
+# "split_date" : datetime – the reference date used for reverse conversion
+# Additional keys are preserved and passed through to the results.
+PromptPayload = dict[str, Any]
+
+
+# ---------------------------------------------------------------------------
+# Internal helpers
+# ---------------------------------------------------------------------------
+
+
+def _build_chat_messages(
+ instruction: str,
+ *,
+ system_prompt: str | None = None,
+) -> list[dict[str, str]]:
+ """Build the list of chat messages for the API call.
+
+ Parameters
+ ----------
+ instruction : str
+ The user-facing instruction text.
+ system_prompt : str or None
+ Optional system prompt.
+
+ Returns
+ -------
+ list[dict[str, str]]
+ """
+ messages: list[dict[str, str]] = []
+ if system_prompt is not None:
+ messages.append({"role": "system", "content": system_prompt})
+ messages.append({"role": "user", "content": instruction})
+ return messages
+
+
+# ---------------------------------------------------------------------------
+# Async LLM call (single patient)
+# ---------------------------------------------------------------------------
+
+
+async def _call_llm_for_generation_async(
+ client: AsyncOpenAI,
+ model_to_use: str,
+ payload: PromptPayload,
+ semaphore: asyncio.Semaphore,
+ *,
+ system_prompt: str | None = None,
+ max_new_tokens: int = 512,
+ temperature: float = 0.7,
+ top_p: float = 0.9,
+ n_samples: int = 1,
+) -> dict | None:
+ """Generate forecasting text for a single patient via the API.
+
+ Parameters
+ ----------
+ client : AsyncOpenAI
+ An ``openai.AsyncOpenAI`` client pointing at the inference server.
+ model_to_use : str
+ Model identifier / path accepted by the server.
+ payload : PromptPayload
+ Dict with ``"patientid"``, ``"instruction"``, ``"split_date"`` and
+ any other metadata the caller wants to pass through.
+ semaphore : asyncio.Semaphore
+ Concurrency limiter.
+ system_prompt : str or None, optional
+ Optional system prompt prepended to every request.
+ max_new_tokens : int
+ Maximum number of tokens to generate per completion.
+ temperature : float
+ Sampling temperature.
+ top_p : float
+ Nucleus-sampling probability mass.
+ n_samples : int
+ Number of independent completions to generate per prompt. When > 1
+ the caller can later aggregate multiple trajectories.
+
+ Returns
+ -------
+ dict or None
+ A copy of the input *payload* augmented with:
+
+ * ``"generated_texts"`` – list[str] of generated completion texts
+ (length = *n_samples*).
+
+ Returns *None* on unrecoverable API errors.
+ """
+ messages = _build_chat_messages(payload["instruction"], system_prompt=system_prompt)
+
+ async with semaphore:
+ try:
+ response = await client.chat.completions.create(
+ model=model_to_use,
+ messages=messages,
+ max_tokens=max_new_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ n=n_samples,
+ )
+
+ generated_texts = [choice.message.content for choice in response.choices]
+
+ result = {k: v for k, v in payload.items() if k != "instruction"}
+ result["generated_texts"] = generated_texts
+ return result
+
+ except AuthenticationError as exc:
+ print(f"Authentication error for patient {payload.get('patientid')}: {exc}")
+ except RateLimitError as exc:
+ print(f"Rate limit exceeded for patient {payload.get('patientid')}: {exc}")
+ except APIConnectionError as exc:
+ print(f"Network error for patient {payload.get('patientid')}: {exc}")
+ raise
+ except OpenAIError as exc:
+ print(f"An OpenAI error occurred for patient {payload.get('patientid')}: {exc}")
+
+ return None
+
+
+# ---------------------------------------------------------------------------
+# Async orchestrator (all patients)
+# ---------------------------------------------------------------------------
+
+
+async def _run_forecasting_inference_async(
+ prompts_with_meta: list[PromptPayload],
+ *,
+ prediction_url: str = "http://0.0.0.0:8000/v1/",
+ prediction_model: str = "default-model",
+ max_concurrent_requests: int = 40,
+ system_prompt: str | None = None,
+ max_new_tokens: int = 512,
+ temperature: float = 0.7,
+ top_p: float = 0.9,
+ n_samples: int = 1,
+ api_key: str = "EMPTY",
+ timeout: float = 600.0,
+) -> list[dict | None]:
+ """Async implementation of :func:`run_forecasting_inference`."""
+ client = AsyncOpenAI(
+ base_url=prediction_url,
+ api_key=api_key,
+ timeout=timeout,
+ )
+
+ semaphore = asyncio.Semaphore(max_concurrent_requests)
+
+ tasks = [
+ _call_llm_for_generation_async(
+ client,
+ prediction_model,
+ payload,
+ semaphore,
+ system_prompt=system_prompt,
+ max_new_tokens=max_new_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ n_samples=n_samples,
+ )
+ for payload in prompts_with_meta
+ ]
+
+ return await asyncio.gather(*tasks)
+
+
+def run_forecasting_inference(
+ prompts_with_meta: list[PromptPayload],
+ *,
+ prediction_url: str = "http://0.0.0.0:8000/v1/",
+ prediction_model: str = "default-model",
+ max_concurrent_requests: int = 40,
+ system_prompt: str | None = None,
+ max_new_tokens: int = 512,
+ temperature: float = 0.7,
+ top_p: float = 0.9,
+ n_samples: int = 1,
+ api_key: str = "EMPTY",
+ timeout: float = 600.0,
+) -> list[dict | None]:
+ """Generate forecasting predictions for all patients via an OpenAI-compatible API.
+
+ This is the main synchronous entry-point. It calls ``asyncio.run``
+ internally so it can be used from plain scripts.
+
+ Parameters
+ ----------
+ prompts_with_meta : list[PromptPayload]
+ Each element is a dict with **at least** the following keys:
+
+ * ``"patientid"`` – unique identifier (str)
+ * ``"instruction"`` – full instruction text produced by
+ ``ConverterInstruction.forward_conversion_inference``
+ * ``"split_date"`` – the reference date (datetime) used when
+ building the split; needed later for ``reverse_conversion``
+
+ Any extra keys are passed through unchanged to the results.
+
+ prediction_url : str
+ Base URL of the OpenAI-compatible inference server.
+ prediction_model : str
+ Model name / path served by the inference server.
+ max_concurrent_requests : int
+ Maximum number of concurrent API requests.
+ system_prompt : str or None
+ Optional system prompt.
+ max_new_tokens : int
+ Maximum number of tokens to generate per completion.
+ temperature : float
+ Sampling temperature (0 = greedy).
+ top_p : float
+ Nucleus-sampling probability mass.
+ n_samples : int
+ Number of independent completions per prompt. Useful for
+ trajectory aggregation (see
+ :meth:`ConverterInstruction.aggregate_multiple_responses`).
+ api_key : str
+ API key (``"EMPTY"`` for local vLLM servers).
+ timeout : float
+ Per-request timeout in seconds.
+
+ Returns
+ -------
+ list[dict or None]
+ One dict per patient. Each dict contains all keys from the input
+ payload (except ``"instruction"``) plus ``"generated_texts"`` – a
+ list of *n_samples* generated completion strings.
+ ``None`` entries indicate API failures.
+ """
+ return asyncio.run(
+ _run_forecasting_inference_async(
+ prompts_with_meta,
+ prediction_url=prediction_url,
+ prediction_model=prediction_model,
+ max_concurrent_requests=max_concurrent_requests,
+ system_prompt=system_prompt,
+ max_new_tokens=max_new_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ n_samples=n_samples,
+ api_key=api_key,
+ timeout=timeout,
+ )
+ )
+
+
+def run_forecasting_inference_notebook(
+ prompts_with_meta: list[PromptPayload],
+ *,
+ prediction_url: str = "http://0.0.0.0:8000/v1/",
+ prediction_model: str = "default-model",
+ max_concurrent_requests: int = 40,
+ system_prompt: str | None = None,
+ max_new_tokens: int = 512,
+ temperature: float = 0.7,
+ top_p: float = 0.9,
+ n_samples: int = 1,
+ api_key: str = "EMPTY",
+ timeout: float = 600.0,
+) -> list[dict | None]:
+ """Generate forecasting predictions – async variant for Jupyter notebooks.
+
+ Identical to :func:`run_forecasting_inference` but returns a *coroutine*
+ that can be ``await``-ed directly in a notebook cell (which already has
+ a running event loop).
+
+ Returns
+ -------
+ Coroutine[..., list[dict or None]]
+ """
+ return _run_forecasting_inference_async(
+ prompts_with_meta,
+ prediction_url=prediction_url,
+ prediction_model=prediction_model,
+ max_concurrent_requests=max_concurrent_requests,
+ system_prompt=system_prompt,
+ max_new_tokens=max_new_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ n_samples=n_samples,
+ api_key=api_key,
+ timeout=timeout,
+ )
+
+
+# ---------------------------------------------------------------------------
+# Post-processing: parse & reverse-convert generated text
+# ---------------------------------------------------------------------------
+
+
+def parse_forecasting_results(
+ raw_results: list[dict | None],
+ converter: Any,
+ data_manager: Any,
+ *,
+ drop_failures: bool = False,
+ aggregate_samples: bool = True,
+) -> pd.DataFrame:
+ """Parse raw generated texts into structured DataFrames via reverse conversion.
+
+ For each patient the function:
+
+ 1. Calls ``converter.reverse_conversion`` on every generated text to obtain
+ structured forecasting DataFrames.
+ 2. When *n_samples > 1* and ``aggregate_samples`` is ``True``, aggregates the
+ multiple trajectories using ``converter.aggregate_multiple_responses``.
+ 3. Returns a single long-format DataFrame with all patients' predictions.
+
+ Parameters
+ ----------
+ raw_results : list[dict or None]
+ Output of :func:`run_forecasting_inference`.
+ converter : ConverterInstruction
+ The same converter instance used to generate the instruction prompts.
+ Must expose ``reverse_conversion`` and ``aggregate_multiple_responses``.
+ data_manager : DataManager
+ The data manager instance (passed to ``reverse_conversion``).
+ drop_failures : bool
+ If *True*, silently drop ``None`` entries (API failures).
+ If *False*, raise a ``ValueError`` when any entry is ``None``.
+ aggregate_samples : bool
+ If *True* (default) and multiple samples were generated per patient,
+ aggregate them via ``converter.aggregate_multiple_responses``.
+ If *False*, each sample is returned as a separate row block with
+ a ``"sample_idx"`` column.
+
+ Returns
+ -------
+ pd.DataFrame
+ A long-format DataFrame with columns from the reverse-converted
+ forecasting data plus ``"patientid"`` and optionally ``"sample_idx"``.
+
+ Raises
+ ------
+ ValueError
+ If *drop_failures* is *False* and any result is ``None``.
+ """
+ if drop_failures:
+ valid = [r for r in raw_results if r is not None]
+ else:
+ if any(r is None for r in raw_results):
+ raise ValueError("Some results are None (API failures). Set drop_failures=True to silently ignore them.")
+ valid = raw_results # type: ignore[assignment]
+
+ if not valid:
+ raise ValueError("No valid results to process.")
+
+ all_rows: list[pd.DataFrame] = []
+
+ for result in valid:
+ patientid = result["patientid"]
+ split_date = result["split_date"]
+ generated_texts = result["generated_texts"]
+
+ # Reverse-convert each sample
+ sample_dfs: list[pd.DataFrame] = []
+ for sample_idx, text in enumerate(generated_texts):
+ try:
+ parsed_tasks = converter.reverse_conversion(
+ text,
+ data_manager,
+ split_date,
+ patientid=patientid,
+ inference_override=True,
+ )
+ except Exception as exc:
+ print(f"Warning: reverse_conversion failed for patient {patientid} sample {sample_idx}: {exc}")
+ continue
+
+ # Collect forecasting task results
+ for task_result in parsed_tasks:
+ task_type = task_result.get("task_type", "")
+ result_data = task_result.get("result")
+
+ if isinstance(result_data, pd.DataFrame) and not result_data.empty:
+ df_task = result_data.copy()
+ df_task["patientid"] = patientid
+ df_task["sample_idx"] = sample_idx
+ df_task["task_type"] = task_type
+ sample_dfs.append(df_task)
+
+ if not sample_dfs:
+ continue
+
+ if aggregate_samples and len(generated_texts) > 1 and len(sample_dfs) > 1:
+ # Use the converter's built-in aggregation (groups by task type)
+ # Separate by task type, aggregate each, then combine
+ combined = pd.concat(sample_dfs, ignore_index=True)
+ task_types = combined["task_type"].unique()
+ agg_parts = []
+ for tt in task_types:
+ task_subset = combined[combined["task_type"] == tt]
+ # Group into per-sample DataFrames for the aggregator
+ per_sample = [
+ task_subset[task_subset["sample_idx"] == si].drop(
+ columns=["sample_idx", "task_type"], errors="ignore"
+ )
+ for si in task_subset["sample_idx"].unique()
+ ]
+ try:
+ agg_df, _meta = converter.aggregate_multiple_responses(per_sample)
+ agg_df["task_type"] = tt
+ agg_df["patientid"] = patientid
+ agg_parts.append(agg_df)
+ except Exception as exc:
+ print(f"Warning: aggregation failed for patient {patientid}: {exc}")
+ # Fallback: just keep first sample
+ fallback = per_sample[0].copy()
+ fallback["task_type"] = tt
+ fallback["patientid"] = patientid
+ agg_parts.append(fallback)
+
+ if agg_parts:
+ all_rows.append(pd.concat(agg_parts, ignore_index=True))
+ else:
+ all_rows.extend(sample_dfs)
+
+ if not all_rows:
+ return pd.DataFrame()
+
+ df_out = pd.concat(all_rows, ignore_index=True)
+ return df_out
diff --git a/twinweaver/utils/tte_inference.py b/twinweaver/utils/tte_inference.py
index 003a651..7ddcd47 100644
--- a/twinweaver/utils/tte_inference.py
+++ b/twinweaver/utils/tte_inference.py
@@ -645,4 +645,14 @@ def _hard_prediction(row):
df["probability_no_occurrence"] = df[f"softmax_{LABEL_NOT_OCCURRED}"]
df["probability_censored"] = df[f"softmax_{LABEL_CENSORED}"]
+ # 5. Add in renormalized probabilities that exclude the censored class
+ # (for some analyses it may be useful to look at the relative probabilities of occurrence vs no occurrence,
+ # ignoring censoring).
+ df["probability_occurrence_renormalized"] = df["probability_occurrence"] / (
+ df["probability_occurrence"] + df["probability_no_occurrence"]
+ )
+ df["probability_no_occurrence_renormalized"] = df["probability_no_occurrence"] / (
+ df["probability_occurrence"] + df["probability_no_occurrence"]
+ )
+
return df