From 909365c4e62bcf6d8a850951fa0a7cfe7f4ebba5 Mon Sep 17 00:00:00 2001 From: Ray Douglass Date: Fri, 22 Sep 2023 10:14:37 -0400 Subject: [PATCH 01/10] v23.12 Updates [skip ci] --- .github/workflows/build.yaml | 16 +++++++------- .github/workflows/pr.yaml | 22 +++++++++---------- .github/workflows/test.yaml | 8 +++---- ci/build_docs.sh | 2 +- ci/test_wheel_raft_dask.sh | 2 +- .../all_cuda-118_arch-x86_64.yaml | 6 ++--- .../all_cuda-120_arch-x86_64.yaml | 6 ++--- .../recipes/raft-dask/conda_build_config.yaml | 2 +- cpp/CMakeLists.txt | 4 ++-- cpp/doxygen/Doxyfile | 2 +- .../cmake/thirdparty/fetch_rapids.cmake | 2 +- dependencies.yaml | 8 +++---- docs/source/build.md | 4 ++-- docs/source/conf.py | 4 ++-- docs/source/developer_guide.md | 8 +++---- fetch_rapids.cmake | 2 +- python/pylibraft/CMakeLists.txt | 2 +- python/pylibraft/pylibraft/__init__.py | 2 +- python/pylibraft/pyproject.toml | 6 ++--- python/raft-ann-bench/pyproject.toml | 2 +- python/raft-dask/CMakeLists.txt | 2 +- python/raft-dask/pyproject.toml | 8 +++---- python/raft-dask/raft_dask/__init__.py | 2 +- 23 files changed, 61 insertions(+), 61 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 107823d5ee..c2d564dfda 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -28,7 +28,7 @@ concurrency: jobs: cpp-build: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-build.yaml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-build.yaml@branch-23.12 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -37,7 +37,7 @@ jobs: python-build: needs: [cpp-build] secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-build.yaml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-build.yaml@branch-23.12 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -46,7 +46,7 @@ jobs: upload-conda: needs: [cpp-build, python-build] secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-upload-packages.yaml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/conda-upload-packages.yaml@branch-23.12 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -57,7 +57,7 @@ jobs: if: github.ref_type == 'branch' needs: python-build secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/custom-job.yaml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/custom-job.yaml@branch-23.12 with: arch: "amd64" branch: ${{ inputs.branch }} @@ -69,7 +69,7 @@ jobs: sha: ${{ inputs.sha }} wheel-build-pylibraft: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-build.yaml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-build.yaml@branch-23.12 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -79,7 +79,7 @@ jobs: wheel-publish-pylibraft: needs: wheel-build-pylibraft secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-publish.yaml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-publish.yaml@branch-23.12 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -89,7 +89,7 @@ jobs: wheel-build-raft-dask: needs: wheel-publish-pylibraft secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-build.yaml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-build.yaml@branch-23.12 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -99,7 +99,7 @@ jobs: wheel-publish-raft-dask: needs: wheel-build-raft-dask secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-publish.yaml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-publish.yaml@branch-23.12 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 4fa3c5df86..8e917bd91e 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -23,41 +23,41 @@ jobs: - wheel-build-raft-dask - wheel-tests-raft-dask secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/pr-builder.yaml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/pr-builder.yaml@branch-23.12 checks: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/checks.yaml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/checks.yaml@branch-23.12 with: enable_check_generated_files: false conda-cpp-build: needs: checks secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-build.yaml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-build.yaml@branch-23.12 with: build_type: pull-request node_type: cpu16 conda-cpp-tests: needs: conda-cpp-build secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-tests.yaml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-tests.yaml@branch-23.12 with: build_type: pull-request conda-python-build: needs: conda-cpp-build secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-build.yaml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-build.yaml@branch-23.12 with: build_type: pull-request conda-python-tests: needs: conda-python-build secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-tests.yaml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-tests.yaml@branch-23.12 with: build_type: pull-request docs-build: needs: conda-python-build secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/custom-job.yaml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/custom-job.yaml@branch-23.12 with: build_type: pull-request node_type: "gpu-v100-latest-1" @@ -67,28 +67,28 @@ jobs: wheel-build-pylibraft: needs: checks secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-build.yaml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-build.yaml@branch-23.12 with: build_type: pull-request script: ci/build_wheel_pylibraft.sh wheel-tests-pylibraft: needs: wheel-build-pylibraft secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-test.yaml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-test.yaml@branch-23.12 with: build_type: pull-request script: ci/test_wheel_pylibraft.sh wheel-build-raft-dask: needs: wheel-tests-pylibraft secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-build.yaml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-build.yaml@branch-23.12 with: build_type: pull-request script: "ci/build_wheel_raft_dask.sh" wheel-tests-raft-dask: needs: wheel-build-raft-dask secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-test.yaml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-test.yaml@branch-23.12 with: build_type: pull-request script: ci/test_wheel_raft_dask.sh diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index a80d5ff0cf..4e45ae29f6 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -16,7 +16,7 @@ on: jobs: conda-cpp-tests: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-tests.yaml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-tests.yaml@branch-23.12 with: build_type: nightly branch: ${{ inputs.branch }} @@ -24,7 +24,7 @@ jobs: sha: ${{ inputs.sha }} conda-python-tests: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-tests.yaml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-tests.yaml@branch-23.12 with: build_type: nightly branch: ${{ inputs.branch }} @@ -32,7 +32,7 @@ jobs: sha: ${{ inputs.sha }} wheel-tests-pylibraft: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-test.yaml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-test.yaml@branch-23.12 with: build_type: nightly branch: ${{ inputs.branch }} @@ -41,7 +41,7 @@ jobs: script: ci/test_wheel_pylibraft.sh wheel-tests-raft-dask: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-test.yaml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-test.yaml@branch-23.12 with: build_type: nightly branch: ${{ inputs.branch }} diff --git a/ci/build_docs.sh b/ci/build_docs.sh index d1a1e2f44f..2eb9f7da6d 100755 --- a/ci/build_docs.sh +++ b/ci/build_docs.sh @@ -28,7 +28,7 @@ rapids-mamba-retry install \ pylibraft \ raft-dask -export RAPIDS_VERSION_NUMBER="23.10" +export RAPIDS_VERSION_NUMBER="23.12" export RAPIDS_DOCS_DIR="$(mktemp -d)" rapids-logger "Build CPP docs" diff --git a/ci/test_wheel_raft_dask.sh b/ci/test_wheel_raft_dask.sh index 676d642de9..a20e950313 100755 --- a/ci/test_wheel_raft_dask.sh +++ b/ci/test_wheel_raft_dask.sh @@ -12,7 +12,7 @@ RAPIDS_PY_WHEEL_NAME="pylibraft_${RAPIDS_PY_CUDA_SUFFIX}" rapids-download-wheels python -m pip install --no-deps ./local-pylibraft-dep/pylibraft*.whl # Always install latest dask for testing -python -m pip install git+https://github.com/dask/dask.git@main git+https://github.com/dask/distributed.git@main git+https://github.com/rapidsai/dask-cuda.git@branch-23.10 +python -m pip install git+https://github.com/dask/dask.git@main git+https://github.com/dask/distributed.git@main git+https://github.com/rapidsai/dask-cuda.git@branch-23.12 # echo to expand wildcard before adding `[extra]` requires for pip python -m pip install $(echo ./dist/raft_dask*.whl)[test] diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 65b4232d83..dc27a0aa32 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -20,7 +20,7 @@ dependencies: - cxx-compiler - cython>=3.0.0 - dask-core>=2023.7.1 -- dask-cuda==23.10.* +- dask-cuda==23.12.* - dask>=2023.7.1 - distributed>=2023.7.1 - doxygen>=1.8.20 @@ -47,7 +47,7 @@ dependencies: - pytest - pytest-cov - recommonmark -- rmm==23.10.* +- rmm==23.12.* - scikit-build>=0.13.1 - scikit-learn - scipy @@ -55,6 +55,6 @@ dependencies: - sphinx-markdown-tables - sysroot_linux-64==2.17 - ucx-proc=*=gpu -- ucx-py==0.34.* +- ucx-py==0.35.* - ucx>=1.13.0 name: all_cuda-118_arch-x86_64 diff --git a/conda/environments/all_cuda-120_arch-x86_64.yaml b/conda/environments/all_cuda-120_arch-x86_64.yaml index 9db38ed1de..019679592f 100644 --- a/conda/environments/all_cuda-120_arch-x86_64.yaml +++ b/conda/environments/all_cuda-120_arch-x86_64.yaml @@ -20,7 +20,7 @@ dependencies: - cxx-compiler - cython>=3.0.0 - dask-core>=2023.7.1 -- dask-cuda==23.10.* +- dask-cuda==23.12.* - dask>=2023.7.1 - distributed>=2023.7.1 - doxygen>=1.8.20 @@ -43,7 +43,7 @@ dependencies: - pytest - pytest-cov - recommonmark -- rmm==23.10.* +- rmm==23.12.* - scikit-build>=0.13.1 - scikit-learn - scipy @@ -51,6 +51,6 @@ dependencies: - sphinx-markdown-tables - sysroot_linux-64==2.17 - ucx-proc=*=gpu -- ucx-py==0.34.* +- ucx-py==0.35.* - ucx>=1.13.0 name: all_cuda-120_arch-x86_64 diff --git a/conda/recipes/raft-dask/conda_build_config.yaml b/conda/recipes/raft-dask/conda_build_config.yaml index d9a11329d9..e36adf2551 100644 --- a/conda/recipes/raft-dask/conda_build_config.yaml +++ b/conda/recipes/raft-dask/conda_build_config.yaml @@ -17,7 +17,7 @@ ucx_version: - ">=1.13.0,<1.15.0" ucx_py_version: - - "0.34.*" + - "0.35.*" cmake_version: - ">=3.26.4" diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index d93b19f784..ada7753857 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -10,8 +10,8 @@ # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. -set(RAPIDS_VERSION "23.10") -set(RAFT_VERSION "23.10.00") +set(RAPIDS_VERSION "23.12") +set(RAFT_VERSION "23.12.00") cmake_minimum_required(VERSION 3.26.4 FATAL_ERROR) include(../fetch_rapids.cmake) diff --git a/cpp/doxygen/Doxyfile b/cpp/doxygen/Doxyfile index cb990d668e..f1dcfa9db7 100644 --- a/cpp/doxygen/Doxyfile +++ b/cpp/doxygen/Doxyfile @@ -38,7 +38,7 @@ PROJECT_NAME = "RAFT C++ API" # could be handy for archiving the generated documentation or if some version # control system is used. -PROJECT_NUMBER = "23.10" +PROJECT_NUMBER = "23.12" # Using the PROJECT_BRIEF tag one can provide an optional one line description # for a project that appears at the top of each page and should give viewer a diff --git a/cpp/template/cmake/thirdparty/fetch_rapids.cmake b/cpp/template/cmake/thirdparty/fetch_rapids.cmake index 42904dcb8a..22afb1a3f5 100644 --- a/cpp/template/cmake/thirdparty/fetch_rapids.cmake +++ b/cpp/template/cmake/thirdparty/fetch_rapids.cmake @@ -12,7 +12,7 @@ # the License. # Use this variable to update RAPIDS and RAFT versions -set(RAPIDS_VERSION "23.10") +set(RAPIDS_VERSION "23.12") if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS.cmake) file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-${RAPIDS_VERSION}/RAPIDS.cmake diff --git a/dependencies.yaml b/dependencies.yaml index 700a6db1bf..f1b74cfe49 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -149,7 +149,7 @@ dependencies: common: - output_types: [conda, requirements, pyproject] packages: - - &rmm rmm==23.10.* + - &rmm rmm==23.12.* specific: - output_types: [conda, requirements, pyproject] matrices: @@ -327,12 +327,12 @@ dependencies: - output_types: [conda, pyproject] packages: - dask>=2023.7.1 - - dask-cuda==23.10.* + - dask-cuda==23.12.* - distributed>=2023.7.1 - joblib>=0.11 - numba>=0.57 - *numpy - - ucx-py==0.34.* + - ucx-py==0.35.* - output_types: conda packages: - dask-core>=2023.7.1 @@ -340,7 +340,7 @@ dependencies: - ucx-proc=*=gpu - output_types: pyproject packages: - - pylibraft==23.10.* + - pylibraft==23.12.* test_python_common: common: - output_types: [conda, requirements, pyproject] diff --git a/docs/source/build.md b/docs/source/build.md index b0c2453972..4a8748deb6 100644 --- a/docs/source/build.md +++ b/docs/source/build.md @@ -264,7 +264,7 @@ While not a highly suggested method for building against RAFT, when all of the n set(RAFT_GIT_DIR ${CMAKE_CURRENT_BINARY_DIR}/raft CACHE STRING "Path to RAFT repo") ExternalProject_Add(raft GIT_REPOSITORY git@github.com:rapidsai/raft.git - GIT_TAG branch-23.10 + GIT_TAG branch-23.12 PREFIX ${RAFT_GIT_DIR} CONFIGURE_COMMAND "" BUILD_COMMAND "" @@ -294,7 +294,7 @@ The following `cmake` snippet enables a flexible configuration of RAFT: ```cmake -set(RAFT_VERSION "23.10") +set(RAFT_VERSION "23.12") set(RAFT_FORK "rapidsai") set(RAFT_PINNED_TAG "branch-${RAFT_VERSION}") diff --git a/docs/source/conf.py b/docs/source/conf.py index 7bb5a8298d..822bce12fc 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -67,9 +67,9 @@ # built documents. # # The short X.Y version. -version = '23.10' +version = '23.12' # The full version, including alpha/beta/rc tags. -release = '23.10.00' +release = '23.12.00' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/docs/source/developer_guide.md b/docs/source/developer_guide.md index 59107e2212..c45cec1666 100644 --- a/docs/source/developer_guide.md +++ b/docs/source/developer_guide.md @@ -187,7 +187,7 @@ RAFT relies on `clang-format` to enforce code style across all C++ and CUDA sour 1. Do not split empty functions/records/namespaces. 2. Two-space indentation everywhere, including the line continuations. 3. Disable reflowing of comments. - The reasons behind these deviations from the Google style guide are given in comments [here](https://github.com/rapidsai/raft/blob/branch-23.10/cpp/.clang-format). + The reasons behind these deviations from the Google style guide are given in comments [here](https://github.com/rapidsai/raft/blob/branch-23.12/cpp/.clang-format). [`doxygen`](https://doxygen.nl/) is used as documentation generator and also as a documentation linter. In order to run doxygen as a linter on C++/CUDA code, run @@ -205,7 +205,7 @@ you can run `codespell -i 3 -w .` from the repository root directory. This will bring up an interactive prompt to select which spelling fixes to apply. ### #include style -[include_checker.py](https://github.com/rapidsai/raft/blob/branch-23.10/cpp/scripts/include_checker.py) is used to enforce the include style as follows: +[include_checker.py](https://github.com/rapidsai/raft/blob/branch-23.12/cpp/scripts/include_checker.py) is used to enforce the include style as follows: 1. `#include "..."` should be used for referencing local files only. It is acceptable to be used for referencing files in a sub-folder/parent-folder of the same algorithm, but should never be used to include files in other algorithms or between algorithms and the primitives or other dependencies. 2. `#include <...>` should be used for referencing everything else @@ -215,7 +215,7 @@ python ./cpp/scripts/include_checker.py --inplace [cpp/include cpp/test ... list ``` ### Copyright header -[copyright.py](https://github.com/rapidsai/raft/blob/branch-23.10/ci/checks/copyright.py) checks the Copyright header for all git-modified files +[copyright.py](https://github.com/rapidsai/raft/blob/branch-23.12/ci/checks/copyright.py) checks the Copyright header for all git-modified files Manually, you can run the following to bulk-fix the header if only the years need to be updated: ```bash @@ -229,7 +229,7 @@ Call CUDA APIs via the provided helper macros `RAFT_CUDA_TRY`, `RAFT_CUBLAS_TRY` ## Logging ### Introduction -Anything and everything about logging is defined inside [logger.hpp](https://github.com/rapidsai/raft/blob/branch-23.10/cpp/include/raft/core/logger.hpp). It uses [spdlog](https://github.com/gabime/spdlog) underneath, but this information is transparent to all. +Anything and everything about logging is defined inside [logger.hpp](https://github.com/rapidsai/raft/blob/branch-23.12/cpp/include/raft/core/logger.hpp). It uses [spdlog](https://github.com/gabime/spdlog) underneath, but this information is transparent to all. ### Usage ```cpp diff --git a/fetch_rapids.cmake b/fetch_rapids.cmake index eebf72e22b..c74c09fb26 100644 --- a/fetch_rapids.cmake +++ b/fetch_rapids.cmake @@ -12,7 +12,7 @@ # the License. # ============================================================================= if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS.cmake) - file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-23.10/RAPIDS.cmake + file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-23.12/RAPIDS.cmake ${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS.cmake ) endif() diff --git a/python/pylibraft/CMakeLists.txt b/python/pylibraft/CMakeLists.txt index 430eaed183..96489ea045 100644 --- a/python/pylibraft/CMakeLists.txt +++ b/python/pylibraft/CMakeLists.txt @@ -16,7 +16,7 @@ cmake_minimum_required(VERSION 3.26.4 FATAL_ERROR) include(../../fetch_rapids.cmake) -set(pylibraft_version 23.10.00) +set(pylibraft_version 23.12.00) # We always need CUDA for pylibraft because the raft dependency brings in a header-only cuco # dependency that enables CUDA unconditionally. diff --git a/python/pylibraft/pylibraft/__init__.py b/python/pylibraft/pylibraft/__init__.py index 3251c6eb68..b5ca0800a5 100644 --- a/python/pylibraft/pylibraft/__init__.py +++ b/python/pylibraft/pylibraft/__init__.py @@ -13,4 +13,4 @@ # limitations under the License. # -__version__ = "23.10.00" +__version__ = "23.12.00" diff --git a/python/pylibraft/pyproject.toml b/python/pylibraft/pyproject.toml index bbf4c1d3cb..7f3a926d43 100644 --- a/python/pylibraft/pyproject.toml +++ b/python/pylibraft/pyproject.toml @@ -19,7 +19,7 @@ requires = [ "cuda-python>=11.7.1,<12.0a0", "cython>=3.0.0", "ninja", - "rmm==23.10.*", + "rmm==23.12.*", "scikit-build>=0.13.1", "setuptools", "wheel", @@ -28,7 +28,7 @@ build-backend = "setuptools.build_meta" [project] name = "pylibraft" -version = "23.10.00" +version = "23.12.00" description = "RAFT: Reusable Algorithms Functions and other Tools" readme = { file = "README.md", content-type = "text/markdown" } authors = [ @@ -39,7 +39,7 @@ requires-python = ">=3.9" dependencies = [ "cuda-python>=11.7.1,<12.0a0", "numpy>=1.21", - "rmm==23.10.*", + "rmm==23.12.*", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. classifiers = [ "Intended Audience :: Developers", diff --git a/python/raft-ann-bench/pyproject.toml b/python/raft-ann-bench/pyproject.toml index 7decc8858b..6562937548 100644 --- a/python/raft-ann-bench/pyproject.toml +++ b/python/raft-ann-bench/pyproject.toml @@ -9,7 +9,7 @@ requires = [ [project] name = "raft-ann-bench" -version = "23.10.00" +version = "23.12.00" description = "RAFT ANN benchmarks" authors = [ { name = "NVIDIA Corporation" }, diff --git a/python/raft-dask/CMakeLists.txt b/python/raft-dask/CMakeLists.txt index f9a20b46bb..53bb12c81c 100644 --- a/python/raft-dask/CMakeLists.txt +++ b/python/raft-dask/CMakeLists.txt @@ -14,7 +14,7 @@ cmake_minimum_required(VERSION 3.26.4 FATAL_ERROR) -set(raft_dask_version 23.10.00) +set(raft_dask_version 23.12.00) include(../../fetch_rapids.cmake) diff --git a/python/raft-dask/pyproject.toml b/python/raft-dask/pyproject.toml index bdbcf61e0f..5c616806a2 100644 --- a/python/raft-dask/pyproject.toml +++ b/python/raft-dask/pyproject.toml @@ -25,7 +25,7 @@ requires = [ [project] name = "raft-dask" -version = "23.10.00" +version = "23.12.00" description = "Reusable Accelerated Functions & Tools Dask Infrastructure" readme = { file = "README.md", content-type = "text/markdown" } authors = [ @@ -34,14 +34,14 @@ authors = [ license = { text = "Apache 2.0" } requires-python = ">=3.9" dependencies = [ - "dask-cuda==23.10.*", + "dask-cuda==23.12.*", "dask>=2023.7.1", "distributed>=2023.7.1", "joblib>=0.11", "numba>=0.57", "numpy>=1.21", - "pylibraft==23.10.*", - "ucx-py==0.34.*", + "pylibraft==23.12.*", + "ucx-py==0.35.*", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. classifiers = [ "Intended Audience :: Developers", diff --git a/python/raft-dask/raft_dask/__init__.py b/python/raft-dask/raft_dask/__init__.py index 48cfd69bac..845c587caa 100644 --- a/python/raft-dask/raft_dask/__init__.py +++ b/python/raft-dask/raft_dask/__init__.py @@ -13,4 +13,4 @@ # limitations under the License. # -__version__ = "23.10.00" +__version__ = "23.12.00" From 8522a143ba9fd25d1054b93edc308d219043751b Mon Sep 17 00:00:00 2001 From: Dante Gama Dessavre Date: Mon, 25 Sep 2023 14:25:23 -0500 Subject: [PATCH 02/10] Fix conf file for benchmarking glove datasets (#1846) Authors: - Dante Gama Dessavre (https://github.com/dantegd) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1846 --- .../run/conf/glove-100-angular.json | 126 +++++++++--------- .../run/conf/glove-100-inner.json | 126 +++++++++--------- .../run/conf/glove-50-angular.json | 126 +++++++++--------- .../run/conf/glove-50-inner.json | 126 +++++++++--------- 4 files changed, 252 insertions(+), 252 deletions(-) diff --git a/python/raft-ann-bench/src/raft-ann-bench/run/conf/glove-100-angular.json b/python/raft-ann-bench/src/raft-ann-bench/run/conf/glove-100-angular.json index 526aef2db0..3595084d19 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/run/conf/glove-100-angular.json +++ b/python/raft-ann-bench/src/raft-ann-bench/run/conf/glove-100-angular.json @@ -735,37 +735,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "half", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "half", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "half", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "half", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "half", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "half", "smemLutDtype": "half" } @@ -785,55 +785,55 @@ "search_params": [ { "k": 10, - "numProbes": 1, + "nprobe": 1, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 1, + "nprobe": 1, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 5, + "nprobe": 5, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "float" } @@ -853,37 +853,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "half" } @@ -903,37 +903,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "fp8" } @@ -953,37 +953,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "fp8" } @@ -1003,37 +1003,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "half" } @@ -1053,37 +1053,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "fp8" } @@ -1103,37 +1103,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "fp8" } @@ -1153,37 +1153,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "half", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "half", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "half", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "half", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "half", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "half", "smemLutDtype": "float" } @@ -1203,37 +1203,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "float" } diff --git a/python/raft-ann-bench/src/raft-ann-bench/run/conf/glove-100-inner.json b/python/raft-ann-bench/src/raft-ann-bench/run/conf/glove-100-inner.json index 54c8bf908c..8b9f1cfb35 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/run/conf/glove-100-inner.json +++ b/python/raft-ann-bench/src/raft-ann-bench/run/conf/glove-100-inner.json @@ -735,37 +735,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "half", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "half", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "half", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "half", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "half", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "half", "smemLutDtype": "half" } @@ -785,55 +785,55 @@ "search_params": [ { "k": 10, - "numProbes": 1, + "nprobe": 1, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 1, + "nprobe": 1, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 5, + "nprobe": 5, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "float" } @@ -853,37 +853,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "half" } @@ -903,37 +903,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "fp8" } @@ -953,37 +953,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "fp8" } @@ -1003,37 +1003,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "half" } @@ -1053,37 +1053,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "fp8" } @@ -1103,37 +1103,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "fp8" } @@ -1153,37 +1153,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "half", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "half", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "half", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "half", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "half", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "half", "smemLutDtype": "float" } @@ -1203,37 +1203,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "float" } diff --git a/python/raft-ann-bench/src/raft-ann-bench/run/conf/glove-50-angular.json b/python/raft-ann-bench/src/raft-ann-bench/run/conf/glove-50-angular.json index 9b3f192c9f..0f02620cb2 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/run/conf/glove-50-angular.json +++ b/python/raft-ann-bench/src/raft-ann-bench/run/conf/glove-50-angular.json @@ -735,37 +735,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "half", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "half", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "half", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "half", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "half", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "half", "smemLutDtype": "half" } @@ -785,55 +785,55 @@ "search_params": [ { "k": 10, - "numProbes": 1, + "nprobe": 1, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 1, + "nprobe": 1, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 5, + "nprobe": 5, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "float" } @@ -853,37 +853,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "half" } @@ -903,37 +903,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "fp8" } @@ -953,37 +953,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "fp8" } @@ -1003,37 +1003,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "half" } @@ -1053,37 +1053,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "fp8" } @@ -1103,37 +1103,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "fp8" } @@ -1153,37 +1153,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "half", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "half", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "half", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "half", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "half", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "half", "smemLutDtype": "float" } @@ -1203,37 +1203,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "float" } diff --git a/python/raft-ann-bench/src/raft-ann-bench/run/conf/glove-50-inner.json b/python/raft-ann-bench/src/raft-ann-bench/run/conf/glove-50-inner.json index 0fe7ac24df..41dec5adb3 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/run/conf/glove-50-inner.json +++ b/python/raft-ann-bench/src/raft-ann-bench/run/conf/glove-50-inner.json @@ -735,37 +735,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "half", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "half", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "half", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "half", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "half", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "half", "smemLutDtype": "half" } @@ -785,55 +785,55 @@ "search_params": [ { "k": 10, - "numProbes": 1, + "nprobe": 1, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 1, + "nprobe": 1, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 5, + "nprobe": 5, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "float" } @@ -853,37 +853,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "half" } @@ -903,37 +903,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "fp8" } @@ -953,37 +953,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "fp8" } @@ -1003,37 +1003,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "half" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "half" } @@ -1053,37 +1053,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "fp8" } @@ -1103,37 +1103,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "fp8" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "fp8" } @@ -1153,37 +1153,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "half", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "half", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "half", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "half", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "half", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "half", "smemLutDtype": "float" } @@ -1203,37 +1203,37 @@ "search_params": [ { "k": 10, - "numProbes": 10, + "nprobe": 10, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 50, + "nprobe": 50, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 100, + "nprobe": 100, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 200, + "nprobe": 200, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 500, + "nprobe": 500, "internalDistanceDtype": "float", "smemLutDtype": "float" }, { "k": 10, - "numProbes": 1024, + "nprobe": 1024, "internalDistanceDtype": "float", "smemLutDtype": "float" } From cb24d998771a72ea6bad12a65cfb4aaf6ab0122e Mon Sep 17 00:00:00 2001 From: tsuki <12711693+enp1s0@users.noreply.github.com> Date: Tue, 26 Sep 2023 04:09:43 +0800 Subject: [PATCH 03/10] [FEA] Add pre-filtering to CAGRA (#1811) This PR adds the pre-filtering feature to the CAGRA search implementations. Rel: taken over from https://github.com/rapidsai/raft/pull/1765 ## Algorithm The pre-filtering algorithm removes a node that should not be in the final result after it has behaved as a parent node. This way, the nodes that should not be in the final result are also used in the graph traversal, avoiding potential performance degradation. ## Changes - Add filtering operation on a parent node after internal top-M buffer candidate calculation. - Add filtering operation to result buffer before storing them in the device memory. Authors: - tsuki (https://github.com/enp1s0) - Micka (https://github.com/lowener) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Micka (https://github.com/lowener) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1811 --- cpp/include/raft/neighbors/cagra.cuh | 75 +++++- .../neighbors/detail/cagra/cagra_search.cuh | 60 ++++- .../detail/cagra/compute_distance.hpp | 13 +- .../raft/neighbors/detail/cagra/factory.cuh | 42 ++-- .../detail/cagra/search_multi_cta.cuh | 80 +++---- .../cagra/search_multi_cta_kernel-ext.cuh | 94 ++++---- .../cagra/search_multi_cta_kernel-inl.cuh | 89 ++++++-- .../detail/cagra/search_multi_kernel.cuh | 215 +++++++++++++----- .../neighbors/detail/cagra/search_plan.cuh | 9 +- .../detail/cagra/search_single_cta.cuh | 70 +++--- .../cagra/search_single_cta_kernel-ext.cuh | 96 ++++---- .../cagra/search_single_cta_kernel-inl.cuh | 214 ++++++++++++++--- cpp/include/raft/neighbors/ivf_flat-inl.cuh | 12 +- cpp/include/raft/neighbors/ivf_pq-inl.cuh | 12 +- .../raft/neighbors/sample_filter_types.hpp | 12 + .../cagra/search_multi_cta_00_generate.py | 58 ++--- ...arch_multi_cta_float_uint32_dim1024_t32.cu | 55 +++-- ...search_multi_cta_float_uint32_dim128_t8.cu | 55 +++-- ...earch_multi_cta_float_uint32_dim256_t16.cu | 55 +++-- ...earch_multi_cta_float_uint32_dim512_t32.cu | 55 +++-- ...arch_multi_cta_float_uint64_dim1024_t32.cu | 55 +++-- ...search_multi_cta_float_uint64_dim128_t8.cu | 55 +++-- ...earch_multi_cta_float_uint64_dim256_t16.cu | 55 +++-- ...earch_multi_cta_float_uint64_dim512_t32.cu | 55 +++-- ...earch_multi_cta_int8_uint32_dim1024_t32.cu | 55 +++-- .../search_multi_cta_int8_uint32_dim128_t8.cu | 55 +++-- ...search_multi_cta_int8_uint32_dim256_t16.cu | 55 +++-- ...search_multi_cta_int8_uint32_dim512_t32.cu | 55 +++-- ...arch_multi_cta_uint8_uint32_dim1024_t32.cu | 55 +++-- ...search_multi_cta_uint8_uint32_dim128_t8.cu | 55 +++-- ...earch_multi_cta_uint8_uint32_dim256_t16.cu | 55 +++-- ...earch_multi_cta_uint8_uint32_dim512_t32.cu | 55 +++-- .../cagra/search_single_cta_00_generate.py | 59 ++--- ...rch_single_cta_float_uint32_dim1024_t32.cu | 58 ++--- ...earch_single_cta_float_uint32_dim128_t8.cu | 58 ++--- ...arch_single_cta_float_uint32_dim256_t16.cu | 58 ++--- ...arch_single_cta_float_uint32_dim512_t32.cu | 58 ++--- ...rch_single_cta_float_uint64_dim1024_t32.cu | 58 ++--- ...earch_single_cta_float_uint64_dim128_t8.cu | 58 ++--- ...arch_single_cta_float_uint64_dim256_t16.cu | 58 ++--- ...arch_single_cta_float_uint64_dim512_t32.cu | 58 ++--- ...arch_single_cta_int8_uint32_dim1024_t32.cu | 58 ++--- ...search_single_cta_int8_uint32_dim128_t8.cu | 58 ++--- ...earch_single_cta_int8_uint32_dim256_t16.cu | 58 ++--- ...earch_single_cta_int8_uint32_dim512_t32.cu | 58 ++--- ...rch_single_cta_uint8_uint32_dim1024_t32.cu | 58 ++--- ...earch_single_cta_uint8_uint32_dim128_t8.cu | 58 ++--- ...arch_single_cta_uint8_uint32_dim256_t16.cu | 58 ++--- ...arch_single_cta_uint8_uint32_dim512_t32.cu | 58 ++--- cpp/test/neighbors/ann_cagra.cuh | 177 +++++++++++++- .../ann_cagra/search_kernel_uint64_t.cuh | 200 ++++++++-------- .../neighbors/ann_cagra/test_float_int64_t.cu | 4 +- .../ann_cagra/test_float_uint32_t.cu | 8 +- .../ann_cagra/test_int8_t_uint32_t.cu | 7 +- .../ann_cagra/test_uint8_t_uint32_t.cu | 8 +- 55 files changed, 2142 insertions(+), 1280 deletions(-) diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index 903d0571dc..1bd7010c83 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -54,14 +54,14 @@ namespace raft::neighbors::cagra { * // use default index parameters * cagra::index_params build_params; * cagra::search_params search_params - * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); + * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); * // create knn graph * cagra::build_knn_graph(res, dataset, knn_graph.view(), 2, build_params, search_params); - * auto optimized_gaph = raft::make_host_matrix(dataset.extent(0), 64); + * auto optimized_gaph = raft::make_host_matrix(dataset.extent(0), 64); * cagra::optimize(res, dataset, knn_graph.view(), optimized_graph.view()); * // Construct an index from dataset and optimized knn_graph * auto index = cagra::index(res, build_params.metric(), dataset, - * optimized_graph.view()); + * optimized_graph.view()); * @endcode * * @tparam DataT data element type @@ -106,7 +106,7 @@ void build_knn_graph(raft::resources const& res, * @code{.cpp} * using namespace raft::neighbors; * cagra::index_params build_params; - * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); + * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); * // build KNN graph not using `cagra::build_knn_graph` * // build(knn_graph, dataset, ...); * // sort graph index @@ -115,7 +115,7 @@ void build_knn_graph(raft::resources const& res, * cagra::optimize(res, dataset, knn_graph.view(), optimized_graph.view()); * // Construct an index from dataset and optimized knn_graph * auto index = cagra::index(res, build_params.metric(), dataset, - * optimized_graph.view()); + * optimized_graph.view()); * @endcode * * @tparam DataT type of the data in the source dataset @@ -316,9 +316,70 @@ void search(raft::resources const& res, auto distances_internal = raft::make_device_matrix_view( distances.data_handle(), distances.extent(0), distances.extent(1)); - cagra::detail::search_main( - res, params, idx, queries_internal, neighbors_internal, distances_internal); + cagra::detail::search_main(res, + params, + idx, + queries_internal, + neighbors_internal, + distances_internal, + raft::neighbors::filtering::none_cagra_sample_filter()); } + +/** + * @brief Search ANN using the constructed index with the given sample filter. + * + * See the [cagra::build](#cagra::build) documentation for a usage example. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * @tparam CagraSampleFilterT Device filter function, with the signature + * `(uint32_t query ix, uint32_t sample_ix) -> bool` + * + * @param[in] res raft resources + * @param[in] params configure the search + * @param[in] idx cagra index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + * @param[in] sample_filter a device filter function that greenlights samples for a given query + */ +template +void search_with_filtering(raft::resources const& res, + const search_params& params, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + CagraSampleFilterT sample_filter = CagraSampleFilterT()) +{ + RAFT_EXPECTS( + queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), + "Number of rows in output neighbors and distances matrices must equal the number of queries."); + + RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), + "Number of columns in output neighbors and distances matrices must equal k"); + RAFT_EXPECTS(queries.extent(1) == idx.dim(), + "Number of query dimensions should equal number of dimensions in the index."); + + using internal_IdxT = typename std::make_unsigned::type; + auto queries_internal = raft::make_device_matrix_view( + queries.data_handle(), queries.extent(0), queries.extent(1)); + auto neighbors_internal = raft::make_device_matrix_view( + reinterpret_cast(neighbors.data_handle()), + neighbors.extent(0), + neighbors.extent(1)); + auto distances_internal = raft::make_device_matrix_view( + distances.data_handle(), distances.extent(0), distances.extent(1)); + + cagra::detail::search_main( + res, params, idx, queries_internal, neighbors_internal, distances_internal, sample_filter); +} + /** @} */ // end group cagra } // namespace raft::neighbors::cagra diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh index b484fa55f9..81e714dc4e 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -34,6 +35,48 @@ namespace raft::neighbors::cagra::detail { +template +struct CagraSampleFilterWithQueryIdOffset { + const uint32_t offset; + CagraSampleFilterT filter; + + CagraSampleFilterWithQueryIdOffset(const uint32_t offset, const CagraSampleFilterT filter) + : offset(offset), filter(filter) + { + } + + _RAFT_DEVICE auto operator()(const uint32_t query_id, const uint32_t sample_id) + { + return filter(query_id + offset, sample_id); + } +}; + +template +struct CagraSampleFilterT_Selector { + using type = CagraSampleFilterWithQueryIdOffset; +}; +template <> +struct CagraSampleFilterT_Selector { + using type = raft::neighbors::filtering::none_cagra_sample_filter; +}; + +// A helper function to set a query id offset +template +inline typename CagraSampleFilterT_Selector::type set_offset( + CagraSampleFilterT filter, const uint32_t offset) +{ + typename CagraSampleFilterT_Selector::type new_filter(offset, filter); + return new_filter; +} +template <> +inline + typename CagraSampleFilterT_Selector::type + set_offset( + raft::neighbors::filtering::none_cagra_sample_filter filter, const uint32_t) +{ + return filter; +} + /** * @brief Search ANN using the constructed index. * @@ -54,13 +97,18 @@ namespace raft::neighbors::cagra::detail { * k] */ -template +template void search_main(raft::resources const& res, search_params params, const index& index, raft::device_matrix_view queries, raft::device_matrix_view neighbors, - raft::device_matrix_view distances) + raft::device_matrix_view distances, + CagraSampleFilterT sample_filter = CagraSampleFilterT()) { resource::detail::warn_non_pool_workspace(res, "raft::neighbors::cagra::search"); RAFT_LOG_DEBUG("# dataset size = %lu, dim = %lu\n", @@ -77,8 +125,9 @@ void search_main(raft::resources const& res, common::nvtx::range fun_scope( "cagra::search(max_queries = %u, k = %u, dim = %zu)", params.max_queries, topk, index.dim()); - std::unique_ptr> plan = - factory::create( + using CagraSampleFilterT_s = typename CagraSampleFilterT_Selector::type; + std::unique_ptr> plan = + factory::create( res, params, index.dim(), index.graph_degree(), topk); plan->check(neighbors.extent(1)); @@ -119,7 +168,8 @@ void search_main(raft::resources const& res, n_queries, _seed_ptr, _num_executed_iterations, - topk); + topk, + set_offset(sample_filter, qid)); } static_assert(std::is_same_v, diff --git a/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp b/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp index 47e976e252..624c1a35d6 100644 --- a/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp @@ -155,17 +155,20 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(INDEX_T* const result_child_in INDEX_T* const visited_hashmap_ptr, const std::uint32_t hash_bitlen, const INDEX_T* const parent_indices, + const INDEX_T* const internal_topk_list, const std::uint32_t search_width) { - const INDEX_T invalid_index = utils::get_max_value(); + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + const INDEX_T invalid_index = utils::get_max_value(); // Read child indices of parents from knn graph and check if the distance // computaiton is necessary. for (uint32_t i = threadIdx.x; i < knn_k * search_width; i += BLOCK_SIZE) { - const INDEX_T parent_id = parent_indices[i / knn_k]; - INDEX_T child_id = invalid_index; - if (parent_id != invalid_index) { - child_id = knn_graph[(i % knn_k) + ((uint64_t)knn_k * parent_id)]; + const INDEX_T smem_parent_id = parent_indices[i / knn_k]; + INDEX_T child_id = invalid_index; + if (smem_parent_id != invalid_index) { + const auto parent_id = internal_topk_list[smem_parent_id] & ~index_msb_1_mask; + child_id = knn_graph[(i % knn_k) + ((uint64_t)knn_k * parent_id)]; } if (child_id != invalid_index) { if (hashmap::insert(visited_hashmap_ptr, hash_bitlen, child_id) == 0) { diff --git a/cpp/include/raft/neighbors/detail/cagra/factory.cuh b/cpp/include/raft/neighbors/detail/cagra/factory.cuh index 625040194b..78111a9310 100644 --- a/cpp/include/raft/neighbors/detail/cagra/factory.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/factory.cuh @@ -20,20 +20,25 @@ #include "search_multi_kernel.cuh" #include "search_plan.cuh" #include "search_single_cta.cuh" +#include namespace raft::neighbors::cagra::detail { -template +template class factory { public: /** * Create a search structure for dataset with dim features. */ - static std::unique_ptr> create(raft::resources const& res, - search_params const& params, - int64_t dim, - int64_t graph_degree, - uint32_t topk) + static std::unique_ptr> create( + raft::resources const& res, + search_params const& params, + int64_t dim, + int64_t graph_degree, + uint32_t topk) { search_plan_impl_base plan(params, dim, graph_degree, topk); switch (plan.max_dim) { @@ -63,26 +68,29 @@ class factory { break; default: RAFT_LOG_DEBUG("Incorrect max_dim (%lu)\n", plan.max_dim); } - return std::unique_ptr>(); + return std::unique_ptr>(); } private: template - static std::unique_ptr> dispatch_kernel( + static std::unique_ptr> dispatch_kernel( raft::resources const& res, search_plan_impl_base& plan) { if (plan.algo == search_algo::SINGLE_CTA) { - return std::unique_ptr>( - new single_cta_search::search( - res, plan, plan.dim, plan.graph_degree, plan.topk)); + return std::unique_ptr>( + new single_cta_search:: + search( + res, plan, plan.dim, plan.graph_degree, plan.topk)); } else if (plan.algo == search_algo::MULTI_CTA) { - return std::unique_ptr>( - new multi_cta_search::search( - res, plan, plan.dim, plan.graph_degree, plan.topk)); + return std::unique_ptr>( + new multi_cta_search:: + search( + res, plan, plan.dim, plan.graph_degree, plan.topk)); } else { - return std::unique_ptr>( - new multi_kernel_search::search( - res, plan, plan.dim, plan.graph_degree, plan.topk)); + return std::unique_ptr>( + new multi_kernel_search:: + search( + res, plan, plan.dim, plan.graph_degree, plan.topk)); } } }; diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh index 6ea1e34032..9a722a6dfe 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh @@ -48,42 +48,43 @@ template - -struct search : public search_plan_impl { - using search_plan_impl::max_queries; - using search_plan_impl::itopk_size; - using search_plan_impl::algo; - using search_plan_impl::team_size; - using search_plan_impl::search_width; - using search_plan_impl::min_iterations; - using search_plan_impl::max_iterations; - using search_plan_impl::thread_block_size; - using search_plan_impl::hashmap_mode; - using search_plan_impl::hashmap_min_bitlen; - using search_plan_impl::hashmap_max_fill_rate; - using search_plan_impl::num_random_samplings; - using search_plan_impl::rand_xor_mask; - - using search_plan_impl::max_dim; - using search_plan_impl::dim; - using search_plan_impl::graph_degree; - using search_plan_impl::topk; - - using search_plan_impl::hash_bitlen; - - using search_plan_impl::small_hash_bitlen; - using search_plan_impl::small_hash_reset_interval; - using search_plan_impl::hashmap_size; - using search_plan_impl::dataset_size; - using search_plan_impl::result_buffer_size; - - using search_plan_impl::smem_size; - - using search_plan_impl::hashmap; - using search_plan_impl::num_executed_iterations; - using search_plan_impl::dev_seed; - using search_plan_impl::num_seeds; + typename DISTANCE_T, + typename SAMPLE_FILTER_T> + +struct search : public search_plan_impl { + using search_plan_impl::max_queries; + using search_plan_impl::itopk_size; + using search_plan_impl::algo; + using search_plan_impl::team_size; + using search_plan_impl::search_width; + using search_plan_impl::min_iterations; + using search_plan_impl::max_iterations; + using search_plan_impl::thread_block_size; + using search_plan_impl::hashmap_mode; + using search_plan_impl::hashmap_min_bitlen; + using search_plan_impl::hashmap_max_fill_rate; + using search_plan_impl::num_random_samplings; + using search_plan_impl::rand_xor_mask; + + using search_plan_impl::max_dim; + using search_plan_impl::dim; + using search_plan_impl::graph_degree; + using search_plan_impl::topk; + + using search_plan_impl::hash_bitlen; + + using search_plan_impl::small_hash_bitlen; + using search_plan_impl::small_hash_reset_interval; + using search_plan_impl::hashmap_size; + using search_plan_impl::dataset_size; + using search_plan_impl::result_buffer_size; + + using search_plan_impl::smem_size; + + using search_plan_impl::hashmap; + using search_plan_impl::num_executed_iterations; + using search_plan_impl::dev_seed; + using search_plan_impl::num_seeds; uint32_t num_cta_per_query; rmm::device_uvector intermediate_indices; @@ -96,7 +97,8 @@ struct search : public search_plan_impl { int64_t dim, int64_t graph_degree, uint32_t topk) - : search_plan_impl(res, params, dim, graph_degree, topk), + : search_plan_impl( + res, params, dim, graph_degree, topk), intermediate_indices(0, resource::get_cuda_stream(res)), intermediate_distances(0, resource::get_cuda_stream(res)), topk_workspace(0, resource::get_cuda_stream(res)) @@ -196,7 +198,8 @@ struct search : public search_plan_impl { const uint32_t num_queries, const INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] uint32_t* const num_executed_iterations, // [num_queries,] - uint32_t topk) + uint32_t topk, + SAMPLE_FILTER_T sample_filter) { cudaStream_t stream = resource::get_cuda_stream(res); @@ -223,6 +226,7 @@ struct search : public search_plan_impl { search_width, min_iterations, max_iterations, + sample_filter, stream); RAFT_CUDA_TRY(cudaPeekAtLastError()); diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh index de83acbb64..ee525587d7 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh @@ -15,7 +15,8 @@ */ #pragma once -#include // RAFT_EXPLICIT +#include // none_cagra_sample_filter +#include // RAFT_EXPLICIT namespace raft::neighbors::cagra::detail { namespace multi_cta_search { @@ -26,7 +27,8 @@ template + class DISTANCE_T, + class SAMPLE_FILTER_T> void select_and_run(raft::device_matrix_view dataset, raft::device_matrix_view graph, INDEX_T* const topk_indices_ptr, @@ -49,47 +51,63 @@ void select_and_run(raft::device_matrix_view( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + extern template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(32, 1024, float, uint32_t, float); -instantiate_kernel_selection(8, 128, float, uint32_t, float); -instantiate_kernel_selection(16, 256, float, uint32_t, float); -instantiate_kernel_selection(32, 512, float, uint32_t, float); -instantiate_kernel_selection(32, 1024, int8_t, uint32_t, float); -instantiate_kernel_selection(8, 128, int8_t, uint32_t, float); -instantiate_kernel_selection(16, 256, int8_t, uint32_t, float); -instantiate_kernel_selection(32, 512, int8_t, uint32_t, float); -instantiate_kernel_selection(32, 1024, uint8_t, uint32_t, float); -instantiate_kernel_selection(8, 128, uint8_t, uint32_t, float); -instantiate_kernel_selection(16, 256, uint8_t, uint32_t, float); -instantiate_kernel_selection(32, 512, uint8_t, uint32_t, float); +instantiate_kernel_selection( + 32, 1024, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection( + 8, 128, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection( + 16, 256, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection( + 32, 512, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection( + 32, 1024, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection( + 8, 128, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection( + 16, 256, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection( + 32, 512, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection( + 32, 1024, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection( + 8, 128, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection( + 16, 256, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection( + 32, 512, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection } // namespace multi_cta_search diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh index 4fc051ac09..8bfbc48898 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh @@ -26,6 +26,7 @@ #include #include #include +#include #include @@ -75,7 +76,7 @@ __device__ void pickup_next_parents(INDEX_T* const next_parent_indices, // [sea if (new_parent) { const auto i = __popc(ballot_mask & ((1 << lane_id) - 1)) + num_new_parents; if (i < search_width) { - next_parent_indices[i] = index; + next_parent_indices[i] = j; itopk_indices[j] |= index_msb_1_mask; // set most significant bit as used node } } @@ -131,7 +132,8 @@ template + class LOAD_T, + class SAMPLE_FILTER_T> __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( INDEX_T* const result_indices_ptr, // [num_queries, num_cta_per_query, itopk_size] DISTANCE_T* const result_distances_ptr, // [num_queries, num_cta_per_query, itopk_size] @@ -152,8 +154,8 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( const uint32_t search_width, const uint32_t min_iteration, const uint32_t max_iteration, - uint32_t* const num_executed_iterations /* stats */ -) + uint32_t* const num_executed_iterations, /* stats */ + SAMPLE_FILTER_T sample_filter) { assert(blockDim.x == BLOCK_SIZE); assert(dataset_dim <= MAX_DATASET_DIM); @@ -287,13 +289,57 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( local_visited_hashmap_ptr, hash_bitlen, parent_indices_buffer, + result_indices_buffer, search_width); _CLK_REC(clk_compute_distance); __syncthreads(); + // Filtering + if constexpr (!std::is_same::value) { + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + const INDEX_T invalid_index = utils::get_max_value(); + + for (unsigned p = threadIdx.x; p < search_width; p += blockDim.x) { + if (parent_indices_buffer[p] != invalid_index) { + const auto parent_id = + result_indices_buffer[parent_indices_buffer[p]] & ~index_msb_1_mask; + if (!sample_filter(query_id, parent_id)) { + // If the parent must not be in the resulting top-k list, remove from the parent list + result_distances_buffer[parent_indices_buffer[p]] = utils::get_max_value(); + result_indices_buffer[parent_indices_buffer[p]] = invalid_index; + } + } + } + __syncthreads(); + } + iter++; } + // Post process for filtering + if constexpr (!std::is_same::value) { + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + const INDEX_T invalid_index = utils::get_max_value(); + + for (unsigned i = threadIdx.x; i < itopk_size + search_width * graph_degree; i += blockDim.x) { + const auto node_id = result_indices_buffer[i] & ~index_msb_1_mask; + if (node_id != (invalid_index & ~index_msb_1_mask) && !sample_filter(query_id, node_id)) { + // If the parent must not be in the resulting top-k list, remove from the parent list + result_distances_buffer[i] = utils::get_max_value(); + result_indices_buffer[i] = invalid_index; + } + } + + __syncthreads(); + topk_by_bitonic_sort(result_distances_buffer, + result_indices_buffer, + itopk_size + (search_width * graph_degree), + itopk_size); + __syncthreads(); + } + for (uint32_t i = threadIdx.x; i < itopk_size; i += BLOCK_SIZE) { uint32_t j = i + (itopk_size * (cta_id + (num_cta_per_query * query_id))); if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[i]; } @@ -361,7 +407,8 @@ template + typename DISTANCE_T, + typename SAMPLE_FILTER_T> struct search_kernel_config { // Search kernel function type. Note that the actual values for the template value // parameters do not matter, because they are not part of the function signature. The @@ -374,7 +421,8 @@ struct search_kernel_config { DATA_T, DISTANCE_T, INDEX_T, - device::LOAD_128BIT_T>); + device::LOAD_128BIT_T, + SAMPLE_FILTER_T>); static auto choose_buffer_size(unsigned result_buffer_size, unsigned block_size) -> kernel_t { @@ -401,7 +449,8 @@ struct search_kernel_config { DATA_T, DISTANCE_T, INDEX_T, - device::LOAD_128BIT_T>; + device::LOAD_128BIT_T, + SAMPLE_FILTER_T>; } else if (block_size == 128) { return search_kernel; + device::LOAD_128BIT_T, + SAMPLE_FILTER_T>; } else if (block_size == 256) { return search_kernel; + device::LOAD_128BIT_T, + SAMPLE_FILTER_T>; } else if (block_size == 512) { return search_kernel; + device::LOAD_128BIT_T, + SAMPLE_FILTER_T>; } else { return search_kernel; + device::LOAD_128BIT_T, + SAMPLE_FILTER_T>; } } }; @@ -450,7 +503,8 @@ template + typename DISTANCE_T, + typename SAMPLE_FILTER_T> void select_and_run( // raft::resources const& res, raft::device_matrix_view dataset, raft::device_matrix_view graph, @@ -475,10 +529,12 @@ void select_and_run( // raft::resources const& res, size_t search_width, size_t min_iterations, size_t max_iterations, + SAMPLE_FILTER_T sample_filter, cudaStream_t stream) { - auto kernel = search_kernel_config:: - choose_buffer_size(result_buffer_size, block_size); + auto kernel = + search_kernel_config:: + choose_buffer_size(result_buffer_size, block_size); RAFT_CUDA_TRY( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -489,7 +545,7 @@ void select_and_run( // raft::resources const& res, dim3 block_dims(block_size, 1, 1); dim3 grid_dims(num_cta_per_query, num_queries, 1); - RAFT_LOG_DEBUG("Launching kernel with %u threads, (%u, %u) blocks %lu smem", + RAFT_LOG_DEBUG("Launching kernel with %u threads, (%u, %u) blocks %u smem", block_size, num_cta_per_query, num_queries, @@ -513,7 +569,8 @@ void select_and_run( // raft::resources const& res, search_width, min_iterations, max_iterations, - num_executed_iterations); + num_executed_iterations, + sample_filter); } } // namespace multi_cta_search diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh index f312226f42..ff1bb969e7 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -242,7 +243,7 @@ __global__ void pickup_next_parents_kernel( if (new_parent) { const auto i = __popc(ballot_mask & ((1 << threadIdx.x) - 1)) + num_new_parents; if (i < parent_list_size) { - parent_list_ptr[i + (ldd * query_id)] = index; + parent_list_ptr[i + (ldd * query_id)] = j; parent_candidates_ptr[j + (lds * query_id)] |= index_msb_1_mask; // set most significant bit as used node } @@ -306,9 +307,13 @@ template + class DISTANCE_T, + class SAMPLE_FILTER_T> __global__ void compute_distance_to_child_nodes_kernel( const INDEX_T* const parent_node_list, // [num_queries, search_width] + INDEX_T* const parent_candidates_ptr, // [num_queries, search_width] + DISTANCE_T* const parent_distance_ptr, // [num_queries, search_width] + const std::size_t lds, const std::uint32_t search_width, const DATA_T* const dataset_ptr, // [dataset_size, data_dim] const std::uint32_t data_dim, @@ -321,16 +326,25 @@ __global__ void compute_distance_to_child_nodes_kernel( const std::uint32_t hash_bitlen, INDEX_T* const result_indices_ptr, // [num_queries, ldd] DISTANCE_T* const result_distances_ptr, // [num_queries, ldd] - const std::uint32_t ldd // (*) ldd >= search_width * graph_degree -) + const std::uint32_t ldd, // (*) ldd >= search_width * graph_degree + SAMPLE_FILTER_T sample_filter) { const uint32_t ldb = hashmap::get_size(hash_bitlen); const auto tid = threadIdx.x + blockDim.x * blockIdx.x; const auto global_team_id = tid / TEAM_SIZE; + const auto query_id = blockIdx.y; + if (global_team_id >= search_width * graph_degree) { return; } - const std::size_t parent_index = + const std::size_t parent_list_index = parent_node_list[global_team_id / graph_degree + (search_width * blockIdx.y)]; + + if (parent_list_index == utils::get_max_value()) { return; } + + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + const auto parent_index = + parent_candidates_ptr[parent_list_index + (lds * query_id)] & ~index_msb_1_mask; + if (parent_index == utils::get_max_value()) { result_distances_ptr[ldd * blockIdx.y + global_team_id] = utils::get_max_value(); return; @@ -361,15 +375,28 @@ __global__ void compute_distance_to_child_nodes_kernel( result_distances_ptr[ldd * blockIdx.y + global_team_id] = utils::get_max_value(); } } + + if constexpr (!std::is_same::value) { + if (!sample_filter(query_id, parent_index)) { + parent_candidates_ptr[parent_list_index + (lds * query_id)] = utils::get_max_value(); + parent_distance_ptr[parent_list_index + (lds * query_id)] = + utils::get_max_value(); + } + } } template + class DISTANCE_T, + class SAMPLE_FILTER_T> void compute_distance_to_child_nodes( const INDEX_T* const parent_node_list, // [num_queries, search_width] + INDEX_T* const parent_candidates_ptr, // [num_queries, search_width] + DISTANCE_T* const parent_distance_ptr, // [num_queries, search_width] + const std::size_t lds, const uint32_t search_width, const DATA_T* const dataset_ptr, // [dataset_size, data_dim] const std::uint32_t data_dim, @@ -384,6 +411,7 @@ void compute_distance_to_child_nodes( INDEX_T* const result_indices_ptr, // [num_queries, ldd] DISTANCE_T* const result_distances_ptr, // [num_queries, ldd] const std::uint32_t ldd, // (*) ldd >= search_width * graph_degree + SAMPLE_FILTER_T sample_filter, cudaStream_t cuda_stream = 0) { const auto block_size = 128; @@ -392,6 +420,9 @@ void compute_distance_to_child_nodes( num_queries); compute_distance_to_child_nodes_kernel <<>>(parent_node_list, + parent_candidates_ptr, + parent_distance_ptr, + lds, search_width, dataset_ptr, data_dim, @@ -404,7 +435,8 @@ void compute_distance_to_child_nodes( hash_bitlen, result_indices_ptr, result_distances_ptr, - ldd); + ldd, + sample_filter); } template @@ -436,6 +468,50 @@ void remove_parent_bit(const std::uint32_t num_queries, num_queries, num_topk, topk_indices_ptr, ld); } +// This function called after the `remove_parent_bit` function +template +__global__ void apply_filter_kernel(INDEX_T* const result_indices_ptr, + DISTANCE_T* const result_distances_ptr, + const std::size_t lds, + const std::uint32_t result_buffer_size, + const std::uint32_t num_queries, + const INDEX_T query_id_offset, + SAMPLE_FILTER_T sample_filter) +{ + const auto tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= result_buffer_size * num_queries) { return; } + const auto i = tid % result_buffer_size; + const auto j = tid / result_buffer_size; + const auto index = i + j * lds; + + if (!sample_filter(query_id_offset + j, result_indices_ptr[index])) { + result_indices_ptr[index] = utils::get_max_value(); + result_distances_ptr[index] = utils::get_max_value(); + } +} + +template +void apply_filter(INDEX_T* const result_indices_ptr, + DISTANCE_T* const result_distances_ptr, + const std::size_t lds, + const std::uint32_t result_buffer_size, + const std::uint32_t num_queries, + const INDEX_T query_id_offset, + SAMPLE_FILTER_T sample_filter, + cudaStream_t cuda_stream) +{ + const std::uint32_t block_size = 256; + const std::uint32_t grid_size = ceildiv(num_queries * result_buffer_size, block_size); + + apply_filter_kernel<<>>(result_indices_ptr, + result_distances_ptr, + lds, + result_buffer_size, + num_queries, + query_id_offset, + sample_filter); +} + template __global__ void batched_memcpy_kernel(T* const dst, // [batch_size, ld_dst] const uint64_t ld_dst, @@ -508,41 +584,42 @@ template -struct search : search_plan_impl { - using search_plan_impl::max_queries; - using search_plan_impl::itopk_size; - using search_plan_impl::algo; - using search_plan_impl::team_size; - using search_plan_impl::search_width; - using search_plan_impl::min_iterations; - using search_plan_impl::max_iterations; - using search_plan_impl::thread_block_size; - using search_plan_impl::hashmap_mode; - using search_plan_impl::hashmap_min_bitlen; - using search_plan_impl::hashmap_max_fill_rate; - using search_plan_impl::num_random_samplings; - using search_plan_impl::rand_xor_mask; - - using search_plan_impl::max_dim; - using search_plan_impl::dim; - using search_plan_impl::graph_degree; - using search_plan_impl::topk; - - using search_plan_impl::hash_bitlen; - - using search_plan_impl::small_hash_bitlen; - using search_plan_impl::small_hash_reset_interval; - using search_plan_impl::hashmap_size; - using search_plan_impl::dataset_size; - using search_plan_impl::result_buffer_size; - - using search_plan_impl::smem_size; - - using search_plan_impl::hashmap; - using search_plan_impl::num_executed_iterations; - using search_plan_impl::dev_seed; - using search_plan_impl::num_seeds; + typename DISTANCE_T, + typename SAMPLE_FILTER_T> +struct search : search_plan_impl { + using search_plan_impl::max_queries; + using search_plan_impl::itopk_size; + using search_plan_impl::algo; + using search_plan_impl::team_size; + using search_plan_impl::search_width; + using search_plan_impl::min_iterations; + using search_plan_impl::max_iterations; + using search_plan_impl::thread_block_size; + using search_plan_impl::hashmap_mode; + using search_plan_impl::hashmap_min_bitlen; + using search_plan_impl::hashmap_max_fill_rate; + using search_plan_impl::num_random_samplings; + using search_plan_impl::rand_xor_mask; + + using search_plan_impl::max_dim; + using search_plan_impl::dim; + using search_plan_impl::graph_degree; + using search_plan_impl::topk; + + using search_plan_impl::hash_bitlen; + + using search_plan_impl::small_hash_bitlen; + using search_plan_impl::small_hash_reset_interval; + using search_plan_impl::hashmap_size; + using search_plan_impl::dataset_size; + using search_plan_impl::result_buffer_size; + + using search_plan_impl::smem_size; + + using search_plan_impl::hashmap; + using search_plan_impl::num_executed_iterations; + using search_plan_impl::dev_seed; + using search_plan_impl::num_seeds; size_t result_buffer_allocation_size; rmm::device_uvector result_indices; // results_indices_buffer @@ -557,7 +634,8 @@ struct search : search_plan_impl { int64_t dim, int64_t graph_degree, uint32_t topk) - : search_plan_impl(res, params, dim, graph_degree, topk), + : search_plan_impl( + res, params, dim, graph_degree, topk), result_indices(0, resource::get_cuda_stream(res)), result_distances(0, resource::get_cuda_stream(res)), parent_node_list(0, resource::get_cuda_stream(res)), @@ -602,7 +680,8 @@ struct search : search_plan_impl { const uint32_t num_queries, const INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] uint32_t* const num_executed_iterations, // [num_queries,] - uint32_t topk) + uint32_t topk, + SAMPLE_FILTER_T sample_filter) { // Init hashmap cudaStream_t stream = resource::get_cuda_stream(res); @@ -684,6 +763,9 @@ struct search : search_plan_impl { // Compute distance to child nodes that are adjacent to the parent node compute_distance_to_child_nodes( parent_node_list.data(), + result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size, + result_distances.data() + (1 - (iter & 0x1)) * result_buffer_size, + result_buffer_allocation_size, search_width, dataset.data_handle(), dataset.extent(1), @@ -698,22 +780,53 @@ struct search : search_plan_impl { result_indices.data() + itopk_size, result_distances.data() + itopk_size, result_buffer_allocation_size, + sample_filter, stream); iter++; } // while ( 1 ) + auto result_indices_ptr = result_indices.data() + (iter & 0x1) * result_buffer_size; + auto result_distances_ptr = result_distances.data() + (iter & 0x1) * result_buffer_size; // Remove parent bit in search results - remove_parent_bit(num_queries, - itopk_size, - result_indices.data() + (iter & 0x1) * result_buffer_size, - result_buffer_allocation_size, - stream); + remove_parent_bit( + num_queries, itopk_size, result_indices_ptr, result_buffer_allocation_size, stream); + + if (!std::is_same::value) { + apply_filter( + result_indices.data() + (iter & 0x1) * itopk_size, + result_distances.data() + (iter & 0x1) * itopk_size, + result_buffer_allocation_size, + result_buffer_size, + num_queries, + 0, + sample_filter, + stream); + + result_indices_ptr = result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size; + result_distances_ptr = result_distances.data() + (1 - (iter & 0x1)) * result_buffer_size; + _cuann_find_topk(itopk_size, + num_queries, + result_buffer_size, + result_distances.data() + (iter & 0x1) * itopk_size, + result_buffer_allocation_size, + result_indices.data() + (iter & 0x1) * itopk_size, + result_buffer_allocation_size, + result_distances_ptr, + result_buffer_allocation_size, + result_indices_ptr, + result_buffer_allocation_size, + topk_workspace.data(), + true, + topk_hint.data(), + stream); + } // Copy results from working buffer to final buffer batched_memcpy(topk_indices_ptr, topk, - result_indices.data() + (iter & 0x1) * result_buffer_size, + result_indices_ptr, result_buffer_allocation_size, topk, num_queries, @@ -721,7 +834,7 @@ struct search : search_plan_impl { if (topk_distances_ptr) { batched_memcpy(topk_distances_ptr, topk, - result_distances.data() + (iter & 0x1) * result_buffer_size, + result_distances_ptr, result_buffer_allocation_size, topk, num_queries, diff --git a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh index 33c77db61e..9419385836 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh @@ -65,7 +65,7 @@ struct search_plan_impl_base : public search_params { } }; -template +template struct search_plan_impl : public search_plan_impl_base { int64_t hash_bitlen; @@ -113,7 +113,8 @@ struct search_plan_impl : public search_plan_impl_base { const std::uint32_t num_queries, const INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] std::uint32_t* const num_executed_iterations, // [num_queries] - uint32_t topk){}; + uint32_t topk, + SAMPLE_FILTER_T sample_filter){}; void adjust_search_params() { @@ -129,13 +130,13 @@ struct search_plan_impl : public search_plan_impl_base { if (max_iterations < min_iterations) { _max_iterations = min_iterations; } if (max_iterations < _max_iterations) { RAFT_LOG_DEBUG( - "# max_iterations is increased from %u to %u.", max_iterations, _max_iterations); + "# max_iterations is increased from %lu to %u.", max_iterations, _max_iterations); max_iterations = _max_iterations; } if (itopk_size % 32) { uint32_t itopk32 = itopk_size; itopk32 += 32 - (itopk_size % 32); - RAFT_LOG_DEBUG("# internal_topk is increased from %u to %u, as it must be multiple of 32.", + RAFT_LOG_DEBUG("# internal_topk is increased from %lu to %u, as it must be multiple of 32.", itopk_size, itopk32); itopk_size = itopk32; diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh index 45dd535e1d..27d54f72cb 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh @@ -49,41 +49,42 @@ template -struct search : search_plan_impl { - using search_plan_impl::max_queries; - using search_plan_impl::itopk_size; - using search_plan_impl::algo; - using search_plan_impl::team_size; - using search_plan_impl::search_width; - using search_plan_impl::min_iterations; - using search_plan_impl::max_iterations; - using search_plan_impl::thread_block_size; - using search_plan_impl::hashmap_mode; - using search_plan_impl::hashmap_min_bitlen; - using search_plan_impl::hashmap_max_fill_rate; - using search_plan_impl::num_random_samplings; - using search_plan_impl::rand_xor_mask; + typename DISTANCE_T, + typename SAMPLE_FILTER_T> +struct search : search_plan_impl { + using search_plan_impl::max_queries; + using search_plan_impl::itopk_size; + using search_plan_impl::algo; + using search_plan_impl::team_size; + using search_plan_impl::search_width; + using search_plan_impl::min_iterations; + using search_plan_impl::max_iterations; + using search_plan_impl::thread_block_size; + using search_plan_impl::hashmap_mode; + using search_plan_impl::hashmap_min_bitlen; + using search_plan_impl::hashmap_max_fill_rate; + using search_plan_impl::num_random_samplings; + using search_plan_impl::rand_xor_mask; - using search_plan_impl::max_dim; - using search_plan_impl::dim; - using search_plan_impl::graph_degree; - using search_plan_impl::topk; + using search_plan_impl::max_dim; + using search_plan_impl::dim; + using search_plan_impl::graph_degree; + using search_plan_impl::topk; - using search_plan_impl::hash_bitlen; + using search_plan_impl::hash_bitlen; - using search_plan_impl::small_hash_bitlen; - using search_plan_impl::small_hash_reset_interval; - using search_plan_impl::hashmap_size; - using search_plan_impl::dataset_size; - using search_plan_impl::result_buffer_size; + using search_plan_impl::small_hash_bitlen; + using search_plan_impl::small_hash_reset_interval; + using search_plan_impl::hashmap_size; + using search_plan_impl::dataset_size; + using search_plan_impl::result_buffer_size; - using search_plan_impl::smem_size; + using search_plan_impl::smem_size; - using search_plan_impl::hashmap; - using search_plan_impl::num_executed_iterations; - using search_plan_impl::dev_seed; - using search_plan_impl::num_seeds; + using search_plan_impl::hashmap; + using search_plan_impl::num_executed_iterations; + using search_plan_impl::dev_seed; + using search_plan_impl::num_seeds; uint32_t num_itopk_candidates; @@ -92,7 +93,8 @@ struct search : search_plan_impl { int64_t dim, int64_t graph_degree, uint32_t topk) - : search_plan_impl(res, params, dim, graph_degree, topk) + : search_plan_impl( + res, params, dim, graph_degree, topk) { set_params(res); } @@ -111,7 +113,7 @@ struct search : search_plan_impl { RAFT_EXPECTS(itopk_size <= max_itopk, "itopk_size cannot be larger than %u", max_itopk); RAFT_LOG_DEBUG("# num_itopk_candidates: %u", num_itopk_candidates); - RAFT_LOG_DEBUG("# num_itopk: %u", itopk_size); + RAFT_LOG_DEBUG("# num_itopk: %lu", itopk_size); // // Determine the thread block size // @@ -234,7 +236,8 @@ struct search : search_plan_impl { const std::uint32_t num_queries, const INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] std::uint32_t* const num_executed_iterations, // [num_queries] - uint32_t topk) + uint32_t topk, + SAMPLE_FILTER_T sample_filter) { cudaStream_t stream = resource::get_cuda_stream(res); select_and_run( @@ -261,6 +264,7 @@ struct search : search_plan_impl { search_width, min_iterations, max_iterations, + sample_filter, stream); } }; diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh index 5f5df1a818..35d239563a 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh @@ -15,7 +15,9 @@ */ #pragma once +#include #include // RAFT_EXPLICIT + namespace raft::neighbors::cagra::detail { namespace single_cta_search { @@ -25,7 +27,8 @@ template + typename DISTANCE_T, + typename SAMPLE_FILTER_T> void select_and_run( // raft::resources const& res, raft::device_matrix_view dataset, raft::device_matrix_view graph, @@ -50,50 +53,65 @@ void select_and_run( // raft::resources const& res, size_t search_width, size_t min_iterations, size_t max_iterations, + SAMPLE_FILTER_T sample_filter, cudaStream_t stream) RAFT_EXPLICIT; #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - extern template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + extern template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(32, 1024, float, uint32_t, float); -instantiate_single_cta_select_and_run(8, 128, float, uint32_t, float); -instantiate_single_cta_select_and_run(16, 256, float, uint32_t, float); -instantiate_single_cta_select_and_run(32, 512, float, uint32_t, float); -instantiate_single_cta_select_and_run(32, 1024, int8_t, uint32_t, float); -instantiate_single_cta_select_and_run(8, 128, int8_t, uint32_t, float); -instantiate_single_cta_select_and_run(16, 256, int8_t, uint32_t, float); -instantiate_single_cta_select_and_run(32, 512, int8_t, uint32_t, float); -instantiate_single_cta_select_and_run(32, 1024, uint8_t, uint32_t, float); -instantiate_single_cta_select_and_run(8, 128, uint8_t, uint32_t, float); -instantiate_single_cta_select_and_run(16, 256, uint8_t, uint32_t, float); -instantiate_single_cta_select_and_run(32, 512, uint8_t, uint32_t, float); +instantiate_single_cta_select_and_run( + 32, 1024, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run( + 8, 128, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run( + 16, 256, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run( + 32, 512, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run( + 32, 1024, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run( + 8, 128, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run( + 16, 256, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run( + 32, 512, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run( + 32, 1024, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run( + 8, 128, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run( + 16, 256, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run( + 32, 512, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_select_and_run diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh index 81325fd5da..128dc8d116 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -78,7 +79,7 @@ __device__ void pickup_next_parents(std::uint32_t* const terminate_flag, if (new_parent) { const auto i = __popc(ballot_mask & ((1 << threadIdx.x) - 1)) + num_new_parents; if (i < search_width) { - next_parent_indices[i] = index; + next_parent_indices[i] = jj; // set most significant bit as used node internal_topk_indices[jj] |= index_msb_1_mask; } @@ -458,7 +459,8 @@ template + class INDEX_T, + class SAMPLE_FILTER_T> __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel(INDEX_T* const result_indices_ptr, // [num_queries, top_k] DISTANCE_T* const result_distances_ptr, // [num_queries, top_k] @@ -482,7 +484,8 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ std::uint32_t* const num_executed_iterations, // [num_queries] const std::uint32_t hash_bitlen, const std::uint32_t small_hash_bitlen, - const std::uint32_t small_hash_reset_interval) + const std::uint32_t small_hash_reset_interval, + SAMPLE_FILTER_T sample_filter) { using LOAD_T = device::LOAD_128BIT_T; const auto query_id = blockIdx.y; @@ -527,6 +530,9 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ auto terminate_flag = reinterpret_cast(topk_ws + 3); auto smem_working_ptr = reinterpret_cast(terminate_flag + 1); + // A flag for filtering. + auto filter_flag = terminate_flag; + const DATA_T* const query_ptr = queries_ptr + query_id * dataset_dim; for (unsigned i = threadIdx.x; i < MAX_DATASET_DIM; i += BLOCK_SIZE) { unsigned j = device::swizzling(i); @@ -576,7 +582,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ std::uint32_t iter = 0; while (1) { // sort - if (TOPK_BY_BITONIC_SORT) { + if constexpr (TOPK_BY_BITONIC_SORT) { // [Notice] // It is good to use multiple warps in topk_by_bitonic_sort() when // batch size is small (short-latency), but it might not be always good @@ -614,17 +620,28 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ // topk with bitonic sort _CLK_START(); - topk_by_bitonic_sort( - result_distances_buffer, - result_indices_buffer, - internal_topk, - result_distances_buffer + internal_topk, - result_indices_buffer + internal_topk, - search_width * graph_degree, - topk_ws, - (iter == 0)); + if (std::is_same::value || + *filter_flag == 0) { + topk_by_bitonic_sort( + result_distances_buffer, + result_indices_buffer, + internal_topk, + result_distances_buffer + internal_topk, + result_indices_buffer + internal_topk, + search_width * graph_degree, + topk_ws, + (iter == 0)); + __syncthreads(); + } else { + topk_by_bitonic_sort_1st( + result_distances_buffer, + result_indices_buffer, + internal_topk + search_width * graph_degree, + internal_topk); + if (threadIdx.x == 0) { *terminate_flag = 0; } + } _CLK_REC(clk_topk); - } else { _CLK_START(); // topk with radix block sort @@ -693,12 +710,61 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ local_visited_hashmap_ptr, hash_bitlen, parent_list_buffer, + result_indices_buffer, search_width); __syncthreads(); _CLK_REC(clk_compute_distance); + // Filtering + if constexpr (!std::is_same::value) { + if (threadIdx.x == 0) { *filter_flag = 0; } + __syncthreads(); + + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + const INDEX_T invalid_index = utils::get_max_value(); + + for (unsigned p = threadIdx.x; p < search_width; p += blockDim.x) { + if (parent_list_buffer[p] != invalid_index) { + const auto parent_id = result_indices_buffer[parent_list_buffer[p]] & ~index_msb_1_mask; + if (!sample_filter(query_id, parent_id)) { + // If the parent must not be in the resulting top-k list, remove from the parent list + result_distances_buffer[parent_list_buffer[p]] = utils::get_max_value(); + result_indices_buffer[parent_list_buffer[p]] = invalid_index; + *filter_flag = 1; + } + } + } + __syncthreads(); + } + iter++; } + + // Post process for filtering + if constexpr (!std::is_same::value) { + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + const INDEX_T invalid_index = utils::get_max_value(); + + for (unsigned i = threadIdx.x; i < internal_topk + search_width * graph_degree; + i += blockDim.x) { + const auto node_id = result_indices_buffer[i] & ~index_msb_1_mask; + if (node_id != (invalid_index & ~index_msb_1_mask) && !sample_filter(query_id, node_id)) { + result_distances_buffer[i] = utils::get_max_value(); + result_indices_buffer[i] = invalid_index; + } + } + + __syncthreads(); + topk_by_bitonic_sort_1st( + result_distances_buffer, + result_indices_buffer, + internal_topk + search_width * graph_degree, + top_k); + __syncthreads(); + } + for (std::uint32_t i = threadIdx.x; i < top_k; i += BLOCK_SIZE) { unsigned j = i + (top_k * query_id); unsigned ii = i; @@ -737,9 +803,15 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ #endif } -template +template struct search_kernel_config { - using kernel_t = decltype(&search_kernel); + using kernel_t = + decltype(&search_kernel); template static auto choose_block_size(unsigned block_size) -> kernel_t @@ -747,24 +819,104 @@ struct search_kernel_config { constexpr unsigned BS = USE_BITONIC_SORT; if constexpr (BS) { if (block_size == 64) { - return search_kernel; + return search_kernel; } else if (block_size == 128) { - return search_kernel; + return search_kernel; } else if (block_size == 256) { - return search_kernel; + return search_kernel; } else if (block_size == 512) { - return search_kernel; + return search_kernel; } else { - return search_kernel; + return search_kernel; } } else { if (block_size == 256) { - return search_kernel; + return search_kernel; } else if (block_size == 512) { - return search_kernel; + return search_kernel; } else { - return search_kernel; + return search_kernel; } } } @@ -826,7 +978,8 @@ template + typename DISTANCE_T, + typename SAMPLE_FILTER_T> void select_and_run( // raft::resources const& res, raft::device_matrix_view dataset, raft::device_matrix_view graph, @@ -851,16 +1004,18 @@ void select_and_run( // raft::resources const& res, size_t search_width, size_t min_iterations, size_t max_iterations, + SAMPLE_FILTER_T sample_filter, cudaStream_t stream) { - auto kernel = search_kernel_config:: - choose_itopk_and_mx_candidates(itopk_size, num_itopk_candidates, block_size); + auto kernel = + search_kernel_config:: + choose_itopk_and_mx_candidates(itopk_size, num_itopk_candidates, block_size); RAFT_CUDA_TRY( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); dim3 thread_dims(block_size, 1, 1); dim3 block_dims(1, num_queries, 1); RAFT_LOG_DEBUG( - "Launching kernel with %u threads, %u block %lu smem", block_size, num_queries, smem_size); + "Launching kernel with %u threads, %u block %u smem", block_size, num_queries, smem_size); kernel<<>>(topk_indices_ptr, topk_distances_ptr, topk, @@ -883,7 +1038,8 @@ void select_and_run( // raft::resources const& res, num_executed_iterations, hash_bitlen, small_hash_bitlen, - small_hash_reset_interval); + small_hash_reset_interval, + sample_filter); RAFT_CUDA_TRY(cudaPeekAtLastError()); } } // namespace single_cta_search diff --git a/cpp/include/raft/neighbors/ivf_flat-inl.cuh b/cpp/include/raft/neighbors/ivf_flat-inl.cuh index a18ee065bf..6641346a67 100644 --- a/cpp/include/raft/neighbors/ivf_flat-inl.cuh +++ b/cpp/include/raft/neighbors/ivf_flat-inl.cuh @@ -342,7 +342,7 @@ void extend(raft::resources const& handle, /** @} */ /** - * @brief Search ANN using the constructed index. + * @brief Search ANN using the constructed index with the given filter. * * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. * @@ -374,6 +374,8 @@ void extend(raft::resources const& handle, * * @tparam T data element type * @tparam IdxT type of the indices + * @tparam IvfSampleFilterT Device filter function, with the signature + * `(uint32_t query_ix, uint32 cluster_ix, uint32_t sample_ix) -> bool` * * @param[in] handle * @param[in] params configure the search @@ -386,7 +388,7 @@ void extend(raft::resources const& handle, * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] * @param[in] mr an optional memory resource to use across the searches (you can provide a large * enough memory pool here to avoid memory allocations within search). - * @param[in] sample_filter a filter the greenlights samples for a given query + * @param[in] sample_filter a device filter function that greenlights samples for a given query */ template void search_with_filtering(raft::resources const& handle, @@ -475,7 +477,7 @@ void search(raft::resources const& handle, */ /** - * @brief Search ANN using the constructed index using the given filter. + * @brief Search ANN using the constructed index with the given filter. * * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. * @@ -501,6 +503,8 @@ void search(raft::resources const& handle, * * @tparam T data element type * @tparam IdxT type of the indices + * @tparam IvfSampleFilterT Device filter function, with the signature + * `(uint32_t query_ix, uint32 cluster_ix, uint32_t sample_ix) -> bool` * * @param[in] handle * @param[in] params configure the search @@ -509,7 +513,7 @@ void search(raft::resources const& handle, * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset * [n_queries, k] * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - * @param[in] sample_filter a filter the greenlights samples for a given query + * @param[in] sample_filter a device filter function that greenlights samples for a given query */ template void search_with_filtering(raft::resources const& handle, diff --git a/cpp/include/raft/neighbors/ivf_pq-inl.cuh b/cpp/include/raft/neighbors/ivf_pq-inl.cuh index ccf8717486..9f203d92fb 100644 --- a/cpp/include/raft/neighbors/ivf_pq-inl.cuh +++ b/cpp/include/raft/neighbors/ivf_pq-inl.cuh @@ -134,7 +134,7 @@ void extend(raft::resources const& handle, } /** - * @brief Search ANN using the constructed index using the given filter. + * @brief Search ANN using the constructed index with the given filter. * * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. * @@ -148,6 +148,8 @@ void extend(raft::resources const& handle, * * @tparam T data element type * @tparam IdxT type of the indices + * @tparam IvfSampleFilterT Device filter function, with the signature + * `(uint32_t query_ix, uint32 cluster_ix, uint32_t sample_ix) -> bool` * * @param[in] handle * @param[in] params configure the search @@ -157,7 +159,7 @@ void extend(raft::resources const& handle, * [n_queries, k] * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, * k] - * @param[in] sample_filter a filter the greenlights samples for a given query. + * @param[in] sample_filter a device filter function that greenlights samples for a given query. */ template void search_with_filtering(raft::resources const& handle, @@ -343,7 +345,7 @@ void extend(raft::resources const& handle, } /** - * @brief Search ANN using the constructed index using the given filter. + * @brief Search ANN using the constructed index with the given filter. * * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. * @@ -372,6 +374,8 @@ void extend(raft::resources const& handle, * * @tparam T data element type * @tparam IdxT type of the indices + * @tparam IvfSampleFilterT Device filter function, with the signature + * `(uint32_t query_ix, uint32 cluster_ix, uint32_t sample_ix) -> bool` * * @param[in] handle * @param[in] params configure the search @@ -382,7 +386,7 @@ void extend(raft::resources const& handle, * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset * [n_queries, k] * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - * @param[in] sample_filter a filter the greenlights samples for a given query + * @param[in] sample_filter a device filter function that greenlights samples for a given query */ template void search_with_filtering(raft::resources const& handle, diff --git a/cpp/include/raft/neighbors/sample_filter_types.hpp b/cpp/include/raft/neighbors/sample_filter_types.hpp index 5a301e9d2f..10c5e99372 100644 --- a/cpp/include/raft/neighbors/sample_filter_types.hpp +++ b/cpp/include/raft/neighbors/sample_filter_types.hpp @@ -37,6 +37,18 @@ struct none_ivf_sample_filter { } }; +/* A filter that filters nothing. This is the default behavior. */ +struct none_cagra_sample_filter { + inline _RAFT_HOST_DEVICE bool operator()( + // query index + const uint32_t query_ix, + // the index of the current sample + const uint32_t sample_ix) const + { + return true; + } +}; + /** * If the filtering depends on the index of a sample, then the following * filter template can be used: diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py b/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py index 784d116503..15eb0a9e65 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py @@ -39,41 +39,45 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \\ - template void select_and_run( \\ - raft::device_matrix_view dataset, \\ - raft::device_matrix_view graph, \\ - INDEX_T* const topk_indices_ptr, \\ - DISTANCE_T* const topk_distances_ptr, \\ - const DATA_T* const queries_ptr, \\ - const uint32_t num_queries, \\ - const INDEX_T* dev_seed_ptr, \\ - uint32_t* const num_executed_iterations, \\ - uint32_t topk, \\ - uint32_t block_size, \\ - uint32_t result_buffer_size, \\ - uint32_t smem_size, \\ - int64_t hash_bitlen, \\ - INDEX_T* hashmap_ptr, \\ - uint32_t num_cta_per_query, \\ - uint32_t num_random_samplings, \\ - uint64_t rand_xor_mask, \\ - uint32_t num_seeds, \\ - size_t itopk_size, \\ - size_t search_width, \\ - size_t min_iterations, \\ - size_t max_iterations, \\ - cudaStream_t stream); +#define instantiate_kernel_selection( \\ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \\ + template void \\ + select_and_run( \\ + raft::device_matrix_view dataset, \\ + raft::device_matrix_view graph, \\ + INDEX_T* const topk_indices_ptr, \\ + DISTANCE_T* const topk_distances_ptr, \\ + const DATA_T* const queries_ptr, \\ + const uint32_t num_queries, \\ + const INDEX_T* dev_seed_ptr, \\ + uint32_t* const num_executed_iterations, \\ + uint32_t topk, \\ + uint32_t block_size, \\ + uint32_t result_buffer_size, \\ + uint32_t smem_size, \\ + int64_t hash_bitlen, \\ + INDEX_T* hashmap_ptr, \\ + uint32_t num_cta_per_query, \\ + uint32_t num_random_samplings, \\ + uint64_t rand_xor_mask, \\ + uint32_t num_seeds, \\ + size_t itopk_size, \\ + size_t search_width, \\ + size_t min_iterations, \\ + size_t max_iterations, \\ + SAMPLE_FILTER_T sample_filter, \\ + cudaStream_t stream); """ trailer = """ #undef instantiate_kernel_selection -} // namespace raft::neighbors::cagra::detail::namespace multi_cta_search +} // namespace raft::neighbors::cagra::detail::multi_cta_search """ mxdim_team = [(128, 8), (256, 16), (512, 32), (1024, 32)] @@ -97,7 +101,7 @@ with open(path, "w") as f: f.write(header) f.write( - f"instantiate_kernel_selection({team}, {mxdim}, {data_t}, {idx_t}, {distance_t});\n" + f"instantiate_kernel_selection(\n {team}, {mxdim}, {data_t}, {idx_t}, {distance_t}, raft::neighbors::filtering::none_cagra_sample_filter);\n" ) f.write(trailer) # For pasting into CMakeLists.txt diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim1024_t32.cu index 2a4e7ac607..1a3b2284bd 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim1024_t32.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_widthhhhhhhhh, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(32, 1024, float, uint32_t, float); +instantiate_kernel_selection( + 32, 1024, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu index 115ce3b48b..36e86d9ed6 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(8, 128, float, uint32_t, float); +instantiate_kernel_selection( + 8, 128, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu index c5e704a85f..6f1af2d93f 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(16, 256, float, uint32_t, float); +instantiate_kernel_selection( + 16, 256, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu index 3469facf39..1279f8e415 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(32, 512, float, uint32_t, float); +instantiate_kernel_selection( + 32, 512, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim1024_t32.cu index 327bfc73b4..0dabff0df5 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim1024_t32.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(32, 1024, float, uint64_t, float); +instantiate_kernel_selection( + 32, 1024, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim128_t8.cu index 1abe0cd8af..72bb74cdb8 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim128_t8.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(8, 128, float, uint64_t, float); +instantiate_kernel_selection( + 8, 128, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim256_t16.cu index dd61810d06..dceea10b5d 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim256_t16.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(16, 256, float, uint64_t, float); +instantiate_kernel_selection( + 16, 256, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim512_t32.cu index 8e12bab514..acb8bd6a12 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim512_t32.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(32, 512, float, uint64_t, float); +instantiate_kernel_selection( + 32, 512, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim1024_t32.cu index d946ac9c79..0254f09ff0 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim1024_t32.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(32, 1024, int8_t, uint32_t, float); +instantiate_kernel_selection( + 32, 1024, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim128_t8.cu index e4d7b44d1e..2b67e7e968 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim128_t8.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(8, 128, int8_t, uint32_t, float); +instantiate_kernel_selection( + 8, 128, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim256_t16.cu index b8dc3b38a8..17d6722e58 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim256_t16.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(16, 256, int8_t, uint32_t, float); +instantiate_kernel_selection( + 16, 256, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim512_t32.cu index 749b35bad6..38f02812e2 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim512_t32.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(32, 512, int8_t, uint32_t, float); +instantiate_kernel_selection( + 32, 512, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim1024_t32.cu index 428d460ba8..fa111196c6 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim1024_t32.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_widthh, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(32, 1024, uint8_t, uint32_t, float); +instantiate_kernel_selection( + 32, 1024, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim128_t8.cu index 28a20b865e..1ef3c28aa3 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim128_t8.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(8, 128, uint8_t, uint32_t, float); +instantiate_kernel_selection( + 8, 128, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim256_t16.cu index e85a84ae8e..d26cb44843 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim256_t16.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(16, 256, uint8_t, uint32_t, float); +instantiate_kernel_selection( + 16, 256, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim512_t32.cu index 232b62ebcd..4d4322f261 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim512_t32.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(32, 512, uint8_t, uint32_t, float); +instantiate_kernel_selection( + 32, 512, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py b/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py index cf61a45b4a..249555082e 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py @@ -39,35 +39,38 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \\ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \\ - template void select_and_run( \\ - raft::device_matrix_view dataset, \\ - raft::device_matrix_view graph, \\ - INDEX_T* const topk_indices_ptr, \\ - DISTANCE_T* const topk_distances_ptr, \\ - const DATA_T* const queries_ptr, \\ - const uint32_t num_queries, \\ - const INDEX_T* dev_seed_ptr, \\ - uint32_t* const num_executed_iterations, \\ - uint32_t topk, \\ - uint32_t num_itopk_candidates, \\ - uint32_t block_size, \\ - uint32_t smem_size, \\ - int64_t hash_bitlen, \\ - INDEX_T* hashmap_ptr, \\ - size_t small_hash_bitlen, \\ - size_t small_hash_reset_interval, \\ - uint32_t num_random_samplings, \\ - uint64_t rand_xor_mask, \\ - uint32_t num_seeds, \\ - size_t itopk_size, \\ - size_t search_width, \\ - size_t min_iterations, \\ - size_t max_iterations, \\ +#define instantiate_single_cta_select_and_run( \\ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \\ + template void \\ + select_and_run( \\ + raft::device_matrix_view dataset, \\ + raft::device_matrix_view graph, \\ + INDEX_T* const topk_indices_ptr, \\ + DISTANCE_T* const topk_distances_ptr, \\ + const DATA_T* const queries_ptr, \\ + const uint32_t num_queries, \\ + const INDEX_T* dev_seed_ptr, \\ + uint32_t* const num_executed_iterations, \\ + uint32_t topk, \\ + uint32_t num_itopk_candidates, \\ + uint32_t block_size, \\ + uint32_t smem_size, \\ + int64_t hash_bitlen, \\ + INDEX_T* hashmap_ptr, \\ + size_t small_hash_bitlen, \\ + size_t small_hash_reset_interval, \\ + uint32_t num_random_samplings, \\ + uint64_t rand_xor_mask, \\ + uint32_t num_seeds, \\ + size_t itopk_size, \\ + size_t search_width, \\ + size_t min_iterations, \\ + size_t max_iterations, \\ + SAMPLE_FILTER_T sample_filter, \\ cudaStream_t stream); """ @@ -75,7 +78,7 @@ trailer = """ #undef instantiate_single_cta_search_kernel -} // namespace raft::neighbors::cagra::detail::single_cta_search +} // namespace raft::neighbors::cagra::detail::single_cta_search """ mxdim_team = [(128, 8), (256, 16), (512, 32), (1024, 32)] @@ -102,7 +105,7 @@ with open(path, "w") as f: f.write(header) f.write( - f"instantiate_single_cta_select_and_run({team}, {mxdim},{data_t}, {idx_t}, {distance_t});\n" + f"instantiate_single_cta_select_and_run(\n {team}, {mxdim}, {data_t}, {idx_t}, {distance_t}, raft::neighbors::filtering::none_cagra_sample_filter);\n" ) f.write(trailer) diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim1024_t32.cu index eb45d4ff08..b8c23103ba 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim1024_t32.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(32, 1024, float, uint32_t, float); +instantiate_single_cta_select_and_run( + 32, 1024, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim128_t8.cu index 049715aa20..8ab1897119 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim128_t8.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(8, 128, float, uint32_t, float); +instantiate_single_cta_select_and_run( + 8, 128, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu index 6028c283db..9fd36b4cb9 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(16, 256, float, uint32_t, float); +instantiate_single_cta_select_and_run( + 16, 256, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu index 2566e9cbd9..a9ee2c864b 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(32, 512, float, uint32_t, float); +instantiate_single_cta_select_and_run( + 32, 512, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim1024_t32.cu index 4cd96ad9c0..dadc574b65 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim1024_t32.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(32, 1024, float, uint64_t, float); +instantiate_single_cta_select_and_run( + 32, 1024, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim128_t8.cu index 822a2efb2f..30e043f47e 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim128_t8.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(8, 128, float, uint64_t, float); +instantiate_single_cta_select_and_run( + 8, 128, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim256_t16.cu index 80d1f76b9b..089e4c930f 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim256_t16.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(16, 256, float, uint64_t, float); +instantiate_single_cta_select_and_run( + 16, 256, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim512_t32.cu index 06c3eaf10b..3e8ffb8bf8 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim512_t32.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(32, 512, float, uint64_t, float); +instantiate_single_cta_select_and_run( + 32, 512, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim1024_t32.cu index b4c30ac943..279587738e 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim1024_t32.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(32, 1024, int8_t, uint32_t, float); +instantiate_single_cta_select_and_run( + 32, 1024, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu index c8d0df3ac4..ef127d3f7d 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(8, 128, int8_t, uint32_t, float); +instantiate_single_cta_select_and_run( + 8, 128, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu index 19ecee91af..7fcfdcc28e 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(16, 256, int8_t, uint32_t, float); +instantiate_single_cta_select_and_run( + 16, 256, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu index 52c4eb7d6b..a6c606d99b 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(32, 512, int8_t, uint32_t, float); +instantiate_single_cta_select_and_run( + 32, 512, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu index 4675e17084..0b8be56614 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(32, 1024, uint8_t, uint32_t, float); +instantiate_single_cta_select_and_run( + 32, 1024, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim128_t8.cu index e73e1071ee..4c193b9408 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim128_t8.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(8, 128, uint8_t, uint32_t, float); +instantiate_single_cta_select_and_run( + 8, 128, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu index 01e26b5f29..bdf16d2f03 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(16, 256, uint8_t, uint32_t, float); +instantiate_single_cta_select_and_run( + 16, 256, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu index b0534b555f..93624df4aa 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(32, 512, uint8_t, uint32_t, float); +instantiate_single_cta_select_and_run( + 32, 512, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index eadc88085f..90f271e3ee 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -15,6 +15,8 @@ */ #pragma once +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Search with filter instantiation + #include "../test_utils.cuh" #include "ann_utils.cuh" #include @@ -25,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -41,8 +44,22 @@ #include #include -namespace raft::neighbors::experimental::cagra { +namespace raft::neighbors::cagra { namespace { + +/* A filter that excludes all indices below `offset`. */ +struct test_cagra_sample_filter { + static constexpr unsigned offset = 400; + inline _RAFT_HOST_DEVICE auto operator()( + // query index + const uint32_t query_ix, + // the index of the current sample inside the current inverted list + const uint32_t sample_ix) const + { + return sample_ix >= offset; + } +}; + // For sort_knn_graph test template void RandomSuffle(raft::host_matrix_view index) @@ -365,6 +382,162 @@ class AnnCagraSortTest : public ::testing::TestWithParam { rmm::device_uvector database; }; +template +class AnnCagraFilterTest : public ::testing::TestWithParam { + public: + AnnCagraFilterTest() + : stream_(resource::get_cuda_stream(handle_)), + ps(::testing::TestWithParam::GetParam()), + database(0, stream_), + search_queries(0, stream_) + { + } + + protected: + void testCagraFilter() + { + size_t queries_size = ps.n_queries * ps.k; + std::vector indices_Cagra(queries_size); + std::vector indices_naive(queries_size); + std::vector distances_Cagra(queries_size); + std::vector distances_naive(queries_size); + + { + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + auto* database_filtered_ptr = database.data() + test_cagra_sample_filter::offset * ps.dim; + naive_knn(handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + search_queries.data(), + database_filtered_ptr, + ps.n_queries, + ps.n_rows - test_cagra_sample_filter::offset, + ps.dim, + ps.k, + ps.metric); + raft::linalg::addScalar(indices_naive_dev.data(), + indices_naive_dev.data(), + IdxT(test_cagra_sample_filter::offset), + queries_size, + stream_); + update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); + update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); + } + + { + rmm::device_uvector distances_dev(queries_size, stream_); + rmm::device_uvector indices_dev(queries_size, stream_); + + { + cagra::index_params index_params; + index_params.metric = ps.metric; // Note: currently ony the cagra::index_params metric is + // not used for knn_graph building. + cagra::search_params search_params; + search_params.algo = ps.algo; + search_params.max_queries = ps.max_queries; + search_params.team_size = ps.team_size; + search_params.hashmap_mode = cagra::hash_mode::HASH; + + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.n_rows, ps.dim); + + cagra::index index(handle_); + if (ps.host_dataset) { + auto database_host = raft::make_host_matrix(ps.n_rows, ps.dim); + raft::copy(database_host.data_handle(), database.data(), database.size(), stream_); + auto database_host_view = raft::make_host_matrix_view( + (const DataT*)database_host.data_handle(), ps.n_rows, ps.dim); + index = cagra::build(handle_, index_params, database_host_view); + } else { + index = cagra::build(handle_, index_params, database_view); + } + + if (!ps.include_serialized_dataset) { index.update_dataset(handle_, database_view); } + + auto search_queries_view = raft::make_device_matrix_view( + search_queries.data(), ps.n_queries, ps.dim); + auto indices_out_view = + raft::make_device_matrix_view(indices_dev.data(), ps.n_queries, ps.k); + auto dists_out_view = raft::make_device_matrix_view( + distances_dev.data(), ps.n_queries, ps.k); + + cagra::search_with_filtering(handle_, + search_params, + index, + search_queries_view, + indices_out_view, + dists_out_view, + test_cagra_sample_filter()); + update_host(distances_Cagra.data(), distances_dev.data(), queries_size, stream_); + update_host(indices_Cagra.data(), indices_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); + } + + // Test filter + bool unacceptable_node = false; + for (int q = 0; q < ps.n_queries; q++) { + for (int i = 0; i < ps.k; i++) { + const auto n = indices_Cagra[q * ps.k + i]; + unacceptable_node = unacceptable_node | !test_cagra_sample_filter()(q, n); + } + } + EXPECT_FALSE(unacceptable_node); + + double min_recall = ps.min_recall; + EXPECT_TRUE(eval_neighbours(indices_naive, + indices_Cagra, + distances_naive, + distances_Cagra, + ps.n_queries, + ps.k, + 0.001, + min_recall)); + EXPECT_TRUE(eval_distances(handle_, + database.data(), + search_queries.data(), + indices_dev.data(), + distances_dev.data(), + ps.n_rows, + ps.dim, + ps.n_queries, + ps.k, + ps.metric, + 1.0e-4)); + } + } + + void SetUp() override + { + database.resize(((size_t)ps.n_rows) * ps.dim, stream_); + search_queries.resize(ps.n_queries * ps.dim, stream_); + raft::random::Rng r(1234ULL); + if constexpr (std::is_same{}) { + r.normal(database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0), stream_); + r.normal(search_queries.data(), ps.n_queries * ps.dim, DataT(0.1), DataT(2.0), stream_); + } else { + r.uniformInt(database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20), stream_); + r.uniformInt(search_queries.data(), ps.n_queries * ps.dim, DataT(1), DataT(20), stream_); + } + resource::sync_stream(handle_); + } + + void TearDown() override + { + resource::sync_stream(handle_); + database.resize(0, stream_); + search_queries.resize(0, stream_); + } + + private: + raft::resources handle_; + rmm::cuda_stream_view stream_; + AnnCagraInputs ps; + rmm::device_uvector database; + rmm::device_uvector search_queries; +}; + inline std::vector generate_inputs() { // TODO(tfeher): test MULTI_CTA kernel with search_width > 1 to allow multiple CTA per queries @@ -467,4 +640,4 @@ inline std::vector generate_inputs() const std::vector inputs = generate_inputs(); -} // namespace raft::neighbors::experimental::cagra +} // namespace raft::neighbors::cagra diff --git a/cpp/test/neighbors/ann_cagra/search_kernel_uint64_t.cuh b/cpp/test/neighbors/ann_cagra/search_kernel_uint64_t.cuh index f61e476652..175e4ef483 100644 --- a/cpp/test/neighbors/ann_cagra/search_kernel_uint64_t.cuh +++ b/cpp/test/neighbors/ann_cagra/search_kernel_uint64_t.cuh @@ -1,93 +1,107 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include // RAFT_EXPLICIT - -namespace raft::neighbors::cagra::detail { - -namespace multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - extern template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - cudaStream_t stream); - -instantiate_kernel_selection(32, 1024, float, uint64_t, float); -instantiate_kernel_selection(8, 128, float, uint64_t, float); -instantiate_kernel_selection(16, 256, float, uint64_t, float); -instantiate_kernel_selection(32, 512, float, uint64_t, float); - -#undef instantiate_kernel_selection -} // namespace multi_cta_search - -namespace single_cta_search { - -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - extern template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - cudaStream_t stream); - -instantiate_single_cta_select_and_run(32, 1024, float, uint64_t, float); -instantiate_single_cta_select_and_run(8, 128, float, uint64_t, float); -instantiate_single_cta_select_and_run(16, 256, float, uint64_t, float); -instantiate_single_cta_select_and_run(32, 512, float, uint64_t, float); - -} // namespace single_cta_search -} // namespace raft::neighbors::cagra::detail +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include // none_cagra_sample_filter +#include // RAFT_EXPLICIT + +namespace raft::neighbors::cagra::detail { + +namespace multi_cta_search { +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + extern template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ + cudaStream_t stream); + +instantiate_kernel_selection( + 32, 1024, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection( + 8, 128, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection( + 16, 256, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection( + 32, 512, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); + +#undef instantiate_kernel_selection +} // namespace multi_cta_search + +namespace single_cta_search { + +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + extern template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ + cudaStream_t stream); + +instantiate_single_cta_select_and_run( + 32, 1024, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run( + 8, 128, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run( + 16, 256, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run( + 32, 512, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace single_cta_search +} // namespace raft::neighbors::cagra::detail \ No newline at end of file diff --git a/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu b/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu index fa3d76d066..6f9e8dbd43 100644 --- a/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu @@ -19,11 +19,11 @@ #include "../ann_cagra.cuh" #include "search_kernel_uint64_t.cuh" -namespace raft::neighbors::experimental::cagra { +namespace raft::neighbors::cagra { typedef AnnCagraTest AnnCagraTestF_I64; TEST_P(AnnCagraTestF_I64, AnnCagra) { this->testCagra(); } INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestF_I64, ::testing::ValuesIn(inputs)); -} // namespace raft::neighbors::experimental::cagra +} // namespace raft::neighbors::cagra diff --git a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu index dbaf4dedd9..01d7e1e1ea 100644 --- a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu @@ -18,7 +18,7 @@ #include "../ann_cagra.cuh" -namespace raft::neighbors::experimental::cagra { +namespace raft::neighbors::cagra { typedef AnnCagraTest AnnCagraTestF_U32; TEST_P(AnnCagraTestF_U32, AnnCagra) { this->testCagra(); } @@ -26,7 +26,11 @@ TEST_P(AnnCagraTestF_U32, AnnCagra) { this->testCagra(); } typedef AnnCagraSortTest AnnCagraSortTestF_U32; TEST_P(AnnCagraSortTestF_U32, AnnCagraSort) { this->testCagraSort(); } +typedef AnnCagraFilterTest AnnCagraFilterTestF_U32; +TEST_P(AnnCagraFilterTestF_U32, AnnCagraFilter) { this->testCagraFilter(); } + INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestF_U32, ::testing::ValuesIn(inputs)); INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestF_U32, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraFilterTest, AnnCagraFilterTestF_U32, ::testing::ValuesIn(inputs)); -} // namespace raft::neighbors::experimental::cagra +} // namespace raft::neighbors::cagra diff --git a/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu index ba60131677..ee06d369fa 100644 --- a/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu @@ -18,14 +18,17 @@ #include "../ann_cagra.cuh" -namespace raft::neighbors::experimental::cagra { +namespace raft::neighbors::cagra { typedef AnnCagraTest AnnCagraTestI8_U32; TEST_P(AnnCagraTestI8_U32, AnnCagra) { this->testCagra(); } typedef AnnCagraSortTest AnnCagraSortTestI8_U32; TEST_P(AnnCagraSortTestI8_U32, AnnCagraSort) { this->testCagraSort(); } +typedef AnnCagraFilterTest AnnCagraFilterTestI8_U32; +TEST_P(AnnCagraFilterTestI8_U32, AnnCagraFilter) { this->testCagraFilter(); } INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestI8_U32, ::testing::ValuesIn(inputs)); INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestI8_U32, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraFilterTest, AnnCagraFilterTestI8_U32, ::testing::ValuesIn(inputs)); -} // namespace raft::neighbors::experimental::cagra +} // namespace raft::neighbors::cagra diff --git a/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu index cc172e4833..3243e73ccd 100644 --- a/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu @@ -18,7 +18,7 @@ #include "../ann_cagra.cuh" -namespace raft::neighbors::experimental::cagra { +namespace raft::neighbors::cagra { typedef AnnCagraTest AnnCagraTestU8_U32; TEST_P(AnnCagraTestU8_U32, AnnCagra) { this->testCagra(); } @@ -26,7 +26,11 @@ TEST_P(AnnCagraTestU8_U32, AnnCagra) { this->testCagra(); } typedef AnnCagraSortTest AnnCagraSortTestU8_U32; TEST_P(AnnCagraSortTestU8_U32, AnnCagraSort) { this->testCagraSort(); } +typedef AnnCagraFilterTest AnnCagraFilterTestU8_U32; +TEST_P(AnnCagraFilterTestU8_U32, AnnCagraSort) { this->testCagraFilter(); } + INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestU8_U32, ::testing::ValuesIn(inputs)); INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestU8_U32, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraFilterTest, AnnCagraFilterTestU8_U32, ::testing::ValuesIn(inputs)); -} // namespace raft::neighbors::experimental::cagra +} // namespace raft::neighbors::cagra From d4002b0921ba12a6a1296d39dd55802469a36550 Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Tue, 26 Sep 2023 03:39:21 +0200 Subject: [PATCH 04/10] Add IVF-Flat C++ example (#1828) This PR adds a C++ example program that demonstrate the usage of IVF-Flat vector search. Authors: - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Artem M. Chirkin (https://github.com/achirkin) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1828 --- cpp/template/CMakeLists.txt | 8 +- cpp/template/src/cagra_example.cu | 90 ++++++++++++++ cpp/template/src/common.cuh | 95 +++++++++++++++ cpp/template/src/ivf_flat_example.cu | 160 +++++++++++++++++++++++++ cpp/template/src/test_vector_search.cu | 59 --------- 5 files changed, 351 insertions(+), 61 deletions(-) create mode 100644 cpp/template/src/cagra_example.cu create mode 100644 cpp/template/src/common.cuh create mode 100644 cpp/template/src/ivf_flat_example.cu delete mode 100644 cpp/template/src/test_vector_search.cu diff --git a/cpp/template/CMakeLists.txt b/cpp/template/CMakeLists.txt index 44b06e1b5f..a1341f3609 100644 --- a/cpp/template/CMakeLists.txt +++ b/cpp/template/CMakeLists.txt @@ -34,5 +34,9 @@ rapids_cpm_init() include(cmake/thirdparty/get_raft.cmake) # -------------- compile tasks ----------------- # -add_executable(TEST_RAFT src/test_vector_search.cu) -target_link_libraries(TEST_RAFT PRIVATE raft::raft raft::compiled) +add_executable(CAGRA_EXAMPLE src/cagra_example.cu) +target_link_libraries(CAGRA_EXAMPLE PRIVATE raft::raft raft::compiled) + +add_executable(IVF_FLAT_EXAMPLE src/ivf_flat_example.cu) +target_link_libraries(IVF_FLAT_EXAMPLE PRIVATE raft::raft raft::compiled) + diff --git a/cpp/template/src/cagra_example.cu b/cpp/template/src/cagra_example.cu new file mode 100644 index 0000000000..7f3a7d6676 --- /dev/null +++ b/cpp/template/src/cagra_example.cu @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include +#include + +#include "common.cuh" + +void cagra_build_search_simple(raft::device_resources const& dev_resources, + raft::device_matrix_view dataset, + raft::device_matrix_view queries) +{ + using namespace raft::neighbors; + + int64_t topk = 12; + int64_t n_queries = queries.extent(0); + + // create output arrays + auto neighbors = raft::make_device_matrix(dev_resources, n_queries, topk); + auto distances = raft::make_device_matrix(dev_resources, n_queries, topk); + + // use default index parameters + cagra::index_params index_params; + + std::cout << "Building CAGRA index (search graph)" << std::endl; + auto index = cagra::build(dev_resources, index_params, dataset); + + std::cout << "CAGRA index has " << index.size() << " vectors" << std::endl; + std::cout << "CAGRA graph has degree " << index.graph_degree() << ", graph size [" + << index.graph().extent(0) << ", " << index.graph().extent(1) << "]" << std::endl; + + // use default search parameters + cagra::search_params search_params; + // search K nearest neighbors + cagra::search( + dev_resources, search_params, index, queries, neighbors.view(), distances.view()); + + // The call to ivf_flat::search is asynchronous. Before accessing the data, sync by calling + // raft::resource::sync_stream(dev_resources); + + print_results(dev_resources, neighbors.view(), distances.view()); +} + +int main() +{ + raft::device_resources dev_resources; + + // Set pool memory resource with 1 GiB initial pool size. All allocations use the same pool. + rmm::mr::pool_memory_resource pool_mr( + rmm::mr::get_current_device_resource(), 1024 * 1024 * 1024ull); + rmm::mr::set_current_device_resource(&pool_mr); + + // Alternatively, one could define a pool allocator for temporary arrays (used within RAFT + // algorithms). In that case only the internal arrays would use the pool, any other allocation + // uses the default RMM memory resource. Here is how to change the workspace memory resource to + // a pool with 2 GiB upper limit. + // raft::resource::set_workspace_to_pool_resource(dev_resources, 2 * 1024 * 1024 * 1024ull); + + // Create input arrays. + int64_t n_samples = 10000; + int64_t n_dim = 90; + int64_t n_queries = 10; + auto dataset = raft::make_device_matrix(dev_resources, n_samples, n_dim); + auto queries = raft::make_device_matrix(dev_resources, n_queries, n_dim); + generate_dataset(dev_resources, dataset.view(), queries.view()); + + // Simple build and search example. + cagra_build_search_simple(dev_resources, + raft::make_const_mdspan(dataset.view()), + raft::make_const_mdspan(queries.view())); +} diff --git a/cpp/template/src/common.cuh b/cpp/template/src/common.cuh new file mode 100644 index 0000000000..0b72d3bf3b --- /dev/null +++ b/cpp/template/src/common.cuh @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +// Fill dataset and queries with synthetic data. +void generate_dataset(raft::device_resources const& dev_resources, + raft::device_matrix_view dataset, + raft::device_matrix_view queries) +{ + auto labels = raft::make_device_vector(dev_resources, dataset.extent(0)); + raft::random::make_blobs(dev_resources, dataset, labels.view()); + raft::random::RngState r(1234ULL); + raft::random::uniform(dev_resources, + r, + raft::make_device_vector_view(queries.data_handle(), queries.size()), + -1.0f, + 1.0f); +} + +// Copy the results to host and print a few samples +template +void print_results(raft::device_resources const& dev_resources, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) +{ + int64_t topk = neighbors.extent(1); + auto neighbors_host = raft::make_host_matrix(neighbors.extent(0), topk); + auto distances_host = raft::make_host_matrix(distances.extent(0), topk); + + cudaStream_t stream = raft::resource::get_cuda_stream(dev_resources); + + raft::copy(neighbors_host.data_handle(), neighbors.data_handle(), neighbors.size(), stream); + raft::copy(distances_host.data_handle(), distances.data_handle(), distances.size(), stream); + + // The calls to RAFT algorithms and raft::copy is asynchronous. + // We need to sync the stream before accessing the data. + raft::resource::sync_stream(dev_resources, stream); + + for (int query_id = 0; query_id < 2; query_id++) { + std::cout << "Query " << query_id << " neighbor indices: "; + raft::print_host_vector("", &neighbors_host(query_id, 0), topk, std::cout); + std::cout << "Query " << query_id << " neighbor distances: "; + raft::print_host_vector("", &distances_host(query_id, 0), topk, std::cout); + } +} + +/** Subsample the dataset to create a training set*/ +raft::device_matrix subsample( + raft::device_resources const& dev_resources, + raft::device_matrix_view dataset, + raft::device_vector_view data_indices, + float fraction) +{ + int64_t n_samples = dataset.extent(0); + int64_t n_dim = dataset.extent(1); + int64_t n_train = n_samples * fraction; + auto trainset = raft::make_device_matrix(dev_resources, n_train, n_dim); + + int seed = 137; + raft::random::RngState rng(seed); + auto train_indices = raft::make_device_vector(dev_resources, n_train); + + raft::random::sample_without_replacement( + dev_resources, rng, data_indices, std::nullopt, train_indices.view(), std::nullopt); + + raft::matrix::copy_rows( + dev_resources, dataset, trainset.view(), raft::make_const_mdspan(train_indices.view())); + + return trainset; +} diff --git a/cpp/template/src/ivf_flat_example.cu b/cpp/template/src/ivf_flat_example.cu new file mode 100644 index 0000000000..5d91f8fe8b --- /dev/null +++ b/cpp/template/src/ivf_flat_example.cu @@ -0,0 +1,160 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +#include "common.cuh" + +void ivf_flat_build_search_simple(raft::device_resources const& dev_resources, + raft::device_matrix_view dataset, + raft::device_matrix_view queries) +{ + using namespace raft::neighbors; + + ivf_flat::index_params index_params; + index_params.n_lists = 1024; + index_params.kmeans_trainset_fraction = 0.1; + index_params.metric = raft::distance::DistanceType::L2Expanded; + + std::cout << "Building IVF-Flat index" << std::endl; + auto index = ivf_flat::build(dev_resources, index_params, dataset); + + std::cout << "Number of clusters " << index.n_lists() << ", number of vectors added to index " + << index.size() << std::endl; + + // Create output arrays. + int64_t topk = 10; + int64_t n_queries = queries.extent(0); + auto neighbors = raft::make_device_matrix(dev_resources, n_queries, topk); + auto distances = raft::make_device_matrix(dev_resources, n_queries, topk); + + // Set search parameters. + ivf_flat::search_params search_params; + search_params.n_probes = 50; + + // Search K nearest neighbors for each of the queries. + ivf_flat::search( + dev_resources, search_params, index, queries, neighbors.view(), distances.view()); + + // The call to ivf_flat::search is asynchronous. Before accessing the data, sync by calling + // raft::resource::sync_stream(dev_resources); + + print_results(dev_resources, neighbors.view(), distances.view()); +} + +void ivf_flat_build_extend_search(raft::device_resources const& dev_resources, + raft::device_matrix_view dataset, + raft::device_matrix_view queries) +{ + using namespace raft::neighbors; + + // Define dataset indices. + auto data_indices = raft::make_device_vector(dev_resources, dataset.extent(0)); + thrust::counting_iterator first(0); + thrust::device_ptr ptr(data_indices.data_handle()); + thrust::copy( + raft::resource::get_thrust_policy(dev_resources), first, first + dataset.extent(0), ptr); + + // Sub-sample the dataset to create a training set. + auto trainset = + subsample(dev_resources, dataset, raft::make_const_mdspan(data_indices.view()), 0.1); + + ivf_flat::index_params index_params; + index_params.n_lists = 100; + index_params.metric = raft::distance::DistanceType::L2Expanded; + index_params.add_data_on_build = false; + + std::cout << "\nRun k-means clustering using the training set" << std::endl; + auto index = + ivf_flat::build(dev_resources, index_params, raft::make_const_mdspan(trainset.view())); + + std::cout << "Number of clusters " << index.n_lists() << ", number of vectors added to index " + << index.size() << std::endl; + + std::cout << "Filling index with the dataset vectors" << std::endl; + index = ivf_flat::extend(dev_resources, + dataset, + std::make_optional(raft::make_const_mdspan(data_indices.view())), + index); + + std::cout << "Index size after addin dataset vectors " << index.size() << std::endl; + + // Set search parameters. + ivf_flat::search_params search_params; + search_params.n_probes = 10; + + // Create output arrays. + int64_t topk = 10; + int64_t n_queries = queries.extent(0); + auto neighbors = raft::make_device_matrix(dev_resources, n_queries, topk); + auto distances = raft::make_device_matrix(dev_resources, n_queries, topk); + + // Search K nearest neighbors for each queries. + ivf_flat::search( + dev_resources, search_params, index, queries, neighbors.view(), distances.view()); + + // The call to ivf_flat::search is asynchronous. Before accessing the data, sync using: + // raft::resource::sync_stream(dev_resources); + + print_results(dev_resources, neighbors.view(), distances.view()); +} + +int main() +{ + raft::device_resources dev_resources; + + // Set pool memory resource with 1 GiB initial pool size. All allocations use the same pool. + rmm::mr::pool_memory_resource pool_mr( + rmm::mr::get_current_device_resource(), 1024 * 1024 * 1024ull); + rmm::mr::set_current_device_resource(&pool_mr); + + // Alternatively, one could define a pool allocator for temporary arrays (used within RAFT + // algorithms). In that case only the internal arrays would use the pool, any other allocation + // uses the default RMM memory resource. Here is how to change the workspace memory resource to + // a pool with 2 GiB upper limit. + // raft::resource::set_workspace_to_pool_resource(dev_resources, 2 * 1024 * 1024 * 1024ull); + + // Create input arrays. + int64_t n_samples = 10000; + int64_t n_dim = 3; + int64_t n_queries = 10; + auto dataset = raft::make_device_matrix(dev_resources, n_samples, n_dim); + auto queries = raft::make_device_matrix(dev_resources, n_queries, n_dim); + generate_dataset(dev_resources, dataset.view(), queries.view()); + + // Simple build and search example. + ivf_flat_build_search_simple(dev_resources, + raft::make_const_mdspan(dataset.view()), + raft::make_const_mdspan(queries.view())); + + // Build and extend example. + ivf_flat_build_extend_search(dev_resources, + raft::make_const_mdspan(dataset.view()), + raft::make_const_mdspan(queries.view())); +} diff --git a/cpp/template/src/test_vector_search.cu b/cpp/template/src/test_vector_search.cu deleted file mode 100644 index f54cfc03e7..0000000000 --- a/cpp/template/src/test_vector_search.cu +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include - -int main() -{ - using namespace raft::neighbors; - raft::device_resources dev_resources; - // Use 5 GB of pool memory - raft::resource::set_workspace_to_pool_resource( - dev_resources, std::make_optional(5 * 1024 * 1024 * 1024ull)); - - int64_t n_samples = 50000; - int64_t n_dim = 90; - int64_t topk = 12; - int64_t n_queries = 1; - - // create input and output arrays - auto input = raft::make_device_matrix(dev_resources, n_samples, n_dim); - auto labels = raft::make_device_vector(dev_resources, n_samples); - auto queries = raft::make_device_matrix(dev_resources, n_queries, n_dim); - auto neighbors = raft::make_device_matrix(dev_resources, n_queries, topk); - auto distances = raft::make_device_matrix(dev_resources, n_queries, topk); - - raft::random::make_blobs(dev_resources, input.view(), labels.view()); - - // use default index parameters - cagra::index_params index_params; - // create and fill the index from a [n_samples, n_dim] input - auto index = cagra::build( - dev_resources, index_params, raft::make_const_mdspan(input.view())); - // use default search parameters - cagra::search_params search_params; - // search K nearest neighbors - cagra::search(dev_resources, - search_params, - index, - raft::make_const_mdspan(queries.view()), - neighbors.view(), - distances.view()); -} From a1002f8c8f4debc52fbab7191297a2f54ff42856 Mon Sep 17 00:00:00 2001 From: Divye Gala Date: Tue, 26 Sep 2023 02:07:39 -0400 Subject: [PATCH 05/10] Port NN-descent algorithm to use in `cagra::build()` (#1748) - [x] Build Time comparison of end-to-end RAFT CAGRA+nn-descent against cuANN CAGRA+nn-descent - [x] Recall comparison of RAFT nn-descent against cuANN nn-descent - [x] RAFT types/APIs in ported code from cuANN - [x] End-to-end CAGRA+nn-descent tests - [x] Docs and code examples - [x] Add `graph_build_algo` build param to CAGRA ann-bench for benchmarking builds with IVF-PQ or NN-Descent - [x] All-neighbors knn graph nn-descent tests against brute-force knn Recall Value comparison of RAFT nn-descent vs cuANN nn-descent ``` Dataset graph_degree, intermediate_degree Iterations cuANN Recall RAFT Recall sift-128-euclidean (64, 98) 15 0.9265991875 0.9471194688 sift-128-euclidean (64, 98) 50 0.9831858594 0.9783938594 deep-image-96-inner (64, 98) 50 0.9806211946 0.9801508853 ``` Authors: - Divye Gala (https://github.com/divyegala) - Ray Wang (https://github.com/RayWang96) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Ray Wang (https://github.com/RayWang96) - Ray Douglass (https://github.com/raydouglass) - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1748 --- build.sh | 2 +- ci/build_cpp.sh | 2 +- cpp/bench/ann/src/raft/raft_benchmark.cu | 7 + cpp/include/raft/neighbors/cagra.cuh | 76 +- cpp/include/raft/neighbors/cagra_types.hpp | 14 + .../neighbors/detail/cagra/cagra_build.cuh | 24 + .../neighbors/detail/cagra/graph_core.cuh | 2 +- .../raft/neighbors/detail/nn_descent.cuh | 1453 +++++++++++++++++ cpp/include/raft/neighbors/nn_descent.cuh | 181 ++ .../raft/neighbors/nn_descent_types.hpp | 147 ++ cpp/test/CMakeLists.txt | 15 + cpp/test/neighbors/ann_cagra.cuh | 136 +- cpp/test/neighbors/ann_nn_descent.cuh | 154 ++ .../ann_nn_descent/test_float_uint32_t.cu | 28 + .../ann_nn_descent/test_int8_t_uint32_t.cu | 28 + .../ann_nn_descent/test_uint8_t_uint32_t.cu | 28 + cpp/test/neighbors/ann_utils.cuh | 43 + docs/source/ann_benchmarks_param_tuning.md | 3 +- .../pylibraft/neighbors/cagra/cagra.pyx | 23 +- .../pylibraft/neighbors/cagra/cpp/c_cagra.pxd | 5 + python/pylibraft/pylibraft/test/test_cagra.py | 12 +- 21 files changed, 2294 insertions(+), 89 deletions(-) create mode 100644 cpp/include/raft/neighbors/detail/nn_descent.cuh create mode 100644 cpp/include/raft/neighbors/nn_descent.cuh create mode 100644 cpp/include/raft/neighbors/nn_descent_types.hpp create mode 100644 cpp/test/neighbors/ann_nn_descent.cuh create mode 100644 cpp/test/neighbors/ann_nn_descent/test_float_uint32_t.cu create mode 100644 cpp/test/neighbors/ann_nn_descent/test_int8_t_uint32_t.cu create mode 100644 cpp/test/neighbors/ann_nn_descent/test_uint8_t_uint32_t.cu diff --git a/build.sh b/build.sh index 071820ba93..32582738c3 100755 --- a/build.sh +++ b/build.sh @@ -78,7 +78,7 @@ INSTALL_TARGET=install BUILD_REPORT_METRICS="" BUILD_REPORT_INCL_CACHE_STATS=OFF -TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" +TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" BENCH_TARGETS="CLUSTER_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH" CACHE_ARGS="" diff --git a/ci/build_cpp.sh b/ci/build_cpp.sh index d2d2d08b99..a41f81152d 100755 --- a/ci/build_cpp.sh +++ b/ci/build_cpp.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. set -euo pipefail diff --git a/cpp/bench/ann/src/raft/raft_benchmark.cu b/cpp/bench/ann/src/raft/raft_benchmark.cu index 7ba381ab0a..a9ff6c2922 100644 --- a/cpp/bench/ann/src/raft/raft_benchmark.cu +++ b/cpp/bench/ann/src/raft/raft_benchmark.cu @@ -147,6 +147,13 @@ void parse_build_param(const nlohmann::json& conf, if (conf.contains("intermediate_graph_degree")) { param.intermediate_graph_degree = conf.at("intermediate_graph_degree"); } + if (conf.contains("graph_build_algo")) { + if (conf.at("graph_build_algo") == "IVF_PQ") { + param.build_algo = raft::neighbors::cagra::graph_build_algo::IVF_PQ; + } else if (conf.at("graph_build_algo") == "NN_DESCENT") { + param.build_algo = raft::neighbors::cagra::graph_build_algo::NN_DESCENT; + } + } } template diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index 1bd7010c83..f96dd34e05 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -35,12 +35,11 @@ namespace raft::neighbors::cagra { */ /** - * @brief Build a kNN graph. + * @brief Build a kNN graph using IVF-PQ. * * The kNN graph is the first building block for CAGRA index. - * This function uses the IVF-PQ method to build a kNN graph. * - * The output is a dense matrix that stores the neighbor indices for each pont in the dataset. + * The output is a dense matrix that stores the neighbor indices for each point in the dataset. * Each point has the same number of neighbors. * * See [cagra::build](#cagra::build) for an alternative method. @@ -52,9 +51,9 @@ namespace raft::neighbors::cagra { * @code{.cpp} * using namespace raft::neighbors; * // use default index parameters - * cagra::index_params build_params; - * cagra::search_params search_params - * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); + * ivf_pq::index_params build_params; + * ivf_pq::search_params search_params + * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); * // create knn graph * cagra::build_knn_graph(res, dataset, knn_graph.view(), 2, build_params, search_params); * auto optimized_gaph = raft::make_host_matrix(dataset.extent(0), 64); @@ -70,7 +69,7 @@ namespace raft::neighbors::cagra { * @param[in] res raft resources * @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim] * @param[out] knn_graph a host matrix view to store the output knn graph [n_rows, graph_degree] - * @param[in] refine_rate refinement rate for ivf-pq search + * @param[in] refine_rate (optional) refinement rate for ivf-pq search * @param[in] build_params (optional) ivf_pq index building parameters for knn graph * @param[in] search_params (optional) ivf_pq search parameters */ @@ -95,6 +94,58 @@ void build_knn_graph(raft::resources const& res, res, dataset_internal, knn_graph_internal, refine_rate, build_params, search_params); } +/** + * @brief Build a kNN graph using NN-descent. + * + * The kNN graph is the first building block for CAGRA index. + * + * The output is a dense matrix that stores the neighbor indices for each point in the dataset. + * Each point has the same number of neighbors. + * + * See [cagra::build](#cagra::build) for an alternative method. + * + * The following distance metrics are supported: + * - L2Expanded + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * using namespace raft::neighbors::experimental; + * // use default index parameters + * nn_descent::index_params build_params; + * build_params.graph_degree = 128; + * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); + * // create knn graph + * cagra::build_knn_graph(res, dataset, knn_graph.view(), build_params); + * auto optimized_gaph = raft::make_host_matrix(dataset.extent(0), 64); + * cagra::optimize(res, dataset, nn_descent_index.graph.view(), optimized_graph.view()); + * // Construct an index from dataset and optimized knn_graph + * auto index = cagra::index(res, build_params.metric(), dataset, + * optimized_graph.view()); + * @endcode + * + * @tparam DataT data element type + * @tparam IdxT type of the dataset vector indices + * @tparam accessor host or device accessor_type for the dataset + * @param[in] res raft::resources is an object mangaging resources + * @param[in] dataset input raft::host/device_matrix_view that can be located in + * in host or device memory + * @param[out] knn_graph a host matrix view to store the output knn graph [n_rows, graph_degree] + * @param[in] build_params an instance of experimental::nn_descent::index_params that are parameters + * to run the nn-descent algorithm + */ +template , memory_type::device>> +void build_knn_graph(raft::resources const& res, + mdspan, row_major, accessor> dataset, + raft::host_matrix_view knn_graph, + experimental::nn_descent::index_params build_params) +{ + detail::build_knn_graph(res, dataset, knn_graph, build_params); +} + /** * @brief Sort a KNN graph index. * Preprocessing step for `cagra::optimize`: If a KNN graph is not built using @@ -259,7 +310,16 @@ index build(raft::resources const& res, std::optional> knn_graph( raft::make_host_matrix(dataset.extent(0), intermediate_degree)); - build_knn_graph(res, dataset, knn_graph->view()); + if (params.build_algo == graph_build_algo::IVF_PQ) { + build_knn_graph(res, dataset, knn_graph->view()); + + } else { + // Use nn-descent to build CAGRA knn graph + auto nn_descent_params = experimental::nn_descent::index_params(); + nn_descent_params.graph_degree = intermediate_degree; + nn_descent_params.intermediate_graph_degree = 1.5 * intermediate_degree; + build_knn_graph(res, dataset, knn_graph->view(), nn_descent_params); + } auto cagra_graph = raft::make_host_matrix(dataset.extent(0), graph_degree); diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index 4728178194..5061d6082d 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -40,11 +40,24 @@ namespace raft::neighbors::cagra { * @{ */ +/** + * @brief ANN algorithm used by CAGRA to build knn graph + * + */ +enum class graph_build_algo { + /* Use IVF-PQ to build all-neighbors knn graph */ + IVF_PQ, + /* Experimental, use NN-Descent to build all-neighbors knn graph */ + NN_DESCENT +}; + struct index_params : ann::index_params { /** Degree of input graph for pruning. */ size_t intermediate_graph_degree = 128; /** Degree of output graph. */ size_t graph_degree = 64; + /** ANN algorithm to build knn graph. */ + graph_build_algo build_algo = graph_build_algo::IVF_PQ; }; enum class search_algo { @@ -362,6 +375,7 @@ struct index : ann::index { // TODO: Remove deprecated experimental namespace in 23.12 release namespace raft::neighbors::experimental::cagra { +using raft::neighbors::cagra::graph_build_algo; using raft::neighbors::cagra::hash_mode; using raft::neighbors::cagra::index; using raft::neighbors::cagra::index_params; diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh index 80e964df57..40024a3deb 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -35,6 +35,7 @@ #include #include #include +#include #include namespace raft::neighbors::cagra::detail { @@ -240,4 +241,27 @@ void build_knn_graph(raft::resources const& res, if (!first) RAFT_LOG_DEBUG("# Finished building kNN graph"); } +template +void build_knn_graph(raft::resources const& res, + mdspan, row_major, accessor> dataset, + raft::host_matrix_view knn_graph, + experimental::nn_descent::index_params build_params) +{ + auto nn_descent_idx = experimental::nn_descent::index(res, knn_graph); + experimental::nn_descent::build(res, build_params, dataset, nn_descent_idx); + + using internal_IdxT = typename std::make_unsigned::type; + using g_accessor = typename decltype(nn_descent_idx.graph())::accessor_type; + using g_accessor_internal = + host_device_accessor, g_accessor::mem_type>; + + auto knn_graph_internal = + mdspan, row_major, g_accessor_internal>( + reinterpret_cast(nn_descent_idx.graph().data_handle()), + nn_descent_idx.graph().extent(0), + nn_descent_idx.graph().extent(1)); + + graph::sort_knn_graph(res, dataset, knn_graph_internal); +} + } // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh index 18d451be60..8845e37973 100644 --- a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh @@ -244,7 +244,7 @@ void sort_knn_graph(raft::resources const& res, const uint32_t input_graph_degree = knn_graph.extent(1); IdxT* const input_graph_ptr = knn_graph.data_handle(); - auto d_input_graph = raft::make_device_matrix(res, graph_size, input_graph_degree); + auto d_input_graph = raft::make_device_matrix(res, graph_size, input_graph_degree); // // Sorting kNN graph diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh new file mode 100644 index 0000000000..3e4d0409bd --- /dev/null +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -0,0 +1,1453 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +#include "../nn_descent_types.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include // raft::util::arch::SM_* +#include +#include +#include +#include + +namespace raft::neighbors::experimental::nn_descent::detail { + +using pinned_memory_resource = thrust::universal_host_pinned_memory_resource; +template +using pinned_memory_allocator = thrust::mr::stateless_resource_allocator; + +using DistData_t = float; +constexpr int DEGREE_ON_DEVICE{32}; +constexpr int SEGMENT_SIZE{32}; +constexpr int counter_interval{100}; +template +struct InternalID_t; + +// InternalID_t uses 1 bit for marking (new or old). +template <> +class InternalID_t { + private: + using Index_t = int; + Index_t id_{std::numeric_limits::max()}; + + public: + __host__ __device__ bool is_new() const { return id_ >= 0; } + __host__ __device__ Index_t& id_with_flag() { return id_; } + __host__ __device__ Index_t id() const + { + if (is_new()) return id_; + return -id_ - 1; + } + __host__ __device__ void mark_old() + { + if (id_ >= 0) id_ = -id_ - 1; + } + __host__ __device__ bool operator==(const InternalID_t& other) const + { + return id() == other.id(); + } +}; + +template +struct ResultItem; + +template <> +class ResultItem { + private: + using Index_t = int; + Index_t id_; + DistData_t dist_; + + public: + __host__ __device__ ResultItem() + : id_(std::numeric_limits::max()), dist_(std::numeric_limits::max()){}; + __host__ __device__ ResultItem(const Index_t id_with_flag, const DistData_t dist) + : id_(id_with_flag), dist_(dist){}; + __host__ __device__ bool is_new() const { return id_ >= 0; } + __host__ __device__ Index_t& id_with_flag() { return id_; } + __host__ __device__ Index_t id() const + { + if (is_new()) return id_; + return -id_ - 1; + } + __host__ __device__ DistData_t& dist() { return dist_; } + + __host__ __device__ void mark_old() + { + if (id_ >= 0) id_ = -id_ - 1; + } + + __host__ __device__ bool operator<(const ResultItem& other) const + { + if (dist_ == other.dist_) return id() < other.id(); + return dist_ < other.dist_; + } + __host__ __device__ bool operator==(const ResultItem& other) const + { + return id() == other.id(); + } + __host__ __device__ bool operator>=(const ResultItem& other) const + { + return !(*this < other); + } + __host__ __device__ bool operator<=(const ResultItem& other) const + { + return (*this == other) || (*this < other); + } + __host__ __device__ bool operator>(const ResultItem& other) const + { + return !(*this <= other); + } + __host__ __device__ bool operator!=(const ResultItem& other) const + { + return !(*this == other); + } +}; + +using align32 = raft::Pow2<32>; + +template +int get_batch_size(const int it_now, const T nrow, const int batch_size) +{ + int it_total = ceildiv(nrow, batch_size); + return (it_now == it_total - 1) ? nrow - it_now * batch_size : batch_size; +} + +// for avoiding bank conflict +template +constexpr __host__ __device__ __forceinline__ int skew_dim(int ndim) +{ + // all "4"s are for alignment + if constexpr (std::is_same::value) { + ndim = ceildiv(ndim, 4) * 4; + return ndim + (ndim % 32 == 0) * 4; + } +} + +template +__device__ __forceinline__ ResultItem xor_swap(ResultItem x, int mask, int dir) +{ + ResultItem y; + y.dist() = __shfl_xor_sync(raft::warp_full_mask(), x.dist(), mask, raft::warp_size()); + y.id_with_flag() = + __shfl_xor_sync(raft::warp_full_mask(), x.id_with_flag(), mask, raft::warp_size()); + return x < y == dir ? y : x; +} + +__device__ __forceinline__ int xor_swap(int x, int mask, int dir) +{ + int y = __shfl_xor_sync(raft::warp_full_mask(), x, mask, raft::warp_size()); + return x < y == dir ? y : x; +} + +// TODO: Move to RAFT utils https://github.com/rapidsai/raft/issues/1827 +__device__ __forceinline__ uint bfe(uint lane_id, uint pos) +{ + uint res; + asm("bfe.u32 %0,%1,%2,%3;" : "=r"(res) : "r"(lane_id), "r"(pos), "r"(1)); + return res; +} + +template +__device__ __forceinline__ void warp_bitonic_sort(T* element_ptr, const int lane_id) +{ + static_assert(raft::warp_size() == 32); + auto& element = *element_ptr; + element = xor_swap(element, 0x01, bfe(lane_id, 1) ^ bfe(lane_id, 0)); + element = xor_swap(element, 0x02, bfe(lane_id, 2) ^ bfe(lane_id, 1)); + element = xor_swap(element, 0x01, bfe(lane_id, 2) ^ bfe(lane_id, 0)); + element = xor_swap(element, 0x04, bfe(lane_id, 3) ^ bfe(lane_id, 2)); + element = xor_swap(element, 0x02, bfe(lane_id, 3) ^ bfe(lane_id, 1)); + element = xor_swap(element, 0x01, bfe(lane_id, 3) ^ bfe(lane_id, 0)); + element = xor_swap(element, 0x08, bfe(lane_id, 4) ^ bfe(lane_id, 3)); + element = xor_swap(element, 0x04, bfe(lane_id, 4) ^ bfe(lane_id, 2)); + element = xor_swap(element, 0x02, bfe(lane_id, 4) ^ bfe(lane_id, 1)); + element = xor_swap(element, 0x01, bfe(lane_id, 4) ^ bfe(lane_id, 0)); + element = xor_swap(element, 0x10, bfe(lane_id, 4)); + element = xor_swap(element, 0x08, bfe(lane_id, 3)); + element = xor_swap(element, 0x04, bfe(lane_id, 2)); + element = xor_swap(element, 0x02, bfe(lane_id, 1)); + element = xor_swap(element, 0x01, bfe(lane_id, 0)); + return; +} + +struct BuildConfig { + size_t max_dataset_size; + size_t dataset_dim; + size_t node_degree{64}; + size_t internal_node_degree{0}; + // If internal_node_degree == 0, the value of node_degree will be assigned to it + size_t max_iterations{50}; + float termination_threshold{0.0001}; +}; + +template +class BloomFilter { + public: + BloomFilter(size_t nrow, size_t num_sets_per_list, size_t num_hashs) + : nrow_(nrow), + num_sets_per_list_(num_sets_per_list), + num_hashs_(num_hashs), + bitsets_(nrow * num_bits_per_set_ * num_sets_per_list) + { + } + + void add(size_t list_id, Index_t key) + { + if (is_cleared) { is_cleared = false; } + uint32_t hash = hash_0(key); + size_t global_set_idx = list_id * num_bits_per_set_ * num_sets_per_list_ + + key % num_sets_per_list_ * num_bits_per_set_; + bitsets_[global_set_idx + hash % num_bits_per_set_] = 1; + for (size_t i = 1; i < num_hashs_; i++) { + hash = hash + hash_1(key); + bitsets_[global_set_idx + hash % num_bits_per_set_] = 1; + } + } + + bool check(size_t list_id, Index_t key) + { + bool is_present = true; + uint32_t hash = hash_0(key); + size_t global_set_idx = list_id * num_bits_per_set_ * num_sets_per_list_ + + key % num_sets_per_list_ * num_bits_per_set_; + is_present &= bitsets_[global_set_idx + hash % num_bits_per_set_]; + + if (!is_present) return false; + for (size_t i = 1; i < num_hashs_; i++) { + hash = hash + hash_1(key); + is_present &= bitsets_[global_set_idx + hash % num_bits_per_set_]; + if (!is_present) return false; + } + return true; + } + + void clear() + { + if (is_cleared) return; +#pragma omp parallel for + for (size_t i = 0; i < nrow_ * num_bits_per_set_ * num_sets_per_list_; i++) { + bitsets_[i] = 0; + } + is_cleared = true; + } + + private: + uint32_t hash_0(uint32_t value) + { + value *= 1103515245; + value += 12345; + value ^= value << 13; + value ^= value >> 17; + value ^= value << 5; + return value; + } + + uint32_t hash_1(uint32_t value) + { + value *= 1664525; + value += 1013904223; + value ^= value << 13; + value ^= value >> 17; + value ^= value << 5; + return value; + } + + static constexpr int num_bits_per_set_ = 512; + bool is_cleared{true}; + std::vector bitsets_; + size_t nrow_; + size_t num_sets_per_list_; + size_t num_hashs_; +}; + +template +struct GnndGraph { + static constexpr int segment_size = 32; + InternalID_t* h_graph; + + size_t nrow; + size_t node_degree; + int num_samples; + int num_segments; + + raft::host_matrix h_dists; + + thrust::host_vector> h_graph_new; + thrust::host_vector> h_list_sizes_new; + + thrust::host_vector> h_graph_old; + thrust::host_vector> h_list_sizes_old; + BloomFilter bloom_filter; + + GnndGraph(const GnndGraph&) = delete; + GnndGraph& operator=(const GnndGraph&) = delete; + GnndGraph(const size_t nrow, + const size_t node_degree, + const size_t internal_node_degree, + const size_t num_samples); + void init_random_graph(); + // TODO: Create a generic bloom filter utility https://github.com/rapidsai/raft/issues/1827 + // Use Bloom filter to sample "new" neighbors for local joining + void sample_graph_new(InternalID_t* new_neighbors, const size_t width); + void sample_graph(bool sample_new); + void update_graph(const InternalID_t* new_neighbors, + const DistData_t* new_dists, + const size_t width, + std::atomic& update_counter); + void sort_lists(); + void clear(); + ~GnndGraph(); +}; + +template +class GNND { + public: + GNND(raft::resources const& res, const BuildConfig& build_config); + GNND(const GNND&) = delete; + GNND& operator=(const GNND&) = delete; + + void build(Data_t* data, const Index_t nrow, Index_t* output_graph); + ~GNND() = default; + using ID_t = InternalID_t; + + private: + void add_reverse_edges(Index_t* graph_ptr, + Index_t* h_rev_graph_ptr, + Index_t* d_rev_graph_ptr, + int2* list_sizes, + cudaStream_t stream = 0); + void local_join(cudaStream_t stream = 0); + + raft::resources const& res; + + BuildConfig build_config_; + GnndGraph graph_; + std::atomic update_counter_; + + Index_t nrow_; + const int ndim_; + + raft::device_matrix<__half, Index_t, raft::row_major> d_data_; + raft::device_vector l2_norms_; + + raft::device_matrix graph_buffer_; + raft::device_matrix dists_buffer_; + + // TODO: Investigate using RMM/RAFT types https://github.com/rapidsai/raft/issues/1827 + thrust::host_vector> graph_host_buffer_; + thrust::host_vector> dists_host_buffer_; + + raft::device_vector d_locks_; + + thrust::host_vector> h_rev_graph_new_; + thrust::host_vector> h_graph_old_; + thrust::host_vector> h_rev_graph_old_; + // int2.x is the number of forward edges, int2.y is the number of reverse edges + + raft::device_vector d_list_sizes_new_; + raft::device_vector d_list_sizes_old_; +}; + +constexpr int TILE_ROW_WIDTH = 64; +constexpr int TILE_COL_WIDTH = 128; + +constexpr int NUM_SAMPLES = 32; +// For now, the max. number of samples is 32, so the sample cache size is fixed +// to 64 (32 * 2). +constexpr int MAX_NUM_BI_SAMPLES = 64; +constexpr int SKEWED_MAX_NUM_BI_SAMPLES = skew_dim(MAX_NUM_BI_SAMPLES); +constexpr int BLOCK_SIZE = 512; +constexpr int WMMA_M = 16; +constexpr int WMMA_N = 16; +constexpr int WMMA_K = 16; + +template +__device__ __forceinline__ void load_vec(Data_t* vec_buffer, + const Data_t* d_vec, + const int load_dims, + const int padding_dims, + const int lane_id) +{ + if constexpr (std::is_same_v or std::is_same_v or + std::is_same_v) { + constexpr int num_load_elems_per_warp = raft::warp_size(); + for (int step = 0; step < ceildiv(padding_dims, num_load_elems_per_warp); step++) { + int idx = step * num_load_elems_per_warp + lane_id; + if (idx < load_dims) { + vec_buffer[idx] = d_vec[idx]; + } else if (idx < padding_dims) { + vec_buffer[idx] = 0.0f; + } + } + } + if constexpr (std::is_same_v) { + if ((size_t)d_vec % sizeof(float2) == 0 && (size_t)vec_buffer % sizeof(float2) == 0 && + load_dims % 4 == 0 && padding_dims % 4 == 0) { + constexpr int num_load_elems_per_warp = raft::warp_size() * 4; +#pragma unroll + for (int step = 0; step < ceildiv(padding_dims, num_load_elems_per_warp); step++) { + int idx_in_vec = step * num_load_elems_per_warp + lane_id * 4; + if (idx_in_vec + 4 <= load_dims) { + *(float2*)(vec_buffer + idx_in_vec) = *(float2*)(d_vec + idx_in_vec); + } else if (idx_in_vec + 4 <= padding_dims) { + *(float2*)(vec_buffer + idx_in_vec) = float2({0.0f, 0.0f}); + } + } + } else { + constexpr int num_load_elems_per_warp = raft::warp_size(); + for (int step = 0; step < ceildiv(padding_dims, num_load_elems_per_warp); step++) { + int idx = step * num_load_elems_per_warp + lane_id; + if (idx < load_dims) { + vec_buffer[idx] = d_vec[idx]; + } else if (idx < padding_dims) { + vec_buffer[idx] = 0.0f; + } + } + } + } +} + +// TODO: Replace with RAFT utilities https://github.com/rapidsai/raft/issues/1827 +/** Calculate L2 norm, and cast data to __half */ +template +__global__ void preprocess_data_kernel(const Data_t* input_data, + __half* output_data, + int dim, + DistData_t* l2_norms, + size_t list_offset = 0) +{ + extern __shared__ char buffer[]; + __shared__ float l2_norm; + Data_t* s_vec = (Data_t*)buffer; + size_t list_id = list_offset + blockIdx.x; + + load_vec(s_vec, input_data + blockIdx.x * dim, dim, dim, threadIdx.x % raft::warp_size()); + if (threadIdx.x == 0) { l2_norm = 0; } + __syncthreads(); + int lane_id = threadIdx.x % raft::warp_size(); + for (int step = 0; step < ceildiv(dim, raft::warp_size()); step++) { + int idx = step * raft::warp_size() + lane_id; + float part_dist = 0; + if (idx < dim) { + part_dist = s_vec[idx]; + part_dist = part_dist * part_dist; + } + __syncwarp(); + for (int offset = raft::warp_size() >> 1; offset >= 1; offset >>= 1) { + part_dist += __shfl_down_sync(raft::warp_full_mask(), part_dist, offset); + } + if (lane_id == 0) { l2_norm += part_dist; } + __syncwarp(); + } + + for (int step = 0; step < ceildiv(dim, raft::warp_size()); step++) { + int idx = step * raft::warp_size() + threadIdx.x; + if (idx < dim) { + if (l2_norms == nullptr) { + output_data[list_id * dim + idx] = + (float)input_data[(size_t)blockIdx.x * dim + idx] / sqrt(l2_norm); + } else { + output_data[list_id * dim + idx] = input_data[(size_t)blockIdx.x * dim + idx]; + if (idx == 0) { l2_norms[list_id] = l2_norm; } + } + } + } +} + +template +__global__ void add_rev_edges_kernel(const Index_t* graph, + Index_t* rev_graph, + int num_samples, + int2* list_sizes) +{ + size_t list_id = blockIdx.x; + int2 list_size = list_sizes[list_id]; + + for (int idx = threadIdx.x; idx < list_size.x; idx += blockDim.x) { + // each node has same number (num_samples) of forward and reverse edges + size_t rev_list_id = graph[list_id * num_samples + idx]; + // there are already num_samples forward edges + int idx_in_rev_list = atomicAdd(&list_sizes[rev_list_id].y, 1); + if (idx_in_rev_list >= num_samples) { + atomicExch(&list_sizes[rev_list_id].y, num_samples); + } else { + rev_graph[rev_list_id * num_samples + idx_in_rev_list] = list_id; + } + } +} + +template > +__device__ void insert_to_global_graph(ResultItem elem, + size_t list_id, + ID_t* graph, + DistData_t* dists, + int node_degree, + int* locks) +{ + int tx = threadIdx.x; + int lane_id = tx % raft::warp_size(); + size_t global_idx_base = list_id * node_degree; + if (elem.id() == list_id) return; + + const int num_segments = ceildiv(node_degree, raft::warp_size()); + + int loop_flag = 0; + do { + int segment_id = elem.id() % num_segments; + if (lane_id == 0) { + loop_flag = atomicCAS(&locks[list_id * num_segments + segment_id], 0, 1) == 0; + } + + loop_flag = __shfl_sync(raft::warp_full_mask(), loop_flag, 0); + + if (loop_flag == 1) { + ResultItem knn_list_frag; + int local_idx = segment_id * raft::warp_size() + lane_id; + size_t global_idx = global_idx_base + local_idx; + if (local_idx < node_degree) { + knn_list_frag.id_with_flag() = graph[global_idx].id_with_flag(); + knn_list_frag.dist() = dists[global_idx]; + } + + int pos_to_insert = -1; + ResultItem prev_elem; + + prev_elem.id_with_flag() = + __shfl_up_sync(raft::warp_full_mask(), knn_list_frag.id_with_flag(), 1); + prev_elem.dist() = __shfl_up_sync(raft::warp_full_mask(), knn_list_frag.dist(), 1); + + if (lane_id == 0) { + prev_elem = ResultItem{std::numeric_limits::min(), + std::numeric_limits::lowest()}; + } + if (elem > prev_elem && elem < knn_list_frag) { + pos_to_insert = segment_id * raft::warp_size() + lane_id; + } else if (elem == prev_elem || elem == knn_list_frag) { + pos_to_insert = -2; + } + uint mask = __ballot_sync(raft::warp_full_mask(), pos_to_insert >= 0); + if (mask) { + uint set_lane_id = __fns(mask, 0, 1); + pos_to_insert = __shfl_sync(raft::warp_full_mask(), pos_to_insert, set_lane_id); + } + + if (pos_to_insert >= 0) { + int local_idx = segment_id * raft::warp_size() + lane_id; + if (local_idx > pos_to_insert) { + local_idx++; + } else if (local_idx == pos_to_insert) { + graph[global_idx_base + local_idx].id_with_flag() = elem.id_with_flag(); + dists[global_idx_base + local_idx] = elem.dist(); + local_idx++; + } + size_t global_pos = global_idx_base + local_idx; + if (local_idx < (segment_id + 1) * raft::warp_size() && local_idx < node_degree) { + graph[global_pos].id_with_flag() = knn_list_frag.id_with_flag(); + dists[global_pos] = knn_list_frag.dist(); + } + } + __threadfence(); + if (loop_flag && lane_id == 0) { atomicExch(&locks[list_id * num_segments + segment_id], 0); } + } + } while (!loop_flag); +} + +template +__device__ ResultItem get_min_item(const Index_t id, + const int idx_in_list, + const Index_t* neighbs, + const DistData_t* distances, + const bool find_in_row = true) +{ + int lane_id = threadIdx.x % raft::warp_size(); + + static_assert(MAX_NUM_BI_SAMPLES == 64); + int idx[MAX_NUM_BI_SAMPLES / raft::warp_size()]; + float dist[MAX_NUM_BI_SAMPLES / raft::warp_size()] = {std::numeric_limits::max(), + std::numeric_limits::max()}; + idx[0] = lane_id; + idx[1] = raft::warp_size() + lane_id; + + if (neighbs[idx[0]] != id) { + dist[0] = find_in_row ? distances[idx_in_list * SKEWED_MAX_NUM_BI_SAMPLES + lane_id] + : distances[idx_in_list + lane_id * SKEWED_MAX_NUM_BI_SAMPLES]; + } + + if (neighbs[idx[1]] != id) { + dist[1] = + find_in_row + ? distances[idx_in_list * SKEWED_MAX_NUM_BI_SAMPLES + raft::warp_size() + lane_id] + : distances[idx_in_list + (raft::warp_size() + lane_id) * SKEWED_MAX_NUM_BI_SAMPLES]; + } + + if (dist[1] < dist[0]) { + dist[0] = dist[1]; + idx[0] = idx[1]; + } + __syncwarp(); + for (int offset = raft::warp_size() >> 1; offset >= 1; offset >>= 1) { + float other_idx = __shfl_down_sync(raft::warp_full_mask(), idx[0], offset); + float other_dist = __shfl_down_sync(raft::warp_full_mask(), dist[0], offset); + if (other_dist < dist[0]) { + dist[0] = other_dist; + idx[0] = other_idx; + } + } + + ResultItem result; + result.dist() = __shfl_sync(raft::warp_full_mask(), dist[0], 0); + result.id_with_flag() = neighbs[__shfl_sync(raft::warp_full_mask(), idx[0], 0)]; + return result; +} + +template +__device__ __forceinline__ void remove_duplicates( + T* list_a, int list_a_size, T* list_b, int list_b_size, int& unique_counter, int execute_warp_id) +{ + static_assert(raft::warp_size() == 32); + if (!(threadIdx.x >= execute_warp_id * raft::warp_size() && + threadIdx.x < execute_warp_id * raft::warp_size() + raft::warp_size())) { + return; + } + int lane_id = threadIdx.x % raft::warp_size(); + T elem = std::numeric_limits::max(); + if (lane_id < list_a_size) { elem = list_a[lane_id]; } + warp_bitonic_sort(&elem, lane_id); + + if (elem != std::numeric_limits::max()) { list_a[lane_id] = elem; } + + T elem_b = std::numeric_limits::max(); + + if (lane_id < list_b_size) { elem_b = list_b[lane_id]; } + __syncwarp(); + + int idx_l = 0; + int idx_r = list_a_size; + bool existed = false; + while (idx_l < idx_r) { + int idx = (idx_l + idx_r) / 2; + int elem = list_a[idx]; + if (elem == elem_b) { + existed = true; + break; + } + if (elem_b > elem) { + idx_l = idx + 1; + } else { + idx_r = idx; + } + } + if (!existed && elem_b != std::numeric_limits::max()) { + int idx = atomicAdd(&unique_counter, 1); + list_a[list_a_size + idx] = elem_b; + } +} + +// launch_bounds here denote BLOCK_SIZE = 512 and MIN_BLOCKS_PER_SM = 4 +// Per +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications, +// MAX_RESIDENT_THREAD_PER_SM = BLOCK_SIZE * BLOCKS_PER_SM = 2048 +// For architectures 750 and 860, the values for MAX_RESIDENT_THREAD_PER_SM +// is 1024 and 1536 respectively, which means the bounds don't work anymore +template > +__global__ void +#ifdef __CUDA_ARCH__ +#if (__CUDA_ARCH__) == 750 || (__CUDA_ARCH__) == 860 +__launch_bounds__(BLOCK_SIZE) +#else +__launch_bounds__(BLOCK_SIZE, 4) +#endif +#endif + local_join_kernel(const Index_t* graph_new, + const Index_t* rev_graph_new, + const int2* sizes_new, + const Index_t* graph_old, + const Index_t* rev_graph_old, + const int2* sizes_old, + const int width, + const __half* data, + const int data_dim, + ID_t* graph, + DistData_t* dists, + int graph_width, + int* locks, + DistData_t* l2_norms) +{ +#if (__CUDA_ARCH__ >= 700) + using namespace nvcuda; + __shared__ int s_list[MAX_NUM_BI_SAMPLES * 2]; + + constexpr int APAD = 8; + constexpr int BPAD = 8; + __shared__ __half s_nv[MAX_NUM_BI_SAMPLES][TILE_COL_WIDTH + APAD]; // New vectors + __shared__ __half s_ov[MAX_NUM_BI_SAMPLES][TILE_COL_WIDTH + BPAD]; // Old vectors + static_assert(sizeof(float) * MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES <= + sizeof(__half) * MAX_NUM_BI_SAMPLES * (TILE_COL_WIDTH + BPAD)); + // s_distances: MAX_NUM_BI_SAMPLES x SKEWED_MAX_NUM_BI_SAMPLES, reuse the space of s_ov + float* s_distances = (float*)&s_ov[0][0]; + int* s_unique_counter = (int*)&s_ov[0][0]; + + if (threadIdx.x == 0) { + s_unique_counter[0] = 0; + s_unique_counter[1] = 0; + } + + Index_t* new_neighbors = s_list; + Index_t* old_neighbors = s_list + MAX_NUM_BI_SAMPLES; + + size_t list_id = blockIdx.x; + int2 list_new_size2 = sizes_new[list_id]; + int list_new_size = list_new_size2.x + list_new_size2.y; + int2 list_old_size2 = sizes_old[list_id]; + int list_old_size = list_old_size2.x + list_old_size2.y; + + if (!list_new_size) return; + int tx = threadIdx.x; + + if (tx < list_new_size2.x) { + new_neighbors[tx] = graph_new[list_id * width + tx]; + } else if (tx >= list_new_size2.x && tx < list_new_size) { + new_neighbors[tx] = rev_graph_new[list_id * width + tx - list_new_size2.x]; + } + + if (tx < list_old_size2.x) { + old_neighbors[tx] = graph_old[list_id * width + tx]; + } else if (tx >= list_old_size2.x && tx < list_old_size) { + old_neighbors[tx] = rev_graph_old[list_id * width + tx - list_old_size2.x]; + } + + __syncthreads(); + + remove_duplicates(new_neighbors, + list_new_size2.x, + new_neighbors + list_new_size2.x, + list_new_size2.y, + s_unique_counter[0], + 0); + + remove_duplicates(old_neighbors, + list_old_size2.x, + old_neighbors + list_old_size2.x, + list_old_size2.y, + s_unique_counter[1], + 1); + __syncthreads(); + list_new_size = list_new_size2.x + s_unique_counter[0]; + list_old_size = list_old_size2.x + s_unique_counter[1]; + + int warp_id = threadIdx.x / raft::warp_size(); + int lane_id = threadIdx.x % raft::warp_size(); + constexpr int num_warps = BLOCK_SIZE / raft::warp_size(); + + int warp_id_y = warp_id / 4; + int warp_id_x = warp_id % 4; + + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::fragment c_frag; + wmma::fill_fragment(c_frag, 0.0); + for (int step = 0; step < ceildiv(data_dim, TILE_COL_WIDTH); step++) { + int num_load_elems = (step == ceildiv(data_dim, TILE_COL_WIDTH) - 1) + ? data_dim - step * TILE_COL_WIDTH + : TILE_COL_WIDTH; +#pragma unroll + for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) { + int idx = i * num_warps + warp_id; + if (idx < list_new_size) { + size_t neighbor_id = new_neighbors[idx]; + size_t idx_in_data = neighbor_id * data_dim; + load_vec(s_nv[idx], + data + idx_in_data + step * TILE_COL_WIDTH, + num_load_elems, + TILE_COL_WIDTH, + lane_id); + } + } + __syncthreads(); + + for (int i = 0; i < TILE_COL_WIDTH / WMMA_K; i++) { + wmma::load_matrix_sync(a_frag, s_nv[warp_id_y * WMMA_M] + i * WMMA_K, TILE_COL_WIDTH + APAD); + wmma::load_matrix_sync(b_frag, s_nv[warp_id_x * WMMA_N] + i * WMMA_K, TILE_COL_WIDTH + BPAD); + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + __syncthreads(); + } + } + + wmma::store_matrix_sync( + s_distances + warp_id_y * WMMA_M * SKEWED_MAX_NUM_BI_SAMPLES + warp_id_x * WMMA_N, + c_frag, + SKEWED_MAX_NUM_BI_SAMPLES, + wmma::mem_row_major); + __syncthreads(); + + for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) { + if (i % SKEWED_MAX_NUM_BI_SAMPLES < list_new_size && + i / SKEWED_MAX_NUM_BI_SAMPLES < list_new_size) { + if (l2_norms == nullptr) { + s_distances[i] = -s_distances[i]; + } else { + s_distances[i] = l2_norms[new_neighbors[i % SKEWED_MAX_NUM_BI_SAMPLES]] + + l2_norms[new_neighbors[i / SKEWED_MAX_NUM_BI_SAMPLES]] - + 2.0 * s_distances[i]; + } + } else { + s_distances[i] = std::numeric_limits::max(); + } + } + __syncthreads(); + + for (int step = 0; step < ceildiv(list_new_size, num_warps); step++) { + int idx_in_list = step * num_warps + tx / raft::warp_size(); + if (idx_in_list >= list_new_size) continue; + auto min_elem = get_min_item(s_list[idx_in_list], idx_in_list, new_neighbors, s_distances); + if (min_elem.id() < gridDim.x) { + insert_to_global_graph(min_elem, s_list[idx_in_list], graph, dists, graph_width, locks); + } + } + + if (!list_old_size) return; + + __syncthreads(); + + wmma::fill_fragment(c_frag, 0.0); + for (int step = 0; step < ceildiv(data_dim, TILE_COL_WIDTH); step++) { + int num_load_elems = (step == ceildiv(data_dim, TILE_COL_WIDTH) - 1) + ? data_dim - step * TILE_COL_WIDTH + : TILE_COL_WIDTH; + if (TILE_COL_WIDTH < data_dim) { +#pragma unroll + for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) { + int idx = i * num_warps + warp_id; + if (idx < list_new_size) { + size_t neighbor_id = new_neighbors[idx]; + size_t idx_in_data = neighbor_id * data_dim; + load_vec(s_nv[idx], + data + idx_in_data + step * TILE_COL_WIDTH, + num_load_elems, + TILE_COL_WIDTH, + lane_id); + } + } + } +#pragma unroll + for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) { + int idx = i * num_warps + warp_id; + if (idx < list_old_size) { + size_t neighbor_id = old_neighbors[idx]; + size_t idx_in_data = neighbor_id * data_dim; + load_vec(s_ov[idx], + data + idx_in_data + step * TILE_COL_WIDTH, + num_load_elems, + TILE_COL_WIDTH, + lane_id); + } + } + __syncthreads(); + + for (int i = 0; i < TILE_COL_WIDTH / WMMA_K; i++) { + wmma::load_matrix_sync(a_frag, s_nv[warp_id_y * WMMA_M] + i * WMMA_K, TILE_COL_WIDTH + APAD); + wmma::load_matrix_sync(b_frag, s_ov[warp_id_x * WMMA_N] + i * WMMA_K, TILE_COL_WIDTH + BPAD); + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + __syncthreads(); + } + } + + wmma::store_matrix_sync( + s_distances + warp_id_y * WMMA_M * SKEWED_MAX_NUM_BI_SAMPLES + warp_id_x * WMMA_N, + c_frag, + SKEWED_MAX_NUM_BI_SAMPLES, + wmma::mem_row_major); + __syncthreads(); + + for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) { + if (i % SKEWED_MAX_NUM_BI_SAMPLES < list_old_size && + i / SKEWED_MAX_NUM_BI_SAMPLES < list_new_size) { + if (l2_norms == nullptr) { + s_distances[i] = -s_distances[i]; + } else { + s_distances[i] = l2_norms[old_neighbors[i % SKEWED_MAX_NUM_BI_SAMPLES]] + + l2_norms[new_neighbors[i / SKEWED_MAX_NUM_BI_SAMPLES]] - + 2.0 * s_distances[i]; + } + } else { + s_distances[i] = std::numeric_limits::max(); + } + } + __syncthreads(); + + for (int step = 0; step < ceildiv(MAX_NUM_BI_SAMPLES * 2, num_warps); step++) { + int idx_in_list = step * num_warps + tx / raft::warp_size(); + if (idx_in_list >= list_new_size && idx_in_list < MAX_NUM_BI_SAMPLES) continue; + if (idx_in_list >= MAX_NUM_BI_SAMPLES + list_old_size && idx_in_list < MAX_NUM_BI_SAMPLES * 2) + continue; + ResultItem min_elem{std::numeric_limits::max(), + std::numeric_limits::max()}; + if (idx_in_list < MAX_NUM_BI_SAMPLES) { + auto temp_min_item = + get_min_item(s_list[idx_in_list], idx_in_list, old_neighbors, s_distances); + if (temp_min_item.dist() < min_elem.dist()) { min_elem = temp_min_item; } + } else { + auto temp_min_item = get_min_item( + s_list[idx_in_list], idx_in_list - MAX_NUM_BI_SAMPLES, new_neighbors, s_distances, false); + if (temp_min_item.dist() < min_elem.dist()) { min_elem = temp_min_item; } + } + + if (min_elem.id() < gridDim.x) { + insert_to_global_graph(min_elem, s_list[idx_in_list], graph, dists, graph_width, locks); + } + } +#endif +} + +namespace { +template +int insert_to_ordered_list(InternalID_t* list, + DistData_t* dist_list, + const int width, + const InternalID_t neighb_id, + const DistData_t dist) +{ + if (dist > dist_list[width - 1]) { return width; } + + int idx_insert = width; + bool position_found = false; + for (int i = 0; i < width; i++) { + if (list[i].id() == neighb_id.id()) { return width; } + if (!position_found && dist_list[i] > dist) { + idx_insert = i; + position_found = true; + } + } + if (idx_insert == width) return idx_insert; + + memmove(list + idx_insert + 1, list + idx_insert, sizeof(*list) * (width - idx_insert - 1)); + memmove(dist_list + idx_insert + 1, + dist_list + idx_insert, + sizeof(*dist_list) * (width - idx_insert - 1)); + + list[idx_insert] = neighb_id; + dist_list[idx_insert] = dist; + return idx_insert; +}; + +} // namespace + +template +GnndGraph::GnndGraph(const size_t nrow, + const size_t node_degree, + const size_t internal_node_degree, + const size_t num_samples) + : nrow(nrow), + node_degree(node_degree), + num_samples(num_samples), + bloom_filter(nrow, internal_node_degree / segment_size, 3), + h_dists{raft::make_host_matrix(nrow, node_degree)}, + h_graph_new{nrow * num_samples}, + h_list_sizes_new{nrow}, + h_graph_old{nrow * num_samples}, + h_list_sizes_old{nrow} +{ + // node_degree must be a multiple of segment_size; + assert(node_degree % segment_size == 0); + assert(internal_node_degree % segment_size == 0); + + num_segments = node_degree / segment_size; + // To save the CPU memory, graph should be allocated by external function + h_graph = nullptr; +} + +// This is the only operation on the CPU that cannot be overlapped. +// So it should be as fast as possible. +template +void GnndGraph::sample_graph_new(InternalID_t* new_neighbors, const size_t width) +{ +#pragma omp parallel for + for (size_t i = 0; i < nrow; i++) { + auto list_new = h_graph_new.data() + i * num_samples; + h_list_sizes_new[i].x = 0; + h_list_sizes_new[i].y = 0; + + for (size_t j = 0; j < width; j++) { + auto new_neighb_id = new_neighbors[i * width + j].id(); + if ((size_t)new_neighb_id >= nrow) break; + if (bloom_filter.check(i, new_neighb_id)) { continue; } + bloom_filter.add(i, new_neighb_id); + new_neighbors[i * width + j].mark_old(); + list_new[h_list_sizes_new[i].x++] = new_neighb_id; + if (h_list_sizes_new[i].x == num_samples) break; + } + } +} + +template +void GnndGraph::init_random_graph() +{ + for (size_t seg_idx = 0; seg_idx < static_cast(num_segments); seg_idx++) { + // random sequence (range: 0~nrow) + // segment_x stores neighbors which id % num_segments == x + std::vector rand_seq(nrow / num_segments); + std::iota(rand_seq.begin(), rand_seq.end(), 0); + std::random_shuffle(rand_seq.begin(), rand_seq.end()); + +#pragma omp parallel for + for (size_t i = 0; i < nrow; i++) { + size_t base_idx = i * node_degree + seg_idx * segment_size; + auto h_neighbor_list = h_graph + base_idx; + auto h_dist_list = h_dists.data_handle() + base_idx; + for (size_t j = 0; j < static_cast(segment_size); j++) { + size_t idx = base_idx + j; + Index_t id = rand_seq[idx % rand_seq.size()] * num_segments + seg_idx; + if ((size_t)id == i) { + id = rand_seq[(idx + segment_size) % rand_seq.size()] * num_segments + seg_idx; + } + h_neighbor_list[j].id_with_flag() = id; + h_dist_list[j] = std::numeric_limits::max(); + } + } + } +} + +template +void GnndGraph::sample_graph(bool sample_new) +{ +#pragma omp parallel for + for (size_t i = 0; i < nrow; i++) { + h_list_sizes_old[i].x = 0; + h_list_sizes_old[i].y = 0; + h_list_sizes_new[i].x = 0; + h_list_sizes_new[i].y = 0; + + auto list = h_graph + i * node_degree; + auto list_old = h_graph_old.data() + i * num_samples; + auto list_new = h_graph_new.data() + i * num_samples; + for (int j = 0; j < segment_size; j++) { + for (int k = 0; k < num_segments; k++) { + auto neighbor = list[k * segment_size + j]; + if ((size_t)neighbor.id() >= nrow) continue; + if (!neighbor.is_new()) { + if (h_list_sizes_old[i].x < num_samples) { + list_old[h_list_sizes_old[i].x++] = neighbor.id(); + } + } else if (sample_new) { + if (h_list_sizes_new[i].x < num_samples) { + list[k * segment_size + j].mark_old(); + list_new[h_list_sizes_new[i].x++] = neighbor.id(); + } + } + if (h_list_sizes_old[i].x == num_samples && h_list_sizes_new[i].x == num_samples) { break; } + } + if (h_list_sizes_old[i].x == num_samples && h_list_sizes_new[i].x == num_samples) { break; } + } + } +} + +template +void GnndGraph::update_graph(const InternalID_t* new_neighbors, + const DistData_t* new_dists, + const size_t width, + std::atomic& update_counter) +{ +#pragma omp parallel for + for (size_t i = 0; i < nrow; i++) { + for (size_t j = 0; j < width; j++) { + auto new_neighb_id = new_neighbors[i * width + j]; + auto new_dist = new_dists[i * width + j]; + if (new_dist == std::numeric_limits::max()) break; + if ((size_t)new_neighb_id.id() == i) continue; + int seg_idx = new_neighb_id.id() % num_segments; + auto list = h_graph + i * node_degree + seg_idx * segment_size; + auto dist_list = h_dists.data_handle() + i * node_degree + seg_idx * segment_size; + int insert_pos = + insert_to_ordered_list(list, dist_list, segment_size, new_neighb_id, new_dist); + if (i % counter_interval == 0 && insert_pos != segment_size) { update_counter++; } + } + } +} + +template +void GnndGraph::sort_lists() +{ +#pragma omp parallel for + for (size_t i = 0; i < nrow; i++) { + std::vector> new_list; + for (size_t j = 0; j < node_degree; j++) { + new_list.emplace_back(h_dists.data_handle()[i * node_degree + j], + h_graph[i * node_degree + j].id()); + } + std::sort(new_list.begin(), new_list.end()); + for (size_t j = 0; j < node_degree; j++) { + h_graph[i * node_degree + j].id_with_flag() = new_list[j].second; + h_dists.data_handle()[i * node_degree + j] = new_list[j].first; + } + } +} + +template +void GnndGraph::clear() +{ + bloom_filter.clear(); +} + +template +GnndGraph::~GnndGraph() +{ + assert(h_graph == nullptr); +} + +template +GNND::GNND(raft::resources const& res, const BuildConfig& build_config) + : res(res), + build_config_(build_config), + graph_(build_config.max_dataset_size, + align32::roundUp(build_config.node_degree), + align32::roundUp(build_config.internal_node_degree ? build_config.internal_node_degree + : build_config.node_degree), + NUM_SAMPLES), + nrow_(build_config.max_dataset_size), + ndim_(build_config.dataset_dim), + d_data_{raft::make_device_matrix<__half, Index_t, raft::row_major>( + res, nrow_, build_config.dataset_dim)}, + l2_norms_{raft::make_device_vector(res, nrow_)}, + graph_buffer_{ + raft::make_device_matrix(res, nrow_, DEGREE_ON_DEVICE)}, + dists_buffer_{ + raft::make_device_matrix(res, nrow_, DEGREE_ON_DEVICE)}, + graph_host_buffer_{static_cast(nrow_ * DEGREE_ON_DEVICE)}, + dists_host_buffer_{static_cast(nrow_ * DEGREE_ON_DEVICE)}, + d_locks_{raft::make_device_vector(res, nrow_)}, + h_rev_graph_new_{static_cast(nrow_ * NUM_SAMPLES)}, + h_graph_old_{static_cast(nrow_ * NUM_SAMPLES)}, + h_rev_graph_old_{static_cast(nrow_ * NUM_SAMPLES)}, + d_list_sizes_new_{raft::make_device_vector(res, nrow_)}, + d_list_sizes_old_{raft::make_device_vector(res, nrow_)} +{ + static_assert(NUM_SAMPLES <= 32); + + thrust::fill(thrust::device, + dists_buffer_.data_handle(), + dists_buffer_.data_handle() + dists_buffer_.size(), + std::numeric_limits::max()); + thrust::fill(thrust::device, + reinterpret_cast(graph_buffer_.data_handle()), + reinterpret_cast(graph_buffer_.data_handle()) + graph_buffer_.size(), + std::numeric_limits::max()); + thrust::fill(thrust::device, d_locks_.data_handle(), d_locks_.data_handle() + d_locks_.size(), 0); +}; + +template +void GNND::add_reverse_edges(Index_t* graph_ptr, + Index_t* h_rev_graph_ptr, + Index_t* d_rev_graph_ptr, + int2* list_sizes, + cudaStream_t stream) +{ + add_rev_edges_kernel<<>>( + graph_ptr, d_rev_graph_ptr, NUM_SAMPLES, list_sizes); + raft::copy( + h_rev_graph_ptr, d_rev_graph_ptr, nrow_ * NUM_SAMPLES, raft::resource::get_cuda_stream(res)); +} + +template +void GNND::local_join(cudaStream_t stream) +{ + thrust::fill(thrust::device.on(stream), + dists_buffer_.data_handle(), + dists_buffer_.data_handle() + dists_buffer_.size(), + std::numeric_limits::max()); + local_join_kernel<<>>( + thrust::raw_pointer_cast(graph_.h_graph_new.data()), + thrust::raw_pointer_cast(h_rev_graph_new_.data()), + d_list_sizes_new_.data_handle(), + thrust::raw_pointer_cast(h_graph_old_.data()), + thrust::raw_pointer_cast(h_rev_graph_old_.data()), + d_list_sizes_old_.data_handle(), + NUM_SAMPLES, + d_data_.data_handle(), + ndim_, + graph_buffer_.data_handle(), + dists_buffer_.data_handle(), + DEGREE_ON_DEVICE, + d_locks_.data_handle(), + l2_norms_.data_handle()); +} + +template +void GNND::build(Data_t* data, const Index_t nrow, Index_t* output_graph) +{ + using input_t = typename std::remove_const::type; + + cudaStream_t stream = raft::resource::get_cuda_stream(res); + nrow_ = nrow; + graph_.h_graph = (InternalID_t*)output_graph; + + cudaPointerAttributes data_ptr_attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&data_ptr_attr, data)); + size_t batch_size = (data_ptr_attr.devicePointer == nullptr) ? 100000 : nrow_; + + raft::spatial::knn::detail::utils::batch_load_iterator vec_batches{ + data, static_cast(nrow_), build_config_.dataset_dim, batch_size, stream}; + for (auto const& batch : vec_batches) { + preprocess_data_kernel<<< + batch.size(), + raft::warp_size(), + sizeof(Data_t) * ceildiv(build_config_.dataset_dim, static_cast(raft::warp_size())) * + raft::warp_size(), + stream>>>(batch.data(), + d_data_.data_handle(), + build_config_.dataset_dim, + l2_norms_.data_handle(), + batch.offset()); + } + + thrust::fill(thrust::device.on(stream), + (Index_t*)graph_buffer_.data_handle(), + (Index_t*)graph_buffer_.data_handle() + graph_buffer_.size(), + std::numeric_limits::max()); + + graph_.clear(); + graph_.init_random_graph(); + graph_.sample_graph(true); + + auto update_and_sample = [&](bool update_graph) { + if (update_graph) { + update_counter_ = 0; + graph_.update_graph(thrust::raw_pointer_cast(graph_host_buffer_.data()), + thrust::raw_pointer_cast(dists_host_buffer_.data()), + DEGREE_ON_DEVICE, + update_counter_); + if (update_counter_ < build_config_.termination_threshold * nrow_ * + build_config_.dataset_dim / counter_interval) { + update_counter_ = -1; + } + } + graph_.sample_graph(false); + }; + + for (size_t it = 0; it < build_config_.max_iterations; it++) { + raft::copy(d_list_sizes_new_.data_handle(), + thrust::raw_pointer_cast(graph_.h_list_sizes_new.data()), + nrow_, + raft::resource::get_cuda_stream(res)); + raft::copy(thrust::raw_pointer_cast(h_graph_old_.data()), + thrust::raw_pointer_cast(graph_.h_graph_old.data()), + nrow_ * NUM_SAMPLES, + raft::resource::get_cuda_stream(res)); + raft::copy(d_list_sizes_old_.data_handle(), + thrust::raw_pointer_cast(graph_.h_list_sizes_old.data()), + nrow_, + raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); + + std::thread update_and_sample_thread(update_and_sample, it); + + std::cout << "# GNND iteraton: " << it + 1 << "/" << build_config_.max_iterations << "\r"; + std::fflush(stdout); + + // Reuse dists_buffer_ to save GPU memory. graph_buffer_ cannot be reused, because it + // contains some information for local_join. + static_assert(DEGREE_ON_DEVICE * sizeof(*(dists_buffer_.data_handle())) >= + NUM_SAMPLES * sizeof(*(graph_buffer_.data_handle()))); + add_reverse_edges(thrust::raw_pointer_cast(graph_.h_graph_new.data()), + thrust::raw_pointer_cast(h_rev_graph_new_.data()), + (Index_t*)dists_buffer_.data_handle(), + d_list_sizes_new_.data_handle(), + stream); + add_reverse_edges(thrust::raw_pointer_cast(h_graph_old_.data()), + thrust::raw_pointer_cast(h_rev_graph_old_.data()), + (Index_t*)dists_buffer_.data_handle(), + d_list_sizes_old_.data_handle(), + stream); + + // Tensor operations from `mma.h` are guarded with archicteture + // __CUDA_ARCH__ >= 700. Since RAFT supports compilation for ARCH 600, + // we need to ensure that `local_join_kernel` (which uses tensor) operations + // is not only not compiled, but also a runtime error is presented to the user + auto kernel = preprocess_data_kernel; + void* kernel_ptr = reinterpret_cast(kernel); + auto runtime_arch = raft::util::arch::kernel_virtual_arch(kernel_ptr); + auto wmma_range = + raft::util::arch::SM_range(raft::util::arch::SM_70(), raft::util::arch::SM_future()); + + if (wmma_range.contains(runtime_arch)) { + local_join(stream); + } else { + THROW("NN_DESCENT cannot be run for __CUDA_ARCH__ < 700"); + } + + update_and_sample_thread.join(); + + if (update_counter_ == -1) { break; } + raft::copy(thrust::raw_pointer_cast(graph_host_buffer_.data()), + graph_buffer_.data_handle(), + nrow_ * DEGREE_ON_DEVICE, + raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); + raft::copy(thrust::raw_pointer_cast(dists_host_buffer_.data()), + dists_buffer_.data_handle(), + nrow_ * DEGREE_ON_DEVICE, + raft::resource::get_cuda_stream(res)); + + graph_.sample_graph_new(thrust::raw_pointer_cast(graph_host_buffer_.data()), DEGREE_ON_DEVICE); + } + + graph_.update_graph(thrust::raw_pointer_cast(graph_host_buffer_.data()), + thrust::raw_pointer_cast(dists_host_buffer_.data()), + DEGREE_ON_DEVICE, + update_counter_); + raft::resource::sync_stream(res); + graph_.sort_lists(); + + // Reuse graph_.h_dists as the buffer for shrink the lists in graph + static_assert(sizeof(decltype(*(graph_.h_dists.data_handle()))) >= sizeof(Index_t)); + Index_t* graph_shrink_buffer = (Index_t*)graph_.h_dists.data_handle(); + +#pragma omp parallel for + for (size_t i = 0; i < (size_t)nrow_; i++) { + for (size_t j = 0; j < build_config_.node_degree; j++) { + size_t idx = i * graph_.node_degree + j; + Index_t id = graph_.h_graph[idx].id(); + if (id < nrow_) { + graph_shrink_buffer[i * build_config_.node_degree + j] = id; + } else { + graph_shrink_buffer[i * build_config_.node_degree + j] = + raft::neighbors::cagra::detail::device::xorshift64(idx) % nrow_; + } + } + } + graph_.h_graph = nullptr; + +#pragma omp parallel for + for (size_t i = 0; i < (size_t)nrow_; i++) { + for (size_t j = 0; j < build_config_.node_degree; j++) { + output_graph[i * build_config_.node_degree + j] = + graph_shrink_buffer[i * build_config_.node_degree + j]; + } + } +} + +template , memory_type::host>> +void build(raft::resources const& res, + const index_params& params, + mdspan, row_major, Accessor> dataset, + index& idx) +{ + RAFT_EXPECTS(dataset.extent(0) < std::numeric_limits::max() - 1, + "The dataset size for GNND should be less than %d", + std::numeric_limits::max() - 1); + size_t intermediate_degree = params.intermediate_graph_degree; + size_t graph_degree = params.graph_degree; + + if (intermediate_degree >= static_cast(dataset.extent(0))) { + RAFT_LOG_WARN( + "Intermediate graph degree cannot be larger than dataset size, reducing it to %lu", + dataset.extent(0)); + intermediate_degree = dataset.extent(0) - 1; + } + if (intermediate_degree < graph_degree) { + RAFT_LOG_WARN( + "Graph degree (%lu) cannot be larger than intermediate graph degree (%lu), reducing " + "graph_degree.", + graph_degree, + intermediate_degree); + graph_degree = intermediate_degree; + } + + // The elements in each knn-list are partitioned into different buckets, and we need more buckets + // to mitigate bucket collisions. `intermediate_degree` is OK to larger than + // extended_graph_degree. + size_t extended_graph_degree = + align32::roundUp(static_cast(graph_degree * (graph_degree <= 32 ? 1.0 : 1.3))); + size_t extended_intermediate_degree = align32::roundUp( + static_cast(intermediate_degree * (intermediate_degree <= 32 ? 1.0 : 1.3))); + + auto int_graph = raft::make_host_matrix( + dataset.extent(0), static_cast(extended_graph_degree)); + + BuildConfig build_config{.max_dataset_size = static_cast(dataset.extent(0)), + .dataset_dim = static_cast(dataset.extent(1)), + .node_degree = extended_graph_degree, + .internal_node_degree = extended_intermediate_degree, + .max_iterations = params.max_iterations, + .termination_threshold = params.termination_threshold}; + + GNND nnd(res, build_config); + nnd.build(dataset.data_handle(), dataset.extent(0), int_graph.data_handle()); + +#pragma omp parallel for + for (size_t i = 0; i < static_cast(dataset.extent(0)); i++) { + for (size_t j = 0; j < graph_degree; j++) { + auto graph = idx.graph().data_handle(); + graph[i * graph_degree + j] = int_graph.data_handle()[i * extended_graph_degree + j]; + } + } +} + +template , memory_type::host>> +index build(raft::resources const& res, + const index_params& params, + mdspan, row_major, Accessor> dataset) +{ + size_t intermediate_degree = params.intermediate_graph_degree; + size_t graph_degree = params.graph_degree; + + if (intermediate_degree < graph_degree) { + RAFT_LOG_WARN( + "Graph degree (%lu) cannot be larger than intermediate graph degree (%lu), reducing " + "graph_degree.", + graph_degree, + intermediate_degree); + graph_degree = intermediate_degree; + } + + index idx{res, dataset.extent(0), static_cast(graph_degree)}; + + build(res, params, dataset, idx); + + return idx; +} + +} // namespace raft::neighbors::experimental::nn_descent::detail diff --git a/cpp/include/raft/neighbors/nn_descent.cuh b/cpp/include/raft/neighbors/nn_descent.cuh new file mode 100644 index 0000000000..ceb5ae5643 --- /dev/null +++ b/cpp/include/raft/neighbors/nn_descent.cuh @@ -0,0 +1,181 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "detail/nn_descent.cuh" + +#include +#include + +namespace raft::neighbors::experimental::nn_descent { + +/** + * @defgroup nn-descent CUDA gradient descent nearest neighbor + * @{ + */ + +/** + * @brief Build nn-descent Index with dataset in device memory + * + * The following distance metrics are supported: + * - L2 + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors::experimental; + * // use default index parameters + * nn_descent::index_params index_params; + * // create and fill the index from a [N, D] raft::device_matrix_view dataset + * auto index = cagra::build(res, index_params, dataset); + * // index.graph() provides a raft::host_matrix_view of an + * // all-neighbors knn graph of dimensions [N, k] of the input + * // dataset + * @endcode + * + * @tparam T data-type of the input dataset + * @tparam IdxT data-type for the output index + * @param[in] res raft::resources is an object mangaging resources + * @param[in] params an instance of nn_descent::index_params that are parameters + * to run the nn-descent algorithm + * @param[in] dataset raft::device_matrix_view input dataset expected to be located + * in device memory + * @return index index containing all-neighbors knn graph in host memory + */ +template +index build(raft::resources const& res, + index_params const& params, + raft::device_matrix_view dataset) +{ + return detail::build(res, params, dataset); +} + +/** + * @brief Build nn-descent Index with dataset in device memory + * + * The following distance metrics are supported: + * - L2 + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors::experimental; + * // use default index parameters + * nn_descent::index_params index_params; + * // create and fill the index from a [N, D] raft::device_matrix_view dataset + * auto knn_graph = raft::make_host_matrix(N, D); + * auto index = nn_descent::index{res, knn_graph.view()}; + * cagra::build(res, index_params, dataset, index); + * // index.graph() provides a raft::host_matrix_view of an + * // all-neighbors knn graph of dimensions [N, k] of the input + * // dataset + * @endcode + * + * @tparam T data-type of the input dataset + * @tparam IdxT data-type for the output index + * @param res raft::resources is an object mangaging resources + * @param[in] params an instance of nn_descent::index_params that are parameters + * to run the nn-descent algorithm + * @param[in] dataset raft::device_matrix_view input dataset expected to be located + * in device memory + * @param[out] idx raft::neighbors::experimental::nn_descentindex containing all-neighbors knn graph + * in host memory + */ +template +void build(raft::resources const& res, + index_params const& params, + raft::device_matrix_view dataset, + index& idx) +{ + detail::build(res, params, dataset, idx); +} + +/** + * @brief Build nn-descent Index with dataset in host memory + * + * The following distance metrics are supported: + * - L2 + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors::experimental; + * // use default index parameters + * nn_descent::index_params index_params; + * // create and fill the index from a [N, D] raft::host_matrix_view dataset + * auto index = cagra::build(res, index_params, dataset); + * // index.graph() provides a raft::host_matrix_view of an + * // all-neighbors knn graph of dimensions [N, k] of the input + * // dataset + * @endcode + * + * @tparam T data-type of the input dataset + * @tparam IdxT data-type for the output index + * @param res raft::resources is an object mangaging resources + * @param[in] params an instance of nn_descent::index_params that are parameters + * to run the nn-descent algorithm + * @param[in] dataset raft::host_matrix_view input dataset expected to be located + * in host memory + * @return index index containing all-neighbors knn graph in host memory + */ +template +index build(raft::resources const& res, + index_params const& params, + raft::host_matrix_view dataset) +{ + return detail::build(res, params, dataset); +} + +/** + * @brief Build nn-descent Index with dataset in host memory + * + * The following distance metrics are supported: + * - L2 + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors::experimental; + * // use default index parameters + * nn_descent::index_params index_params; + * // create and fill the index from a [N, D] raft::host_matrix_view dataset + * auto knn_graph = raft::make_host_matrix(N, D); + * auto index = nn_descent::index{res, knn_graph.view()}; + * cagra::build(res, index_params, dataset, index); + * // index.graph() provides a raft::host_matrix_view of an + * // all-neighbors knn graph of dimensions [N, k] of the input + * // dataset + * @endcode + * + * @tparam T data-type of the input dataset + * @tparam IdxT data-type for the output index + * @param[in] res raft::resources is an object mangaging resources + * @param[in] params an instance of nn_descent::index_params that are parameters + * to run the nn-descent algorithm + * @param[in] dataset raft::host_matrix_view input dataset expected to be located + * in host memory + * @param[out] idx raft::neighbors::experimental::nn_descentindex containing all-neighbors knn graph + * in host memory + */ +template +void build(raft::resources const& res, + index_params const& params, + raft::host_matrix_view dataset, + index& idx) +{ + detail::build(res, params, dataset, idx); +} + +/** @} */ // end group nn-descent + +} // namespace raft::neighbors::experimental::nn_descent diff --git a/cpp/include/raft/neighbors/nn_descent_types.hpp b/cpp/include/raft/neighbors/nn_descent_types.hpp new file mode 100644 index 0000000000..64e464c618 --- /dev/null +++ b/cpp/include/raft/neighbors/nn_descent_types.hpp @@ -0,0 +1,147 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "ann_types.hpp" +#include + +#include +#include +#include +#include +#include + +namespace raft::neighbors::experimental::nn_descent { +/** + * @ingroup nn_descent + * @{ + */ + +/** + * @brief Parameters used to build an nn-descent index + * + * `graph_degree`: For an input dataset of dimensions (N, D), + * determines the final dimensions of the all-neighbors knn graph + * which turns out to be of dimensions (N, graph_degree) + * `intermediate_graph_degree`: Internally, nn-descent builds an + * all-neighbors knn graph of dimensions (N, intermediate_graph_degree) + * before selecting the final `graph_degree` neighbors. It's recommended + * that `intermediate_graph_degree` >= 1.5 * graph_degree + * `max_iterations`: The number of iterations that nn-descent will refine + * the graph for. More iterations produce a better quality graph at cost of performance + * `termination_threshold`: The delta at which nn-descent will terminate its iterations + * + */ +struct index_params : ann::index_params { + size_t graph_degree = 64; // Degree of output graph. + size_t intermediate_graph_degree = 128; // Degree of input graph for pruning. + size_t max_iterations = 20; // Number of nn-descent iterations. + float termination_threshold = 0.0001; // Termination threshold of nn-descent. +}; + +/** + * @brief nn-descent Build an nn-descent index + * The index contains an all-neighbors graph of the input dataset + * stored in host memory of dimensions (n_rows, n_cols) + * + * @tparam IdxT dtype to be used for constructing knn-graph + */ +template +struct index : ann::index { + public: + /** + * @brief Construct a new index object + * + * This constructor creates an nn-descent index which is a knn-graph in host memory. + * The type of the knn-graph is a dense raft::host_matrix and dimensions are + * (n_rows, n_cols). + * + * @param res raft::resources is an object mangaging resources + * @param n_rows number of rows in knn-graph + * @param n_cols number of cols in knn-graph + */ + index(raft::resources const& res, int64_t n_rows, int64_t n_cols) + : ann::index(), + res_{res}, + metric_{raft::distance::DistanceType::L2Expanded}, + graph_{raft::make_host_matrix(n_rows, n_cols)}, + graph_view_{graph_.view()} + { + } + + /** + * @brief Construct a new index object + * + * This constructor creates an nn-descent index using a user allocated host memory knn-graph. + * The type of the knn-graph is a dense raft::host_matrix and dimensions are + * (n_rows, n_cols). + * + * @param res raft::resources is an object mangaging resources + * @param graph_view raft::host_matrix_view for storing knn-graph + */ + index(raft::resources const& res, + raft::host_matrix_view graph_view) + : ann::index(), + res_{res}, + metric_{raft::distance::DistanceType::L2Expanded}, + graph_{raft::make_host_matrix(0, 0)}, + graph_view_{graph_view} + { + } + + /** Distance metric used for clustering. */ + [[nodiscard]] constexpr inline auto metric() const noexcept -> raft::distance::DistanceType + { + return metric_; + } + + // /** Total length of the index (number of vectors). */ + [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT + { + return graph_view_.extent(0); + } + + /** Graph degree */ + [[nodiscard]] constexpr inline auto graph_degree() const noexcept -> uint32_t + { + return graph_view_.extent(1); + } + + /** neighborhood graph [size, graph-degree] */ + [[nodiscard]] inline auto graph() noexcept -> host_matrix_view + { + return graph_view_; + } + + // Don't allow copying the index for performance reasons (try avoiding copying data) + index(const index&) = delete; + index(index&&) = default; + auto operator=(const index&) -> index& = delete; + auto operator=(index&&) -> index& = default; + ~index() = default; + + private: + raft::resources const& res_; + raft::distance::DistanceType metric_; + raft::host_matrix graph_; // graph to return for non-int IdxT + raft::host_matrix_view + graph_view_; // view of graph for user provided matrix +}; + +/** @} */ + +} // namespace raft::neighbors::experimental::nn_descent diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index db4c59c807..71de21e64a 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -379,6 +379,21 @@ if(BUILD_TESTS) 100 ) + ConfigureTest( + NAME + NEIGHBORS_ANN_NN_DESCENT_TEST + PATH + test/neighbors/ann_nn_descent/test_float_uint32_t.cu + test/neighbors/ann_nn_descent/test_int8_t_uint32_t.cu + test/neighbors/ann_nn_descent/test_uint8_t_uint32_t.cu + LIB + EXPLICIT_INSTANTIATE_ONLY + GPUS + 1 + PERCENT + 100 + ) + ConfigureTest( NAME NEIGHBORS_SELECTION_TEST PATH test/neighbors/selection.cu LIB EXPLICIT_INSTANTIATE_ONLY GPUS 1 PERCENT 50 diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 90f271e3ee..343afd04ec 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -147,6 +147,7 @@ struct AnnCagraInputs { int n_rows; int dim; int k; + graph_build_algo build_algo; search_algo algo; int max_queries; int team_size; @@ -161,12 +162,13 @@ struct AnnCagraInputs { inline ::std::ostream& operator<<(::std::ostream& os, const AnnCagraInputs& p) { - std::vector algo = {"single-cta", "multi_cta", "multi_kernel", "auto"}; + std::vector algo = {"single-cta", "multi_cta", "multi_kernel", "auto"}; + std::vector build_algo = {"IVF_PQ", "NN_DESCENT"}; os << "{n_queries=" << p.n_queries << ", dataset shape=" << p.n_rows << "x" << p.dim << ", k=" << p.k << ", " << algo.at((int)p.algo) << ", max_queries=" << p.max_queries << ", itopk_size=" << p.itopk_size << ", search_width=" << p.search_width - << ", metric=" << static_cast(p.metric) << (p.host_dataset ? ", host" : ", device") << '}' - << std::endl; + << ", metric=" << static_cast(p.metric) << (p.host_dataset ? ", host" : ", device") + << ", build_algo=" << build_algo.at((int)p.build_algo) << '}' << std::endl; return os; } @@ -216,6 +218,7 @@ class AnnCagraTest : public ::testing::TestWithParam { cagra::index_params index_params; index_params.metric = ps.metric; // Note: currently ony the cagra::index_params metric is // not used for knn_graph building. + index_params.build_algo = ps.build_algo; cagra::search_params search_params; search_params.algo = ps.algo; search_params.max_queries = ps.max_queries; @@ -340,11 +343,25 @@ class AnnCagraSortTest : public ::testing::TestWithParam { auto knn_graph = raft::make_host_matrix(ps.n_rows, index_params.intermediate_graph_degree); - if (ps.host_dataset) { - cagra::build_knn_graph(handle_, database_host_view, knn_graph.view()); + if (ps.build_algo == graph_build_algo::IVF_PQ) { + if (ps.host_dataset) { + cagra::build_knn_graph(handle_, database_host_view, knn_graph.view()); + } else { + cagra::build_knn_graph(handle_, database_view, knn_graph.view()); + } } else { - cagra::build_knn_graph(handle_, database_view, knn_graph.view()); - }; + auto nn_descent_idx_params = experimental::nn_descent::index_params{}; + nn_descent_idx_params.graph_degree = index_params.intermediate_graph_degree; + nn_descent_idx_params.intermediate_graph_degree = index_params.intermediate_graph_degree; + + if (ps.host_dataset) { + cagra::build_knn_graph( + handle_, database_host_view, knn_graph.view(), nn_descent_idx_params); + } else { + cagra::build_knn_graph( + handle_, database_host_view, knn_graph.view(), nn_descent_idx_params); + } + } handle_.sync_stream(); ASSERT_TRUE(CheckOrder(knn_graph.view(), database_host.view())); @@ -546,6 +563,7 @@ inline std::vector generate_inputs() {1000}, {1, 8, 17}, {1, 16}, // k + {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, {search_algo::SINGLE_CTA, search_algo::MULTI_CTA, search_algo::MULTI_KERNEL}, {0, 1, 10, 100}, // query size {0}, @@ -561,6 +579,7 @@ inline std::vector generate_inputs() {1000}, {1, 3, 5, 7, 8, 17, 64, 128, 137, 192, 256, 512, 619, 1024}, // dim {16}, // k + {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, {search_algo::AUTO}, {10}, {0}, @@ -571,68 +590,55 @@ inline std::vector generate_inputs() {true}, {0.995}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); - inputs2 = - raft::util::itertools::product({100}, - {1000}, - {64}, - {16}, - {search_algo::AUTO}, - {10}, - {0, 4, 8, 16, 32}, // team_size - {64}, - {1}, - {raft::distance::DistanceType::L2Expanded}, - {false}, - {false}, - {0.995}); - inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); - - inputs2 = - raft::util::itertools::product({100}, - {1000}, - {64}, - {16}, - {search_algo::AUTO}, - {10}, - {0}, // team_size - {32, 64, 128, 256, 512, 768}, - {1}, - {raft::distance::DistanceType::L2Expanded}, - {false}, - {true}, - {0.995}); + inputs2 = raft::util::itertools::product( + {100}, + {1000}, + {64}, + {16}, + {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, + {search_algo::AUTO}, + {10}, + {0, 4, 8, 16, 32}, // team_size + {64}, + {1}, + {raft::distance::DistanceType::L2Expanded}, + {false}, + {false}, + {0.995}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); - inputs2 = - raft::util::itertools::product({100}, - {10000, 20000}, - {32}, - {10}, - {search_algo::AUTO}, - {10}, - {0}, // team_size - {64}, - {1}, - {raft::distance::DistanceType::L2Expanded}, - {false, true}, - {false}, - {0.995}); + inputs2 = raft::util::itertools::product( + {100}, + {1000}, + {64}, + {16}, + {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, + {search_algo::AUTO}, + {10}, + {0}, // team_size + {32, 64, 128, 256, 512, 768}, + {1}, + {raft::distance::DistanceType::L2Expanded}, + {false}, + {true}, + {0.995}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); - inputs2 = - raft::util::itertools::product({100}, - {10000, 20000}, - {32}, - {10}, - {search_algo::AUTO}, - {10}, - {0}, // team_size - {64}, - {1}, - {raft::distance::DistanceType::L2Expanded}, - {false, true}, - {true}, - {0.995}); + inputs2 = raft::util::itertools::product( + {100}, + {10000, 20000}, + {32}, + {10}, + {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, + {search_algo::AUTO}, + {10}, + {0}, // team_size + {64}, + {1}, + {raft::distance::DistanceType::L2Expanded}, + {false, true}, + {false}, + {0.995}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); return inputs; diff --git a/cpp/test/neighbors/ann_nn_descent.cuh b/cpp/test/neighbors/ann_nn_descent.cuh new file mode 100644 index 0000000000..948323cf6e --- /dev/null +++ b/cpp/test/neighbors/ann_nn_descent.cuh @@ -0,0 +1,154 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../test_utils.cuh" +#include "ann_utils.cuh" + +#include + +#include +#include +#include + +#include + +#include +#include +#include +#include + +namespace raft::neighbors::experimental::nn_descent { + +struct AnnNNDescentInputs { + int n_rows; + int dim; + int graph_degree; + raft::distance::DistanceType metric; + bool host_dataset; + double min_recall; +}; + +inline ::std::ostream& operator<<(::std::ostream& os, const AnnNNDescentInputs& p) +{ + os << "dataset shape=" << p.n_rows << "x" << p.dim << ", graph_degree=" << p.graph_degree + << ", metric=" << static_cast(p.metric) << (p.host_dataset ? ", host" : ", device") + << std::endl; + return os; +} + +template +class AnnNNDescentTest : public ::testing::TestWithParam { + public: + AnnNNDescentTest() + : stream_(resource::get_cuda_stream(handle_)), + ps(::testing::TestWithParam::GetParam()), + database(0, stream_) + { + } + + protected: + void testNNDescent() + { + size_t queries_size = ps.n_rows * ps.graph_degree; + std::vector indices_NNDescent(queries_size); + std::vector indices_naive(queries_size); + + { + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + naive_knn(handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + database.data(), + database.data(), + ps.n_rows, + ps.n_rows, + ps.dim, + ps.graph_degree, + ps.metric); + update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); + } + + { + { + nn_descent::index_params index_params; + index_params.metric = ps.metric; + index_params.graph_degree = ps.graph_degree; + index_params.intermediate_graph_degree = 2 * ps.graph_degree; + + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.n_rows, ps.dim); + + { + if (ps.host_dataset) { + auto database_host = raft::make_host_matrix(ps.n_rows, ps.dim); + raft::copy(database_host.data_handle(), database.data(), database.size(), stream_); + auto database_host_view = raft::make_host_matrix_view( + (const DataT*)database_host.data_handle(), ps.n_rows, ps.dim); + auto index = nn_descent::build(handle_, index_params, database_host_view); + update_host( + indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_); + } else { + auto index = nn_descent::build(handle_, index_params, database_view); + update_host( + indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_); + }; + } + resource::sync_stream(handle_); + } + + double min_recall = ps.min_recall; + EXPECT_TRUE(eval_recall( + indices_naive, indices_NNDescent, ps.n_rows, ps.graph_degree, 0.001, min_recall)); + } + } + + void SetUp() override + { + database.resize(((size_t)ps.n_rows) * ps.dim, stream_); + raft::random::Rng r(1234ULL); + if constexpr (std::is_same{}) { + r.normal(database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0), stream_); + } else { + r.uniformInt(database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20), stream_); + } + resource::sync_stream(handle_); + } + + void TearDown() override + { + resource::sync_stream(handle_); + database.resize(0, stream_); + } + + private: + raft::resources handle_; + rmm::cuda_stream_view stream_; + AnnNNDescentInputs ps; + rmm::device_uvector database; +}; + +const std::vector inputs = raft::util::itertools::product( + {1000, 2000}, // n_rows + {3, 5, 7, 8, 17, 64, 128, 137, 192, 256, 512, 619, 1024}, // dim + {32, 64}, // graph_degree + {raft::distance::DistanceType::L2Expanded}, + {false, true}, + {0.92}); + +} // namespace raft::neighbors::experimental::nn_descent diff --git a/cpp/test/neighbors/ann_nn_descent/test_float_uint32_t.cu b/cpp/test/neighbors/ann_nn_descent/test_float_uint32_t.cu new file mode 100644 index 0000000000..13bff6ac90 --- /dev/null +++ b/cpp/test/neighbors/ann_nn_descent/test_float_uint32_t.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../ann_nn_descent.cuh" + +namespace raft::neighbors::experimental::nn_descent { + +typedef AnnNNDescentTest AnnNNDescentTestF_U32; +TEST_P(AnnNNDescentTestF_U32, AnnCagra) { this->testNNDescent(); } + +INSTANTIATE_TEST_CASE_P(AnnNNDescentTest, AnnNNDescentTestF_U32, ::testing::ValuesIn(inputs)); + +} // namespace raft::neighbors::experimental::nn_descent diff --git a/cpp/test/neighbors/ann_nn_descent/test_int8_t_uint32_t.cu b/cpp/test/neighbors/ann_nn_descent/test_int8_t_uint32_t.cu new file mode 100644 index 0000000000..5895303e09 --- /dev/null +++ b/cpp/test/neighbors/ann_nn_descent/test_int8_t_uint32_t.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../ann_nn_descent.cuh" + +namespace raft::neighbors::experimental::nn_descent { + +typedef AnnNNDescentTest AnnNNDescentTestI8_U32; +TEST_P(AnnNNDescentTestI8_U32, AnnCagra) { this->testNNDescent(); } + +INSTANTIATE_TEST_CASE_P(AnnNNDescentTest, AnnNNDescentTestI8_U32, ::testing::ValuesIn(inputs)); + +} // namespace raft::neighbors::experimental::nn_descent diff --git a/cpp/test/neighbors/ann_nn_descent/test_uint8_t_uint32_t.cu b/cpp/test/neighbors/ann_nn_descent/test_uint8_t_uint32_t.cu new file mode 100644 index 0000000000..a034e84074 --- /dev/null +++ b/cpp/test/neighbors/ann_nn_descent/test_uint8_t_uint32_t.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../ann_nn_descent.cuh" + +namespace raft::neighbors::experimental::nn_descent { + +typedef AnnNNDescentTest AnnNNDescentTestUI8_U32; +TEST_P(AnnNNDescentTestUI8_U32, AnnCagra) { this->testNNDescent(); } + +INSTANTIATE_TEST_CASE_P(AnnNNDescentTest, AnnNNDescentTestUI8_U32, ::testing::ValuesIn(inputs)); + +} // namespace raft::neighbors::experimental::nn_descent diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh index 0e54e29c01..be60ec5b6d 100644 --- a/cpp/test/neighbors/ann_utils.cuh +++ b/cpp/test/neighbors/ann_utils.cuh @@ -123,6 +123,49 @@ struct idx_dist_pair { idx_dist_pair(IdxT x, DistT y, CompareDist op) : idx(x), dist(y), eq_compare(op) {} }; +template +auto eval_recall(const std::vector& expected_idx, + const std::vector& actual_idx, + size_t rows, + size_t cols, + double eps, + double min_recall) -> testing::AssertionResult +{ + size_t match_count = 0; + size_t total_count = static_cast(rows) * static_cast(cols); + for (size_t i = 0; i < rows; ++i) { + for (size_t k = 0; k < cols; ++k) { + size_t idx_k = i * cols + k; // row major assumption! + auto act_idx = actual_idx[idx_k]; + for (size_t j = 0; j < cols; ++j) { + size_t idx = i * cols + j; // row major assumption! + auto exp_idx = expected_idx[idx]; + if (act_idx == exp_idx) { + match_count++; + break; + } + } + } + } + double actual_recall = static_cast(match_count) / static_cast(total_count); + double error_margin = (actual_recall - min_recall) / std::max(1.0 - min_recall, eps); + RAFT_LOG_INFO("Recall = %f (%zu/%zu), the error is %2.1f%% %s the threshold (eps = %f).", + actual_recall, + match_count, + total_count, + std::abs(error_margin * 100.0), + error_margin < 0 ? "above" : "below", + eps); + if (actual_recall < min_recall - eps) { + return testing::AssertionFailure() + << "actual recall (" << actual_recall << ") is lower than the minimum expected recall (" + << min_recall << "); eps = " << eps << ". "; + } + return testing::AssertionSuccess(); +} + +/** same as eval_recall, but in case indices do not match, + * then check distances as well, and accept match if actual dist is equal to expected_dist */ template auto eval_neighbours(const std::vector& expected_idx, const std::vector& actual_idx, diff --git a/docs/source/ann_benchmarks_param_tuning.md b/docs/source/ann_benchmarks_param_tuning.md index dd6090c5e2..433df2ae2f 100644 --- a/docs/source/ann_benchmarks_param_tuning.md +++ b/docs/source/ann_benchmarks_param_tuning.md @@ -48,7 +48,8 @@ CAGRA uses a graph-based index, which creates an intermediate, approximate kNN g |-----------------------------|----------------|----------|----------------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | `graph_degree` | `build_param` | N | Positive Integer >0 | 64 | Degree of the final kNN graph index. | | `intermediate_graph_degree` | `build_param` | N | Positive Integer >0 | 128 | Degree of the intermediate kNN graph. | -| `dataset_memory_type` | `build_param` | N | ["device", "host", "mmap"] | "device" | What memory type should the dataset reside? | +| `graph_build_algo` | `build_param` | N | ["IVF_PQ", "NN_DESCENT"] | "IVF_PQ" | Algorithm to use for search | +| `dataset_memory_type` | `build_param` | N | ["device", "host", "mmap"] | "device" | What memory type should the dataset reside? | | `query_memory_type` | `search_params` | N | ["device", "host", "mmap"] | "device | What memory type should the queries reside? | | `itopk` | `search_wdith` | N | Positive Integer >0 | 64 | Number of intermediate search results retained during the search. Higher values improve search accuracy at the cost of speed. | | `search_width` | `search_param` | N | Positive Integer >0 | 1 | Number of graph nodes to select as the starting point for the search in each iteration. | diff --git a/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx b/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx index e0c59a5ed3..c11d933b27 100644 --- a/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx +++ b/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx @@ -104,11 +104,13 @@ cdef class IndexParams: graph_degree : int, default = 64 - add_data_on_build : bool, default = True - After training the coarse and fine quantizers, we will populate - the index with the dataset if add_data_on_build == True, otherwise - the index is left empty, and the extend method can be used - to add new vectors to the index. + build_algo: string denoting the graph building algorithm to use, + default = "ivf_pq" + Valid values for algo: ["ivf_pq", "nn_descent"], where + - ivf_pq will use the IVF-PQ algorithm for building the knn graph + - nn_descent (experimental) will use the NN-Descent algorithm for + building the knn graph. It is expected to be generally + faster than ivf_pq. """ cdef c_cagra.index_params params @@ -116,12 +118,15 @@ cdef class IndexParams: metric="sqeuclidean", intermediate_graph_degree=128, graph_degree=64, - add_data_on_build=True): + build_algo="ivf_pq"): self.params.metric = _get_metric(metric) self.params.metric_arg = 0 self.params.intermediate_graph_degree = intermediate_graph_degree self.params.graph_degree = graph_degree - self.params.add_data_on_build = add_data_on_build + if build_algo == "ivf_pq": + self.params.build_algo = c_cagra.graph_build_algo.IVF_PQ + elif build_algo == "nn_descent": + self.params.build_algo = c_cagra.graph_build_algo.NN_DESCENT @property def metric(self): @@ -135,10 +140,6 @@ cdef class IndexParams: def graph_degree(self): return self.params.graph_degree - @property - def add_data_on_build(self): - return self.params.add_data_on_build - cdef class Index: cdef readonly bool trained diff --git a/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd b/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd index 0c683bcd9b..7e22f274e9 100644 --- a/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd +++ b/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd @@ -51,9 +51,14 @@ from pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq cimport ( cdef extern from "raft/neighbors/cagra_types.hpp" \ namespace "raft::neighbors::cagra" nogil: + ctypedef enum graph_build_algo: + IVF_PQ "raft::neighbors::cagra::graph_build_algo::IVF_PQ", + NN_DESCENT "raft::neighbors::cagra::graph_build_algo::NN_DESCENT" + cpdef cppclass index_params(ann_index_params): size_t intermediate_graph_degree size_t graph_degree + graph_build_algo build_algo ctypedef enum search_algo: SINGLE_CTA "raft::neighbors::cagra::search_algo::SINGLE_CTA", diff --git a/python/pylibraft/pylibraft/test/test_cagra.py b/python/pylibraft/pylibraft/test/test_cagra.py index 74e9f53b91..f74fc5ae62 100644 --- a/python/pylibraft/pylibraft/test/test_cagra.py +++ b/python/pylibraft/pylibraft/test/test_cagra.py @@ -52,6 +52,7 @@ def run_cagra_build_search_test( metric="euclidean", intermediate_graph_degree=128, graph_degree=64, + build_algo="ivf_pq", array_type="device", compare=True, inplace=True, @@ -67,6 +68,7 @@ def run_cagra_build_search_test( metric=metric, intermediate_graph_degree=intermediate_graph_degree, graph_degree=graph_degree, + build_algo=build_algo, ) if array_type == "device": @@ -139,13 +141,17 @@ def run_cagra_build_search_test( @pytest.mark.parametrize("inplace", [True, False]) @pytest.mark.parametrize("dtype", [np.float32, np.int8, np.uint8]) @pytest.mark.parametrize("array_type", ["device", "host"]) -def test_cagra_dataset_dtype_host_device(dtype, array_type, inplace): +@pytest.mark.parametrize("build_algo", ["ivf_pq", "nn_descent"]) +def test_cagra_dataset_dtype_host_device( + dtype, array_type, inplace, build_algo +): # Note that inner_product tests use normalized input which we cannot # represent in int8, therefore we test only sqeuclidean metric here. run_cagra_build_search_test( dtype=dtype, inplace=inplace, array_type=array_type, + build_algo=build_algo, ) @@ -158,6 +164,7 @@ def test_cagra_dataset_dtype_host_device(dtype, array_type, inplace): "add_data_on_build": True, "k": 1, "metric": "euclidean", + "build_algo": "ivf_pq", }, { "intermediate_graph_degree": 32, @@ -165,6 +172,7 @@ def test_cagra_dataset_dtype_host_device(dtype, array_type, inplace): "add_data_on_build": False, "k": 5, "metric": "sqeuclidean", + "build_algo": "ivf_pq", }, { "intermediate_graph_degree": 128, @@ -172,6 +180,7 @@ def test_cagra_dataset_dtype_host_device(dtype, array_type, inplace): "add_data_on_build": True, "k": 10, "metric": "inner_product", + "build_algo": "nn_descent", }, ], ) @@ -184,6 +193,7 @@ def test_cagra_index_params(params): graph_degree=params["graph_degree"], intermediate_graph_degree=params["intermediate_graph_degree"], compare=False, + build_algo=params["build_algo"], ) From ed42bb5c706c9903b51388f5b214ceea360d9b0a Mon Sep 17 00:00:00 2001 From: tsuki <12711693+enp1s0@users.noreply.github.com> Date: Wed, 27 Sep 2023 00:48:53 +0800 Subject: [PATCH 06/10] Change CAGRA auto mode selection (#1830) This PR changes the rule to select a CAGRA search mode as written in the [CAGRA paper](https://arxiv.org/abs/2308.15136). Authors: - tsuki (https://github.com/enp1s0) - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1830 --- .../raft/neighbors/detail/cagra/search_multi_cta.cuh | 7 ++++--- cpp/include/raft/neighbors/detail/cagra/search_plan.cuh | 7 ++++--- python/pylibraft/pylibraft/test/test_cagra.py | 2 +- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh index 9a722a6dfe..c6478bef84 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh @@ -109,9 +109,10 @@ struct search : public search_plan_implitopk_size = 32; - search_width = 1; - num_cta_per_query = max(params.search_width, params.itopk_size / 32); + constexpr unsigned muti_cta_itopk_size = 32; + this->itopk_size = muti_cta_itopk_size; + search_width = 1; + num_cta_per_query = max(params.search_width, params.itopk_size / muti_cta_itopk_size); result_buffer_size = itopk_size + search_width * graph_degree; typedef raft::Pow2<32> AlignBytes; unsigned result_buffer_size_32 = AlignBytes::roundUp(result_buffer_size); diff --git a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh index 9419385836..a0f346ab51 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh @@ -38,12 +38,13 @@ struct search_plan_impl_base : public search_params { { set_max_dim_team(dim); if (algo == search_algo::AUTO) { - if (itopk_size <= 512) { + const size_t num_sm = raft::getMultiProcessorCount(); + if (itopk_size <= 512 && search_params::max_queries >= num_sm * 2lu) { algo = search_algo::SINGLE_CTA; RAFT_LOG_DEBUG("Auto strategy: selecting single-cta"); } else { - algo = search_algo::MULTI_KERNEL; - RAFT_LOG_DEBUG("Auto strategy: selecting multi-kernel"); + algo = search_algo::MULTI_CTA; + RAFT_LOG_DEBUG("Auto strategy: selecting multi-cta"); } } } diff --git a/python/pylibraft/pylibraft/test/test_cagra.py b/python/pylibraft/pylibraft/test/test_cagra.py index f74fc5ae62..24126c0c5a 100644 --- a/python/pylibraft/pylibraft/test/test_cagra.py +++ b/python/pylibraft/pylibraft/test/test_cagra.py @@ -251,7 +251,7 @@ def test_cagra_index_params(params): "search_width": 4, "min_iterations": 0, "thread_block_size": 0, - "hashmap_mode": "small", + "hashmap_mode": "auto", "hashmap_min_bitlen": 0, "hashmap_max_fill_rate": 0.5, "num_random_samplings": 1, From b9b5f4444a38ee0d82d5823e8348470057c9d149 Mon Sep 17 00:00:00 2001 From: Micka Date: Tue, 26 Sep 2023 22:19:22 +0200 Subject: [PATCH 07/10] [FEA] Add bitset for ANN pre-filtering and deletion (#1803) Related to #1600 Authors: - Micka (https://github.com/lowener) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1803 --- build.sh | 2 +- cpp/bench/prims/CMakeLists.txt | 2 + cpp/bench/prims/core/bitset.cu | 74 ++++++ cpp/include/raft/core/bitset.cuh | 308 ++++++++++++++++++++++ cpp/include/raft/util/memory_pool-inl.hpp | 5 + cpp/test/CMakeLists.txt | 1 + cpp/test/core/bitset.cu | 188 +++++++++++++ docs/source/cpp_api.rst | 3 +- docs/source/cpp_api/core.rst | 3 +- docs/source/cpp_api/core_bitset.rst | 15 ++ docs/source/cpp_api/utils.rst | 21 ++ 11 files changed, 619 insertions(+), 3 deletions(-) create mode 100644 cpp/bench/prims/core/bitset.cu create mode 100644 cpp/include/raft/core/bitset.cuh create mode 100644 cpp/test/core/bitset.cu create mode 100644 docs/source/cpp_api/core_bitset.rst create mode 100644 docs/source/cpp_api/utils.rst diff --git a/build.sh b/build.sh index 32582738c3..6200e6a2fa 100755 --- a/build.sh +++ b/build.sh @@ -79,7 +79,7 @@ BUILD_REPORT_METRICS="" BUILD_REPORT_INCL_CACHE_STATS=OFF TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" -BENCH_TARGETS="CLUSTER_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH" +BENCH_TARGETS="CLUSTER_BENCH;CORE_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH" CACHE_ARGS="" NVTX=ON diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index e8d4739384..ca4b0f099d 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -77,6 +77,7 @@ if(BUILD_PRIMS_BENCH) NAME CLUSTER_BENCH PATH bench/prims/cluster/kmeans_balanced.cu bench/prims/cluster/kmeans.cu bench/prims/main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY ) + ConfigureBench(NAME CORE_BENCH PATH bench/prims/core/bitset.cu bench/prims/main.cpp) ConfigureBench( NAME TUNE_DISTANCE PATH bench/prims/distance/tune_pairwise/kernel.cu @@ -155,4 +156,5 @@ if(BUILD_PRIMS_BENCH) LIB EXPLICIT_INSTANTIATE_ONLY ) + endif() diff --git a/cpp/bench/prims/core/bitset.cu b/cpp/bench/prims/core/bitset.cu new file mode 100644 index 0000000000..5f44aa9af5 --- /dev/null +++ b/cpp/bench/prims/core/bitset.cu @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +namespace raft::bench::core { + +struct bitset_inputs { + uint32_t bitset_len; + uint32_t mask_len; + uint32_t query_len; +}; // struct bitset_inputs + +template +struct bitset_bench : public fixture { + bitset_bench(const bitset_inputs& p) + : params(p), + mask{raft::make_device_vector(res, p.mask_len)}, + queries{raft::make_device_vector(res, p.query_len)}, + outputs{raft::make_device_vector(res, p.query_len)} + { + raft::random::RngState state{42}; + raft::random::uniformInt(res, state, mask.view(), index_t{0}, index_t{p.bitset_len}); + } + + void run_benchmark(::benchmark::State& state) override + { + loop_on_state(state, [this]() { + auto my_bitset = raft::core::bitset( + this->res, raft::make_const_mdspan(mask.view()), params.bitset_len); + my_bitset.test(res, raft::make_const_mdspan(queries.view()), outputs.view()); + }); + } + + private: + raft::resources res; + bitset_inputs params; + raft::device_vector mask, queries; + raft::device_vector outputs; +}; // struct bitset + +const std::vector bitset_input_vecs{ + {256 * 1024 * 1024, 64 * 1024 * 1024, 256 * 1024 * 1024}, // Standard Bench + {256 * 1024 * 1024, 64 * 1024 * 1024, 1024 * 1024 * 1024}, // Extra queries + {128 * 1024 * 1024, 1024 * 1024 * 1024, 256 * 1024 * 1024}, // Extra mask to test atomics impact +}; + +using Uint8_32 = bitset_bench; +using Uint16_64 = bitset_bench; +using Uint32_32 = bitset_bench; +using Uint32_64 = bitset_bench; + +RAFT_BENCH_REGISTER(Uint8_32, "", bitset_input_vecs); +RAFT_BENCH_REGISTER(Uint16_64, "", bitset_input_vecs); +RAFT_BENCH_REGISTER(Uint32_32, "", bitset_input_vecs); +RAFT_BENCH_REGISTER(Uint32_64, "", bitset_input_vecs); + +} // namespace raft::bench::core diff --git a/cpp/include/raft/core/bitset.cuh b/cpp/include/raft/core/bitset.cuh new file mode 100644 index 0000000000..6747c5fab0 --- /dev/null +++ b/cpp/include/raft/core/bitset.cuh @@ -0,0 +1,308 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace raft::core { +/** + * @defgroup bitset Bitset + * @{ + */ +/** + * @brief View of a RAFT Bitset. + * + * This lightweight structure stores a pointer to a bitset in device memory with it's length. + * It provides a test() device function to check if a given index is set in the bitset. + * + * @tparam bitset_t Underlying type of the bitset array. Default is uint32_t. + * @tparam index_t Indexing type used. Default is uint32_t. + */ +template +struct bitset_view { + index_t static constexpr const bitset_element_size = sizeof(bitset_t) * 8; + + _RAFT_HOST_DEVICE bitset_view(bitset_t* bitset_ptr, index_t bitset_len) + : bitset_ptr_{bitset_ptr}, bitset_len_{bitset_len} + { + } + /** + * @brief Create a bitset view from a device vector view of the bitset. + * + * @param bitset_span Device vector view of the bitset + * @param bitset_len Number of bits in the bitset + */ + _RAFT_HOST_DEVICE bitset_view(raft::device_vector_view bitset_span, + index_t bitset_len) + : bitset_ptr_{bitset_span.data_handle()}, bitset_len_{bitset_len} + { + } + /** + * @brief Device function to test if a given index is set in the bitset. + * + * @param sample_index Single index to test + * @return bool True if index has not been unset in the bitset + */ + inline _RAFT_DEVICE auto test(const index_t sample_index) const -> bool + { + const bitset_t bit_element = bitset_ptr_[sample_index / bitset_element_size]; + const index_t bit_index = sample_index % bitset_element_size; + const bool is_bit_set = (bit_element & (bitset_t{1} << bit_index)) != 0; + return is_bit_set; + } + + /** + * @brief Get the device pointer to the bitset. + */ + inline _RAFT_HOST_DEVICE auto data_handle() -> bitset_t* { return bitset_ptr_; } + inline _RAFT_HOST_DEVICE auto data_handle() const -> const bitset_t* { return bitset_ptr_; } + /** + * @brief Get the number of bits of the bitset representation. + */ + inline _RAFT_HOST_DEVICE auto size() const -> index_t { return bitset_len_; } + + /** + * @brief Get the number of elements used by the bitset representation. + */ + inline _RAFT_HOST_DEVICE auto n_elements() const -> index_t + { + return raft::ceildiv(bitset_len_, bitset_element_size); + } + + inline auto to_mdspan() -> raft::device_vector_view + { + return raft::make_device_vector_view(bitset_ptr_, n_elements()); + } + inline auto to_mdspan() const -> raft::device_vector_view + { + return raft::make_device_vector_view(bitset_ptr_, n_elements()); + } + + private: + bitset_t* bitset_ptr_; + index_t bitset_len_; +}; + +/** + * @brief RAFT Bitset. + * + * This structure encapsulates a bitset in device memory. It provides a view() method to get a + * device-usable lightweight view of the bitset. + * Each index is represented by a single bit in the bitset. The total number of bytes used is + * ceil(bitset_len / 8). + * @tparam bitset_t Underlying type of the bitset array. Default is uint32_t. + * @tparam index_t Indexing type used. Default is uint32_t. + */ +template +struct bitset { + index_t static constexpr const bitset_element_size = sizeof(bitset_t) * 8; + + /** + * @brief Construct a new bitset object with a list of indices to unset. + * + * @param res RAFT resources + * @param mask_index List of indices to unset in the bitset + * @param bitset_len Length of the bitset + * @param default_value Default value to set the bits to. Default is true. + */ + bitset(const raft::resources& res, + raft::device_vector_view mask_index, + index_t bitset_len, + bool default_value = true) + : bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)), + raft::resource::get_cuda_stream(res)}, + bitset_len_{bitset_len}, + default_value_{default_value} + { + cudaMemsetAsync(bitset_.data(), + default_value ? 0xff : 0x00, + n_elements() * sizeof(bitset_t), + resource::get_cuda_stream(res)); + set(res, mask_index, !default_value); + } + + /** + * @brief Construct a new bitset object + * + * @param res RAFT resources + * @param bitset_len Length of the bitset + * @param default_value Default value to set the bits to. Default is true. + */ + bitset(const raft::resources& res, index_t bitset_len, bool default_value = true) + : bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)), + resource::get_cuda_stream(res)}, + bitset_len_{bitset_len}, + default_value_{default_value} + { + cudaMemsetAsync(bitset_.data(), + default_value ? 0xff : 0x00, + n_elements() * sizeof(bitset_t), + resource::get_cuda_stream(res)); + } + // Disable copy constructor + bitset(const bitset&) = delete; + bitset(bitset&&) = default; + bitset& operator=(const bitset&) = delete; + bitset& operator=(bitset&&) = default; + + /** + * @brief Create a device-usable view of the bitset. + * + * @return bitset_view + */ + inline auto view() -> raft::core::bitset_view + { + return bitset_view(to_mdspan(), bitset_len_); + } + [[nodiscard]] inline auto view() const -> raft::core::bitset_view + { + return bitset_view(to_mdspan(), bitset_len_); + } + + /** + * @brief Get the device pointer to the bitset. + */ + inline auto data_handle() -> bitset_t* { return bitset_.data(); } + inline auto data_handle() const -> const bitset_t* { return bitset_.data(); } + /** + * @brief Get the number of bits of the bitset representation. + */ + inline auto size() const -> index_t { return bitset_len_; } + + /** + * @brief Get the number of elements used by the bitset representation. + */ + inline auto n_elements() const -> index_t + { + return raft::ceildiv(bitset_len_, bitset_element_size); + } + + /** @brief Get an mdspan view of the current bitset */ + inline auto to_mdspan() -> raft::device_vector_view + { + return raft::make_device_vector_view(bitset_.data(), n_elements()); + } + [[nodiscard]] inline auto to_mdspan() const -> raft::device_vector_view + { + return raft::make_device_vector_view(bitset_.data(), n_elements()); + } + + /** @brief Resize the bitset. If the requested size is larger, new memory is allocated and set to + * the default value. */ + void resize(const raft::resources& res, index_t new_bitset_len) + { + auto old_size = raft::ceildiv(bitset_len_, bitset_element_size); + auto new_size = raft::ceildiv(new_bitset_len, bitset_element_size); + bitset_.resize(new_size); + bitset_len_ = new_bitset_len; + if (old_size < new_size) { + // If the new size is larger, set the new bits to the default value + cudaMemsetAsync(bitset_.data() + old_size, + default_value_ ? 0xff : 0x00, + (new_size - old_size) * sizeof(bitset_t), + resource::get_cuda_stream(res)); + } + } + + /** + * @brief Test a list of indices in a bitset. + * + * @tparam output_t Output type of the test. Default is bool. + * @param res RAFT resources + * @param queries List of indices to test + * @param output List of outputs + */ + template + void test(const raft::resources& res, + raft::device_vector_view queries, + raft::device_vector_view output) const + { + RAFT_EXPECTS(output.extent(0) == queries.extent(0), "Output and queries must be same size"); + auto bitset_view = view(); + raft::linalg::map( + res, + output, + [bitset_view] __device__(index_t query) { return output_t(bitset_view.test(query)); }, + queries); + } + /** + * @brief Set a list of indices in a bitset to set_value. + * + * @param res RAFT resources + * @param mask_index indices to remove from the bitset + * @param set_value Value to set the bits to (true or false) + */ + void set(const raft::resources& res, + raft::device_vector_view mask_index, + bool set_value = false) + { + auto* bitset_ptr = this->data_handle(); + thrust::for_each_n(resource::get_thrust_policy(res), + mask_index.data_handle(), + mask_index.extent(0), + [bitset_ptr, set_value] __device__(const index_t sample_index) { + const index_t bit_element = sample_index / bitset_element_size; + const index_t bit_index = sample_index % bitset_element_size; + const bitset_t bitmask = bitset_t{1} << bit_index; + if (set_value) { + atomicOr(bitset_ptr + bit_element, bitmask); + } else { + const bitset_t bitmask2 = ~bitmask; + atomicAnd(bitset_ptr + bit_element, bitmask2); + } + }); + } + /** + * @brief Flip all the bits in a bitset. + * + * @param res RAFT resources + */ + void flip(const raft::resources& res) + { + auto bitset_span = this->to_mdspan(); + raft::linalg::map( + res, + bitset_span, + [] __device__(bitset_t element) { return bitset_t(~element); }, + raft::make_const_mdspan(bitset_span)); + } + /** + * @brief Reset the bits in a bitset. + * + * @param res RAFT resources + */ + void reset(const raft::resources& res) + { + cudaMemsetAsync(bitset_.data(), + default_value_ ? 0xff : 0x00, + n_elements() * sizeof(bitset_t), + resource::get_cuda_stream(res)); + } + + private: + raft::device_uvector bitset_; + index_t bitset_len_; + bool default_value_; +}; + +/** @} */ +} // end namespace raft::core diff --git a/cpp/include/raft/util/memory_pool-inl.hpp b/cpp/include/raft/util/memory_pool-inl.hpp index 070c8f4e30..ad94ee0096 100644 --- a/cpp/include/raft/util/memory_pool-inl.hpp +++ b/cpp/include/raft/util/memory_pool-inl.hpp @@ -25,6 +25,10 @@ namespace raft { +/** + * @defgroup memory_pool Memory Pool + * @{ + */ /** * @brief Get a pointer to a pooled memory resource within the scope of the lifetime of the returned * unique pointer. @@ -73,4 +77,5 @@ RAFT_INLINE_CONDITIONAL std::unique_ptr get_poo return pool_res; } +/** @} */ } // namespace raft diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 71de21e64a..0651ccac86 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -105,6 +105,7 @@ if(BUILD_TESTS) NAME CORE_TEST PATH + test/core/bitset.cu test/core/device_resources_manager.cpp test/core/device_setter.cpp test/core/logger.cpp diff --git a/cpp/test/core/bitset.cu b/cpp/test/core/bitset.cu new file mode 100644 index 0000000000..215de98aaf --- /dev/null +++ b/cpp/test/core/bitset.cu @@ -0,0 +1,188 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" + +#include +#include +#include + +#include + +#include +#include + +namespace raft::core { + +struct test_spec_bitset { + uint64_t bitset_len; + uint64_t mask_len; + uint64_t query_len; +}; + +auto operator<<(std::ostream& os, const test_spec_bitset& ss) -> std::ostream& +{ + os << "bitset{bitset_len: " << ss.bitset_len << ", mask_len: " << ss.mask_len + << ", query_len: " << ss.query_len << "}"; + return os; +} + +template +void add_cpu_bitset(std::vector& bitset, const std::vector& mask_idx) +{ + static size_t constexpr const bitset_element_size = sizeof(bitset_t) * 8; + for (size_t i = 0; i < mask_idx.size(); i++) { + auto idx = mask_idx[i]; + bitset[idx / bitset_element_size] &= ~(bitset_t{1} << (idx % bitset_element_size)); + } +} + +template +void create_cpu_bitset(std::vector& bitset, const std::vector& mask_idx) +{ + for (size_t i = 0; i < bitset.size(); i++) { + bitset[i] = ~bitset_t(0x00); + } + add_cpu_bitset(bitset, mask_idx); +} + +template +void test_cpu_bitset(const std::vector& bitset, + const std::vector& queries, + std::vector& result) +{ + static size_t constexpr const bitset_element_size = sizeof(bitset_t) * 8; + for (size_t i = 0; i < queries.size(); i++) { + result[i] = uint8_t((bitset[queries[i] / bitset_element_size] & + (bitset_t{1} << (queries[i] % bitset_element_size))) != 0); + } +} + +template +void flip_cpu_bitset(std::vector& bitset) +{ + for (size_t i = 0; i < bitset.size(); i++) { + bitset[i] = ~bitset[i]; + } +} + +template +class BitsetTest : public testing::TestWithParam { + protected: + index_t static constexpr const bitset_element_size = sizeof(bitset_t) * 8; + const test_spec_bitset spec; + std::vector bitset_result; + std::vector bitset_ref; + raft::resources res; + + public: + explicit BitsetTest() + : spec(testing::TestWithParam::GetParam()), + bitset_result(raft::ceildiv(spec.bitset_len, uint64_t(bitset_element_size))), + bitset_ref(raft::ceildiv(spec.bitset_len, uint64_t(bitset_element_size))) + { + } + + void run() + { + auto stream = resource::get_cuda_stream(res); + + // generate input and mask + raft::random::RngState rng(42); + auto mask_device = raft::make_device_vector(res, spec.mask_len); + std::vector mask_cpu(spec.mask_len); + raft::random::uniformInt(res, rng, mask_device.view(), index_t(0), index_t(spec.bitset_len)); + update_host(mask_cpu.data(), mask_device.data_handle(), mask_device.extent(0), stream); + resource::sync_stream(res, stream); + + // calculate the results + auto my_bitset = raft::core::bitset( + res, raft::make_const_mdspan(mask_device.view()), index_t(spec.bitset_len)); + update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream); + + // calculate the reference + create_cpu_bitset(bitset_ref, mask_cpu); + resource::sync_stream(res, stream); + ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); + + auto query_device = raft::make_device_vector(res, spec.query_len); + auto result_device = raft::make_device_vector(res, spec.query_len); + auto query_cpu = std::vector(spec.query_len); + auto result_cpu = std::vector(spec.query_len); + auto result_ref = std::vector(spec.query_len); + + // Create queries and verify the test results + raft::random::uniformInt(res, rng, query_device.view(), index_t(0), index_t(spec.bitset_len)); + update_host(query_cpu.data(), query_device.data_handle(), query_device.extent(0), stream); + my_bitset.test(res, raft::make_const_mdspan(query_device.view()), result_device.view()); + update_host(result_cpu.data(), result_device.data_handle(), result_device.extent(0), stream); + test_cpu_bitset(bitset_ref, query_cpu, result_ref); + resource::sync_stream(res, stream); + ASSERT_TRUE(hostVecMatch(result_cpu, result_ref, Compare())); + + // Add more sample to the bitset and re-test + raft::random::uniformInt(res, rng, mask_device.view(), index_t(0), index_t(spec.bitset_len)); + update_host(mask_cpu.data(), mask_device.data_handle(), mask_device.extent(0), stream); + resource::sync_stream(res, stream); + my_bitset.set(res, mask_device.view()); + update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream); + + add_cpu_bitset(bitset_ref, mask_cpu); + resource::sync_stream(res, stream); + ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); + + // Flip the bitset and re-test + my_bitset.flip(res); + update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream); + flip_cpu_bitset(bitset_ref); + resource::sync_stream(res, stream); + ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); + } +}; + +auto inputs_bitset = ::testing::Values(test_spec_bitset{32, 5, 10}, + test_spec_bitset{100, 30, 10}, + test_spec_bitset{1024, 55, 100}, + test_spec_bitset{10000, 1000, 1000}, + test_spec_bitset{1 << 15, 1 << 3, 1 << 12}, + test_spec_bitset{1 << 15, 1 << 24, 1 << 13}, + test_spec_bitset{1 << 25, 1 << 23, 1 << 14}); + +using Uint16_32 = BitsetTest; +TEST_P(Uint16_32, Run) { run(); } +INSTANTIATE_TEST_CASE_P(BitsetTest, Uint16_32, inputs_bitset); + +using Uint32_32 = BitsetTest; +TEST_P(Uint32_32, Run) { run(); } +INSTANTIATE_TEST_CASE_P(BitsetTest, Uint32_32, inputs_bitset); + +using Uint64_32 = BitsetTest; +TEST_P(Uint64_32, Run) { run(); } +INSTANTIATE_TEST_CASE_P(BitsetTest, Uint64_32, inputs_bitset); + +using Uint8_64 = BitsetTest; +TEST_P(Uint8_64, Run) { run(); } +INSTANTIATE_TEST_CASE_P(BitsetTest, Uint8_64, inputs_bitset); + +using Uint32_64 = BitsetTest; +TEST_P(Uint32_64, Run) { run(); } +INSTANTIATE_TEST_CASE_P(BitsetTest, Uint32_64, inputs_bitset); + +using Uint64_64 = BitsetTest; +TEST_P(Uint64_64, Run) { run(); } +INSTANTIATE_TEST_CASE_P(BitsetTest, Uint64_64, inputs_bitset); + +} // namespace raft::core diff --git a/docs/source/cpp_api.rst b/docs/source/cpp_api.rst index 0e82d81e35..e60ef4e697 100644 --- a/docs/source/cpp_api.rst +++ b/docs/source/cpp_api.rst @@ -18,4 +18,5 @@ C++ API cpp_api/random.rst cpp_api/solver.rst cpp_api/sparse.rst - cpp_api/stats.rst \ No newline at end of file + cpp_api/stats.rst + cpp_api/utils.rst \ No newline at end of file diff --git a/docs/source/cpp_api/core.rst b/docs/source/cpp_api/core.rst index 7e69f92948..39e57fd69a 100644 --- a/docs/source/cpp_api/core.rst +++ b/docs/source/cpp_api/core.rst @@ -20,4 +20,5 @@ expose in public APIs. core_nvtx.rst core_interruptible.rst core_operators.rst - core_math.rst \ No newline at end of file + core_math.rst + core_bitset.rst \ No newline at end of file diff --git a/docs/source/cpp_api/core_bitset.rst b/docs/source/cpp_api/core_bitset.rst new file mode 100644 index 0000000000..af1cff6d37 --- /dev/null +++ b/docs/source/cpp_api/core_bitset.rst @@ -0,0 +1,15 @@ +Bitset +====== + +.. role:: py(code) + :language: c++ + :class: highlight + +``#include `` + +namespace *raft::core* + +.. doxygengroup:: bitset + :project: RAFT + :members: + :content-only: \ No newline at end of file diff --git a/docs/source/cpp_api/utils.rst b/docs/source/cpp_api/utils.rst new file mode 100644 index 0000000000..4471093c8b --- /dev/null +++ b/docs/source/cpp_api/utils.rst @@ -0,0 +1,21 @@ +Utilities +========= + +RAFT contains numerous utility functions and primitives that are easily usable. +This page provides C++ API references for the publicly-exposed utility functions. + +.. role:: py(code) + :language: c++ + :class: highlight + +Memory Pool +----------- + +``#include `` + +namespace *raft* + +.. doxygengroup:: memory_pool + :project: RAFT + :members: + :content-only: From 687ea813b14d791537b67365dc48fa41a3d6dbda Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Tue, 26 Sep 2023 18:22:02 -0500 Subject: [PATCH 08/10] raft: Build CUDA 12.0 ARM conda packages. (#1853) This PR builds conda packages using CUDA 12 on ARM. Closes #1834. Authors: - Bradley Dice (https://github.com/bdice) Approvers: - Ray Douglass (https://github.com/raydouglass) URL: https://github.com/rapidsai/raft/pull/1853 --- .github/workflows/build.yaml | 16 ++++++++-------- .github/workflows/pr.yaml | 22 +++++++++++----------- .github/workflows/test.yaml | 8 ++++---- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index c2d564dfda..60ea2e4a9f 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -28,7 +28,7 @@ concurrency: jobs: cpp-build: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-build.yaml@branch-23.12 + uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-build.yaml@cuda-120-arm with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -37,7 +37,7 @@ jobs: python-build: needs: [cpp-build] secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-build.yaml@branch-23.12 + uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-build.yaml@cuda-120-arm with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -46,7 +46,7 @@ jobs: upload-conda: needs: [cpp-build, python-build] secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-upload-packages.yaml@branch-23.12 + uses: rapidsai/shared-action-workflows/.github/workflows/conda-upload-packages.yaml@cuda-120-arm with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -57,7 +57,7 @@ jobs: if: github.ref_type == 'branch' needs: python-build secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/custom-job.yaml@branch-23.12 + uses: rapidsai/shared-action-workflows/.github/workflows/custom-job.yaml@cuda-120-arm with: arch: "amd64" branch: ${{ inputs.branch }} @@ -69,7 +69,7 @@ jobs: sha: ${{ inputs.sha }} wheel-build-pylibraft: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-build.yaml@branch-23.12 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-build.yaml@cuda-120-arm with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -79,7 +79,7 @@ jobs: wheel-publish-pylibraft: needs: wheel-build-pylibraft secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-publish.yaml@branch-23.12 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-publish.yaml@cuda-120-arm with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -89,7 +89,7 @@ jobs: wheel-build-raft-dask: needs: wheel-publish-pylibraft secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-build.yaml@branch-23.12 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-build.yaml@cuda-120-arm with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -99,7 +99,7 @@ jobs: wheel-publish-raft-dask: needs: wheel-build-raft-dask secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-publish.yaml@branch-23.12 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-publish.yaml@cuda-120-arm with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 8e917bd91e..c2b318df47 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -23,41 +23,41 @@ jobs: - wheel-build-raft-dask - wheel-tests-raft-dask secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/pr-builder.yaml@branch-23.12 + uses: rapidsai/shared-action-workflows/.github/workflows/pr-builder.yaml@cuda-120-arm checks: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/checks.yaml@branch-23.12 + uses: rapidsai/shared-action-workflows/.github/workflows/checks.yaml@cuda-120-arm with: enable_check_generated_files: false conda-cpp-build: needs: checks secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-build.yaml@branch-23.12 + uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-build.yaml@cuda-120-arm with: build_type: pull-request node_type: cpu16 conda-cpp-tests: needs: conda-cpp-build secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-tests.yaml@branch-23.12 + uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-tests.yaml@cuda-120-arm with: build_type: pull-request conda-python-build: needs: conda-cpp-build secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-build.yaml@branch-23.12 + uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-build.yaml@cuda-120-arm with: build_type: pull-request conda-python-tests: needs: conda-python-build secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-tests.yaml@branch-23.12 + uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-tests.yaml@cuda-120-arm with: build_type: pull-request docs-build: needs: conda-python-build secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/custom-job.yaml@branch-23.12 + uses: rapidsai/shared-action-workflows/.github/workflows/custom-job.yaml@cuda-120-arm with: build_type: pull-request node_type: "gpu-v100-latest-1" @@ -67,28 +67,28 @@ jobs: wheel-build-pylibraft: needs: checks secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-build.yaml@branch-23.12 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-build.yaml@cuda-120-arm with: build_type: pull-request script: ci/build_wheel_pylibraft.sh wheel-tests-pylibraft: needs: wheel-build-pylibraft secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-test.yaml@branch-23.12 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-test.yaml@cuda-120-arm with: build_type: pull-request script: ci/test_wheel_pylibraft.sh wheel-build-raft-dask: needs: wheel-tests-pylibraft secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-build.yaml@branch-23.12 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-build.yaml@cuda-120-arm with: build_type: pull-request script: "ci/build_wheel_raft_dask.sh" wheel-tests-raft-dask: needs: wheel-build-raft-dask secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-test.yaml@branch-23.12 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-test.yaml@cuda-120-arm with: build_type: pull-request script: ci/test_wheel_raft_dask.sh diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 4e45ae29f6..5c0c520188 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -16,7 +16,7 @@ on: jobs: conda-cpp-tests: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-tests.yaml@branch-23.12 + uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-tests.yaml@cuda-120-arm with: build_type: nightly branch: ${{ inputs.branch }} @@ -24,7 +24,7 @@ jobs: sha: ${{ inputs.sha }} conda-python-tests: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-tests.yaml@branch-23.12 + uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-tests.yaml@cuda-120-arm with: build_type: nightly branch: ${{ inputs.branch }} @@ -32,7 +32,7 @@ jobs: sha: ${{ inputs.sha }} wheel-tests-pylibraft: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-test.yaml@branch-23.12 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-test.yaml@cuda-120-arm with: build_type: nightly branch: ${{ inputs.branch }} @@ -41,7 +41,7 @@ jobs: script: ci/test_wheel_pylibraft.sh wheel-tests-raft-dask: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-test.yaml@branch-23.12 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-test.yaml@cuda-120-arm with: build_type: nightly branch: ${{ inputs.branch }} From 6c7cadad24dc7b4c33b6cf121f0926831a1eb5b4 Mon Sep 17 00:00:00 2001 From: tsuki <12711693+enp1s0@users.noreply.github.com> Date: Wed, 27 Sep 2023 07:23:03 +0800 Subject: [PATCH 09/10] Remove block size template parameter from CAGRA search (#1740) This PR removes block size template parameters from CAGRA search kernel functions to reduce the library size and build time. rel: https://github.com/rapidsai/raft/issues/1459 Authors: - tsuki (https://github.com/enp1s0) Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/1740 --- .../detail/cagra/compute_distance.hpp | 6 +- .../raft/neighbors/detail/cagra/hashmap.hpp | 13 +- .../cagra/search_multi_cta_kernel-inl.cuh | 87 ++--- .../detail/cagra/search_multi_kernel.cuh | 2 +- .../detail/cagra/search_single_cta.cuh | 34 +- .../cagra/search_single_cta_kernel-inl.cuh | 322 +++++++----------- .../neighbors/detail/cagra/topk_by_radix.cuh | 20 +- .../detail/cagra/topk_for_cagra/topk_core.cuh | 260 +++++++++----- cpp/test/neighbors/ann_cagra.cuh | 2 +- 9 files changed, 338 insertions(+), 408 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp b/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp index 624c1a35d6..55b7b47508 100644 --- a/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp @@ -133,7 +133,6 @@ _RAFT_DEVICE void compute_distance_to_random_nodes( } template -_RAFT_DEVICE inline void init(IdxT* const table, const unsigned bitlen) +template +_RAFT_DEVICE inline void init(IdxT* const table, const unsigned bitlen, unsigned FIRST_TID = 0) { if (threadIdx.x < FIRST_TID) return; for (unsigned i = threadIdx.x - FIRST_TID; i < get_size(bitlen); i += blockDim.x - FIRST_TID) { @@ -37,15 +37,6 @@ _RAFT_DEVICE inline void init(IdxT* const table, const unsigned bitlen) } } -template -_RAFT_DEVICE inline void init(IdxT* const table, const uint32_t bitlen) -{ - if ((FIRST_TID > 0 && threadIdx.x < FIRST_TID) || threadIdx.x >= LAST_TID) return; - for (unsigned i = threadIdx.x - FIRST_TID; i < get_size(bitlen); i += LAST_TID - FIRST_TID) { - table[i] = utils::get_max_value(); - } -} - template _RAFT_DEVICE inline uint32_t insert(IdxT* const table, const uint32_t bitlen, const IdxT key) { diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh index 8bfbc48898..358a183971 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh @@ -125,8 +125,6 @@ __device__ inline void topk_by_bitonic_sort(float* distances, // [num_elements] // multiple CTAs per single query // template -__launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( +__launch_bounds__(1024, 1) __global__ void search_kernel( INDEX_T* const result_indices_ptr, // [num_queries, num_cta_per_query, itopk_size] DISTANCE_T* const result_distances_ptr, // [num_queries, num_cta_per_query, itopk_size] const DATA_T* const dataset_ptr, // [dataset_size, dataset_dim] @@ -157,7 +155,6 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( uint32_t* const num_executed_iterations, /* stats */ SAMPLE_FILTER_T sample_filter) { - assert(blockDim.x == BLOCK_SIZE); assert(dataset_dim <= MAX_DATASET_DIM); const auto num_queries = gridDim.y; @@ -209,7 +206,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( } #endif const DATA_T* const query_ptr = queries_ptr + (dataset_dim * query_id); - for (unsigned i = threadIdx.x; i < MAX_DATASET_DIM; i += BLOCK_SIZE) { + for (unsigned i = threadIdx.x; i < MAX_DATASET_DIM; i += blockDim.x) { unsigned j = device::swizzling(i); if (i < dataset_dim) { query_buffer[j] = spatial::knn::detail::utils::mapping{}(query_ptr[i]); @@ -276,21 +273,20 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( _CLK_START(); // constexpr unsigned max_n_frags = 16; constexpr unsigned max_n_frags = 0; - device:: - compute_distance_to_child_nodes( - result_indices_buffer + itopk_size, - result_distances_buffer + itopk_size, - query_buffer, - dataset_ptr, - dataset_dim, - dataset_ld, - knn_graph, - graph_degree, - local_visited_hashmap_ptr, - hash_bitlen, - parent_indices_buffer, - result_indices_buffer, - search_width); + device::compute_distance_to_child_nodes( + result_indices_buffer + itopk_size, + result_distances_buffer + itopk_size, + query_buffer, + dataset_ptr, + dataset_dim, + dataset_ld, + knn_graph, + graph_degree, + local_visited_hashmap_ptr, + hash_bitlen, + parent_indices_buffer, + result_indices_buffer, + search_width); _CLK_REC(clk_compute_distance); __syncthreads(); @@ -340,7 +336,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( __syncthreads(); } - for (uint32_t i = threadIdx.x; i < itopk_size; i += BLOCK_SIZE) { + for (uint32_t i = threadIdx.x; i < itopk_size; i += blockDim.x) { uint32_t j = i + (itopk_size * (cta_id + (num_cta_per_query * query_id))); if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[i]; } constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; @@ -414,8 +410,6 @@ struct search_kernel_config { // parameters do not matter, because they are not part of the function signature. The // second to fourth value parameters will be selected by the choose_* functions below. using kernel_t = decltype(&search_kernel kernel_t { if (result_buffer_size <= 64) { - return choose_max_elements<64>(block_size); - } else if (result_buffer_size <= 128) { - return choose_max_elements<128>(block_size); - } else if (result_buffer_size <= 256) { - return choose_max_elements<256>(block_size); - } - THROW("Result buffer size %u larger than max buffer size %u", result_buffer_size, 256); - } - - template - // Todo: rename this to choose block_size - static auto choose_max_elements(unsigned block_size) -> kernel_t - { - if (block_size == 64) { return search_kernel; - } else if (block_size == 128) { + } else if (result_buffer_size <= 128) { return search_kernel; - } else if (block_size == 256) { + } else if (result_buffer_size <= 256) { return search_kernel; - } else if (block_size == 512) { - return search_kernel; - } else { - return search_kernel; } + THROW("Result buffer size %u larger than max buffer size %u", result_buffer_size, 256); } }; diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh index ff1bb969e7..5dcfcb3929 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh @@ -254,7 +254,7 @@ __global__ void pickup_next_parents_kernel( if ((num_new_parents > 0) && (threadIdx.x == 0)) { *terminate_flag = 0; } } else if (small_hash_bitlen) { // reset small-hash - hashmap::init<32>(visited_hashmap_ptr + (ldb * query_id), hash_bitlen); + hashmap::init(visited_hashmap_ptr + (ldb * query_id), hash_bitlen, 32); } if (small_hash_bitlen) { diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh index 27d54f72cb..b36bc6f77b 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh @@ -131,11 +131,9 @@ struct search : search_plan_impl { // Tentatively calculate the required share memory size when radix // sort based topk is used, assuming the block size is the maximum. if (itopk_size <= 256) { - smem_size += - topk_by_radix_sort<256, max_block_size, INDEX_T>::smem_size * sizeof(std::uint32_t); + smem_size += topk_by_radix_sort<256, INDEX_T>::smem_size * sizeof(std::uint32_t); } else { - smem_size += - topk_by_radix_sort<512, max_block_size, INDEX_T>::smem_size * sizeof(std::uint32_t); + smem_size += topk_by_radix_sort<512, INDEX_T>::smem_size * sizeof(std::uint32_t); } } @@ -188,34 +186,10 @@ struct search : search_plan_impl { smem_size = base_smem_size; if (itopk_size <= 256) { constexpr unsigned MAX_ITOPK = 256; - if (block_size == 256) { - constexpr unsigned BLOCK_SIZE = 256; - smem_size += - topk_by_radix_sort::smem_size * sizeof(std::uint32_t); - } else if (block_size == 512) { - constexpr unsigned BLOCK_SIZE = 512; - smem_size += - topk_by_radix_sort::smem_size * sizeof(std::uint32_t); - } else { - constexpr unsigned BLOCK_SIZE = 1024; - smem_size += - topk_by_radix_sort::smem_size * sizeof(std::uint32_t); - } + smem_size += topk_by_radix_sort::smem_size * sizeof(std::uint32_t); } else { constexpr unsigned MAX_ITOPK = 512; - if (block_size == 256) { - constexpr unsigned BLOCK_SIZE = 256; - smem_size += - topk_by_radix_sort::smem_size * sizeof(std::uint32_t); - } else if (block_size == 512) { - constexpr unsigned BLOCK_SIZE = 512; - smem_size += - topk_by_radix_sort::smem_size * sizeof(std::uint32_t); - } else { - constexpr unsigned BLOCK_SIZE = 1024; - smem_size += - topk_by_radix_sort::smem_size * sizeof(std::uint32_t); - } + smem_size += topk_by_radix_sort::smem_size * sizeof(std::uint32_t); } } RAFT_LOG_DEBUG("# smem_size: %u", smem_size); diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh index 128dc8d116..3a5501f545 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh @@ -90,11 +90,12 @@ __device__ void pickup_next_parents(std::uint32_t* const terminate_flag, if (threadIdx.x == 0 && (num_new_parents == 0)) { *terminate_flag = 1; } } -template +template __device__ inline void topk_by_bitonic_sort_1st(float* candidate_distances, // [num_candidates] IdxT* candidate_indices, // [num_candidates] const std::uint32_t num_candidates, - const std::uint32_t num_itopk) + const std::uint32_t num_itopk, + unsigned MULTI_WARPS = 0) { const unsigned lane_id = threadIdx.x % 32; const unsigned warp_id = threadIdx.x / 32; @@ -192,7 +193,7 @@ __device__ inline void topk_by_bitonic_sort_1st(float* candidate_distances, // } } -template +template __device__ inline void topk_by_bitonic_sort_2nd(float* itopk_distances, // [num_itopk] IdxT* itopk_indices, // [num_itopk] const std::uint32_t num_itopk, @@ -200,7 +201,8 @@ __device__ inline void topk_by_bitonic_sort_2nd(float* itopk_distances, // [num IdxT* candidate_indices, // [num_candidates] const std::uint32_t num_candidates, std::uint32_t* work_buf, - const bool first) + const bool first, + unsigned MULTI_WARPS = 0) { const unsigned lane_id = threadIdx.x % 32; const unsigned warp_id = threadIdx.x / 32; @@ -399,8 +401,6 @@ __device__ inline void topk_by_bitonic_sort_2nd(float* itopk_distances, // [num template __device__ void topk_by_bitonic_sort(float* itopk_distances, // [num_itopk] IdxT* itopk_indices, // [num_itopk] @@ -409,33 +409,37 @@ __device__ void topk_by_bitonic_sort(float* itopk_distances, // [num_itopk] IdxT* candidate_indices, // [num_candidates] const std::uint32_t num_candidates, std::uint32_t* work_buf, - const bool first) + const bool first, + const unsigned MULTI_WARPS_1, + const unsigned MULTI_WARPS_2) { // The results in candidate_distances/indices are sorted by bitonic sort. - topk_by_bitonic_sort_1st( - candidate_distances, candidate_indices, num_candidates, num_itopk); + topk_by_bitonic_sort_1st( + candidate_distances, candidate_indices, num_candidates, num_itopk, MULTI_WARPS_1); // The results sorted above are merged with the internal intermediate top-k // results so far using bitonic merge. - topk_by_bitonic_sort_2nd(itopk_distances, - itopk_indices, - num_itopk, - candidate_distances, - candidate_indices, - num_candidates, - work_buf, - first); + topk_by_bitonic_sort_2nd(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first, + MULTI_WARPS_2); } -template +template __device__ inline void hashmap_restore(INDEX_T* const hashmap_ptr, const size_t hashmap_bitlen, const INDEX_T* itopk_indices, - uint32_t itopk_size) + const uint32_t itopk_size, + const uint32_t first_tid = 0) { constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - if (threadIdx.x < FIRST_TID || threadIdx.x >= LAST_TID) return; - for (unsigned i = threadIdx.x - FIRST_TID; i < itopk_size; i += LAST_TID - FIRST_TID) { + if (threadIdx.x < first_tid) return; + for (unsigned i = threadIdx.x - first_tid; i < itopk_size; i += blockDim.x - first_tid) { auto key = itopk_indices[i] & ~index_msb_1_mask; // clear most significant bit hashmap::insert(hashmap_ptr, hashmap_bitlen, key); } @@ -451,8 +455,6 @@ __device__ inline void set_value_device(T* const ptr, const T fill, const std::u // One query one thread block template -__launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ +__launch_bounds__(1024, 1) __global__ void search_kernel(INDEX_T* const result_indices_ptr, // [num_queries, top_k] DISTANCE_T* const result_distances_ptr, // [num_queries, top_k] const std::uint32_t top_k, @@ -534,7 +536,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ auto filter_flag = terminate_flag; const DATA_T* const query_ptr = queries_ptr + query_id * dataset_dim; - for (unsigned i = threadIdx.x; i < MAX_DATASET_DIM; i += BLOCK_SIZE) { + for (unsigned i = threadIdx.x; i < MAX_DATASET_DIM; i += blockDim.x) { unsigned j = device::swizzling(i); if (i < dataset_dim) { query_buffer[j] = spatial::knn::detail::utils::mapping{}(query_ptr[i]); @@ -554,7 +556,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ } else { local_visited_hashmap_ptr = visited_hashmap_ptr + (hashmap::get_size(hash_bitlen) * query_id); } - hashmap::init<0, BLOCK_SIZE>(local_visited_hashmap_ptr, hash_bitlen); + hashmap::init(local_visited_hashmap_ptr, hash_bitlen, 0); __syncthreads(); _CLK_REC(clk_init); @@ -590,8 +592,8 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ // topk_by_bitonic_sort() consists of two operations: // if MAX_CANDIDATES is greater than 128, the first operation uses two warps; // if MAX_ITOPK is greater than 256, the second operation used two warps. - constexpr unsigned multi_warps_1 = ((BLOCK_SIZE >= 64) && (MAX_CANDIDATES > 128)) ? 1 : 0; - constexpr unsigned multi_warps_2 = ((BLOCK_SIZE >= 64) && (MAX_ITOPK > 256)) ? 1 : 0; + const unsigned multi_warps_1 = ((blockDim.x >= 64) && (MAX_CANDIDATES > 128)) ? 1 : 0; + const unsigned multi_warps_2 = ((blockDim.x >= 64) && (MAX_ITOPK > 256)) ? 1 : 0; // reset small-hash table. if ((iter + 1) % small_hash_reset_interval == 0) { @@ -600,21 +602,23 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ // the small hash and whether they are performed in overlap with // topk_by_bitonic_sort(). _CLK_START(); - if (BLOCK_SIZE == 32) { - hashmap::init<0, BLOCK_SIZE>(local_visited_hashmap_ptr, hash_bitlen); - } else if (BLOCK_SIZE == 64) { + unsigned hash_start_tid; + if (blockDim.x == 32) { + hash_start_tid = 0; + } else if (blockDim.x == 64) { if (multi_warps_1 || multi_warps_2) { - hashmap::init<0, BLOCK_SIZE>(local_visited_hashmap_ptr, hash_bitlen); + hash_start_tid = 0; } else { - hashmap::init<32, BLOCK_SIZE>(local_visited_hashmap_ptr, hash_bitlen); + hash_start_tid = 32; } } else { if (multi_warps_1 || multi_warps_2) { - hashmap::init<64, BLOCK_SIZE>(local_visited_hashmap_ptr, hash_bitlen); + hash_start_tid = 64; } else { - hashmap::init<32, BLOCK_SIZE>(local_visited_hashmap_ptr, hash_bitlen); + hash_start_tid = 32; } } + hashmap::init(local_visited_hashmap_ptr, hash_bitlen, hash_start_tid); _CLK_REC(clk_reset_hash); } @@ -623,29 +627,31 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ if (std::is_same::value || *filter_flag == 0) { - topk_by_bitonic_sort( - result_distances_buffer, - result_indices_buffer, - internal_topk, - result_distances_buffer + internal_topk, - result_indices_buffer + internal_topk, - search_width * graph_degree, - topk_ws, - (iter == 0)); + topk_by_bitonic_sort(result_distances_buffer, + result_indices_buffer, + internal_topk, + result_distances_buffer + internal_topk, + result_indices_buffer + internal_topk, + search_width * graph_degree, + topk_ws, + (iter == 0), + multi_warps_1, + multi_warps_2); __syncthreads(); } else { - topk_by_bitonic_sort_1st( + topk_by_bitonic_sort_1st( result_distances_buffer, result_indices_buffer, internal_topk + search_width * graph_degree, - internal_topk); + internal_topk, + false); if (threadIdx.x == 0) { *terminate_flag = 0; } } _CLK_REC(clk_topk); } else { _CLK_START(); // topk with radix block sort - topk_by_radix_sort{}( + topk_by_radix_sort{}( internal_topk, gridDim.x, result_buffer_size, @@ -662,7 +668,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ // reset small-hash table if ((iter + 1) % small_hash_reset_interval == 0) { _CLK_START(); - hashmap::init<0, BLOCK_SIZE>(local_visited_hashmap_ptr, hash_bitlen); + hashmap::init(local_visited_hashmap_ptr, hash_bitlen); _CLK_REC(clk_reset_hash); } } @@ -684,10 +690,10 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ // restore small-hash table by putting internal-topk indices in it if ((iter + 1) % small_hash_reset_interval == 0) { - constexpr unsigned first_tid = ((BLOCK_SIZE <= 32) ? 0 : 32); + const unsigned first_tid = ((blockDim.x <= 32) ? 0 : 32); _CLK_START(); - hashmap_restore( - local_visited_hashmap_ptr, hash_bitlen, result_indices_buffer, internal_topk); + hashmap_restore( + local_visited_hashmap_ptr, hash_bitlen, result_indices_buffer, internal_topk, first_tid); _CLK_REC(clk_restore_hash); } __syncthreads(); @@ -697,21 +703,20 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ // compute the norms between child nodes and query node _CLK_START(); constexpr unsigned max_n_frags = 16; - device:: - compute_distance_to_child_nodes( - result_indices_buffer + internal_topk, - result_distances_buffer + internal_topk, - query_buffer, - dataset_ptr, - dataset_dim, - dataset_ld, - knn_graph, - graph_degree, - local_visited_hashmap_ptr, - hash_bitlen, - parent_list_buffer, - result_indices_buffer, - search_width); + device::compute_distance_to_child_nodes( + result_indices_buffer + internal_topk, + result_distances_buffer + internal_topk, + query_buffer, + dataset_ptr, + dataset_dim, + dataset_ld, + knn_graph, + graph_degree, + local_visited_hashmap_ptr, + hash_bitlen, + parent_list_buffer, + result_indices_buffer, + search_width); __syncthreads(); _CLK_REC(clk_compute_distance); @@ -757,15 +762,16 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ } __syncthreads(); - topk_by_bitonic_sort_1st( + topk_by_bitonic_sort_1st( result_distances_buffer, result_indices_buffer, internal_topk + search_width * graph_degree, - top_k); + top_k, + false); __syncthreads(); } - for (std::uint32_t i = threadIdx.x; i < top_k; i += BLOCK_SIZE) { + for (std::uint32_t i = threadIdx.x; i < top_k; i += blockDim.x) { unsigned j = i + (top_k * query_id); unsigned ii = i; if (TOPK_BY_BITONIC_SORT) { ii = device::swizzling(i); } @@ -811,114 +817,45 @@ template struct search_kernel_config { using kernel_t = - decltype(&search_kernel); + decltype(&search_kernel); - template - static auto choose_block_size(unsigned block_size) -> kernel_t + template + static auto choose_search_kernel(unsigned itopk_size) -> kernel_t { - constexpr unsigned BS = USE_BITONIC_SORT; - if constexpr (BS) { - if (block_size == 64) { - return search_kernel; - } else if (block_size == 128) { - return search_kernel; - } else if (block_size == 256) { - return search_kernel; - } else if (block_size == 512) { - return search_kernel; - } else { - return search_kernel; - } - - } else { - if (block_size == 256) { - return search_kernel; - } else if (block_size == 512) { - return search_kernel; - } else { - return search_kernel; - } + if (itopk_size <= 64) { + return search_kernel; + } else if (itopk_size <= 128) { + return search_kernel; + } else if (itopk_size <= 256) { + return search_kernel; + } else if (itopk_size <= 512) { + return search_kernel; } + THROW("No kernel for parametels itopk_size %u, max_candidates %u", itopk_size, MAX_CANDIDATES); } static auto choose_itopk_and_mx_candidates(unsigned itopk_size, @@ -927,45 +864,18 @@ struct search_kernel_config { { if (num_itopk_candidates <= 64) { // use bitonic sort based topk - constexpr unsigned max_candidates = 64; - if (itopk_size <= 64) { - return choose_block_size<64, max_candidates, 1>(block_size); - } else if (itopk_size <= 128) { - return choose_block_size<128, max_candidates, 1>(block_size); - } else if (itopk_size <= 256) { - return choose_block_size<256, max_candidates, 1>(block_size); - } else if (itopk_size <= 512) { - return choose_block_size<512, max_candidates, 1>(block_size); - } + return choose_search_kernel<64, 1>(itopk_size); } else if (num_itopk_candidates <= 128) { - constexpr unsigned max_candidates = 128; - if (itopk_size <= 64) { - return choose_block_size<64, max_candidates, 1>(block_size); - } else if (itopk_size <= 128) { - return choose_block_size<128, max_candidates, 1>(block_size); - } else if (itopk_size <= 256) { - return choose_block_size<256, max_candidates, 1>(block_size); - } else if (itopk_size <= 512) { - return choose_block_size<512, max_candidates, 1>(block_size); - } + return choose_search_kernel<128, 1>(itopk_size); } else if (num_itopk_candidates <= 256) { - constexpr unsigned max_candidates = 256; - if (itopk_size <= 64) { - return choose_block_size<64, max_candidates, 1>(block_size); - } else if (itopk_size <= 128) { - return choose_block_size<128, max_candidates, 1>(block_size); - } else if (itopk_size <= 256) { - return choose_block_size<256, max_candidates, 1>(block_size); - } else if (itopk_size <= 512) { - return choose_block_size<512, max_candidates, 1>(block_size); - } + return choose_search_kernel<256, 1>(itopk_size); } else { // Radix-based topk is used constexpr unsigned max_candidates = 32; // to avoid build failure if (itopk_size <= 256) { - return choose_block_size<256, max_candidates, 0>(block_size); + return search_kernel; } else if (itopk_size <= 512) { - return choose_block_size<512, max_candidates, 0>(block_size); + return search_kernel; } } THROW("No kernel for parametels itopk_size %u, num_itopk_candidates %u", diff --git a/cpp/include/raft/neighbors/detail/cagra/topk_by_radix.cuh b/cpp/include/raft/neighbors/detail/cagra/topk_by_radix.cuh index a1b7f930d3..6a6a3cddf4 100644 --- a/cpp/include/raft/neighbors/detail/cagra/topk_by_radix.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/topk_by_radix.cuh @@ -26,14 +26,11 @@ struct topk_by_radix_sort_base { static constexpr std::uint32_t state_bit_lenght = 0; static constexpr std::uint32_t vecLen = 2; // TODO }; -template +template struct topk_by_radix_sort : topk_by_radix_sort_base {}; -template -struct topk_by_radix_sort> +template +struct topk_by_radix_sort> : topk_by_radix_sort_base { __device__ void operator()(uint32_t topk, uint32_t batch_size, @@ -48,8 +45,7 @@ struct topk_by_radix_sort(work); - topk_cta_11_core::state_bit_lenght, + topk_cta_11_core::state_bit_lenght, topk_by_radix_sort_base::vecLen, 64, 32, @@ -58,10 +54,9 @@ struct topk_by_radix_sort \ + template \ struct topk_by_radix_sort< \ MAX_INTERNAL_TOPK, \ - BLOCK_SIZE, \ IdxT, \ std::enable_if_t<((MAX_INTERNAL_TOPK <= V) && (2 * MAX_INTERNAL_TOPK > V))>> \ : topk_by_radix_sort_base { \ @@ -77,10 +72,9 @@ struct topk_by_radix_sort= V / 4); \ + assert(blockDim.x >= V / 4); \ std::uint8_t* state = (std::uint8_t*)work; \ - topk_cta_11_core::state_bit_lenght, \ + topk_cta_11_core::state_bit_lenght, \ topk_by_radix_sort_base::vecLen, \ V, \ V / 4, \ diff --git a/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh b/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh index 0fcfe2cc16..fd4aeb9bb3 100644 --- a/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh @@ -22,8 +22,6 @@ #include namespace raft::neighbors::cagra::detail { -using namespace cub; - // __device__ inline uint32_t convert(uint32_t x) { @@ -174,8 +172,46 @@ __device__ inline uint16_t get_element_from_u16_vector(struct u16_vector& vec, i return xi; } +template +__device__ inline void block_scan(const T input, T& output) +{ + switch (blockDim.x) { + case 32: { + typedef cub::BlockScan BlockScanT; + __shared__ typename BlockScanT::TempStorage temp_storage; + BlockScanT(temp_storage).InclusiveSum(input, output); + } break; + case 64: { + typedef cub::BlockScan BlockScanT; + __shared__ typename BlockScanT::TempStorage temp_storage; + BlockScanT(temp_storage).InclusiveSum(input, output); + } break; + case 128: { + typedef cub::BlockScan BlockScanT; + __shared__ typename BlockScanT::TempStorage temp_storage; + BlockScanT(temp_storage).InclusiveSum(input, output); + } break; + case 256: { + typedef cub::BlockScan BlockScanT; + __shared__ typename BlockScanT::TempStorage temp_storage; + BlockScanT(temp_storage).InclusiveSum(input, output); + } break; + case 512: { + typedef cub::BlockScan BlockScanT; + __shared__ typename BlockScanT::TempStorage temp_storage; + BlockScanT(temp_storage).InclusiveSum(input, output); + } break; + case 1024: { + typedef cub::BlockScan BlockScanT; + __shared__ typename BlockScanT::TempStorage temp_storage; + BlockScanT(temp_storage).InclusiveSum(input, output); + } break; + default: break; + } +} + // -template +template __device__ inline void update_histogram(int itr, uint32_t thread_id, uint32_t num_threads, @@ -220,7 +256,7 @@ __device__ inline void update_histogram(int itr, return; } if (itr > 0) { - for (int i = threadIdx.x; i < num_bins; i += blockDim_x) { + for (int i = threadIdx.x; i < num_bins; i += blockDim.x) { hist[i] = 0; } __syncthreads(); @@ -285,8 +321,53 @@ __device__ inline void update_histogram(int itr, __syncthreads(); } +template +__device__ inline void select_best_index_for_next_threshold_core(uint32_t& my_index, + uint32_t& my_csum, + const unsigned num_bins, + const uint32_t* const hist, + const uint32_t nx_below_threshold, + const uint32_t max_threshold, + const uint32_t threshold, + const uint32_t shift, + const uint32_t topk) +{ + typedef cub::BlockScan BlockScanT; + __shared__ typename BlockScanT::TempStorage temp_storage; + if (num_bins == 2048) { + constexpr int n_data = 2048 / blockDim_x; + uint32_t csum[n_data]; + for (int i = 0; i < n_data; i++) { + csum[i] = hist[i + (n_data * threadIdx.x)]; + } + BlockScanT(temp_storage).InclusiveSum(csum, csum); + for (int i = n_data - 1; i >= 0; i--) { + if (nx_below_threshold + csum[i] > topk) continue; + const uint32_t index = i + (n_data * threadIdx.x); + if (threshold + (index << shift) > max_threshold) continue; + my_index = index; + my_csum = csum[i]; + break; + } + } else if (num_bins == 1024) { + constexpr int n_data = 1024 / blockDim_x; + uint32_t csum[n_data]; + for (int i = 0; i < n_data; i++) { + csum[i] = hist[i + (n_data * threadIdx.x)]; + } + BlockScanT(temp_storage).InclusiveSum(csum, csum); + for (int i = n_data - 1; i >= 0; i--) { + if (nx_below_threshold + csum[i] > topk) continue; + const uint32_t index = i + (n_data * threadIdx.x); + if (threshold + (index << shift) > max_threshold) continue; + my_index = index; + my_csum = csum[i]; + break; + } + } +} + // -template __device__ inline void select_best_index_for_next_threshold( const uint32_t topk, const uint32_t threshold, @@ -302,15 +383,12 @@ __device__ inline void select_best_index_for_next_threshold( // index under the condition that the sum of the number of elements found // so far ('nx_below_threshold') and the csum value does not exceed the // topk value. - typedef BlockScan BlockScanT; - __shared__ typename BlockScanT::TempStorage temp_storage; - uint32_t my_index = 0xffffffff; uint32_t my_csum = 0; - if (num_bins <= blockDim_x) { + if (num_bins <= blockDim.x) { uint32_t csum = 0; if (threadIdx.x < num_bins) { csum = hist[threadIdx.x]; } - BlockScanT(temp_storage).InclusiveSum(csum, csum); + detail::block_scan(csum, csum); if (threadIdx.x < num_bins) { const uint32_t index = threadIdx.x; if ((nx_below_threshold + csum <= topk) && (threshold + (index << shift) <= max_threshold)) { @@ -319,36 +397,62 @@ __device__ inline void select_best_index_for_next_threshold( } } } else { - if (num_bins == 2048) { - constexpr int n_data = 2048 / blockDim_x; - uint32_t csum[n_data]; - for (int i = 0; i < n_data; i++) { - csum[i] = hist[i + (n_data * threadIdx.x)]; - } - BlockScanT(temp_storage).InclusiveSum(csum, csum); - for (int i = n_data - 1; i >= 0; i--) { - if (nx_below_threshold + csum[i] > topk) continue; - const uint32_t index = i + (n_data * threadIdx.x); - if (threshold + (index << shift) > max_threshold) continue; - my_index = index; - my_csum = csum[i]; + switch (blockDim.x) { + case 64: + select_best_index_for_next_threshold_core<64>(my_index, + my_csum, + num_bins, + hist, + nx_below_threshold, + max_threshold, + threshold, + shift, + topk); break; - } - } else if (num_bins == 1024) { - constexpr int n_data = 1024 / blockDim_x; - uint32_t csum[n_data]; - for (int i = 0; i < n_data; i++) { - csum[i] = hist[i + (n_data * threadIdx.x)]; - } - BlockScanT(temp_storage).InclusiveSum(csum, csum); - for (int i = n_data - 1; i >= 0; i--) { - if (nx_below_threshold + csum[i] > topk) continue; - const uint32_t index = i + (n_data * threadIdx.x); - if (threshold + (index << shift) > max_threshold) continue; - my_index = index; - my_csum = csum[i]; + case 128: + select_best_index_for_next_threshold_core<128>(my_index, + my_csum, + num_bins, + hist, + nx_below_threshold, + max_threshold, + threshold, + shift, + topk); + break; + case 256: + select_best_index_for_next_threshold_core<256>(my_index, + my_csum, + num_bins, + hist, + nx_below_threshold, + max_threshold, + threshold, + shift, + topk); + break; + case 512: + select_best_index_for_next_threshold_core<512>(my_index, + my_csum, + num_bins, + hist, + nx_below_threshold, + max_threshold, + threshold, + shift, + topk); + break; + case 1024: + select_best_index_for_next_threshold_core<1024>(my_index, + my_csum, + num_bins, + hist, + nx_below_threshold, + max_threshold, + threshold, + shift, + topk); break; - } } } if (threadIdx.x < num_bins) { @@ -481,10 +585,14 @@ __device__ inline uint32_t max_value_of() return ~0u; } -template +template __device__ __host__ inline uint32_t get_state_size(uint32_t len_x) { - const uint32_t num_threads = blockDim_x; +#ifdef __CUDA_ARCH__ + const uint32_t num_threads = blockDim.x; +#else + const uint32_t num_threads = BLOCK_SIZE; +#endif if (stateBitLen == 8) { uint32_t numElements_perThread = (len_x + num_threads - 1) / num_threads; uint32_t numState_perThread = (numElements_perThread + stateBitLen - 1) / stateBitLen; @@ -494,7 +602,7 @@ __device__ __host__ inline uint32_t get_state_size(uint32_t len_x) } // -template +template __device__ inline void topk_cta_11_core(uint32_t topk, uint32_t len_x, const uint32_t* _x, // [size_batch, ld_x,] @@ -511,7 +619,7 @@ __device__ inline void topk_cta_11_core(uint32_t topk, uint32_t* const best_index = &(_smem[2 * maxTopk + 2048]); uint32_t* const best_csum = &(_smem[2 * maxTopk + 2048 + 3]); - const uint32_t num_threads = blockDim_x; + const uint32_t num_threads = blockDim.x; const uint32_t thread_id = threadIdx.x; uint32_t nx = len_x; const uint32_t* const x = _x; @@ -541,29 +649,29 @@ __device__ inline void topk_cta_11_core(uint32_t topk, for (int j = 0; j < 3; j += 1) { uint32_t num_bins; uint32_t shift; - update_histogram(j, - thread_id, - num_threads, - hint, - threshold, - num_bins, - shift, - x, - nx, - hist, - state, - smem_out_vals, - output_count); - - select_best_index_for_next_threshold(topk, - threshold, - hint, - nx_below_threshold, - num_bins, - shift, - hist, - best_index + j, - best_csum + j); + + update_histogram(j, + thread_id, + num_threads, + hint, + threshold, + num_bins, + shift, + x, + nx, + hist, + state, + smem_out_vals, + output_count); + select_best_index_for_next_threshold(topk, + threshold, + hint, + nx_below_threshold, + num_bins, + shift, + hist, + best_index + j, + best_csum + j); threshold += (best_index[j] << shift); nx_below_threshold += best_csum[j]; @@ -601,7 +709,7 @@ __device__ inline void topk_cta_11_core(uint32_t topk, #endif if (!sort) { - for (int k = thread_id; k < topk; k += blockDim_x) { + for (int k = thread_id; k < topk; k += blockDim.x) { const uint32_t i = smem_out_vals[k]; if (y) { y[k] = x[i]; } if (out_vals) { @@ -756,7 +864,7 @@ int _get_vecLen(uint32_t maxSamples, int maxVecLen = MAX_VEC_LENGTH) } } // unnamed namespace -template +template __launch_bounds__(1024, 1) __global__ void kern_topk_cta_11(uint32_t topk, uint32_t size_batch, @@ -781,14 +889,14 @@ __launch_bounds__(1024, 1) __global__ "maxTopk * sizeof(ValT) must be smaller or equal to 8192 byte"); __shared__ uint32_t _smem[smem_len]; - topk_cta_11_core( + topk_cta_11_core( topk, len_x, (_x == NULL ? NULL : _x + i_batch * ld_x), (_in_vals == NULL ? NULL : _in_vals + i_batch * ld_iv), (_y == NULL ? NULL : _y + i_batch * ld_y), (_out_vals == NULL ? NULL : _out_vals + i_batch * ld_ov), - (_state == NULL ? NULL : _state + i_batch * get_state_size(len_x)), + (_state == NULL ? NULL : _state + i_batch * get_state_size(len_x)), (_hints == NULL ? NULL : _hints + i_batch), sort, _smem); @@ -808,7 +916,7 @@ size_t inline _cuann_find_topk_bufferSize(uint32_t topK, // state if (stateBitLen == 8) { workspaceSize = _cuann_aligned( - sizeof(uint8_t) * get_state_size(numElements) * sizeBatch); + sizeof(uint8_t) * get_state_size(numElements) * sizeBatch); } return workspaceSize; @@ -862,12 +970,12 @@ inline void _cuann_find_topk(uint32_t topK, bool) = nullptr; // V:vecLen, K:maxTopk, T:numSortThreads -#define SET_KERNEL_VKT(V, K, T, ValT) \ - do { \ - assert(numThreads >= T); \ - assert((K % T) == 0); \ - assert((K / T) <= 4); \ - cta_kernel = kern_topk_cta_11; \ +#define SET_KERNEL_VKT(V, K, T, ValT) \ + do { \ + assert(numThreads >= T); \ + assert((K % T) == 0); \ + assert((K / T) <= 4); \ + cta_kernel = kern_topk_cta_11; \ } while (0) // V: vecLen diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 343afd04ec..b750372244 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -49,7 +49,7 @@ namespace { /* A filter that excludes all indices below `offset`. */ struct test_cagra_sample_filter { - static constexpr unsigned offset = 400; + static constexpr unsigned offset = 300; inline _RAFT_HOST_DEVICE auto operator()( // query index const uint32_t query_ix, From 25858c598ce8b5ad56ed9d84722c37cb5bf2e61b Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 26 Sep 2023 17:31:16 -0700 Subject: [PATCH 10/10] Add index class for brute_force knn (#1817) This adds an index class to match the ANN methods. This allows us to precompute norms for the dataset in `brute_force::build` and then use them in `brute_force::search` - meaning we don't have to compute norms for the entire dataset on every query. Authors: - Ben Frederickson (https://github.com/benfred) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1817 --- cpp/CMakeLists.txt | 32 ++-- .../raft/neighbors/brute_force-ext.cuh | 39 ++++- .../raft/neighbors/brute_force-inl.cuh | 98 +++++++++++- .../raft/neighbors/brute_force_types.hpp | 144 ++++++++++++++++++ .../raft/neighbors/detail/knn_brute_force.cuh | 49 +++--- .../spatial/knn/detail/fused_l2_knn-ext.cuh | 8 +- .../spatial/knn/detail/fused_l2_knn-inl.cuh | 46 ++++-- .../neighbors/brute_force_knn_index_float.cu | 39 +++++ .../knn/detail/fused_l2_knn_int32_t_float.cu | 4 +- .../knn/detail/fused_l2_knn_int64_t_float.cu | 4 +- .../knn/detail/fused_l2_knn_uint32_t_float.cu | 4 +- cpp/template/CMakeLists.txt | 1 - cpp/test/neighbors/tiled_knn.cu | 30 ++++ 13 files changed, 446 insertions(+), 52 deletions(-) create mode 100644 cpp/include/raft/neighbors/brute_force_types.hpp create mode 100644 cpp/src/neighbors/brute_force_knn_index_float.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index d93b19f784..7d63751906 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -22,7 +22,8 @@ include(rapids-find) option(BUILD_CPU_ONLY "Build CPU only components. Applies to RAFT ANN benchmarks currently" OFF) -# workaround for rapids_cuda_init_architectures not working for arch detection with enable_language(CUDA) +# workaround for rapids_cuda_init_architectures not working for arch detection with +# enable_language(CUDA) set(lang_list "CXX") if(NOT BUILD_CPU_ONLY) @@ -286,7 +287,8 @@ endif() set_target_properties(raft_compiled PROPERTIES EXPORT_NAME compiled) if(RAFT_COMPILE_LIBRARY) - add_library(raft_objs OBJECT + add_library( + raft_objs OBJECT src/core/logger.cpp src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu @@ -331,6 +333,7 @@ if(RAFT_COMPILE_LIBRARY) src/neighbors/brute_force_knn_int64_t_float_uint32_t.cu src/neighbors/brute_force_knn_int_float_int.cu src/neighbors/brute_force_knn_uint32_t_float_uint32_t.cu + src/neighbors/brute_force_knn_index_float.cu src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu @@ -452,18 +455,21 @@ if(RAFT_COMPILE_LIBRARY) src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu src/util/memory_pool.cpp - ) + ) set_target_properties( raft_objs PROPERTIES CXX_STANDARD 17 CXX_STANDARD_REQUIRED ON CUDA_STANDARD 17 CUDA_STANDARD_REQUIRED ON - POSITION_INDEPENDENT_CODE ON) + POSITION_INDEPENDENT_CODE ON + ) target_compile_definitions(raft_objs PRIVATE "RAFT_EXPLICIT_INSTANTIATE_ONLY") - target_compile_options(raft_objs PRIVATE "$<$:${RAFT_CXX_FLAGS}>" - "$<$:${RAFT_CUDA_FLAGS}>") + target_compile_options( + raft_objs PRIVATE "$<$:${RAFT_CXX_FLAGS}>" + "$<$:${RAFT_CUDA_FLAGS}>" + ) add_library(raft_lib SHARED $) add_library(raft_lib_static STATIC $) @@ -477,13 +483,15 @@ if(RAFT_COMPILE_LIBRARY) ) foreach(target raft_lib raft_lib_static raft_objs) - target_link_libraries(${target} PUBLIC - raft::raft - ${RAFT_CTK_MATH_DEPENDENCIES} # TODO: Once `raft::resources` is used everywhere, this - # will just be cublas - $) + target_link_libraries( + ${target} + PUBLIC raft::raft + ${RAFT_CTK_MATH_DEPENDENCIES} # TODO: Once `raft::resources` is used everywhere, this + # will just be cublas + $ + ) - #So consumers know when using libraft.so/libraft.a + # So consumers know when using libraft.so/libraft.a target_compile_definitions(${target} PUBLIC "RAFT_COMPILED") # ensure CUDA symbols aren't relocated to the middle of the debug build binaries target_link_options(${target} PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/fatbin.ld") diff --git a/cpp/include/raft/neighbors/brute_force-ext.cuh b/cpp/include/raft/neighbors/brute_force-ext.cuh index 862db75866..b8c00616da 100644 --- a/cpp/include/raft/neighbors/brute_force-ext.cuh +++ b/cpp/include/raft/neighbors/brute_force-ext.cuh @@ -22,7 +22,8 @@ #include // raft::identity_op #include // raft::resources #include // raft::distance::DistanceType -#include // RAFT_EXPLICIT +#include +#include // RAFT_EXPLICIT #ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY @@ -38,6 +39,19 @@ inline void knn_merge_parts( size_t n_samples, std::optional> translations = std::nullopt) RAFT_EXPLICIT; +template +index build(raft::resources const& res, + mdspan, row_major, Accessor> dataset, + raft::distance::DistanceType metric = distance::DistanceType::L2Unexpanded, + T metric_arg = 0.0) RAFT_EXPLICIT; + +template +void search(raft::resources const& res, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) RAFT_EXPLICIT; + template ( + raft::resources const& res, + const raft::neighbors::brute_force::index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); + +extern template void search( + raft::resources const& res, + const raft::neighbors::brute_force::index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); + +extern template raft::neighbors::brute_force::index build( + raft::resources const& res, + raft::device_matrix_view dataset, + raft::distance::DistanceType metric, + float metric_arg); +} // namespace raft::neighbors::brute_force + #define instantiate_raft_neighbors_brute_force_fused_l2_knn( \ value_t, idx_t, idx_layout, query_layout) \ extern template void raft::neighbors::brute_force::fused_l2_knn( \ diff --git a/cpp/include/raft/neighbors/brute_force-inl.cuh b/cpp/include/raft/neighbors/brute_force-inl.cuh index bc9e09e5b0..88439a738b 100644 --- a/cpp/include/raft/neighbors/brute_force-inl.cuh +++ b/cpp/include/raft/neighbors/brute_force-inl.cuh @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -280,6 +281,101 @@ void fused_l2_knn(raft::resources const& handle, metric); } -/** @} */ // end group brute_force_knn +/** + * @brief Build the index from the dataset for efficient search. + * + * @tparam T data element type + * + * @param[in] res + * @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim] + * @param[in] metric: distance metric to use. Euclidean (L2) is used by default + * @param[in] metric_arg: the value of `p` for Minkowski (l-p) distances. This + * is ignored if the metric_type is not Minkowski. + * + * @return the constructed brute force index + */ +template +index build(raft::resources const& res, + mdspan, row_major, Accessor> dataset, + raft::distance::DistanceType metric = distance::DistanceType::L2Unexpanded, + T metric_arg = 0.0) +{ + // certain distance metrics can benefit by pre-calculating the norms for the index dataset + // which lets us avoid calculating these at query time + std::optional> norms; + if (metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded || + metric == raft::distance::DistanceType::CosineExpanded) { + norms = make_device_vector(res, dataset.extent(0)); + // cosine needs the l2norm, where as l2 distances needs the squared norm + if (metric == raft::distance::DistanceType::CosineExpanded) { + raft::linalg::norm(res, + dataset, + norms->view(), + raft::linalg::NormType::L2Norm, + raft::linalg::Apply::ALONG_ROWS, + raft::sqrt_op{}); + } else { + raft::linalg::norm(res, + dataset, + norms->view(), + raft::linalg::NormType::L2Norm, + raft::linalg::Apply::ALONG_ROWS); + } + } + + return index(res, dataset, std::move(norms), metric, metric_arg); +} +/** + * @brief Brute Force search using the constructed index. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] res raft resources + * @param[in] idx brute force index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + */ +template +void search(raft::resources const& res, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) +{ + RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), "Value of k must match for outputs"); + RAFT_EXPECTS(idx.dataset().extent(1) == queries.extent(1), + "Number of columns in queries must match brute force index"); + + auto k = neighbors.extent(1); + auto d = idx.dataset().extent(1); + + std::vector dataset = {const_cast(idx.dataset().data_handle())}; + std::vector sizes = {idx.dataset().extent(0)}; + std::vector norms; + if (idx.has_norms()) { norms.push_back(const_cast(idx.norms().data_handle())); } + + detail::brute_force_knn_impl(res, + dataset, + sizes, + d, + const_cast(queries.data_handle()), + queries.extent(0), + neighbors.data_handle(), + distances.data_handle(), + k, + true, + true, + nullptr, + idx.metric(), + idx.metric_arg(), + raft::identity_op(), + norms.size() ? &norms : nullptr); +} +/** @} */ // end group brute_force_knn } // namespace raft::neighbors::brute_force diff --git a/cpp/include/raft/neighbors/brute_force_types.hpp b/cpp/include/raft/neighbors/brute_force_types.hpp new file mode 100644 index 0000000000..cc934b7a98 --- /dev/null +++ b/cpp/include/raft/neighbors/brute_force_types.hpp @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "ann_types.hpp" +#include + +#include +#include +#include +#include +#include +#include + +#include + +namespace raft::neighbors::brute_force { +/** + * @addtogroup brute_force + * @{ + */ + +/** + * @brief Brute Force index. + * + * The index stores the dataset and norms for the dataset in device memory. + * + * @tparam T data element type + */ +template +struct index : ann::index { + public: + /** Distance metric used for retrieval */ + [[nodiscard]] constexpr inline raft::distance::DistanceType metric() const noexcept + { + return metric_; + } + + /** Total length of the index (number of vectors). */ + [[nodiscard]] constexpr inline int64_t size() const noexcept { return dataset_view_.extent(0); } + + /** Dimensionality of the data. */ + [[nodiscard]] constexpr inline uint32_t dim() const noexcept { return dataset_view_.extent(1); } + + /** Dataset [size, dim] */ + [[nodiscard]] inline auto dataset() const noexcept + -> device_matrix_view + { + return dataset_view_; + } + + /** Dataset norms */ + [[nodiscard]] inline auto norms() const -> device_vector_view + { + return make_const_mdspan(norms_.value().view()); + } + + /** Whether ot not this index has dataset norms */ + [[nodiscard]] inline bool has_norms() const noexcept { return norms_.has_value(); } + + [[nodiscard]] inline T metric_arg() const noexcept { return metric_arg_; } + + // Don't allow copying the index for performance reasons (try avoiding copying data) + index(const index&) = delete; + index(index&&) = default; + auto operator=(const index&) -> index& = delete; + auto operator=(index&&) -> index& = default; + ~index() = default; + + /** Construct a brute force index from dataset + * + * Constructs a brute force index from a dataset. This lets us precompute norms for + * the dataset, providing a speed benefit over doing this at query time. + + * If the dataset is already in GPU memory, then this class stores a non-owning reference to + * the dataset. If the dataset is in host memory, it will be copied to the device and the + * index will own the device memory. + */ + template + index(raft::resources const& res, + mdspan, row_major, data_accessor> dataset, + std::optional>&& norms, + raft::distance::DistanceType metric, + T metric_arg = 0.0) + : ann::index(), + metric_(metric), + dataset_(make_device_matrix(res, 0, 0)), + norms_(std::move(norms)), + metric_arg_(metric_arg) + { + update_dataset(res, dataset); + resource::sync_stream(res); + } + + private: + /** + * Replace the dataset with a new dataset. + */ + void update_dataset(raft::resources const& res, + raft::device_matrix_view dataset) + { + dataset_view_ = dataset; + } + + /** + * Replace the dataset with a new dataset. + * + * We create a copy of the dataset on the device. The index manages the lifetime of this copy. + */ + void update_dataset(raft::resources const& res, + raft::host_matrix_view dataset) + { + dataset_ = make_device_matrix(dataset.extents(0), dataset.extents(1)); + raft::copy(dataset_.data_handle(), + dataset.data_handle(), + dataset.size(), + resource::get_cuda_stream(res)); + dataset_view_ = make_const_mdspan(dataset_.view()); + } + + raft::distance::DistanceType metric_; + raft::device_matrix dataset_; + std::optional> norms_; + raft::device_matrix_view dataset_view_; + T metric_arg_; +}; + +/** @} */ + +} // namespace raft::neighbors::brute_force diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 123a902ef9..be05d5545f 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -64,10 +64,11 @@ void tiled_brute_force_knn(const raft::resources& handle, ElementType* distances, // size (m, k) IndexType* indices, // size (m, k) raft::distance::DistanceType metric, - float metric_arg = 2.0, - size_t max_row_tile_size = 0, - size_t max_col_tile_size = 0, - DistanceEpilogue distance_epilogue = raft::identity_op()) + float metric_arg = 2.0, + size_t max_row_tile_size = 0, + size_t max_col_tile_size = 0, + DistanceEpilogue distance_epilogue = raft::identity_op(), + const ElementType* precomputed_index_norms = nullptr) { // Figure out the number of rows/cols to tile for size_t tile_rows = 0; @@ -97,7 +98,7 @@ void tiled_brute_force_knn(const raft::resources& handle, metric == raft::distance::DistanceType::L2SqrtExpanded || metric == raft::distance::DistanceType::CosineExpanded) { search_norms.resize(m, stream); - index_norms.resize(n, stream); + if (!precomputed_index_norms) { index_norms.resize(n, stream); } // cosine needs the l2norm, where as l2 distances needs the squared norm if (metric == raft::distance::DistanceType::CosineExpanded) { raft::linalg::rowNorm(search_norms.data(), @@ -108,19 +109,24 @@ void tiled_brute_force_knn(const raft::resources& handle, true, stream, raft::sqrt_op{}); - raft::linalg::rowNorm(index_norms.data(), - index, - d, - n, - raft::linalg::NormType::L2Norm, - true, - stream, - raft::sqrt_op{}); + if (!precomputed_index_norms) { + raft::linalg::rowNorm(index_norms.data(), + index, + d, + n, + raft::linalg::NormType::L2Norm, + true, + stream, + raft::sqrt_op{}); + } } else { raft::linalg::rowNorm( search_norms.data(), search, d, m, raft::linalg::NormType::L2Norm, true, stream); - raft::linalg::rowNorm( - index_norms.data(), index, d, n, raft::linalg::NormType::L2Norm, true, stream); + + if (!precomputed_index_norms) { + raft::linalg::rowNorm( + index_norms.data(), index, d, n, raft::linalg::NormType::L2Norm, true, stream); + } } pairwise_metric = raft::distance::DistanceType::InnerProduct; } @@ -178,7 +184,7 @@ void tiled_brute_force_knn(const raft::resources& handle, if (metric == raft::distance::DistanceType::L2Expanded || metric == raft::distance::DistanceType::L2SqrtExpanded) { auto row_norms = search_norms.data(); - auto col_norms = index_norms.data(); + auto col_norms = precomputed_index_norms ? precomputed_index_norms : index_norms.data(); auto dist = temp_distances.data(); raft::linalg::map_offset( @@ -200,7 +206,7 @@ void tiled_brute_force_knn(const raft::resources& handle, }); } else if (metric == raft::distance::DistanceType::CosineExpanded) { auto row_norms = search_norms.data(); - auto col_norms = index_norms.data(); + auto col_norms = precomputed_index_norms ? precomputed_index_norms : index_norms.data(); auto dist = temp_distances.data(); raft::linalg::map_offset( @@ -330,7 +336,8 @@ void brute_force_knn_impl( std::vector* translations = nullptr, raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded, float metricArg = 0, - DistanceEpilogue distance_epilogue = raft::identity_op()) + DistanceEpilogue distance_epilogue = raft::identity_op(), + std::vector* input_norms = nullptr) { auto userStream = resource::get_cuda_stream(handle); @@ -424,7 +431,8 @@ void brute_force_knn_impl( rowMajorIndex, rowMajorQuery, stream, - metric); + metric, + input_norms ? (*input_norms)[i] : nullptr); // Perform necessary post-processing if (metric == raft::distance::DistanceType::L2SqrtExpanded || @@ -473,7 +481,8 @@ void brute_force_knn_impl( metricArg, 0, 0, - distance_epilogue); + distance_epilogue, + input_norms ? (*input_norms)[i] : nullptr); break; } } diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn-ext.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn-ext.cuh index 390436939f..1a48e1adde 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn-ext.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn-ext.cuh @@ -36,7 +36,9 @@ void fusedL2Knn(size_t D, bool rowMajorIndex, bool rowMajorQuery, cudaStream_t stream, - raft::distance::DistanceType metric) RAFT_EXPLICIT; + raft::distance::DistanceType metric, + const value_t* index_norms = NULL, + const value_t* query_norms = NULL) RAFT_EXPLICIT; } // namespace raft::spatial::knn::detail @@ -56,7 +58,9 @@ void fusedL2Knn(size_t D, bool rowMajorIndex, \ bool rowMajorQuery, \ cudaStream_t stream, \ - raft::distance::DistanceType metric) + raft::distance::DistanceType metric, \ + const Mvalue_t* index_norms, \ + const Mvalue_t* query_norms); instantiate_raft_spatial_knn_detail_fusedL2Knn(int32_t, float, true); instantiate_raft_spatial_knn_detail_fusedL2Knn(int32_t, float, false); diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn-inl.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn-inl.cuh index 4a571c1447..67abab3d1e 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn-inl.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn-inl.cuh @@ -706,6 +706,8 @@ template void fusedL2ExpKnnImpl(const DataT* x, const DataT* y, + const DataT* xn, + const DataT* yn, IdxT m, IdxT n, IdxT k, @@ -787,19 +789,25 @@ void fusedL2ExpKnnImpl(const DataT* x, } } - DataT* xn = (DataT*)workspace; - DataT* yn = (DataT*)workspace; - - if (x != y) { - yn += m; - raft::linalg::rowNorm( - xn, x, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); - raft::linalg::rowNorm( - yn, y, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); - } else { + // calculate norms if they haven't been passed in + if (!xn) { + DataT* xn_ = (DataT*)workspace; + workspace = xn_ + m; raft::linalg::rowNorm( - xn, x, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); + xn_, x, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); + xn = xn_; } + if (!yn) { + if (x == y) { + yn = xn; + } else { + DataT* yn_ = (DataT*)(workspace); + raft::linalg::rowNorm( + yn_, y, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); + yn = yn_; + } + } + fusedL2ExpKnnRowMajor<<>>(x, y, xn, @@ -836,6 +844,8 @@ void fusedL2ExpKnn(IdxT m, IdxT ldd, const DataT* x, const DataT* y, + const DataT* xn, + const DataT* yn, bool sqrt, OutT* out_dists, IdxT* out_inds, @@ -850,6 +860,8 @@ void fusedL2ExpKnn(IdxT m, fusedL2ExpKnnImpl( x, y, + xn, + yn, m, n, k, @@ -867,6 +879,8 @@ void fusedL2ExpKnn(IdxT m, fusedL2ExpKnnImpl( x, y, + xn, + yn, m, n, k, @@ -883,6 +897,8 @@ void fusedL2ExpKnn(IdxT m, } else { fusedL2ExpKnnImpl(x, y, + xn, + yn, m, n, k, @@ -927,7 +943,9 @@ void fusedL2Knn(size_t D, bool rowMajorIndex, bool rowMajorQuery, cudaStream_t stream, - raft::distance::DistanceType metric) + raft::distance::DistanceType metric, + const value_t* index_norms = NULL, + const value_t* query_norms = NULL) { // Validate the input data ASSERT(k > 0, "l2Knn: k must be > 0"); @@ -968,6 +986,8 @@ void fusedL2Knn(size_t D, ldd, query, index, + query_norms, + index_norms, sqrt, out_dists, out_inds, @@ -985,6 +1005,8 @@ void fusedL2Knn(size_t D, ldd, query, index, + query_norms, + index_norms, sqrt, out_dists, out_inds, diff --git a/cpp/src/neighbors/brute_force_knn_index_float.cu b/cpp/src/neighbors/brute_force_knn_index_float.cu new file mode 100644 index 0000000000..f2fda93a97 --- /dev/null +++ b/cpp/src/neighbors/brute_force_knn_index_float.cu @@ -0,0 +1,39 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +template void raft::neighbors::brute_force::search( + raft::resources const& res, + const raft::neighbors::brute_force::index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); + +template void raft::neighbors::brute_force::search( + raft::resources const& res, + const raft::neighbors::brute_force::index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); + +template raft::neighbors::brute_force::index raft::neighbors::brute_force::build( + raft::resources const& res, + raft::device_matrix_view dataset, + raft::distance::DistanceType metric, + float metric_arg); \ No newline at end of file diff --git a/cpp/src/spatial/knn/detail/fused_l2_knn_int32_t_float.cu b/cpp/src/spatial/knn/detail/fused_l2_knn_int32_t_float.cu index 67b08655e6..b73cf31c58 100644 --- a/cpp/src/spatial/knn/detail/fused_l2_knn_int32_t_float.cu +++ b/cpp/src/spatial/knn/detail/fused_l2_knn_int32_t_float.cu @@ -32,7 +32,9 @@ bool rowMajorIndex, \ bool rowMajorQuery, \ cudaStream_t stream, \ - raft::distance::DistanceType metric) + raft::distance::DistanceType metric, \ + const Mvalue_t* index_norms, \ + const Mvalue_t* query_norms) instantiate_raft_spatial_knn_detail_fusedL2Knn(int32_t, float, true); instantiate_raft_spatial_knn_detail_fusedL2Knn(int32_t, float, false); diff --git a/cpp/src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu b/cpp/src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu index 3c0d13710e..35ef37c984 100644 --- a/cpp/src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu +++ b/cpp/src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu @@ -32,7 +32,9 @@ bool rowMajorIndex, \ bool rowMajorQuery, \ cudaStream_t stream, \ - raft::distance::DistanceType metric) + raft::distance::DistanceType metric, \ + const Mvalue_t* index_norms, \ + const Mvalue_t* query_norms) instantiate_raft_spatial_knn_detail_fusedL2Knn(int64_t, float, true); instantiate_raft_spatial_knn_detail_fusedL2Knn(int64_t, float, false); diff --git a/cpp/src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu b/cpp/src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu index e799c5181f..ff23d9c41b 100644 --- a/cpp/src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu +++ b/cpp/src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu @@ -32,7 +32,9 @@ bool rowMajorIndex, \ bool rowMajorQuery, \ cudaStream_t stream, \ - raft::distance::DistanceType metric) + raft::distance::DistanceType metric, \ + const Mvalue_t* index_norms, \ + const Mvalue_t* query_norms) // These are used by brute_force_knn: instantiate_raft_spatial_knn_detail_fusedL2Knn(uint32_t, float, true); diff --git a/cpp/template/CMakeLists.txt b/cpp/template/CMakeLists.txt index a1341f3609..538eac07ef 100644 --- a/cpp/template/CMakeLists.txt +++ b/cpp/template/CMakeLists.txt @@ -39,4 +39,3 @@ target_link_libraries(CAGRA_EXAMPLE PRIVATE raft::raft raft::compiled) add_executable(IVF_FLAT_EXAMPLE src/ivf_flat_example.cu) target_link_libraries(IVF_FLAT_EXAMPLE PRIVATE raft::raft raft::compiled) - diff --git a/cpp/test/neighbors/tiled_knn.cu b/cpp/test/neighbors/tiled_knn.cu index 2ab82b845e..ebde8e6d35 100644 --- a/cpp/test/neighbors/tiled_knn.cu +++ b/cpp/test/neighbors/tiled_knn.cu @@ -180,6 +180,36 @@ class TiledKNNTest : public ::testing::TestWithParam { float(0.001), stream_, true)); + + // Also test out the 'index' api - where we can use precomputed norms + if (params_.row_major) { + auto idx = + raft::neighbors::brute_force::build(handle_, + raft::make_device_matrix_view( + database.data(), params_.num_db_vecs, params_.dim), + metric, + metric_arg); + + raft::neighbors::brute_force::search( + handle_, + idx, + raft::make_device_matrix_view( + search_queries.data(), params_.num_queries, params_.dim), + raft::make_device_matrix_view( + raft_indices_.data(), params_.num_queries, params_.k), + raft::make_device_matrix_view( + raft_distances_.data(), params_.num_queries, params_.k)); + + ASSERT_TRUE(raft::spatial::knn::devArrMatchKnnPair(ref_indices_.data(), + raft_indices_.data(), + ref_distances_.data(), + raft_distances_.data(), + num_queries, + k_, + float(0.001), + stream_, + true)); + } } void SetUp() override