Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/01_data_preparation_for_training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@
],
"metadata": {
"kernelspec": {
"display_name": ".venv_dev_gpu",
"display_name": ".venv_dev",
"language": "python",
"name": "python3"
},
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ build-backend = "setuptools.build_meta" # Standard entry point for setuptools bu

[project]
name = "twinweaver"
version = "0.3.1"
version = "0.3.3"
description = "Converting longitudinal patient data into text for LLM-based event prediction and forecasting."

# --- NEW/UPDATED FIELDS ---
Expand Down
58 changes: 58 additions & 0 deletions tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,64 @@ def test_event_categories_to_exclude_empty_list(setup_components):
assert "drug pemetrexed is administered" in instruction


def test_event_categories_excluded_from_input_but_kept_in_targets(mock_config, sample_data):
"""Exclusion config must affect only input history, not forecasting targets."""
df_events, df_constant, df_constant_desc = sample_data

mock_config.split_event_category = "lot"
# Forecast drug events so target/answer is expected to include drug content
mock_config.event_category_forecast = ["drug"]
mock_config.event_category_events_prediction_with_naming = {"death": "death", "progression": "next progression"}
mock_config.constant_columns_to_use = ["birthyear", "gender", "histology", "smoking_history"]
mock_config.constant_birthdate_column = "birthyear"
mock_config.event_categories_to_exclude_from_input = ["drug"]

dm = DataManager(config=mock_config)
dm.load_indication_data(df_events, df_constant, df_constant_desc)
dm.process_indication_data()
dm.setup_unique_mapping_of_events()
dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)
dm.infer_var_types()

splitter_forecast = DataSplitterForecasting(
data_manager=dm,
config=mock_config,
max_forecasted_trajectory_length=pd.Timedelta(days=90),
max_split_length_after_split_event=pd.Timedelta(days=90),
)
splitter_forecast.setup_statistics()

data_splitter = DataSplitter(data_splitter_forecasting=splitter_forecast)

converter = ConverterInstruction(
nr_tokens_budget_total=4096, config=mock_config, dm=dm, variable_stats=splitter_forecast.variable_stats
)

patient_data = dm.get_patient_data("p0")
f_splits, e_splits, _ = data_splitter.get_splits_from_patient_with_target(patient_data)

assert e_splits is None
assert f_splits is not None and len(f_splits) > 0

result = converter.forward_conversion(
forecasting_splits=f_splits[0],
event_splits=None,
override_mode_to_select_forecasting="forecasting",
)

instruction = result["instruction"]
answer = result["answer"]

# Input context should not include excluded drug history lines
input_context = instruction.split("Task 1 is", 1)[0]
assert "drug pemetrexed is administered" not in input_context

# Target/answer must still contain the forecasted drug values
assert "Task 1 is forecasting:" in answer
assert "drug " in answer.lower()
assert "is administered" in answer.lower()


def test_forward_conversion_inference(setup_components):
"""Test conversion for inference (no target string)."""
dm, data_splitter, converter = setup_components
Expand Down
12 changes: 6 additions & 6 deletions twinweaver/common/converter_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def _get_constant_string(self, constant: pd.DataFrame, constant_description: pd.

return constant_string

def _preprocess_events(self, events: pd.DataFrame) -> pd.DataFrame:
def _preprocess_events(self, events: pd.DataFrame, is_input: bool = True) -> pd.DataFrame:
"""
Performs initial preprocessing on the time-series event data.

Expand Down Expand Up @@ -299,11 +299,11 @@ def _preprocess_events(self, events: pd.DataFrame) -> pd.DataFrame:
)

# Exclude specified event categories from the input if configured
if self.config.event_categories_to_exclude_from_input:
events = events[
~events[self.config.event_category_col].isin(self.config.event_categories_to_exclude_from_input)
]

if is_input:
if self.config.event_categories_to_exclude_from_input:
events = events[
~events[self.config.event_category_col].isin(self.config.event_categories_to_exclude_from_input)
]
return events

def _get_event_string(
Expand Down
2 changes: 1 addition & 1 deletion twinweaver/instruction/converter_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _generate_target_string(self, patient_split: DataSplitterForecastingOption)

#: preprocess:
target_data = patient_split.target_events_after_split.copy()
target_cleaned = self._preprocess_events(target_data.copy())
target_cleaned = self._preprocess_events(target_data.copy(), is_input=False)

#: get delta between split and first target
target_first_day = target_cleaned[self.config.date_col].min()
Expand Down
Loading