diff --git a/examples/01_data_preparation_for_training.ipynb b/examples/01_data_preparation_for_training.ipynb index d42c8f7..e74876d 100644 --- a/examples/01_data_preparation_for_training.ipynb +++ b/examples/01_data_preparation_for_training.ipynb @@ -423,7 +423,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv_dev_gpu", + "display_name": ".venv_dev", "language": "python", "name": "python3" }, diff --git a/pyproject.toml b/pyproject.toml index 00aa66f..3f929f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ build-backend = "setuptools.build_meta" # Standard entry point for setuptools bu [project] name = "twinweaver" -version = "0.3.1" +version = "0.3.3" description = "Converting longitudinal patient data into text for LLM-based event prediction and forecasting." # --- NEW/UPDATED FIELDS --- diff --git a/tests/test_converter.py b/tests/test_converter.py index d5d713a..a19d1c0 100644 --- a/tests/test_converter.py +++ b/tests/test_converter.py @@ -248,6 +248,64 @@ def test_event_categories_to_exclude_empty_list(setup_components): assert "drug pemetrexed is administered" in instruction +def test_event_categories_excluded_from_input_but_kept_in_targets(mock_config, sample_data): + """Exclusion config must affect only input history, not forecasting targets.""" + df_events, df_constant, df_constant_desc = sample_data + + mock_config.split_event_category = "lot" + # Forecast drug events so target/answer is expected to include drug content + mock_config.event_category_forecast = ["drug"] + mock_config.event_category_events_prediction_with_naming = {"death": "death", "progression": "next progression"} + mock_config.constant_columns_to_use = ["birthyear", "gender", "histology", "smoking_history"] + mock_config.constant_birthdate_column = "birthyear" + mock_config.event_categories_to_exclude_from_input = ["drug"] + + dm = DataManager(config=mock_config) + dm.load_indication_data(df_events, df_constant, df_constant_desc) + dm.process_indication_data() + dm.setup_unique_mapping_of_events() + dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1) + dm.infer_var_types() + + splitter_forecast = DataSplitterForecasting( + data_manager=dm, + config=mock_config, + max_forecasted_trajectory_length=pd.Timedelta(days=90), + max_split_length_after_split_event=pd.Timedelta(days=90), + ) + splitter_forecast.setup_statistics() + + data_splitter = DataSplitter(data_splitter_forecasting=splitter_forecast) + + converter = ConverterInstruction( + nr_tokens_budget_total=4096, config=mock_config, dm=dm, variable_stats=splitter_forecast.variable_stats + ) + + patient_data = dm.get_patient_data("p0") + f_splits, e_splits, _ = data_splitter.get_splits_from_patient_with_target(patient_data) + + assert e_splits is None + assert f_splits is not None and len(f_splits) > 0 + + result = converter.forward_conversion( + forecasting_splits=f_splits[0], + event_splits=None, + override_mode_to_select_forecasting="forecasting", + ) + + instruction = result["instruction"] + answer = result["answer"] + + # Input context should not include excluded drug history lines + input_context = instruction.split("Task 1 is", 1)[0] + assert "drug pemetrexed is administered" not in input_context + + # Target/answer must still contain the forecasted drug values + assert "Task 1 is forecasting:" in answer + assert "drug " in answer.lower() + assert "is administered" in answer.lower() + + def test_forward_conversion_inference(setup_components): """Test conversion for inference (no target string).""" dm, data_splitter, converter = setup_components diff --git a/twinweaver/common/converter_base.py b/twinweaver/common/converter_base.py index 3f76c07..35dd983 100644 --- a/twinweaver/common/converter_base.py +++ b/twinweaver/common/converter_base.py @@ -264,7 +264,7 @@ def _get_constant_string(self, constant: pd.DataFrame, constant_description: pd. return constant_string - def _preprocess_events(self, events: pd.DataFrame) -> pd.DataFrame: + def _preprocess_events(self, events: pd.DataFrame, is_input: bool = True) -> pd.DataFrame: """ Performs initial preprocessing on the time-series event data. @@ -299,11 +299,11 @@ def _preprocess_events(self, events: pd.DataFrame) -> pd.DataFrame: ) # Exclude specified event categories from the input if configured - if self.config.event_categories_to_exclude_from_input: - events = events[ - ~events[self.config.event_category_col].isin(self.config.event_categories_to_exclude_from_input) - ] - + if is_input: + if self.config.event_categories_to_exclude_from_input: + events = events[ + ~events[self.config.event_category_col].isin(self.config.event_categories_to_exclude_from_input) + ] return events def _get_event_string( diff --git a/twinweaver/instruction/converter_forecasting.py b/twinweaver/instruction/converter_forecasting.py index b93e49a..e94dd38 100644 --- a/twinweaver/instruction/converter_forecasting.py +++ b/twinweaver/instruction/converter_forecasting.py @@ -127,7 +127,7 @@ def _generate_target_string(self, patient_split: DataSplitterForecastingOption) #: preprocess: target_data = patient_split.target_events_after_split.copy() - target_cleaned = self._preprocess_events(target_data.copy()) + target_cleaned = self._preprocess_events(target_data.copy(), is_input=False) #: get delta between split and first target target_first_day = target_cleaned[self.config.date_col].min()