Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Translated the hh-lambda loop #32

Merged
merged 49 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
c8a3c45
Extracted `hh <= nlamba` loop as R function (#17)
wleoncio Jan 26, 2024
f2b3fed
Added functional `legacy` switch (#17)
wleoncio Jan 26, 2024
8b1eeb6
Fixed spacing in printed text
wleoncio Jan 26, 2024
dcfb2a4
Added skeleton of loop function on C++ (#17)
wleoncio Jan 26, 2024
08918bd
Added `MADMMplasso.h` (#17)
wleoncio Jan 26, 2024
1ebe86d
Translated a bit of `hh_nlambda_loop()` to C++ (#17)
wleoncio Jan 26, 2024
b84e55a
Removed duplicated `if` case (#17)
wleoncio Jan 26, 2024
4023f01
Added `my_values` as argument to loop functions (#17)
wleoncio Jan 26, 2024
411df0f
Revert "Removed duplicated `if` case (#17)"
wleoncio Jan 26, 2024
5f2d38d
Fixed if-statements (#17)
wleoncio Jan 26, 2024
916706e
Fixed object types (#17)
wleoncio Jan 26, 2024
998845b
Using safe element access (#17)
wleoncio Jan 26, 2024
9d3caac
Translated more of `hh_nlambda_loop()` to C++ (#17)
wleoncio Jan 26, 2024
141f815
Fixed `hh` index (#17)
wleoncio Jan 26, 2024
5992b60
Converting `my_values` to list in all cases (#17)
wleoncio Feb 6, 2024
1727540
Syntax fix (#17)
wleoncio Feb 6, 2024
93a7237
Translated more of `hh_nlambda_loop()` to C++ (#17)
wleoncio Feb 6, 2024
fdc1bc4
Added TODO
wleoncio Feb 6, 2024
328f1ee
Translated almost all `hh_nlambda_loop()` (#17)
wleoncio Feb 14, 2024
bdab078
Exporting and using `hh_nlambda_loop_cpp()` (#17)
wleoncio Feb 14, 2024
3673011
Syntax fix (#17)
wleoncio Feb 14, 2024
364d275
Translated rest of `hh_nlambda_loop()` (#17)
wleoncio Feb 14, 2024
aff62ec
Fixed package documentation
wleoncio Feb 14, 2024
399729a
Increment version number to 0.0.0.9011
wleoncio Feb 14, 2024
585bad6
Removed commented out code (#17)
wleoncio Feb 14, 2024
5f0d1d8
Added unit tests (#17)
wleoncio Feb 14, 2024
0862485
Translated `count_nonzero_a()` to C++ (#17)
wleoncio Feb 14, 2024
9e98c98
Using `count_nonzero_a_cpp()` on `hh_nlambda_loop_cpp()` (#17)
wleoncio Feb 14, 2024
7dd48f0
Fixes to header files (#17, #2)
wleoncio Feb 14, 2024
306a033
Adding separate functions to handle `sp_mat` and `cube` (#17)
wleoncio Feb 14, 2024
5bf822b
Increment version number to 0.0.0.9012
wleoncio Feb 15, 2024
c55fc71
Exporting remaining `count_nonzero_a*()` (#17)
wleoncio Feb 15, 2024
8051f0f
Added more unit tests for #17
wleoncio Feb 15, 2024
015a8c5
Updates `TODOs` and `FIXMEs` (#17)
wleoncio Feb 15, 2024
1577ddf
Merge branch 'main' into issue-17
wleoncio Mar 1, 2024
33b5848
Added `aux/` folder to `.gitignore`
wleoncio Mar 1, 2024
80c2bd2
Increment version number to 0.0.0.9013
wleoncio Mar 1, 2024
9f6186e
Fixed duplicated condition
wleoncio Mar 4, 2024
071c49b
`print(cost_time)` only triggers if `my_print == TRUE`
wleoncio Mar 4, 2024
fa7e57a
Merge branch 'main' into issue-17
wleoncio Mar 4, 2024
b1b5b46
Merge branch 'main' into issue-17
wleoncio Mar 4, 2024
ab9676d
More consistent formatting
wleoncio Mar 4, 2024
7573dbc
Increment version number to 0.0.0.9014
wleoncio Mar 4, 2024
2b4d267
Adjusted unit tests
wleoncio Mar 4, 2024
ca1603f
Fixed code smell
wleoncio Mar 4, 2024
9ecbac7
Fixing lints on unit tests
wleoncio Mar 4, 2024
75796aa
Converted TODOs to GitHub issues
wleoncio Mar 4, 2024
6cb3d04
Merge branch 'main' into issue-17
wleoncio Mar 5, 2024
5e24bdd
Increment version number to 0.0.0.9015
wleoncio Mar 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
src/*.o
src/*.so
src/*.dll
aux/
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: MADMMplasso
Title: Multi Variate Multi Response 'ADMM' with Interaction Effects
Version: 0.0.0.9013
Version: 0.0.0.9015
Authors@R:
c(
person(
Expand Down
116 changes: 35 additions & 81 deletions R/MADMMplasso.R
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,6 @@ MADMMplasso <- function(X, Z, y, alpha, my_lambda = NULL, lambda_min = 0.001, ma
gg <- gg1
}

lam_list <- list()
obj <- NULL
n_main_terms <- NULL
non_zero_theta <- NULL
my_obj <- list()

my_W_hat <- generate_my_w(X = X, Z = Z)

svd.w <- svd(my_W_hat)
Expand Down Expand Up @@ -202,14 +196,19 @@ MADMMplasso <- function(X, Z, y, alpha, my_lambda = NULL, lambda_min = 0.001, ma
doParallel::registerDoParallel(cl = cl)
foreach::getDoParRegistered()

my_values <- foreach(i = 1:nlambda, .packages = "MADMMplasso", .combine = rbind) %dopar% {
my_values_matrix <- foreach(i = 1:nlambda, .packages = "MADMMplasso", .combine = rbind) %dopar% {
admm_MADMMplasso(
beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it, my_W_hat, XtY,
y, N, e.abs, e.rel, alpha, lam[i, ], alph, svd.w, tree, my_print,
invmat, gg[i, ], legacy
)
}
parallel::stopCluster(cl)

# Converting to list so hh_nlambda_loop_cpp can handle it
for (hh in seq_len(nrow(my_values_matrix))) {
my_values[[hh]] <- my_values_matrix[hh, ]
}
} else if (!parallel && !pal) {
my_values <- lapply(
seq_len(nlambda),
Expand All @@ -221,87 +220,42 @@ MADMMplasso <- function(X, Z, y, alpha, my_lambda = NULL, lambda_min = 0.001, ma
)
}
)
} else {
# This is triggered when parallel is FALSE and pal is 1
my_values <- list()
}

hh <- 1
while (hh <= nlambda) {
lambda <- lam[hh, ]

start_time <- Sys.time()
if (pal) {
my_values <- admm_MADMMplasso(
beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it, my_W_hat, XtY,
y, N, e.abs, e.rel, alpha, lambda, alph, svd.w, tree, my_print, invmat,
gg[hh, ], legacy
)

beta <- my_values$beta
theta <- my_values$theta
my_obj[[hh]] <- list(my_values$obj)
beta0 <- my_values$beta0
theta0 <- my_values$theta0 ### iteration
beta_hat <- my_values$beta_hat
y_hat <- my_values$y_hat
}
cost_time <- Sys.time() - start_time
print(cost_time)
if (parallel && !pal) {
beta <- my_values[hh, ]$beta
theta <- my_values[hh, ]$theta
my_obj[[hh]] <- list(my_values[hh, ]$obj)
beta0 <- my_values[hh, ]$beta0
theta0 <- my_values[hh, ]$theta0 ### iteration
beta_hat <- my_values[hh, ]$beta_hat
y_hat <- my_values[hh, ]$y_hat
} else if (!parallel && !pal) {
beta <- my_values[[hh]]$beta
theta <- my_values[[hh]]$theta
my_obj[[hh]] <- list(my_values[[hh]]$obj)
beta0 <- my_values[[hh]]$beta0
theta0 <- my_values[[hh]]$theta0 ### iteration
beta_hat <- my_values[[hh]]$beta_hat
y_hat <- my_values[[hh]]$y_hat
}

beta1 <- as(beta * (abs(beta) > tol), "sparseMatrix")
theta1 <- as.sparse3Darray(theta * (abs(theta) > tol))
beta_hat1 <- as(beta_hat * (abs(beta_hat) > tol), "sparseMatrix")

n_interaction_terms <- count_nonzero_a((theta1))

n_main_terms <- (c(n_main_terms, count_nonzero_a((beta1))))

obj1 <- (sum(as.vector((y - y_hat)^2))) / (D * N)
obj <- c(obj, obj1)

non_zero_theta <- (c(non_zero_theta, n_interaction_terms))
lam_list <- (c(lam_list, lambda))

BETA0[[hh]] <- beta0
THETA0[[hh]] <- theta0
BETA[[hh]] <- as(beta1, "sparseMatrix")
BETA_hat[[hh]] <- as(beta_hat1, "sparseMatrix")

Y_HAT[[hh]] <- y_hat
THETA[[hh]] <- as.sparse3Darray(theta1)

if (hh == 1) {
print(c(hh, (n_main_terms[hh]), non_zero_theta[hh], obj1))
} else {
print(c(hh, (n_main_terms[hh]), non_zero_theta[hh], obj[hh - 1], obj1))
}

hh <- hh + 1
} ### lambda
loop_output <- hh_nlambda_loop(
lam, nlambda, beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it,
my_W_hat, XtY, y, N, e.abs, e.rel, alpha, alph, svd.w, tree, my_print,
invmat, gg, tol, parallel, pal, BETA0, THETA0, BETA,
BETA_hat, Y_HAT, THETA, D, my_values, legacy
)

remove(invmat)
remove(my_values)
remove(my_W_hat)

obj[1] <- obj[2]
loop_output$obj[1] <- loop_output$obj[2]

pred <- data.frame(Lambda = lam, nzero = n_main_terms, nzero_inter = non_zero_theta, OBJ_main = obj)
out <- list(beta0 = BETA0, beta = BETA, BETA_hat = BETA_hat, theta0 = THETA0, theta = THETA, path = pred, Lambdas = lam, non_zero = n_main_terms, LOSS = obj, Y_HAT = Y_HAT, gg = gg)
pred <- data.frame(
Lambda = lam,
nzero = loop_output$n_main_terms,
nzero_inter = loop_output$non_zero_theta,
OBJ_main = loop_output$obj
)
out <- list(
beta0 = loop_output$BETA0,
beta = loop_output$BETA,
BETA_hat = loop_output$BETA_hat,
theta0 = loop_output$THETA0,
theta = loop_output$THETA,
path = pred,
Lambdas = lam,
non_zero = loop_output$n_main_terms,
LOSS = loop_output$obj,
Y_HAT = loop_output$Y_HAT,
gg = gg
)
class(out) <- "MADMMplasso"
# Return results
return(out)
Expand Down
16 changes: 16 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,22 @@ admm_MADMMplasso_cpp <- function(beta0, theta0, beta, beta_hat, theta, rho1, X,
.Call(`_MADMMplasso_admm_MADMMplasso_cpp`, beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it, W_hat, XtY, y, N, e_abs, e_rel, alpha, lambda, alph, svd_w, tree, invmat, gg, my_print)
}

count_nonzero_a_cpp <- function(x) {
.Call(`_MADMMplasso_count_nonzero_a_cpp`, x)
}

count_nonzero_a_sp_mat <- function(x) {
.Call(`_MADMMplasso_count_nonzero_a_sp_mat`, x)
}

count_nonzero_a_cube <- function(x) {
.Call(`_MADMMplasso_count_nonzero_a_cube`, x)
}

hh_nlambda_loop_cpp <- function(lam, nlambda, beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it, my_W_hat, XtY, y, N, e_abs, e_rel, alpha, alph, svd_w, tree, my_print, invmat, gg, tol, parallel, pal, BETA0, THETA0, BETA, BETA_hat, Y_HAT, THETA, D, my_values) {
.Call(`_MADMMplasso_hh_nlambda_loop_cpp`, lam, nlambda, beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it, my_W_hat, XtY, y, N, e_abs, e_rel, alpha, alph, svd_w, tree, my_print, invmat, gg, tol, parallel, pal, BETA0, THETA0, BETA, BETA_hat, Y_HAT, THETA, D, my_values)
}

model_intercept <- function(beta0, theta0, beta, theta, X, Z) {
.Call(`_MADMMplasso_model_intercept`, beta0, theta0, beta, theta, X, Z)
}
Expand Down
4 changes: 2 additions & 2 deletions R/admm_MADMMplasso.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ admm_MADMMplasso <- function(beta0, theta0, beta, beta_hat, theta, rho1, X, Z, m
return(out)
}
warning(
"Using legacy R code for MADMMplasso.",
"This functionality will be removed in a future release.",
"Using legacy R code for MADMMplasso. ",
"This functionality will be removed in a future release. ",
"Please consider using legacy = FALSE instead."
)
TT <- tree
Expand Down
4 changes: 3 additions & 1 deletion R/cv_MADMMplasso.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ cv_MADMMplasso <- function(fit, nfolds, X, Z, y, alpha = 0.5, lambda = fit$Lambd
nfolds <- length(table(foldid))

for (ii in 1:nfolds) {
print(c("fold,", ii))
if (my_print) {
print(c("fold,", ii))
}
oo <- foldid == ii

ggg[[ii]] <- MADMMplasso(X = X[!oo, , drop = FALSE], Z = Z[!oo, , drop = FALSE], y = y[!oo, , drop = FALSE], alpha = alpha, my_lambda = lambda, lambda_min = 0.01, max_it = max_it, e.abs = e.abs, e.rel = e.rel, nlambda = length(lambda[, 1]), rho = rho, tree = TT, my_print = my_print, alph = alph, parallel = parallel, pal = pal, gg = gg, tol = tol, cl = cl, legacy)
Expand Down
102 changes: 102 additions & 0 deletions R/hh_nlambda_loop.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
hh_nlambda_loop <- function(
lam, nlambda, beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it,
my_W_hat, XtY, y, N, e.abs, e.rel, alpha, alph, svd.w, tree, my_print,
invmat, gg, tol, parallel, pal, BETA0, THETA0, BETA,
BETA_hat, Y_HAT, THETA, D, my_values, legacy = TRUE
) {
if (legacy) {
obj <- NULL
non_zero_theta <- NULL
my_obj <- list()
n_main_terms <- NULL
lam_list <- list()
hh <- 1
while (hh <= nlambda) {
lambda <- lam[hh, ]

start_time <- Sys.time()
if (pal) {
my_values <- admm_MADMMplasso(
beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it, my_W_hat, XtY,
y, N, e.abs, e.rel, alpha, lambda, alph, svd.w, tree, my_print, invmat,
gg[hh, ], legacy
)

beta <- my_values$beta
theta <- my_values$theta
my_obj[[hh]] <- list(my_values$obj)
beta0 <- my_values$beta0
theta0 <- my_values$theta0 ### iteration
beta_hat <- my_values$beta_hat
y_hat <- my_values$y_hat
}
cost_time <- Sys.time() - start_time
if (my_print) {
print(cost_time)
}
if (parallel && !pal) {
beta <- my_values[hh, ]$beta
theta <- my_values[hh, ]$theta
my_obj[[hh]] <- list(my_values[hh, ]$obj)
beta0 <- my_values[hh, ]$beta0
theta0 <- my_values[hh, ]$theta0 ### iteration
beta_hat <- my_values[hh, ]$beta_hat
y_hat <- my_values[hh, ]$y_hat
} else if (!parallel && !pal) {
beta <- my_values[[hh]]$beta
theta <- my_values[[hh]]$theta
my_obj[[hh]] <- list(my_values[[hh]]$obj)
beta0 <- my_values[[hh]]$beta0
theta0 <- my_values[[hh]]$theta0 ### iteration
beta_hat <- my_values[[hh]]$beta_hat
y_hat <- my_values[[hh]]$y_hat
}
# Executed if par == TRUE, independent of parallel

beta1 <- as(beta * (abs(beta) > tol), "sparseMatrix")
theta1 <- as.sparse3Darray(theta * (abs(theta) > tol))
beta_hat1 <- as(beta_hat * (abs(beta_hat) > tol), "sparseMatrix")

n_interaction_terms <- count_nonzero_a((theta1))

n_main_terms <- (c(n_main_terms, count_nonzero_a((beta1))))

obj1 <- (sum(as.vector((y - y_hat)^2))) / (D * N)
obj <- c(obj, obj1)

non_zero_theta <- (c(non_zero_theta, n_interaction_terms))
lam_list <- (c(lam_list, lambda))

BETA0[[hh]] <- beta0
THETA0[[hh]] <- theta0
BETA[[hh]] <- as(beta1, "sparseMatrix")
BETA_hat[[hh]] <- as(beta_hat1, "sparseMatrix")

Y_HAT[[hh]] <- y_hat
THETA[[hh]] <- as.sparse3Darray(theta1)

if (my_print) {
if (hh == 1) {
print(c(hh, (n_main_terms[hh]), non_zero_theta[hh], obj1))
} else {
print(c(hh, (n_main_terms[hh]), non_zero_theta[hh], obj[hh - 1], obj1))
}
}

hh <- hh + 1
} ### lambda
out <- list(
obj = obj, n_main_terms = n_main_terms, non_zero_theta = non_zero_theta,
BETA0 = BETA0, THETA0 = THETA0, BETA = BETA, BETA_hat = BETA_hat,
Y_HAT = Y_HAT, THETA = THETA
)
} else {
out <- hh_nlambda_loop_cpp(
lam, nlambda, beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it,
my_W_hat, XtY, y, N, e.abs, e.rel, alpha, alph, svd.w, tree, my_print,
invmat, gg, tol, parallel, pal, BETA0, THETA0, BETA,
BETA_hat, Y_HAT, THETA, D, my_values
)
}
return(out)
}
26 changes: 26 additions & 0 deletions src/MADMMplasso.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include <RcppArmadillo.h>
Rcpp::List admm_MADMMplasso_cpp(
const arma::vec beta0,
const arma::mat theta0,
arma::mat beta,
arma::mat beta_hat,
arma::cube theta,
const double rho1,
const arma::mat X,
const arma::mat Z,
const int max_it,
const arma::mat W_hat,
const arma::mat XtY,
const arma::mat y,
const int N,
const double e_abs,
const double e_rel,
const double alpha,
const arma::vec lambda,
const double alph,
const Rcpp::List svd_w,
const Rcpp::List tree,
const Rcpp::List invmat,
const arma::vec gg,
const bool my_print = true
);
Loading
Loading