Skip to content

Commit

Permalink
added doctest
Browse files Browse the repository at this point in the history
  • Loading branch information
tcoroller committed Jan 6, 2025
1 parent 48cdcc0 commit c38c6b3
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 81 deletions.
52 changes: 45 additions & 7 deletions docs/notebooks/loss_time_covariates.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,40 @@


def neg_partial_time_log_likelihood(
log_hz: torch.Tensor, # Txnxp torch tensor, n is batch size, T number of time points, p is number of different covariates over time
time: torch.Tensor, # n length vector, time at which someone experiences event
events: torch.Tensor, # n length vector, boolean, true or false to determine if someone had an event
log_hz: torch.Tensor,
time: torch.Tensor,
events: torch.Tensor,
reduction: str = "mean",
) -> torch.Tensor:
"""
needs further work
Compute the negative partial log-likelihood for time-dependent covariates in a Cox proportional hazards model.
Args:
log_hz (torch.Tensor): A tensor of shape (T, n, p) where T is the number of time points, n is the batch size,
and p is the number of different covariates over time.
time (torch.Tensor): A tensor of length n representing the time at which an event occurs for each individual.
events (torch.Tensor): A boolean tensor of length n indicating whether an event occurred (True) or not (False) for each individual.
reduction (str, optional): Specifies the reduction to apply to the output: 'mean' | 'sum'. Default is 'mean'.
Returns:
torch.Tensor: The computed negative partial log-likelihood. If reduction is 'mean', returns the mean value.
If reduction is 'sum', returns the sum of the values.
Raises:
ValueError: If the specified reduction method is not 'mean' or 'sum'.
Examples:
>>> _ = torch.manual_seed(52)
>>> n = 10 # number of samples
>>> t = 5 # time steps
>>> time = torch.randint(low=5, high=250, size=(n,)).float()
>>> event = torch.randint(low=0, high=2, size=(n,)).bool()
>>> log_hz = torch.rand((t, n, 1))
>>> neg_partial_time_log_likelihood(log_hz, time, event)
tensor(0.9456)
>>> neg_partial_time_log_likelihood(log_hz.squeeze(), time, event) # Also works with 2D tensor
tensor(0.9456)
>>> neg_partial_time_log_likelihood(log_hz, time, event, reduction='sum')
tensor(37.8241)
"""

# only consider theta at tiem of
pll = _partial_likelihood_time_cox(log_hz, time, events)

Expand Down Expand Up @@ -86,7 +112,15 @@ def _partial_likelihood_time_cox(
we want to identify the index of the covariate upon failure. We could either consider the last covariate before a series of zeros
(requires special data formatting but could reduce issues as it automatically contains event time information).
Examples:
>>> _ = torch.manual_seed(52)
>>> n = 3 # number of samples
>>> t = 3 # time steps
>>> time = torch.randint(low=5, high=250, size=(n,)).float()
>>> event = torch.randint(low=0, high=2, size=(n,)).bool()
>>> log_hz = torch.rand((t, n, 1))
>>> _partial_likelihood_time_cox(log_hz, time, event)
tensor([-1.3772, -1.0683, -0.7879, -0.8220, 0.0000, 0.0000])
"""
# Last dimension must be equal to 1 if shape == 3
if len(log_hz.shape) == 3:
Expand Down Expand Up @@ -114,13 +148,13 @@ def _partial_likelihood_time_cox(
log_hz_sorted_tj = torch.gather(log_hz_sorted, 1, idx.expand(log_hz_sorted.size()))

# same step as in normal cox loss, just again, we consider Z(tau_j) where tau_j denotes event time to subject j
log_denominator_tj = torch.logcumsumexp(log_hz_sorted.flip(0), dim=0).flip(0)
log_cumulative_hazard = torch.logcumsumexp(log_hz_sorted.flip(0), dim=0).flip(0)

# Keep only patients with events
include = events_sorted.expand(log_hz_sorted.size())

# return the partial log likelihood
return (log_hz_sorted_tj - log_denominator_tj)[include]
return (log_hz_sorted_tj - log_cumulative_hazard)[include]


def _time_varying_covariance(
Expand Down Expand Up @@ -168,6 +202,10 @@ def _time_varying_covariance(
if __name__ == "__main__":
import torch
from torchsurv.metrics.cindex import ConcordanceIndex
import doctest

# Run doctest
results = doctest.testmod()

# set seed
torch.manual_seed(123)
Expand Down
Loading

0 comments on commit c38c6b3

Please sign in to comment.