diff --git a/R/RcppExports.R b/R/RcppExports.R index e22b1d2..017fb2e 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -53,6 +53,10 @@ multiples_of <- function(x, divisor, subset_out = FALSE) { .Call(`_MADMMplasso_multiples_of`, x, divisor, subset_out) } +lm_arma <- function(R, Z) { + .Call(`_MADMMplasso_lm_arma`, R, Z) +} + reg <- function(r, Z) { .Call(`_MADMMplasso_reg`, r, Z) } diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index dd8e899..5900628 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -101,6 +101,18 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// lm_arma +arma::vec lm_arma(const arma::vec& R, const arma::mat& Z); +RcppExport SEXP _MADMMplasso_lm_arma(SEXP RSEXP, SEXP ZSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const arma::vec& >::type R(RSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type Z(ZSEXP); + rcpp_result_gen = Rcpp::wrap(lm_arma(R, Z)); + return rcpp_result_gen; +END_RCPP +} // reg Rcpp::List reg(const arma::mat r, const arma::mat Z); RcppExport SEXP _MADMMplasso_reg(SEXP rSEXP, SEXP ZSEXP) { @@ -143,6 +155,7 @@ static const R_CallMethodDef CallEntries[] = { {"_MADMMplasso_model_p", (DL_FUNC) &_MADMMplasso_model_p, 6}, {"_MADMMplasso_modulo", (DL_FUNC) &_MADMMplasso_modulo, 2}, {"_MADMMplasso_multiples_of", (DL_FUNC) &_MADMMplasso_multiples_of, 3}, + {"_MADMMplasso_lm_arma", (DL_FUNC) &_MADMMplasso_lm_arma, 2}, {"_MADMMplasso_reg", (DL_FUNC) &_MADMMplasso_reg, 2}, {"_MADMMplasso_scale_cpp", (DL_FUNC) &_MADMMplasso_scale_cpp, 2}, {"_MADMMplasso_sqrt_sum_squared_rows", (DL_FUNC) &_MADMMplasso_sqrt_sum_squared_rows, 1}, diff --git a/src/reg.cpp b/src/reg.cpp index a3cbe47..27efc6c 100644 --- a/src/reg.cpp +++ b/src/reg.cpp @@ -1,5 +1,17 @@ #include // [[Rcpp::depends(RcppArmadillo)]] + +// [[Rcpp::export]] +arma::vec lm_arma(const arma::vec &R, const arma::mat &Z) { + // Add a column of ones to Z + arma::mat Z_intercept = arma::join_rows(arma::ones(Z.n_rows), Z); + + // Solve the system of linear equations + arma::vec coefficients = arma::solve(Z_intercept, R); + + return coefficients; +} + // [[Rcpp::export]] Rcpp::List reg( const arma::mat r, @@ -10,9 +22,9 @@ Rcpp::List reg( arma::mat theta01(Z.n_cols, r.n_cols, arma::fill::zeros); for (arma::uword e = 0; e < r.n_cols; e++) { - arma::vec new1 = arma::solve(Z, r.col(e)); + arma::vec new1 = lm_arma(r.col(e), Z); beta01(e) = new1(0); - theta01.col(e) = new1.tail(new1.n_elem); + theta01.col(e) = new1.tail(new1.n_elem - 1); } Rcpp::List out = Rcpp::List::create( diff --git a/tests/testthat/test-reg.R b/tests/testthat/test-reg.R new file mode 100644 index 0000000..0d78536 --- /dev/null +++ b/tests/testthat/test-reg.R @@ -0,0 +1,23 @@ +# Original function ============================================================ +reg_R <- function(r, Z) { + beta01 <- matrix(0, 1, ncol(r)) + theta01 <- matrix(0, ncol(Z), ncol(r)) + for (e in seq_len(ncol(r))) { + new1 <- lm(r[, e] ~ Z, singular.ok = TRUE) + beta01[e] <- matrix(new1$coefficients[1]) + theta01[, e] <- as.vector(new1$coefficients[-1]) + } + return(list(beta0 = beta01, theta0 = theta01)) +} + +# Testing ====================================================================== +reps <- 10L +n_obs <- rpois(reps, lambda = 10L) +n_vars <- sample(2:10, reps, replace = TRUE) +test_that("reg() produces the correct output", { + for (rp in seq_len(reps)) { + r <- matrix(rnorm(n_obs[rp] * n_vars[rp]), n_obs[rp], n_vars[rp]) + z <- as.matrix(sample(0:1, n_obs[rp], replace = TRUE)) + expect_identical(reg(r, z), reg_R(r, z), tolerance = 1e-10) + } +})