From da22621681672b4c96b0c0fe79acfe7a6d4aaf91 Mon Sep 17 00:00:00 2001 From: Nikita Makarov Date: Thu, 2 Apr 2026 15:42:52 +0000 Subject: [PATCH] Added fractional units and hours/mins/seconds --- docs/dataset-format.md | 2 +- docs/tutorials.md | 2 +- .../customizing_text_generation.ipynb | 11 +- pyproject.toml | 2 +- tests/test_delta_units.py | 540 ++++++++++++++++++ twinweaver/common/config.py | 51 +- twinweaver/common/converter_base.py | 50 +- twinweaver/instruction/converter_events.py | 2 +- .../instruction/converter_forecasting.py | 12 +- .../instruction/data_splitter_events.py | 39 +- 10 files changed, 665 insertions(+), 46 deletions(-) create mode 100644 tests/test_delta_units.py diff --git a/docs/dataset-format.md b/docs/dataset-format.md index 564cced..632bff9 100644 --- a/docs/dataset-format.md +++ b/docs/dataset-format.md @@ -106,7 +106,7 @@ On the first visit, the patient experienced the following: #### Relative Dating -TwinWeaver uses **relative dating** instead of absolute dates. All calendar dates from the input data are converted into time deltas relative to the previous visit (e.g., *"2 weeks later"*) rather than being included as raw dates (e.g., *"2024-01-29"*). This serves two important purposes: first, it **anonymizes** the patient data by removing identifiable calendar dates from the training text; second, it provides the model with **clinically meaningful temporal context** — the time elapsed between visits — rather than arbitrary date strings. By default, time intervals are expressed in weeks, but this can be changed to days using `Config.set_delta_time_unit("days")`. Accumulative deltas (time since the very first visit rather than since the previous visit) are also supported. +TwinWeaver uses **relative dating** instead of absolute dates. All calendar dates from the input data are converted into time deltas relative to the previous visit (e.g., *"2 weeks later"*) rather than being included as raw dates (e.g., *"2024-01-29"*). This serves two important purposes: first, it anonymizes the patient data by removing identifiable calendar dates from the training text; second, it prevents model overfitting on specific timestamps. By default, time intervals are expressed in weeks, but this can be changed using `Config.set_delta_time_unit()`. Supported units are `"days"`, `"weeks"`, `"hours"`, `"minutes"`, and `"seconds"`. Fractional values are naturally supported for every unit (e.g. *"0.5 days later"* for events 12 hours apart). Accumulative deltas (time since the very first visit rather than since the previous visit) are also supported. ### Final Output Structure diff --git a/docs/tutorials.md b/docs/tutorials.md index d3befe6..a7e576e 100644 --- a/docs/tutorials.md +++ b/docs/tutorials.md @@ -129,7 +129,7 @@ A comprehensive tutorial on customizing **every textual component** of the instr - Customizing preamble and introduction text - Modifying demographics section formatting - Changing event day and time interval descriptions -- Switching time units between days and weeks +- Switching time units between days, weeks, hours, minutes, and seconds - Customizing genetic data tags and placeholder text - Modifying forecasting, time-to-event, and QA task prompts - Configuring multi-task instruction formatting diff --git a/examples/advanced/custom_output/customizing_text_generation.ipynb b/examples/advanced/custom_output/customizing_text_generation.ipynb index 03ea854..c327b39 100644 --- a/examples/advanced/custom_output/customizing_text_generation.ipynb +++ b/examples/advanced/custom_output/customizing_text_generation.ipynb @@ -13,7 +13,7 @@ "1. **Preamble & Introduction Text** - Customizing the opening text of patient records\n", "2. **Demographics Section** - Modifying how constant/static data is introduced\n", "3. **Event Day Formatting** - Changing how visit days and time intervals are described\n", - "4. **Time Units** - Switching between days and weeks\n", + "4. **Time Units** - Switching between days, weeks, hours, minutes, and seconds\n", "5. **Genetic Data Formatting** - Customizing genetic event tags and text\n", "6. **Forecasting Prompts** - Modifying value prediction task descriptions\n", "7. **Time-to-Event Prompts** - Customizing survival/event prediction text\n", @@ -286,7 +286,7 @@ "source": [ "### 2.4 Time Units\n", "\n", - "You can switch between `days` and `weeks` for time intervals. Use `set_delta_time_unit()` to update all related prompts automatically." + "You can switch between `days`, `weeks`, `hours`, `minutes`, and `seconds` for time intervals. Fractional values are naturally supported for every unit (e.g., `0.5 days` for events 12 hours apart). Use `set_delta_time_unit()` to update all related prompts automatically." ] }, { @@ -300,7 +300,10 @@ "# CUSTOMIZING TIME UNITS\n", "# ============================================================================\n", "# Option 1: Use the helper method (updates all time-related prompts)\n", + "# Supported units: \"days\", \"weeks\", \"hours\", \"minutes\", \"seconds\"\n", + "# (and parenthesised variants like \"day(s)\", \"hour(s)\", etc.)\n", "# config_custom.set_delta_time_unit(\"days\", unit_sing=\"day\")\n", + "# config_custom.set_delta_time_unit(\"hours\", unit_sing=\"hour\")\n", "\n", "# Option 2: Set directly (if you want different phrasing)\n", "config_custom.delta_time_unit = \"weeks\"\n", @@ -717,7 +720,7 @@ "| `event_day_text` | Text for subsequent visits with time delta | \" weeks later...\" |\n", "| `post_event_text` | Text after listing day's events | \".\\n\" |\n", "| **Time Units** | | |\n", - "| `delta_time_unit` | Time unit for intervals | \"weeks\" |\n", + "| `delta_time_unit` | Time unit for intervals (`\"days\"`, `\"weeks\"`, `\"hours\"`, `\"minutes\"`, `\"seconds\"`) | \"weeks\" |\n", "| `forecasting_prompt_var_time` | Time description in forecasting | \" the future weeks \" |\n", "| **Genetic Data** | | |\n", "| `genetic_tag_opening` | Opening tag for genetic data | \"\" |\n", @@ -772,7 +775,7 @@ "\n", "1. **Preamble and introduction text** sets the context for the patient record\n", "2. **Event day formatting** controls how clinical visits are described temporally\n", - "3. **Time units** can be switched between days and weeks using `set_delta_time_unit()`\n", + "3. **Time units** can be switched between days, weeks, hours, minutes, and seconds using `set_delta_time_unit()` — fractional values are supported for every unit\n", "4. **Genetic data formatting** uses customizable tags and placeholder text\n", "5. **Task prompts** (forecasting, TTE, QA) can be fully rewritten for different LLM styles\n", "6. **Multi-task formatting** allows structured output for complex prediction tasks\n", diff --git a/pyproject.toml b/pyproject.toml index a059c54..7985fb0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ build-backend = "setuptools.build_meta" # Standard entry point for setuptools bu [project] name = "twinweaver" -version = "0.3.4" +version = "0.3.5" description = "Converting longitudinal patient data into text for LLM-based event prediction and forecasting." # --- NEW/UPDATED FIELDS --- diff --git a/tests/test_delta_units.py b/tests/test_delta_units.py new file mode 100644 index 0000000..d74804d --- /dev/null +++ b/tests/test_delta_units.py @@ -0,0 +1,540 @@ +"""Tests for extended delta_time_unit support. + +Covers: +- Config.set_delta_time_unit with hours, minutes, seconds +- Fractional days support +- ConverterBase _time_divisor mapping for all units +- _delta_to_timedelta round-trip precision +- DataSplitterEvents unit_length_to_sample with new units +- End-to-end forward conversion with non-default units +""" + +import pytest +import pandas as pd + +from twinweaver.common.config import Config +from twinweaver.common.data_manager import DataManager +from twinweaver.instruction.converter_instruction import ConverterInstruction +from twinweaver.instruction.data_splitter_events import DataSplitterEvents +from twinweaver.instruction.data_splitter_forecasting import DataSplitterForecasting +from twinweaver.instruction.data_splitter import DataSplitter +from twinweaver.pretrain.converter_pretrain import ConverterPretrain + + +# ── Config unit tests ────────────────────────────────────────────────────── + + +class TestConfigSetDeltaTimeUnit: + """Tests for Config.set_delta_time_unit with all supported units.""" + + def test_set_hours(self): + cfg = Config() + cfg.set_delta_time_unit("hours", unit_sing="hour") + assert cfg.delta_time_unit == "hours" + assert "hours" in cfg.event_day_text + assert "hour" in cfg.forecasting_fval_prompt_start + + def test_set_minutes(self): + cfg = Config() + cfg.set_delta_time_unit("minutes", unit_sing="minute") + assert cfg.delta_time_unit == "minutes" + assert "minutes" in cfg.event_day_text + assert "minute" in cfg.forecasting_fval_prompt_start + + def test_set_seconds(self): + cfg = Config() + cfg.set_delta_time_unit("seconds", unit_sing="second") + assert cfg.delta_time_unit == "seconds" + assert "seconds" in cfg.event_day_text + assert "second" in cfg.forecasting_fval_prompt_start + + def test_set_hours_parenthesised(self): + cfg = Config() + cfg.set_delta_time_unit("hour(s)") + assert cfg.delta_time_unit == "hour(s)" + assert "hour(s)" in cfg.event_day_text + + def test_set_minutes_parenthesised(self): + cfg = Config() + cfg.set_delta_time_unit("minute(s)") + assert cfg.delta_time_unit == "minute(s)" + + def test_set_seconds_parenthesised(self): + cfg = Config() + cfg.set_delta_time_unit("second(s)") + assert cfg.delta_time_unit == "second(s)" + + def test_set_days_still_works(self): + cfg = Config() + cfg.set_delta_time_unit("days", unit_sing="day") + assert cfg.delta_time_unit == "days" + assert "days" in cfg.event_day_text + + def test_set_weeks_still_works(self): + cfg = Config() + cfg.set_delta_time_unit("weeks", unit_sing="week") + assert cfg.delta_time_unit == "weeks" + assert "weeks" in cfg.event_day_text + + def test_invalid_unit_raises(self): + cfg = Config() + with pytest.raises(AssertionError): + cfg.set_delta_time_unit("fortnights") + + def test_invalid_singular_raises(self): + cfg = Config() + with pytest.raises(AssertionError): + cfg.set_delta_time_unit("hours", unit_sing="fortnight") + + def test_singular_defaults_to_plural(self): + """When unit_sing is None the plural form is used in prompts.""" + cfg = Config() + cfg.set_delta_time_unit("hours") + assert "hours" in cfg.forecasting_fval_prompt_start + + def test_class_constants_complete(self): + """The class-level constant tuples must contain all expected entries.""" + assert "hours" in Config.VALID_DELTA_UNITS_PLURAL + assert "minutes" in Config.VALID_DELTA_UNITS_PLURAL + assert "seconds" in Config.VALID_DELTA_UNITS_PLURAL + assert "hour" in Config.VALID_DELTA_UNITS_SINGULAR + assert "minute" in Config.VALID_DELTA_UNITS_SINGULAR + assert "second" in Config.VALID_DELTA_UNITS_SINGULAR + + +# ── ConverterBase _time_divisor tests ────────────────────────────────────── + + +class TestConverterBaseTimeDivisor: + """Test that _time_divisor is set correctly for all unit variants.""" + + @pytest.fixture() + def _make_converter(self, mock_config, sample_data): + """Factory that returns a ConverterPretrain for a given delta_time_unit.""" + df_events, df_constant, df_constant_desc = sample_data + mock_config.constant_columns_to_use = [ + "birthyear", + "gender", + "histology", + "smoking_history", + ] + + dm = DataManager(config=mock_config) + dm.load_indication_data(df_events, df_constant, df_constant_desc) + dm.process_indication_data() + dm.setup_unique_mapping_of_events() + + def _build(unit: str): + mock_config.set_delta_time_unit(unit) + return ConverterPretrain(config=mock_config, dm=dm) + + return _build + + @pytest.mark.parametrize( + "unit, expected_divisor", + [ + ("weeks", 7.0), + ("week(s)", 7.0), + ("days", 1.0), + ("day(s)", 1.0), + ("hours", 1.0 / 24.0), + ("hour(s)", 1.0 / 24.0), + ("minutes", 1.0 / (24.0 * 60.0)), + ("minute(s)", 1.0 / (24.0 * 60.0)), + ("seconds", 1.0 / (24.0 * 60.0 * 60.0)), + ("second(s)", 1.0 / (24.0 * 60.0 * 60.0)), + ], + ) + def test_time_divisor_values(self, _make_converter, unit, expected_divisor): + converter = _make_converter(unit) + assert converter._time_divisor == pytest.approx(expected_divisor) + + def test_unsupported_unit_raises(self, mock_config, sample_data): + df_events, df_constant, df_constant_desc = sample_data + mock_config.constant_columns_to_use = [ + "birthyear", + "gender", + "histology", + "smoking_history", + ] + dm = DataManager(config=mock_config) + dm.load_indication_data(df_events, df_constant, df_constant_desc) + dm.process_indication_data() + dm.setup_unique_mapping_of_events() + + # Bypass the Config assertion to force an unsupported unit + mock_config.delta_time_unit = "fortnights" + with pytest.raises(ValueError, match="Unsupported delta_time_unit"): + ConverterPretrain(config=mock_config, dm=dm) + + +# ── _delta_to_timedelta precision tests ──────────────────────────────────── + + +class TestDeltaToTimedelta: + """Test that _delta_to_timedelta round-trips correctly.""" + + @pytest.fixture() + def converter_for_unit(self, mock_config, sample_data): + df_events, df_constant, df_constant_desc = sample_data + mock_config.constant_columns_to_use = [ + "birthyear", + "gender", + "histology", + "smoking_history", + ] + dm = DataManager(config=mock_config) + dm.load_indication_data(df_events, df_constant, df_constant_desc) + dm.process_indication_data() + dm.setup_unique_mapping_of_events() + + def _build(unit: str): + mock_config.set_delta_time_unit(unit) + return ConverterPretrain(config=mock_config, dm=dm) + + return _build + + def test_weeks_rounds_to_days(self, converter_for_unit): + conv = converter_for_unit("weeks") + td = conv._delta_to_timedelta(3.0) # 3 weeks = 21 days + assert td == pd.Timedelta(days=21) + # Fractional weeks: 2.5 weeks = 17.5 days → rounded to 18 + td2 = conv._delta_to_timedelta(2.5) + assert td2 == pd.Timedelta(days=18) + + def test_days_rounds_to_days(self, converter_for_unit): + conv = converter_for_unit("days") + td = conv._delta_to_timedelta(5.0) + assert td == pd.Timedelta(days=5) + # Fractional day: 0.7 → round to 1 + td2 = conv._delta_to_timedelta(0.7) + assert td2 == pd.Timedelta(days=1) + + def test_hours_preserves_sub_day(self, converter_for_unit): + conv = converter_for_unit("hours") + td = conv._delta_to_timedelta(6.0) # 6 hours + assert td == pytest.approx(pd.Timedelta(hours=6), abs=pd.Timedelta(seconds=1)) + + def test_minutes_preserves_sub_day(self, converter_for_unit): + conv = converter_for_unit("minutes") + td = conv._delta_to_timedelta(90.0) # 90 min = 1.5 hours + assert td == pytest.approx(pd.Timedelta(minutes=90), abs=pd.Timedelta(seconds=1)) + + def test_seconds_preserves_sub_day(self, converter_for_unit): + conv = converter_for_unit("seconds") + td = conv._delta_to_timedelta(3600.0) # 3600 s = 1 hour + # Allow tiny floating-point rounding (nanosecond level) + assert abs(td - pd.Timedelta(hours=1)) < pd.Timedelta(microseconds=1) + + +# ── Fractional days in text conversion ───────────────────────────────────── + + +class TestFractionalDays: + """Fractional day deltas should appear in generated text.""" + + @pytest.fixture() + def converter_days(self, mock_config, sample_data): + df_events, df_constant, df_constant_desc = sample_data + mock_config.constant_columns_to_use = [ + "birthyear", + "gender", + "histology", + "smoking_history", + ] + dm = DataManager(config=mock_config) + dm.load_indication_data(df_events, df_constant, df_constant_desc) + dm.process_indication_data() + dm.setup_unique_mapping_of_events() + + mock_config.set_delta_time_unit("days") + converter = ConverterPretrain(config=mock_config, dm=dm) + return dm, converter + + def test_fractional_day_text(self, converter_days): + """When events are <1 day apart the delta should be a fraction like '0.5'.""" + dm, converter = converter_days + + # Build a minimal 2-event DataFrame with 12 hours between events + events = pd.DataFrame( + { + "date": pd.to_datetime(["2024-01-01 00:00", "2024-01-01 12:00"]), + "event_category": ["lab", "lab"], + "event_name": ["glucose", "glucose"], + "event_descriptive_name": ["glucose", "glucose"], + "event_value": ["100", "110"], + "source": ["events", "events"], + "meta_data": [pd.NA, pd.NA], + } + ) + + text = converter._get_event_string(events) + # 12 h = 0.5 days + assert "0.5" in text + assert "days" in text + + +# ── DataSplitterEvents with new units ────────────────────────────────────── + + +class TestDataSplitterEventsUnits: + """DataSplitterEvents should accept hours, minutes, seconds.""" + + @pytest.fixture() + def initialized_dm(self, mock_config, sample_data): + df_events, df_constant, df_constant_desc = sample_data + mock_config.split_event_category = "lot" + mock_config.event_category_forecast = ["lab"] + mock_config.event_category_events_prediction_with_naming = { + "death": "death", + "progression": "next progression", + } + mock_config.constant_columns_to_use = [ + "birthyear", + "gender", + "histology", + "smoking_history", + ] + dm = DataManager(config=mock_config) + dm.load_indication_data(df_events, df_constant, df_constant_desc) + dm.process_indication_data() + dm.setup_unique_mapping_of_events() + dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1) + dm.infer_var_types() + return dm + + @pytest.mark.parametrize("unit", ["days", "hours", "minutes", "seconds"]) + def test_splitter_accepts_unit(self, initialized_dm, mock_config, unit): + """DataSplitterEvents should instantiate without error for each unit.""" + splitter = DataSplitterEvents( + initialized_dm, + config=mock_config, + max_length_to_sample=pd.Timedelta(weeks=104), + min_length_to_sample=pd.Timedelta(weeks=1), + unit_length_to_sample=unit, + max_split_length_after_split_event=pd.Timedelta(days=90), + ) + splitter.setup_variables() + assert splitter.unit_length_to_sample == unit + + def test_splitter_unsupported_unit(self, initialized_dm, mock_config): + splitter = DataSplitterEvents( + initialized_dm, + config=mock_config, + max_length_to_sample=pd.Timedelta(weeks=104), + min_length_to_sample=pd.Timedelta(weeks=1), + unit_length_to_sample="fortnights", + max_split_length_after_split_event=pd.Timedelta(days=90), + ) + splitter.setup_variables() + patient_data = initialized_dm.get_patient_data("p0") + with pytest.raises(NotImplementedError, match="not implemented"): + splitter.get_splits_from_patient(patient_data, max_nr_samples_per_split=1) + + +# ── End-to-end conversion with "days" unit ───────────────────────────────── + + +class TestEndToEndDaysUnit: + """Integration: full forward conversion with delta_time_unit='days'.""" + + @pytest.fixture() + def setup_days(self, mock_config, sample_data): + df_events, df_constant, df_constant_desc = sample_data + mock_config.split_event_category = "lot" + mock_config.event_category_forecast = ["lab"] + mock_config.event_category_events_prediction_with_naming = { + "death": "death", + "progression": "next progression", + } + mock_config.constant_columns_to_use = [ + "birthyear", + "gender", + "histology", + "smoking_history", + ] + mock_config.constant_birthdate_column = "birthyear" + mock_config.set_delta_time_unit("days", unit_sing="day") + + dm = DataManager(config=mock_config) + dm.load_indication_data(df_events, df_constant, df_constant_desc) + dm.process_indication_data() + dm.setup_unique_mapping_of_events() + dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1) + dm.infer_var_types() + + splitter_events = DataSplitterEvents( + dm, + config=mock_config, + max_length_to_sample=pd.Timedelta(days=730), + min_length_to_sample=pd.Timedelta(days=1), + unit_length_to_sample="days", + max_split_length_after_split_event=pd.Timedelta(days=90), + ) + splitter_events.setup_variables() + + splitter_forecast = DataSplitterForecasting( + data_manager=dm, + config=mock_config, + max_forecasted_trajectory_length=pd.Timedelta(days=90), + max_split_length_after_split_event=pd.Timedelta(days=90), + ) + splitter_forecast.setup_statistics() + + data_splitter = DataSplitter(splitter_events, splitter_forecast) + + converter = ConverterInstruction( + nr_tokens_budget_total=4096, + config=mock_config, + dm=dm, + variable_stats=splitter_forecast.variable_stats, + ) + return dm, data_splitter, converter + + def test_forward_conversion_uses_days(self, setup_days): + dm, data_splitter, converter = setup_days + patient_data = dm.get_patient_data("p0") + f_splits, e_splits, _ = data_splitter.get_splits_from_patient_with_target(patient_data) + + result = converter.forward_conversion( + forecasting_splits=f_splits[0], + event_splits=e_splits[0], + override_mode_to_select_forecasting="both", + ) + + instruction = result["instruction"] + # "weeks" should NOT appear in the temporal text, "days" should + assert "days later" in instruction + assert "weeks later" not in instruction + # Demographic section should still be present + assert "Starting with demographic data:" in instruction + + +# ── End-to-end conversion with "hours" unit ──────────────────────────────── + + +class TestEndToEndHoursUnit: + """Integration: full forward conversion with delta_time_unit='hours'.""" + + @pytest.fixture() + def setup_hours(self, mock_config, sample_data): + df_events, df_constant, df_constant_desc = sample_data + mock_config.split_event_category = "lot" + mock_config.event_category_forecast = ["lab"] + mock_config.event_category_events_prediction_with_naming = { + "death": "death", + "progression": "next progression", + } + mock_config.constant_columns_to_use = [ + "birthyear", + "gender", + "histology", + "smoking_history", + ] + mock_config.constant_birthdate_column = "birthyear" + mock_config.set_delta_time_unit("hours", unit_sing="hour") + + dm = DataManager(config=mock_config) + dm.load_indication_data(df_events, df_constant, df_constant_desc) + dm.process_indication_data() + dm.setup_unique_mapping_of_events() + dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1) + dm.infer_var_types() + + splitter_events = DataSplitterEvents( + dm, + config=mock_config, + max_length_to_sample=pd.Timedelta(weeks=104), + min_length_to_sample=pd.Timedelta(hours=1), + unit_length_to_sample="hours", + max_split_length_after_split_event=pd.Timedelta(days=90), + ) + splitter_events.setup_variables() + + splitter_forecast = DataSplitterForecasting( + data_manager=dm, + config=mock_config, + max_forecasted_trajectory_length=pd.Timedelta(days=90), + max_split_length_after_split_event=pd.Timedelta(days=90), + ) + splitter_forecast.setup_statistics() + + data_splitter = DataSplitter(splitter_events, splitter_forecast) + + converter = ConverterInstruction( + nr_tokens_budget_total=4096, + config=mock_config, + dm=dm, + variable_stats=splitter_forecast.variable_stats, + ) + return dm, data_splitter, converter + + def test_forward_conversion_uses_hours(self, setup_hours): + dm, data_splitter, converter = setup_hours + patient_data = dm.get_patient_data("p0") + f_splits, e_splits, _ = data_splitter.get_splits_from_patient_with_target(patient_data) + + result = converter.forward_conversion( + forecasting_splits=f_splits[0], + event_splits=e_splits[0], + override_mode_to_select_forecasting="both", + ) + + instruction = result["instruction"] + # The temporal text should use "hours", not "weeks" or "days" + assert "hours later" in instruction + assert "weeks later" not in instruction + assert "Starting with demographic data:" in instruction + + +# ── Pretrain conversion with "days" unit ─────────────────────────────────── + + +class TestPretrainDaysUnit: + """Integration: pretrain forward conversion with delta_time_unit='days'.""" + + @pytest.fixture() + def setup_pretrain_days(self, mock_config, sample_data): + df_events, df_constant, df_constant_desc = sample_data + mock_config.constant_columns_to_use = [ + "birthyear", + "gender", + "histology", + "smoking_history", + ] + mock_config.constant_birthdate_column = "birthyear" + mock_config.set_delta_time_unit("days", unit_sing="day") + + dm = DataManager(config=mock_config) + dm.load_indication_data(df_events, df_constant, df_constant_desc) + dm.process_indication_data() + dm.setup_unique_mapping_of_events() + + converter = ConverterPretrain(config=mock_config, dm=dm) + return dm, converter + + def test_pretrain_forward_uses_days(self, setup_pretrain_days): + dm, converter = setup_pretrain_days + patient_data = dm.get_patient_data("p0") + + result = converter.forward_conversion(patient_data["events"], patient_data["constant"]) + text = result["text"] + + assert "days later" in text + assert "weeks later" not in text + assert "Starting with demographic data:" in text + + def test_pretrain_roundtrip_days(self, setup_pretrain_days): + """Forward → reverse should preserve data at day-level precision for days unit.""" + dm, converter = setup_pretrain_days + patient_data = dm.get_patient_data("p0") + + forward_result = converter.forward_conversion(patient_data["events"], patient_data["constant"]) + text = forward_result["text"] + meta = forward_result["meta"] + + reverse_result = converter.reverse_conversion( + text=text, data_manager=dm, init_date=meta["events"]["date"].min() + ) + diff = converter.get_difference_in_event_dataframes(meta["events"], reverse_result["events"], skip_genetic=True) + assert diff.shape[0] == 0, f"Found differences in roundtrip conversion:\n{diff}" diff --git a/twinweaver/common/config.py b/twinweaver/common/config.py index a1d3f37..ff5d10b 100644 --- a/twinweaver/common/config.py +++ b/twinweaver/common/config.py @@ -20,8 +20,10 @@ class Config: date_cutoff : str | None If set, only use data before this date (format: "YYYY-MM-DD"), censored after. Default: None. delta_time_unit : str - Unit of time used to express intervals between patient visits in the generated text. Options are "days" or - "weeks". Default: "weeks". + Unit of time used to express intervals between patient visits in the generated text. Options are "days", + "weeks", "hours", "minutes", or "seconds" (and their parenthesised plurals like "day(s)"). + Fractional days are naturally supported when using "days" (e.g. 0.5 days). + Default: "weeks". numeric_detect_min_fraction: float Fraction of values that must be numeric to classify a variable as numeric. Defaults to 0.99. date_col : str @@ -254,7 +256,7 @@ def __init__(self): # --- Import data parameters --- self.date_cutoff = None # If set, only use data before this date (format: "YYYY-MM-DD"), censored after self.delta_time_unit: str = ( - "weeks" # Either "days" or "weeks" - if you change this, you need to call set_delta_time_unit + "weeks" # "days", "weeks", "hours", "minutes", or "seconds" - if you change this, call set_delta_time_unit ) self.numeric_detect_min_fraction: float = ( 0.99 # Fraction of numeric values required to consider an event as numeric @@ -437,14 +439,47 @@ def __init__(self): "progression": "death", } + # Valid plural and parenthesised-plural forms for delta_time_unit + VALID_DELTA_UNITS_PLURAL = ( + "days", + "weeks", + "hours", + "minutes", + "seconds", + "day(s)", + "week(s)", + "hour(s)", + "minute(s)", + "second(s)", + ) + # Valid singular forms (used for prompts where singular reads better) + VALID_DELTA_UNITS_SINGULAR = ("day", "week", "hour", "minute", "second") + def set_delta_time_unit(self, unit: str, unit_sing=None): """ - Set the time unit for delta time representation in text conversion. Possible to set either - "days" (and "day(s)") or "weeks" (and "week(s)"). Optionally, a singular form can be provided - for use in specific prompts. If not provided, the plural form will be used. + Set the time unit for delta time representation in text conversion. + + Supported plural forms: ``"days"``, ``"weeks"``, ``"hours"``, ``"minutes"``, + ``"seconds"`` (and their parenthesised variants like ``"day(s)"``). + Fractional values are naturally supported for every unit – e.g. ``0.5`` + days will appear as ``"0.5 days later …"`` in the generated text. + + Optionally, a singular form can be provided for use in specific prompts. + If not provided, the plural form will be used. + + Parameters + ---------- + unit : str + Plural (or parenthesised-plural) time unit. + unit_sing : str or None + Optional singular form of the unit (e.g. ``"hour"``). """ - assert unit in ("days", "weeks", "day(s)", "week(s)"), "unit must be either 'days' or 'weeks'" - assert unit_sing in (None, "day", "week"), "unit_sing must be either None, 'day' or 'week'" + assert unit in self.VALID_DELTA_UNITS_PLURAL, ( + f"unit must be one of {self.VALID_DELTA_UNITS_PLURAL}, got '{unit}'" + ) + assert unit_sing in (None, *self.VALID_DELTA_UNITS_SINGULAR), ( + f"unit_sing must be None or one of {self.VALID_DELTA_UNITS_SINGULAR}, got '{unit_sing}'" + ) self.delta_time_unit = unit if unit_sing is None: unit_sing = unit diff --git a/twinweaver/common/converter_base.py b/twinweaver/common/converter_base.py index 35dd983..2ff7ca6 100644 --- a/twinweaver/common/converter_base.py +++ b/twinweaver/common/converter_base.py @@ -129,13 +129,42 @@ def __init__(self, config: Config) -> None: self.always_keep_first_visit = None # Handles time division depending on config - if self.config.delta_time_unit in ["weeks", "week(s)"]: - self._time_divisor = 7.0 - elif self.config.delta_time_unit in ["days", "day(s)"]: - self._time_divisor = 1.0 + # _time_divisor converts a timedelta expressed in *days* into the + # desired unit. For sub-day units we use fractional-day arithmetic + # (e.g. 1 hour = 1/24 day, so divisor = 1/24). + _unit_to_divisor = { + "weeks": 7.0, + "week(s)": 7.0, + "days": 1.0, + "day(s)": 1.0, + "hours": 1.0 / 24.0, + "hour(s)": 1.0 / 24.0, + "minutes": 1.0 / (24.0 * 60.0), + "minute(s)": 1.0 / (24.0 * 60.0), + "seconds": 1.0 / (24.0 * 60.0 * 60.0), + "second(s)": 1.0 / (24.0 * 60.0 * 60.0), + } + unit = self.config.delta_time_unit + if unit in _unit_to_divisor: + self._time_divisor = _unit_to_divisor[unit] else: self._time_divisor = None - raise ValueError(f"Unsupported delta_time_unit: {self.config.delta_time_unit}") + raise ValueError(f"Unsupported delta_time_unit: {unit}. Supported values: {list(_unit_to_divisor.keys())}") + + # Whether the configured unit is sub-day (hours, minutes, seconds) + self._sub_day_unit = unit in ("hours", "hour(s)", "minutes", "minute(s)", "seconds", "second(s)") + + def _delta_to_timedelta(self, delta: float) -> pd.Timedelta: + """Convert a delta value (in the configured time unit) back to a :class:`pd.Timedelta`. + + For day-level units (days, weeks) the result is rounded to whole days so + that the original day-level data precision is preserved. For sub-day + units (hours, minutes, seconds) the full resolution is kept. + """ + days_float = delta * self._time_divisor + if self._sub_day_unit: + return pd.to_timedelta(days_float, unit="D") + return pd.to_timedelta(round(days_float), unit="D") def _preprocess_constant_date( self, @@ -347,8 +376,9 @@ def _get_event_string( #: sort by date using config constant events = events.sort_values(self.config.date_col).reset_index(drop=True) - #: for every visit get delta to previous in days or weeks, starting with 0 using config constant - events_delta = events[self.config.date_col].diff().dt.days / self._time_divisor + #: for every visit get delta to previous in the configured time unit, starting with 0 + # Use total_seconds() instead of .dt.days so fractional days and sub-day units are preserved. + events_delta = events[self.config.date_col].diff().dt.total_seconds() / (86400.0 * self._time_divisor) events_delta[0] = events_delta_0 events["delta"] = events_delta @@ -910,7 +940,11 @@ def _add_event_to_data( source_value = self.config.source_genetic if event_category == "unknown - genetic" else "events" new_event = { - self.config.date_col: prev_date + pd.to_timedelta(round(delta * self._time_divisor), unit="D"), + # Convert delta (in the configured time unit) back to a Timedelta. + # For day-level units (days, weeks) round to whole days to preserve + # the original day-level data precision. For sub-day units + # (hours, minutes, seconds) keep the full resolution. + self.config.date_col: prev_date + self._delta_to_timedelta(delta), self.config.event_category_col: event_category, self.config.event_name_col: event_name, self.config.event_descriptive_name_col: event_descriptive_name, diff --git a/twinweaver/instruction/converter_events.py b/twinweaver/instruction/converter_events.py index 40bb36d..572723c 100644 --- a/twinweaver/instruction/converter_events.py +++ b/twinweaver/instruction/converter_events.py @@ -163,7 +163,7 @@ def _generate_prompt(self, patient_split: DataSplitterEventsOption) -> tuple: #: get delta in time in config.delta_time_unit, rounded using round_and_strip delta_time_numeric = patient_split.observation_end_date - patient_split.split_date_included_in_input - delta_time_numeric = delta_time_numeric.days / self._time_divisor + delta_time_numeric = delta_time_numeric.total_seconds() / (86400.0 * self._time_divisor) delta_time = round_and_strip(delta_time_numeric, self.decimal_precision) diff --git a/twinweaver/instruction/converter_forecasting.py b/twinweaver/instruction/converter_forecasting.py index e94dd38..03f4392 100644 --- a/twinweaver/instruction/converter_forecasting.py +++ b/twinweaver/instruction/converter_forecasting.py @@ -132,7 +132,7 @@ def _generate_target_string(self, patient_split: DataSplitterForecastingOption) #: get delta between split and first target target_first_day = target_cleaned[self.config.date_col].min() split_date = patient_split.split_date_included_in_input - delta_days = (target_first_day - split_date).days / self._time_divisor + delta_days = (target_first_day - split_date).total_seconds() / (86400.0 * self._time_divisor) #: convert to string using default approach target_str = self._get_event_string( @@ -151,7 +151,9 @@ def _generate_target_string(self, patient_split: DataSplitterForecastingOption) # Ensure dates is a pd.Series or pd.DatetimeIndex for subtraction if not isinstance(dates, (pd.Series, pd.DatetimeIndex)): dates = pd.to_datetime(dates) - future_prediction_time_per_variable[variable] = (dates - split_date).days / self._time_divisor + future_prediction_time_per_variable[variable] = (dates - split_date).total_seconds() / ( + 86400.0 * self._time_divisor + ) # Get descriptive name curr_var = target_cleaned[target_cleaned[self.config.event_name_col] == variable] @@ -166,7 +168,7 @@ def _generate_target_string(self, patient_split: DataSplitterForecastingOption) # Ensure dates_to_forecast is pd.DatetimeIndex for subtraction if not isinstance(dates_to_forecast, pd.DatetimeIndex): dates_to_forecast = pd.to_datetime(dates_to_forecast) - future_weeks_to_forecast = (dates_to_forecast - split_date).days / self._time_divisor + future_weeks_to_forecast = (dates_to_forecast - split_date).total_seconds() / (86400.0 * self._time_divisor) #: add last observed values of each variable from input history input_history = patient_split.events_until_split @@ -374,9 +376,7 @@ def forward_conversion_inference( # Ensure weeks is iterable if not hasattr(weeks, "__iter__"): weeks = [weeks] - dates_per_variable[variable] = [ - split_date + pd.Timedelta(days=float(w) * self._time_divisor) for w in weeks - ] + dates_per_variable[variable] = [split_date + self._delta_to_timedelta(float(w)) for w in weeks] target_pseudo_meta["dates_per_variable"] = dates_per_variable #: make target_meta["variable_name_mapping"] by looking up in input events diff --git a/twinweaver/instruction/data_splitter_events.py b/twinweaver/instruction/data_splitter_events.py index fce1823..5e7531f 100644 --- a/twinweaver/instruction/data_splitter_events.py +++ b/twinweaver/instruction/data_splitter_events.py @@ -109,7 +109,8 @@ def __init__( The minimum length of time into the future to sample for event prediction. Required, no default. unit_length_to_sample : str - The unit of time for the length to sample (e.g. "weeks"). + The unit of time for the length to sample (e.g. "weeks", "days", "hours", + "minutes", or "seconds"). max_split_length_after_split_event : pd.Timedelta, optional The maximum number of days after the split event (e.g. line of therapy) to consider for split points. Defaults to 0 days. @@ -389,26 +390,32 @@ def get_splits_from_patient( continue prev_sampled_category.append(sampled_cateogry) - # Determine how many weeks to predict into the future + # Determine how many units to predict into the future if override_observation_time_delta is None: #: randomly sample end date -> so that we also get random values in between for consistency # This is so that the model can learn different time values for the same variable - #: To not bias the model, we select a random nr time as max end date`` - - if self.unit_length_to_sample == "days": - max_units = self.max_length_to_sample.days - min_units = self.min_length_to_sample.days - random_units = np.random.randint(min_units, max_units + 1) - end_time_delta = pd.Timedelta(days=random_units) - elif self.unit_length_to_sample == "weeks": - max_units = self.max_length_to_sample.days // 7 - min_units = self.min_length_to_sample.days // 7 - random_units = np.random.randint(min_units, max_units + 1) - end_time_delta = pd.Timedelta(weeks=random_units) - else: + #: To not bias the model, we select a random nr time as max end date + + # Mapping from unit name to the divisor that converts a Timedelta (in seconds) + # to integer units, plus the pd.Timedelta keyword for reconstruction. + _unit_info = { + "days": (86400, "days"), + "weeks": (86400 * 7, "weeks"), + "hours": (3600, "hours"), + "minutes": (60, "minutes"), + "seconds": (1, "seconds"), + } + unit = self.unit_length_to_sample + if unit not in _unit_info: raise NotImplementedError( - f"Unit length to sample {self.unit_length_to_sample} not implemented." + f"Unit length to sample '{unit}' not implemented. " + f"Supported units: {list(_unit_info.keys())}" ) + divisor, td_kwarg = _unit_info[unit] + max_units = int(self.max_length_to_sample.total_seconds() // divisor) + min_units = int(self.min_length_to_sample.total_seconds() // divisor) + random_units = np.random.randint(min_units, max_units + 1) + end_time_delta = pd.Timedelta(**{td_kwarg: random_units}) else: end_time_delta = override_observation_time_delta