Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,24 @@ We obtain the concordance index for this batch with:
```python
>>> from torchsurv.metrics.cindex import ConcordanceIndex
>>> cindex = ConcordanceIndex()

### Competing risks with cause-specific Cox

For competing risks, the model outputs one log relative hazard per cause. Event labels are integer-coded with `0` for censoring and `1..K` for the observed cause.

```python
>>> from torch import nn
>>> from torchsurv.loss import competing_risks
>>> n_causes = 2
>>> event_cr = torch.randint(low=0, high=n_causes + 1, size=(n,), dtype=torch.long)
>>> model_cr = nn.Sequential(nn.Linear(16, n_causes))
>>> log_hz_cr = model_cr(x)
>>> loss = competing_risks.neg_partial_log_likelihood(log_hz_cr, event_cr, time)
>>> baseline = competing_risks.baseline_cumulative_incidence_function(log_hz_cr.detach(), event_cr, time)
>>> cif = competing_risks.cumulative_incidence_function(baseline, log_hz_cr.detach(), torch.tensor([25.0, 50.0]))
>>> print(cif.shape)
torch.Size([64, 2, 2])
```
>>> print(cindex(log_hz, event, time))
tensor(0.4062)
```
Expand Down
5 changes: 5 additions & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Changelog
=========

Version 0.2.1
-------------

* Added a first competing-risks implementation based on cause-specific Cox models, including CIF and event-free survival helpers.

Version 0.1.6
-------------

Expand Down
1 change: 1 addition & 0 deletions docs/loss.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Loss
:template: autosummary/module.rst

torchsurv.loss.cox
torchsurv.loss.competing_risks
torchsurv.loss.weibull
torchsurv.loss.survival
torchsurv.loss.momentum
33 changes: 33 additions & 0 deletions docs/notebooks/loss.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,38 @@ This model is particularly powerful when the true hazard does not follow a stand
- Observations are conditionally independent given the covariates.
- Numerical integration (trapezoidal rule) is sufficiently accurate given the time discretization.

### 5. Competing Risks with Cause-Specific Cox

The first competing-risks loss is available through:

```python
from torchsurv.loss.competing_risks import neg_partial_log_likelihood
```

In this setting, the event variable is integer-coded with `0` for censoring and `1, \ldots, K` for the observed cause. The model outputs one log relative hazard per cause,

$$
\log \lambda_{ik} = f_{\theta, k}(\mathbf{x}_i), \quad k \in \{1, \ldots, K\}.
$$

For each cause $k$, `TorchSurv` fits a cause-specific Cox model by treating cause $k$ as the event of interest and all other outcomes as censored at their observed times. The total loss is the sum of the per-cause Cox partial log-likelihoods:

$$
\text{npll}_{CR} = \sum_{k=1}^{K} \text{npll}_{k}.
$$

This parameterization supports neural networks with a final layer of width $K$, keeps the same Cox-style training loop, and enables prediction of:

- cause-specific baseline hazard increments,
- event-free survival,
- cumulative incidence functions (CIFs) for each cause.

**Assumptions.**
- Cause-specific proportional hazards within each cause.
- Independent right censoring.
- Correct specification of the cause-specific log-risk functions.
- Competing events are handled through separate cause-specific hazards rather than a direct subdistribution hazard model.


### FAQ: Choosing the Right Survival Model

Expand Down Expand Up @@ -214,3 +246,4 @@ Use the **Flexible Survival model** when you do not want to impose any parametri
| **Weibull** | $h_i(t) = \frac{\exp(f_{\theta_1}(\mathbf{x}_i))}{\exp(f_{\theta_2}(\mathbf{x}_i))} \left(\frac{t}{\exp(f_{\theta_2}(\mathbf{x}_i))}\right)^{\exp(f_{\theta_1}(\mathbf{x}_i)) - 1}$ | ✗ | ✓ | You expect monotonic hazard shape |
| **Exponential** | $h_i(t) = \frac{1}{\exp(f_\theta(\mathbf{x}_i))}$ | ✗ | ✓ | You expect constant risk over time |
| **Flexible Survival** | $h_i(t) = \exp(f_{\theta}(\mathbf{x}_i, t))$ | ✓ | ✗ (numerical approximation) | You need full flexibility, no parametric form |
| **Competing Risks (Cause-Specific Cox)** | $h_{ik}(t) = \lambda_{0k}(t)\exp(f_{\theta,k}(\mathbf{x}_i))$ | ✗ | ✓ (per-cause partial likelihood) | You need CIFs with multiple mutually exclusive event types |
5 changes: 5 additions & 0 deletions docs/package_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ graph LR

%% LOSS
LOSS --> COX["Cox"]:::sub
LOSS --> CR["Competing\nRisks"]:::sub
LOSS --> WEIBULL["Weibull"]:::sub
LOSS --> SURVIVAL["Survival\n(discrete-time)"]:::sub
LOSS --> MOMENTUM["Momentum"]:::sub
Expand All @@ -18,6 +19,10 @@ graph LR
COX --> C2["baseline_survival_function"]:::fn
COX --> C3["survival_function"]:::fn

CR --> CR1["neg_partial_log_likelihood"]:::fn
CR --> CR2["baseline_cumulative_incidence_function"]:::fn
CR --> CR3["cumulative_incidence_function\n· survival_function"]:::fn

WEIBULL --> W1["neg_log_likelihood"]:::fn
WEIBULL --> W2["log_hazard"]:::fn
WEIBULL --> W3["survival_function"]:::fn
Expand Down
2 changes: 2 additions & 0 deletions src/torchsurv/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

from __future__ import annotations

from torchsurv.loss import competing_risks
from torchsurv.loss.cox import baseline_survival_function, neg_partial_log_likelihood, survival_function_cox
from torchsurv.loss.momentum import Momentum
from torchsurv.loss.survival import neg_log_likelihood, survival_function
from torchsurv.loss.weibull import log_hazard, neg_log_likelihood_weibull, survival_function_weibull

__all__ = [
"baseline_survival_function",
"competing_risks",
"log_hazard",
"Momentum",
"neg_log_likelihood",
Expand Down
Loading
Loading