Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Welcome to TorchSurv's documentation!
notebooks/introduction
notebooks/momentum
notebooks/regression_time_varying
notebooks/synthetic_data_signal
notebooks/non_medical_applications

.. toctree::
Expand Down
273 changes: 273 additions & 0 deletions docs/notebooks/synthetic_data_signal.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ IssueTracker = "https://github.com/Novartis/torchsurv/issues"
Changelog = "https://opensource.nibr.com/torchsurv/CHANGELOG.html"

[tool.codespell]
ignore-words-list = ["TE", "FPR", "tOI", "te", "FO", "MIs", "fO", "nd"] # Known false positives
ignore-words-list = ["TE", "FPR", "tOI", "te", "FO", "MIs", "fO", "nd", "ue"] # Known false positives
skip = [
"*.bib",
"*.toml",
Expand Down
4 changes: 3 additions & 1 deletion src/torchsurv/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""This module provides validation utilities for survival analysis inputs."""
"""Utilities for survival analysis inputs and synthetic benchmarking data."""

from __future__ import annotations

from torchsurv.tools.synthetic import make_synthetic_data
from torchsurv.tools.validators import (
EvalTimeInputs,
ModelInputs,
Expand All @@ -18,6 +19,7 @@
"NewTimeInputs",
"SurvivalInputs",
"TimeVaryingCoxInputs",
"make_synthetic_data",
"impute_missing_log_shape",
"validate_time_varying_log_hz",
]
152 changes: 152 additions & 0 deletions src/torchsurv/tools/synthetic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from __future__ import annotations

import math

import torch

__all__ = ["make_synthetic_data"]


def _validate_positive_int(name: str, value: int) -> None:
if not isinstance(value, int):
raise ValueError(f"Input '{name}' must be an integer.")
if value <= 0:
raise ValueError(f"Input '{name}' must be strictly positive.")


def _validate_probability(name: str, value: float, *, allow_one: bool) -> float:
if not isinstance(value, (int, float)):
raise ValueError(f"Input '{name}' must be a number.")
value = float(value)
upper_bound = 1.0 if allow_one else 1.0 - 1e-12
if value < 0.0 or value > upper_bound:
comparator = "[0, 1]" if allow_one else "[0, 1)"
raise ValueError(f"Input '{name}' must be in {comparator}.")
return value


def _standardize(x: torch.Tensor) -> torch.Tensor:
return (x - x.mean()) / x.std().clamp_min(torch.finfo(x.dtype).eps)


def _calibrate_censoring(
event_time: torch.Tensor, raw_censor_time: torch.Tensor, censoring_rate: float
) -> torch.Tensor:
"""Scale raw censoring times to approximately match the requested censoring rate."""
if censoring_rate == 0.0:
return torch.full_like(event_time, torch.inf)

lower = torch.tensor(0.0, dtype=event_time.dtype, device=event_time.device)
upper = torch.tensor(1.0, dtype=event_time.dtype, device=event_time.device)

def observed_censoring(scale: torch.Tensor) -> torch.Tensor:
return (scale * raw_censor_time < event_time).float().mean()

while observed_censoring(upper) > censoring_rate:
upper = upper * 2.0

for _ in range(60):
midpoint = (lower + upper) / 2.0
if observed_censoring(midpoint) > censoring_rate:
lower = midpoint
else:
upper = midpoint

return upper * raw_censor_time


def make_synthetic_data(
n: int,
m: int,
rho: float,
*,
censoring_rate: float = 0.3,
seed: int | None = None,
) -> dict[str, torch.Tensor]:
"""Generate a synthetic survival dataset for Cox-model benchmarking.

The generated data follow a proportional-hazards construction:

- features are IID Gaussian,
- the true log-risk is a mixture of feature-derived signal and independent noise,
- event times follow an exponential baseline hazard scaled by ``exp(log_risk)``,
- censoring is independent and calibrated to an approximate target rate.

Args:
n: Number of samples.
m: Number of features.
rho: Signal strength in ``[0, 1]``. ``rho=1`` means the latent log-risk
is fully determined by the features, while ``rho=0`` means the
latent log-risk is independent of the features.
censoring_rate: Approximate fraction of censored observations in ``[0, 1)``.
seed: Optional random seed for reproducibility.

Returns:
Dictionary containing:

- ``x``: Covariate matrix with shape ``(n, m)``
- ``event``: Event indicator with shape ``(n,)`` and dtype ``bool``
- ``time``: Observed event/censoring time with shape ``(n,)``
- ``log_risk``: Ground-truth latent Cox log-risk with shape ``(n,)``
- ``beta``: Ground-truth normalized feature coefficients with shape ``(m,)``

Examples:
>>> batch = make_synthetic_data(n=64, m=8, rho=0.75, seed=7)
>>> sorted(batch.keys())
['beta', 'event', 'log_risk', 'time', 'x']
>>> batch["x"].shape
torch.Size([64, 8])
"""
_validate_positive_int("n", n)
_validate_positive_int("m", m)
rho = _validate_probability("rho", rho, allow_one=True)
censoring_rate = _validate_probability("censoring_rate", censoring_rate, allow_one=False)

generator = torch.Generator()
if seed is not None:
generator.manual_seed(seed)

x = torch.randn((n, m), generator=generator, dtype=torch.float32)

beta = torch.randn((m,), generator=generator, dtype=torch.float32)
beta = beta / beta.norm().clamp_min(torch.finfo(beta.dtype).eps)

signal = _standardize(x @ beta)
noise = _standardize(torch.randn((n,), generator=generator, dtype=torch.float32))

log_risk = math.sqrt(rho) * signal + math.sqrt(1.0 - rho) * noise

baseline_hazard = torch.tensor(0.1, dtype=torch.float32)
uniforms = torch.rand((n,), generator=generator, dtype=torch.float32).clamp_min(torch.finfo(torch.float32).eps)
event_time = -torch.log(uniforms) / (baseline_hazard * torch.exp(log_risk))

censor_uniforms = torch.rand((n,), generator=generator, dtype=torch.float32).clamp_min(
torch.finfo(torch.float32).eps
)
raw_censor_time = -torch.log(censor_uniforms)
censor_time = _calibrate_censoring(event_time, raw_censor_time, censoring_rate)

event = event_time <= censor_time
time = torch.minimum(event_time, censor_time)

if not event.any():
first_event = torch.argmin(event_time)
event[first_event] = True
time[first_event] = event_time[first_event]

return {
"x": x,
"event": event,
"time": time,
"log_risk": log_risk,
"beta": beta,
}


if __name__ == "__main__":
import doctest

# Run doctest
results = doctest.testmod()
if results.failed == 0:
print("All tests passed.")
91 changes: 91 additions & 0 deletions tests/test_synthetic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import pytest
import torch
from torch import nn

from torchsurv.loss.cox import neg_partial_log_likelihood
from torchsurv.metrics.cindex import ConcordanceIndex
from torchsurv.tools import make_synthetic_data


class TestSyntheticData:
def test_shapes_and_dtypes(self):
batch = make_synthetic_data(n=128, m=6, rho=0.75, seed=12)

assert set(batch) == {"x", "event", "time", "log_risk", "beta"}
assert batch["x"].shape == (128, 6)
assert batch["event"].shape == (128,)
assert batch["time"].shape == (128,)
assert batch["log_risk"].shape == (128,)
assert batch["beta"].shape == (6,)
assert batch["x"].dtype == torch.float32
assert batch["event"].dtype == torch.bool
assert batch["time"].dtype == torch.float32
assert batch["log_risk"].dtype == torch.float32
assert batch["beta"].dtype == torch.float32
assert torch.all(batch["time"] >= 0.0)
assert batch["event"].any()

def test_reproducible_with_seed(self):
batch_a = make_synthetic_data(n=64, m=4, rho=0.5, censoring_rate=0.2, seed=7)
batch_b = make_synthetic_data(n=64, m=4, rho=0.5, censoring_rate=0.2, seed=7)

for key in batch_a:
assert torch.equal(batch_a[key], batch_b[key])

@pytest.mark.parametrize(
("kwargs", "message"),
[
({"n": 0, "m": 4, "rho": 0.5}, "n"),
({"n": 32, "m": 0, "rho": 0.5}, "m"),
({"n": 32, "m": 4, "rho": -0.1}, "rho"),
({"n": 32, "m": 4, "rho": 1.1}, "rho"),
({"n": 32, "m": 4, "rho": 0.5, "censoring_rate": -0.1}, "censoring_rate"),
({"n": 32, "m": 4, "rho": 0.5, "censoring_rate": 1.0}, "censoring_rate"),
],
)
def test_invalid_inputs_raise(self, kwargs, message):
with pytest.raises(ValueError, match=message):
make_synthetic_data(**kwargs)

def test_signal_strength_controls_concordance(self):
cindex = ConcordanceIndex()
high_signal = []
low_signal = []

for seed in range(5):
high_batch = make_synthetic_data(n=256, m=8, rho=1.0, seed=seed)
low_batch = make_synthetic_data(n=256, m=8, rho=0.0, seed=seed)
high_estimate = high_batch["x"] @ high_batch["beta"]
low_estimate = low_batch["x"] @ low_batch["beta"]

high_signal.append(cindex(high_estimate, high_batch["event"], high_batch["time"], instate=False).item())
low_signal.append(cindex(low_estimate, low_batch["event"], low_batch["time"], instate=False).item())

high_mean = sum(high_signal) / len(high_signal)
low_mean = sum(low_signal) / len(low_signal)

assert high_mean > low_mean + 0.2
assert 0.4 < low_mean < 0.6

def test_high_signal_dataset_is_trainable(self):
batch = make_synthetic_data(n=256, m=8, rho=1.0, seed=0)
model = nn.Linear(8, 1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)
cindex = ConcordanceIndex()

losses = []
with torch.no_grad():
initial_cindex = cindex(model(batch["x"]).squeeze(), batch["event"], batch["time"], instate=False).item()

for _ in range(80):
optimizer.zero_grad()
loss = neg_partial_log_likelihood(model(batch["x"]), batch["event"], batch["time"])
loss.backward()
optimizer.step()
losses.append(loss.item())

with torch.no_grad():
final_cindex = cindex(model(batch["x"]).squeeze(), batch["event"], batch["time"], instate=False).item()

assert losses[-1] < losses[0]
assert final_cindex > initial_cindex + 0.1
Loading