Skip to content

Commit

Permalink
rm comprehensive example JOSS paper
Browse files Browse the repository at this point in the history
  • Loading branch information
melodiemonod committed Dec 20, 2024
1 parent 4e81a22 commit 305a544
Showing 1 changed file with 0 additions and 109 deletions.
109 changes: 0 additions & 109 deletions paper/paper.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,115 +153,6 @@ cindex.p_value(alternative='greater') # pvalue, H0:c=0.5, HA:c>0.5
cindex.compare(cindex_other) # pvalue, H0:c1=c2, HA:c1>c2
```

# Comprehensive Example: Fitting a Cox Proportional Hazards Model with TorchSurv

In this section, we provide a reproducible code example to demonstrate how to use `TorchSurv` for fitting a Cox proportional hazards model. We simulate data where each observations is associated with 10 features, a time-to-event that depends linearly on these features, and a time-to-censoring. The observable time is the minimum between the time-to-event and time-to-censoring, representing the first event that occurs. Subsequently, we fit a Cox proportional hazards model using maximum likelihood estimation and assess the model's predictive performance through the AUC and the C-index. To facilitate rapid execution, we use a simple linear backend model in PyTorch to define the log relative hazards. For more comprehensive examples using real data, we encourage readers to visit the `Torchsurv` website.



```python
import torch
from torch.utils.data import DataLoader, Dataset
from torchsurv.loss import cox
from torchsurv.metrics.cindex import ConcordanceIndex
from torchsurv.metrics.auc import Auc

torch.manual_seed(42)

# 1. Simulate response
n_features = 10 # int, number of features per observation
time_end = torch.tensor(
2000.0
) # float, end of observational period after which all observations are censored
weights = (
torch.randn(n_features) * 5
) # float, weights associated with the features ~ normal(0, 5^2)

# Define the survival response generator function
def tte_generator(batch_size: int):
while True:
x = torch.randn(batch_size, n_features) # features
mean_event_time, mean_censoring_time = 1000.0 + x @ weights, 1000.0

event_time = (
mean_event_time + torch.randn(batch_size) * 50
) # event time ~ normal(mean_event_time, 50^2)
censoring_time = torch.distributions.Exponential(
1 / mean_censoring_time
).sample(
(batch_size,)
) # censoring time ~ Exponential(mean = mean_censoring_time)
censoring_time = torch.minimum(
censoring_time, time_end
) # truncate censoring time to time_end

event = (event_time <= censoring_time).bool() # event indicator
time = torch.minimum(event_time, censoring_time) # observed time

yield x, event, time

# 2. Define the PyTorch dataset class
class TTE_dataset(Dataset):
def __init__(self, generator: callable, batch_size: int):
self.batch_size = batch_size
self.generatated_data = generator(batch_size=batch_size)

def __len__(self):
return self.batch_size

def __getitem__(self, index):
return next(self.generatated_data)

# 3. Define the backbone model on the log hazards.
class MyPyTorchCoxModel(torch.nn.Module):
def __init__(self):
super(MyPyTorchCoxModel, self).__init__()
self.fc = torch.nn.Linear(n_features, 1, bias=False) # Simple linear model

def forward(self, x):
return self.fc(x)

# 4. Instantiate the model, optimizer, dataset and dataloader
cox_model = MyPyTorchCoxModel()
optimizer = torch.optim.Adam(cox_model.parameters(), lr=0.01)
batch_size = 64 # int, batch size
dataset = TTE_dataset(tte_generator, batch_size=batch_size)
dataloader = DataLoader(
dataset, batch_size=1, shuffle=True
) # Batch size of 1 because dataset yields batches

# 5. Training loop
for epoch in range(100):
for i, batch in enumerate(dataloader):
x, event, time = [t.squeeze() for t in batch]
optimizer.zero_grad()
log_hzs = cox_model(x) # torch.Size([batch_size, 1])
loss = cox.neg_partial_log_likelihood(log_hzs, event, time)
loss.backward()
optimizer.step()

# 6. Evaluate the model
n_samples_test = 1000 # int, number of observations in test set
data_test = next(tte_generator(batch_size=n_samples_test))
x, event, time = [t.squeeze() for t in data_test] # test set
log_hzs = cox_model(x) # log relative hazards evaluated on test set

# AUC at time point 1000
auc = Auc()
print(
"AUC:", auc(log_hzs, event, time, new_time=torch.tensor(1000.0))
) # tensor([0.5902])
print("AUC Confidence Interval:", auc.confidence_interval()) # tensor([0.5623, 0.6180])
print("AUC p-value:", auc.p_value(alternative="greater")) # tensor([0.])

# C-index
cindex = ConcordanceIndex()
print("C-Index:", cindex(log_hzs, event, time)) # tensor(0.5774)
print(
"C-Index Confidence Interval:", cindex.confidence_interval()
) # tensor([0.5086, 0.6463])
print("C-Index p-value:", cindex.p_value(alternative="greater")) # tensor(0.0138)
```

# Conflicts of interest

Expand Down

0 comments on commit 305a544

Please sign in to comment.