From c38c6b38d5c14ba4761037fe5f58e8eb2c1289c6 Mon Sep 17 00:00:00 2001 From: corolth1 Date: Mon, 6 Jan 2025 14:14:30 -0500 Subject: [PATCH] added doctest --- docs/notebooks/loss_time_covariates.py | 52 +++++++- docs/notebooks/time_varying.ipynb | 173 ++++++++++++++----------- 2 files changed, 144 insertions(+), 81 deletions(-) diff --git a/docs/notebooks/loss_time_covariates.py b/docs/notebooks/loss_time_covariates.py index dfd5a20..0279de4 100644 --- a/docs/notebooks/loss_time_covariates.py +++ b/docs/notebooks/loss_time_covariates.py @@ -7,14 +7,40 @@ def neg_partial_time_log_likelihood( - log_hz: torch.Tensor, # Txnxp torch tensor, n is batch size, T number of time points, p is number of different covariates over time - time: torch.Tensor, # n length vector, time at which someone experiences event - events: torch.Tensor, # n length vector, boolean, true or false to determine if someone had an event + log_hz: torch.Tensor, + time: torch.Tensor, + events: torch.Tensor, reduction: str = "mean", ) -> torch.Tensor: """ - needs further work + Compute the negative partial log-likelihood for time-dependent covariates in a Cox proportional hazards model. + Args: + log_hz (torch.Tensor): A tensor of shape (T, n, p) where T is the number of time points, n is the batch size, + and p is the number of different covariates over time. + time (torch.Tensor): A tensor of length n representing the time at which an event occurs for each individual. + events (torch.Tensor): A boolean tensor of length n indicating whether an event occurred (True) or not (False) for each individual. + reduction (str, optional): Specifies the reduction to apply to the output: 'mean' | 'sum'. Default is 'mean'. + Returns: + torch.Tensor: The computed negative partial log-likelihood. If reduction is 'mean', returns the mean value. + If reduction is 'sum', returns the sum of the values. + Raises: + ValueError: If the specified reduction method is not 'mean' or 'sum'. + + Examples: + >>> _ = torch.manual_seed(52) + >>> n = 10 # number of samples + >>> t = 5 # time steps + >>> time = torch.randint(low=5, high=250, size=(n,)).float() + >>> event = torch.randint(low=0, high=2, size=(n,)).bool() + >>> log_hz = torch.rand((t, n, 1)) + >>> neg_partial_time_log_likelihood(log_hz, time, event) + tensor(0.9456) + >>> neg_partial_time_log_likelihood(log_hz.squeeze(), time, event) # Also works with 2D tensor + tensor(0.9456) + >>> neg_partial_time_log_likelihood(log_hz, time, event, reduction='sum') + tensor(37.8241) """ + # only consider theta at tiem of pll = _partial_likelihood_time_cox(log_hz, time, events) @@ -86,7 +112,15 @@ def _partial_likelihood_time_cox( we want to identify the index of the covariate upon failure. We could either consider the last covariate before a series of zeros (requires special data formatting but could reduce issues as it automatically contains event time information). - + Examples: + >>> _ = torch.manual_seed(52) + >>> n = 3 # number of samples + >>> t = 3 # time steps + >>> time = torch.randint(low=5, high=250, size=(n,)).float() + >>> event = torch.randint(low=0, high=2, size=(n,)).bool() + >>> log_hz = torch.rand((t, n, 1)) + >>> _partial_likelihood_time_cox(log_hz, time, event) + tensor([-1.3772, -1.0683, -0.7879, -0.8220, 0.0000, 0.0000]) """ # Last dimension must be equal to 1 if shape == 3 if len(log_hz.shape) == 3: @@ -114,13 +148,13 @@ def _partial_likelihood_time_cox( log_hz_sorted_tj = torch.gather(log_hz_sorted, 1, idx.expand(log_hz_sorted.size())) # same step as in normal cox loss, just again, we consider Z(tau_j) where tau_j denotes event time to subject j - log_denominator_tj = torch.logcumsumexp(log_hz_sorted.flip(0), dim=0).flip(0) + log_cumulative_hazard = torch.logcumsumexp(log_hz_sorted.flip(0), dim=0).flip(0) # Keep only patients with events include = events_sorted.expand(log_hz_sorted.size()) # return the partial log likelihood - return (log_hz_sorted_tj - log_denominator_tj)[include] + return (log_hz_sorted_tj - log_cumulative_hazard)[include] def _time_varying_covariance( @@ -168,6 +202,10 @@ def _time_varying_covariance( if __name__ == "__main__": import torch from torchsurv.metrics.cindex import ConcordanceIndex + import doctest + + # Run doctest + results = doctest.testmod() # set seed torch.manual_seed(123) diff --git a/docs/notebooks/time_varying.ipynb b/docs/notebooks/time_varying.ipynb index da52850..8f87285 100644 --- a/docs/notebooks/time_varying.ipynb +++ b/docs/notebooks/time_varying.ipynb @@ -117,7 +117,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -130,7 +130,7 @@ "from sklearn.model_selection import train_test_split\n", "\n", "# Our package\n", - "#from torchsurv.loss.time_varying import neg_partial_log_likelihood2\n", + "# from torchsurv.loss.time_varying import neg_partial_log_likelihood2\n", "\n", "# PyTorch boilerplate - see https://github.com/Novartis/torchsurv/blob/main/docs/notebooks/helpers_introduction.py\n", "from helpers_introduction import Custom_dataset, plot_losses" @@ -170,7 +170,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -191,14 +191,14 @@ "torch.manual_seed(123)\n", "\n", "n = 100 # Number of subjects\n", - "T = torch.tensor(6) # Number of time points\n", + "T = torch.tensor(6) # Number of time points\n", "time_vec = torch.tensor([0, 1, 2, 3, 4, 5])\n", "\n", "# Simulation parameters\n", "age_mean = 35\n", "age_std = 5\n", "sex_prob = 0.54\n", - "G = torch.tensor([[0.29, -0.00465],[-0.00465, 0.000320]])\n", + "G = torch.tensor([[0.29, -0.00465], [-0.00465, 0.000320]])\n", "Z = torch.tensor([[1, 1, 1, 1, 1, 1], time_vec], dtype=torch.float32).T\n", "sigma = torch.tensor([0.1161])\n", "alpha = 1\n", @@ -220,7 +220,12 @@ "\n", "# Generate expected longitudinal trajectories\n", "# quite frakly this is useless now - it was based on my bad understanding of the algorithm\n", - "trajectories = random_effects[:, 0].unsqueeze(1) + random_effects[:, 1].unsqueeze(1) * Z[:,1] + alpha * age.unsqueeze(1) + error_sample\n", + "trajectories = (\n", + " random_effects[:, 0].unsqueeze(1)\n", + " + random_effects[:, 1].unsqueeze(1) * Z[:, 1]\n", + " + alpha * age.unsqueeze(1)\n", + " + error_sample\n", + ")\n", "\n", "print(trajectories[1:5, :])" ] @@ -280,11 +285,11 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "#import lmbert W function\n", + "# import lmbert W function\n", "\n", "from scipy.special import lambertw" ] @@ -302,7 +307,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -328,15 +333,19 @@ "source": [ "# Specify the values for parameters, generate the random variables and call on relevant variables defined previously\n", "\n", - "alpha = torch.tensor([0.05, -0.5]) # regression coefficient for time-invariant covariates\n", - "gamma = torch.tensor(0.3) # association strength between longitudinal measures and time-to-event\n", + "alpha = torch.tensor(\n", + " [0.05, -0.5]\n", + ") # regression coefficient for time-invariant covariates\n", + "gamma = torch.tensor(\n", + " 0.3\n", + ") # association strength between longitudinal measures and time-to-event\n", "lambda_0 = torch.tensor(0.1) # baseline hazard rate\n", "\n", "torch.manual_seed(456)\n", "\n", "# Generate the random variables for hazard of a subject and censoring\n", "Q = dist.Uniform(0, 1).sample((n,)) # Random variable for hazard (Q)\n", - "C = dist.Uniform(3,5.5).sample((n,)) # Random variable for censoring\n", + "C = dist.Uniform(3, 5.5).sample((n,)) # Random variable for censoring\n", "\n", "# age and sex are the names of variables corresponding to those covariates\n", "# create the X matrix of covariates\n", @@ -348,19 +357,19 @@ "\n", "# Generate time to event T using the equation above\n", "log_Q = torch.log(Q)\n", - "lambert_W_nominator = gamma*b2*log_Q\n", - "lambert_W_denominator = torch.exp(alpha@XX.T + gamma*b1)\n", - "# below should give a vector of length sample_size \n", - "lambert_W = lambertw(-lambert_W_nominator/(lambda_0*lambert_W_denominator))\n", - "time_to_event = lambert_W/(gamma*b2)\n", + "lambert_W_nominator = gamma * b2 * log_Q\n", + "lambert_W_denominator = torch.exp(alpha @ XX.T + gamma * b1)\n", + "# below should give a vector of length sample_size\n", + "lambert_W = lambertw(-lambert_W_nominator / (lambda_0 * lambert_W_denominator))\n", + "time_to_event = lambert_W / (gamma * b2)\n", "\n", - "#take the real part of the LBF, the complex part is =0\n", + "# take the real part of the LBF, the complex part is =0\n", "outcome_LWF = time_to_event.real\n", "outcome_LWF = torch.floor(outcome_LWF)\n", "outcome_LWF\n", "\n", "# implement censoring with some level of intensity below\n", - "events = C<5\n", + "events = C < 5\n", "events" ] }, @@ -415,7 +424,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -430,8 +439,8 @@ } ], "source": [ - "#from torchsurv.loss import time_covariates\n", - "#from torchsurv.metrics.cindex import ConcordanceIndex\n", + "# from torchsurv.loss import time_covariates\n", + "# from torchsurv.metrics.cindex import ConcordanceIndex\n", "\n", "# Parameters\n", "input_size = 1\n", @@ -447,7 +456,7 @@ "print(test.shape)\n", "print(inputs.shape)\n", "\n", - "#initializa hidden state\n", + "# initializa hidden state\n", "h0 = torch.randn(num_layers, batch_size, output_size)\n", "print(h0.shape)\n", "# Forward pass time series input\n", @@ -458,7 +467,7 @@ "# print(f\"Estimate shape for {batch_size} samples = {estimates.size()}\") # Estimate shape for 8 samples = torch.Size([8, 1])\n", "\n", "\n", - "#loss = neg_loss_function(outputs, events, time)\n", + "# loss = neg_loss_function(outputs, events, time)\n", "# print(f\"loss = {loss}, has gradient = {loss.requires_grad}\") # loss = 1.0389232635498047, has gradient = True\n", "\n", "# cindex = ConcordanceIndex()\n", @@ -476,51 +485,52 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "\n", - "# as a reminder covars is the matrix of covariates where a row corresponds to a subject and a column corresponds to their observation at some time \n", + "# as a reminder covars is the matrix of covariates where a row corresponds to a subject and a column corresponds to their observation at some time\n", "# the columns are padded so if a subject experiences an event, the remaining of the column is zero\n", "\n", "# Generating example torch matrix\n", "torch_matrix = trajectories\n", "# Convert torch matrix to pandas dataframe\n", "\n", - "#set time to integer\n", + "# set time to integer\n", "max_time = max(time_vec.type(torch.int64))\n", "\n", - "vars = []\n", - "#times = []\n", + "variables = []\n", "start = []\n", "stop = []\n", "event = []\n", "subjs = []\n", + "\n", "for i in range(n):\n", " subj_counter = 0\n", - " for j in range(max_time):\n", - " if torch_matrix[i,j] == 0:\n", + " for j in range(1, max_time + 1):\n", + " if torch_matrix[i, j - 1] == 0:\n", " break\n", - " else:\n", - " vars.append(torch_matrix[i,j].item())\n", - " #times.append(j)\n", - " start.append(j-1)\n", - " stop.append(j)\n", - " event.append(False)\n", - " subj_counter += 1\n", + " variables.append(torch_matrix[i, j - 1].item())\n", + " start.append(j - 1)\n", + " stop.append(j)\n", + " event.append(False)\n", + " subj_counter += 1\n", " subjs.extend([i] * subj_counter)\n", - " if events[i]==True: event[-1]=True\n", - "\n", - "df = pd.DataFrame({\n", - " \"subj\": subjs,\n", - " #\"times\": times,\n", - " \"start\":start,\n", - " \"stop\": stop,\n", - " \"events\": event,\n", - " \"var\": vars, \n", - "})\n" + " if events[i]:\n", + " event[-1] = True\n", + "\n", + "df = pd.DataFrame(\n", + " {\n", + " \"subj\": subjs,\n", + " # \"times\": times,\n", + " \"start\": start,\n", + " \"stop\": stop,\n", + " \"events\": event,\n", + " \"var\": variables,\n", + " }\n", + ")" ] }, { @@ -532,7 +542,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -723,7 +733,14 @@ "from lifelines import CoxTimeVaryingFitter\n", "\n", "ctv = CoxTimeVaryingFitter(penalizer=0.1)\n", - "ctv.fit(df, id_col=\"subj\", event_col=\"events\", start_col=\"start\", stop_col=\"stop\", show_progress=True)\n", + "ctv.fit(\n", + " df,\n", + " id_col=\"subj\",\n", + " event_col=\"events\",\n", + " start_col=\"start\",\n", + " stop_col=\"stop\",\n", + " show_progress=True,\n", + ")\n", "ctv.print_summary()\n", "ctv.plot()" ] @@ -874,16 +891,18 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from lifelines.utils import to_long_format, add_covariate_to_timeline\n", "\n", - "base_df = pd.DataFrame([\n", - " {'id': 1, 'duration': 10, 'event': True, 'var1': 0.1},\n", - " {'id': 2, 'duration': 12, 'event': True, 'var1': 0.5}\n", - "])\n", + "base_df = pd.DataFrame(\n", + " [\n", + " {\"id\": 1, \"duration\": 10, \"event\": True, \"var1\": 0.1},\n", + " {\"id\": 2, \"duration\": 12, \"event\": True, \"var1\": 0.5},\n", + " ]\n", + ")\n", "\n", "base_df = to_long_format(base_df, duration_col=\"duration\")" ] @@ -982,39 +1001,39 @@ "metadata": {}, "outputs": [], "source": [ - "print('x_test', x_test.shape)\n", - "print('events', test_event.shape)\n", - "print('times', test_time.shape)\n", + "print(\"x_test\", x_test.shape)\n", + "print(\"events\", test_event.shape)\n", + "print(\"times\", test_time.shape)\n", "\n", "time_sorted, idx = torch.sort(time)\n", "log_hz_sorted = log_hz[idx]\n", "event_sorted = event[idx]\n", "time_unique = torch.unique(time_sorted)\n", - "print('')\n", + "print(\"\")\n", "print(\"time_sorted\", time_sorted.shape)\n", - "print('log_hz_sorted', log_hz_sorted.shape)\n", - "print('event_sorted', event_sorted.shape)\n", + "print(\"log_hz_sorted\", log_hz_sorted.shape)\n", + "print(\"event_sorted\", event_sorted.shape)\n", "print(\"time_unique\", time_unique.shape)\n", "\n", - "print('-'*30)\n", + "print(\"-\" * 30)\n", "cov_fake = torch.clone(x_test)\n", - "print('covariates', cov_fake.shape)\n", + "print(\"covariates\", cov_fake.shape)\n", "covariates_sorted = cov_fake[idx, :]\n", "covariate_inner_product = torch.matmul(covariates_sorted, covariates_sorted.T)\n", - "print('cov_inner', covariate_inner_product.shape)\n", + "print(\"cov_inner\", covariate_inner_product.shape)\n", "log_nominator_left = torch.matmul(log_hz_sorted.T, covariate_inner_product)\n", - "print('log_nom_left', log_nominator_left.shape)\n", + "print(\"log_nom_left\", log_nominator_left.shape)\n", "bracket = torch.mul(log_hz_sorted, covariates_sorted)\n", - "print('bracket', bracket.shape)\n", + "print(\"bracket\", bracket.shape)\n", "log_nominator_right = torch.matmul(bracket, bracket.T)\n", - "print('log_nom_right', log_nominator_right.shape)\n", + "print(\"log_nom_right\", log_nominator_right.shape)\n", "sum_nominator_right = log_nominator_right[0,].unsqueeze(0)\n", - "print('sum_nom', sum_nominator_right.shape)\n", + "print(\"sum_nom\", sum_nominator_right.shape)\n", "log_denominator = torch.logcumsumexp(log_hz_sorted.flip(0), dim=0).flip(0).T\n", - "print('log_denom', log_denominator.shape)\n", + "print(\"log_denom\", log_denominator.shape)\n", "last_bit = torch.div(log_nominator_left - sum_nominator_right, log_denominator)\n", - "print('last_bit', last_bit.shape)\n", - "last_bit\n" + "print(\"last_bit\", last_bit.shape)\n", + "last_bit" ] }, { @@ -1047,7 +1066,9 @@ "\n", "# make random positive time to event\n", "time = torch.rand(batch_size) * 100\n", - "print(time) # tensor([32.8563, 38.3207, 24.6015, 72.2986, 19.9004, 65.2180, 73.2083, 21.2663])\n", + "print(\n", + " time\n", + ") # tensor([32.8563, 38.3207, 24.6015, 72.2986, 19.9004, 65.2180, 73.2083, 21.2663])\n", "\n", "# Create simple RNN model\n", "rnn = torch.nn.RNN(input_size, output_size, num_layers)\n", @@ -1058,11 +1079,15 @@ "outputs, _ = rnn(inputs, h0)\n", "estimates = outputs[-1] # Keep only last predictions, many to one approach\n", "print(estimates.size()) # torch.Size([8, 1])\n", - "print(f\"Estimate shape for {batch_size} samples = {estimates.size()}\") # Estimate shape for 8 samples = torch.Size([8, 1])\n", + "print(\n", + " f\"Estimate shape for {batch_size} samples = {estimates.size()}\"\n", + ") # Estimate shape for 8 samples = torch.Size([8, 1])\n", "\n", "\n", "loss = cox.neg_partial_log_likelihood(estimates, events, time)\n", - "print(f\"loss = {loss}, has gradient = {loss.requires_grad}\") # loss = 1.0389232635498047, has gradient = True\n", + "print(\n", + " f\"loss = {loss}, has gradient = {loss.requires_grad}\"\n", + ") # loss = 1.0389232635498047, has gradient = True\n", "\n", "cindex = ConcordanceIndex()\n", "print(f\"c-index = {cindex(estimates, events, time)}\") # c-index = 0.20000000298023224"