diff --git a/R/args.R b/R/args.R index 6373eb07..ffc50654 100644 --- a/R/args.R +++ b/R/args.R @@ -715,12 +715,6 @@ validate_cmdstan_args <- function(self) { } validate_init(self$init, num_inits) validate_seed(self$seed, num_procs) - if (!is.null(self$opencl_ids)) { - if (cmdstan_version() < "2.26") { - stop("Runtime selection of OpenCL devices is only supported with CmdStan version 2.26 or newer.", call. = FALSE) - } - checkmate::assert_vector(self$opencl_ids, len = 2) - } invisible(TRUE) } diff --git a/R/model.R b/R/model.R index e17ffbd4..fbb06918 100644 --- a/R/model.R +++ b/R/model.R @@ -1,3 +1,4 @@ + #' Create a new CmdStanModel object #' #' @description \if{html}{\figure{logo.png}{options: width=25}} @@ -230,10 +231,13 @@ CmdStanModel <- R6::R6Class( stanc_options_ = list(), include_paths_ = NULL, using_user_header_ = FALSE, - precompile_cpp_options_ = NULL, + precompile_cpp_options_ = list(), precompile_stanc_options_ = NULL, precompile_include_paths_ = NULL, - variables_ = NULL + variables_ = NULL, + exe_info_ = list(), + # intentionally only set at compile(), not initialize() + cmdstan_version_ = NULL ), public = list( functions = NULL, @@ -248,7 +252,7 @@ CmdStanModel <- R6::R6Class( private$stan_file_ <- absolute_path(stan_file) private$stan_code_ <- readLines(stan_file) private$model_name_ <- sub(" ", "_", strip_ext(basename(private$stan_file_))) - private$precompile_cpp_options_ <- args$cpp_options %||% list() + private$precompile_cpp_options_ <- validate_precompile_cpp_options(args$cpp_options) %||% list() private$precompile_stanc_options_ <- assert_valid_stanc_options(args$stanc_options) %||% list() if (!is.null(args$user_header) || !is.null(args$cpp_options[["USER_HEADER"]]) || !is.null(args$cpp_options[["user_header"]])) { @@ -270,22 +274,43 @@ CmdStanModel <- R6::R6Class( } if (!is.null(stan_file) && compile) { self$compile(...) - } - if (length(self$exe_file()) > 0 && file.exists(self$exe_file())) { - cpp_options <- model_compile_info(self$exe_file()) - for (cpp_option_name in names(cpp_options)) { - if (cpp_option_name != "stan_version" && - (!is.logical(cpp_options[[cpp_option_name]]) || isTRUE(cpp_options[[cpp_option_name]]))) { - private$cpp_options_[[cpp_option_name]] <- cpp_options[[cpp_option_name]] + } else { + # set exe path, same logic as in compile + if(!is.null(private$dir_)){ + dir <- repair_path(absolute_path(private$dir_)) + assert_dir_exists(dir, access = "rw") + if (length(self$exe_file()) != 0) { + self$exe_file(file.path(dir, basename(self$exe_file()))) + } + } + if (length(self$exe_file()) == 0) { + if (is.null(private$dir_)) { + exe_base <- self$stan_file() + } else { + exe_base <- file.path(private$dir_, basename(self$stan_file())) + } + self$exe_file(cmdstan_ext(strip_ext(exe_base))) + if (dir.exists(self$exe_file())) { + stop("There is a subfolder matching the model name in the same folder as the model! Please remove or rename the subfolder and try again.", call. = FALSE) } } + + # exe_info is updated inside the compile method (if compile command is run) + exe_info <- self$exe_info(update = TRUE) + if(file.exists(self$exe_file())) exe_info_reflects_cpp_options(self$exe_info(), args$cpp_options) + } + if (length(self$exe_file()) > 0 && file.exists(self$exe_file())) { + private$cpp_options_ <- model_compile_info_legacy(self$exe_file()) } invisible(self) }, include_paths = function() { - if (length(self$exe_file()) > 0 && file.exists(self$exe_file())) { + # checks whether a compile has occurred since object creation + if (!is.null(private$cmdstan_version_)) { + # yes, compile occurred return(private$include_paths_) } else { + # no, compile did not occur return(private$precompile_include_paths_) } }, @@ -328,9 +353,72 @@ CmdStanModel <- R6::R6Class( } private$exe_file_ }, + exe_info = function(update = FALSE) { + if (update) { + if (!file.exists(private$exe_file_)) return(NULL) + ret <- run_info_cli(private$exe_file_) + # Above command will return non-zero if + # cmdstan version < "2.26.1" + + cli_info_success <- !is.null(ret$status) && ret$status == 0 + exe_info <- if (cli_info_success) parse_exe_info_string(ret$stdout) else list() + cpp_options <- exe_info_style_cpp_options(private$precompile_cpp_options_) + compiled_with_cpp_options <- !is.null(private$cmdstan_version_) + + private$exe_info_ <- if (compiled_with_cpp_options) { + # recompile has occurred since the CmdStanModel was created + # cpp_options as were used as configured + c( + # info cli as source of truth + exe_info, + # use cpp_options for options not provided in info + cpp_options[!names(cpp_options) %in% names(exe_info)] + ) + } else if (cli_info_success) { + # no compile/recompile has occurred, we only trust info cli + # don't know if other cpp_options were applied, so skip them + exe_info + } else { + # info cli failure + no compile/recompile has occurred + warning( + 'Retrieving exe_file info failed. ', + 'This may be due to running a model that was compiled with pre-2.26.1 cmdstan.' + ) + NULL + } + } + private$exe_info_ + }, + exe_info_fallback = function() { + c( + # current cmdstan_version, may or may not be compiled with this version + list(stan_version = cmdstan_version()), + + # user provided args, may or may not match binary + exe_info_style_cpp_options(private$precompile_cpp_options_) + ) + }, + cmdstan_version = function(fallback = TRUE) { + # this is intentionally not private$cmdstan_version_ + # because that value is only set if model has been recomplied + # since CmdStanModel instantiation + if (!fallback) self$exe_info()[['stan_version']] + for (candidate in c( + self$exe_info()[['stan_version']], + self$exe_info_fallback()[['stan_version']] + )) if (!is.null(candidate)) return (candidate) + }, cpp_options = function() { + warning( + 'mod$cpp_options() will be deprecated in the next major version of cmdstanr. ', + 'Use mod$exe_info() to see options from last compilation. ', + 'Use mod$precompile_cpp_options() to see default options for next compilation.' + ) private$cpp_options_ }, + precompile_cpp_options = function() { + private$precompile_cpp_options_ + }, hpp_file = function() { if (!length(private$hpp_file_)) { stop("The .hpp file does not exists. Please (re)compile the model.", call. = FALSE) @@ -398,10 +486,11 @@ CmdStanModel <- R6::R6Class( #' program. #' @param user_header (string) The path to a C++ file (with a .hpp extension) #' to compile with the Stan model. -#' @param cpp_options (list) Any makefile options to be used when compiling the +#' @param cpp_options (list) Makefile options to be used when compiling the #' model (`STAN_THREADS`, `STAN_MPI`, `STAN_OPENCL`, etc.). Anything you would -#' otherwise write in the `make/local` file. For an example of using threading -#' see the Stan case study +#' otherwise write in the `make/local` file. Setting a value to `NULL` or `""` +#' within the list unsets the flag. +#' For an example of using threading see the Stan case study. #' [Reduce Sum: A Minimal Example](https://mc-stan.org/users/documentation/case-studies/reduce_sum_tutorial.html). #' @param stanc_options (list) Any Stan-to-C++ transpiler options to be used #' when compiling the model. See the **Examples** section below as well as the @@ -478,14 +567,20 @@ compile <- function(quiet = TRUE, #deprecated compile_hessian_method = FALSE, threads = FALSE) { - if (length(self$stan_file()) == 0) { stop("'$compile()' cannot be used because the 'CmdStanModel' was not created with a Stan file.", call. = FALSE) } assert_stan_file_exists(self$stan_file()) + + if (!is.null(user_header) && ( + !is.null(cpp_options[["USER_HEADER"]]) || !is.null(cpp_options[["user_header"]]) + )) warning("User header specified both via user_header argument and via cpp_options arguments") + if (length(cpp_options) == 0 && !is.null(private$precompile_cpp_options_)) { cpp_options <- private$precompile_cpp_options_ } + cpp_options <- validate_precompile_cpp_options(cpp_options) + if (length(stanc_options) == 0 && !is.null(private$precompile_stanc_options_)) { stanc_options <- private$precompile_stanc_options_ } @@ -544,21 +639,9 @@ compile <- function(quiet = TRUE, # Note that unlike cpp_options["USER_HEADER"], the user_header variable is deliberately # not transformed with wsl_safe_path() as that breaks the check below on WSLv1 if (!is.null(user_header)) { - if (!is.null(cpp_options[["USER_HEADER"]]) || !is.null(cpp_options[["user_header"]])) { - warning("User header specified both via user_header argument and via cpp_options arguments") - } - - cpp_options[["USER_HEADER"]] <- wsl_safe_path(absolute_path(user_header)) + cpp_options[["user_header"]] <- wsl_safe_path(absolute_path(user_header)) stanc_options[["allow-undefined"]] <- TRUE private$using_user_header_ <- TRUE - } else if (!is.null(cpp_options[["USER_HEADER"]])) { - if (!is.null(cpp_options[["user_header"]])) { - warning('User header specified both via cpp_options[["USER_HEADER"]] and cpp_options[["user_header"]].', call. = FALSE) - } - - user_header <- cpp_options[["USER_HEADER"]] - cpp_options[["USER_HEADER"]] <- wsl_safe_path(absolute_path(cpp_options[["USER_HEADER"]])) - private$using_user_header_ <- TRUE } else if (!is.null(cpp_options[["user_header"]])) { user_header <- cpp_options[["user_header"]] cpp_options[["user_header"]] <- wsl_safe_path(absolute_path(cpp_options[["user_header"]])) @@ -578,6 +661,9 @@ compile <- function(quiet = TRUE, # - the executable does not exist # - the stan model was changed since last compilation # - a user header is used and the user header changed since last compilation (#813) + self$exe_file(exe) + self$exe_info(update = TRUE) + if (!file.exists(exe)) { force_recompile <- TRUE } else if (file.exists(self$stan_file()) @@ -587,18 +673,20 @@ compile <- function(quiet = TRUE, && file.exists(user_header) && file.mtime(exe) < file.mtime(user_header)) { force_recompile <- TRUE + } else if (!isTRUE(exe_info_reflects_cpp_options(self$exe_info(), cpp_options))) { + force_recompile <- TRUE } + if (!force_recompile && rlang::is_interactive()) { + message("Model executable is up to date!") + } + if (!force_recompile) { - if (rlang::is_interactive()) { - message("Model executable is up to date!") - } private$cpp_options_ <- cpp_options - private$precompile_cpp_options_ <- NULL + private$precompile_cpp_options_ <- cpp_options private$precompile_stanc_options_ <- NULL private$precompile_include_paths_ <- NULL self$functions$existing_exe <- TRUE - self$exe_file(exe) return(invisible(self)) } else { if (rlang::is_interactive()) { @@ -654,7 +742,6 @@ compile <- function(quiet = TRUE, self$functions$existing_exe <- FALSE stancflags_val <- paste0("STANCFLAGS += ", stancflags_val, paste0(" ", stancflags_combined, collapse = " ")) - if (!dry_run) { if (compile_standalone) { @@ -737,11 +824,15 @@ compile <- function(quiet = TRUE, con = wsl_safe_path(private$hpp_file_, revert = TRUE)) } # End - if(!dry_run) + private$cmdstan_version_ <- cmdstan_version() private$exe_file_ <- exe - private$cpp_options_ <- cpp_options - private$precompile_cpp_options_ <- NULL + private$precompile_cpp_options_ <- cpp_options private$precompile_stanc_options_ <- NULL private$precompile_include_paths_ <- NULL + + # Must be run after private$cmdstan_version_, private$exe_file_, and private$precompiled_cpp_options_ + # are all up to date + self$exe_info(update=TRUE) if(!dry_run) { if (compile_model_methods) { @@ -786,7 +877,7 @@ CmdStanModel$set("public", name = "compile", value = compile) #' } #' variables <- function() { - if (cmdstan_version() < "2.27.0") { + if (self$cmdstan_version() < "2.27.0") { stop("$variables() is only supported for CmdStan 2.27 or newer.", call. = FALSE) } if (length(self$stan_file()) == 0) { @@ -864,6 +955,7 @@ check_syntax <- function(pedantic = FALSE, include_paths = NULL, stanc_options = list(), quiet = FALSE) { + if (length(self$stan_file()) == 0) { stop("'$check_syntax()' cannot be used because the 'CmdStanModel' was not created with a Stan file.", call. = FALSE) } @@ -1208,7 +1300,7 @@ sample <- function(data = NULL, } } - if (cmdstan_version() >= "2.27.0" && !fixed_param) { + if (self$cmdstan_version() >= "2.27.0" && !fixed_param) { if (self$has_stan_file() && file.exists(self$stan_file())) { if (!is.null(self$variables()) && length(self$variables()$parameters) == 0) { stop("Model contains no parameters. Please use 'fixed_param = TRUE'.", call. = FALSE) @@ -1221,7 +1313,7 @@ sample <- function(data = NULL, procs <- CmdStanMCMCProcs$new( num_procs = checkmate::assert_integerish(chains, lower = 1, len = 1), parallel_procs = checkmate::assert_integerish(parallel_chains, lower = 1, null.ok = TRUE), - threads_per_proc = assert_valid_threads(threads_per_chain, self$cpp_options(), multiple_chains = TRUE), + threads_per_proc = assert_valid_threads(threads_per_chain, self$exe_info(), self$exe_info_fallback(), multiple_chains = TRUE), show_stderr_messages = show_exceptions, show_stdout_messages = show_messages ) @@ -1265,7 +1357,7 @@ sample <- function(data = NULL, output_dir = output_dir, output_basename = output_basename, sig_figs = sig_figs, - opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()), + opencl_ids = assert_valid_opencl(opencl_ids, self$exe_info(), self$exe_info_fallback()), model_variables = model_variables, save_cmdstan_config = save_cmdstan_config ) @@ -1515,7 +1607,7 @@ optimize <- function(data = NULL, num_procs = 1, show_stderr_messages = show_exceptions, show_stdout_messages = show_messages, - threads_per_proc = assert_valid_threads(threads, self$cpp_options()) + threads_per_proc = assert_valid_threads(threads, self$exe_info(), self$exe_info_fallback()) ) model_variables <- NULL if (is_variables_method_supported(self)) { @@ -1550,7 +1642,7 @@ optimize <- function(data = NULL, output_dir = output_dir, output_basename = output_basename, sig_figs = sig_figs, - opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()), + opencl_ids = assert_valid_opencl(opencl_ids, self$exe_info(), self$exe_info_fallback()), model_variables = model_variables, save_cmdstan_config = save_cmdstan_config ) @@ -1655,7 +1747,7 @@ laplace <- function(data = NULL, num_procs = 1, show_stderr_messages = show_exceptions, show_stdout_messages = show_messages, - threads_per_proc = assert_valid_threads(threads, self$cpp_options()) + threads_per_proc = assert_valid_threads(threads, self$exe_info(), self$exe_info_fallback()) ) model_variables <- NULL if (is_variables_method_supported(self)) { @@ -1717,7 +1809,7 @@ laplace <- function(data = NULL, output_dir = output_dir, output_basename = output_basename, sig_figs = sig_figs, - opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()), + opencl_ids = assert_valid_opencl(opencl_ids, self$exe_info(), self$exe_info_fallback()), model_variables = model_variables, save_cmdstan_config = save_cmdstan_config ) @@ -1805,7 +1897,7 @@ variational <- function(data = NULL, num_procs = 1, show_stderr_messages = show_exceptions, show_stdout_messages = show_messages, - threads_per_proc = assert_valid_threads(threads, self$cpp_options()) + threads_per_proc = assert_valid_threads(threads, self$exe_info(), self$exe_info_fallback()) ) model_variables <- NULL if (is_variables_method_supported(self)) { @@ -1840,7 +1932,7 @@ variational <- function(data = NULL, output_dir = output_dir, output_basename = output_basename, sig_figs = sig_figs, - opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()), + opencl_ids = assert_valid_opencl(opencl_ids, self$exe_info(), self$exe_info_fallback()), model_variables = model_variables, save_cmdstan_config = save_cmdstan_config ) @@ -1950,7 +2042,7 @@ pathfinder <- function(data = NULL, num_procs = 1, show_stderr_messages = show_exceptions, show_stdout_messages = show_messages, - threads_per_proc = assert_valid_threads(num_threads, self$cpp_options()) + threads_per_proc = assert_valid_threads(num_threads, self$exe_info(), self$exe_info_fallback()) ) model_variables <- NULL if (is_variables_method_supported(self)) { @@ -1990,7 +2082,7 @@ pathfinder <- function(data = NULL, output_dir = output_dir, output_basename = output_basename, sig_figs = sig_figs, - opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()), + opencl_ids = assert_valid_opencl(opencl_ids, self$exe_info(), self$exe_info_fallback()), model_variables = model_variables, num_threads = num_threads, save_cmdstan_config = save_cmdstan_config @@ -2087,7 +2179,7 @@ generate_quantities <- function(fitted_params, procs <- CmdStanGQProcs$new( num_procs = length(fitted_params_files), parallel_procs = checkmate::assert_integerish(parallel_chains, lower = 1, null.ok = TRUE), - threads_per_proc = assert_valid_threads(threads_per_chain, self$cpp_options(), multiple_chains = TRUE) + threads_per_proc = assert_valid_threads(threads_per_chain, self$exe_info(), self$exe_info_fallback(), multiple_chains = TRUE) ) model_variables <- NULL if (is_variables_method_supported(self)) { @@ -2108,7 +2200,7 @@ generate_quantities <- function(fitted_params, output_dir = output_dir, output_basename = output_basename, sig_figs = sig_figs, - opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()), + opencl_ids = assert_valid_opencl(opencl_ids, self$exe_info(), self$exe_info_fallback()), model_variables = model_variables ) runset <- CmdStanRun$new(args, procs) @@ -2243,40 +2335,111 @@ CmdStanModel$set("public", name = "expose_functions", value = expose_functions) # internal ---------------------------------------------------------------- -assert_valid_opencl <- function(opencl_ids, cpp_options) { - if (is.null(cpp_options[["stan_opencl"]]) - && !is.null(opencl_ids)) { +assert_valid_opencl <- function( + opencl_ids, + exe_info, + fallback_exe_info = list('stan_version' = '2.0.0', 'stan_opencl' = FALSE) +) { + if (is.null(opencl_ids)) return(invisible(opencl_ids)) + + fallback <- length(exe_info) == 0 + if(fallback) exe_info <- fallback_exe_info + # If we're unsure if this info is accurate, we shouldn't stop the user from attempting on that basis + # the user should have been warned about this in initialize(), so no need to re-warn here. + if(fallback) stop <- warning + + if (exe_info[['stan_version']] < "2.26.0") { + stop("Runtime selection of OpenCL devices is only supported with CmdStan version 2.26 or newer.", call. = FALSE) + } + + if (isFALSE(exe_info[["stan_opencl"]])) { stop("'opencl_ids' is set but the model was not compiled with for use with OpenCL.", "\nRecompile the model with 'cpp_options = list(stan_opencl = TRUE)'", call. = FALSE) } + checkmate::assert_vector(opencl_ids, len = 2) invisible(opencl_ids) } -assert_valid_threads <- function(threads, cpp_options, multiple_chains = FALSE) { +assert_valid_threads <- function(threads, exe_info, fallback_exe_info, multiple_chains = FALSE) { + fallback <- length(exe_info) == 0 + if(fallback) exe_info <- fallback_exe_info + # If we're unsure if this info is accurate, we shouldn't stop the user from attempting on that basis + # the user should have been warned about this in initialize(), so no need to re-warn here. + if(fallback) stop <- warning + threads_arg <- if (multiple_chains) "threads_per_chain" else "threads" checkmate::assert_integerish(threads, .var.name = threads_arg, null.ok = TRUE, lower = 1, len = 1) - if (is.null(cpp_options[["stan_threads"]]) || !isTRUE(cpp_options[["stan_threads"]])) { - if (!is.null(threads)) { - warning( - "'", threads_arg, "' is set but the model was not compiled with ", - "'cpp_options = list(stan_threads = TRUE)' ", - "so '", threads_arg, "' will have no effect!", - call. = FALSE - ) - threads <- NULL - } - } else if (isTRUE(cpp_options[["stan_threads"]]) && is.null(threads)) { + if (isTRUE(exe_info[["stan_threads"]]) && is.null(threads)) { stop( "The model was compiled with 'cpp_options = list(stan_threads = TRUE)' ", - "but '", threads_arg, "' was not set!", + "or equivalent, but '", threads_arg, "' was not set!", call. = FALSE ) + } else if (!exe_info[["stan_threads"]] && !is.null(threads)) { + warning( + "'", threads_arg, "' is set but the model was not compiled with ", + "'cpp_options = list(stan_threads = TRUE)' or equivalent ", + "so '", threads_arg, "' will have no effect!", + call. = FALSE + ) + if (!fallback) threads <- NULL } invisible(threads) } +validate_precompile_cpp_options <- function(cpp_options) { + if(is.null(cpp_options) || length(cpp_options) == 0) return(list()) + + if (!is.null(cpp_options[["user_header"]]) && !is.null(cpp_options[['USER_HEADER']])) { + warning('User header specified both via cpp_options[["USER_HEADER"]] and cpp_options[["user_header"]].', call. = FALSE) + } + + names(cpp_options) <- tolower(names(cpp_options)) + flags_set_if_defined <- c( + # cmdstan + "stan_threads", "stan_mpi", "stan_opencl", "stan_no_range_checks", "stan_cpp_optims", + # stan math + "integrated_opencl", "tbb_lib", "tbb_inc", "tbb_interface_new" + ) + for (flag in flags_set_if_defined) { + if (isFALSE(cpp_options[[flag]])) warning( + toupper(flag), " set to ", cpp_options[flag], " Since this is a non-empty value, ", + "it will result in the corresponding ccp option being turned ON. To turn this", + " option off, use cpp_options = list(", flag, " = NULL)." + ) + } + cpp_options +} + +exe_info_style_cpp_options <- function(cpp_options) { + names(cpp_options) <- tolower(names(cpp_options)) + flags_reported_in_exe_info <- c( + "stan_threads", "stan_mpi", "stan_opencl", "stan_no_range_checks", "stan_cpp_optims" + ) + for (flag in flags_reported_in_exe_info) { + cpp_options[[flag]] <- !(is.null(cpp_options[[flag]]) || cpp_options[[flag]] == '') + } + cpp_options +} + +exe_info_reflects_cpp_options <- function(exe_info, cpp_options) { + if(length(exe_info) == 0) { + warning('Recompiling is recommended due to missing exe_info.') + return(TRUE) + } + if(is.null(cpp_options)) return(TRUE) + + cpp_options <- exe_info_style_cpp_options(cpp_options)[tolower(names(cpp_options))] + overlap <- names(cpp_options)[names(cpp_options) %in% names(exe_info)] + + if(length(overlap) == 0) TRUE else all.equal( + exe_info[overlap], + cpp_options[overlap] + ) +} + assert_valid_stanc_options <- function(stanc_options) { i <- 1 names <- names(stanc_options) @@ -2375,7 +2538,51 @@ model_variables <- function(stan_file, include_paths = NULL, allow_undefined = F variables } -model_compile_info <- function(exe_file) { +# Parse the string output of `info` into an R object (list) +parse_exe_info_string <- function(ret_stdout) { + info <- list() + info_raw <- strsplit(strsplit(ret_stdout, "\n")[[1]], "=") + for (key_val in info_raw) { + if (length(key_val) > 1) { + key_val <- trimws(key_val) + val <- key_val[2] + if (!is.na(as.logical(val))) { + val <- as.logical(val) + } + info[[tolower(key_val[1])]] <- val + } + } + + info[["stan_version"]] <- paste0(info[["stan_version_major"]], ".", info[["stan_version_minor"]], ".", info[["stan_version_patch"]]) + info[["stan_version_major"]] <- NULL + info[["stan_version_minor"]] <- NULL + info[["stan_version_patch"]] <- NULL + + info +} + +# run info command +run_info_cli <- function(exe_file) { + withr::with_path( + c( + toolchain_PATH_env_var(), + tbb_path() + ), + ret <- wsl_compatible_run( + command = wsl_safe_path(exe_file), + args = "info", + error_on_status = FALSE + ) + ) + ret +} + + +is_variables_method_supported <- function(mod) { + cmdstan_version() >= "2.27.0" && mod$has_stan_file() && file.exists(mod$stan_file()) +} + +model_compile_info_legacy <- function(exe_file) { info <- NULL if (cmdstan_version() > "2.26.1") { withr::with_path( @@ -2399,18 +2606,12 @@ model_compile_info <- function(exe_file) { if (!is.na(as.logical(val))) { val <- as.logical(val) } - info[[toupper(key_val[1])]] <- val + if(!is.logical(val) || isTRUE(val)) { + info[[tolower(key_val[1])]] <- val + } } } - info[["STAN_VERSION"]] <- paste0(info[["STAN_VERSION_MAJOR"]], ".", info[["STAN_VERSION_MINOR"]], ".", info[["STAN_VERSION_PATCH"]]) - info[["STAN_VERSION_MAJOR"]] <- NULL - info[["STAN_VERSION_MINOR"]] <- NULL - info[["STAN_VERSION_PATCH"]] <- NULL } } info } - -is_variables_method_supported <- function(mod) { - cmdstan_version() >= "2.27.0" && mod$has_stan_file() && file.exists(mod$stan_file()) -} diff --git a/R/path.R b/R/path.R index 15bbeae6..4172c86d 100644 --- a/R/path.R +++ b/R/path.R @@ -234,8 +234,16 @@ unset_cmdstan_path <- function() { } # fake a cmdstan version (only used in tests) -fake_cmdstan_version <- function(version) { +fake_cmdstan_version <- function(version, mod=NULL) { .cmdstanr$VERSION <- version + if(!is.null(mod)) { + if (!is.null(mod$.__enclos_env__$private$exe_info_)) { + mod$.__enclos_env__$private$exe_info_$stan_version <- version + } + if (!is.null(mod$.__enclos_env__$private$cmdstan_version_)) { + mod$.__enclos_env__$private$cmdstan_version_ <- version + } + } } reset_cmdstan_version <- function() { .cmdstanr$VERSION <- read_cmdstan_version(cmdstan_path()) diff --git a/man/model-method-compile.Rd b/man/model-method-compile.Rd index c92f2704..40e0f41e 100644 --- a/man/model-method-compile.Rd +++ b/man/model-method-compile.Rd @@ -45,10 +45,11 @@ program.} \item{user_header}{(string) The path to a C++ file (with a .hpp extension) to compile with the Stan model.} -\item{cpp_options}{(list) Any makefile options to be used when compiling the +\item{cpp_options}{(list) Makefile options to be used when compiling the model (\code{STAN_THREADS}, \code{STAN_MPI}, \code{STAN_OPENCL}, etc.). Anything you would -otherwise write in the \code{make/local} file. For an example of using threading -see the Stan case study +otherwise write in the \code{make/local} file. Setting a value to \code{NULL} or \code{""} +within the list unsets the flag. +For an example of using threading see the Stan case study. \href{https://mc-stan.org/users/documentation/case-studies/reduce_sum_tutorial.html}{Reduce Sum: A Minimal Example}.} \item{stanc_options}{(list) Any Stan-to-C++ transpiler options to be used diff --git a/tests/testthat/helper-custom-expectations.R b/tests/testthat/helper-custom-expectations.R index fd8d5565..86244d1e 100644 --- a/tests/testthat/helper-custom-expectations.R +++ b/tests/testthat/helper-custom-expectations.R @@ -100,3 +100,11 @@ expect_noninteractive_silent <- function(object) { rlang::with_interactive(value = FALSE, expect_silent(object)) } + +expect_equal_ignore_order <- function(object, expected, ...){ + object <- expected[sort(names(object))] + expected <- expected[sort(names(expected))] + expect_equal(object, expected, ...) +} + +expect_not_true <- function(...) expect_false(isTRUE(...)) diff --git a/tests/testthat/helper-mock-cli.R b/tests/testthat/helper-mock-cli.R new file mode 100644 index 00000000..a654ab5a --- /dev/null +++ b/tests/testthat/helper-mock-cli.R @@ -0,0 +1,22 @@ +real_wcr <- wsl_compatible_run + +with_mocked_cli <- function(code, compile_ret, info_ret){ + with_mocked_bindings( + code, + wsl_compatible_run = function(command, args, ...) { + if ( + !is.null(command) + && command == 'make' + && !is.null(args) + && startsWith(basename(args[1]), 'model-') + ) { + message("mock-compile-was-called") + compile_ret + } else if (!is.null(args) && args[1] == "info") info_ret + else real_wcr(command = command, args = args, ...) + } + ) +} + +expect_mock_compile <- function(object, ...) expect_message(object, regexp = 'mock-compile-was-called', ...) +expect_no_mock_compile <- function(object, ...) expect_no_message(object, message = 'mock-compile-was-called' , ...) diff --git a/tests/testthat/helper-models.R b/tests/testthat/helper-models.R index b0773e8b..0ffdfc61 100644 --- a/tests/testthat/helper-models.R +++ b/tests/testthat/helper-models.R @@ -14,6 +14,11 @@ cmdstan_example_file <- function() { file.path(cmdstan_path(), "examples", "bernoulli", "bernoulli.stan") } +cmdstan_example_exe_file <- function() { + # stan program in different directory from the others + file.path(cmdstan_path(), "examples", "bernoulli", "bernoulli.stan") +} + testing_model <- function(name) { cmdstan_model(stan_file = testing_stan_file(name)) } diff --git a/tests/testthat/test-example.R b/tests/testthat/test-example.R index 8d14d5d0..157929e0 100644 --- a/tests/testthat/test-example.R +++ b/tests/testthat/test-example.R @@ -1,7 +1,7 @@ context("cmdstanr_example") test_that("cmdstanr_example works", { - fit_mcmc <- cmdstanr_example("logistic", chains = 2) + fit_mcmc <- cmdstanr_example("logistic", chains = 2, force_recompile = TRUE) checkmate::expect_r6(fit_mcmc, "CmdStanMCMC") expect_equal(fit_mcmc$num_chains(), 2) diff --git a/tests/testthat/test-model-compile-user_header.R b/tests/testthat/test-model-compile-user_header.R new file mode 100644 index 00000000..57047804 --- /dev/null +++ b/tests/testthat/test-model-compile-user_header.R @@ -0,0 +1,104 @@ + +file_that_exists <- 'placeholder_exists' +file_that_doesnt_exist <- 'placeholder_doesnt_exist' +file.create(file_that_exists) +on.exit(if(file.exists(file_that_exists)) file.remove(file_that_exists), add=TRUE, after=FALSE) + +make_local_orig <- cmdstan_make_local() +cmdstan_make_local(cpp_options = list("PRECOMPILED_HEADERS"="false")) +on.exit(cmdstan_make_local(cpp_options = make_local_orig, append = FALSE), add = TRUE, after = FALSE) + +test_that("cmdstan_model works with user_header with mock", { + skip_if(os_is_macos()) + tmpfile <- tempfile(fileext = ".hpp") + hpp <- + " + #include + #include + #include + + namespace bernoulli_external_model_namespace + { + template >* = nullptr> + inline typename boost::math::tools::promote_args::type make_odds(const T0__ & + theta, + std::ostream *pstream__) + { + return theta / (1 - theta); + } + }" + cat(hpp, file = tmpfile, sep = "\n") + + with_mocked_cli(compile_ret = list(status = 0), info_ret = list(), code = expect_mock_compile( + expect_warning( + expect_no_warning({ + mod <- cmdstan_model( + stan_file = testing_stan_file("bernoulli_external"), + exe_file = file_that_exists, + user_header = tmpfile + ) + }, message = 'Recompiling is recommended'), # this warning should not occur because recompile happens automatically + 'Retrieving exe_file info failed' # this warning should occur + ) + )) + + with_mocked_cli(compile_ret = list(status = 0), info_ret = list(), code = expect_mock_compile({ + mod_2 <- cmdstan_model( + stan_file = testing_stan_file("bernoulli_external"), + exe_file = file_that_doesnt_exist, + cpp_options=list(USER_HEADER=tmpfile), + stanc_options = list("allow-undefined") + ) + })) + + # Check recompilation upon changing header + file.create(file_that_exists) + with_mocked_cli(compile_ret = list(status = 0), info_ret = list(), code = expect_no_mock_compile({ + mod$compile(quiet = TRUE, user_header = tmpfile) + })) + + Sys.setFileTime(tmpfile, Sys.time() + 1) # touch file to trigger recompile + with_mocked_cli(compile_ret = list(status = 0), info_ret = list(), code = expect_mock_compile({ + mod$compile(quiet = TRUE, user_header = tmpfile) + })) + + # mock does not automatically update file mtime + Sys.setFileTime(mod$exe_file(), Sys.time() + 1) # touch file to trigger recompile + + # Alternative spec of user header + with_mocked_cli(compile_ret = list(status = 0), info_ret = list(), code = expect_no_mock_compile({ + mod$compile( + quiet = TRUE, + cpp_options = list(user_header = tmpfile), + dry_run = TRUE + )})) + + # Error/warning messages + with_mocked_cli(compile_ret = list(status = 1), info_ret = list(), code = expect_error( + cmdstan_model( + stan_file = testing_stan_file("bernoulli_external"), + cpp_options = list(USER_HEADER = "non_existent.hpp"), + stanc_options = list("allow-undefined") + ), + "header file '[^']*' does not exist" + )) + + with_mocked_cli(compile_ret = list(status = 1), info_ret = list(), code = expect_warning( + cmdstan_model( + stan_file = testing_stan_file("bernoulli_external"), + cpp_options = list(USER_HEADER = tmpfile, user_header = tmpfile), + dry_run = TRUE + ), + "User header specified both" + )) + with_mocked_cli(compile_ret = list(status = 1), info_ret = list(), code = expect_warning( + cmdstan_model( + stan_file = testing_stan_file("bernoulli_external"), + user_header = tmpfile, + cpp_options = list(USER_HEADER = tmpfile), + dry_run = TRUE + ), + "User header specified both" + )) +}) diff --git a/tests/testthat/test-model-compile.R b/tests/testthat/test-model-compile.R index 2be8390f..dde9ee24 100644 --- a/tests/testthat/test-model-compile.R +++ b/tests/testthat/test-model-compile.R @@ -2,19 +2,25 @@ context("model-compile") set_cmdstan_path() stan_program <- cmdstan_example_file() +exe <- cmdstan_ext(strip_ext(stan_program)) +if (file.exists(exe)) file.remove(exe) + mod <- cmdstan_model(stan_file = stan_program, compile = FALSE) + +make_local_orig <- cmdstan_make_local() cmdstan_make_local(cpp_options = list("PRECOMPILED_HEADERS"="false")) +on.exit(cmdstan_make_local(cpp_options = make_local_orig, append = FALSE), add = TRUE, after = FALSE) test_that("object initialized correctly", { expect_equal(mod$stan_file(), stan_program) - expect_equal(mod$exe_file(), character(0)) + expect_equal(mod$exe_file(), exe) + expect_false(file.exists(mod$exe_file())) expect_error( mod$hpp_file(), "The .hpp file does not exists. Please (re)compile the model.", fixed = TRUE ) }) - test_that("error if no compile() before model fitting", { expect_error( mod$sample(), @@ -25,7 +31,6 @@ test_that("error if no compile() before model fitting", { test_that("compile() method works", { # remove executable if exists - exe <- cmdstan_ext(strip_ext(mod$stan_file())) if (file.exists(exe)) { file.remove(exe) } @@ -381,7 +386,6 @@ test_that("check_syntax() works with pedantic=TRUE", { fixed = TRUE ) }) - test_that("check_syntax() works with include_paths", { stan_program_w_include <- testing_stan_file("bernoulli_include") @@ -391,15 +395,20 @@ test_that("check_syntax() works with include_paths", { }) + +# Test Failing Due to Side effect ----- + test_that("check_syntax() works with include_paths on compiled model", { stan_program_w_include <- testing_stan_file("bernoulli_include") mod_w_include <- cmdstan_model(stan_file = stan_program_w_include, compile=TRUE, - include_paths = test_path("resources", "stan")) + include_paths = test_path("resources", "stan"), + force_recompile = TRUE) expect_true(mod_w_include$check_syntax()) }) + test_that("check_syntax() works with pedantic=TRUE", { model_code <- " transformed data { @@ -496,7 +505,11 @@ test_that("cpp_options work with settings in make/local", { rebuild_cmdstan() mod <- cmdstan_model(stan_file = stan_program) - expect_null(mod$cpp_options()$STAN_THREADS) + expect_null( + expect_warning(mod$cpp_options()$stan_threads, "Use mod\\$exe_info()") + ) + expect_false(mod$exe_info()$stan_threads) + expect_null(mod$precompile_cpp_options()$stan_threads) file.remove(mod$exe_file()) @@ -504,7 +517,10 @@ test_that("cpp_options work with settings in make/local", { file <- file.path(cmdstan_path(), "examples", "bernoulli", "bernoulli.stan") mod <- cmdstan_model(file) - expect_true(mod$cpp_options()$STAN_THREADS) + expect_true( + expect_warning(mod$cpp_options()$stan_threads, "Use mod\\$exe_info()") + ) + expect_true(mod$exe_info()$stan_threads) file.remove(mod$exe_file()) @@ -761,7 +777,8 @@ test_that("format() works with include_paths on compiled model", { stan_program_w_include <- testing_stan_file("bernoulli_include") mod_w_include <- cmdstan_model(stan_file = stan_program_w_include, compile=TRUE, - include_paths = test_path("resources", "stan")) + include_paths = test_path("resources", "stan"), + force_recompile = TRUE) expect_output( mod_w_include$format(), "#include ", @@ -789,6 +806,8 @@ test_that("overwrite_file works with format()", { } " stan_file_tmp <- write_stan_file(code) + on.exit(file.remove(stan_file_tmp)) + mod_1 <- cmdstan_model(stan_file_tmp, compile = FALSE) expect_false( any( @@ -852,4 +871,4 @@ test_that("STANCFLAGS included from make/local", { } expect_output(print(out), out_w_flags) cmdstan_make_local(cpp_options = make_local_old, append = FALSE) -}) +}) \ No newline at end of file diff --git a/tests/testthat/test-model-generate_quantities.R b/tests/testthat/test-model-generate_quantities.R index 7df7f2b7..9f5ec1fb 100644 --- a/tests/testthat/test-model-generate_quantities.R +++ b/tests/testthat/test-model-generate_quantities.R @@ -21,7 +21,6 @@ bad_arg_values <- list( parallel_chains = -20 ) - test_that("generate_quantities() method runs when all arguments specified validly", { # specifying all arguments validly expect_gq_output(fit1 <- do.call(mod_gq$generate_quantities, ok_arg_values)) @@ -52,7 +51,11 @@ test_that("generate_quantities work for different chains and parallel_chains", { expect_gq_output( mod_gq$generate_quantities(data = data_list, fitted_params = fit, parallel_chains = 4) ) - mod_gq <- cmdstan_model(testing_stan_file("bernoulli_ppc"), cpp_options = list(stan_threads = TRUE)) + + expect_call_compilation({ + mod_gq <- cmdstan_model(testing_stan_file("bernoulli_ppc"), cpp_options = list(stan_threads = TRUE)) + }) + expect_gq_output( mod_gq$generate_quantities(data = data_list, fitted_params = fit_1_chain, threads_per_chain = 2) ) @@ -91,4 +94,4 @@ test_that("generate_quantities() warns if threads specified but not enabled", { expect_gq_output(fit_gq <- mod_gq$generate_quantities(data = data_list, fitted_params = fit, threads_per_chain = 4)), "'threads_per_chain' will have no effect" ) -}) +}) \ No newline at end of file diff --git a/tests/testthat/test-model-internal.R b/tests/testthat/test-model-internal.R new file mode 100644 index 00000000..24c42fad --- /dev/null +++ b/tests/testthat/test-model-internal.R @@ -0,0 +1,55 @@ +test_that("parse_exe_info_string works", { + expect_equal_ignore_order( + parse_exe_info_string(" + stan_version_major = 2 + stan_version_minor = 38 + stan_version_patch = 0 + STAN_THREADS=false + STAN_MPI=false + STAN_OPENCL=true + STAN_NO_RANGE_CHECKS=false + STAN_CPP_OPTIMS=false + "), + list( + stan_version = '2.38.0', + stan_threads = FALSE, + stan_mpi = FALSE, + stan_opencl = TRUE, + stan_no_range_checks = FALSE, + stan_cpp_optims = FALSE + ) + ) +}) + +test_that("validate_precompile_cpp_options works", { + expect_equal_ignore_order( + validate_precompile_cpp_options(list(Stan_Threads = TRUE, STAN_OPENCL = NULL, aBc = FALSE)), + list( + stan_threads = TRUE, + stan_opencl = NULL, + abc = FALSE + ) + ) + expect_warning(validate_precompile_cpp_options(list(STAN_OPENCL= FALSE))) +}) + + +test_that('exe_info cpp_options comparison works', { + exe_info_all_flags_off <- exe_info_style_cpp_options(list()) + exe_info_all_flags_off[['stan_version']] <- '35.0.0' + + expect_true(exe_info_reflects_cpp_options(exe_info_all_flags_off, list())) + expect_true(exe_info_reflects_cpp_options(list(stan_opencl = FALSE), list(stan_opencl = NULL))) + expect_not_true(exe_info_reflects_cpp_options(list(stan_opencl = FALSE), list(stan_opencl = FALSE))) + expect_not_true(exe_info_reflects_cpp_options(list(stan_opencl = FALSE, stan_threads = FALSE), list(stan_opencl = NULL, stan_threads = TRUE))) + expect_not_true(exe_info_reflects_cpp_options( + list(stan_opencl = FALSE, stan_threads = FALSE), + list(stan_opencl = NULL, stan_threads = TRUE, EXTRA_ARG = TRUE) + )) + + # no exe_info -> no recompile based on cpp info + expect_warning( + expect_true(exe_info_reflects_cpp_options(list(), list())), + 'Recompiling is recommended' + ) +}) \ No newline at end of file diff --git a/tests/testthat/test-model-recompile-logic.R b/tests/testthat/test-model-recompile-logic.R new file mode 100644 index 00000000..652cccb8 --- /dev/null +++ b/tests/testthat/test-model-recompile-logic.R @@ -0,0 +1,238 @@ +stan_program <- cmdstan_example_file() +file_that_doesnt_exist <- 'placeholder_doesnt_exist' +file_that_exists <- 'placeholder_exists' +file.create(file_that_exists) +on.exit(if(file.exists(file_that_exists)) file.remove(file_that_exists)) + +test_that("warning when no recompile and no info", + with_mocked_cli(compile_ret = list(), info_ret = list(status = 1), code = expect_warning({ + mod <- cmdstan_model(stan_file = stan_program, exe_file = file_that_exists, compile = FALSE) + }, "Recompiling is recommended.")) +) + +test_that("recompiles when force_recompile flag set", + with_mocked_cli(compile_ret = list(status=0), info_ret = list(), code = expect_mock_compile({ + mod <- cmdstan_model(stan_file = stan_program, force_recompile = TRUE) + })) +) + +test_that("No mismatch results in no recompile.", with_mocked_cli( + compile_ret = list(status = 0), + info_ret = list( + status = 0, + stdout = " + stan_version_major = 2 + stan_version_minor = 35 + stan_version_patch = 0 + STAN_THREADS=false + STAN_MPI=false + STAN_OPENCL=false + STAN_NO_RANGE_CHECKS=false + STAN_CPP_OPTIMS=false + " + ), + code = expect_no_mock_compile({ + mod <- cmdstan_model(stan_file = stan_program, exe_file = file_that_exists) + }) +)) + +test_that("Mismatch results in recompile.", with_mocked_cli( + compile_ret = list(status=0), + info_ret = list( + status=0, + stdout= " + stan_version_major = 2 + stan_version_minor = 35 + stan_version_patch = 0 + STAN_THREADS=false + STAN_MPI=false + STAN_OPENCL=false + STAN_NO_RANGE_CHECKS=false + STAN_CPP_OPTIMS=false + " + ), + code = expect_mock_compile({ + mod <- cmdstan_model(stan_file = stan_program, exe_file = file_that_exists, cpp_options = list(stan_threads = TRUE)) + }) +)) +test_that("$exe_info(), $precompile_cpp_options() return expected data without recompile", + with_mocked_cli( + compile_ret = list(status=0), + info_ret = list( + status=0, + stdout= " + stan_version_major = 2 + stan_version_minor = 38 + stan_version_patch = 0 + STAN_THREADS=false + STAN_MPI=false + STAN_OPENCL=true + STAN_NO_RANGE_CHECKS=false + STAN_CPP_OPTIMS=false + " + ), + code = { + file.create(file_that_exists) + expect_no_mock_compile({ + mod <- cmdstan_model( + stan_file = stan_program, + exe_file = file_that_exists, + compile = FALSE, + cpp_options = list(Stan_Threads = TRUE, stan_opencl = NULL, aBc = FALSE) + ) + }) + expect_equal_ignore_order( + mod$exe_info(), + list( + stan_version = '2.38.0', + stan_threads = FALSE, + stan_mpi = FALSE, + stan_opencl = TRUE, + stan_no_range_checks = FALSE, + stan_cpp_optims = FALSE + ) + ) + expect_equal_ignore_order( + mod$precompile_cpp_options(), + list( + stan_threads = TRUE, + stan_opencl = NULL, + abc = FALSE + ) + ) + } + ) +) + +test_that("$exe_info_fallback() logic works as expected with cpp_options", + with_mocked_cli( + compile_ret = list(status=0), + info_ret = list( + status = 1, + stdout = '' + ), + code = { + expect_warning( + expect_no_mock_compile({ + mod <- cmdstan_model( + stan_file = stan_program, + exe_file = file_that_exists, + compile = FALSE, + cpp_options = list(Stan_Threads = TRUE, stan_Opencl = NULL, aBc = FALSE, dEf = NULL) + ) + }), + 'Retrieving exe_file info failed' + ) + # cmdstan_model call same as above + # Because we use testthat 3e, cannot nest expect_warning() with itself + expect_warning( + expect_no_mock_compile({ + mod <- cmdstan_model( + stan_file = stan_program, + exe_file = file_that_exists, + compile = FALSE, + cpp_options = list(Stan_Threads = TRUE, stan_Opencl = NULL, aBc = FALSE, dEf = NULL) + ) + }), + 'Recompiling is recommended' + ) + expect_equal( + mod$exe_info(), + NULL + ) + expect_equal_ignore_order( + mod$exe_info_fallback(), + list( + stan_version = cmdstan_version(), + stan_threads = TRUE, + stan_mpi = FALSE, + stan_opencl = FALSE, + stan_no_range_checks = FALSE, + stan_cpp_optims = FALSE, + abc = FALSE, + def = NULL + ) + ) + expect_equal_ignore_order( + mod$precompile_cpp_options(), + list( + stan_threads = TRUE, + stan_opencl = NULL, + abc = FALSE, + def = NULL + ) + ) + } + ) +) + +test_that("$exe_info_fallback() logic works as expected without cpp_options", + with_mocked_cli( + compile_ret = list(status=0), + info_ret = list( + status = 1, + stdout = "" + ), + code = { + expect_warning( + expect_no_mock_compile({ + mod <- cmdstan_model( + exe_file = file_that_exists + ) + }), + 'Retrieving exe_file info failed' + ) + # cmdstan_model call same as above + # Because we use testthat 3e, cannot nest expect_warning() with itself + expect_warning( + expect_no_mock_compile({ + mod <- cmdstan_model( + exe_file = file_that_exists + ) + }), + "Recompiling is recommended" + ) + expect_equal( + mod$exe_info(), + NULL + ) + expect_equal_ignore_order( + mod$exe_info_fallback(), + list( + stan_version = cmdstan_version(), + stan_threads = FALSE, + stan_mpi = FALSE, + stan_opencl = FALSE, + stan_no_range_checks = FALSE, + stan_cpp_optims = FALSE + ) + ) + expect_equal_ignore_order( + mod$precompile_cpp_options(), + list() + ) + } + ) +) + +test_that("Recompile when cpp args don't match binary", { + with_mocked_cli( + compile_ret = list(status=0), + info_ret = list( + status=0, + stdout= " + stan_version_major = 2 + stan_version_minor = 38 + stan_version_patch = 0 + STAN_THREADS=false + STAN_MPI=false + STAN_OPENCL=true + STAN_NO_RANGE_CHECKS=false + STAN_CPP_OPTIMS=false + " + ), + expect_mock_compile({ + mod_gq <- cmdstan_model(testing_stan_file("bernoulli_ppc"), exe_file = file_that_exists, cpp_options = list(stan_threads = TRUE)) + }) + ) +}) \ No newline at end of file diff --git a/tests/testthat/test-model-variables.R b/tests/testthat/test-model-variables.R index 680bb7cd..62e6a0c2 100644 --- a/tests/testthat/test-model-variables.R +++ b/tests/testthat/test-model-variables.R @@ -5,7 +5,7 @@ set_cmdstan_path() test_that("$variables() errors if version less than 2.27", { mod <- testing_model("bernoulli") ver <- cmdstan_version() - .cmdstanr$VERSION <- "2.26.0" + fake_cmdstan_version("2.26.0", mod = mod) expect_error( mod$variables(), "$variables() is only supported for CmdStan 2.27 or newer", diff --git a/tests/testthat/test-opencl.R b/tests/testthat/test-opencl.R index 55c59e2c..cc87184c 100644 --- a/tests/testthat/test-opencl.R +++ b/tests/testthat/test-opencl.R @@ -5,7 +5,7 @@ fit <- testing_fit("bernoulli", method = "sample", seed = 123, chains = 1) test_that("all methods error when opencl_ids is used with non OpenCL model", { stan_file <- testing_stan_file("bernoulli") - mod <- cmdstan_model(stan_file = stan_file) + mod <- cmdstan_model(stan_file = stan_file, force_recompile = TRUE) expect_error( mod$sample(data = testing_data("bernoulli"), opencl_ids = c(0, 0), chains = 1), "'opencl_ids' is set but the model was not compiled with for use with OpenCL.", @@ -22,7 +22,7 @@ test_that("all methods error when opencl_ids is used with non OpenCL model", { fixed = TRUE ) stan_file_gq <- testing_stan_file("bernoulli_ppc") - mod_gq <- cmdstan_model(stan_file = stan_file_gq) + mod_gq <- cmdstan_model(stan_file = stan_file_gq, force_recompile = TRUE) expect_error( mod_gq$generate_quantities(fitted_params = fit, data = testing_data("bernoulli"), opencl_ids = c(0, 0)), "'opencl_ids' is set but the model was not compiled with for use with OpenCL.", @@ -33,7 +33,7 @@ test_that("all methods error when opencl_ids is used with non OpenCL model", { test_that("all methods error on invalid opencl_ids", { skip_if_not(Sys.getenv("CMDSTANR_OPENCL_TESTS") %in% c("1", "true")) stan_file <- testing_stan_file("bernoulli") - mod <- cmdstan_model(stan_file = stan_file, cpp_options = list(stan_opencl = TRUE)) + mod <- cmdstan_model(stan_file = stan_file, force_recompile = TRUE, cpp_options = list(stan_opencl = TRUE)) utils::capture.output( expect_warning( mod$sample(data = testing_data("bernoulli"), opencl_ids = c(1000, 1000), chains = 1), @@ -56,7 +56,7 @@ test_that("all methods error on invalid opencl_ids", { ) ) stan_file_gq <- testing_stan_file("bernoulli_ppc") - mod_gq <- cmdstan_model(stan_file = stan_file_gq, cpp_options = list(stan_opencl = TRUE)) + mod_gq <- cmdstan_model(stan_file = stan_file_gq, force_recompile = TRUE, cpp_options = list(stan_opencl = TRUE)) utils::capture.output( expect_warning( mod_gq$generate_quantities(fitted_params = fit, data = testing_data("bernoulli"), opencl_ids = c(1000, 1000)), @@ -69,51 +69,71 @@ test_that("all methods error on invalid opencl_ids", { test_that("all methods run with valid opencl_ids", { skip_if_not(Sys.getenv("CMDSTANR_OPENCL_TESTS") %in% c("1", "true")) stan_file <- testing_stan_file("bernoulli") - mod <- cmdstan_model(stan_file = stan_file, cpp_options = list(stan_opencl = TRUE)) + mod <- cmdstan_model(stan_file = stan_file, force_recompile = TRUE, cpp_options = list(stan_opencl = TRUE)) expect_sample_output( fit <- mod$sample(data = testing_data("bernoulli"), opencl_ids = c(0, 0), chains = 1) ) expect_false(is.null(fit$metadata()$opencl_platform_name)) - expect_false(is.null(fit$metadata()$opencl_ids_name)) + expect_false(is.null(fit$metadata()$opencl_device_name)) + expect_false(is.null(fit$metadata()$device)) + expect_false(is.null(fit$metadata()$platform)) stan_file_gq <- testing_stan_file("bernoulli_ppc") - mod_gq <- cmdstan_model(stan_file = stan_file_gq, cpp_options = list(stan_opencl = TRUE)) + mod_gq <- cmdstan_model(stan_file = stan_file_gq, force_recompile = TRUE, cpp_options = list(stan_opencl = TRUE)) expect_gq_output( fit <- mod_gq$generate_quantities(fitted_params = fit, data = testing_data("bernoulli"), opencl_ids = c(0, 0)), ) expect_false(is.null(fit$metadata()$opencl_platform_name)) - expect_false(is.null(fit$metadata()$opencl_ids_name)) + expect_false(is.null(fit$metadata()$opencl_device_name)) + expect_false(is.null(fit$metadata()$device)) + expect_false(is.null(fit$metadata()$platform)) expect_sample_output( fit <- mod$sample(data = testing_data("bernoulli"), opencl_ids = c(0, 0)) ) expect_false(is.null(fit$metadata()$opencl_platform_name)) - expect_false(is.null(fit$metadata()$opencl_ids_name)) + expect_false(is.null(fit$metadata()$opencl_device_name)) + expect_false(is.null(fit$metadata()$device)) + expect_false(is.null(fit$metadata()$platform)) expect_optim_output( fit <- mod$optimize(data = testing_data("bernoulli"), opencl_ids = c(0, 0)) ) expect_false(is.null(fit$metadata()$opencl_platform_name)) - expect_false(is.null(fit$metadata()$opencl_ids_name)) + expect_false(is.null(fit$metadata()$opencl_device_name)) + expect_false(is.null(fit$metadata()$device)) + expect_false(is.null(fit$metadata()$platform)) expect_vb_output( fit <- mod$variational(data = testing_data("bernoulli"), opencl_ids = c(0, 0)) ) expect_false(is.null(fit$metadata()$opencl_platform_name)) - expect_false(is.null(fit$metadata()$opencl_ids_name)) + expect_false(is.null(fit$metadata()$opencl_device_name)) + expect_false(is.null(fit$metadata()$device)) + expect_false(is.null(fit$metadata()$platform)) }) test_that("error for runtime selection of OpenCL devices if version less than 2.26", { skip_if_not(Sys.getenv("CMDSTANR_OPENCL_TESTS") %in% c("1", "true")) - fake_cmdstan_version("2.25.0") stan_file <- testing_stan_file("bernoulli") mod <- cmdstan_model(stan_file = stan_file, cpp_options = list(stan_opencl = TRUE), force_recompile = TRUE) + fake_cmdstan_version("2.25.0", mod) expect_error( - mod$sample(data = data_list, chains = 1, refresh = 0, opencl_ids = c(1,1)), + mod$sample(data = testing_data("bernoulli"), chains = 1, refresh = 0, opencl_ids = c(0,0)), "Runtime selection of OpenCL devices is only supported with CmdStan version 2.26 or newer", fixed = TRUE ) reset_cmdstan_version() }) + +test_that("model from exe_file retains open_cl option", { + skip_if_not(Sys.getenv("CMDSTANR_OPENCL_TESTS") %in% c("1", "true")) + stan_file <- testing_stan_file("bernoulli") + mod <- cmdstan_model(stan_file = stan_file, cpp_options = list(stan_opencl = TRUE)) + mod_from_exe <- cmdstan_model(exe_file = mod$exe_file()) + expect_sample_output( + fit <- mod_from_exe$sample(data = testing_data("bernoulli"), opencl_ids = c(0, 0), chains = 1) + ) +}) \ No newline at end of file diff --git a/tests/testthat/test-threads.R b/tests/testthat/test-threads.R index 1a333e82..8652faf2 100644 --- a/tests/testthat/test-threads.R +++ b/tests/testthat/test-threads.R @@ -15,7 +15,7 @@ test_that("using threads_per_chain without stan_threads set in compile() warns", "Running MCMC with 4 sequential chains", fixed = TRUE ), - "'threads_per_chain' is set but the model was not compiled with 'cpp_options = list(stan_threads = TRUE)' so 'threads_per_chain' will have no effect!", + "'threads_per_chain' is set but the model was not compiled with 'cpp_options = list(stan_threads = TRUE)' or equivalent so 'threads_per_chain' will have no effect!", fixed = TRUE) }) @@ -24,7 +24,7 @@ test_that("threading works with sample()", { expect_error( mod$sample(data = data_file_json), - "The model was compiled with 'cpp_options = list(stan_threads = TRUE)' but 'threads_per_chain' was not set!", + "The model was compiled with 'cpp_options = list(stan_threads = TRUE)' or equivalent, but 'threads_per_chain' was not set!", fixed = TRUE ) @@ -57,7 +57,7 @@ test_that("threading works with optimize()", { expect_error( mod$optimize(data = data_file_json), - "The model was compiled with 'cpp_options = list(stan_threads = TRUE)' but 'threads' was not set!", + "The model was compiled with 'cpp_options = list(stan_threads = TRUE)' or equivalent, but 'threads' was not set!", fixed = TRUE ) @@ -91,7 +91,7 @@ test_that("threading works with variational()", { expect_error( mod$variational(data = data_file_json), - "The model was compiled with 'cpp_options = list(stan_threads = TRUE)' but 'threads' was not set!", + "The model was compiled with 'cpp_options = list(stan_threads = TRUE)' or equivalent, but 'threads' was not set!", fixed = TRUE ) @@ -130,7 +130,7 @@ test_that("threading works with generate_quantities()", { ) expect_error( mod_gq$generate_quantities(fitted_params = f, data = data_file_json), - "The model was compiled with 'cpp_options = list(stan_threads = TRUE)' but 'threads_per_chain' was not set!", + "The model was compiled with 'cpp_options = list(stan_threads = TRUE)' or equivalent, but 'threads_per_chain' was not set!", fixed = TRUE ) expect_output( @@ -158,23 +158,49 @@ test_that("threading works with generate_quantities()", { expect_equal(f_gq$metadata()$threads_per_chain, 4) }) -test_that("correct output when stan_threads not TRUE", { - mod <- cmdstan_model(stan_program, cpp_options = list(stan_threads = FALSE), force_recompile = TRUE) +test_that("correct output when stan_threads unset", { + mod <- cmdstan_model(stan_program, cpp_options = list(stan_threads = NULL), force_recompile = TRUE) expect_output( mod$sample(data = data_file_json), "Running MCMC with 4 sequential chains", fixed = TRUE ) mod <- cmdstan_model(stan_program, cpp_options = list(stan_threads = "dummy string"), force_recompile = TRUE) - expect_output( + expect_error( mod$sample(data = data_file_json), - "Running MCMC with 4 sequential chains", + "The model was compiled with 'cpp_options = list(stan_threads = TRUE)' or equivalent, but 'threads_per_chain' was not set!", fixed = TRUE ) - mod <- cmdstan_model(stan_program, cpp_options = list(stan_threads = FALSE), force_recompile = TRUE) + + mod <- cmdstan_model(stan_program, cpp_options = list(stan_threads = NULL), force_recompile = TRUE) expect_warning( mod$sample(data = data_file_json, threads_per_chain = 4), - "'threads_per_chain' is set but the model was not compiled with 'cpp_options = list(stan_threads = TRUE)' so 'threads_per_chain' will have no effect!", + "'threads_per_chain' is set but the model was not compiled with 'cpp_options = list(stan_threads = TRUE)' or equivalent so 'threads_per_chain' will have no effect!", + fixed = TRUE + ) + + expect_warning( + cmdstan_model(stan_program, cpp_options = list(stan_threads = FALSE), force_recompile = TRUE), + "STAN_THREADS set to FALSE Since this is a non-empty value, it will result in the corresponding ccp option being turned ON. To turn this option off, use cpp_options = list(stan_threads = NULL).", + fixed = TRUE + ) +}) + +test_that('correct output when stan threads set via make local',{ + #TODO clean this up so no leftover changes to make local + file.copy( + file.path(cmdstan_path(), 'make', 'local'), + file.path(cmdstan_path(), 'make', 'local.save') + ) + on.exit(file.rename( + file.path(cmdstan_path(), 'make', 'local.save'), + file.path(cmdstan_path(), 'make', 'local') + ), add = TRUE, after = FALSE) + cmdstan_make_local(cpp_options = list(stan_threads = TRUE)) + mod <- cmdstan_model(stan_program, force_recompile = TRUE) + expect_output( + f <- mod$sample(data = data_file_json, parallel_chains = 4, threads_per_chain = 1), + "Running MCMC with 4 parallel chains, with 1 thread(s) per chain..", fixed = TRUE ) })