Skip to content

Commit

Permalink
Merge pull request #80 from mlverse/updates
Browse files Browse the repository at this point in the history
Improvements to tests
  • Loading branch information
edgararuiz authored Dec 2, 2023
2 parents 8d30970 + 1e9d1e9 commit 502d447
Show file tree
Hide file tree
Showing 31 changed files with 432 additions and 315 deletions.
33 changes: 4 additions & 29 deletions .github/workflows/spark-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@ jobs:
env:
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
R_KEEP_PKG_SOURCE: yes
SPARK_VERSION: ${{ matrix.config.spark }}
SPARK_VERSION: ${{ matrix.config.pyspark }}
HADOOP_VERSION: ${{ matrix.config.hadoop }}
PYSPARK_VERSION: ${{ matrix.config.pyspark }}
DATABRICKS_TOKEN: ${{ secrets.DATABRICKS_TOKEN }}
DATABRICKS_HOST: "https://rstudio-partner-posit-default.cloud.databricks.com"
DATABRICKS_CLUSTER_ID: "1026-175310-7cpsh3g8"

steps:
- uses: actions/checkout@v3
Expand Down Expand Up @@ -89,37 +92,9 @@ jobs:
print(t_pkgs[, c("Package", "Version")])
shell: Rscript {0}

- name: Cache Environment
id: cache-venv
uses: actions/cache@v3
with:
path: /home/runner/.virtualenvs/r-sparklyr-pyspark-${{ matrix.config.pyspark }}
key: sparklyr-virtualenv-${{ matrix.config.pyspark }}

- name: Virtual Environment
#if: steps.cache-venv.outputs.cache-hit != 'true'
run: |
devtools::load_all()
install_pyspark(
Sys.getenv("SPARK_VERSION"),
python = Sys.which("python")
)
shell: Rscript {0}

- name: R Tests
run: |
devtools::load_all()
library(sparklyr)
sv <- Sys.getenv("SPARK_VERSION")
if( sv >="3.5") {
env_name <- use_envname(version = sv)
loc <- paste0("/home/runner/.virtualenvs/", env_name,"/bin/python")
Sys.setenv("PYTHON_VERSION_MISMATCH" = loc)
Sys.setenv("PYSPARK_DRIVER_PYTHON" = loc)
}
Sys.getenv("JAVA_HOME")
Sys.getenv("PYTHON_VERSION_MISMATCH")
Sys.getenv("PYSPARK_DRIVER_PYTHON")
devtools::test(reporter = sparklyr_reporter())
shell: Rscript {0}

35 changes: 5 additions & 30 deletions .github/workflows/test-coverage.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# Workflow derived from https://github.com/r-lib/actions/tree/v2/examples
# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help
on:
push:
branches: main
Expand All @@ -10,7 +8,11 @@ jobs:
test-coverage:
runs-on: ubuntu-latest
env:
SPARK_VERSION: "3.5"
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
DATABRICKS_TOKEN: ${{ secrets.DATABRICKS_TOKEN }}
DATABRICKS_HOST: "https://rstudio-partner-posit-default.cloud.databricks.com"
DATABRICKS_CLUSTER_ID: "1026-175310-7cpsh3g8"

steps:
- uses: actions/checkout@v3
Expand Down Expand Up @@ -57,37 +59,10 @@ jobs:
sparklyr::download_scalac()
shell: Rscript {0}

- name: Cache Environment
id: cache-venv
uses: actions/cache@v3
with:
path: /home/runner/.virtualenvs/r-sparklyr-pyspark-3.5
key: sparklyr-virtualenv-3.5

- name: Virtual Environment
run: |
devtools::load_all()
install_pyspark("3.5", python = Sys.which("python"))
shell: Rscript {0}

- name: Test coverage
run: |
devtools::load_all()
env_name <- use_envname(version = 3.5)
loc <- paste0("/home/runner/.virtualenvs/", env_name,"/bin/python")
Sys.setenv("PYTHON_VERSION_MISMATCH" = loc)
Sys.setenv("PYSPARK_DRIVER_PYTHON" = loc)
Sys.setenv("SPARK_VERSION" = 3.5)
Sys.setenv("SCALA_VERSION" = 2.12)
Sys.setenv("CODE_COVERAGE" = "true")
Sys.getenv("JAVA_HOME")
Sys.getenv("SPARK_VERSION")
Sys.getenv("PYTHON_VERSION_MISMATCH")
Sys.getenv("PYSPARK_DRIVER_PYTHON")
devtools::load_all()
covr::codecov(
quiet = FALSE,
clean = FALSE,
Expand Down
1 change: 1 addition & 0 deletions R/sparklyr-spark-connect.R
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ build_user_agent <- function() {
}

connection_label <- function(x) {
x <- x[[1]]
ret <- "Connection"
method <- NULL
con <- spark_connection(x)
Expand Down
6 changes: 3 additions & 3 deletions R/start-stop-service.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#' @returns It returns messages to the console with the status of starting, and
#' stopping the local Spark Connect service.
#' @export
spark_connect_service_start <- function(version = "3.4",
spark_connect_service_start <- function(version = "3.5",
scala_version = "2.12",
include_args = TRUE,
...) {
Expand Down Expand Up @@ -38,12 +38,12 @@ spark_connect_service_start <- function(version = "3.4",

#' @rdname spark_connect_service_start
#' @export
spark_connect_service_stop <- function(version = "3.4",
spark_connect_service_stop <- function(version = "3.5",
...) {
get_version <- spark_install_find(version = version)
cmd <- path(get_version$sparkVersionDir, "sbin", "stop-connect-server.sh")
cli_div(theme = cli_colors())
cli_text("{.header Stopping {.emph Spark Connect}}")
cli_h3("{.header Stopping {.emph Spark Connect}}")
prs <- process$new(
command = cmd,
stdout = "|",
Expand Down
8 changes: 4 additions & 4 deletions inst/rstudio/shinycon/app.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ connection_spark_ui <- function() {
tags$td(style = paste("height: 5px"))
),
tags$tr(
tags$td("Master:"),
tags$td("Host URL:"),
div(
tags$td(
textInput(
Expand All @@ -79,7 +79,7 @@ connection_spark_ui <- function() {
tags$td(style = paste("height: 5px"))
),
tags$tr(
tags$td("Auth:"),
tags$td("Password:"),
tags$td(textOutput("auth"))
)
)
Expand All @@ -93,10 +93,10 @@ connection_spark_server <- function(input, output, session) {
output$auth <- reactive({
t_source <- names(token)
if (t_source == "environment") {
ret <- "✓ Found - Using value from 'DATABRICKS_TOKEN'"
ret <- "✓ Found - Using 'DATABRICKS_TOKEN'"
}
if (t_source == "oauth") {
ret <- "✓ Found - Using value from Posit Workbench OAuth"
ret <- "✓ Found - Managed by Posit Workbench OAuth"
}
if (t_source == "") {
ret <- "✘ Not Found - Add it to your 'DATABRICKS_TOKEN' env variable"
Expand Down
2 changes: 1 addition & 1 deletion man/install_pyspark.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/installed_components.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/spark_connect_service_start.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions tests/testthat/_snaps/python-install.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Install code is correctly created

Code
build_job_code(list(a = 1))
Output
[1] "pysparklyr:::install_environment(a = 1)"

152 changes: 70 additions & 82 deletions tests/testthat/helper-init.R
Original file line number Diff line number Diff line change
@@ -1,33 +1,69 @@
.test_env <- new.env()
.test_env$sc <- NULL
.test_env$lr_model <- NULL
.test_env$env <- NULL
.test_env$started <- NULL

use_test_env <- function() {
if (is.null(.test_env$env)) {
base <- fs::path_expand("~/test-spark")
.test_env$env <- fs::path(base, random_table_name("env"))
fs::dir_create(.test_env$env)
}
.test_env$env
}

test_version_spark <- function() {
use_test_version_spark <- function() {
version <- Sys.getenv("SPARK_VERSION", unset = NA)
if (is.na(version)) version <- "3.4"
if (is.na(version)) version <- "3.5"
version
}

test_scala_spark <- function() {
use_test_scala_spark <- function() {
version <- Sys.getenv("SCALA_VERSION", unset = NA)
if (is.na(version)) version <- "2.12"
version
}

test_spark_connect <- function() {
use_test_connect_start <- function() {
if (is.null(.test_env$started)) {
env_path <- path(use_test_python_environment(), "bin", "python")
version <- use_test_version_spark()
Sys.setenv("PYTHON_VERSION_MISMATCH" = env_path)
Sys.setenv("PYSPARK_DRIVER_PYTHON" = env_path)
cli_h1("Starting Spark Connect service version {version}")
cli_h3("PYTHON_VERSION_MISMATCH: {Sys.getenv('PYTHON_VERSION_MISMATCH')}")
cli_h3("PYSPARK_DRIVER_PYTHON: {Sys.getenv('PYSPARK_DRIVER_PYTHON')}")
spark_connect_service_start(
version = version,
scala_version = use_test_scala_spark()
)
.test_env$started <- 0
} else {
invisible()
}
}

use_test_spark_connect <- function() {
if (is.null(.test_env$sc)) {
cli_h2("Connecting to Spark cluster")
.test_env$sc <- sparklyr::spark_connect(
master = "sc://localhost",
method = "spark_connect",
version = test_version_spark()
use_test_connect_start()
cli_h1("Connecting to Spark cluster")
withr::with_envvar(
new = c("WORKON_HOME" = use_test_env()),
{
.test_env$sc <- sparklyr::spark_connect(
master = "sc://localhost",
method = "spark_connect",
version = use_test_version_spark()
)
}
)
}
.test_env$sc
}

test_table_mtcars <- function() {
sc <- test_spark_connect()
use_test_table_mtcars <- function() {
sc <- use_test_spark_connect()
if (!"mtcars" %in% dbListTables(sc)) {
ret <- dplyr::copy_to(sc, mtcars, overwrite = TRUE)
} else {
Expand All @@ -36,82 +72,34 @@ test_table_mtcars <- function() {
ret
}

test_lr_model <- function() {
use_test_lr_model <- function() {
if (is.null(.test_env$lr_model)) {
tbl_mtcars <- test_table_mtcars()
tbl_mtcars <- use_test_table_mtcars()
.test_env$lr_model <- ml_logistic_regression(tbl_mtcars, am ~ ., max_iter = 10)
}
.test_env$lr_model
}

test_coverage_enable <- function() {
Sys.setenv("CODE_COVERAGE" = "true")
}

expect_same_remote_result <- function(.data, pipeline) {
sc <- test_spark_connect()
temp_name <- random_table_name("test_")
spark_data <- copy_to(sc, .data, temp_name)

local <- pipeline(.data)

remote <- try(
spark_data %>%
pipeline() %>%
collect()
)

if (inherits(remote, "try-error")) {
expect_equal(remote[[1]], "")
} else {
expect_equal(local, remote, ignore_attr = TRUE)
}

DBI::dbRemoveTable(sc, temp_name)
}

testthat_tbl <- function(name, data = NULL, repartition = 0L) {
sc <- test_spark_connect()

tbl <- tryCatch(dplyr::tbl(sc, name), error = identity)
if (inherits(tbl, "error")) {
if (is.null(data)) data <- eval(as.name(name), envir = parent.frame())
tbl <- dplyr::copy_to(sc, data, name = name, repartition = repartition)
}

tbl
}

random_table_name <- function(prefix) {
paste0(prefix, paste0(floor(runif(10, 0, 10)), collapse = ""))
}


skip_spark_min_version <- function(version) {
sc <- test_spark_connect()
sp_version <- spark_version(sc)
comp_ver <- compareVersion(as.character(version), sp_version)
if (comp_ver != -1) {
skip(glue("Skips on Spark version {version}"))
}
}


test_remove_python_envs <- function(x = "") {
found <- find_environments(x)
cli_inform("Environments found: {length(found)}")

invisible(
lapply(found,
function(x)
try(virtualenv_remove(x, confirm = FALSE), silent = TRUE)
)
)

invisible(
lapply(found,
function(x)
try(conda_remove(x), silent = TRUE)
)
use_test_python_environment <- function() {
withr::with_envvar(
new = c("WORKON_HOME" = use_test_env()),
{
version <- use_test_version_spark()
env <- use_envname(method = "spark_connect", version = version)
env_avail <- names(env)
target <- path(use_test_env(), env)
if (!dir_exists(target)) {
if (env_avail != "exact") {
cli_h1("Creating Python environment")
install_pyspark(
version = version,
as_job = FALSE,
python = Sys.which("python")
)
env <- use_envname(method = "spark_connect", version = version)
}
}
}
)
target
}
Loading

0 comments on commit 502d447

Please sign in to comment.