Skip to content

Commit

Permalink
Merge branch 'branch-25.02' into fix-cagra-graph-optimization-bug
Browse files Browse the repository at this point in the history
  • Loading branch information
enp1s0 authored Jan 16, 2025
2 parents d0faea9 + 86b4ee8 commit 263286b
Show file tree
Hide file tree
Showing 12 changed files with 74 additions and 39 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ There are several benefits to using cuVS and GPUs for vector search, including
6. Multiple language support
7. Building blocks for composing new or accelerating existing algorithms

In addition to the items above, cuVS takes on the burden of keeping non-trivial accelerated code up to date as new NVIDIA architectures and CUDA versions are released. This provides a deslightful development experimence, guaranteeing that any libraries, databases, or applications built on top of it will always be getting the best performance and scale.
In addition to the items above, cuVS takes on the burden of keeping non-trivial accelerated code up to date as new NVIDIA architectures and CUDA versions are released. This provides a delightful development experience, guaranteeing that any libraries, databases, or applications built on top of it will always be getting the best performance and scale.

## cuVS Technology Stack

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/distance/pairwise_distance_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ extern "C" cuvsError_t cuvsPairwiseDistance(cuvsResources_t res,

if ((x_row_major != y_row_major) || (x_row_major != distances_row_major)) {
RAFT_FAIL(
"Inputs to cuvsPairwiseDistance must all have the same layout (row-major or col-major");
"Inputs to cuvsPairwiseDistance must all have the same layout (row-major or col-major)");
}

if (x_row_major) {
Expand Down
28 changes: 21 additions & 7 deletions cpp/src/neighbors/brute_force_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@

namespace {

template <typename T>
template <typename T, typename LayoutT = raft::row_major>
void* _build(cuvsResources_t res,
DLManagedTensor* dataset_tensor,
cuvsDistanceType metric,
T metric_arg)
{
auto res_ptr = reinterpret_cast<raft::resources*>(res);

using mdspan_type = raft::device_matrix_view<T const, int64_t, raft::row_major>;
using mdspan_type = raft::device_matrix_view<T const, int64_t, LayoutT>;
auto mds = cuvs::core::from_dlpack<mdspan_type>(dataset_tensor);

cuvs::neighbors::brute_force::index_params params;
Expand All @@ -53,7 +53,7 @@ void* _build(cuvsResources_t res,
return index_on_heap;
}

template <typename T>
template <typename T, typename QueriesLayoutT = raft::row_major>
void _search(cuvsResources_t res,
cuvsBruteForceIndex index,
DLManagedTensor* queries_tensor,
Expand All @@ -64,7 +64,7 @@ void _search(cuvsResources_t res,
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto index_ptr = reinterpret_cast<cuvs::neighbors::brute_force::index<T>*>(index.addr);

using queries_mdspan_type = raft::device_matrix_view<T const, int64_t, raft::row_major>;
using queries_mdspan_type = raft::device_matrix_view<T const, int64_t, QueriesLayoutT>;
using neighbors_mdspan_type = raft::device_matrix_view<int64_t, int64_t, raft::row_major>;
using distances_mdspan_type = raft::device_matrix_view<float, int64_t, raft::row_major>;
using prefilter_mds_type = raft::device_vector_view<const uint32_t, int64_t>;
Expand Down Expand Up @@ -150,8 +150,15 @@ extern "C" cuvsError_t cuvsBruteForceBuild(cuvsResources_t res,
auto dataset = dataset_tensor->dl_tensor;

if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) {
index->addr =
reinterpret_cast<uintptr_t>(_build<float>(res, dataset_tensor, metric, metric_arg));
if (cuvs::core::is_c_contiguous(dataset_tensor)) {
index->addr =
reinterpret_cast<uintptr_t>(_build<float>(res, dataset_tensor, metric, metric_arg));
} else if (cuvs::core::is_f_contiguous(dataset_tensor)) {
index->addr = reinterpret_cast<uintptr_t>(
_build<float, raft::col_major>(res, dataset_tensor, metric, metric_arg));
} else {
RAFT_FAIL("dataset input to cuvsBruteForceBuild must be contiguous (non-strided)");
}
index->dtype = dataset.dtype;
} else {
RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d",
Expand Down Expand Up @@ -189,7 +196,14 @@ extern "C" cuvsError_t cuvsBruteForceSearch(cuvsResources_t res,
RAFT_EXPECTS(queries.dtype.code == index.dtype.code, "type mismatch between index and queries");

if (queries.dtype.code == kDLFloat && queries.dtype.bits == 32) {
_search<float>(res, index, queries_tensor, neighbors_tensor, distances_tensor, prefilter);
if (cuvs::core::is_c_contiguous(queries_tensor)) {
_search<float>(res, index, queries_tensor, neighbors_tensor, distances_tensor, prefilter);
} else if (cuvs::core::is_f_contiguous(queries_tensor)) {
_search<float, raft::col_major>(
res, index, queries_tensor, neighbors_tensor, distances_tensor, prefilter);
} else {
RAFT_FAIL("queries input to cuvsBruteForceSearch must be contiguous (non-strided)");
}
} else {
RAFT_FAIL("Unsupported queries DLtensor dtype: %d and bits: %d",
queries.dtype.code,
Expand Down
2 changes: 1 addition & 1 deletion cpp/test/neighbors/ann_ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -879,7 +879,7 @@ inline auto enum_variety_ip() -> test_cases_t
// InnerProduct score is signed,
// thus we're forced to used signed 8-bit representation,
// thus we have one bit less precision
y.min_recall = y.min_recall.value() * 0.90;
y.min_recall = y.min_recall.value() * 0.88;
} else {
// In other cases it seems to perform a little bit better, still worse than L2
y.min_recall = y.min_recall.value() * 0.94;
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def setup(app):
linkcode_resolve = make_linkcode_resolve(
"cuvs",
"https://github.com/rapidsai/cuvs/"
"blob/{revision}/python/cuvs/cuvs/"
"blob/{revision}/python/cuvs/"
"{package}/{path}#L{lineno}",
)

Expand Down
24 changes: 14 additions & 10 deletions docs/source/cuvs_bench/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ This tool offers several benefits, including

* `Docker`_

- `How to run the benchmarks`_
- `How benchmarks are run`_

* `Step 1: Prepare the dataset`_

Expand Down Expand Up @@ -93,32 +93,36 @@ We provide images for GPU enabled systems, as well as systems without a GPU. The
- `cuvs-bench-datasets`: Contains the GPU and CPU benchmarks with million-scale datasets already included in the container. Best suited for users that want to run multiple million scale datasets already included in the image.
- `cuvs-bench-cpu`: Contains only CPU benchmarks with minimal size. Best suited for users that want the smallest containers to reproduce benchmarks on systems without a GPU.

Nightly images are located in `dockerhub <https://hub.docker.com/r/rapidsai/cuvs-bench/tags>`_, meanwhile release (stable) versions are located in `NGC <https://hub.docker.com/r/rapidsai/cuvs-bench>`_, starting with release 24.10.
Nightly images are located in `dockerhub <https://hub.docker.com/r/rapidsai/cuvs-bench/tags>`_.

The following command pulls the nightly container for Python version 3.10, CUDA version 12.0, and cuVS version 24.10:
The following command pulls the nightly container for Python version 3.10, CUDA version 12.5, and cuVS version 24.12:

.. code-block:: bash
docker pull rapidsai/cuvs-bench:24.10a-cuda12.0-py3.10 #substitute cuvs-bench for the exact desired container.
docker pull rapidsai/cuvs-bench:24.12a-cuda12.5-py3.10 #substitute cuvs-bench for the exact desired container.
The CUDA and python versions can be changed for the supported values:
- Supported CUDA versions: 11.4 and 12.x
- Supported Python versions: 3.9 and 3.10.
- Supported CUDA versions: 11.8 and 12.5
- Supported Python versions: 3.10 and 3.11.

You can see the exact versions as well in the dockerhub site:
- `cuVS bench images <https://hub.docker.com/r/rapidsai/cuvs-bench/tags>`_
- `cuVS bench with datasets preloaded images <https://hub.docker.com/r/rapidsai/cuvs-bench-cpu/tags>`_
- `cuVS bench with pre-loaded million-scale datasets images <https://hub.docker.com/r/rapidsai/cuvs-bench-cpu/tags>`_
- `cuVS bench CPU only images <https://hub.docker.com/r/rapidsai/cuvs-bench-datasets/tags>`_

**Note:** GPU containers use the CUDA toolkit from inside the container, the only requirement is a driver installed on the host machine that supports that version. So, for example, CUDA 11.8 containers can run in systems with a CUDA 12.x capable driver. Please also note that the Nvidia-Docker runtime from the `Nvidia Container Toolkit <https://github.com/NVIDIA/nvidia-docker>`_ is required to use GPUs inside docker containers.

How to run the benchmarks
=========================
How benchmarks are run
======================

The `cuvs-bench` package contains lightweight Python scripts to run the benchmarks. There are 4 general steps to running the benchmarks and visualizing the results.

We provide a collection of lightweight Python scripts to run the benchmarks. There are 4 general steps to running the benchmarks and visualizing the results.
#. Prepare Dataset

#. Build Index and Search Index

#. Data Export

#. Plot Results

Step 1: Prepare the dataset
Expand Down
22 changes: 18 additions & 4 deletions docs/source/sphinxext/github_link.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
# This contains code with copyright by the scikit-learn project, subject to the
# license in /thirdparty/LICENSES/LICENSE.scikit_learn
#
# Copyright (c) 2024-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import inspect
import os
Expand Down Expand Up @@ -101,10 +116,9 @@ def _linkcode_resolve(domain, info, package, url_fmt, revision):
else:
return
else:
# Test if we are absolute or not (pyx are relative)
if (not os.path.isabs(fn)):
# Should be relative to docs right now
fn = os.path.abspath(os.path.join("..", "python", fn))
if fn.endswith(".pyx"):
sp_path = next(x for x in sys.path if re.match(".*site-packages$", x))
fn = fn.replace("/opt/conda/conda-bld/work/python/cuvs", sp_path)

# Convert to relative from module root
fn = os.path.relpath(fn,
Expand Down
6 changes: 3 additions & 3 deletions python/cuvs/cuvs/distance/distance.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ SUPPORTED_DISTANCES = ["euclidean", "l1", "cityblock", "l2", "inner_product",

@auto_sync_resources
@auto_convert_output
def pairwise_distance(X, Y, out=None, metric="euclidean", metric_arg=2.0,
def pairwise_distance(X, Y, out=None, metric="euclidean", p=2.0,
resources=None):
"""
Compute pairwise distances between X and Y
Expand All @@ -74,7 +74,7 @@ def pairwise_distance(X, Y, out=None, metric="euclidean", metric_arg=2.0,
Y : CUDA array interface compliant matrix shape (n, k)
out : Optional writable CUDA array interface matrix shape (m, n)
metric : string denoting the metric type (default="euclidean")
metric_arg : metric parameter (currently used only for "minkowski")
p : metric parameter (currently used only for "minkowski")
{resources_docstring}
Examples
Expand Down Expand Up @@ -139,6 +139,6 @@ def pairwise_distance(X, Y, out=None, metric="euclidean", metric_arg=2.0,
y_dlpack,
out_dlpack,
distance_type,
metric_arg))
p))

return out
4 changes: 2 additions & 2 deletions python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def build(dataset, metric="sqeuclidean", metric_arg=2.0, resources=None):
"""

dataset_ai = wrap_array(dataset)
_check_input_array(dataset_ai, [np.dtype('float32')])
_check_input_array(dataset_ai, [np.dtype('float32')], exp_row_major=False)

cdef cuvsResources_t res = <cuvsResources_t>resources.get_c_obj()

Expand Down Expand Up @@ -218,7 +218,7 @@ def search(Index index,
cdef cuvsResources_t res = <cuvsResources_t>resources.get_c_obj()

queries_cai = wrap_array(queries)
_check_input_array(queries_cai, [np.dtype('float32')])
_check_input_array(queries_cai, [np.dtype('float32')], exp_row_major=False)

cdef uint32_t n_queries = queries_cai.shape[0]

Expand Down
6 changes: 4 additions & 2 deletions python/cuvs/cuvs/neighbors/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
# limitations under the License.


def _check_input_array(cai, exp_dt, exp_rows=None, exp_cols=None):
def _check_input_array(
cai, exp_dt, exp_rows=None, exp_cols=None, exp_row_major=True
):
if cai.dtype not in exp_dt:
raise TypeError("dtype %s not supported" % cai.dtype)

if not cai.c_contiguous:
if exp_row_major and not cai.c_contiguous:
raise ValueError("Row major input is expected")

if exp_cols is not None and cai.shape[1] != exp_cols:
Expand Down
9 changes: 6 additions & 3 deletions python/cuvs/cuvs/test/test_brute_force.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,15 @@
],
)
@pytest.mark.parametrize("inplace", [True, False])
@pytest.mark.parametrize("order", ["F", "C"])
@pytest.mark.parametrize("dtype", [np.float32])
def test_brute_force_knn(
n_index_rows, n_query_rows, n_cols, k, inplace, metric, dtype
n_index_rows, n_query_rows, n_cols, k, inplace, order, metric, dtype
):
index = np.random.random_sample((n_index_rows, n_cols)).astype(dtype)
queries = np.random.random_sample((n_query_rows, n_cols)).astype(dtype)
index = np.random.random_sample((n_index_rows, n_cols))
index = np.asarray(index, order=order).astype(dtype)
queries = np.random.random_sample((n_query_rows, n_cols))
queries = np.asarray(queries, order=order).astype(dtype)

# RussellRao expects boolean arrays
if metric == "russellrao":
Expand Down
6 changes: 2 additions & 4 deletions python/cuvs/cuvs/test/test_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"jensenshannon",
"russellrao",
"cosine",
"minkowski",
"sqeuclidean",
"inner_product",
],
Expand Down Expand Up @@ -70,10 +71,7 @@ def test_distance(n_rows, n_cols, inplace, order, metric, dtype):
output_device = device_ndarray(output) if inplace else None

ret_output = pairwise_distance(
input1_device,
input1_device,
output_device,
metric,
input1_device, input1_device, output_device, metric, p=2.0
)

output_device = ret_output if not inplace else output_device
Expand Down

0 comments on commit 263286b

Please sign in to comment.