Skip to content

Commit

Permalink
Ensure multiple precomputed neighbor data all has the same n_neighbors.
Browse files Browse the repository at this point in the history
  • Loading branch information
jlmelville committed Dec 6, 2018
1 parent c93adcf commit 157abba
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
23 changes: 17 additions & 6 deletions R/uwot.R
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,12 @@ data2set <- function(X, Xcat, n_neighbors, metrics, nn_method,
# Extract this block of nn data from list of lists
if (metric == "precomputed") {
nn_sub <- nn_method[[i]]
if (i == 1) {
n_neighbors <- NULL
}
else {
n_neighbors <- ncol(nn_method[[1]]$idx)
}
}

x2set_res <- x2set(Xsub, n_neighbors, metric, nn_method = nn_sub,
Expand Down Expand Up @@ -1229,7 +1235,9 @@ x2nn <- function(X, n_neighbors, metric, nn_method,
n_vertices = x2nv(X),
verbose = FALSE) {
if (is.list(nn_method)) {
validate_nn(nn_method, n_vertices)
# on first iteration n_neighbors is NULL
# on subsequent iterations ensure n_neighbors is consistent for all data
validate_nn(nn_method, n_vertices, n_neighbors = n_neighbors)
nn <- nn_method
}
else {
Expand All @@ -1254,16 +1262,19 @@ x2nn <- function(X, n_neighbors, metric, nn_method,
nn
}

validate_nn <- function(nn_method, n_vertices) {
validate_nn <- function(nn_method, n_vertices, n_neighbors = NULL) {
if (!is.matrix(nn_method$idx)) {
stop("Couldn't find precalculated 'idx' matrix")
}
if (nrow(nn_method$idx) != n_vertices) {
stop("Precalculated 'idx' matrix must have ", n_vertices, " rows, but
found ", nrow(nn_method$idx))
stop("Precalculated 'idx' matrix must have ", n_vertices,
" rows, but found ", nrow(nn_method$idx))
}

# set n_neighbors from these matrices if it hasn't been already set
if (is.null(n_neighbors)) {
n_neighbors <- ncol(nn_method$idx)
}
n_neighbors <- ncol(nn_method$idx)

if (!is.matrix(nn_method$dist)) {
stop("Couldn't find precalculated 'dist' matrix")
}
Expand Down
5 changes: 5 additions & 0 deletions tests/testthat/test_errors.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,8 @@ expect_error(lvish(iris10, n_threads = 0, perplexity = 50), "perplexity")
expect_error(umap(iris10, n_threads = 0, n_neighbors = 4, y = c(1:9, NA)), "numeric y")
expect_error(umap(X = NULL, n_threads = 0, n_neighbors = 4, nn_method = nn,
init = "spca"), "spca")
# add an extra column to nn
nn5 <- nn
nn5$idx <- cbind(nn5$idx, rep(100, nrow(nn5$idx)))
nn5$dist <- cbind(nn5$dist, rep(100.0, nrow(nn5$dist)))
expect_error(umap(X = NULL, n_threads = 0, nn_method = list(nn, nn5)), "Precalculated")

0 comments on commit 157abba

Please sign in to comment.