Skip to content

Commit

Permalink
improve tex
Browse files Browse the repository at this point in the history
  • Loading branch information
melodiemonod committed Jul 19, 2024
1 parent ec6b55d commit a449e74
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions paper/paper.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,23 @@ bibliography: paper.bib

# Summary

`TorchSurv` (available on GitHub and PyPI) is a Python package that serves as a companion tool to perform deep survival modeling within the `PyTorch` environment [@paszke2019pytorch]. With its lightweight design, minimal input requirements, full `PyTorch` backend, and freedom from restrictive survival model parameterizations, `TorchSurv` facilitates efficient deep survival model implementation and is particularly beneficial for high-dimensional and complex input data scenarios.
`TorchSurv` has been rigorously tested using both open-source and synthetically generated survival data. The package is thoroughly documented and includes illustrative examples. The latest documentation for TorchSurv can be found on the [`TorchSurv`'s website](https://opensource.nibr.com/torchsurv/).
`TorchSurv` (available on GitHub and PyPI) is a Python package that serves as a companion tool to perform deep survival modeling within the `PyTorch` environment [@paszke2019pytorch]. With its lightweight design, minimal input requirements, full `PyTorch` backend, and freedom from restrictive parameterizations, `TorchSurv` facilitates efficient deep survival model implementation and is particularly beneficial for high-dimensional and complex data scenarios.
`TorchSurv` has been rigorously tested using both open-source and synthetically generated survival data. The package is thoroughly documented and includes illustrative examples. The latest documentation for TorchSurv can be found on the[`TorchSurv`'s website](https://opensource.nibr.com/torchsurv/).

`TorchSurv` provides a user-friendly workflow for defining a survival model with parameters specified by a `PyTorch`-based (deep) neural network. At the core of `TorchSurv` lies its `PyTorch`-based calculation of log-likelihoods for prominent survival models, including the Cox proportional hazards model [@Cox1972] and the Weibull Accelerated Time Failure (AFT) model [@Carroll2003].
In survival analysis, each observation is associated with survival data denoted by $y$ (comprising the event indicator and the time-to-event or censoring) and covariates denoted by $x$. A survival model that is able to capture the complexity of the survival data $y$, is parametrized by parameters denoted by $\theta$. For instance, in the Cox proportional hazards model, the survival model parameters $\theta$ are the relative hazards. Within the `TorchSurv` framework, a `PyTorch`-based neural network is defined to act as a flexible function that takes the covariates $x$ as input and outputs the survival model parameters $\theta$. Estimation of the parameters $\theta$ is achieved via maximum likelihood estimation facilitated by backpropagation.
Additionally, `TorchSurv` offers evaluation metrics (the time-dependent Area Under the cure (AUC) under the Receiver operating characteristic curve (ROC), the Concordance index (C-index) and the Brier Score) to characterize the predictive performance of survival models.
`TorchSurv` provides a user-friendly workflow for training and evaluating `PyTorch`-based deep survival models.
At its core, `TorchSurv` features `PyTorch`-based calculations of log-likelihoods for prominent survival models, including the Cox proportional hazards model [@Cox1972] and the Weibull Accelerated Time Failure (AFT) model [@Carroll2003].
In survival analysis, each observation is associated with survival reponse, denoted by $y$ (comprising the event indicator and the time-to-event or censoring), and covariates denoted by $x$. A survival model is parametrized by parameters, denoted by $\theta$. Within the `TorchSurv` framework, a `PyTorch`-based neural network is defined to act as a flexible function that takes the covariates $x$ as input and outputs the parameters $\theta$. Estimation of the parameters $\theta$ is achieved via maximum likelihood estimation facilitated by backpropagation.
Additionally, `TorchSurv` offers evaluation metrics, including the time-dependent Area Under the cure (AUC) under the Receiver operating characteristic (ROC) curve, the Concordance index (C-index) and the Brier Score, to characterize the predictive performance of survival models.
Below is an overview of the workflow for model inference and evaluation with `TorchSurv`:

1. Initialize a `PyTorch`-based neural network that defines the function from the covariates $x$ to the parameters $\theta$. In the context of the Cox proportional hazards model for example, the parameters are the log relative hazards.
2. Initiate training: For each epoch on the training set,
- Draw survival data $y^{\text{train}}$ (i.e., event indicator and time-to-event or censoring) and covariates $x^{\text{train}}$ from the training set.
- Obtain parameters $\theta^{\text{train}}$ based on drawn covariates $x^{\text{train}}$ using `PyTorch`-based neural network.
- Calculate the loss given survival data $y^{\text{train}}$ and parameters $\theta^{\text{train}}$ using `TorchSurv`'s loss function. In the context of the Cox relative hazards model for example, the loss function is equal to the negative of the log partial likelihood.
- Utilize backpropagation to update parameters $\theta^{\text{train}}$.
3. Obtain parameters $\theta^{\text{test}}$ based on covariates from the test set $x^{\text{test}}$ using the trained `PyTorch`-based neural network.
4. Evaluate the predictive performance of the model using `TorchSurv`'s evaluation metric functions (e.g., C-index) given parameters $\theta^{\text{test}}$ and survival data from the test set $y^{\text{test}}$.
- Draw survival response $y^{\text{train}}$ and covariates $x^{\text{train}}$ from the training set.
- Obtain parameters $\theta^{\text{train}}$ given covariates $x^{\text{train}}$ using the neural network.
- Calculate the loss as the negative log-likelihood of survival response $y^{\text{train}}$ given parameters $\theta^{\text{train}}$. This calculation is facilitated by `TorchSurv`'s loss function.
- Utilize backpropagation to update the neural network's parameters.
3. Obtain parameters $\theta^{\text{test}}$ given covariates from the test set $x^{\text{test}}$ using the trained neural network.
4. Evaluate the predictive performance of the model using `TorchSurv`'s evaluation metric functions given parameters $\theta^{\text{test}}$ and survival response from the test set $y^{\text{test}}$.



Expand All @@ -69,7 +70,7 @@ Specifically, the limitations on the log-likelihood functions include protected
With respect to the evaluation metrics, `scikit-survival` stands out as a comprehensive library. However, it lacks certain desirable features, including confidence intervals and comparison of the evaluation metric between two different models, and it is implemented with `NumPy`.
Our package, `TorchSurv`, is specifically designed for use in Python, but we also provide a comparative analysis of its functionalities with popular `R` packages for survival analysis in \autoref{tab:bibliography_R}. `R` packages also restrict users to specific forms to define the parameters and do not make log-likelihood functions readily accessible. However, `R` has extensive libraries for evaluation metrics, such as the `RiskRegression` library [@riskRegressionpackage]. `TorchSurv` offers a comparable range of evaluation metrics, ensuring comprehensive model evaluation regardless of the chosen programming environment.

The outputs of both the log-likelihood functions and the evaluation metrics functions have undergone thorough comparison with benchmarks generated with Python packages and R packages on open-source data and synthetic data. High agreement between the outputs is consistently observed, providing users with confidence in the accuracy and reliability of `TorchSurv`'s functionalities. The comparison is presented in the [`TorchSurv`'s website](https://opensource.nibr.com/torchsurv/benchmarks.html).
`TorchSurv`'s log-likelihood and evaluation metrics functions have undergone thorough comparison with benchmarks generated with Python packages and R packages on open-source data and synthetic data. High agreement between the outputs is consistently observed, providing users with confidence in the accuracy and reliability of `TorchSurv`'s functionalities. The comparison is presented in the [`TorchSurv`'s website](https://opensource.nibr.com/torchsurv/benchmarks.html).

![**Survival analysis libraries in Python.** $^1$[@nagpal2022auton], $^{2}$[@Kvamme2019pycox], $^{3}$[@torchlifeAbeywardana], $^{4}$[@polsterl2020scikit], $^{5}$[@davidson2019lifelines], $^{6}$[@katzman2018deepsurv]. A green tick indicates a fully supported feature, a red cross indicates an unsupported feature, a blue crossed tick indicates a partially supported feature. For computing the concordance index, `pycox` requires the use of the estimated survival function as the risk score and does not support other types of time-dependent risk scores. `scikit-survival` does not support time-dependent risk scores in both the concordance index and AUC computation. Additionally, both `pycox` and `scikit-survival `impose the use of inverse probability of censoring weighting (IPCW) for subject-specific weights. `scikit-survival` only offers the Breslow approximation of the Cox partial log-likelihood in case of ties in the event time, while it lacks the Efron approximation.\label{tab:bibliography}](table_1.png)

Expand All @@ -81,7 +82,7 @@ The outputs of both the log-likelihood functions and the evaluation metrics func

## Loss functions

**Cox loss function.** The Cox loss function is defined as the negative of the Cox proportional hazards model's partial log-likelihood [@Cox1972]. The function requires the subject-specific log relative hazards and the survival response (i.e., event indicator and time-to-event or censoring). The log relative hazards are obtained from a `PyTorch`-based model pre-specified by the user. In case of ties in the event times, the user can choose between the Breslow [@Breslow1975] and the Efron method [@Efron1977] to approximate the Cox partial log-likelihood. We illustrate the use of the Cox loss function for a pseudo training loop in the code snippet below.
**Cox loss function.** The Cox loss function is defined as the negative of the Cox proportional hazards model's partial log-likelihood [@Cox1972]. The function requires the subject-specific log relative hazards and the survival response (i.e., event indicator and time-to-event or censoring). The log relative hazards are obtained from a `PyTorch`-based model pre-specified by the user. In case of ties in the event times, the user can choose between the Breslow method [@Breslow1975] and the Efron method [@Efron1977] to approximate the Cox partial log-likelihood. We illustrate the use of the Cox loss function for a pseudo training loop in the code snippet below.

```python
from torchsurv.loss import cox
Expand Down Expand Up @@ -126,7 +127,7 @@ log_hzs = model_momentum.infer(x) # torch.Size([16, 1])

## Evaluation Metrics Functions

The `TorchSurv` package offers a comprehensive set of metrics to evaluate the predictive performance of survival models, including the AUC, C-index, and Brier score. The inputs of the evaluation metrics functions are the individual risk score estimated on the test set and the survival data on the test set. The risk score measures the risk (or a proxy thereof) that a subject has an event. We provide definitions for each metric and demonstrate their use through illustrative code snippets.
The `TorchSurv` package offers a comprehensive set of metrics to evaluate the predictive performance of survival models, including the AUC, C-index, and Brier score. The inputs of the evaluation metrics functions are the individual risk score estimated on the test set and the survival response on the test set. The risk score measures the risk (or a proxy thereof) that a subject has an event. We provide definitions for each metric and demonstrate their use through illustrative code snippets.

**AUC.** The AUC measures the discriminatory capacity of a model at a given time $t$, i.e., the model’s ability to provide a reliable ranking of times-to-event based on estimated individual risk scores [@Heagerty2005;@Uno2007;@Blanche2013].

Expand All @@ -137,15 +138,15 @@ auc(log_hzs, event, time) # AUC at each time
auc(log_hzs, event, time, new_time=torch.tensor(10.)) # AUC at time 10
```

**C-index.** The C-index is a generalization of the AUC that represents the assessment of the discriminatory capacity of the model over time [@Harrell1996;@Uno_2011].
**C-index.** The C-index is a generalization of the AUC that represents the assessment of the discriminatory capacity of the model across the time period [@Harrell1996;@Uno_2011].

```python
from torchsurv.metrics import ConcordanceIndex
cindex = ConcordanceIndex()
cindex(log_hzs, event, time) # C-index
```

**Brier Score.** The Brier score evaluates the accuracy of a model at a given time $t$. It represents the average squared distance between the observed survival status and the predicted survival probability [@Graf_1999]. The Brier score cannot be obtained for the Cox model because the survival function is not available, but it can be obtained for the Weibull model.
**Brier Score.** The Brier score evaluates the accuracy of a model at a given time $t$. It represents the average squared distance between the observed survival status and the predicted survival probability [@Graf_1999]. The Brier score cannot be obtained for the Cox proportional hazards model because the survival function is not available, but it can be obtained for the Weibull ATF model.

```python
from torchsurv.metrics import Brier
Expand Down

0 comments on commit a449e74

Please sign in to comment.