Skip to content

Commit

Permalink
add chatGTP vectorized version
Browse files Browse the repository at this point in the history
  • Loading branch information
bblodfon committed Mar 1, 2024
1 parent 95c7400 commit c98ae29
Showing 1 changed file with 39 additions and 0 deletions.
39 changes: 39 additions & 0 deletions R/antolini.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,42 @@
cindex_chatGTP = function(pred, meth = c("A", "H"), tiex = 0.5) {
n_obs = length(pred$truth)
pred_times = pred$truth[, 1]
status = pred$truth[, 2]
surv = pred$data$distr
times = as.numeric(colnames(surv))
risk = unname(pred$data$crank)

# Assuming meth "A" optimization
if (meth == "A") {
extend_times = getFromNamespace("C_Vec_WeightedDiscreteCdf", ns = "distr6")
surv_mat = extend_times(x = pred_times, data = times, cdf = t(1 - surv), FALSE, FALSE)
rownames(surv_mat) = pred_times
}

n_seq = seq_len(n_obs)
pairs_i = rep(n_seq, each = n_obs)
pairs_j = rep(n_seq, n_obs)

comparable = function(ti, tj, di, cutoff) di & ti < tj & ti < cutoff
comp = comparable(pred_times[pairs_i], pred_times[pairs_j], status[pairs_i], cutoff = Inf)

if (meth == "A") {
surv_ii2 = rep(diag(surv_mat), times = n_obs)
surv_ij = surv_mat[cbind(as.character(pred_times[pairs_i]), pairs_j)]

conc = sum(surv_ii2[comp] < surv_ij[comp]) +
sum(surv_ii2[comp] == surv_ij[comp]) * tiex
} else {
ri = risk[pairs_i]
rj = risk[pairs_j]
conc = sum(ri[comp] > rj[comp]) +
sum(ri[comp] == rj[comp]) * tiex
}

conc / sum(comp)
}


cindex = function(pred, meth = c("A", "H"), tiex = 0.5) {
n_obs = length(pred$truth)
pred_times = pred$truth[, 1] # to differentiate with `times` below
Expand Down

0 comments on commit c98ae29

Please sign in to comment.