Skip to content

Commit

Permalink
Merge pull request #29 from ocbe-uio/fix-reg
Browse files Browse the repository at this point in the history
Fix `reg()`
  • Loading branch information
Theo-qua authored Mar 1, 2024
2 parents 4ace926 + 0c77491 commit a4dbea8
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 2 deletions.
4 changes: 4 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ multiples_of <- function(x, divisor, subset_out = FALSE) {
.Call(`_MADMMplasso_multiples_of`, x, divisor, subset_out)
}

lm_arma <- function(R, Z) {
.Call(`_MADMMplasso_lm_arma`, R, Z)
}

reg <- function(r, Z) {
.Call(`_MADMMplasso_reg`, r, Z)
}
Expand Down
13 changes: 13 additions & 0 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,18 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// lm_arma
arma::vec lm_arma(const arma::vec& R, const arma::mat& Z);
RcppExport SEXP _MADMMplasso_lm_arma(SEXP RSEXP, SEXP ZSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< const arma::vec& >::type R(RSEXP);
Rcpp::traits::input_parameter< const arma::mat& >::type Z(ZSEXP);
rcpp_result_gen = Rcpp::wrap(lm_arma(R, Z));
return rcpp_result_gen;
END_RCPP
}
// reg
Rcpp::List reg(const arma::mat r, const arma::mat Z);
RcppExport SEXP _MADMMplasso_reg(SEXP rSEXP, SEXP ZSEXP) {
Expand Down Expand Up @@ -143,6 +155,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_MADMMplasso_model_p", (DL_FUNC) &_MADMMplasso_model_p, 6},
{"_MADMMplasso_modulo", (DL_FUNC) &_MADMMplasso_modulo, 2},
{"_MADMMplasso_multiples_of", (DL_FUNC) &_MADMMplasso_multiples_of, 3},
{"_MADMMplasso_lm_arma", (DL_FUNC) &_MADMMplasso_lm_arma, 2},
{"_MADMMplasso_reg", (DL_FUNC) &_MADMMplasso_reg, 2},
{"_MADMMplasso_scale_cpp", (DL_FUNC) &_MADMMplasso_scale_cpp, 2},
{"_MADMMplasso_sqrt_sum_squared_rows", (DL_FUNC) &_MADMMplasso_sqrt_sum_squared_rows, 1},
Expand Down
16 changes: 14 additions & 2 deletions src/reg.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]

// [[Rcpp::export]]
arma::vec lm_arma(const arma::vec &R, const arma::mat &Z) {
// Add a column of ones to Z
arma::mat Z_intercept = arma::join_rows(arma::ones<arma::vec>(Z.n_rows), Z);

// Solve the system of linear equations
arma::vec coefficients = arma::solve(Z_intercept, R);

return coefficients;
}

// [[Rcpp::export]]
Rcpp::List reg(
const arma::mat r,
Expand All @@ -10,9 +22,9 @@ Rcpp::List reg(
arma::mat theta01(Z.n_cols, r.n_cols, arma::fill::zeros);

for (arma::uword e = 0; e < r.n_cols; e++) {
arma::vec new1 = arma::solve(Z, r.col(e));
arma::vec new1 = lm_arma(r.col(e), Z);
beta01(e) = new1(0);
theta01.col(e) = new1.tail(new1.n_elem);
theta01.col(e) = new1.tail(new1.n_elem - 1);
}

Rcpp::List out = Rcpp::List::create(
Expand Down
23 changes: 23 additions & 0 deletions tests/testthat/test-reg.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Original function ============================================================
reg_R <- function(r, Z) {
beta01 <- matrix(0, 1, ncol(r))
theta01 <- matrix(0, ncol(Z), ncol(r))
for (e in seq_len(ncol(r))) {
new1 <- lm(r[, e] ~ Z, singular.ok = TRUE)
beta01[e] <- matrix(new1$coefficients[1])
theta01[, e] <- as.vector(new1$coefficients[-1])
}
return(list(beta0 = beta01, theta0 = theta01))
}

# Testing ======================================================================
reps <- 10L
n_obs <- rpois(reps, lambda = 10L)
n_vars <- sample(2:10, reps, replace = TRUE)
test_that("reg() produces the correct output", {
for (rp in seq_len(reps)) {
r <- matrix(rnorm(n_obs[rp] * n_vars[rp]), n_obs[rp], n_vars[rp])
z <- as.matrix(sample(0:1, n_obs[rp], replace = TRUE))
expect_identical(reg(r, z), reg_R(r, z), tolerance = 1e-10)
}
})

0 comments on commit a4dbea8

Please sign in to comment.