From ffe4a1bb847335afd02634b44a96e26b2a064e65 Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Tue, 30 Apr 2024 11:57:26 +0800 Subject: [PATCH] Refactor milvuslite Signed-off-by: junjie.jiang --- .devcontainer/Dockerfile | 34 - .devcontainer/devcontainer.json | 20 - .devcontainer/setup/.gitignore | 1 - .devcontainer/setup/gen_env.sh | 12 - .devcontainer/setup/setup_dev.sh | 8 - .gitmodules | 9 + CMakeLists.txt | 259 +++ CONTRIBUTING.md | 44 - README.md | 204 --- cmake/milvus-storage.cmake | 11 + conanfile.py | 72 + examples/bfloat16_example.py | 75 + examples/binary_example.py | 76 + examples/customize_schema.py | 78 + examples/customize_schema_auto_id.py | 78 + examples/dynamic_field.py | 100 ++ examples/example.py | 188 -- examples/float16_example.py | 75 + examples/fuzzy_match.py | 83 + examples/hello_milvus.py | 189 ++ examples/hello_milvus_array.py | 88 + examples/hello_milvus_delete.py | 80 + examples/index.py | 95 + examples/non_ascii_encode.py | 41 + examples/simple.py | 81 + examples/simple_auto_id.py | 66 + examples/sparse.py | 99 ++ milvus_binary/.gitignore | 2 - milvus_binary/build.sh | 243 --- milvus_binary/env.sh | 5 - milvus_binary/patches/knowhere-v1.3.11.patch | 26 - milvus_binary/patches/macosx-v2.2.5.patch | 106 -- milvus_binary/patches/milvus-v2.2.5.patch | 278 --- milvus_binary/patches/msys-v2.2.4.patch | 156 -- milvus_binary/patches/msys-v2.2.5.patch | 156 -- milvus_build_backend/backend.py | 57 - pyproject.toml | 9 - python/setup.py | 138 ++ python/src/milvus/__init__.py | 0 python/src/milvus/server.py | 92 + python/src/milvus/server_manager.py | 47 + requirements.txt | 7 - setup.cfg | 38 - src/CMakeLists.txt | 56 + src/collection_data.cpp | 131 ++ src/collection_data.h | 92 + src/collection_meta.cpp | 203 +++ src/collection_meta.h | 219 +++ src/common.h | 107 ++ src/create_collection_task.cpp | 323 ++++ src/create_collection_task.h | 56 + src/create_index_task.cpp | 507 ++++++ src/create_index_task.h | 80 + src/delete_task.cpp | 32 + src/delete_task.h | 27 + src/index.cpp | 90 + src/index.h | 60 + src/insert_task.cpp | 238 +++ src/insert_task.h | 61 + src/milvus/__init__.py | 564 ------ src/milvus/data/.gitignore | 1 - src/milvus/data/config.yaml.template | 541 ------ src/milvus_local.cpp | 331 ++++ src/milvus_local.h | 105 ++ src/milvus_proxy.cpp | 356 ++++ src/milvus_proxy.h | 94 + src/milvus_service_impl.cpp | 252 +++ src/milvus_service_impl.h | 139 ++ src/parser/Plan.g4 | 151 ++ src/parser/antlr/PlanBaseVisitor.cpp | 7 + src/parser/antlr/PlanBaseVisitor.h | 140 ++ src/parser/antlr/PlanLexer.cpp | 417 +++++ src/parser/antlr/PlanLexer.h | 56 + src/parser/antlr/PlanParser.cpp | 1642 ++++++++++++++++++ src/parser/antlr/PlanParser.h | 426 +++++ src/parser/antlr/PlanVisitor.cpp | 7 + src/parser/antlr/PlanVisitor.h | 84 + src/parser/parser.cc | 35 + src/parser/parser.h | 1447 +++++++++++++++ src/parser/utils.h | 789 +++++++++ src/proto/common.proto | 1 + src/proto/feder.proto | 1 + src/proto/manifest.proto | 1 + src/proto/milvus.proto | 1 + src/proto/msg.proto | 1 + src/proto/plan.proto | 1 + src/proto/rg.proto | 1 + src/proto/schema.proto | 1 + src/proto/segcore.proto | 1 + src/query_task.cpp | 250 +++ src/query_task.h | 51 + src/retrieve_result.h | 27 + src/schema_util.cpp | 684 ++++++++ src/schema_util.h | 115 ++ src/search_result.h | 39 + src/search_task.cpp | 294 ++++ src/search_task.h | 55 + src/segcore_wrapper.cpp | 269 +++ src/segcore_wrapper.h | 64 + src/server.cpp | 79 + src/status.h | 255 +++ src/storage.cpp | 164 ++ src/storage.h | 114 ++ src/string_util.hpp | 65 + src/type.h | 12 + src/unittest/CMakeLists.txt | 56 + src/unittest/grpc_server_test.cpp | 284 +++ src/unittest/milvus_local_test.cpp | 154 ++ src/unittest/milvus_proxy_test.cpp | 203 +++ src/unittest/run_examples.py | 25 + src/unittest/storage_test.cpp | 39 + src/unittest/test_util.cpp | 283 +++ src/unittest/test_util.h | 75 + tests/test_milvus_config.py | 47 - thirdparty/milvus | 1 + thirdparty/milvus-proto | 1 + thirdparty/milvus-storage | 1 + 117 files changed, 14360 insertions(+), 2747 deletions(-) delete mode 100644 .devcontainer/Dockerfile delete mode 100644 .devcontainer/devcontainer.json delete mode 100644 .devcontainer/setup/.gitignore delete mode 100644 .devcontainer/setup/gen_env.sh delete mode 100644 .devcontainer/setup/setup_dev.sh create mode 100644 .gitmodules create mode 100644 CMakeLists.txt delete mode 100644 CONTRIBUTING.md delete mode 100644 README.md create mode 100644 cmake/milvus-storage.cmake create mode 100644 conanfile.py create mode 100644 examples/bfloat16_example.py create mode 100644 examples/binary_example.py create mode 100644 examples/customize_schema.py create mode 100644 examples/customize_schema_auto_id.py create mode 100644 examples/dynamic_field.py delete mode 100644 examples/example.py create mode 100644 examples/float16_example.py create mode 100644 examples/fuzzy_match.py create mode 100644 examples/hello_milvus.py create mode 100644 examples/hello_milvus_array.py create mode 100644 examples/hello_milvus_delete.py create mode 100644 examples/index.py create mode 100644 examples/non_ascii_encode.py create mode 100644 examples/simple.py create mode 100644 examples/simple_auto_id.py create mode 100644 examples/sparse.py delete mode 100644 milvus_binary/.gitignore delete mode 100644 milvus_binary/build.sh delete mode 100644 milvus_binary/env.sh delete mode 100644 milvus_binary/patches/knowhere-v1.3.11.patch delete mode 100644 milvus_binary/patches/macosx-v2.2.5.patch delete mode 100644 milvus_binary/patches/milvus-v2.2.5.patch delete mode 100644 milvus_binary/patches/msys-v2.2.4.patch delete mode 100644 milvus_binary/patches/msys-v2.2.5.patch delete mode 100644 milvus_build_backend/backend.py delete mode 100644 pyproject.toml create mode 100644 python/setup.py create mode 100644 python/src/milvus/__init__.py create mode 100644 python/src/milvus/server.py create mode 100644 python/src/milvus/server_manager.py delete mode 100644 requirements.txt delete mode 100644 setup.cfg create mode 100644 src/CMakeLists.txt create mode 100644 src/collection_data.cpp create mode 100644 src/collection_data.h create mode 100644 src/collection_meta.cpp create mode 100644 src/collection_meta.h create mode 100644 src/common.h create mode 100644 src/create_collection_task.cpp create mode 100644 src/create_collection_task.h create mode 100644 src/create_index_task.cpp create mode 100644 src/create_index_task.h create mode 100644 src/delete_task.cpp create mode 100644 src/delete_task.h create mode 100644 src/index.cpp create mode 100644 src/index.h create mode 100644 src/insert_task.cpp create mode 100644 src/insert_task.h delete mode 100644 src/milvus/__init__.py delete mode 100644 src/milvus/data/.gitignore delete mode 100644 src/milvus/data/config.yaml.template create mode 100644 src/milvus_local.cpp create mode 100644 src/milvus_local.h create mode 100644 src/milvus_proxy.cpp create mode 100644 src/milvus_proxy.h create mode 100644 src/milvus_service_impl.cpp create mode 100644 src/milvus_service_impl.h create mode 100644 src/parser/Plan.g4 create mode 100644 src/parser/antlr/PlanBaseVisitor.cpp create mode 100644 src/parser/antlr/PlanBaseVisitor.h create mode 100644 src/parser/antlr/PlanLexer.cpp create mode 100644 src/parser/antlr/PlanLexer.h create mode 100644 src/parser/antlr/PlanParser.cpp create mode 100644 src/parser/antlr/PlanParser.h create mode 100644 src/parser/antlr/PlanVisitor.cpp create mode 100644 src/parser/antlr/PlanVisitor.h create mode 100644 src/parser/parser.cc create mode 100644 src/parser/parser.h create mode 100644 src/parser/utils.h create mode 120000 src/proto/common.proto create mode 120000 src/proto/feder.proto create mode 120000 src/proto/manifest.proto create mode 120000 src/proto/milvus.proto create mode 120000 src/proto/msg.proto create mode 120000 src/proto/plan.proto create mode 120000 src/proto/rg.proto create mode 120000 src/proto/schema.proto create mode 120000 src/proto/segcore.proto create mode 100644 src/query_task.cpp create mode 100644 src/query_task.h create mode 100644 src/retrieve_result.h create mode 100644 src/schema_util.cpp create mode 100644 src/schema_util.h create mode 100644 src/search_result.h create mode 100644 src/search_task.cpp create mode 100644 src/search_task.h create mode 100644 src/segcore_wrapper.cpp create mode 100644 src/segcore_wrapper.h create mode 100644 src/server.cpp create mode 100644 src/status.h create mode 100644 src/storage.cpp create mode 100644 src/storage.h create mode 100644 src/string_util.hpp create mode 100644 src/type.h create mode 100644 src/unittest/CMakeLists.txt create mode 100644 src/unittest/grpc_server_test.cpp create mode 100644 src/unittest/milvus_local_test.cpp create mode 100644 src/unittest/milvus_proxy_test.cpp create mode 100644 src/unittest/run_examples.py create mode 100644 src/unittest/storage_test.cpp create mode 100644 src/unittest/test_util.cpp create mode 100644 src/unittest/test_util.h delete mode 100644 tests/test_milvus_config.py create mode 160000 thirdparty/milvus create mode 160000 thirdparty/milvus-proto create mode 160000 thirdparty/milvus-storage diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile deleted file mode 100644 index 3eff0b7..0000000 --- a/.devcontainer/Dockerfile +++ /dev/null @@ -1,34 +0,0 @@ -FROM milvusdb/milvus-env:amd64-centos7-20230329-1431037 - -# Ignore tool missing warnings, and wo also not need clang-tools -RUN rm -fr /etc/profile.d/llvm-toolset*.sh - -# python38 for dev -RUN yum -y install rh-python38 -RUN echo 'source scl_source enable rh-python38' > /etc/profile.d/rh-python38.sh - -# git new for devel -RUN yum -y install rh-git227-git-all -RUN echo 'source scl_source enable rh-git227' > /etc/profile.d/rh-git227.sh - -# Add local user in container for dev -RUN yum -y install sudo -ADD setup/env /tmp/env -RUN . /tmp/env && \ - if getent group $DEV_GROUP 2>/dev/null 1>/dev/null ; then \ - echo group $DEV_GROUP already exist ; \ - elif getent group $DEV_GID 2>/dev/null 1>/dev/null ; then \ - echo group $DEV_GID already exist ; \ - else \ - groupadd -g $DEV_GID $DEV_GROUP ; \ - fi ; \ - if id $DEV_USER 2>/dev/null 1>/dev/null ; then \ - echo user $DEV_USER already exist ; \ - exit 1 ; \ - elif id $DEV_UID 2>/dev/null 1>/dev/null ; then \ - echo user $DEV_UID already exist ; \ - exit 1 ; \ - else \ - useradd -g $DEV_GID -u $DEV_UID -m -d $DEV_HOME -s /bin/bash $DEV_USER ; \ - fi && \ - echo "${DEV_USER} ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/user diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json deleted file mode 100644 index 26d3ef3..0000000 --- a/.devcontainer/devcontainer.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "name": "Embd Milvus Development", - "build": { - "dockerfile": "Dockerfile" - }, - "initializeCommand": "bash .devcontainer/setup/gen_env.sh", - "onCreateCommand": "bash .devcontainer/setup/setup_dev.sh", - "containerUser": "${localEnv:USER}", - "customizations": { - "vscode": { - "extensions": [ - "ms-python.python", - "bungcip.better-toml" - ] - } - }, - "mounts": [ - "source=${localEnv:HOME}/.ssh,target=/home/${localEnv:USER}/.ssh,type=bind,consistency=cached" - ] -} diff --git a/.devcontainer/setup/.gitignore b/.devcontainer/setup/.gitignore deleted file mode 100644 index 8fa5b33..0000000 --- a/.devcontainer/setup/.gitignore +++ /dev/null @@ -1 +0,0 @@ -env \ No newline at end of file diff --git a/.devcontainer/setup/gen_env.sh b/.devcontainer/setup/gen_env.sh deleted file mode 100644 index 5bdf33c..0000000 --- a/.devcontainer/setup/gen_env.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/bin/bash - -this_dir=$(dirname $0) - -echo -n > ${this_dir}/env - -echo DEV_USER=$(id -un) >> ${this_dir}/env -echo DEV_UID=$(id -u) >> ${this_dir}/env -echo DEV_GROUP=$(id -gn) >> ${this_dir}/env -echo DEV_GID=$(id -g) >> ${this_dir}/env -echo DEV_HOME=$HOME >> ${this_dir}/env -echo DEV_SHELL=$SHELL >> ${this_dir}/env diff --git a/.devcontainer/setup/setup_dev.sh b/.devcontainer/setup/setup_dev.sh deleted file mode 100644 index 3e1213f..0000000 --- a/.devcontainer/setup/setup_dev.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - -set -e - -project_dir=$(dirname $(dirname $(cd $(dirname $0); pwd))) - -python3 -m pip install --user -U pip -python3 -m pip install --user -r ${project_dir}/requirements.txt diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..50a3a48 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,9 @@ +[submodule "thirdparty/milvus"] + path = thirdparty/milvus + url = https://github.com/milvus-io/milvus.git +[submodule "thirdparty/milvus-proto"] + path = thirdparty/milvus-proto + url = https://github.com/milvus-io/milvus-proto.git +[submodule "thirdparty/milvus-storage"] + path = thirdparty/milvus-storage + url = https://github.com/milvus-io/milvus-storage.git diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..095e765 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,259 @@ +cmake_minimum_required(VERSION 3.20) +project(milvus-lite) + +option(ENABLE_UNIT_TESTS "Enable unit tests" ON) +message(STATUS "Enable testing: ${ENABLE_UNIT_TESTS}") + +list(APPEND CMAKE_MODULE_PATH ${CMAKE_BINARY_DIR}/) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED True) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +set(CMAKE_VERBOSE_MAKEFILE ON) +if(CMAKE_BUILD_TYPE STREQUAL "Debug") + set(CMAKE_CXX_FLAGS "-g -Wall -fPIC ${CMAKE_CXX_FLAGS}") +else() + set(CMAKE_CXX_FLAGS "-O3 -Wall -fPIC ${CMAKE_CXX_FLAGS}") +endif() + +if(APPLE) + include_directories(/opt/homebrew/opt/libomp/include) + link_directories(/opt/homebrew/opt/libomp/lib) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lomp") +endif() + +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) + +include(FetchContent) + +set(CONAN_LIBS "${CONAN_LIBS};marisa::marisa;google-cloud-cpp::rest_internal;AWS::aws-sdk-cpp-identity-management") + +find_package(SQLiteCpp REQUIRED) +include_directories(${SQLiteCpp_INCLUDE_DIRS}) + +find_package(antlr4-runtime REQUIRED) +include_directories(${antlr4-cppruntime_INCLUDES}) + +find_package(Protobuf REQUIRED) +include_directories(${protobuf_INCLUDE_DIRS}) + +find_package(gRPC REQUIRED) +include_directories(${gRPC_INCLUDE_DIRS}) + +find_package(TBB REQUIRED) +include_directories(${TBB_tbb_INCLUDE_DIRS}) + +find_package(nlohmann_json REQUIRED) +include_directories(${nlohmann_json_INCLUDE_DIRS}) + +find_package(Boost REQUIRED) +include_directories(${Boost_INCLUDE_DIRS}) + +find_package(folly REQUIRED) +include_directories(${Folly_INCLUDE_DIRS}) + +find_package(fmt REQUIRED) +include_directories(${fmt_INCLUDE_DIRS}) + +find_package(opentelemetry-cpp REQUIRED) +include_directories(${opentelemetry-cpp_INCLUDE_DIRS}) + +find_package(glog REQUIRED) +include_directories(${glog_INCLUDE_DIRS}) + +find_package(gflags REQUIRED) +include_directories(${gflags_INCLUDE_DIRS}) + +find_package(Arrow REQUIRED) +include_directories(${arrow_INCLUDE_DIRS}) + +find_package(re2 REQUIRED) +include_directories(${re2_INCLUDE_DIRS}) +link_directories(${re2_LIB_DIRS}) + +find_package(double-conversion REQUIRED) +include_directories(${double-conversion_INCLUDE_DIRS}) + +find_package(prometheus-cpp REQUIRED) +include_directories(${prometheus-cpp_INCLUDE_DIRS}) + +find_package(marisa REQUIRED) +include_directories(${marisa_INCLUDE_DIRS}) + +find_package(yaml-cpp REQUIRED) +include_directories(${yaml-cpp_INCLUDE_DIRS}) + +find_package(google-cloud-cpp REQUIRED) +include_directories(${google-cloud-cpp_INCLUDE_DIRS}) + +find_package(absl REQUIRED) +include_directories(${absl_type_traits_INCLUDE_DIRS}) + +find_package(AWSSDK REQUIRED) + +add_definitions(-DANTLR4CPP_STATIC) +add_definitions(-DHAVE_CPP_STDLIB) + +IF(APPLE) + add_definitions("-D_GNU_SOURCE") +endif() + +file(MAKE_DIRECTORY ${CMAKE_BINARY_DIR}/pb) + +add_library(milvus_proto STATIC "${CMAKE_SOURCE_DIR}/src/proto/plan.proto" + "${CMAKE_SOURCE_DIR}/src/proto/schema.proto" + "${CMAKE_SOURCE_DIR}/src/proto/common.proto" + "${CMAKE_SOURCE_DIR}/src/proto/segcore.proto" + "${CMAKE_SOURCE_DIR}/src/proto/milvus.proto" + "${CMAKE_SOURCE_DIR}/src/proto/msg.proto" + "${CMAKE_SOURCE_DIR}/src/proto/feder.proto" + "${CMAKE_SOURCE_DIR}/src/proto/rg.proto" +) + +set(PROTO_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/pb") + +protobuf_generate( + TARGET milvus_proto + IMPORT_DIRS "${CMAKE_SOURCE_DIR}/src/proto" + PROTOC_OUT_DIR "${PROTO_BINARY_DIR}") + + +add_library(milvus_grpc_service STATIC + "${CMAKE_SOURCE_DIR}/src/proto/milvus.proto" +) + +target_link_libraries(milvus_grpc_service milvus_proto) + +protobuf_generate( + TARGET milvus_grpc_service + # OUT_VAR PROTO_GENERATED_FILES + LANGUAGE grpc + GENERATE_EXTENSIONS .grpc.pb.h .grpc.pb.cc + PLUGIN "protoc-gen-grpc=\$" + IMPORT_DIRS "${CMAKE_SOURCE_DIR}/src/proto" + PROTOC_OUT_DIR "${PROTO_BINARY_DIR}") + +include_directories("${CMAKE_SOURCE_DIR}/src/parser/") +include_directories("${CMAKE_SOURCE_DIR}/src/parser/antlr/") +include_directories("${CMAKE_BINARY_DIR}/pb") +include_directories("${CMAKE_BINARY_DIR}") + +add_library(parser STATIC "${CMAKE_SOURCE_DIR}/src/parser/parser.cc" + "${CMAKE_SOURCE_DIR}/src/parser/antlr/PlanBaseVisitor.cpp" + "${CMAKE_SOURCE_DIR}/src/parser/antlr/PlanLexer.cpp" + "${CMAKE_SOURCE_DIR}/src/parser/antlr/PlanParser.cpp" + "${CMAKE_SOURCE_DIR}/src/parser/antlr/PlanVisitor.cpp" + ) + + +target_link_libraries(parser milvus_proto ${antlr4-cppruntime_LIBRARIES}) + +add_subdirectory(thirdparty/milvus/internal/core/thirdparty/knowhere) + +include_directories("${CMAKE_SOURCE_DIR}/thirdparty/milvus/internal/core/thirdparty/") +include_directories("${CMAKE_SOURCE_DIR}/thirdparty/milvus/internal/core/thirdparty/tantivy/tantivy-binding/include/") +include_directories("${CMAKE_SOURCE_DIR}/thirdparty/milvus/internal/core/thirdparty/tantivy/") + +include_directories("${knowhere_SOURCE_DIR}/include") +function(MILVUS_ADD_PKG_CONFIG MODULE) + configure_file(${MODULE}.pc.in "${CMAKE_CURRENT_BINARY_DIR}/${MODULE}.pc" @ONLY) + install(FILES "${CMAKE_CURRENT_BINARY_DIR}/${MODULE}.pc" + DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig/") +endfunction() + + +include(cmake/milvus-storage.cmake) + +include_directories("${CMAKE_SOURCE_DIR}/thirdparty/milvus/internal/core/src/") +add_library(boost_bitset_ext + thirdparty/milvus/internal/core/thirdparty/boost_ext/dynamic_bitset_ext.cpp) + + +FetchContent_Declare( + simdjson + GIT_REPOSITORY https://github.com/simdjson/simdjson.git + GIT_TAG v3.1.7 +) +FetchContent_MakeAvailable(simdjson) + + +if (CMAKE_BUILD_TYPE STREQUAL "Debug") + set(CARGO_CMD cargo build) +else () + set(CARGO_CMD cargo build --release) +endif () + +set(HOME_VAR $ENV{HOME}) + +add_custom_command(OUTPUT ls_cargo + COMMENT "ls cargo" + COMMAND ls ${HOME_VAR}/.cargo/bin/ + ) + +add_custom_target(ls_cargo_target DEPENDS ls_cargo) + +add_custom_command(OUTPUT compile_tantivy + COMMENT "Compiling tantivy binding" + COMMAND CARGO_TARGET_DIR=${CMAKE_CURRENT_BINARY_DIR} ${CARGO_CMD} + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/thirdparty/milvus/internal/core/thirdparty/tantivy/tantivy-binding + DEPENDS ls_cargo_target + ) + +add_custom_target(tantivy_binding_target DEPENDS compile_tantivy) + +add_library(tantivy_binding STATIC IMPORTED) +set_target_properties(tantivy_binding + PROPERTIES + IMPORTED_GLOBAL TRUE + IMPORTED_LOCATION "${CMAKE_CURRENT_BINARY_DIR}/release/libtantivy_binding.a" + INTERFACE_INCLUDE_DIRECTORIES "${CMAKE_SOURCE_DIR}/thirdparty/milvus/internal/core/thirdparty/tantivy/tantivy-binding/include/" +) + +add_dependencies(tantivy_binding tantivy_binding_target) + +FetchContent_Declare(clone_opendal + GIT_REPOSITORY "https://github.com/apache/incubator-opendal.git" + GIT_TAG "v0.43.0-rc.2" # it's much better to use a specific Git revision or Git tag for reproducibility + SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/incubator-opendal" + GIT_SHALLOW 1 +) + +FetchContent_MakeAvailable(clone_opendal) + +add_custom_command(OUTPUT compile_opendal + COMMENT "Compiling opendal" + COMMAND bash -c "cargo build --release --verbose" + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/incubator-opendal/bindings/c +) +add_custom_target(compile_opendal_target DEPENDS compile_opendal) + +add_library(opendal STATIC IMPORTED) +set_target_properties(opendal + PROPERTIES + IMPORTED_GLOBAL TRUE + IMPORTED_LOCATION "${CMAKE_CURRENT_BINARY_DIR}/incubator-opendal/target/release/libopendal_c.a" + INTERFACE_INCLUDE_DIRECTORIES "${CMAKE_CURRENT_BINARY_DIR}/incubator-opendal/bindings/c/include/") + +add_dependencies(opendal compile_opendal_target clone_opendal_target) + +set(MILVUS_ENGINE_SRC + "${CMAKE_SOURCE_DIR}/thirdparty/milvus/internal/core/src") +add_subdirectory(thirdparty/milvus/internal/core/src/log) +add_subdirectory(thirdparty/milvus/internal/core/src/config) +add_subdirectory(thirdparty/milvus/internal/core/src/common) +add_subdirectory(thirdparty/milvus/internal/core/src/storage) +add_subdirectory(thirdparty/milvus/internal/core/src/query) +add_subdirectory(thirdparty/milvus/internal/core/src/exec) +add_subdirectory(thirdparty/milvus/internal/core/src/index) +add_subdirectory(thirdparty/milvus/internal/core/src/segcore) +add_subdirectory(thirdparty/milvus/internal/core/src/bitset) + +if(ENABLE_UNIT_TESTS) + include(CTest) + enable_testing() +endif() + +add_subdirectory(src) + +find_program(MEMORYCHECK_COMMAND NAMES valgrind) +set(MEMORYCHECK_COMMAND_OPTIONS "--trace-children=yes --track-origins=yes --leak-check=full --show-leak-kinds=all") diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md deleted file mode 100644 index 3ffe490..0000000 --- a/CONTRIBUTING.md +++ /dev/null @@ -1,44 +0,0 @@ -# Contributing Guide - -Welcome contributors! This guide will help you get started with contributing to Milvus-lite. - -Please always find the latest version of this guide at [CONTRIBUTING.md:main](https://github.com/milvus-io/milvus-lite/blob/main/CONTRIBUTING.md) - -## How to set up the development environment -The Milvus-lite project is written in Python. To set up the development environment, you need to install Python 3.8 or later(Our release is supported Python3.6+). We recommend that you use a virtual environment to install the dependencies, although the Milvus-lite project requires a very small number of external packages. - -The main dependencies for build Milvus-lite is to install dependencies of milvis, so generally you could refer to Milvus's [install_deps.sh](https://github.com/milvus-io/milvus/blob/master/scripts/install_deps.sh) as a reference. Please note, you should follow the related branch of Milvus. For example, if you want to build Milvus-lite with Milvus 2.2.0, you should checkout the branch of Milvus 2.2.0. - -For python3 build wheel, we use the build module and requires newer version of setuptools. So you should install the latest version of setuptools and build. - -### Setup development environment under linux -We release the Milvus-lite with CentOS image, which reuses from milvusdb/milvus-env, so the binary distribution is compatiable with manylinux2014. - -If you open the project with VSCode, you could alse use devcontainer to setup the development environment(recommanded). That will help you install all dependencies automatically. - -### Setup development environment under macOS -As we build Milvus-lite with macos 11 and 12, and all dependencies are resloved during build. -Generally, you need first install [brew](https://brew.sh/), then install the following packages: - -```bash -brew install boost libomp ninja tbb openblas ccache pkg-config md5sha1sum llvm@15 -``` - -You could also find more details in [build.sh](milvus_binary/build.sh) - - -### Setup development environment under Windows/msys2 -We only support specific version of msys/mingw, currently we **msys2-base-x86_64-20220603**, which could be found at [MSYS2 Install Release](https://github.com/msys2/msys2-installer/releases/tag/2022-06-03) - -After install mingw/msys2, you need install the following packages by pacman: -```bash -pacman -S git patch -pacman -S mingw-w64-x86_64-python mingw-w64-x86_64-python-wheel mingw-w64-x86_64-python-pip -``` - -## Build Milvus-lite -```bash -python3 -m build --wheel -``` - -After build, you shoud have wheel package under dist folder. diff --git a/README.md b/README.md deleted file mode 100644 index 9828605..0000000 --- a/README.md +++ /dev/null @@ -1,204 +0,0 @@ -# Milvus Lite - -[![PyPI Version](https://img.shields.io/pypi/v/milvus.svg)](https://pypi.python.org/pypi/milvus) - -## Introduction - -Milvus Lite is a lightweight version of Milvus that can be embedded into your Python application. It is a single binary that can be easily installed and run on your machine. - -It could be imported as a Python library, as well as use it as a command standalone server. - -Thanks to Milvus standalone version could be run with embedded etcd and local storage, Milvus Lite does not have any other dependencies. - -Everything you do with Milvus Lite, every piece of code you write for Milvus Lite can be safely migrated to other forms of Milvus (standalone version, cluster version, cloud version, etc.). - -Please note that it is not suggested to use Milvus Lite in a production environment. Consider using Milvus clustered or the fully managed Milvus on Cloud. - - -## Requirements - -Milvus Lite is available in: -- Google Colab [example](https://github.com/milvus-io/milvus-lite/blob/main/examples/example.ipynb) -- Jupyter Notebook - -Here's also a list of verified OS types where Milvus Lite can successfully build and run: -- Ubuntu >= 18.04 (x86_64) -- CentOS >= 7.0 (x86_64) -- MacOS >= 11.0 (Apple Silicon) - -*NOTE* -* For linux we use manylinux2014 as the base image, so it should be able to run on most linux distributions. -* Milvus Lite can also run on **Windows**. However, this is not strictly verified. - -## Installation - -Milvus Lite is available on PyPI. You can install it via `pip` for Python 3.6+: - -```bash -$ python3 -m pip install milvus -``` - -Or, install with client(pymilvus): -```bash -$ python3 -m pip install "milvus[client]" -``` -Note: pymilvus now requires Python 3.7+ - -## Usage - -### Import as Python library -Simply import `milvus.default_server`. - -```python -from milvus import default_server -from pymilvus import connections, utility - -# (OPTIONAL) Set if you want store all related data to specific location -# Default location: -# %APPDATA%/milvus-io/milvus-server on windows -# ~/.milvus-io/milvus-server on linux -# default_server.set_base_dir('milvus_data') - -# (OPTIONAL) if you want cleanup previous data -# default_server.cleanup() - -# Start you milvus server -default_server.start() - -# Now you could connect with localhost and the given port -# Port is defined by default_server.listen_port -connections.connect(host='127.0.0.1', port=default_server.listen_port) - -# Check if the server is ready. -print(utility.get_server_version()) -``` - -### CLI milvus-server - -You could also use the `milvus-server` command to start the server. - -```bash -$ milvus-serevr -``` - -The full options cloud be found by `milvus-server --help`. - - -## Advanced usage - -### Debug startup - -You could use `debug_server` instead of `default_server` for checking startup failures. - -```python -from milvus import debug_server -``` - -and you could also try create server instance by your self - -```python -from milvus import MilvusServer - -server = MilvusServer(debug=True) -``` - -If you're using CLI `milvus-server`, you could use `--debug` to enable debug mode. - -```bash -$ milvus-server --debug -``` - -### Configurations for Milvus -Milvus Lite could set configure by API as well as by CLI. We seperate the configurations into two parts: `basic` and `extra`. - -#### The basic configurations -You could find available configurations by `milvus-server --help` for got the list of `basic` configurations. - -These basic configurations including: -- Some listen ports for service, e.g. `--proxy-port` for specifying the port of proxy service. -- Some storage configurations, e.g. `--data` for specifying the data directory. -- Some log configurations. e.g. `--system-log-level` for specifying the log level. - -If you using Python API, you could set these configurations by `MilvusServer.config.set` method. - -```python -# this have the same effect as `milvus-server --system-log-level info` -default_server.config.set('system_log_level', 'info') -``` - -All configuable basic configurations could be found in config yaml template, which is installed with milvus package. - -#### The extra configurations -Other configurations are `extra` configurations, which could also be set by `MilvusServer.config.set` method. - -for example, if we want to set `dataCoord.segment.maxSize` to 1024, we could do: - -```python -default_server.config.set('dataCoord.segment.maxSize', 1024) -``` - -or by CLI: - -``` bash -milvus-server --extra-config dataCoord.segment.maxSize=1024 -``` - -Both of them will update the content of Milvus config yaml with: -``` yaml -dataCoord: - segment: - maxSize: 1024 -``` - -### Context - -You could close server while you not need it anymore. -Or, you're able to using `with` context to start/stop it. - -```python -from milvus import default_server - -with default_server: - # milvus started, using default server here - ... -``` - -### Data and Log Persistence - -By default all data and logs are stored in the following locations: `~/.milvus.io/milvus-server/VERSION` (VERSION is the versiong string of Milvus Lite). - -You could also set it at runtime(before the server started), by Python code: - -```python -from milvus import default_server -default_server.set_base_dir('milvus_data') -``` - -Or with CLI: - -```bash -$ milvus-server --data milvus_data -``` - -### Working with PyMilvus - -Milvus Lite could be run without pymilvus if you just want run as a server. -You could also install with extras `client` to get pymilvus. - -```bash -$ python3 -m pip install "milvus[client]" -``` - -## Examples - -Milvus Lite is friendly with jupyter notebook, you could find more examples under [examples](https://github.com/milvus-io/milvus-lite/blob/main/examples) folder. - -## Contributing -If you want to contribute to Milvus Lite, please read the [Contributing Guide](https://github.com/milvus-io/milvus-lite/blob/main/CONTRIBUTING.md) first. - -## Report a bug -When you use or develop milvus-lite, if you find any bug, please report it to us. You could submit an issue in [milvus-lite]( -https://github.com/milvus-io/milvus-lite/issues/new/choose) or report you [milvus](https://github.com/milvus-io/milvus/issues/new/choose) repo if you think is a Milvus issue. - -## License -Milvus Lite is under the Apache 2.0 license. See the [LICENSE](https://github.com/milvus-io/milvus-lite/blob/main/LICENSE) file for details. diff --git a/cmake/milvus-storage.cmake b/cmake/milvus-storage.cmake new file mode 100644 index 0000000..9fb045b --- /dev/null +++ b/cmake/milvus-storage.cmake @@ -0,0 +1,11 @@ + +file(GLOB_RECURSE SRC_FILES thirdparty/milvus-storage/cpp/src/*.cpp thirdparty/milvus-storage/cpp/src/*.cc) +add_library(milvus-storage ${SRC_FILES}) +target_include_directories(milvus-storage PUBLIC BEFORE thirdparty/milvus-storage/cpp/include/milvus-storage thirdparty/milvus-storage/cpp/src) +target_link_libraries(milvus-storage PUBLIC + arrow::arrow + Boost::boost + protobuf::protobuf + glog::glog + opendal +) diff --git a/conanfile.py b/conanfile.py new file mode 100644 index 0000000..e1a0a71 --- /dev/null +++ b/conanfile.py @@ -0,0 +1,72 @@ +from conans import ConanFile +from conan.tools.cmake import CMake + + +class MilvusLiteConan(ConanFile): + settings = "os", "compiler", "build_type", "arch" + requires = ( + # gtest + "gtest/1.13.0", + # glog + "xz_utils/5.4.5", + "zlib/1.2.13", + "libunwind/1.7.2", + "glog/0.6.0", + # protobuf + "protobuf/3.21.4", + # folly + "fmt/9.1.0", + "folly/2023.10.30.05@milvus/dev", + # antlr + "antlr4-cppruntime/4.13.1", + # sqlite + "sqlitecpp/3.3.1", + "onetbb/2021.9.0", + "nlohmann_json/3.11.2", + "boost/1.82.0", + "fmt/9.1.0", + "openssl/1.1.1t", + "libcurl/7.86.0", + "opentelemetry-cpp/1.8.1.1@milvus/dev", + "prometheus-cpp/1.1.0", + "re2/20230301", + "simdjson/3.7.0", + "arrow/12.0.1", + "double-conversion/3.2.1", + "marisa/0.2.6", + "zstd/1.5.4", + "yaml-cpp/0.7.0", + "libdwarf/0.9.1", + "google-cloud-cpp/2.5.0@milvus/dev", + ) + + generators = {"cmake", "cmake_find_package"} + + default_options = { + "glog:with_gflags": True, + "glog:shared": True, + "gtest:build_gmock": False, + "onetbb:tbbmalloc": False, + "onetbb:tbbproxy": False, + "boost:without_locale": False, + "boost:without_test": True, + "fmt:header_only": True, + "prometheus-cpp:with_pull": False, + "double-conversion:shared": True, + "arrow:filesystem_layer": True, + "arrow:parquet": True, + "arrow:compute": True, + "arrow:with_re2": True, + "arrow:with_zstd": True, + "arrow:with_boost": True, + "arrow:with_thrift": True, + "arrow:with_jemalloc": True, + "arrow:shared": False, + "arrow:with_s3": True, + "aws-sdk-cpp:config": True, + "aws-sdk-cpp:text-to-speech": False, + "aws-sdk-cpp:transfer": False, + } + + def imports(self): + self.copy("*.so*", "./lib", "lib") diff --git a/examples/bfloat16_example.py b/examples/bfloat16_example.py new file mode 100644 index 0000000..8237853 --- /dev/null +++ b/examples/bfloat16_example.py @@ -0,0 +1,75 @@ +import time +import random +import numpy as np +import tensorflow as tf +from pymilvus import ( + connections, + utility, + FieldSchema, CollectionSchema, DataType, + Collection, + ) +from pymilvus import MilvusClient + +from milvus.server_manager import server_manager_instance + +uri = server_manager_instance.start_and_get_uri("./local_test.db") +if uri is None: + print("Start milvus failed") + exit() + +bf16_index_types = ["FLAT"] + +default_bf16_index_params = [{"nlist": 128}] + +def gen_bf16_vectors(num, dim): + raw_vectors = [] + bf16_vectors = [] + for _ in range(num): + raw_vector = [random.random() for _ in range(dim)] + raw_vectors.append(raw_vector) + bf16_vector = tf.cast(raw_vector, dtype=tf.bfloat16).numpy() + bf16_vectors.append(bf16_vector) + return raw_vectors, bf16_vectors + +def bf16_vector_search(): + connections.connect(uri=uri) + + int64_field = FieldSchema(name="int64", dtype=DataType.INT64, is_primary=True, auto_id=True) + dim = 128 + nb = 3000 + vector_field_name = "bfloat16_vector" + bf16_vector = FieldSchema(name=vector_field_name, dtype=DataType.BFLOAT16_VECTOR, dim=dim) + schema = CollectionSchema(fields=[int64_field, bf16_vector]) + + if utility.has_collection("hello_milvus_fp16"): + utility.drop_collection("hello_milvus_fp16") + hello_milvus = Collection("hello_milvus_fp16", schema, consistency_level="Strong") + + _, vectors = gen_bf16_vectors(nb, dim) + hello_milvus.insert([vectors[:6]]) + rows = [ + {vector_field_name: vectors[6]}, + {vector_field_name: vectors[7]}, + {vector_field_name: vectors[8]}, + {vector_field_name: vectors[9]}, + {vector_field_name: vectors[10]}, + {vector_field_name: vectors[11]}, + ] + hello_milvus.insert(rows) + hello_milvus.flush() + + for i, index_type in enumerate(bf16_index_types): + index_params = default_bf16_index_params[i] + hello_milvus.create_index(vector_field_name, + index_params={"index_type": index_type, "params": index_params, "metric_type": "L2"}) + hello_milvus.load() + print("index_type = ", index_type) + res = hello_milvus.search(vectors[0:10], vector_field_name, {"metric_type": "L2"}, limit=1) + print(res) + hello_milvus.release() + hello_milvus.drop_index() + + hello_milvus.drop() + +if __name__ == "__main__": + bf16_vector_search() diff --git a/examples/binary_example.py b/examples/binary_example.py new file mode 100644 index 0000000..cd97711 --- /dev/null +++ b/examples/binary_example.py @@ -0,0 +1,76 @@ +import time +import random +import numpy as np +from pymilvus import ( + connections, + utility, + FieldSchema, CollectionSchema, DataType, + Collection, + ) + +from milvus.server_manager import server_manager_instance +uri = server_manager_instance.start_and_get_uri("./local_test.db") +if uri is None: + print("Start milvus failed") + exit() + + +bin_index_types = ["BIN_FLAT"] + +default_bin_index_params = [{"nlist": 128}, {"nlist": 128}] + +def gen_binary_vectors(num, dim): + raw_vectors = [] + binary_vectors = [] + for _ in range(num): + raw_vector = [random.randint(0, 1) for _ in range(dim)] + raw_vectors.append(raw_vector) + # packs a binary-valued array into bits in a unit8 array, and bytes array_of_ints + binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist())) + return raw_vectors, binary_vectors + + +def binary_vector_search(): + connections.connect(uri=uri) + + int64_field = FieldSchema(name="int64", dtype=DataType.INT64, is_primary=True, auto_id=True) + dim = 128 + nb = 3000 + vector_field_name = "binary_vector" + binary_vector = FieldSchema(name=vector_field_name, dtype=DataType.BINARY_VECTOR, dim=dim) + schema = CollectionSchema(fields=[int64_field, binary_vector], enable_dynamic_field=True) + + has = utility.has_collection("hello_milvus") + if has: + hello_milvus = Collection("hello_milvus_bin") + hello_milvus.drop() + else: + hello_milvus = Collection("hello_milvus_bin", schema) + + _, vectors = gen_binary_vectors(nb, dim) + rows = [ + {vector_field_name: vectors[0]}, + {vector_field_name: vectors[1]}, + {vector_field_name: vectors[2]}, + {vector_field_name: vectors[3]}, + {vector_field_name: vectors[4]}, + {vector_field_name: vectors[5]}, + ] + + hello_milvus.insert(rows) + hello_milvus.flush() + for i, index_type in enumerate(bin_index_types): + index_params = default_bin_index_params[i] + hello_milvus.create_index(vector_field_name, + index_params={"index_type": index_type, "params": index_params, "metric_type": "HAMMING"}) + hello_milvus.load() + print("index_type = ", index_type) + res = hello_milvus.search(vectors[:1], vector_field_name, {"metric_type": "HAMMING"}, limit=1) + print("res = ", res) + hello_milvus.release() + hello_milvus.drop_index() + hello_milvus.drop() + + +if __name__ == "__main__": + binary_vector_search() diff --git a/examples/customize_schema.py b/examples/customize_schema.py new file mode 100644 index 0000000..7eebabe --- /dev/null +++ b/examples/customize_schema.py @@ -0,0 +1,78 @@ +import time +import numpy as np +from pymilvus import ( + MilvusClient, + DataType +) +from milvus.server_manager import server_manager_instance + +uri = server_manager_instance.start_and_get_uri("./local_test.db") +if uri is None: + print("Start milvus failed") + exit() + + +fmt = "\n=== {:30} ===\n" +dim = 8 +collection_name = "hello_milvus" +milvus_client = MilvusClient(uri) + + +has_collection = milvus_client.has_collection(collection_name, timeout=5) +if has_collection: + milvus_client.drop_collection(collection_name) + +schema = milvus_client.create_schema(enable_dynamic_field=True) +schema.add_field("id", DataType.INT64, is_primary=True) +schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=dim) +schema.add_field("title", DataType.VARCHAR, max_length=64) + + +index_params = milvus_client.prepare_index_params() +index_params.add_index(field_name = "embeddings", metric_type="L2") +milvus_client.create_collection(collection_name, schema=schema, index_params=index_params, consistency_level="Strong") + +print(fmt.format(" all collections ")) +print(milvus_client.list_collections()) + +print(fmt.format(f"schema of collection {collection_name}")) +print(milvus_client.describe_collection(collection_name)) + +rng = np.random.default_rng(seed=19530) +rows = [ + {"id": 1, "embeddings": rng.random((1, dim))[0], "a": 100, "title": "t1"}, + {"id": 2, "embeddings": rng.random((1, dim))[0], "b": 200, "title": "t2"}, + {"id": 3, "embeddings": rng.random((1, dim))[0], "c": 300, "title": "t3"}, + {"id": 4, "embeddings": rng.random((1, dim))[0], "d": 400, "title": "t4"}, + {"id": 5, "embeddings": rng.random((1, dim))[0], "e": 500, "title": "t5"}, + {"id": 6, "embeddings": rng.random((1, dim))[0], "f": 600, "title": "t6"}, +] + +print(fmt.format("Start inserting entities")) +insert_result = milvus_client.insert(collection_name, rows) +print(fmt.format("Inserting entities done")) +print(insert_result) + + +print(fmt.format("Start load collection ")) +milvus_client.load_collection(collection_name) + +print(fmt.format("Start query by specifying primary keys")) +query_results = milvus_client.query(collection_name, ids=[2]) +print(query_results[0]) + +print(fmt.format("Start query by specifying filtering expression")) +query_results = milvus_client.query(collection_name, filter= "f == 600 or title == 't2'") +for ret in query_results: + print(ret) + +rng = np.random.default_rng(seed=19530) +vectors_to_search = rng.random((1, dim)) + +print(fmt.format(f"Start search with retrieve serveral fields.")) +result = milvus_client.search(collection_name, vectors_to_search, limit=3, output_fields=["pk", "a", "b"]) +for hits in result: + for hit in hits: + print(f"hit: {hit}") + +milvus_client.drop_collection(collection_name) diff --git a/examples/customize_schema_auto_id.py b/examples/customize_schema_auto_id.py new file mode 100644 index 0000000..f7f9ce5 --- /dev/null +++ b/examples/customize_schema_auto_id.py @@ -0,0 +1,78 @@ +import time +import numpy as np +from pymilvus import ( + MilvusClient, + DataType +) + +from milvus.server_manager import server_manager_instance + +uri = server_manager_instance.start_and_get_uri("./local_test.db") +if uri is None: + print("Start milvus failed") + exit() + +fmt = "\n=== {:30} ===\n" +dim = 8 +collection_name = "hello_milvus" + +milvus_client = MilvusClient(uri) + +has_collection = milvus_client.has_collection(collection_name, timeout=5) +if has_collection: + milvus_client.drop_collection(collection_name) + +schema = milvus_client.create_schema(enable_dynamic_field=True, auto_id=True) +schema.add_field("id", DataType.INT64, is_primary=True) +schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=dim) +schema.add_field("title", DataType.VARCHAR, max_length=64) + + +index_params = milvus_client.prepare_index_params() +index_params.add_index(field_name = "embeddings", metric_type="L2") +milvus_client.create_collection(collection_name, schema=schema, index_params=index_params, consistency_level="Strong") + +print(fmt.format(" all collections ")) +print(milvus_client.list_collections()) + +print(fmt.format(f"schema of collection {collection_name}")) +print(milvus_client.describe_collection(collection_name)) + +rng = np.random.default_rng(seed=19530) +rows = [ + {"embeddings": rng.random((1, dim))[0], "a": 100, "title": "t1"}, + {"embeddings": rng.random((1, dim))[0], "b": 200, "title": "t2"}, + {"embeddings": rng.random((1, dim))[0], "c": 300, "title": "t3"}, + {"embeddings": rng.random((1, dim))[0], "d": 400, "title": "t4"}, + {"embeddings": rng.random((1, dim))[0], "e": 500, "title": "t5"}, + {"embeddings": rng.random((1, dim))[0], "f": 600, "title": "t6"}, +] + +print(fmt.format("Start inserting entities")) +insert_result = milvus_client.insert(collection_name, rows) +print(fmt.format("Inserting entities done")) +print(insert_result) + + +print(fmt.format("Start load collection ")) +milvus_client.load_collection(collection_name) + +print(fmt.format("Start query by specifying primary keys")) +query_results = milvus_client.query(collection_name, ids=insert_result['ids'][0]) +print(query_results[0]) + +print(fmt.format("Start query by specifying filtering expression")) +query_results = milvus_client.query(collection_name, filter= "f == 600 or title == 't2'") +for ret in query_results: + print(ret) + +rng = np.random.default_rng(seed=19530) +vectors_to_search = rng.random((1, dim)) + +print(fmt.format(f"Start search with retrieve serveral fields.")) +result = milvus_client.search(collection_name, vectors_to_search, limit=3, output_fields=["pk", "a", "b"]) +for hits in result: + for hit in hits: + print(f"hit: {hit}") + +milvus_client.drop_collection(collection_name) diff --git a/examples/dynamic_field.py b/examples/dynamic_field.py new file mode 100644 index 0000000..d24062c --- /dev/null +++ b/examples/dynamic_field.py @@ -0,0 +1,100 @@ +import time +import numpy as np +from pymilvus import ( + connections, + utility, + FieldSchema, CollectionSchema, DataType, + Collection, +) + +fmt = "\n=== {:30} ===\n" +dim = 8 + +print(fmt.format("start connecting to Milvus")) +connections.connect("default", host="localhost", port="19530") + +has = utility.has_collection("hello_milvus") +print(f"Does collection hello_milvus exist in Milvus: {has}") +if has: + utility.drop_collection("hello_milvus") + +fields = [ + FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100), + FieldSchema(name="random", dtype=DataType.DOUBLE), + FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=dim) +] + +schema = CollectionSchema(fields, "hello_milvus is the simplest demo to introduce the APIs", enable_dynamic_field=True) + +print(fmt.format("Create collection `hello_milvus`")) +hello_milvus = Collection("hello_milvus", schema, consistency_level="Strong") + +################################################################################ +# 3. insert data +hello_milvus2 = Collection("hello_milvus") +print(fmt.format("Start inserting entities")) +rng = np.random.default_rng(seed=19530) + +rows = [ + {"pk": "1", "random": 1.0, "embeddings": rng.random((1, dim))[0], "a": 1}, + {"pk": "2", "random": 1.0, "embeddings": rng.random((1, dim))[0], "b": 1}, + {"pk": "3", "random": 1.0, "embeddings": rng.random((1, dim))[0], "c": 1}, + {"pk": "4", "random": 1.0, "embeddings": rng.random((1, dim))[0], "d": 1}, + {"pk": "5", "random": 1.0, "embeddings": rng.random((1, dim))[0], "e": 1}, + {"pk": "6", "random": 1.0, "embeddings": rng.random((1, dim))[0], "f": 1}, + ] + +insert_result = hello_milvus.insert(rows) + +hello_milvus.insert({"pk": "7", "random": 1.0, "embeddings": rng.random((1, dim))[0], "g": 1}) +hello_milvus.flush() +print(f"Number of entities in Milvus: {hello_milvus.num_entities}") # check the num_entites + +# 4. create index +print(fmt.format("Start Creating index IVF_FLAT")) +index = { + "index_type": "IVF_FLAT", + "metric_type": "L2", + "params": {"nlist": 128}, +} + +hello_milvus.create_index("embeddings", index) + +print(fmt.format("Start loading")) +hello_milvus.load() +# ----------------------------------------------------------------------------- +# search based on vector similarity +print(fmt.format("Start searching based on vector similarity")) + +rng = np.random.default_rng(seed=19530) +vectors_to_search = rng.random((1, dim)) +search_params = { + "metric_type": "L2", + "params": {"nprobe": 10}, +} + +start_time = time.time() +result = hello_milvus.search(vectors_to_search, "embeddings", search_params, limit=3, output_fields=["pk", "embeddings"]) +end_time = time.time() + +for hits in result: + for hit in hits: + print(f"hit: {hit}") + + +result = hello_milvus.search(vectors_to_search, "embeddings", search_params, limit=3, output_fields=["pk", "embeddings", "$meta"]) +for hits in result: + for hit in hits: + print(f"hit: {hit}") + +expr = f'pk in ["1" , "2"] || g == 1' + +print(fmt.format(f"Start query with expr `{expr}`")) +result = hello_milvus.query(expr=expr, output_fields=["random", "a", "g"]) +for hit in result: + print("hit:", hit) + +############################################################################### +# 7. drop collection +print(fmt.format("Drop collection `hello_milvus`")) +utility.drop_collection("hello_milvus") diff --git a/examples/example.py b/examples/example.py deleted file mode 100644 index 75e9901..0000000 --- a/examples/example.py +++ /dev/null @@ -1,188 +0,0 @@ -""" example.py based from pymilvus -""" - -import random - -from milvus import default_server - - -from pymilvus import ( - connections, - FieldSchema, CollectionSchema, DataType, - Collection, - utility -) - -# This example shows how to: -# 1. connect to Milvus server -# 2. create a collection -# 3. insert entities -# 4. create index -# 5. search - -# Optional, if you want store all related data to specific location -# default it wil using %APPDATA%/milvus-io/milvus-server -default_server.set_base_dir('test_milvus') - -# Optional, if you want cleanup previous data -default_server.cleanup() - -# star you milvus server -default_server.start() - -_HOST = '127.0.0.1' -# The port may be changed, by default it's 19530 -_PORT = default_server.listen_port - -# Const names -_COLLECTION_NAME = 'demo' -_ID_FIELD_NAME = 'id_field' -_VECTOR_FIELD_NAME = 'float_vector_field' - -# Vector parameters -_DIM = 128 -_INDEX_FILE_SIZE = 32 # max file size of stored index - -# Index parameters -_METRIC_TYPE = 'L2' -_INDEX_TYPE = 'IVF_FLAT' -_NLIST = 1024 -_NPROBE = 16 -_TOPK = 3 - - -# Create a Milvus connection -def create_connection(): - print(f"\nCreate connection...") - connections.connect(host=_HOST, port=_PORT) - print(f"\nList connections:") - print(connections.list_connections()) - - -# Create a collection named 'demo' -def create_collection(name, id_field, vector_field): - field1 = FieldSchema(name=id_field, dtype=DataType.INT64, description="int64", is_primary=True) - field2 = FieldSchema(name=vector_field, dtype=DataType.FLOAT_VECTOR, description="float vector", dim=_DIM, - is_primary=False) - schema = CollectionSchema(fields=[field1, field2], description="collection description") - collection = Collection(name=name, data=None, schema=schema, properties={"collection.ttl.seconds": 15}) - print("\ncollection created:", name) - return collection - - -def has_collection(name): - return utility.has_collection(name) - - -# Drop a collection in Milvus -def drop_collection(name): - collection = Collection(name) - collection.drop() - print("\nDrop collection: {}".format(name)) - - -# List all collections in Milvus -def list_collections(): - print("\nlist collections:") - print(utility.list_collections()) - - -def insert(collection, num, dim): - data = [ - [i for i in range(num)], - [[random.random() for _ in range(dim)] for _ in range(num)], - ] - collection.insert(data) - return data[1] - - -def get_entity_num(collection): - print("\nThe number of entity:") - print(collection.num_entities) - - -def create_index(collection, filed_name): - index_param = { - "index_type": _INDEX_TYPE, - "params": {"nlist": _NLIST}, - "metric_type": _METRIC_TYPE} - collection.create_index(filed_name, index_param) - print("\nCreated index:\n{}".format(collection.index().params)) - - -def drop_index(collection): - collection.drop_index() - print("\nDrop index sucessfully") - - -def load_collection(collection): - collection.load() - - -def release_collection(collection): - collection.release() - - -def search(collection, vector_field, id_field, search_vectors): - search_param = { - "data": search_vectors, - "anns_field": vector_field, - "param": {"metric_type": _METRIC_TYPE, "params": {"nprobe": _NPROBE}}, - "limit": _TOPK, - "expr": "id_field >= 0"} - results = collection.search(**search_param) - for i, result in enumerate(results): - print("\nSearch result for {}th vector: ".format(i)) - for j, res in enumerate(result): - print("Top {}: {}".format(j, res)) - - -def set_properties(collection): - collection.set_properties(properties={"collection.ttl.seconds": 1800}) - - -def main(): - # create a connection - create_connection() - - # drop collection if the collection exists - if has_collection(_COLLECTION_NAME): - drop_collection(_COLLECTION_NAME) - - # create collection - collection = create_collection(_COLLECTION_NAME, _ID_FIELD_NAME, _VECTOR_FIELD_NAME) - - # alter ttl properties of collection level - set_properties(collection) - - # show collections - list_collections() - - # insert 10000 vectors with 128 dimension - vectors = insert(collection, 10000, _DIM) - - collection.flush() - # get the number of entities - get_entity_num(collection) - - # create index - create_index(collection, _VECTOR_FIELD_NAME) - - # load data to memory - load_collection(collection) - - # search - search(collection, _VECTOR_FIELD_NAME, _ID_FIELD_NAME, vectors[:3]) - - # release memory - release_collection(collection) - - # drop collection index - drop_index(collection) - - # drop collection - drop_collection(_COLLECTION_NAME) - - -if __name__ == '__main__': - main() diff --git a/examples/float16_example.py b/examples/float16_example.py new file mode 100644 index 0000000..96e864d --- /dev/null +++ b/examples/float16_example.py @@ -0,0 +1,75 @@ +import time +import random +import numpy as np +from pymilvus import ( + connections, + utility, + FieldSchema, CollectionSchema, DataType, + Collection, + ) +from pymilvus import MilvusClient + +from milvus.server_manager import server_manager_instance + +uri = server_manager_instance.start_and_get_uri("./local_test.db") +if uri is None: + print("Start milvus failed") + exit() + +fp16_index_types = ["FLAT"] + +default_fp16_index_params = [{"nlist": 128}] + +def gen_fp16_vectors(num, dim): + raw_vectors = [] + fp16_vectors = [] + for _ in range(num): + raw_vector = [random.random() for _ in range(dim)] + raw_vectors.append(raw_vector) + fp16_vector = np.array(raw_vector, dtype=np.float16) + fp16_vectors.append(fp16_vector) + return raw_vectors, fp16_vectors + +def fp16_vector_search(): + connections.connect(uri=uri) + + int64_field = FieldSchema(name="int64", dtype=DataType.INT64, is_primary=True, auto_id=True) + dim = 128 + nb = 3000 + vector_field_name = "float16_vector" + fp16_vector = FieldSchema(name=vector_field_name, dtype=DataType.FLOAT16_VECTOR, dim=dim) + schema = CollectionSchema(fields=[int64_field, fp16_vector]) + + if utility.has_collection("hello_milvus_fp16"): + utility.drop_collection("hello_milvus_fp16") + + hello_milvus = Collection("hello_milvus_fp16", schema) + + _, vectors = gen_fp16_vectors(nb, dim) + hello_milvus.insert([vectors[:6]]) + rows = [ + {vector_field_name: vectors[6]}, + {vector_field_name: vectors[7]}, + {vector_field_name: vectors[8]}, + {vector_field_name: vectors[9]}, + {vector_field_name: vectors[10]}, + {vector_field_name: vectors[11]}, + ] + hello_milvus.insert(rows) + hello_milvus.flush() + + for i, index_type in enumerate(fp16_index_types): + index_params = default_fp16_index_params[i] + hello_milvus.create_index(vector_field_name, + index_params={"index_type": index_type, "params": index_params, "metric_type": "L2"}) + hello_milvus.load() + print("index_type = ", index_type) + res = hello_milvus.search(vectors[0:10], vector_field_name, {"metric_type": "L2"}, limit=1) + print(res) + hello_milvus.release() + hello_milvus.drop_index() + + hello_milvus.drop() + +if __name__ == "__main__": + fp16_vector_search() diff --git a/examples/fuzzy_match.py b/examples/fuzzy_match.py new file mode 100644 index 0000000..25b5d8a --- /dev/null +++ b/examples/fuzzy_match.py @@ -0,0 +1,83 @@ +from pymilvus import ( + connections, + FieldSchema, CollectionSchema, DataType, + Collection, +) + +from milvus.server_manager import server_manager_instance +uri = server_manager_instance.start_and_get_uri("./local_test.db") +if uri is None: + print("Start milvus failed") + exit() + +DIMENSION = 8 +COLLECTION_NAME = "books2" +connections.connect(uri=uri) + +fields = [ + FieldSchema(name='id', dtype=DataType.INT64, is_primary=True), + FieldSchema(name='title', dtype=DataType.VARCHAR, max_length=200), + FieldSchema(name='release_year', dtype=DataType.INT64), + FieldSchema(name='embeddings', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION), +] +schema = CollectionSchema(fields=fields, enable_dynamic_field=True) +collection = Collection(name=COLLECTION_NAME, schema=schema) + +data_rows = [ + { + "id": 1, + "title": "Lord of the Flies", + "release_year": 1954, + "embeddings": [0.64, 0.44, 0.13, 0.47, 0.74, 0.03, 0.32, 0.6], + }, + { + "id": 2, + "title": "The Great Gatsby", + "release_year": 1925, + "embeddings": [0.9, 0.45, 0.18, 0.43, 0.4, 0.4, 0.7, 0.24], + }, + { + "id": 3, + "title": "The Catcher in the Rye", + "release_year": 1951, + "embeddings": [0.43, 0.57, 0.43, 0.88, 0.84, 0.69, 0.27, 0.98], + }, + { + "id": 4, + "title": "Flipped", + "release_year": 2010, + "embeddings": [0.84, 0.69, 0.27, 0.43, 0.57, 0.43, 0.88, 0.98], + }, +] + +collection.insert(data_rows) +collection.create_index( + "embeddings", {"index_type": "FLAT", "metric_type": "L2"}) + +collection.load() + +# prefix match. +res = collection.query(expr='title like "The%"', output_fields=["id", "title"]) +print(res) + +# infix match. +res = collection.query(expr='title like "%the%"', output_fields=["id", "title"]) +print(res) + +# postfix match. +res = collection.query(expr='title like "%Rye"', output_fields=["id", "title"]) +print(res) + +# _ match any one and only one character. +res = collection.query(expr='title like "Flip_ed"', output_fields=["id", "title"]) +print(res) + +# you can create inverted index to accelerate the fuzzy match. +collection.release() +collection.create_index( + "title", {"index_type": "INVERTED"}) +collection.load() + +# _ match any one and only one character. +res = collection.query(expr='title like "Flip_ed"', output_fields=["id", "title"]) +print(res) diff --git a/examples/hello_milvus.py b/examples/hello_milvus.py new file mode 100644 index 0000000..7b6a929 --- /dev/null +++ b/examples/hello_milvus.py @@ -0,0 +1,189 @@ +# hello_milvus.py demonstrates the basic operations of PyMilvus, a Python SDK of Milvus. +# 1. connect to Milvus +# 2. create collection +# 3. insert data +# 4. create index +# 5. search, query, and hybrid search on entities +# 6. delete entities by PK +# 7. drop collection +import time + +import numpy as np +from pymilvus import ( + connections, + utility, + FieldSchema, CollectionSchema, DataType, + Collection, +) + +from milvus.server_manager import server_manager_instance +uri = server_manager_instance.start_and_get_uri("./110.db") +if uri is None: + print("Start milvus failed") + exit() + +fmt = "\n=== {:30} ===\n" +search_latency_fmt = "search latency = {:.4f}s" +num_entities, dim = 3000, 8 + +################################################################################# +# 1. connect to Milvus +# Add a new connection alias `default` for Milvus server in `localhost:19530` +# Actually the "default" alias is a buildin in PyMilvus. +# If the address of Milvus is the same as `localhost:19530`, you can omit all +# parameters and call the method as: `connections.connect()`. +# +# Note: the `using` parameter of the following methods is default to "default". +print(fmt.format("start connecting to Milvus")) +connections.connect("default", uri=uri) +# connections.connect("default", host="localhost", port="19530") + +has = utility.has_collection("hello_milvus") +print(f"Does collection hello_milvus exist in Milvus: {has}") + +################################################################################# +# 2. create collection +# We're going to create a collection with 3 fields. +# +-+------------+------------+------------------+------------------------------+ +# | | field name | field type | other attributes | field description | +# +-+------------+------------+------------------+------------------------------+ +# |1| "pk" | VarChar | is_primary=True | "primary field" | +# | | | | auto_id=False | | +# +-+------------+------------+------------------+------------------------------+ +# |2| "random" | Double | | "a double field" | +# +-+------------+------------+------------------+------------------------------+ +# |3|"embeddings"| FloatVector| dim=8 | "float vector with dim 8" | +# +-+------------+------------+------------------+------------------------------+ +fields = [ + FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100), + FieldSchema(name="random", dtype=DataType.DOUBLE), + FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=dim) +] + +schema = CollectionSchema(fields, "hello_milvus is the simplest demo to introduce the APIs") + +print(fmt.format("Create collection `hello_milvus`")) +hello_milvus = Collection("hello_milvus", schema, consistency_level="Strong") +# works ok +hello_milvus = Collection("hello_milvus", schema, consistency_level="Strong") + +################################################################################ +# 3. insert data +# We are going to insert 3000 rows of data into `hello_milvus` +# Data to be inserted must be organized in fields. +# +# The insert() method returns: +# - either automatically generated primary keys by Milvus if auto_id=True in the schema; +# - or the existing primary key field from the entities if auto_id=False in the schema. + +print(fmt.format("Start inserting entities")) +rng = np.random.default_rng(seed=19530) +entities = [ + # provide the pk field because `auto_id` is set to False + [str(i) for i in range(num_entities)], + rng.random(num_entities).tolist(), # field random, only supports list + rng.random((num_entities, dim)), # field embeddings, supports numpy.ndarray and list +] + +insert_result = hello_milvus.insert(entities) + +# hello_milvus.flush() +# print(f"Number of entities in Milvus: {hello_milvus.num_entities}") # check the num_entities + +################################################################################ +# 4. create index +# We are going to create an IVF_FLAT index for hello_milvus collection. +# create_index() can only be applied to `FloatVector` and `BinaryVector` fields. +print(fmt.format("Start Creating index FLAT")) +index = { + "index_type": "FLAT", + "metric_type": "L2", + "params": {"nlist": 128}, +} + +hello_milvus.create_index("embeddings", index) + +################################################################################ +# 5. search, query, and hybrid search +# After data were inserted into Milvus and indexed, you can perform: +# - search based on vector similarity +# - query based on scalar filtering(boolean, int, etc.) +# - hybrid search based on vector similarity and scalar filtering. +# + +# Before conducting a search or a query, you need to load the data in `hello_milvus` into memory. +print(fmt.format("Start loading")) +hello_milvus.load() + +# ----------------------------------------------------------------------------- +# search based on vector similarity +print(fmt.format("Start searching based on vector similarity")) +vectors_to_search = entities[-1][-2:] +search_params = { + "metric_type": "L2", + "params": {"nprobe": 10}, +} + +start_time = time.time() +result = hello_milvus.search(vectors_to_search, "embeddings", search_params, limit=3, output_fields=["random"]) +end_time = time.time() + +for hits in result: + for hit in hits: + print(f"hit: {hit}, random field: {hit.entity.get('random')}") +print(search_latency_fmt.format(end_time - start_time)) + +# ----------------------------------------------------------------------------- +# query based on scalar filtering(boolean, int, etc.) +print(fmt.format("Start querying with `random > 0.5`")) + +start_time = time.time() +result = hello_milvus.query(expr="random > 0.5", output_fields=["random", "embeddings"]) +end_time = time.time() + +print(f"query result:\n-{result[0]}") +print(search_latency_fmt.format(end_time - start_time)) + +# ----------------------------------------------------------------------------- +# pagination +r1 = hello_milvus.query(expr="random > 0.5", limit=4, output_fields=["random"]) +r2 = hello_milvus.query(expr="random > 0.5", offset=1, limit=3, output_fields=["random"]) +print(f"query pagination(limit=4):\n\t{r1}") +print(f"query pagination(offset=1, limit=3):\n\t{r2}") + + +# ----------------------------------------------------------------------------- +# hybrid search +print(fmt.format("Start hybrid searching with `random > 0.5`")) + +start_time = time.time() +result = hello_milvus.search(vectors_to_search, "embeddings", search_params, limit=3, expr="random > 0.5", output_fields=["random"]) +end_time = time.time() + +for hits in result: + for hit in hits: + print(f"hit: {hit}, random field: {hit.entity.get('random')}") +print(search_latency_fmt.format(end_time - start_time)) + +############################################################################### +# 6. delete entities by PK +# You can delete entities by their PK values using boolean expressions. +ids = insert_result.primary_keys + +expr = f'pk in ["{ids[0]}" , "{ids[1]}"]' +print(fmt.format(f"Start deleting with expr `{expr}`")) + +result = hello_milvus.query(expr=expr, output_fields=["random", "embeddings"]) +print(f"query before delete by expr=`{expr}` -> result: \n-{result[0]}\n-{result[1]}\n") + +hello_milvus.delete(expr) + +result = hello_milvus.query(expr=expr, output_fields=["random", "embeddings"]) +print(f"query after delete by expr=`{expr}` -> result: {result}\n") + + +############################################################################### +# 7. drop collection +# Finally, drop the hello_milvus collection +print(fmt.format("Drop collection `hello_milvus`")) +utility.drop_collection("hello_milvus") diff --git a/examples/hello_milvus_array.py b/examples/hello_milvus_array.py new file mode 100644 index 0000000..5473cb9 --- /dev/null +++ b/examples/hello_milvus_array.py @@ -0,0 +1,88 @@ +from pymilvus import CollectionSchema, FieldSchema, Collection, connections, DataType, Partition, utility +import numpy as np +import random +import pandas as pd + +# connections.connect() + +from milvus.server_manager import server_manager_instance +uri = server_manager_instance.start_and_get_uri("./local_test.db") +if uri is None: + print("Start milvus failed") + exit() + +connections.connect(uri=uri) + +dim = 128 +collection_name = "test_array" +arr_len = 100 +nb = 10 +if utility.has_collection(collection_name): + utility.drop_collection(collection_name) +# create collection +pk_field = FieldSchema(name="int64", dtype=DataType.INT64, is_primary=True, description='pk') +vector_field = FieldSchema(name="float_vector", dtype=DataType.FLOAT_VECTOR, dim=dim) +int8_array = FieldSchema(name="int8_array", dtype=DataType.ARRAY, element_type=DataType.INT8, max_capacity=arr_len) +int16_array = FieldSchema(name="int16_array", dtype=DataType.ARRAY, element_type=DataType.INT16, max_capacity=arr_len) +int32_array = FieldSchema(name="int32_array", dtype=DataType.ARRAY, element_type=DataType.INT32, max_capacity=arr_len) +int64_array = FieldSchema(name="int64_array", dtype=DataType.ARRAY, element_type=DataType.INT64, max_capacity=arr_len) +bool_array = FieldSchema(name="bool_array", dtype=DataType.ARRAY, element_type=DataType.BOOL, max_capacity=arr_len) +float_array = FieldSchema(name="float_array", dtype=DataType.ARRAY, element_type=DataType.FLOAT, max_capacity=arr_len) +double_array = FieldSchema(name="double_array", dtype=DataType.ARRAY, element_type=DataType.DOUBLE, max_capacity=arr_len) +string_array = FieldSchema(name="string_array", dtype=DataType.ARRAY, element_type=DataType.VARCHAR, max_capacity=arr_len, + max_length=100) + +fields = [pk_field, vector_field, int8_array, int16_array, int32_array, int64_array, + bool_array, float_array, double_array, string_array] + +schema = CollectionSchema(fields=fields) +collection = Collection(collection_name, schema=schema) + +# insert data +pk_value = [i for i in range(nb)] +vector_value = [[random.random() for _ in range(dim)] for i in range(nb)] +int8_value = [[np.int8(j) for j in range(arr_len)] for i in range(nb)] +int16_value = [[np.int16(j) for j in range(arr_len)] for i in range(nb)] +int32_value = [[np.int32(j) for j in range(arr_len)] for i in range(nb)] +int64_value = [[np.int64(j) for j in range(arr_len)] for i in range(nb)] +bool_value = [[np.bool_(j) for j in range(arr_len)] for i in range(nb)] +float_value = [[np.float32(j) for j in range(arr_len)] for i in range(nb)] +double_value = [[np.double(j) for j in range(arr_len)] for i in range(nb)] +string_value = [[str(j) for j in range(arr_len)] for i in range(nb)] + +data = [pk_value, vector_value, + int8_value,int16_value, int32_value, int64_value, + bool_value, + float_value, + double_value, + string_value + ] + +#collection.insert(data) + +data = pd.DataFrame({ + 'int64': pk_value, + 'float_vector': vector_value, + "int8_array": int8_value, + "int16_array": int16_value, + "int32_array": int32_value, + "int64_array": int64_value, + "bool_array": bool_value, + "float_array": float_value, + "double_array": double_value, + "string_array": string_value +}) +collection.insert(data) + +index = { + "index_type": "FLAT", + "metric_type": "L2", + "params": {"nlist": 128}, +} + +collection.create_index("float_vector", index) +collection.load() + +res = collection.query("int64 >= 0", output_fields=["int8_array"]) +for hits in res: + print(hits) diff --git a/examples/hello_milvus_delete.py b/examples/hello_milvus_delete.py new file mode 100644 index 0000000..9cfdc71 --- /dev/null +++ b/examples/hello_milvus_delete.py @@ -0,0 +1,80 @@ +import time +import numpy as np +from pymilvus import ( + MilvusClient, + exceptions +) + +from milvus.server_manager import server_manager_instance +uri = server_manager_instance.start_and_get_uri("./local_test.db") +if uri is None: + print("Start milvus failed") + exit() + +fmt = "\n=== {:30} ===\n" +dim = 8 +collection_name = "hello_milvus" +milvus_client = MilvusClient(uri) +milvus_client.drop_collection(collection_name) +milvus_client.create_collection(collection_name, dim, consistency_level="Strong", metric_type="L2") + +print("collections:", milvus_client.list_collections()) +print(f"{collection_name} :", milvus_client.describe_collection(collection_name)) +rng = np.random.default_rng(seed=19530) + +rows = [ + {"id": 1, "vector": rng.random((1, dim))[0], "a": 1}, + {"id": 2, "vector": rng.random((1, dim))[0], "b": 2}, + {"id": 3, "vector": rng.random((1, dim))[0], "c": 3}, + {"id": 4, "vector": rng.random((1, dim))[0], "d": 4}, + {"id": 5, "vector": rng.random((1, dim))[0], "e": 5}, + {"id": 6, "vector": rng.random((1, dim))[0], "f": 6}, +] + +print(fmt.format("Start inserting entities")) +pks = milvus_client.insert(collection_name, rows, progress_bar=True)['ids'] +pks2 = milvus_client.insert(collection_name, {"id": 7, "vector": rng.random((1, dim))[0], "g": 1})['ids'] +pks.extend(pks2) + + +def fetch_data_by_pk(pk): + print(f"get primary key {pk} from {collection_name}") + pk_data = milvus_client.get(collection_name, pk) + + if pk_data: + print(f"data of primary key {pk} is", pk_data[0]) + else: + print(f"data of primary key {pk} is empty") + +fetch_data_by_pk(pks[2]) + +print(f"start to delete primary key {pks[2]} in collection {collection_name}") +milvus_client.delete(collection_name, pks = pks[2]) + +fetch_data_by_pk(pks[2]) + + +fetch_data_by_pk(pks[4]) +filter = "e == 5 or f == 6" +print(f"start to delete by expr {filter} in collection {collection_name}") +milvus_client.delete(collection_name, filter=filter) + +fetch_data_by_pk(pks[4]) + +print(f"start to delete by expr '{filter}' or by primary 4 in collection {collection_name}, expect get exception") +try: + milvus_client.delete(collection_name, pks = 4, filter=filter) +except Exception as e: + assert isinstance(e, exceptions.ParamError) + print("catch exception", e) + +print(f"start to delete without specify any expr '{filter}' or any primary key in collection {collection_name}, expect get exception") +try: + milvus_client.delete(collection_name) +except Exception as e: + print("catch exception", e) + +result = milvus_client.query(collection_name, "", output_fields = ["count(*)"]) +print(f"final entities in {collection_name} is {result[0]['count(*)']}") + +milvus_client.drop_collection(collection_name) diff --git a/examples/index.py b/examples/index.py new file mode 100644 index 0000000..481d6d9 --- /dev/null +++ b/examples/index.py @@ -0,0 +1,95 @@ +from milvus.server_manager import server_manager_instance + +import time +import numpy as np +from pymilvus import ( + MilvusClient, + DataType +) + +uri = server_manager_instance.start_and_get_uri("./local_test.db") +if uri is None: + print("Start milvus failed") + exit() + +fmt = "\n=== {:30} ===\n" +dim = 8 +collection_name = "hello_milvus" + +milvus_client = MilvusClient(uri) + +has_collection = milvus_client.has_collection(collection_name, timeout=5) +if has_collection: + milvus_client.drop_collection(collection_name) + +schema = milvus_client.create_schema(enable_dynamic_field=True) +schema.add_field("id", DataType.INT64, is_primary=True) +schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=dim) +schema.add_field("title", DataType.VARCHAR, max_length=64) + +# collection is not loaded after creation +milvus_client.create_collection(collection_name, schema=schema, consistency_level="Strong") + +rng = np.random.default_rng(seed=19530) +rows = [ + {"id": 1, "embeddings": rng.random((1, dim))[0], "a": 100, "title": "t1"}, + {"id": 2, "embeddings": rng.random((1, dim))[0], "b": 200, "title": "t2"}, + {"id": 3, "embeddings": rng.random((1, dim))[0], "c": 300, "title": "t3"}, + {"id": 4, "embeddings": rng.random((1, dim))[0], "d": 400, "title": "t4"}, + {"id": 5, "embeddings": rng.random((1, dim))[0], "e": 500, "title": "t5"}, + {"id": 6, "embeddings": rng.random((1, dim))[0], "f": 600, "title": "t6"}, +] + +print(fmt.format("Start inserting entities")) +insert_result = milvus_client.insert(collection_name, rows) +print(fmt.format("Inserting entities done")) +print(insert_result) + +index_params = milvus_client.prepare_index_params() +index_params.add_index(field_name = "embeddings", metric_type="L2") +index_params.add_index(field_name = "title", index_type = "Trie", index_name="my_trie") + +print(fmt.format("Start create index")) +milvus_client.create_index(collection_name, index_params) + + +index_names = milvus_client.list_indexes(collection_name) +print(f"index names for {collection_name}:", index_names) +for index_name in index_names: + index_info = milvus_client.describe_index(collection_name, index_name=index_name) + print(f"index info for index {index_name} is:", index_info) + +print(fmt.format("Start load collection")) +milvus_client.load_collection(collection_name) + +print(fmt.format("Start query by specifying primary keys")) +query_results = milvus_client.query(collection_name, ids=[2]) +print(query_results[0]) + +print(fmt.format("Start query by specifying filtering expression")) +query_results = milvus_client.query(collection_name, filter= "f == 600 or title == 't2'") +for ret in query_results: + print(ret) + +vectors_to_search = rng.random((1, dim)) +print(fmt.format(f"Start search with retrieve serveral fields.")) +result = milvus_client.search(collection_name, vectors_to_search, limit=3, output_fields=["title"]) +for hits in result: + for hit in hits: + print(f"hit: {hit}") + + + +field_index_names = milvus_client.list_indexes(collection_name, field_name = "embeddings") +print(f"index names for {collection_name}`s field embeddings:", field_index_names) + +try: + milvus_client.drop_index(collection_name, "my_trie") +except Exception as e: + print(f"cacthed {e}") + +milvus_client.release_collection(collection_name) + +milvus_client.drop_index(collection_name, "my_trie") + +milvus_client.drop_collection(collection_name) diff --git a/examples/non_ascii_encode.py b/examples/non_ascii_encode.py new file mode 100644 index 0000000..a240421 --- /dev/null +++ b/examples/non_ascii_encode.py @@ -0,0 +1,41 @@ +import numpy as np +from pymilvus import MilvusClient, DataType + +from milvus.server_manager import server_manager_instance +uri = server_manager_instance.start_and_get_uri("./local_test.db") +if uri is None: + print("Start milvus failed") + exit() + +dimension = 128 +collection_name = "books" +client = MilvusClient(uri=uri) +client.drop_collection(collection_name) + +schema = client.create_schema(auto_id=True) +schema.add_field("id", DataType.INT64, is_primary=True) +schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=dimension) +schema.add_field("info", DataType.JSON) + +index_params = client.prepare_index_params("embeddings", metric_type="L2") +client.create_collection(collection_name, schema=schema, index_params=index_params) + +rng = np.random.default_rng(seed=19530) +rows = [ + {"embeddings": rng.random((1, dimension))[0], + "info": {"title": "Lord of the Flies", "author": "William Golding"}}, + + {"embeddings": rng.random((1, dimension))[0], + "info": {"作者": "J.D.塞林格", "title": "麦田里的守望者", }}, + + {"embeddings": rng.random((1, dimension))[0], + "info": {"Título": "Cien años de soledad", "autor": "Gabriel García Márquez"}}, +] + +client.insert(collection_name, rows) +result = client.query(collection_name, filter="info['作者'] == 'J.D.塞林格' or info['Título'] == 'Cien años de soledad'", + output_fields=["info"], + consistency_level="Strong") + +for hit in result: + print(f"hit: {hit}") diff --git a/examples/simple.py b/examples/simple.py new file mode 100644 index 0000000..82d8487 --- /dev/null +++ b/examples/simple.py @@ -0,0 +1,81 @@ +from milvus.server_manager import server_manager_instance +import time +import numpy as np +from pymilvus import ( + MilvusClient, +) + +uri = server_manager_instance.start_and_get_uri("./local_test.db") +if uri is None: + print("Start milvus failed") + exit() + + +fmt = "\n=== {:30} ===\n" +dim = 8 +collection_name = "hello_milvus" +milvus_client = MilvusClient(uri) + +has_collection = milvus_client.has_collection(collection_name, timeout=5) +if has_collection: + milvus_client.drop_collection(collection_name) +milvus_client.create_collection(collection_name, dim, consistency_level="Strong", metric_type="L2") + +print(fmt.format(" all collections ")) +print(milvus_client.list_collections()) + +print(fmt.format(f"schema of collection {collection_name}")) +print(milvus_client.describe_collection(collection_name)) + +rng = np.random.default_rng(seed=19530) +rows = [ + {"id": 1, "vector": rng.random((1, dim))[0], "a": 100}, + {"id": 2, "vector": rng.random((1, dim))[0], "b": 200}, + {"id": 3, "vector": rng.random((1, dim))[0], "c": 300}, + {"id": 4, "vector": rng.random((1, dim))[0], "d": 400}, + {"id": 5, "vector": rng.random((1, dim))[0], "e": 500}, + {"id": 6, "vector": rng.random((1, dim))[0], "f": 600}, +] + +print(fmt.format("Start inserting entities")) +insert_result = milvus_client.insert(collection_name, rows, progress_bar=True) +print(fmt.format("Inserting entities done")) +print(insert_result) + +print(fmt.format("Start query by specifying primary keys")) +query_results = milvus_client.query(collection_name, ids=[2]) +print(query_results[0]) + +# upsert_ret = milvus_client.upsert(collection_name, {"id": 2 , "vector": rng.random((1, dim))[0], "g": 100}) +# print(upsert_ret) + +print(fmt.format("Start query by specifying primary keys")) +query_results = milvus_client.query(collection_name, ids=[2]) +print(query_results[0]) + + + +print(fmt.format("Start query by specifying filtering expression")) +query_results = milvus_client.query(collection_name, filter= "f == 600") +for ret in query_results: + print(ret) + + +print(f"start to delete by specifying filter in collection {collection_name}") +delete_result = milvus_client.delete(collection_name, ids=[6]) +print(delete_result) + +print(fmt.format("Start query by specifying filtering expression")) +query_results = milvus_client.query(collection_name, filter= "f == 600") +assert len(query_results) == 0 + +rng = np.random.default_rng(seed=19530) +vectors_to_search = rng.random((1, dim)) + +print(fmt.format(f"Start search with retrieve serveral fields.")) +result = milvus_client.search(collection_name, vectors_to_search, limit=3, output_fields=["pk", "a", "b"]) +for hits in result: + for hit in hits: + print(f"hit: {hit}") + +milvus_client.drop_collection(collection_name) diff --git a/examples/simple_auto_id.py b/examples/simple_auto_id.py new file mode 100644 index 0000000..6e37051 --- /dev/null +++ b/examples/simple_auto_id.py @@ -0,0 +1,66 @@ +from milvus.server_manager import server_manager_instance +import time +import numpy as np +from pymilvus import ( + MilvusClient, +) + +uri = server_manager_instance.start_and_get_uri("./local_test.db") +if uri is None: + print("Start milvus failed") + exit() + +fmt = "\n=== {:30} ===\n" +dim = 8 +collection_name = "hello_milvus" + +milvus_client = MilvusClient(uri) + +has_collection = milvus_client.has_collection(collection_name, timeout=5) +if has_collection: + milvus_client.drop_collection(collection_name) +milvus_client.create_collection(collection_name, dim, consistency_level="Strong", metric_type="L2", auto_id=True) + +print(fmt.format(" all collections ")) +print(milvus_client.list_collections()) + +print(fmt.format(f"schema of collection {collection_name}")) +print(milvus_client.describe_collection(collection_name)) + +rng = np.random.default_rng(seed=19530) +rows = [ + {"vector": rng.random((1, dim))[0], "a": 100}, + {"vector": rng.random((1, dim))[0], "b": 200}, + {"vector": rng.random((1, dim))[0], "c": 300}, + {"vector": rng.random((1, dim))[0], "d": 400}, + {"vector": rng.random((1, dim))[0], "e": 500}, + {"vector": rng.random((1, dim))[0], "f": 600}, +] + +print(fmt.format("Start inserting entities")) +insert_result = milvus_client.insert(collection_name, rows, progress_bar=True) +print("insert done:", insert_result) + +print(fmt.format("Start query by specifying filter")) +query_results = milvus_client.query(collection_name, filter= "f == 600") +for ret in query_results: + print(ret) + +print(f"start to delete by specifying filter in collection {collection_name}") +delete_result = milvus_client.delete(collection_name, filter = "f == 600") +print(delete_result) + +print(fmt.format("Start query by specifying filtering expression")) +query_results = milvus_client.query(collection_name, filter= "f == 600") +assert len(query_results) == 0 + +rng = np.random.default_rng(seed=19530) +vectors_to_search = rng.random((1, dim)) + +print(fmt.format(f"Start search with retrieve serveral fields.")) +result = milvus_client.search(collection_name, vectors_to_search, limit=3, output_fields=["pk", "a", "b"]) +for hits in result: + for hit in hits: + print(f"hit: {hit}") + +milvus_client.drop_collection(collection_name) diff --git a/examples/sparse.py b/examples/sparse.py new file mode 100644 index 0000000..78d44f5 --- /dev/null +++ b/examples/sparse.py @@ -0,0 +1,99 @@ +from pymilvus import ( + MilvusClient, + FieldSchema, CollectionSchema, DataType, +) +import random + +from milvus.server_manager import server_manager_instance +uri = server_manager_instance.start_and_get_uri("./local_test.db") +if uri is None: + print("Start milvus failed") + exit() + +def generate_sparse_vector(dimension: int, non_zero_count: int) -> dict: + indices = random.sample(range(dimension), non_zero_count) + values = [random.random() for _ in range(non_zero_count)] + sparse_vector = {index: value for index, value in zip(indices, values)} + return sparse_vector + + +fmt = "\n=== {:30} ===\n" +dim = 100 +non_zero_count = 20 +collection_name = "hello_sparse" +milvus_client = MilvusClient(uri) + +has_collection = milvus_client.has_collection(collection_name, timeout=5) +if has_collection: + milvus_client.drop_collection(collection_name) +fields = [ + FieldSchema(name="pk", dtype=DataType.VARCHAR, + is_primary=True, auto_id=True, max_length=100), + + # FieldSchema(name="pk", dtype=DataType.INT64, + # is_primary=True, auto_id=True), + FieldSchema(name="random", dtype=DataType.DOUBLE), + FieldSchema(name="embeddings", dtype=DataType.SPARSE_FLOAT_VECTOR), +] +schema = CollectionSchema( + fields, "demo for using sparse float vector with milvus client") +index_params = milvus_client.prepare_index_params() +index_params.add_index(field_name="embeddings", index_name="sparse_inverted_index", + index_type="SPARSE_INVERTED_INDEX", metric_type="IP", params={"drop_ratio_build": 0.2}) +milvus_client.create_collection(collection_name, schema=schema, + index_params=index_params, timeout=5, consistency_level="Strong") + +print(fmt.format(" all collections ")) +print(milvus_client.list_collections()) + +print(fmt.format(f"schema of collection {collection_name}")) +print(milvus_client.describe_collection(collection_name)) + +N = 6 +rows = [{"random": i, "embeddings": generate_sparse_vector( + dim, non_zero_count)} for i in range(N)] + +print(fmt.format("Start inserting entities")) +insert_result = milvus_client.insert(collection_name, rows, progress_bar=True) +print(fmt.format("Inserting entities done")) +print(insert_result) + +print(fmt.format(f"Start vector anns search.")) +vectors_to_search = [generate_sparse_vector(dim, non_zero_count)] +search_params = { + "metric_type": "IP", + "params": { + "drop_ratio_search": 0.2, + } +} +# no need to specify anns_field for collections with only 1 vector field +result = milvus_client.search(collection_name, vectors_to_search, limit=3, output_fields=[ + "pk", "random", "embeddings"], search_params=search_params) +for hits in result: + for hit in hits: + print(f"hit: {hit}") + +print(fmt.format("Start query by specifying filtering expression")) +query_results = milvus_client.query(collection_name, filter="random < 3") +pks = [ret['pk'] for ret in query_results] +for ret in query_results: + print(ret) + +# print(fmt.format("Start query by specifying primary keys")) +# query_results = milvus_client.query( +# collection_name, filter=f"pk == '{pks[0]}'") +# print(query_results[0]) + +print(f"start to delete by specifying filter in collection {collection_name}") +print(pks[:1], 'xxxxxxxxxxxxxxxxxxxxxxxx') +print(milvus_client.query(collection_name, ids=pks[:1])) +delete_result = milvus_client.delete(collection_name, ids=pks[:1]) +print(delete_result) +print(milvus_client.query(collection_name, ids=pks[:1])) + +# print(fmt.format("Start query by specifying primary keys")) +# query_results = milvus_client.query( +# collection_name, filter=f"pk == '{pks[0]}'") +# print(f'query result should be empty: {query_results}') + +milvus_client.drop_collection(collection_name) diff --git a/milvus_binary/.gitignore b/milvus_binary/.gitignore deleted file mode 100644 index 2cebab2..0000000 --- a/milvus_binary/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -milvus/ -output/ \ No newline at end of file diff --git a/milvus_binary/build.sh b/milvus_binary/build.sh deleted file mode 100644 index 27f6f8d..0000000 --- a/milvus_binary/build.sh +++ /dev/null @@ -1,243 +0,0 @@ -#!/bin/bash - -export LANG=en_US.utf-8 -set -e - -build_dir=$(cd $(dirname $0); pwd) -cd ${build_dir} - -# load envs -. env.sh - -# getopts -while getopts "fr:b:p:" arg; do - case $arg in - f) - BUILD_FORCE=YES - ;; - r) - MILVUS_REPO=$OPTARG - ;; - b) - MILVUS_VERSION=$OPTARG - ;; - p) - BUILD_PROXY=$OPTARG - ;; - *) - ;; - esac -done - -# proxy if needed -if [ ! -z "${BUILD_PROXY}" ] ; then - echo using proxy during build: $BUILD_PROXY - export http_proxy=${BUILD_PROXY} - export https_proxy=${BUILD_PROXY} -fi - -# remove milvus source if build force -if [[ ${BUILD_FORCE} == "YES" ]] ; then - rm -fr milvus -fi - - -# get host -OS=$(uname -s) -ARCH=$(uname -m) - -case $OS in - Linux) - osname=linux - ;; - MINGW*) - osname=msys - ;; - Darwin) - osname=macosx - ;; - *) - osname=none - ;; -esac - -# clone milvus -if [[ ! -d milvus ]] ; then - git clone ${MILVUS_REPO} milvus - cd milvus - git checkout ${MILVUS_VERSION} - # apply milvus patch later if needed - if [ -f ../patches/milvus-${MILVUS_VERSION}.patch ] ; then - patch -p1 < ../patches/milvus-${MILVUS_VERSION}.patch - fi - cd - -fi - -# patch Makefile -if [[ "${osname}" == "macosx" ]] ; then - sed -i '' 's/-ldflags="/-ldflags="-s -w /' milvus/Makefile - sed -i '' 's/-ldflags="-s -w -s -w /-ldflags="-s -w /' milvus/Makefile - sed -i '' 's/="-dev"/="-lite"/' milvus/Makefile -else - sed 's/-ldflags="/-ldflags="-s -w /' -i milvus/Makefile - sed 's/-ldflags="-s -w -s -w /-ldflags="-s -w /' -i milvus/Makefile - sed 's/="-dev"/="-lite"/' -i milvus/Makefile -fi - -# build for linux x86_64 -function build_linux_x86_64() { - cd milvus - # conan after 2.3 - pip3 install --user "conan<2.0" - export PATH=${HOME}/.local/bin:${PATH} - make -j $(nproc) milvus - cd bin - rm -fr lib* - - has_new_file=true - while ${has_new_file} ; do - has_new_file=false - for x in $(ldd milvus | awk '{print $1}') ; do - if [[ $x =~ libc.so.* ]] ; then - : - elif [[ $x =~ libdl.so.* ]] ; then - : - elif [[ $x =~ libm.so.* ]] ; then - : - elif [[ $x =~ librt.so.* ]] ; then - : - elif [[ $x =~ libpthread.so.* ]] ; then - : - elif test -f $x ; then - : - else - echo $x - for p in ../internal/core/output/lib ../internal/core/output/lib64 /lib64 /usr/lib64 /usr/lib /usr/local/lib64 /usr/local/lib ; do - if test -f $p/$x && ! test -f $x ; then - file=$p/$x - while test -L $file ; do - file=$(realpath $file) - done - cp -frv $file $x - has_new_file=true - fi - done - fi - done - done -} - -function install_deps_for_macosx() { - bash milvus/scripts/install_deps.sh - # need this for cache binary - brew install md5sha1sum -} - -# build for macos arm64/x86_64 -build_macosx_common() { - cd milvus - make -j $(sysctl -n hw.physicalcpu) milvus - - # resolve dependencies for milvus - cd bin - rm -fr lib* - files=("milvus") - while true ; do - new_files=() - for file in ${files[@]} ; do - for line in $(otool -L $file | grep -v ${file}: | grep -v /usr/lib | grep -v /System/Library | awk '{print $1}') ; do - filename=$(basename $line) - if [[ -f ${filename} ]] ; then - continue - fi - find_in_build_dir=$(find ../cmake_build -name $filename) - if [[ ! -z "$find_in_build_dir" ]] ; then - cp -frv ${find_in_build_dir} ${filename} - new_files+=( "${filename}" ) - continue - fi - if [[ -f $line ]] ; then - cp -frv $line $filename - new_files+=( "${filename}" ) - continue - fi - done - done - if [[ ${#new_files[@]} -eq 0 ]] ; then - break - fi - for file in ${new_files[@]} ; do - files+=( ${file} ) - done - done -} - - -function build_macosx_x86_64() { - install_deps_for_macosx - build_macosx_common -} - -function build_macosx_arm64() { - install_deps_for_macosx - build_macosx_common -} - -function build_msys() { - cd milvus - bash scripts/install_deps_msys.sh - source scripts/setenv.sh - - export GOROOT=/mingw64/lib/go - go version - - make -j $(nproc) milvus - - cd bin - mv milvus milvus.exe - - find .. -name \*.dll | xargs -I {} cp -frv {} . || : - for x in $(ldd milvus.exe | awk '{print $1}') ; do - if [ -f ${MINGW_PREFIX}/bin/$x ] ; then - cp -frv ${MINGW_PREFIX}/bin/$x . - fi - done -} - -function build_milvus() { - set -e - # prepare output - cd ${build_dir} - # check if prev build ok - if [ -f output/build.txt ] ; then - cp -fr env.sh output/env.sh.txt - cp -fr build.sh output/build.sh.txt - if md5sum -c output/build.txt ; then - echo already build success, if you need rebuild it use -f flag or remove file: ${build_dir}/output/build.txt - exit 0 - fi - fi - # build for os - case $OS in - Linux) - build_linux_${ARCH} - ;; - MINGW*) - build_msys - ;; - Darwin) - build_macosx_${ARCH} - ;; - *) - ;; - esac - - cd ${build_dir} - rm -fr output && mkdir output - cp -fr env.sh output/env.sh.txt - cp -fr build.sh output/build.sh.txt - cp -fr milvus/bin/* output - md5sum output/* | grep -v build.txt > output/build.txt -} - -build_milvus diff --git a/milvus_binary/env.sh b/milvus_binary/env.sh deleted file mode 100644 index 3da67de..0000000 --- a/milvus_binary/env.sh +++ /dev/null @@ -1,5 +0,0 @@ - -MILVUS_REPO="https://github.com/milvus-io/milvus.git" -MILVUS_VERSION="v2.3.0-beta" -BUILD_PROXY= -BUILD_FORCE=NO diff --git a/milvus_binary/patches/knowhere-v1.3.11.patch b/milvus_binary/patches/knowhere-v1.3.11.patch deleted file mode 100644 index f0eb5b1..0000000 --- a/milvus_binary/patches/knowhere-v1.3.11.patch +++ /dev/null @@ -1,26 +0,0 @@ -diff --git a/CMakeLists.txt b/CMakeLists.txt -index 36f7332..6cb7e8b 100644 ---- a/CMakeLists.txt -+++ b/CMakeLists.txt -@@ -27,13 +27,6 @@ add_definitions( -DAUTO_INITIALIZE_EASYLOGGINGPP ) - if ( APPLE ) - set ( CMAKE_CROSSCOMPILING TRUE ) - set ( RUN_HAVE_GNU_POSIX_REGEX 0 ) -- if ( DEFINED ENV{HOMEBREW_PREFIX} ) -- set( APPLE_LLVM_PREFIX $ENV{HOMEBREW_PREFIX} ) -- else() -- set( APPLE_LLVM_PREFIX "/usr/local" ) -- endif() -- set ( CMAKE_C_COMPILER "${APPLE_LLVM_PREFIX}/opt/llvm/bin/clang" ) -- set ( CMAKE_CXX_COMPILER "${APPLE_LLVM_PREFIX}/opt/llvm/bin/clang++" ) - endif () - - set( CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake" ) -@@ -80,7 +73,6 @@ include( CheckCXXCompilerFlag ) - if ( ${CMAKE_SYSTEM_NAME} MATCHES "Darwin" ) - message(STATUS "MacOS") - set ( MACOS TRUE ) -- set ( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -L${APPLE_LLVM_PREFIX}/opt/libomp/lib" ) - elseif ( "${CMAKE_SYSTEM}" MATCHES "Linux" ) - message( STATUS "Linux") - set ( LINUX TRUE ) diff --git a/milvus_binary/patches/macosx-v2.2.5.patch b/milvus_binary/patches/macosx-v2.2.5.patch deleted file mode 100644 index 9e25351..0000000 --- a/milvus_binary/patches/macosx-v2.2.5.patch +++ /dev/null @@ -1,106 +0,0 @@ -diff --git a/internal/core/CMakeLists.txt b/internal/core/CMakeLists.txt -index eb1f727d2..bc47fe8a3 100644 ---- a/internal/core/CMakeLists.txt -+++ b/internal/core/CMakeLists.txt -@@ -19,8 +19,6 @@ cmake_minimum_required( VERSION 3.18 ) - if ( APPLE ) - set( CMAKE_CROSSCOMPILING TRUE ) - set( RUN_HAVE_GNU_POSIX_REGEX 0 ) -- set( CMAKE_C_COMPILER "/usr/local/opt/llvm/bin/clang" ) -- set( CMAKE_CXX_COMPILER "/usr/local/opt/llvm/bin/clang++" ) - endif () - - add_definitions(-DELPP_THREAD_SAFE) -@@ -31,7 +29,6 @@ project(core) - include(CheckCXXCompilerFlag) - if ( APPLE ) - message(STATUS "==============Darwin Environment==============") -- set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I/usr/local/opt/llvm/include -I/usr/local/include -I/usr/local/opt/libomp/include -L/usr/local/opt/libomp/lib") - elseif (${CMAKE_SYSTEM_NAME} MATCHES "Linux") - message(STATUS "==============Linux Environment===============") - set(LINUX TRUE) -diff --git a/scripts/core_build.sh b/scripts/core_build.sh -index 0c33d9439..57d343a39 100755 ---- a/scripts/core_build.sh -+++ b/scripts/core_build.sh -@@ -186,15 +186,31 @@ fi - - unameOut="$(uname -s)" - case "${unameOut}" in -- Darwin*) -- llvm_prefix="$(brew --prefix llvm)" -- export CLANG_TOOLS_PATH="${llvm_prefix}/bin" -- export CC="${llvm_prefix}/bin/clang" -- export CXX="${llvm_prefix}/bin/clang++" -- export LDFLAGS="-L${llvm_prefix}/lib -L/usr/local/opt/libomp/lib" -- export CXXFLAGS="-I${llvm_prefix}/include -I/usr/local/include -I/usr/local/opt/libomp/include" -- ;; -- *) echo "==System:${unameOut}"; -+ Darwin*) -+ # detect llvm version by valid list -+ for llvm_version in 15 14 NOT_FOUND ; do -+ if brew ls --versions llvm@${llvm_version} > /dev/null; then -+ break -+ fi -+ done -+ if [ "${llvm_version}" = "NOT_FOUND" ] ; then -+ echo "llvm@14~15 is not installed" -+ exit 1 -+ fi -+ llvm_prefix="$(brew --prefix llvm@${llvm_version})" -+ export CLANG_TOOLS_PATH="${llvm_prefix}/bin" -+ export PATH=${CLANG_TOOLS_PATH}:${PATH} -+ export CC="${llvm_prefix}/bin/clang" -+ export CXX="${llvm_prefix}/bin/clang++" -+ export CFLAGS="-Wno-deprecated-declarations -I${llvm_prefix}/include -I/usr/local/include -I$(brew --prefix libomp)/include -I$(brew --prefix boost)/include -I$(brew --prefix tbb)/include" -+ export CXXFLAGS=${CFLAGS} -+ export LDFLAGS="-L${llvm_prefix}/lib -L$(brew --prefix libomp)/lib -L$(brew --prefix boost)/lib -L$(brew --prefix tbb)/lib" -+ ;; -+ Linux*) -+ ;; -+ *) -+ echo "Cannot build on windows" -+ ;; - esac - - -diff --git a/scripts/install_deps.sh b/scripts/install_deps.sh -index 5d58bb3f5..bb7f00a65 100755 ---- a/scripts/install_deps.sh -+++ b/scripts/install_deps.sh -@@ -29,7 +29,7 @@ function install_linux_deps() { - sudo yum install -y git make lcov libtool m4 autoconf automake ccache openssl-devel zlib-devel libzstd-devel \ - libcurl-devel python3-devel \ - devtoolset-7-gcc devtoolset-7-gcc-c++ devtoolset-7-gcc-gfortran \ -- llvm-toolset-7.0-clang llvm-toolset-7.0-clang-tools-extra libuuid-devel pulseaudio-libs-devel -+ llvm-toolset-7.0-clang llvm-toolset-7.0-clang-tools-extra libuuid-devel pulseaudio-libs-devel - - echo "source scl_source enable devtoolset-7" | sudo tee -a /etc/profile.d/devtoolset-7.sh - echo "source scl_source enable llvm-toolset-7.0" | sudo tee -a /etc/profile.d/llvm-toolset-7.sh -@@ -57,25 +57,8 @@ function install_linux_deps() { - - function install_mac_deps() { - sudo xcode-select --install > /dev/null 2>&1 -- brew install boost libomp ninja tbb cmake llvm ccache zstd -- brew uninstall grep -+ brew install boost libomp ninja tbb openblas cmake llvm@15 ccache pkg-config zstd openssl librdkafka - brew install grep -- export PATH="/usr/local/opt/grep/libexec/gnubin:$PATH" -- brew update && brew upgrade && brew cleanup -- -- if [[ $(arch) == 'arm64' ]]; then -- brew install openssl -- brew install librdkafka -- brew install pkg-config -- sudo mkdir /usr/local/include -- sudo mkdir /usr/local/opt -- sudo ln -s "$(brew --prefix llvm)" "/usr/local/opt/llvm" -- sudo ln -s "$(brew --prefix libomp)/include/omp.h" "/usr/local/include/omp.h" -- sudo ln -s "$(brew --prefix libomp)" "/usr/local/opt/libomp" -- sudo ln -s "$(brew --prefix boost)/include/boost" "/usr/local/include/boost" -- sudo ln -s "$(brew --prefix tbb)/include/tbb" "/usr/local/include/tbb" -- sudo ln -s "$(brew --prefix tbb)/include/oneapi" "/usr/local/include/oneapi" -- fi - } - - if ! command -v go &> /dev/null diff --git a/milvus_binary/patches/milvus-v2.2.5.patch b/milvus_binary/patches/milvus-v2.2.5.patch deleted file mode 100644 index b9de6b5..0000000 --- a/milvus_binary/patches/milvus-v2.2.5.patch +++ /dev/null @@ -1,278 +0,0 @@ -diff --git a/go.mod b/go.mod -index 42e64c166..1035a64e6 100644 ---- a/go.mod -+++ b/go.mod -@@ -18,7 +18,6 @@ require ( - github.com/confluentinc/confluent-kafka-go v1.9.1 - github.com/containerd/cgroups v1.0.2 - github.com/gin-gonic/gin v1.7.7 -- github.com/go-basic/ipv4 v1.0.0 - github.com/gofrs/flock v0.8.1 - github.com/golang/mock v1.5.0 - github.com/golang/protobuf v1.5.2 -diff --git a/go.sum b/go.sum -index bc410f4ad..bd2f09e7a 100644 ---- a/go.sum -+++ b/go.sum -@@ -221,8 +221,6 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE - github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= - github.com/gin-gonic/gin v1.7.7 h1:3DoBmSbJbZAWqXJC3SLjAPfutPJJRN1U5pALB7EeTTs= - github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ1qq1U= --github.com/go-basic/ipv4 v1.0.0 h1:gjyFAa1USC1hhXTkPOwBWDPfMcUaIM+tvo1XzV9EZxs= --github.com/go-basic/ipv4 v1.0.0/go.mod h1:etLBnaxbidQfuqE6wgZQfs38nEWNmzALkxDZe4xY8Dg= - github.com/go-fonts/dejavu v0.1.0/go.mod h1:4Wt4I4OU2Nq9asgDCteaAaWZOV24E+0/Pwo0gppep4g= - github.com/go-fonts/latin-modern v0.2.0/go.mod h1:rQVLdDMK+mK1xscDwsqM5J8U2jrRa3T0ecnM9pNujks= - github.com/go-fonts/liberation v0.1.1/go.mod h1:K6qoJYypsmfVjWg8KOVDQhLc8UDgIK2HYqyqAO9z7GY= -diff --git a/internal/core/CMakeLists.txt b/internal/core/CMakeLists.txt -index eb1f727d2..d70ef40d4 100644 ---- a/internal/core/CMakeLists.txt -+++ b/internal/core/CMakeLists.txt -@@ -19,19 +19,21 @@ cmake_minimum_required( VERSION 3.18 ) - if ( APPLE ) - set( CMAKE_CROSSCOMPILING TRUE ) - set( RUN_HAVE_GNU_POSIX_REGEX 0 ) -- set( CMAKE_C_COMPILER "/usr/local/opt/llvm/bin/clang" ) -- set( CMAKE_CXX_COMPILER "/usr/local/opt/llvm/bin/clang++" ) - endif () - - add_definitions(-DELPP_THREAD_SAFE) - set(CMAKE_POSITION_INDEPENDENT_CODE ON) -+ -+if ( MSYS ) -+ add_definitions(-DPROTOBUF_USE_DLLS) -+endif () -+ - message( STATUS "Building using CMake version: ${CMAKE_VERSION}" ) - - project(core) - include(CheckCXXCompilerFlag) - if ( APPLE ) - message(STATUS "==============Darwin Environment==============") -- set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I/usr/local/opt/llvm/include -I/usr/local/include -I/usr/local/opt/libomp/include -L/usr/local/opt/libomp/lib") - elseif (${CMAKE_SYSTEM_NAME} MATCHES "Linux") - message(STATUS "==============Linux Environment===============") - set(LINUX TRUE) -diff --git a/internal/core/src/indexbuilder/CMakeLists.txt b/internal/core/src/indexbuilder/CMakeLists.txt -index cae7415c9..29cef2aa0 100644 ---- a/internal/core/src/indexbuilder/CMakeLists.txt -+++ b/internal/core/src/indexbuilder/CMakeLists.txt -@@ -23,6 +23,7 @@ add_library(milvus_indexbuilder SHARED ${INDEXBUILDER_FILES}) - find_library(TBB NAMES tbb) - set(PLATFORM_LIBS dl) - if (MSYS) -+find_library(TBB NAMES tbb12) - set(PLATFORM_LIBS -Wl,--allow-multiple-definition) - endif () - -diff --git a/internal/core/src/segcore/CMakeLists.txt b/internal/core/src/segcore/CMakeLists.txt -index 1b58ab37b..c85dffb50 100644 ---- a/internal/core/src/segcore/CMakeLists.txt -+++ b/internal/core/src/segcore/CMakeLists.txt -@@ -41,6 +41,7 @@ add_library(milvus_segcore SHARED ${SEGCORE_FILES}) - find_library(TBB NAMES tbb) - set(PLATFORM_LIBS dl) - if (MSYS) -+find_library(TBB NAMES tbb12) - set(PLATFORM_LIBS ) - endif () - -diff --git a/internal/core/thirdparty/knowhere/CMakeLists.txt b/internal/core/thirdparty/knowhere/CMakeLists.txt -index 05711722e..5bf6358f0 100644 ---- a/internal/core/thirdparty/knowhere/CMakeLists.txt -+++ b/internal/core/thirdparty/knowhere/CMakeLists.txt -@@ -52,6 +52,7 @@ macro(build_knowhere) - PREFIX ${CMAKE_BINARY_DIR}/3rdparty_download/knowhere-subbuild - BINARY_DIR knowhere-bin - INSTALL_DIR ${KNOWHERE_INSTALL_PREFIX} -+ PATCH_COMMAND patch -p1 < ${CMAKE_SOURCE_DIR}/../../../patches/knowhere-v1.3.11.patch - ) - - ExternalProject_Get_Property(knowhere_ep INSTALL_DIR) -diff --git a/internal/util/etcd/etcd_server.go b/internal/util/etcd/etcd_server.go -index 75f81c43e..ffd758812 100644 ---- a/internal/util/etcd/etcd_server.go -+++ b/internal/util/etcd/etcd_server.go -@@ -1,6 +1,9 @@ - package etcd - - import ( -+ "net/url" -+ "os" -+ "runtime" - "sync" - - "github.com/milvus-io/milvus/internal/log" -@@ -45,6 +48,12 @@ func InitEtcdServer( - } else { - cfg = embed.NewConfig() - } -+ if runtime.GOOS == "windows" { -+ err := zap.RegisterSink("winfile", newWinFileSink) -+ if err != nil { -+ initError = err -+ } -+ } - cfg.Dir = dataDir - cfg.LogOutputs = []string{logPath} - cfg.LogLevel = logLevel -@@ -73,3 +82,10 @@ func StopEtcdServer() { - }) - } - } -+ -+// special file sink for zap, as etcd using zap as Logger -+// See: https://github.com/uber-go/zap/issues/621 -+func newWinFileSink(u *url.URL) (zap.Sink, error) { -+ // e.g. winfile:///D:/test/ -> D:/test/ -+ return os.OpenFile(u.Path[1:], os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0600) -+} -diff --git a/internal/util/funcutil/func.go b/internal/util/funcutil/func.go -index 061224ba2..71d91f099 100644 ---- a/internal/util/funcutil/func.go -+++ b/internal/util/funcutil/func.go -@@ -31,7 +31,6 @@ import ( - "strings" - "time" - -- "github.com/go-basic/ipv4" - "go.uber.org/zap" - grpcStatus "google.golang.org/grpc/status" - -@@ -56,7 +55,16 @@ func CheckGrpcReady(ctx context.Context, targetCh chan error) { - - // GetLocalIP return the local ip address - func GetLocalIP() string { -- return ipv4.LocalIP() -+ addrs, err := net.InterfaceAddrs() -+ if err == nil { -+ for _, addr := range addrs { -+ ipaddr, ok := addr.(*net.IPNet) -+ if ok && ipaddr.IP.IsGlobalUnicast() && ipaddr.IP.To4() != nil { -+ return ipaddr.IP.String() -+ } -+ } -+ } -+ return "127.0.0.1" - } - - // WaitForComponentStates wait for component's state to be one of the specific states -diff --git a/internal/util/paramtable/grpc_param.go b/internal/util/paramtable/grpc_param.go -index af684ed69..6736c4645 100644 ---- a/internal/util/paramtable/grpc_param.go -+++ b/internal/util/paramtable/grpc_param.go -@@ -17,8 +17,8 @@ import ( - "sync" - "time" - -- "github.com/go-basic/ipv4" - "github.com/milvus-io/milvus/internal/log" -+ "github.com/milvus-io/milvus/internal/util/funcutil" - "go.uber.org/zap" - ) - -@@ -81,7 +81,7 @@ func (p *grpcConfig) init(domain string) { - - // LoadFromEnv is used to initialize configuration items from env. - func (p *grpcConfig) LoadFromEnv() { -- p.IP = ipv4.LocalIP() -+ p.IP = funcutil.GetLocalIP() - } - - // LoadFromArgs is used to initialize configuration items from args. -diff --git a/scripts/core_build.sh b/scripts/core_build.sh -index 0c33d9439..57d343a39 100755 ---- a/scripts/core_build.sh -+++ b/scripts/core_build.sh -@@ -186,15 +186,31 @@ fi - - unameOut="$(uname -s)" - case "${unameOut}" in -- Darwin*) -- llvm_prefix="$(brew --prefix llvm)" -- export CLANG_TOOLS_PATH="${llvm_prefix}/bin" -- export CC="${llvm_prefix}/bin/clang" -- export CXX="${llvm_prefix}/bin/clang++" -- export LDFLAGS="-L${llvm_prefix}/lib -L/usr/local/opt/libomp/lib" -- export CXXFLAGS="-I${llvm_prefix}/include -I/usr/local/include -I/usr/local/opt/libomp/include" -- ;; -- *) echo "==System:${unameOut}"; -+ Darwin*) -+ # detect llvm version by valid list -+ for llvm_version in 15 14 NOT_FOUND ; do -+ if brew ls --versions llvm@${llvm_version} > /dev/null; then -+ break -+ fi -+ done -+ if [ "${llvm_version}" = "NOT_FOUND" ] ; then -+ echo "llvm@14~15 is not installed" -+ exit 1 -+ fi -+ llvm_prefix="$(brew --prefix llvm@${llvm_version})" -+ export CLANG_TOOLS_PATH="${llvm_prefix}/bin" -+ export PATH=${CLANG_TOOLS_PATH}:${PATH} -+ export CC="${llvm_prefix}/bin/clang" -+ export CXX="${llvm_prefix}/bin/clang++" -+ export CFLAGS="-Wno-deprecated-declarations -I${llvm_prefix}/include -I/usr/local/include -I$(brew --prefix libomp)/include -I$(brew --prefix boost)/include -I$(brew --prefix tbb)/include" -+ export CXXFLAGS=${CFLAGS} -+ export LDFLAGS="-L${llvm_prefix}/lib -L$(brew --prefix libomp)/lib -L$(brew --prefix boost)/lib -L$(brew --prefix tbb)/lib" -+ ;; -+ Linux*) -+ ;; -+ *) -+ echo "Cannot build on windows" -+ ;; - esac - - -diff --git a/scripts/install_deps.sh b/scripts/install_deps.sh -index 5d58bb3f5..bb7f00a65 100755 ---- a/scripts/install_deps.sh -+++ b/scripts/install_deps.sh -@@ -29,7 +29,7 @@ function install_linux_deps() { - sudo yum install -y git make lcov libtool m4 autoconf automake ccache openssl-devel zlib-devel libzstd-devel \ - libcurl-devel python3-devel \ - devtoolset-7-gcc devtoolset-7-gcc-c++ devtoolset-7-gcc-gfortran \ -- llvm-toolset-7.0-clang llvm-toolset-7.0-clang-tools-extra libuuid-devel pulseaudio-libs-devel -+ llvm-toolset-7.0-clang llvm-toolset-7.0-clang-tools-extra libuuid-devel pulseaudio-libs-devel - - echo "source scl_source enable devtoolset-7" | sudo tee -a /etc/profile.d/devtoolset-7.sh - echo "source scl_source enable llvm-toolset-7.0" | sudo tee -a /etc/profile.d/llvm-toolset-7.sh -@@ -57,25 +57,8 @@ function install_linux_deps() { - - function install_mac_deps() { - sudo xcode-select --install > /dev/null 2>&1 -- brew install boost libomp ninja tbb cmake llvm ccache zstd -- brew uninstall grep -+ brew install boost libomp ninja tbb openblas cmake llvm@15 ccache pkg-config zstd openssl librdkafka - brew install grep -- export PATH="/usr/local/opt/grep/libexec/gnubin:$PATH" -- brew update && brew upgrade && brew cleanup -- -- if [[ $(arch) == 'arm64' ]]; then -- brew install openssl -- brew install librdkafka -- brew install pkg-config -- sudo mkdir /usr/local/include -- sudo mkdir /usr/local/opt -- sudo ln -s "$(brew --prefix llvm)" "/usr/local/opt/llvm" -- sudo ln -s "$(brew --prefix libomp)/include/omp.h" "/usr/local/include/omp.h" -- sudo ln -s "$(brew --prefix libomp)" "/usr/local/opt/libomp" -- sudo ln -s "$(brew --prefix boost)/include/boost" "/usr/local/include/boost" -- sudo ln -s "$(brew --prefix tbb)/include/tbb" "/usr/local/include/tbb" -- sudo ln -s "$(brew --prefix tbb)/include/oneapi" "/usr/local/include/oneapi" -- fi - } - - if ! command -v go &> /dev/null -diff --git a/scripts/setenv.sh b/scripts/setenv.sh -index 577683dfb..08a80abe3 100644 ---- a/scripts/setenv.sh -+++ b/scripts/setenv.sh -@@ -42,6 +42,7 @@ case "${unameOut}" in - export RPATH=$LD_LIBRARY_PATH;; - Darwin*) - export PKG_CONFIG_PATH="${PKG_CONFIG_PATH}:$ROOT_DIR/internal/core/output/lib/pkgconfig" -+ export PKG_CONFIG_PATH="${PKG_CONFIG_PATH}:$(brew --prefix openssl)/lib/pkgconfig" - export DYLD_LIBRARY_PATH=$ROOT_DIR/internal/core/output/lib - export RPATH=$DYLD_LIBRARY_PATH;; - MINGW*) diff --git a/milvus_binary/patches/msys-v2.2.4.patch b/milvus_binary/patches/msys-v2.2.4.patch deleted file mode 100644 index ff55b1d..0000000 --- a/milvus_binary/patches/msys-v2.2.4.patch +++ /dev/null @@ -1,156 +0,0 @@ -diff --git a/go.mod b/go.mod -index 10b7f730f..6f29a1ae2 100644 ---- a/go.mod -+++ b/go.mod -@@ -18,7 +18,6 @@ require ( - github.com/confluentinc/confluent-kafka-go v1.9.1 - github.com/containerd/cgroups v1.0.2 - github.com/gin-gonic/gin v1.7.7 -- github.com/go-basic/ipv4 v1.0.0 - github.com/gofrs/flock v0.8.1 - github.com/golang/protobuf v1.5.2 - github.com/google/btree v1.0.1 -diff --git a/go.sum b/go.sum -index 0069189c1..01cc61310 100644 ---- a/go.sum -+++ b/go.sum -@@ -221,8 +221,6 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE - github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= - github.com/gin-gonic/gin v1.7.7 h1:3DoBmSbJbZAWqXJC3SLjAPfutPJJRN1U5pALB7EeTTs= - github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ1qq1U= --github.com/go-basic/ipv4 v1.0.0 h1:gjyFAa1USC1hhXTkPOwBWDPfMcUaIM+tvo1XzV9EZxs= --github.com/go-basic/ipv4 v1.0.0/go.mod h1:etLBnaxbidQfuqE6wgZQfs38nEWNmzALkxDZe4xY8Dg= - github.com/go-fonts/dejavu v0.1.0/go.mod h1:4Wt4I4OU2Nq9asgDCteaAaWZOV24E+0/Pwo0gppep4g= - github.com/go-fonts/latin-modern v0.2.0/go.mod h1:rQVLdDMK+mK1xscDwsqM5J8U2jrRa3T0ecnM9pNujks= - github.com/go-fonts/liberation v0.1.1/go.mod h1:K6qoJYypsmfVjWg8KOVDQhLc8UDgIK2HYqyqAO9z7GY= -diff --git a/internal/core/CMakeLists.txt b/internal/core/CMakeLists.txt -index eb1f727d2..ec6193d0d 100644 ---- a/internal/core/CMakeLists.txt -+++ b/internal/core/CMakeLists.txt -@@ -25,6 +25,11 @@ endif () - - add_definitions(-DELPP_THREAD_SAFE) - set(CMAKE_POSITION_INDEPENDENT_CODE ON) -+ -+if ( MSYS ) -+ add_definitions(-DPROTOBUF_USE_DLLS) -+endif () -+ - message( STATUS "Building using CMake version: ${CMAKE_VERSION}" ) - - project(core) -diff --git a/internal/core/src/indexbuilder/CMakeLists.txt b/internal/core/src/indexbuilder/CMakeLists.txt -index cae7415c9..29cef2aa0 100644 ---- a/internal/core/src/indexbuilder/CMakeLists.txt -+++ b/internal/core/src/indexbuilder/CMakeLists.txt -@@ -23,6 +23,7 @@ add_library(milvus_indexbuilder SHARED ${INDEXBUILDER_FILES}) - find_library(TBB NAMES tbb) - set(PLATFORM_LIBS dl) - if (MSYS) -+find_library(TBB NAMES tbb12) - set(PLATFORM_LIBS -Wl,--allow-multiple-definition) - endif () - -diff --git a/internal/core/src/segcore/CMakeLists.txt b/internal/core/src/segcore/CMakeLists.txt -index 1b58ab37b..c85dffb50 100644 ---- a/internal/core/src/segcore/CMakeLists.txt -+++ b/internal/core/src/segcore/CMakeLists.txt -@@ -41,6 +41,7 @@ add_library(milvus_segcore SHARED ${SEGCORE_FILES}) - find_library(TBB NAMES tbb) - set(PLATFORM_LIBS dl) - if (MSYS) -+find_library(TBB NAMES tbb12) - set(PLATFORM_LIBS ) - endif () - -diff --git a/internal/util/etcd/etcd_server.go b/internal/util/etcd/etcd_server.go -index 75f81c43e..ffd758812 100644 ---- a/internal/util/etcd/etcd_server.go -+++ b/internal/util/etcd/etcd_server.go -@@ -1,6 +1,9 @@ - package etcd - - import ( -+ "net/url" -+ "os" -+ "runtime" - "sync" - - "github.com/milvus-io/milvus/internal/log" -@@ -45,6 +48,12 @@ func InitEtcdServer( - } else { - cfg = embed.NewConfig() - } -+ if runtime.GOOS == "windows" { -+ err := zap.RegisterSink("winfile", newWinFileSink) -+ if err != nil { -+ initError = err -+ } -+ } - cfg.Dir = dataDir - cfg.LogOutputs = []string{logPath} - cfg.LogLevel = logLevel -@@ -73,3 +82,10 @@ func StopEtcdServer() { - }) - } - } -+ -+// special file sink for zap, as etcd using zap as Logger -+// See: https://github.com/uber-go/zap/issues/621 -+func newWinFileSink(u *url.URL) (zap.Sink, error) { -+ // e.g. winfile:///D:/test/ -> D:/test/ -+ return os.OpenFile(u.Path[1:], os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0600) -+} -diff --git a/internal/util/funcutil/func.go b/internal/util/funcutil/func.go -index 53edf7b0a..e12eb43f2 100644 ---- a/internal/util/funcutil/func.go -+++ b/internal/util/funcutil/func.go -@@ -31,7 +31,6 @@ import ( - "strings" - "time" - -- "github.com/go-basic/ipv4" - "go.uber.org/zap" - grpcStatus "google.golang.org/grpc/status" - -@@ -56,7 +55,16 @@ func CheckGrpcReady(ctx context.Context, targetCh chan error) { - - // GetLocalIP return the local ip address - func GetLocalIP() string { -- return ipv4.LocalIP() -+ addrs, err := net.InterfaceAddrs() -+ if err == nil { -+ for _, addr := range addrs { -+ ipaddr, ok := addr.(*net.IPNet) -+ if ok && ipaddr.IP.IsGlobalUnicast() && ipaddr.IP.To4() != nil { -+ return ipaddr.IP.String() -+ } -+ } -+ } -+ return "127.0.0.1" - } - - // WaitForComponentStates wait for component's state to be one of the specific states -diff --git a/internal/util/paramtable/grpc_param.go b/internal/util/paramtable/grpc_param.go -index af684ed69..6736c4645 100644 ---- a/internal/util/paramtable/grpc_param.go -+++ b/internal/util/paramtable/grpc_param.go -@@ -17,8 +17,8 @@ import ( - "sync" - "time" - -- "github.com/go-basic/ipv4" - "github.com/milvus-io/milvus/internal/log" -+ "github.com/milvus-io/milvus/internal/util/funcutil" - "go.uber.org/zap" - ) - -@@ -81,7 +81,7 @@ func (p *grpcConfig) init(domain string) { - - // LoadFromEnv is used to initialize configuration items from env. - func (p *grpcConfig) LoadFromEnv() { -- p.IP = ipv4.LocalIP() -+ p.IP = funcutil.GetLocalIP() - } - - // LoadFromArgs is used to initialize configuration items from args. diff --git a/milvus_binary/patches/msys-v2.2.5.patch b/milvus_binary/patches/msys-v2.2.5.patch deleted file mode 100644 index ff55b1d..0000000 --- a/milvus_binary/patches/msys-v2.2.5.patch +++ /dev/null @@ -1,156 +0,0 @@ -diff --git a/go.mod b/go.mod -index 10b7f730f..6f29a1ae2 100644 ---- a/go.mod -+++ b/go.mod -@@ -18,7 +18,6 @@ require ( - github.com/confluentinc/confluent-kafka-go v1.9.1 - github.com/containerd/cgroups v1.0.2 - github.com/gin-gonic/gin v1.7.7 -- github.com/go-basic/ipv4 v1.0.0 - github.com/gofrs/flock v0.8.1 - github.com/golang/protobuf v1.5.2 - github.com/google/btree v1.0.1 -diff --git a/go.sum b/go.sum -index 0069189c1..01cc61310 100644 ---- a/go.sum -+++ b/go.sum -@@ -221,8 +221,6 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE - github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= - github.com/gin-gonic/gin v1.7.7 h1:3DoBmSbJbZAWqXJC3SLjAPfutPJJRN1U5pALB7EeTTs= - github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ1qq1U= --github.com/go-basic/ipv4 v1.0.0 h1:gjyFAa1USC1hhXTkPOwBWDPfMcUaIM+tvo1XzV9EZxs= --github.com/go-basic/ipv4 v1.0.0/go.mod h1:etLBnaxbidQfuqE6wgZQfs38nEWNmzALkxDZe4xY8Dg= - github.com/go-fonts/dejavu v0.1.0/go.mod h1:4Wt4I4OU2Nq9asgDCteaAaWZOV24E+0/Pwo0gppep4g= - github.com/go-fonts/latin-modern v0.2.0/go.mod h1:rQVLdDMK+mK1xscDwsqM5J8U2jrRa3T0ecnM9pNujks= - github.com/go-fonts/liberation v0.1.1/go.mod h1:K6qoJYypsmfVjWg8KOVDQhLc8UDgIK2HYqyqAO9z7GY= -diff --git a/internal/core/CMakeLists.txt b/internal/core/CMakeLists.txt -index eb1f727d2..ec6193d0d 100644 ---- a/internal/core/CMakeLists.txt -+++ b/internal/core/CMakeLists.txt -@@ -25,6 +25,11 @@ endif () - - add_definitions(-DELPP_THREAD_SAFE) - set(CMAKE_POSITION_INDEPENDENT_CODE ON) -+ -+if ( MSYS ) -+ add_definitions(-DPROTOBUF_USE_DLLS) -+endif () -+ - message( STATUS "Building using CMake version: ${CMAKE_VERSION}" ) - - project(core) -diff --git a/internal/core/src/indexbuilder/CMakeLists.txt b/internal/core/src/indexbuilder/CMakeLists.txt -index cae7415c9..29cef2aa0 100644 ---- a/internal/core/src/indexbuilder/CMakeLists.txt -+++ b/internal/core/src/indexbuilder/CMakeLists.txt -@@ -23,6 +23,7 @@ add_library(milvus_indexbuilder SHARED ${INDEXBUILDER_FILES}) - find_library(TBB NAMES tbb) - set(PLATFORM_LIBS dl) - if (MSYS) -+find_library(TBB NAMES tbb12) - set(PLATFORM_LIBS -Wl,--allow-multiple-definition) - endif () - -diff --git a/internal/core/src/segcore/CMakeLists.txt b/internal/core/src/segcore/CMakeLists.txt -index 1b58ab37b..c85dffb50 100644 ---- a/internal/core/src/segcore/CMakeLists.txt -+++ b/internal/core/src/segcore/CMakeLists.txt -@@ -41,6 +41,7 @@ add_library(milvus_segcore SHARED ${SEGCORE_FILES}) - find_library(TBB NAMES tbb) - set(PLATFORM_LIBS dl) - if (MSYS) -+find_library(TBB NAMES tbb12) - set(PLATFORM_LIBS ) - endif () - -diff --git a/internal/util/etcd/etcd_server.go b/internal/util/etcd/etcd_server.go -index 75f81c43e..ffd758812 100644 ---- a/internal/util/etcd/etcd_server.go -+++ b/internal/util/etcd/etcd_server.go -@@ -1,6 +1,9 @@ - package etcd - - import ( -+ "net/url" -+ "os" -+ "runtime" - "sync" - - "github.com/milvus-io/milvus/internal/log" -@@ -45,6 +48,12 @@ func InitEtcdServer( - } else { - cfg = embed.NewConfig() - } -+ if runtime.GOOS == "windows" { -+ err := zap.RegisterSink("winfile", newWinFileSink) -+ if err != nil { -+ initError = err -+ } -+ } - cfg.Dir = dataDir - cfg.LogOutputs = []string{logPath} - cfg.LogLevel = logLevel -@@ -73,3 +82,10 @@ func StopEtcdServer() { - }) - } - } -+ -+// special file sink for zap, as etcd using zap as Logger -+// See: https://github.com/uber-go/zap/issues/621 -+func newWinFileSink(u *url.URL) (zap.Sink, error) { -+ // e.g. winfile:///D:/test/ -> D:/test/ -+ return os.OpenFile(u.Path[1:], os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0600) -+} -diff --git a/internal/util/funcutil/func.go b/internal/util/funcutil/func.go -index 53edf7b0a..e12eb43f2 100644 ---- a/internal/util/funcutil/func.go -+++ b/internal/util/funcutil/func.go -@@ -31,7 +31,6 @@ import ( - "strings" - "time" - -- "github.com/go-basic/ipv4" - "go.uber.org/zap" - grpcStatus "google.golang.org/grpc/status" - -@@ -56,7 +55,16 @@ func CheckGrpcReady(ctx context.Context, targetCh chan error) { - - // GetLocalIP return the local ip address - func GetLocalIP() string { -- return ipv4.LocalIP() -+ addrs, err := net.InterfaceAddrs() -+ if err == nil { -+ for _, addr := range addrs { -+ ipaddr, ok := addr.(*net.IPNet) -+ if ok && ipaddr.IP.IsGlobalUnicast() && ipaddr.IP.To4() != nil { -+ return ipaddr.IP.String() -+ } -+ } -+ } -+ return "127.0.0.1" - } - - // WaitForComponentStates wait for component's state to be one of the specific states -diff --git a/internal/util/paramtable/grpc_param.go b/internal/util/paramtable/grpc_param.go -index af684ed69..6736c4645 100644 ---- a/internal/util/paramtable/grpc_param.go -+++ b/internal/util/paramtable/grpc_param.go -@@ -17,8 +17,8 @@ import ( - "sync" - "time" - -- "github.com/go-basic/ipv4" - "github.com/milvus-io/milvus/internal/log" -+ "github.com/milvus-io/milvus/internal/util/funcutil" - "go.uber.org/zap" - ) - -@@ -81,7 +81,7 @@ func (p *grpcConfig) init(domain string) { - - // LoadFromEnv is used to initialize configuration items from env. - func (p *grpcConfig) LoadFromEnv() { -- p.IP = ipv4.LocalIP() -+ p.IP = funcutil.GetLocalIP() - } - - // LoadFromArgs is used to initialize configuration items from args. diff --git a/milvus_build_backend/backend.py b/milvus_build_backend/backend.py deleted file mode 100644 index 0f7ea8d..0000000 --- a/milvus_build_backend/backend.py +++ /dev/null @@ -1,57 +0,0 @@ -import os -import sys -import lzma -import platform -import setuptools.build_meta as _build - - -def _get_project_dir(): - return os.path.dirname(os.path.abspath(os.path.dirname(__file__))) - - -def _build_milvus_binary(): - project_dir = _get_project_dir() - status = os.system(f'bash {project_dir}/milvus_binary/build.sh') - if status != 0: - raise RuntimeError('Build milvus binary failed') - # install it to data/bin - bin_dir = os.path.join(project_dir, 'milvus_binary', 'output') - to_dir = os.path.join(project_dir, 'src', 'milvus', 'data', 'bin') - os.makedirs(to_dir, exist_ok=True) - for file in os.listdir(bin_dir): - if file.endswith('.txt'): - continue - file_from = os.path.join(bin_dir, file) - file_to = os.path.join(to_dir, f'{file}.lzma') - with lzma.open(file_to, 'wb') as lzma_file: - with open(file_from, 'rb') as orig_file: - print('writeing binary file: ', file) - lzma_file.write(orig_file.read()) - - -def _get_platform(): - machine_text = platform.machine().lower() - if sys.platform.lower() == 'darwin': - if machine_text == 'x86_64': - return 'macosx_10_9_x86_64' - elif machine_text == 'arm64': - return 'macosx_11_0_arm64' - if sys.platform.lower() == 'linux': - return f'manylinux2014_{machine_text}' - if sys.platform.lower() == 'win32': - return 'win_amd64' - - -get_requires_for_build_wheel = _build.get_requires_for_build_wheel - - -def build_wheel(wheel_directory, config_settings=None, metadata_directory=None): - _build_milvus_binary() - name = _build.build_wheel( - wheel_directory, config_settings, metadata_directory) - if name.endswith('-none-any.whl'): - new_name = name.replace('-any.whl', f'-{_get_platform()}.whl') - - os.rename(os.path.join(wheel_directory, name), - os.path.join(wheel_directory, new_name)) - return new_name diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index c895fa3..0000000 --- a/pyproject.toml +++ /dev/null @@ -1,9 +0,0 @@ -[build-system] -requires = ["setuptools>=64.0"] -build-backend = "backend" -backend-path = ["milvus_build_backend"] - -[tool.pytest.ini_options] -testpaths = ["tests"] -pythonpath = ["src"] -addopts = "--cov=milvus" \ No newline at end of file diff --git a/python/setup.py b/python/setup.py new file mode 100644 index 0000000..18d11ba --- /dev/null +++ b/python/setup.py @@ -0,0 +1,138 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +import pathlib +import unittest +from typing import List +import subprocess +import platform + +from setuptools import setup, find_namespace_packages, Extension +from wheel.bdist_wheel import bdist_wheel as _bdist_wheel + +import shutil + + +MILVUS_BIN = 'milvus' + + +class CMakeBuild(_bdist_wheel): + def finalize_options(self): + if sys.platform.lower() == 'linux': + self.plat_name = f"manylinux2014_{platform.machine().lower()}" + elif sys.platform.lower() == 'darwin': + self.plat_name = f"macosx_{platform.machine().lower()}" + return super().finalize_options() + + def copy_lib(self, lib_path, dst_dir, pick_libs): + name = pathlib.Path(lib_path).name + new_file = os.path.join(dst_dir, name) + for lib_prefix in pick_libs: + if name.startswith(lib_prefix): + shutil.copy(lib_path, new_file) + continue + + def _pack_macos(self, src_dir: str, dst_dir: str): + mac_pkg = ['libknowhere', 'libmilvus', + 'libgflags_nothreads', 'libglog', + 'libtbb', 'libgomp', + 'libdouble-conversion'] + milvus_bin = pathlib.Path(src_dir) / MILVUS_BIN + out_str = subprocess.check_output(['otool', '-L', str(milvus_bin)]) + lines = out_str.decode('utf-8').split('\n') + all_files = [] + for line in lines[1:]: + r = line.split(' ') + if not r[0].endswith('dylib'): + continue + self.copy_lib(r[1].strip().split(' ')[0].strip(), dst_dir, mac_pkg) + + def _pack_linux(self, src_dir: str, dst_dir: str): + linux_pkg = ['libknowhere', 'libmilvus', + 'libgflags_nothreads', 'libglog', + 'libtbb', 'libm', 'libgcc_s', + 'libgomp', 'libopenblas', + 'libdouble-conversion', 'libz', + 'libgfortran', 'libquadmath'] + milvus_bin = pathlib.Path(src_dir) / MILVUS_BIN + out_str = subprocess.check_output(['ldd', str(milvus_bin)]) + lines = out_str.decode('utf-8').split('\n') + all_files = [] + for line in lines: + r = line.split("=>") + if len(r) != 2: + continue + self.copy_lib(r[1].strip().split(' ')[0].strip(), dst_dir, linux_pkg) + + def run(self): + build_lib = self.bdist_dir + build_temp = os.path.abspath(os.path.join(os.path.dirname(build_lib), 'build_milvus')) + + if not os.path.exists(build_temp): + os.makedirs(build_temp) + #clean build temp + shutil.rmtree(os.path.join(build_temp, 'lib'), ignore_errors=True) + extdir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + env = os.environ + env['LD_LIBRARY_PATH'] = os.path.join(build_temp, 'lib') + subprocess.check_call(['conan', 'install', extdir, '--build=missing', '-s', 'build_type=Release'], cwd=build_temp, env=env) + subprocess.check_call(['cmake', extdir, '-DENABLE_UNIT_TESTS=OFF'], cwd=build_temp, env=env) + subprocess.check_call(['cmake', '--build', '.', '--', '-j48'], + cwd=build_temp, + env=env, + ) + dst_lib_path = os.path.join(build_lib, 'milvus/lib') + shutil.rmtree(dst_lib_path, ignore_errors=True) + os.makedirs(dst_lib_path) + shutil.copy(os.path.join(build_temp, 'lib', MILVUS_BIN), os.path.join(dst_lib_path, MILVUS_BIN)) + if sys.platform.lower() == 'linux': + self._pack_linux(os.path.join(build_temp, 'lib'), dst_lib_path) + elif sys.platform.lower() == 'darwin': + self._pack_macos(os.path.join(build_temp, 'lib'), dst_lib_path) + else: + raise RuntimeError('Unsupport platform: %s', sys.platform) + + super().run() + + +def test_suite(): + test_loader = unittest.TestLoader() + tests = test_loader.discover('tests', pattern='test_*.py') + return tests + + +def parse_requirements(file_name: str) -> List[str]: + with open(file_name, encoding='utf-8') as f: + return [ + require.strip() for require in f + if require.strip() and not require.startswith('#') + ] + +setup(name='milvus', + version='2.4.0', + description='', + author='Milvus Team', + author_email='milvus-team@zilliz.com', + url='https://github.com/milvus-io/milvus-lite.git', + test_suite='setup.test_suite', + package_dir={'': 'src'}, + packages=find_namespace_packages('src'), + package_data={}, + include_package_data=True, + python_requires='>=3.7', + cmdclass={"bdist_wheel": CMakeBuild}, + long_description_content_type='text/markdown' + ) diff --git a/python/src/milvus/__init__.py b/python/src/milvus/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/src/milvus/server.py b/python/src/milvus/server.py new file mode 100644 index 0000000..447882b --- /dev/null +++ b/python/src/milvus/server.py @@ -0,0 +1,92 @@ +import os +import subprocess +import pathlib +import logging +import fcntl +import re + + +BIN_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'lib') + + +logger = logging.getLogger() + + +class Server: + """ + """ + MILVUS_BIN = 'milvus' + + + def __init__(self, db_file: str, args=None): + if os.environ.get('BIN_PATH') is not None: + self._bin_path = pathlib.Path(os.environ['BIN_PATH']).absolute() + else: + self._bin_path = pathlib.Path(BIN_PATH).absolute() + self._db_file = pathlib.Path(db_file).absolute() + if not re.match(r'^[a-zA-Z0-9.\-_]+$', self._db_file.name): + raise RuntimeError(f"Unsupport db name {self._db_file.name}, the name must match ^[a-zA-Z0-9.\-_]+$") + self._work_dir = self._db_file.parent + self._args = args + self._p = None + self._uds_path = str(self._db_file.parent / f'.{self._db_file.name}.sock') + self._lock_path = str(self._db_file.parent / f'.{self._db_file.name}.lock') + self._lock_fd = None + + def init(self) -> bool: + if not self._bin_path.exists(): + logger.error("Bin path not exists") + return False + if not self._work_dir.exists(): + logger.error("Dir %s not exist", self._work_dir) + return True + + @property + def milvus_bin(self): + return str(self._bin_path / 'milvus') + + @property + def log_level(self): + return os.environ.get("LOG_LEVEL", "ERROR") + + @property + def uds_path(self): + return f'unix:{self._uds_path}' + + @property + def args(self): + if self._args is not None: + return self._args + return [self.milvus_bin, self._db_file, self.uds_path, self.log_level, self._lock_path] + + def start(self) -> bool: + assert self._p is None, "Server already started" + self._lock_fd = open(self._lock_path, 'a') + try: + fcntl.lockf(self._lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB) + self._p = subprocess.Popen( + args=self.args, + env={"LD_LIBRARY_PATH": str(self._bin_path)}, + cwd=str(self._work_dir), + ) + return True + except BlockingIOError: + logger.error("Open %s failed, the file has been opened by another program", self._db_file) + return False + + def stop(self): + if self._p is not None: + logger.info("Stop milvus...") + try: + self._p.terminate() + self._p.wait(timeout=2) + except subprocess.TimeoutExpired: + self._p.kill() + self._p.wait(timeout=3) + self._p = None + if self._lock_fd: + fcntl.flock(self._lock_fd, fcntl.LOCK_UN) + self._lock_fd.close() + self._lock_fd = None + pathlib.Path(self._uds_path).unlink(missing_ok=True) + pathlib.Path(self._lock_path).unlink(missing_ok=True) diff --git a/python/src/milvus/server_manager.py b/python/src/milvus/server_manager.py new file mode 100644 index 0000000..cdf2f84 --- /dev/null +++ b/python/src/milvus/server_manager.py @@ -0,0 +1,47 @@ +from typing import Optional +import threading +from milvus.server import Server +import logging +import pathlib + + +logger = logging.getLogger() + + +class ServerManager: + def __init__(self): + self._lock = threading.Lock() + self._servers = {} + + def start_and_get_uri(self, path: str, args=None) -> Optional[str]: + path = pathlib.Path(path).absolute().resolve() + with self._lock: + if str(path) not in self._servers: + s = Server(str(path), args) + if not s.init(): + return None + self._servers[str(path)] = s + if not self._servers[str(path)].start(): + logger.error("Start local milvus failed") + return None + return self._servers[str(path)].uds_path + + def release_server(self, path: str): + path = pathlib.Path(path).absolute().resolve() + with self._lock: + if str(path) not in self._servers: + logger.warning("No local milvus in path %s", str(path)) + return + self._servers[str(path)].stop() + del self._servers[str(path)] + + def release_all(self): + for s in self._servers.values(): + s.stop() + + def __del__(self): + with self._lock: + self.release_all() + + +server_manager_instance = ServerManager() diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 41c3120..0000000 --- a/requirements.txt +++ /dev/null @@ -1,7 +0,0 @@ -# requirements for devel -build -wheel -setuptools>64.0 -pytest -pytest-cov -pyyaml diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 064c642..0000000 --- a/setup.cfg +++ /dev/null @@ -1,38 +0,0 @@ -[metadata] -name = milvus -version = attr: milvus.__version__ -author = Milvus Team -author_email = milvus-team@zilliz.com -maintainer = Ji Bin -maintainer_email = matrixji@live.com -description = Embeded Milvus -license = Apache-2.0 -license_files = LICENSE -long_description = file: README.md -long_description_content_type = text/markdown -home_page = https://github.com/milvus-io/milvus-lite -keywords = Milvus, Embeded Milvus, Milvus Server - -[options] -package_dir = - = src -include_package_data = True -packages = find_namespace: -python_requires = >=3.6 -install_requires = - -[options.packages.find] -where = src - -[options.package_data] -milvus.data = - *.template -milvus.data.bin = - *.lzma - -[options.entry_points] -console_scripts = - milvus-server = milvus:main - -[options.extras_require] -client = pymilvus>=2.3.0b1,<2.4.0 diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt new file mode 100644 index 0000000..96f7340 --- /dev/null +++ b/src/CMakeLists.txt @@ -0,0 +1,56 @@ +include_directories(${CMAKE_CURRENT_LIST_DIR}) + +add_library(milite STATIC + ${CMAKE_CURRENT_LIST_DIR}/collection_meta.cpp + ${CMAKE_CURRENT_LIST_DIR}/collection_data.cpp + ${CMAKE_CURRENT_LIST_DIR}/storage.cpp + ${CMAKE_CURRENT_LIST_DIR}/index.cpp + ${CMAKE_CURRENT_LIST_DIR}/milvus_local.cpp + ${CMAKE_CURRENT_LIST_DIR}/segcore_wrapper.cpp + ${CMAKE_CURRENT_LIST_DIR}/milvus_proxy.cpp + ${CMAKE_CURRENT_LIST_DIR}/create_collection_task.cpp + ${CMAKE_CURRENT_LIST_DIR}/create_index_task.cpp + ${CMAKE_CURRENT_LIST_DIR}/search_task.cpp + ${CMAKE_CURRENT_LIST_DIR}/insert_task.cpp + ${CMAKE_CURRENT_LIST_DIR}/query_task.cpp + ${CMAKE_CURRENT_LIST_DIR}/delete_task.cpp + ${CMAKE_CURRENT_LIST_DIR}/schema_util.cpp +) + + +target_link_libraries( + milite + PUBLIC + parser + milvus_proto + milvus_segcore + ${CONAN_LIBS} + SQLiteCpp + ${antlr4-cppruntime_LIBRARIES} + marisa::marisa + TBB::tbb +) + +add_library( + milvus_service + STATIC + "${CMAKE_SOURCE_DIR}/src/milvus_service_impl.cpp" +) + +target_link_libraries( + milvus_service + PUBLIC + milvus_grpc_service +) + +add_executable(milvus server.cpp) + +target_link_libraries( + milvus + milvus_service + milite +) + +if(ENABLE_UNIT_TESTS) + add_subdirectory(unittest) +endif() diff --git a/src/collection_data.cpp b/src/collection_data.cpp new file mode 100644 index 0000000..4a4ba6b --- /dev/null +++ b/src/collection_data.cpp @@ -0,0 +1,131 @@ +#include "collection_data.h" +#include +// #include +#include +#include +#include +#include +#include "string_util.hpp" + +#include "log/Log.h" + +namespace milvus::local { + +CollectionData::CollectionData(const char* collection_name) + : collection_name_(collection_name), + col_id_("id"), + col_milvus_id_("milvus_id"), + col_data_("data") { +} +CollectionData::~CollectionData() { +} + +std::string +CollectionData::GetTableCreateSql() { + return string_util::SFormat( + "CREATE TABLE IF NOT EXISTS {} ({} INTEGER PRIMARY KEY, {} " + "VARCHAR(1024), {} BLOB);", + collection_name_, + col_id_, + col_milvus_id_, + col_data_); +} + +bool +CollectionData::CreateCollection(SQLite::Database* db) { + const std::string table_create_sql = GetTableCreateSql(); + if (db->tryExec(table_create_sql) != 0) { + const char* err = db->getErrorMsg(); + LOG_ERROR("Create table {} failed, errs: {}", collection_name_, err); + return false; + } + return true; +} + +bool +CollectionData::DropCollection(SQLite::Database* db) { + // DROP TABLE {collection_name_} + std::string drop_sql = + string_util::SFormat("DROP TABLE {}", collection_name_); + + if (db->tryExec(drop_sql) != 0) { + const char* err = db->getErrorMsg(); + LOG_ERROR("Drop collection {} failed, errs: {}", collection_name_, err); + return false; + } + return true; +} + +int +CollectionData::Insert(SQLite::Database* db, + const std::string& milvus_id, + const std::string& data) { + std::string insert_sql = string_util::SFormat( + "INSERT INTO {} VALUES (NULL, ?, ?)", collection_name_); + try { + SQLite::Statement query(*db, insert_sql); + SQLite::bind(query, milvus_id, data); + return query.exec(); + } catch (std::exception& e) { + LOG_ERROR("Insert data failed, errs: {}", e.what()); + return -1; + } +} + +void +CollectionData::Load(SQLite::Database* db, + int64_t start, + int64_t limit, + std::vector* output_rows) { + // SELECT {col_data_} from {collection_name_} LIMIT {limit} OFFSET {start} + std::string select_sql = + string_util::SFormat("SELECT {} from {} LIMIT {} OFFSET {}", + col_data_, + collection_name_, + limit, + start); + try { + SQLite::Statement query(*db, select_sql); + while (query.executeStep()) { + output_rows->push_back(query.getColumn(0).getString()); + } + + } catch (std::exception& e) { + LOG_ERROR("Load data failed, errs: {}", e.what()); + } +} + +int +CollectionData::Delete(SQLite::Database* db, + const std::vector& milvus_ids) { + // DELETE FROM {collection_name_} WHERE {col_milvus_id} in ({}) + std::string delete_sql = + string_util::SFormat("DELETE FROM {} WHERE {} IN ({})", + collection_name_, + col_milvus_id_, + string_util::Join(",", milvus_ids)); + try { + SQLite::Statement query(*db, delete_sql); + return query.exec(); + } catch (std::exception& e) { + LOG_ERROR("Delete data failed, errs: {}", e.what()); + return -1; + } +} + +int64_t +CollectionData::Count(SQLite::Database* db) { + // SELECT count(*) FROM {}; + std::string count_sql = + string_util::SFormat("SELECT count(*) FROM {}", collection_name_); + try { + SQLite::Statement query(*db, count_sql); + query.executeStep(); + return query.getColumn(0).getInt64(); + } catch (std::exception& e) { + LOG_ERROR("count data failed, errs: {}", e.what()); + return -1; + } +} + +} // namespace milvus::local diff --git a/src/collection_data.h b/src/collection_data.h new file mode 100644 index 0000000..92bd655 --- /dev/null +++ b/src/collection_data.h @@ -0,0 +1,92 @@ +/* collection data table + + ──────┬─────────────┬──────────── + id │ milvus_id │ data + │ │ + │ │ + ──────┼─────────────┼──────────── + 1 │ 1234 │ xxxx + │ │ + ──────┼─────────────┼──────────── + │ │ + 2 │ 1235 │ xxxx + │ │ + ──────┼─────────────┼──────────── + │ │ + 3 │ 1236 │ xxxx + │ │ + ───────┼─────────────┼──────────── + │ │ + 4 │ 1237 │ xxxx + │ │ + ───────┼─────────────┼──────────── + │ │ + 5 │ 1238 │ xxxx + │ │ + │ │ + ──────┴─────────────┴──────────── + + +*/ + +#pragma once +#include +#include +#include +#include +#include "type.h" +#include "common.h" + +namespace milvus::local { + +class CollectionData final : NonCopyableNonMovable { + public: + explicit CollectionData(const char*); + virtual ~CollectionData(); + + public: + int + Insert(SQLite::Database* db, + const std::string& milvus_id, + const std::string& data); + + int + Delete(SQLite::Database* db, const std::vector& milvus_ids); + // int + // upsert(SQLite::Database* db, std::string& milvus_id, const std::string& data); + // const char* + // get(SQLite::Database* db, const std::string& milvus_id); + + void + Load(SQLite::Database* db, + int64_t start, + int64_t limit, + std::vector* output_rows); + + bool + CreateCollection(SQLite::Database* db); + + bool + DropCollection(SQLite::Database* db); + + const std::string& + GetTableName() { + return collection_name_; + } + + int64_t + Count(SQLite::Database* db); + + private: + std::string + GetTableCreateSql(); + + private: + const std::string collection_name_; + + // table column name + const std::string col_id_; + const std::string col_milvus_id_; + const std::string col_data_; +}; +} // namespace milvus::local diff --git a/src/collection_meta.cpp b/src/collection_meta.cpp new file mode 100644 index 0000000..a4f93b0 --- /dev/null +++ b/src/collection_meta.cpp @@ -0,0 +1,203 @@ +#include "collection_meta.h" +#include +// #include +#include +#include +#include +#include "log/Log.h" +#include "string_util.hpp" + +namespace milvus::local { + +const std::string kSchemaStr = "schema"; +const std::string kIndexStr = "index"; +const std::string kPartitionStr = "PARTITION"; + +CollectionMeta::CollectionMeta() + : table_meta_name_("collection_meta"), + col_id_("id"), + col_collection_name_("collection_name"), + col_meta_type_("meta_type"), + col_blob_field_("blob_field"), + col_string_field_("string_field") { +} + +CollectionMeta::~CollectionMeta() { +} + +bool +CollectionMeta::LoadMeta(SQLite::Database* db) { + // SELECT * FROM {table_meta_name_} + std::string load_cmd = + string_util::SFormat("SELECT * FROM {}", table_meta_name_); + try { + SQLite::Statement query(*db, load_cmd); + while (query.executeStep()) { + auto collection_name = query.getColumn(1).getString(); + LOG_INFO("Load {}'s meta", collection_name); + if (collections_.find(collection_name) == collections_.end()) { + collections_.emplace(collection_name, + std::make_unique()); + } + auto meta_type = query.getColumn(2).getString(); + if (meta_type == kSchemaStr) { + auto info = + static_cast(query.getColumn(3).getBlob()); + auto pk_name = query.getColumn(4).getString(); + collections_[collection_name]->AddSchema(info, pk_name); + } else if (meta_type == kIndexStr) { + auto info = + static_cast(query.getColumn(3).getBlob()); + auto index_name = query.getColumn(4).getString(); + collections_[collection_name]->AddIndex(index_name, info); + } else { + LOG_ERROR("Unkown meta data"); + return false; + } + } + return true; + } catch (std::exception& e) { + LOG_ERROR("Load meta data failed, err: {}", e.what()); + return false; + } +} + +bool +CollectionMeta::CreateTable(SQLite::Database* db) { + const std::string table_create_cmd = string_util::SFormat( + "CREATE TABLE IF NOT EXISTS {} ({} INTEGER PRIMARY KEY, {} " + "VARCHAR(1024), {} VARCHAR(1024), {} BLOB, {} " + "VARCHAR(1024))", + table_meta_name_, + col_id_, + col_collection_name_, + col_meta_type_, + col_blob_field_, + col_string_field_); + + if (db->tryExec(table_create_cmd) != 0) { + const char* err = db->getErrorMsg(); + LOG_ERROR("Create table failed, errs: {}", err); + return false; + } + return true; +} + +bool +CollectionMeta::Init(SQLite::Database* db) { + return CreateTable(db) && LoadMeta(db); +} + +bool +CollectionMeta::CreateCollection(SQLite::Database* db, + const std::string& collection_name, + const std::string& pk_name, + const std::string& schema_proto) { + collections_.emplace(collection_name, std::make_unique()); + collections_[collection_name]->AddSchema(schema_proto.c_str(), pk_name); + + // INSERT INTO {table_name} VALUES (NULL, {collection_name}, "schema", {data}, NULL) + std::string insert_cmd = string_util::SFormat( + "INSERT INTO {} VALUES (NULL, ?, ?, ?, ?)", table_meta_name_); + try { + SQLite::Statement query(*db, insert_cmd); + SQLite::bind(query, collection_name, kSchemaStr, schema_proto, pk_name); + return query.exec(); + } catch (std::exception& e) { + LOG_ERROR("Insert data failed, errs: {}", e.what()); + return -1; + } +} + +const std::string& +CollectionMeta::GetCollectionSchema(const std::string& collection_name) { + return collections_[collection_name]->Schema(); +} + +bool +CollectionMeta::CreateIndex(SQLite::Database* db, + const std::string& collection_name, + const std::string& index_name, + const std::string& index_proto) { + // INSERT INTO {table_name} VALUES (NULL, {collection_name}, "schema", {data}, NULL) + collections_[collection_name]->AddIndex(index_name, index_proto.c_str()); + std::string insert_cmd = string_util::SFormat( + "INSERT INTO {} VALUES (NULL, \"{}\", \"index\", \"{}\", \"{}\")", + table_meta_name_, + collection_name, + index_proto, + index_name); + try { + db->exec(insert_cmd); + return true; + } catch (std::exception& e) { + LOG_ERROR("Add index failed, err: {}", e.what()); + return false; + } +} + +bool +CollectionMeta::GetCollectionIndex(const std::string& collection_name, + const std::string& index_name, + std::string* output_index_info) { + return collections_[collection_name]->GetIndex(index_name, + output_index_info); +} + +void +CollectionMeta::GetAllIndex(const std::string& collection_name, + const std::string& exclude, + std::vector* all_index) { + collections_[collection_name]->GetAllIndexs(all_index, exclude); +} + +bool +CollectionMeta::DropCollection(SQLite::Database* db, + const std::string& collection_name) { + // DELETE FROM {table_name} WHERE {col_collection_name_}={collection_name}; + std::string delete_cmd = string_util::SFormat( + "DELETE FROM {} WHERE " + "{}='{}'", + table_meta_name_, + col_collection_name_, + collection_name); + try { + collections_.erase(collection_name); + db->exec(delete_cmd); + return true; + } catch (std::exception& e) { + LOG_ERROR( + "Drop collection: {} failed, err: {}", collection_name, e.what()); + return false; + } +} + +bool +CollectionMeta::DropIndex(SQLite::Database* db, + const std::string& collection_name, + const std::string& index_name) { + // DELETE FROM {table_name} WHERE {col_collection_name_}={collection_name} and {col_meta_type_}={kIndexStr} and {col_string_field_}={index_name}; + std::string drop_index_cmd = string_util::SFormat( + "DELETE FROM {} WHERE " + "{}='{}' and {}='{}' and {}='{}'", + table_meta_name_, + col_collection_name_, + collection_name, + col_meta_type_, + kIndexStr, + col_string_field_, + index_name); + try { + collections_[collection_name]->DropIndex(index_name); + db->exec(drop_index_cmd); + return true; + } catch (std::exception& e) { + LOG_ERROR("Drop collection {}'s index:{} failed, err: {}", + collection_name, + index_name, + e.what()); + return false; + } +} + +} // namespace milvus::local diff --git a/src/collection_meta.h b/src/collection_meta.h new file mode 100644 index 0000000..fd305a8 --- /dev/null +++ b/src/collection_meta.h @@ -0,0 +1,219 @@ +/* meta table + + ──────┬───────────────────┬─────────────┬─────────────┬──────────────── + id │ collection_name │ meta_type │ blob_field │ string_field + │ │ │ │ + │ │ │ │ + ──────┼───────────────────┼─────────────┼─────────────┼──────────────── + 1 │ collection1 │ schema │ xxx │ pk_name + │ │ │ │ + ──────┼───────────────────┼─────────────┼─────────────┼──────────────── + │ │ │ │ + 2 │ collection1 │ index │ xxx │ index1 + │ │ │ │ + ──────┼───────────────────┼─────────────┼─────────────┼──────────────── + │ │ │ │ + 3 │ collection1 │ index │ xxx │ index2 + │ │ │ │ + ──────┼───────────────────┼─────────────┼─────────────┼──────────────── + │ │ │ │ + 4 │ collection1 │ partition │ null │ p1 + │ │ │ │ + ──────┼───────────────────┼─────────────┼─────────────┼──────────────── + │ │ │ │ + 5 │ collction2 │ schema │ xxx │ pk_name + │ │ │ │ + ──────┴───────────────────┴─────────────┴─────────────┴──────────────── +*/ + +#pragma once +#include +#include +#include +#include +#include +#include +#include "SQLiteCpp/Database.h" +#include "log/Log.h" + +namespace milvus::local { + +using SchemaInfo = std::string; +using IndexInfo = std::map; + +/* + * CollectionMeta 存储所有collection的元信息,数据写入sqlite3中,内存中保存副本。 + * Collection 是否存在等检查,都在storage中,CollectionMeta不再进行相关检查。 + */ + +class CollectionMeta final { + public: + class CollectionInfo { + public: + CollectionInfo() = default; + ~CollectionInfo() = default; + CollectionInfo(const CollectionInfo&) = delete; + CollectionInfo& + operator=(const CollectionInfo&) = delete; + CollectionInfo(const CollectionInfo&&) = delete; + CollectionInfo& + operator=(const CollectionInfo&&) = delete; + + public: + void + AddSchema(const char* info, const std::string& pk_name) { + schema_info_.assign(info); + pk_name_ = pk_name; + } + + const std::string& + Schema() { + return schema_info_; + } + bool + AddIndex(const std::string& index_name, const char* index) { + if (index_info_.find(index_name) != index_info_.end()) { + LOG_ERROR("Index: {} already exist", index_name); + return false; + } + index_info_.emplace(index_name, index); + return true; + } + + bool + GetIndex(const std::string& index_name, std::string* index) { + if (index_info_.find(index_name) == index_info_.end()) { + return false; + } + index->assign(index_info_[index_name].c_str()); + return true; + } + + void + GetAllIndexs(std::vector* all_index, + const std::string& exclude) { + for (const auto& pair : index_info_) { + if (pair.first != exclude) + all_index->push_back(pair.second); + } + } + + bool + HasIndex(const std::string& index_name) { + return index_info_.find(index_name) != index_info_.end(); + } + + bool + DropIndex(const std::string& index_name) { + if (!HasIndex(index_name)) { + return true; + } + index_info_.erase(index_name); + return true; + } + + const std::string& + GetPkName() { + return pk_name_; + } + + private: + IndexInfo index_info_; + SchemaInfo schema_info_; + std::string pk_name_; + }; + + public: + CollectionMeta(); + ~CollectionMeta(); + + public: + CollectionMeta(const CollectionMeta&) = delete; + CollectionMeta& + operator=(const CollectionMeta&) = delete; + CollectionMeta(const CollectionMeta&&) = delete; + CollectionMeta& + operator=(const CollectionMeta&&) = delete; + + public: + bool + Init(SQLite::Database* db); + + bool + CreateCollection(SQLite::Database* db, + const std::string& collection_name, + const std::string& pk_name, + const std::string& schema_proto); + + const std::string& + GetCollectionSchema(const std::string& collection_name); + + bool + CreateIndex(SQLite::Database* db, + const std::string& collection_name, + const std::string& index_name, + const std::string& index_proto); + + void + GetAllIndex(const std::string& collection_name, + const std::string& exclude, + std::vector* all_index); + + bool + HasIndex(const std::string& collection_name, + const std::string& index_name) { + return collections_[collection_name]->HasIndex(index_name); + } + + bool + DropIndex(SQLite::Database* db, + const std::string& collection_name, + const std::string& index_name); + + bool + GetCollectionIndex(const std::string& collection_name, + const std::string& index_name, + std::string* ouput_index_info); + + std::string + GetPkName(const std::string& collection_name) { + return collections_[collection_name]->GetPkName(); + } + + void + CollectionNames(std::vector* collection_names) { + for (const auto& pair : collections_) { + collection_names->push_back(pair.first); + } + } + + bool + DropCollection(SQLite::Database* db, const std::string& collection_name); + + private: + bool + CreateTable(SQLite::Database* db); + + bool + LoadMeta(SQLite::Database* db); + + // std::string + // schema_info(const std::string& collection_name); + + private: + // collection meta + std::map> collections_; + + private: + // sqlite3 table info + const std::string table_meta_name_; + + // table column name + const std::string col_id_; + const std::string col_collection_name_; + const std::string col_meta_type_; + const std::string col_blob_field_; + const std::string col_string_field_; +}; + +} // namespace milvus::local diff --git a/src/common.h b/src/common.h new file mode 100644 index 0000000..d565762 --- /dev/null +++ b/src/common.h @@ -0,0 +1,107 @@ +#pragma once + +#include +#include +#include + +namespace milvus::local { + +#define CHECK_STATUS(status, err) \ + do { \ + if (!(status).IsOk()) { \ + return status; \ + } \ + } while (0) + +#define DELETE_AND_SET_NULL(ptr, deleter) \ + do { \ + if ((ptr) != nullptr) { \ + deleter(ptr); \ + (ptr) = nullptr; \ + } \ + } while (0) + +using KVMap = std::map; + +// system field id: +// 0: unique row id +// 1: timestamp +// 100: first user field id +// 101: second user field id +// 102: ... +const int64_t kStartOfUserFieldId = 100; + +const int64_t kRowIdField = 0; + +const int64_t kTimeStampField = 1; + +const std::string kRowIdFieldName("RowID"); + +const std::string kTimeStampFieldName("Timestamp"); + +const std::string kMetaFieldName("$meta"); + +const std::string kPlaceholderTag("$0"); + +const int64_t kTopkLimit = 16384; + +// scalar index type +const std::string kDefaultStringIndexType("Trie"); +const std::string kInvertedIndexType("INVERTED"); +const std::string kDefaultArithmeticIndexType = ("STL_SORT"); + +const int64_t kMaxIndexRow = 1000000; + +// Search, Index parameter keys +const std::string kTopkKey("topk"); +// const std::string kSearchParamKey("search_param"); +const std::string kSearchParamKey("params"); +const std::string kOffsetKey("offset"); +const std::string kRoundDecimalKey("round_decimal"); +const std::string kGroupByFieldKey("group_by_field"); +const std::string kAnnFieldKey("anns_field"); +const std::string kSegmentNumKey("segment_num"); +const std::string kWithFilterKey("with_filter"); +const std::string kWithOptimizeKey("with_optimize"); +const std::string kCollectionKey("collection"); +const std::string kIndexParamsKey("params"); +const std::string kIndexTypeKey("index_type"); +const std::string kMetricTypeKey("metric_type"); +const std::string kDimKey("dim"); +const std::string kMaxLengthKey("max_length"); +const std::string kMaxCapacityKey("max_capacity"); +const std::string kReduceStopForBestKey("reduce_stop_for_best"); +const std::string kLimitKey("limit"); +const std::string KMetricsIPName("IP"); +const std::string kMetricsCosineName("COSINE"); +const std::string kMetricsL2Name("L2"); + +const std::string kCountStr("count(*)"); + +inline int64_t +GetCollectionId(const std::string& collection_name) { + std::hash hasher; + size_t hash_value = hasher(collection_name); + return static_cast(hash_value); +} + +inline int64_t +GetIndexId(const std::string& index_name) { + std::hash hasher; + size_t hash_value = hasher(index_name); + return static_cast(hash_value); +} + +struct NonCopyableNonMovable { + constexpr NonCopyableNonMovable() noexcept = default; + virtual ~NonCopyableNonMovable() noexcept = default; + + NonCopyableNonMovable(NonCopyableNonMovable&&) = delete; + NonCopyableNonMovable& + operator=(NonCopyableNonMovable&&) = delete; + NonCopyableNonMovable(const NonCopyableNonMovable&) = delete; + NonCopyableNonMovable& + operator=(const NonCopyableNonMovable&) = delete; +}; + +} // namespace milvus::local diff --git a/src/create_collection_task.cpp b/src/create_collection_task.cpp new file mode 100644 index 0000000..a1157a8 --- /dev/null +++ b/src/create_collection_task.cpp @@ -0,0 +1,323 @@ +#include "create_collection_task.h" + +#include +#include +#include +#include +#include +#include "common.h" +#include "log/Log.h" +#include "pb/schema.pb.h" +#include "status.h" +#include "string_util.hpp" + +namespace milvus::local { + +using DType = ::milvus::proto::schema::DataType; +using DCase = ::milvus::proto::schema::ValueField::DataCase; + +bool +CreateCollectionTask::HasSystemFields( + const ::milvus::proto::schema::CollectionSchema& schema) { + for (const auto& f : schema.fields()) { + if (f.name() == kRowIdFieldName || f.name() == kTimeStampFieldName || + f.name() == kMetaFieldName) { + return true; + } + } + return false; +} + +Status +CreateCollectionTask::GetVarcharFieldMaxLength( + const ::milvus::proto::schema::FieldSchema& field, uint64_t* max_len) { + if (field.data_type() != DType::VarChar && + field.element_type() != DType::VarChar) { + return Status::ParameterInvalid("{} is not varchar field", + field.name()); + } + + for (const auto& kv_pair : field.type_params()) { + if (kv_pair.key() == kMaxLengthKey) { + try { + auto length = std::stoll(kv_pair.value()); + if (length <= 0) { + return Status::ParameterInvalid( + "the maximum length specified for a VarChar should be " + "in (0, 65535])"); + } else { + *max_len = static_cast(length); + return Status::Ok(); + } + } catch (std::exception& e) { + return Status::ParameterInvalid("Invalid max length {}", + kv_pair.value()); + } + } + } + + for (const auto& kv_pair : field.index_params()) { + if (kv_pair.key() == kMaxLengthKey) { + try { + *max_len = std::stoll(kv_pair.value()); + return Status::Ok(); + } catch (std::exception& e) { + return Status::ParameterInvalid("Invalid max length {}", + kv_pair.value()); + } + } + } + return Status::ParameterInvalid( + "type param(max_length) should be specified for varChar field of " + "collection"); +} + +bool +CreateCollectionTask::CheckDefaultValue( + const ::milvus::proto::schema::CollectionSchema& schema) { + for (const auto& f : schema.fields()) { + if (!f.has_default_value() || !f.has_default_value()) + continue; + switch (f.default_value().data_case()) { + case DCase::kBoolData: + if (f.data_type() != DType::Bool) { + LOG_ERROR( + "{} field's default value is Bool type, mismatches " + "field type", + f.name()); + return false; + } + break; + case DCase::kIntData: { + if (f.data_type() != DType::Int16 && + f.data_type() != DType::Int32 && + f.data_type() != DType::Int8) { + LOG_ERROR( + "{} field's default value is Int type, mismatches " + "field type", + f.name()); + return false; + } + auto default_value = f.default_value().int_data(); + if (f.data_type() == DType::Int16) { + if (default_value < std::numeric_limits::min() || + default_value > std::numeric_limits::max()) { + LOG_ERROR("{} field's default value out of range.", + f.name()); + return false; + } + } + if (f.data_type() == DType::Int8) { + if (default_value < std::numeric_limits::min() || + default_value > std::numeric_limits::max()) { + LOG_ERROR("{} field's default value out of range.", + f.name()); + return false; + } + } + } break; + case DCase::kLongData: + if (f.data_type() != DType::Int64) { + LOG_ERROR( + "{} field's default value is Long type, mismatches " + "field type", + f.name()); + return false; + } + break; + case DCase::kFloatData: + if (f.data_type() != DType::Float) { + LOG_ERROR( + "{} field's default value is Float type, mismatches " + "field type", + f.name()); + return false; + } + break; + case DCase::kDoubleData: + if (f.data_type() != DType::Double) { + LOG_ERROR( + "{} field's default value is Double type, mismatches " + "field type", + f.name()); + return false; + } + break; + case DCase::kStringData: { + if (f.data_type() != DType::VarChar) { + LOG_ERROR( + "{} field's default value is VarChar type, " + "mismatches field type", + f.name()); + return false; + } + auto string_len = f.default_value().string_data().size(); + uint64_t max_length = 0; + auto s = GetVarcharFieldMaxLength(f, &max_length); + if (s.IsErr()) { + LOG_ERROR(s.Detail()); + return false; + } + + if (string_len > max_length) { + return false; + } + + } break; + // case DCase::kBytesData: // not used + // break; + default: + return false; + break; + } + } + return true; +} + +void +CreateCollectionTask::AssignFieldId( + ::milvus::proto::schema::CollectionSchema* schema) { + for (int i = 0; i < schema->fields_size(); i++) { + schema->mutable_fields(i)->set_fieldid(kStartOfUserFieldId + i); + } +} + +void +CreateCollectionTask::AppendDynamicField( + ::milvus::proto::schema::CollectionSchema* schema) { + if (schema->enable_dynamic_field()) { + auto dynamice_field = schema->add_fields(); + dynamice_field->set_name(kMetaFieldName); + dynamice_field->set_description("dynamic schema"); + dynamice_field->set_data_type(DType::JSON); + dynamice_field->set_is_dynamic(true); + } +} + +void +CreateCollectionTask::AppendSysFields( + ::milvus::proto::schema::CollectionSchema* schema) { + auto row_id_field = schema->add_fields(); + row_id_field->set_fieldid(kRowIdField); + row_id_field->set_name(kRowIdFieldName); + row_id_field->set_is_primary_key(false); + row_id_field->set_description("row id"); + row_id_field->set_data_type(DType::Int64); + + auto ts_field = schema->add_fields(); + ts_field->set_fieldid(kTimeStampField); + ts_field->set_name(kTimeStampFieldName); + ts_field->set_is_primary_key(false); + ts_field->set_description("time stamp"); + ts_field->set_data_type(DType::Int64); +} + +Status +CreateCollectionTask::ValidateSchema( + const ::milvus::proto::schema::CollectionSchema& schema) { + std::set field_names; + std::string pk_name; + for (const auto& field_schema : schema.fields()) { + if (field_names.find(field_schema.name()) != field_names.end()) { + return Status::ParameterInvalid("Duplicated field name: {}", + field_schema.name()); + } + if (field_schema.is_primary_key()) { + if (!pk_name.empty()) { + return Status::ParameterInvalid( + "there are more than one primary key, field_name = {}, {}", + pk_name, + field_schema.name()); + } + } else { + pk_name = field_schema.name(); + } + if (field_schema.is_dynamic()) { + return Status::ParameterInvalid( + "cannot explicitly set a field as a dynamic field"); + } + CHECK_STATUS(CheckFieldName(field_schema.name()), ""); + if (field_schema.data_type() == DType::VarChar) { + uint64_t max_length = 0; + CHECK_STATUS(GetVarcharFieldMaxLength(field_schema, &max_length), + ""); + } + } + return Status::Ok(); +} + +Status +CreateCollectionTask::Process( + ::milvus::proto::schema::CollectionSchema* schema) { + if (!schema->ParseFromString(create_collection_request_->schema())) { + LOG_ERROR("Failed parse schema"); + return Status::ParameterInvalid("Failed parse schema"); + } + + if (create_collection_request_->collection_name() != schema->name()) { + auto err = string_util::SFormat( + "collection name [{}] not matches schema name [{}]", + create_collection_request_->collection_name(), + schema->name()); + LOG_ERROR(err); + return Status::ParameterInvalid(err); + } + + CHECK_STATUS(ValidateSchema(*schema), ""); + + if (HasSystemFields(*schema)) { + auto err_msg = + string_util::SFormat("Schema contains system field {}, {}, {}", + kRowIdFieldName, + kTimeStampFieldName, + kMetaFieldName); + LOG_ERROR(err_msg); + return Status::ParameterInvalid(err_msg); + } + + if (!CheckDefaultValue(*schema)) { + return Status::ParameterInvalid(); + } + + AppendDynamicField(schema); + AssignFieldId(schema); + AppendSysFields(schema); + return Status::Ok(); +} + +Status +CreateCollectionTask::CheckFieldName(const std::string& field_name) { + std::string name = string_util::Trim(field_name); + if (name.empty()) { + return Status::ParameterInvalid("field {} should not be empty", name); + } + std::string invalid_msg = + string_util::SFormat("Invalid field name {}. ", name); + if (name.size() > 255) { + return Status::ParameterInvalid( + "{}, the length of a field name must " + "be less than 255 characters", + invalid_msg); + } + + char first = name[0]; + if (first != '_' && !string_util::IsAlpha(first)) { + return Status::ParameterInvalid( + "{} the first character of a field {} must be an underscore " + "or letter", + invalid_msg, + name); + } + std::regex pattern("^[a-zA-Z_][a-zA-Z0-9_]*$"); + if (!std::regex_match(name, pattern)) { + auto err = string_util::SFormat( + "{}, field name can only contain " + "numbers, letters and underscores", + invalid_msg); + LOG_ERROR(err); + return Status::ParameterInvalid(err); + } + return Status::Ok(); +} + +} // namespace milvus::local diff --git a/src/create_collection_task.h b/src/create_collection_task.h new file mode 100644 index 0000000..6b5bd44 --- /dev/null +++ b/src/create_collection_task.h @@ -0,0 +1,56 @@ +#pragma once + +#include +#include "pb/milvus.pb.h" +#include "pb/schema.pb.h" +#include "pb/segcore.pb.h" +#include "common.h" +#include "status.h" + +namespace milvus::local { + +class CreateCollectionTask final : NonCopyableNonMovable { + public: + explicit CreateCollectionTask( + const ::milvus::proto::milvus::CreateCollectionRequest* + create_collection_request) + : create_collection_request_(create_collection_request) { + } + virtual ~CreateCollectionTask() = default; + + public: + Status + Process(::milvus::proto::schema::CollectionSchema* schema); + + private: + bool + CheckDefaultValue(const ::milvus::proto::schema::CollectionSchema& schema); + + bool + HasSystemFields(const ::milvus::proto::schema::CollectionSchema& schema); + + void + AssignFieldId(::milvus::proto::schema::CollectionSchema* schema); + + void + AppendDynamicField(::milvus::proto::schema::CollectionSchema* schema); + + void + AppendSysFields(::milvus::proto::schema::CollectionSchema* schema); + + Status + GetVarcharFieldMaxLength(const ::milvus::proto::schema::FieldSchema& field, + uint64_t* max_len); + + Status + ValidateSchema(const ::milvus::proto::schema::CollectionSchema& schema); + + Status + CheckFieldName(const std::string& field_name); + + private: + const ::milvus::proto::milvus::CreateCollectionRequest* + create_collection_request_; +}; + +} // namespace milvus::local diff --git a/src/create_index_task.cpp b/src/create_index_task.cpp new file mode 100644 index 0000000..e70d8fb --- /dev/null +++ b/src/create_index_task.cpp @@ -0,0 +1,507 @@ +#include "create_index_task.h" +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "nlohmann/json.hpp" +#include "pb/common.pb.h" +#include "pb/schema.pb.h" +#include "pb/segcore.pb.h" +#include "schema_util.h" +#include "log/Log.h" +#include "status.h" +#include "string_util.hpp" + +namespace milvus::local { + +/* dtype index_type metric_type + + * FloatVector: FLAT, HNWS L2, IP, COSINE, AUTOINDEX + * BinaryVector: BIN_FLAT HAMMING, JACCARD, SUBSTRUCTURE, SUPERSTRUCTURE + * Float16Vector: FLAT L2, IP, COSINE + * BFloat16Vector: FLAT L2, IP, COSINE + * SparseFloatVector: SPARSE_INVERTED_INDEX, SPARSE_WAND IP + */ + +// metrics type +const char* kL2 = "L2"; +const char* kIP = "IP"; +const char* kCosine = "COSINE"; +const char* kHamming = "HAMMING"; +const char* kJaccard = "JACCARD"; +const char* kSubStructure = "SUBSTRUCTURE"; +const char* kSuperStructure = "SUPERSTRUCTURE"; + +// index_type +const char* kAutoIndex = "AUTOINDEX"; +const char* kFlat = "FLAT"; +const char* kBin_Flat = "BIN_FLAT"; +const char* kHNSW = "HNSW"; +const char* kSparseInvertedIndex = "SPARSE_INVERTED_INDEX"; +const char* kSparseWand = "SPARSE_WAND"; + +// default metric +const char* kFloatVectorDefaultMetricType = kIP; +const char* kSparseFloatVectorDefaultMetricType = kIP; +const char* kBinaryVectorDefaultMetricType = kJaccard; + +class AutoIndexConfig final : NonCopyableNonMovable { + public: + AutoIndexConfig() + : index_param({{"M", "18"}, + {"efConstruction", "240"}, + {"index_type", kHNSW}, + {"metric_type", kIP}}) { + } + ~AutoIndexConfig() = default; + + public: + const KVMap index_param; +}; + +using DType = ::milvus::proto::schema::DataType; + +static const AutoIndexConfig kAutoIndexConfig; + +class IndexChecker : NonCopyableNonMovable { + public: + IndexChecker(const std::string& index_type, + const std::string& metric, + int64_t dim) + : index_type_(index_type), + metric_(metric), + dim_(dim), + need_check_dim_(true) { + } + virtual ~IndexChecker() = default; + + Status + Check(); + + protected: + std::string index_type_; + std::string metric_; + int64_t dim_; + + bool need_check_dim_; + int64_t min_dim_; + int64_t max_dim_; + std::vector supported_index_; + std::vector supported_metric_; +}; + +Status +IndexChecker::Check() { + if (need_check_dim_ && (dim_ < min_dim_ || dim_ > max_dim_)) { + auto err = string_util::SFormat( + "invalid dimension: {}. should be in range {} ~ {}", + dim_, + min_dim_, + max_dim_); + return Status::Undefined(err); + } + if (std::find(supported_index_.begin(), + supported_index_.end(), + index_type_) == supported_index_.end()) { + auto err = string_util::SFormat( + "invalid index type: {}, local mode only support {}", + index_type_, + string_util::Join(" ", supported_index_)); + return Status::Undefined(err); + } + if (std::find(supported_metric_.begin(), + supported_metric_.end(), + metric_) == supported_metric_.end()) { + auto err = string_util::SFormat( + "metric type {} not found or not supported, supported: {}", + metric_, + string_util::Join(" ", supported_metric_)); + return Status::Undefined(err); + } + return Status::Ok(); +} + +class FloatVectorIndexChecker : public virtual IndexChecker { + public: + FloatVectorIndexChecker(const std::string& index_type, + const std::string& metric, + int64_t dim) + : IndexChecker(index_type, metric, dim) { + min_dim_ = 2; + max_dim_ = 32768; + supported_index_ = {kFlat, kHNSW, kAutoIndex}; + supported_metric_ = {kL2, kIP, kCosine}; + } + + virtual ~FloatVectorIndexChecker() = default; +}; + +class BinaryVectorChecker : public virtual IndexChecker { + public: + BinaryVectorChecker(const std::string& index_type, + const std::string& metric, + int64_t dim) + : IndexChecker(index_type, metric, dim) { + min_dim_ = 2; + max_dim_ = 32768; + supported_index_ = {kBin_Flat}; + supported_metric_ = { + kHamming, kJaccard, kSubStructure, kSuperStructure}; + } + + virtual ~BinaryVectorChecker() = default; +}; + +class Float16VectorChecker : public virtual IndexChecker { + public: + Float16VectorChecker(const std::string& index_type, + const std::string& metric, + int64_t dim) + : IndexChecker(index_type, metric, dim) { + min_dim_ = 2; + max_dim_ = 32768; + supported_index_ = {kFlat}; + supported_metric_ = {kL2, kIP, kCosine}; + } + + virtual ~Float16VectorChecker() = default; +}; + +class BFloat16VectorChecker : public virtual IndexChecker { + public: + BFloat16VectorChecker(const std::string& index_type, + const std::string& metric, + int64_t dim) + : IndexChecker(index_type, metric, dim) { + min_dim_ = 2; + max_dim_ = 32768; + supported_index_ = {kFlat}; + supported_metric_ = {kL2, kIP, kCosine}; + } + + virtual ~BFloat16VectorChecker() = default; +}; + +class SparseFloatVectorChecker : public virtual IndexChecker { + public: + SparseFloatVectorChecker(const std::string& index_type, + const std::string& metric, + int64_t dim) + : IndexChecker(index_type, metric, dim) { + min_dim_ = -1; + max_dim_ = -1; + need_check_dim_ = false; + supported_index_ = {kSparseInvertedIndex, kSparseWand}; + supported_metric_ = {kIP}; + } + + virtual ~SparseFloatVectorChecker() = default; +}; + +Status +Check(DType field_type, + const std::string& index_type, + const std::string& metric, + int64_t dim) { + if (field_type == DType::FloatVector) { + return FloatVectorIndexChecker(index_type, metric, dim).Check(); + } else if (field_type == DType::Float16Vector) { + return Float16VectorChecker(index_type, metric, dim).Check(); + } else if (field_type == DType::BFloat16Vector) { + return BFloat16VectorChecker(index_type, metric, dim).Check(); + } else if (field_type == DType::BinaryVector) { + return BinaryVectorChecker(index_type, metric, dim).Check(); + } else if (field_type == DType::SparseFloatVector) { + return SparseFloatVectorChecker(index_type, metric, dim).Check(); + } else { + return Status::ParameterInvalid("Unknow data type"); + } +} + +void +CreateIndexTask::WrapUserIndexParams(const std::string& metrics_type) { + ::milvus::proto::common::KeyValuePair p1; + p1.set_key(kIndexTypeKey); + p1.set_value(kAutoIndex); + new_extra_params_.push_back(p1); + + ::milvus::proto::common::KeyValuePair p2; + p2.set_key(kMetricTypeKey); + p2.set_value(metrics_type); + new_extra_params_.push_back(p2); +} + +bool +CreateIndexTask::AddAutoIndexParams(size_t number_params, KVMap* index_params) { + is_auto_index_ = true; + if (index_params->size() == number_params) { + const auto metrics_type = + kAutoIndexConfig.index_param.at(kMetricTypeKey); + WrapUserIndexParams(metrics_type); + index_params->insert(kAutoIndexConfig.index_param.begin(), + kAutoIndexConfig.index_param.end()); + return true; + } + + if (index_params->size() > number_params + 1) { + LOG_ERROR("Only metric type can be passed when use AutoIndex"); + return false; + } + + if (index_params->size() == (number_params + 1)) { + auto it = kAutoIndexConfig.index_param.find(kMetricTypeKey); + if (it == kAutoIndexConfig.index_param.end()) { + LOG_ERROR("Only metric type can be passed when use AutoIndex"); + return false; + } + WrapUserIndexParams(it->second); + index_params->insert(kAutoIndexConfig.index_param.begin(), + kAutoIndexConfig.index_param.end()); + (*index_params)[kMetricTypeKey] = it->second; + return true; + } + return true; +} + +Status +CreateIndexTask::CheckTrain(const ::milvus::proto::schema::FieldSchema& field, + KVMap& index_params) { + auto index_type = index_params.at(kIndexTypeKey); + if (!IsVectorIndex(field.data_type())) { + return Status::Ok(); + } + if (!schema_util::IsSparseVectorType(field.data_type())) { + if (!FillDimension(field, &index_params)) { + return Status::ParameterInvalid(); + } + } + int64_t dim = -1; + if (!schema_util::IsSparseVectorType(field.data_type())) { + dim = std::stoll(index_params.at(kDimKey)); + } + + std::string metric = index_params.at(kMetricTypeKey); + return Check(field.data_type(), index_type, metric, dim); +} + +Status +CreateIndexTask::ParseIndexParams() { + const milvus::proto::schema::FieldSchema* field_ptr = nullptr; + for (const auto& field : schema_->fields()) { + if (field.name() == create_index_request_->field_name()) { + field_ptr = &field; + } + } + if (field_ptr == nullptr) { + auto err = string_util::SFormat("Can not found field {}", + create_index_request_->field_name()); + LOG_ERROR(err); + return Status::ParameterInvalid(err); + } + + if (!create_index_request_->index_name().empty()) { + index_name_ = create_index_request_->index_name(); + } else { + index_name_ = field_ptr->name(); + } + + field_id_ = field_ptr->fieldid(); + collectionid_ = GetCollectionId(schema_->name()); + + KVMap index_params; + for (const auto& param : create_index_request_->extra_params()) { + if (param.key() == kIndexParamsKey) { + try { + nlohmann::json data = nlohmann::json::parse(param.value()); + for (auto& [key, value] : data.items()) { + if (!value.is_string()) { + index_params[key] = value.dump(); + } + } + } catch (nlohmann::json::parse_error& e) { + auto err = + string_util::SFormat("Index params err: {}", e.what()); + LOG_ERROR(err); + return Status::ParameterInvalid(err); + } + + } else { + index_params[param.key()] = param.value(); + } + } + + if (IsVectorIndex(field_ptr->data_type())) { + auto it = index_params.find(kIndexTypeKey); + if (it == index_params.end() && !AddAutoIndexParams(0, &index_params)) { + return Status::SegcoreErr(); + } else if (it->second == kAutoIndex && + !AddAutoIndexParams(1, &index_params)) { + return Status::SegcoreErr(); + } + + auto metric_it = index_params.find(kMetricTypeKey); + if (metric_it == index_params.end()) { + if (field_ptr->data_type() == DType::FloatVector || + field_ptr->data_type() == DType::BFloat16Vector || + field_ptr->data_type() == DType::Float16Vector) { + index_params[kMetricTypeKey] = kFloatVectorDefaultMetricType; + } else if (field_ptr->data_type() == DType::BinaryVector) { + index_params[kMetricTypeKey] = kBinaryVectorDefaultMetricType; + } else if (field_ptr->data_type() == DType::SparseFloatVector) { + index_params[kMetricTypeKey] = + kSparseFloatVectorDefaultMetricType; + } else { + LOG_ERROR("Unkwon index data type: {}", field_ptr->data_type()); + return Status::ParameterInvalid(); + } + } + + } else { + // scalar index + auto it = index_params.find(kIndexTypeKey); + if (field_ptr->data_type() == DType::VarChar) { + if (it == index_params.end()) { + index_params[kIndexTypeKey] = kDefaultStringIndexType; + } else if (!ValidateStringIndexType(it->second)) { + auto err = + string_util::SFormat("Unkown index type {}", it->second); + LOG_ERROR(err); + return Status::ParameterInvalid(err); + } + } else if (field_ptr->data_type() == DType::Float || + field_ptr->data_type() == DType::Double || + field_ptr->data_type() == DType::Int16 || + field_ptr->data_type() == DType::Int8 || + field_ptr->data_type() == DType::Int32 || + field_ptr->data_type() == DType::Int64) { + if (it == index_params.end()) { + index_params[kIndexTypeKey] = kDefaultArithmeticIndexType; + } else if (!ValidateArithmeticIndexType(it->second)) { + auto err = + string_util::SFormat("Unkown index type {}", it->second); + LOG_ERROR(err); + return Status::ParameterInvalid(err); + } + } else if (field_ptr->data_type() == DType::Bool) { + if (it == index_params.end()) { + LOG_ERROR("no index type specified"); + return Status::ParameterInvalid("no index type specified"); + } + if (it->second != kInvertedIndexType) { + auto err = string_util::SFormat( + "index type {} not supported for boolean, supported: {}", + it->second, + kInvertedIndexType); + LOG_ERROR(err); + return Status::ParameterInvalid(err); + } + } else { + LOG_ERROR( + "Only int, varchar, float, double and bool fields support " + "scalar index."); + return Status::ParameterInvalid(); + } + } + auto it = index_params.find(kIndexTypeKey); + if (it == index_params.end()) { + LOG_ERROR("IndexType not specified"); + return Status::ParameterInvalid(); + } + + CHECK_STATUS(CheckTrain(*field_ptr, index_params), ""); + + index_params.erase(kDimKey); + index_params.erase(kMaxLengthKey); + + for (const auto& param : index_params) { + ::milvus::proto::common::KeyValuePair p; + p.set_key(param.first); + p.set_value(param.second); + new_index_params_.push_back(p); + } + + auto type_params = field_ptr->type_params(); + KVMap type_params_map; + for (const auto& param : type_params) { + ::milvus::proto::common::KeyValuePair p; + p.set_key(param.key()); + p.set_value(param.value()); + new_type_params_.push_back(p); + } + return Status::Ok(); +} + +bool +CreateIndexTask::IsVectorIndex(::milvus::proto::schema::DataType dtype) { + return schema_util::IsVectorField(dtype); +} + +bool +CreateIndexTask::FillDimension( + const ::milvus::proto::schema::FieldSchema& field, KVMap* index_params) { + if (!IsVectorIndex(field.data_type())) { + return true; + } + + std::string dim; + if (!schema_util::FindDimFromFieldParams(field, &dim)) { + LOG_ERROR("Dimension not found in schema"); + return false; + } + + auto it = index_params->find(kDimKey); + if (it != index_params->end() && it->second != dim) { + LOG_ERROR("dimension mismatch, dimension in schema: {}, dimension: {}", + dim, + it->second); + return false; + } else { + (*index_params)[kDimKey] = dim; + } + return true; +} + +Status +CreateIndexTask::Process(milvus::proto::segcore::FieldIndexMeta* field_meta) { + CHECK_STATUS(ParseIndexParams(), ""); + + field_meta->set_index_name(index_name_); + field_meta->set_fieldid(field_id_); + field_meta->set_collectionid(collectionid_); + field_meta->set_is_auto_index(is_auto_index_); + + std::set kset; + + for (const auto& param : new_index_params_) { + if (kset.find(param.key()) == kset.end()) { + auto pair = field_meta->add_index_params(); + pair->set_key(param.key()); + pair->set_value(param.value()); + kset.insert(param.key()); + } + } + + for (const auto& param : new_type_params_) { + if (kset.find(param.key()) == kset.end()) { + auto pair = field_meta->add_index_params(); + pair->set_key(param.key()); + pair->set_value(param.value()); + kset.insert(param.key()); + } + } + + for (const auto& param : new_extra_params_) { + if (kset.find(param.key()) == kset.end()) { + auto pair = field_meta->add_index_params(); + pair->set_key(param.key()); + pair->set_value(param.value()); + kset.insert(param.key()); + } + } + return Status::Ok(); +} + +} // namespace milvus::local diff --git a/src/create_index_task.h b/src/create_index_task.h new file mode 100644 index 0000000..020ed29 --- /dev/null +++ b/src/create_index_task.h @@ -0,0 +1,80 @@ +#pragma once + +#include +#include +#include +#include +#include "common.h" +#include "pb/common.pb.h" +#include "pb/milvus.pb.h" +#include "pb/schema.pb.h" +#include "pb/segcore.pb.h" +#include "status.h" + +namespace milvus::local { + +class CreateIndexTask final : NonCopyableNonMovable { + public: + CreateIndexTask( + const ::milvus::proto::milvus::CreateIndexRequest* create_index_request, + const ::milvus::proto::schema::CollectionSchema* schema) + : create_index_request_(create_index_request), schema_(schema) { + field_id_ = 0; + collectionid_ = 0; + is_auto_index_ = false; + } + virtual ~CreateIndexTask() = default; + + public: + Status + Process(milvus::proto::segcore::FieldIndexMeta* field_meta); + + private: + Status + ParseIndexParams(); + + void + WrapUserIndexParams(const std::string& metrics_type); + + bool + AddAutoIndexParams(size_t numberParams, KVMap* index_params); + + Status + CheckTrain(const ::milvus::proto::schema::FieldSchema& field, + KVMap& index_params); + + bool + ValidateStringIndexType(const std::string& index_type) { + return index_type == kDefaultStringIndexType || + index_type == "marisa-trie" || index_type == kInvertedIndexType; + } + + bool + ValidateArithmeticIndexType(const std::string& index_type) { + return index_type == kDefaultStringIndexType || + index_type == "Asceneding" || index_type == kInvertedIndexType; + } + + bool + FillDimension(const ::milvus::proto::schema::FieldSchema& field, + KVMap* index_params); + + bool + IsVectorIndex(::milvus::proto::schema::DataType dtype); + + private: + // string of ::milvus::proto::milvus::CreateIndexRequest; + const ::milvus::proto::milvus::CreateIndexRequest* create_index_request_; + const ::milvus::proto::schema::CollectionSchema* schema_; + + std::vector<::milvus::proto::common::KeyValuePair> new_index_params_; + std::vector<::milvus::proto::common::KeyValuePair> new_type_params_; + std::vector<::milvus::proto::common::KeyValuePair> new_extra_params_; + + std::string index_name_; + int64_t field_id_; + int64_t collectionid_; + bool is_auto_index_; +}; + +} // namespace milvus::local diff --git a/src/delete_task.cpp b/src/delete_task.cpp new file mode 100644 index 0000000..c1749a2 --- /dev/null +++ b/src/delete_task.cpp @@ -0,0 +1,32 @@ +#include "delete_task.h" +#include "antlr4-runtime.h" +#include "log/Log.h" +#include "parser/parser.h" +#include "parser/utils.h" +#include "schema_util.h" +#include "status.h" +#include "string_util.hpp" + +namespace milvus::local { + +Status +DeleteTask::Process(::milvus::proto::plan::PlanNode* plan) { + if (string_util::Trim(delete_request_->expr()) == "") { + return Status::ParameterInvalid("expr cannot be empty"); + } + CHECK_STATUS( + schema_util::ParseExpr(delete_request_->expr(), + *schema_, + plan->mutable_query()->mutable_predicates()), + ""); + + auto pk_id = schema_util::GetPkId(*schema_); + if (!pk_id.has_value()) { + LOG_ERROR("Can not found {}'s primary key", schema_->name()); + return Status::CollectionIllegalSchema(); + } + plan->add_output_field_ids(*pk_id); + return Status::Ok(); +} + +} // namespace milvus::local diff --git a/src/delete_task.h b/src/delete_task.h new file mode 100644 index 0000000..ca614c8 --- /dev/null +++ b/src/delete_task.h @@ -0,0 +1,27 @@ +#pragma once + +#include "pb/milvus.pb.h" +#include "pb/schema.pb.h" +#include "pb/plan.pb.h" +#include "common.h" +#include "status.h" + +namespace milvus::local { + +class DeleteTask : NonCopyableNonMovable { + public: + DeleteTask(const ::milvus::proto::milvus::DeleteRequest* delete_request, + const ::milvus::proto::schema::CollectionSchema* schema) + : delete_request_(delete_request), schema_(schema) { + } + virtual ~DeleteTask() = default; + + Status + Process(::milvus::proto::plan::PlanNode* plan); + + private: + const ::milvus::proto::milvus::DeleteRequest* delete_request_; + const ::milvus::proto::schema::CollectionSchema* schema_; +}; + +} // namespace milvus::local diff --git a/src/index.cpp b/src/index.cpp new file mode 100644 index 0000000..5158277 --- /dev/null +++ b/src/index.cpp @@ -0,0 +1,90 @@ +#include "index.h" +#include "common.h" +#include "log/Log.h" +#include "status.h" +#include + +namespace milvus::local { + +Status +Index::CreateCollection(const std::string& collection_name, + const std::string& schema_proto) { + if (HasLoaded(collection_name)) { + LOG_INFO("Collection {} alread load", collection_name); + return Status::Ok(); + } + + auto c = std::make_unique(); + CHECK_STATUS(c->SetCollectionInfo(collection_name, schema_proto), + "Create collection failed: "); + collections_[collection_name] = std::move(c); + return Status::Ok(); +} + +bool +Index::DropCollection(const std::string& collection_name) { + if (collections_.find(collection_name) != collections_.end()) { + collections_.erase(collection_name); + } + return true; +} + +Status +Index::CreateIndex(const std::string& collection_name, + const std::string& index_proto) { + if (collections_.find(collection_name) == collections_.end()) { + LOG_ERROR("Collecton {} not existed", collection_name); + return Status::CollectionNotFound(); + } + CHECK_STATUS(collections_[collection_name]->SetIndexMeta(index_proto), + "Create index failed:"); + return Status::Ok(); +} + +Status +Index::Insert(const std::string& collection_name, + int64_t size, + const std::string& insert_record_prot) { + if (collections_.find(collection_name) == collections_.end()) { + LOG_ERROR("Collecton {} not existed", collection_name); + return Status::CollectionNotFound(); + } + return collections_[collection_name]->Insert(size, insert_record_prot); +} + +Status +Index::Retrieve(const std::string& collection_name, + const std::string& plan, + RetrieveResult* result) { + if (collections_.find(collection_name) == collections_.end()) { + LOG_ERROR("Collecton {} not existed", collection_name); + return Status::CollectionNotFound(); + } + return collections_[collection_name]->Retrieve(plan, result); +} + +Status +Index::Search(const std::string& collection_name, + const std::string& plan, + const std::string& placeholder_group, + SearchResult* result) { + if (collections_.find(collection_name) == collections_.end()) { + LOG_ERROR("Collecton {} not existed", collection_name); + return Status::CollectionNotFound(); + } + return collections_[collection_name]->Search( + plan, placeholder_group, result); +} + +Status +Index::DeleteByIds(const std::string& collection_name, + const std::string& ids, + int64_t size) { + if (collections_.find(collection_name) == collections_.end()) { + LOG_ERROR("Collecton {} not existed", collection_name); + return Status::CollectionNotFound(); + } + return collections_[collection_name]->DeleteByIds(ids, size); +} + +} // namespace milvus::local diff --git a/src/index.h b/src/index.h new file mode 100644 index 0000000..6e7b26a --- /dev/null +++ b/src/index.h @@ -0,0 +1,60 @@ +#pragma once + +#include +#include +#include +#include +#include "segcore_wrapper.h" +#include "common.h" +#include "status.h" + +namespace milvus::local { + +class Index final : NonCopyableNonMovable { + public: + Index() = default; + virtual ~Index() = default; + + public: + // meta interface + Status + CreateCollection(const std::string& collection_name, + const std::string& schema_proto); + bool + DropCollection(const std::string& collection_name); + + bool + HasLoaded(const std::string& collection_name) { + return collections_.find(collection_name) != collections_.end(); + } + + Status + CreateIndex(const std::string& collection_name, + const std::string& index_proto); + + Status + Insert(const std::string& collection_name, + int64_t size, + const std::string& insert_record_proto); + + Status + Retrieve(const std::string& collection_name, + const std::string& expr, + RetrieveResult* result); + + Status + Search(const std::string& collection_name, + const std::string& plan, + const std::string& placeholder_group, + SearchResult* result); + + Status + DeleteByIds(const std::string& collection_name, + const std::string& ids, + int64_t size); + + private: + std::map> collections_; +}; + +} // namespace milvus::local diff --git a/src/insert_task.cpp b/src/insert_task.cpp new file mode 100644 index 0000000..d8ccb71 --- /dev/null +++ b/src/insert_task.cpp @@ -0,0 +1,238 @@ +#include "insert_task.h" +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "log/Log.h" +#include "pb/schema.pb.h" +#include "pb/segcore.pb.h" +#include "schema_util.h" +#include "status.h" +#include "string_util.hpp" + +namespace milvus::local { + +using DType = ::milvus::proto::schema::DataType; +int64_t InsertTask::cur_id_ = 0; + +InsertTask::InsertTask(::milvus::proto::milvus::InsertRequest* r, + const ::milvus::proto::schema::CollectionSchema* schema) + : insert_request_(r), schema_(schema), num_rows_(0) { +} + +bool +InsertTask::AddSystemField() { + num_rows_ = insert_request_->num_rows(); + if (num_rows_ <= 0) { + LOG_ERROR("Error rows nums {}", num_rows_); + return false; + } + + auto timestamps = GetTimestamps(num_rows_); + auto row_ids = GetRowIds(timestamps); + + auto row_field = insert_request_->add_fields_data(); + row_field->set_type(DType::Int64); + row_field->set_field_name(kRowIdFieldName); + row_field->set_field_id(kRowIdField); + for (auto id : row_ids) { + row_field->mutable_scalars()->mutable_long_data()->add_data(id); + } + + auto time_field = insert_request_->add_fields_data(); + time_field->set_type(DType::Int64); + time_field->set_field_name(kTimeStampFieldName); + time_field->set_field_id(kTimeStampField); + for (auto t : timestamps) { + time_field->mutable_scalars()->mutable_long_data()->add_data(t); + } + return true; +} + +bool +InsertTask::GenFieldMap() { + for (const auto& field : insert_request_->fields_data()) { + field_data_map_.emplace(field.field_name(), &field); + } + + for (const auto& field : schema_->fields()) { + if (field_data_map_.find(field.name()) == field_data_map_.end()) { + if (field.is_primary_key() && field.autoid()) { + auto row_id_field = field_data_map_.at(kRowIdFieldName); + auto pk_field = insert_request_->add_fields_data(); + pk_field->set_field_name(field.name()); + pk_field->set_field_id(field.fieldid()); + pk_field->set_type(field.data_type()); + if (field.data_type() == DType::VarChar) { + for (uint32_t i = 0; i < num_rows_; i++) { + pk_field->mutable_scalars() + ->mutable_string_data() + ->add_data(std::to_string(std::any_cast( + schema_util::GetField(*row_id_field, i)))); + } + } else { + for (uint32_t i = 0; i < num_rows_; i++) { + pk_field->mutable_scalars() + ->mutable_long_data() + ->add_data(std::any_cast( + schema_util::GetField(*row_id_field, i))); + } + } + field_data_map_.emplace(field.name(), pk_field); + } else { + LOG_ERROR("Lost field {}", field.name()); + return false; + } + } + } + + return true; +} + +bool +InsertTask::CheckDynamicFieldData() { + if (!schema_->enable_dynamic_field()) { + return true; + } + for (int i = 0; i < insert_request_->fields_data_size(); i++) { + auto field = insert_request_->mutable_fields_data(i); + if (field->is_dynamic()) { + field->set_field_name(kMetaFieldName); + //TODO check json + return true; + } + } + // no dynamic field found, and default value + auto dy_field = insert_request_->add_fields_data(); + dy_field->set_field_name(kMetaFieldName); + dy_field->set_is_dynamic(true); + for (const auto& schema_field : schema_->fields()) { + if (schema_field.is_dynamic()) { + dy_field->set_field_id(schema_field.fieldid()); + break; + } + } + dy_field->set_type(::milvus::proto::schema::DataType::JSON); + dy_field->mutable_scalars()->mutable_json_data()->add_data()->assign("{}"); + return true; +} + +Status +InsertTask::Process(Rows* rows) { + if (!(AddSystemField() && CheckDynamicFieldData() && GenFieldMap())) { + return Status::ParameterInvalid(); + } + + CHECK_STATUS(CheckVectorDim(), ""); + + auto pk_field_name = schema_util::GetPkName(*schema_); + if (!pk_field_name.has_value()) { + auto err = + string_util::SFormat("Collection {} has no pk", schema_->name()); + LOG_ERROR(err); + return Status::ParameterInvalid(err); + } + + pk_type_ = field_data_map_.at(pk_field_name.value())->type(); + + for (uint32_t i = 0; i < num_rows_; i++) { + ::milvus::proto::segcore::InsertRecord record; + std::string pk; + record.set_num_rows(1); + for (const auto& field : schema_->fields()) { + auto field_data = record.add_fields_data(); + field_data->set_field_id(field.fieldid()); + field_data->set_field_name(field.name()); + field_data->set_type(field.data_type()); + if (!schema_util::SliceFieldData( + *field_data_map_.at(field.name()), + std::vector>{{i, 1}}, + field_data)) { + LOG_ERROR("Parse field data failed"); + return Status::FieldNotFound(); + } + if (field.name() == pk_field_name.value()) { + if (field.data_type() == DType::Int64) { + pk = std::to_string( + field_data->scalars().long_data().data(0)); + } else { + pk = field_data->scalars().string_data().data(0); + } + } + } + rows->push_back(std::make_tuple(pk, record.SerializeAsString())); + } + return Status::Ok(); +} + +Status +InsertTask::CheckVectorDim() { + int64_t num_rows = insert_request_->num_rows(); + if (num_rows <= 0) { + return Status::ParameterInvalid("Err num_rows: {}", num_rows); + } + for (const auto& field_schema : schema_->fields()) { + if (field_schema.data_type() == DType::FloatVector) { + // int64_t dim = field_data.vectors().dim(); + auto field_data = field_data_map_.at(field_schema.name()); + int64_t dim = schema_util::GetDim(field_schema); + if (dim <= 0) { + return Status::ParameterInvalid("Can not found dim info"); + } + int vect_size = field_data->vectors().float_vector().data_size(); + if (vect_size % dim != 0) { + return Status::Undefined( + "the length({}) of float data should divide the dim({})", + vect_size, + dim); + } + + int32_t vec_rows = vect_size / field_data->vectors().dim(); + + if (vec_rows != num_rows) { + return Status::ParameterInvalid( + "the num_rows ({}) of field ({}) is not equal to passed " + "num_rows ({}): [expected={}][actual={}]", + vec_rows, + field_data->field_name(), + num_rows, + num_rows, + vec_rows); + } + } + } + return Status::Ok(); +} + +std::vector +InsertTask::GetTimestamps(int64_t size) { + auto ts = GetTimestamp(); + return std::vector(size, ts); +} + +uint64_t +InsertTask::GetTimestamp() { + // https://github.com/milvus-io/milvus/blob/master/docs/design_docs/20211214-milvus_hybrid_ts.md + auto now = std::chrono::system_clock::now(); + auto duration = now.time_since_epoch(); + auto ms = + std::chrono::duration_cast(duration).count(); + return (ms << 18) + cur_id_; +} + +std::vector +InsertTask::GetRowIds(std::vector& timestamps) { + size_t size = timestamps.size(); + std::vector row_ids; + for (size_t i = 0; i < size; i++) { + row_ids.push_back(i + cur_id_ + timestamps[i]); + } + cur_id_ += size; + return row_ids; +} + +} // namespace milvus::local diff --git a/src/insert_task.h b/src/insert_task.h new file mode 100644 index 0000000..ff00fce --- /dev/null +++ b/src/insert_task.h @@ -0,0 +1,61 @@ +#pragma once +#include +#include +#include +#include "pb/milvus.pb.h" +#include "pb/schema.pb.h" +#include "pb/segcore.pb.h" +#include "status.h" +#include "type.h" +#include "common.h" + +namespace milvus::local { + +class InsertTask : NonCopyableNonMovable { + public: + InsertTask(::milvus::proto::milvus::InsertRequest* r, + const ::milvus::proto::schema::CollectionSchema* schema); + virtual ~InsertTask() = default; + + public: + Status + Process(Rows* insert_records); + + ::milvus::proto::schema::DataType + PkType() { + return pk_type_; + } + + private: + bool + GenFieldMap(); + + Status + CheckVectorDim(); + + std::vector + GetTimestamps(int64_t size); + + uint64_t + GetTimestamp(); + + std::vector + GetRowIds(std::vector& timestamps); + + bool + AddSystemField(); + + bool + CheckDynamicFieldData(); + + private: + static int64_t cur_id_; + ::milvus::proto::milvus::InsertRequest* insert_request_; + const ::milvus::proto::schema::CollectionSchema* schema_; + std::map + field_data_map_; + uint32_t num_rows_; + ::milvus::proto::schema::DataType pk_type_; +}; + +} // namespace milvus::local diff --git a/src/milvus/__init__.py b/src/milvus/__init__.py deleted file mode 100644 index dadd145..0000000 --- a/src/milvus/__init__.py +++ /dev/null @@ -1,564 +0,0 @@ -"""Milvus Server -""" - -from argparse import ArgumentParser, Action -import logging -import os -import shutil -import signal -import sys -import lzma -from os import makedirs -from os.path import join, abspath, dirname, expandvars, isfile -import re -import subprocess -import socket -from time import sleep -import datetime -from typing import Any, List -import urllib.error -import urllib.request -import json -import hashlib - -__version__ = '2.3.0-beta.1' - -LOGGERS = {} - - -def _initialize_data_files(base_dir) -> None: - bin_dir = join(base_dir, 'bin') - os.makedirs(bin_dir, exist_ok=True) - lzma_dir = join(dirname(abspath(__file__)), 'data', 'bin') - files = filter(lambda x: x.endswith('.lzma'), os.listdir(lzma_dir)) - files = map(lambda x: x[:-5], files) - for filename in files: - orig_file = join(bin_dir, filename) - lzma_md5_file = orig_file + '.lzma.md5' - lzma_file = join(lzma_dir, filename) + '.lzma' - with open(lzma_file, 'rb') as raw: - md5sum_text = hashlib.md5(raw.read()).hexdigest() - if isfile(lzma_md5_file): - with open(lzma_md5_file, 'r', encoding='utf-8') as lzma_md5_fp: - md5sum_text_pre = lzma_md5_fp.read().strip() - if md5sum_text == md5sum_text_pre: - continue - with lzma.LZMAFile(lzma_file, mode='r') as lzma_fp: - with open(orig_file, 'wb') as raw: - raw.write(lzma_fp.read()) - os.chmod(orig_file, 0o755) - with open(lzma_md5_file, 'w', encoding='utf-8') as lzma_md5_fp: - lzma_md5_fp.write(md5sum_text) - - -def _create_logger(usage: str = 'null') -> logging.Logger: - usage = usage.lower() - if usage in LOGGERS: - return LOGGERS[usage] - logger = logging.Logger(name=f'python_milvus_server_{usage}') - if usage != 'debug': - logger.setLevel(logging.FATAL) - else: - logger.setLevel(logging.DEBUG) - handler = logging.StreamHandler(sys.stderr) - formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d: %(message)s') - logger.addHandler(handler) - handler.setFormatter(formatter) - LOGGERS[usage] = logger - return logger - - -class MilvusServerConfig: - """ Milvus server config - """ - - RANDOM_PORT_START = 40000 - - def __init__(self, **kwargs): - """create new configuration for milvus server - - Kwargs: - template(str, optional): template file path - - data_dir(str, optional): base data directory for log and data - """ - self.base_data_dir = '' - self.configs: dict = kwargs - self.logger = _create_logger('debug' if kwargs.get('debug', False) else 'null') - - self.template_file: str = kwargs.get('template', None) - self.template_text: str = '' - self.config_key_maps = {} - self.configurable_items = {} - self.extra_configs = {} - self.load_template() - self.parse_template() - self.listen_ports = {} - - def update(self, **kwargs): - """ update configs - """ - self.configs.update(kwargs) - - def load_template(self): - """ load config template for milvus server - """ - if not self.template_file: - self.template_file = join(dirname(abspath(__file__)), 'data', 'config.yaml.template') - with open(self.template_file, 'r', encoding='utf-8') as template: - self.template_text = template.read() - - def parse_template(self): - """ parse template, lightweight template engine - for avoid introducing dependencies like: yaml/Jinja2 - - We using: - - {{ foo }} for variable - - {{ bar: value }} for variable with default values - - {{ bar(type) }} and {{ bar(type): value }} for type hint - """ - type_mappings = { - 'int': int, - 'bool': bool, - 'str': str, - 'string': str - } - for line in self.template_text.split('\n'): - matches = re.match(r'.*\{\{(.*)}}.*', line) - if matches: - text = matches.group(1) - original_key = '{{' + text + '}}' - text = text.strip() - value_type = str - if ':' in text: - key, val = text.split(':', maxsplit=2) - key, val = key.strip(), val.strip() - else: - key, val = text.strip(), None - if '(' in key: - key, type_str = key.split('(') - key, type_str = key.strip(), type_str.strip() - type_str = type_str.replace(')', '') - value_type = type_mappings[type_str] - self.config_key_maps[original_key] = key - self.configurable_items[key] = [value_type, self.get_value(val, value_type)] - self.verbose_configurable_items() - - def verbose_configurable_items(self): - for key, val in self.configurable_items.items(): - self.logger.debug( - 'Config item %s(%s) with default: %s', key, val[0], val[1]) - - def resolve(self): - self.cleanup_listen_ports() - self.resolve_all_listen_ports() - self.resolve_storage() - for key, value in self.configurable_items.items(): - if value[1] is None: - raise RuntimeError(f'{key} is still not resolved, please try specify one.') - # ready to start - self.cleanup_listen_ports() - self.write_config() - self.verbose_configurable_items() - - def resolve_port(self, port_start: int): - used_ports = self.listen_ports.values() - used_ports = [x[0] for x in used_ports if len(x) == 2] - used_ports = set(used_ports) - for i in range(10000): - port = port_start + i - if port not in used_ports: - sock = self.try_bind_port(port) - if sock: - return port, sock - return None, None - - def resolve_all_listen_ports(self): - port_keys = list(filter(lambda x: x.endswith('_port'), self.configurable_items.keys())) - for port_key in port_keys: - if port_key in self.configs: - port = int(self.configs.get(port_key)) - sock = self.try_bind_port(port) - if not sock: - raise RuntimeError(f'set {port_key}={port}, but bind failed') - else: - port_start = self.configurable_items[port_key][1] - port_start = port_start or self.RANDOM_PORT_START - port_start = int(port_start) - port, sock = self.resolve_port(port_start) - self.listen_ports[port_key] = (port, sock) - self.logger.debug('bind port %d for %s success', port, port_key) - for port_key, data in self.listen_ports.items(): - self.configurable_items[port_key][1] = data[0] - - def try_bind_port(self, port): - """ return a socket if bind success, else None - """ - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - try: - sock.bind(('127.0.0.1', port)) - sock.listen() - return sock - except (OSError, socket.error, socket.gaierror, socket.timeout, ValueError, OverflowError) as ex: - self.logger.debug('try bind port:%d failed, %s', port, ex) - return None - - @classmethod - def get_default_data_dir(cls): - if sys.platform.lower() == 'win32': - default_dir = expandvars('%APPDATA%') - return join(default_dir, 'milvus.io', 'milvus-server', __version__) - default_dir = expandvars('${HOME}') - return join(default_dir, '.milvus.io', 'milvus-server', __version__) - - @classmethod - def get_value_text(cls, val) -> str: - if isinstance(val, bool): - return 'true' if val else 'false' - return str(val) - - @classmethod - def get_value(cls, text, val_type) -> Any: - if val_type == bool: - return text == 'true' - if val_type == int: - if not text: - return 0 - return int(text) - return text - - def resolve_storage(self): - self.base_data_dir = self.configs.get('data_dir', self.get_default_data_dir()) - self.base_data_dir = abspath(self.base_data_dir) - makedirs(self.base_data_dir, exist_ok=True) - config_dir = join(self.base_data_dir, 'configs') - logs_dir = join(self.base_data_dir, 'logs') - storage_dir = join(self.base_data_dir, 'data') - for subdir in (config_dir, logs_dir, storage_dir): - makedirs(subdir, exist_ok=True) - - # logs - if sys.platform.lower() == 'win32': - self.set('etcd_log_path', 'winfile:///' + join(logs_dir, 'etcd.log').replace('\\', '/')) - else: - self.set('etcd_log_path', join(logs_dir, 'etcd.log')) - self.set('system_log_path', logs_dir) - - # data - self.set('etcd_data_dir', join(storage_dir, 'etcd.data')) - self.set('local_storage_dir', join(storage_dir, 'storage')) - self.set('rocketmq_data_dir', join(storage_dir, 'rocketmq')) - - def get(self, attr) -> Any: - return self.configurable_items[attr][1] - - def get_type(self, attr) -> Any: - return self.configurable_items[attr][0] - - def set(self, attr, val) -> None: - if attr in self.configurable_items: - if isinstance(val, self.configurable_items[attr][0]): - self.configurable_items[attr][1] = val - else: - self.extra_configs[attr] = val - - def cleanup_listen_ports(self): - for data in self.listen_ports.values(): - if data[1]: - data[1].close() - self.listen_ports.clear() - - def write_config(self): - config_file = join(self.base_data_dir, 'configs', 'milvus.yaml') - os.makedirs(dirname(config_file), exist_ok=True) - content = self.template_text - for key, val in self.config_key_maps.items(): - value = self.configurable_items[val][1] - value_text = self.get_value_text(value) - content = content.replace(key, value_text) - content = self.update_extra_configs(content) - with open(config_file, 'w', encoding='utf-8') as config: - config.write(content) - - def update_extra_configs(self, content): - current_key = [] - new_content = '' - for line in content.splitlines(): - if line.strip().startswith('#'): - new_content += line + os.linesep - continue - matches = re.match( - r'^( *[a-zA-Z0-9_]+):([^#]*)(#.*)?$', line.rstrip()) - if not matches: - new_content += line + os.linesep - continue - key_with_prefix = matches.group(1).rstrip() - comment = matches.group(3) or '' - key = key_with_prefix.strip() - level = (len(key_with_prefix) - len(key)) // 2 - current_key = current_key[:level] - current_key.append(key) - current_key_text = '.'.join(current_key) - for extra_key, extra_val in self.extra_configs.items(): - if extra_key == current_key_text: - if comment.strip(): - line = f'{key_with_prefix}: {extra_val} #{comment.strip()[1:]}' - else: - line = f'{key_with_prefix}: {extra_val}' - new_content += line + os.linesep - return new_content - - -class MilvusServer: - """ Milvus server - """ - - def __init__(self, config: MilvusServerConfig = None, wait_for_started=True, **kwargs): - """_summary_ - - Args: - config (MilvusServerConfig, optional): the server config. - Defaults to default_server_config. - wait_for_started (bool, optional): wait for server started. Defaults to True. - - Kwargs: - """ - if not config: - self.config = MilvusServerConfig() - else: - self.config = config - self.config.update(**kwargs) - self.server_proc = None - self.proc_fds = {} - self._debug = kwargs.get('debug', False) - self.logger = _create_logger('debug' if self._debug else 'null') - self.webservice_port = 9091 - self.wait_for_started = wait_for_started - - def get_milvus_executable_path(self): - """ get where milvus - """ - if sys.platform.lower() == 'win32': - join(self.config.base_data_dir, 'bin', 'milvus.exe') - return join(self.config.base_data_dir, 'bin', 'milvus') - - def __enter__(self): - self.start() - - def __exit__(self, exc_type, exc_val, exc_tb): - self.stop() - - def __del__(self): - self.stop() - - @classmethod - def prepend_path_to_envs(cls, envs, name, val): - envs.update({name: ':'.join([val, os.environ.get(name, '')])}) - - def cleanup(self): - if self.running: - raise RuntimeError('Server is running') - shutil.rmtree(self.config.base_data_dir, ignore_errors=True) - - def wait(self): - while self.running: - sleep(0.1) - - def wait_started(self, timeout=30000): - """ wait server started - - Args: - timeout: timeout in milliseconds, default 30,000 - - use http client to visit the health api to check if server ready - """ - start_time = datetime.datetime.now() - health_url = f'http://127.0.0.1:{self.webservice_port}/api/v1/health' - while (datetime.datetime.now() - start_time).total_seconds() < (timeout / 1000) and self.running: - try: - with urllib.request.urlopen(health_url, timeout=100) as resp: - json.loads(resp.read().decode('utf-8')) - self.logger.info('Milvus server is started') - # still wait 1 seconds to make sure server is ready - sleep(1) - return - except (urllib.error.URLError, urllib.error.HTTPError, json.JSONDecodeError): - sleep(0.1) - if self.running: - raise TimeoutError(f'Milvus not startd in {timeout/1000} seconds') - else: - raise RuntimeError('Milvus server already stopped') - - def start(self): - self.config.resolve() - _initialize_data_files(self.config.base_data_dir) - - milvus_exe = self.get_milvus_executable_path() - old_pwd = os.getcwd() - os.chdir(self.config.base_data_dir) - envs = os.environ.copy() - # resolve listen port for METRICS_PORT (restful service), default 9091 - self.webservice_port, sock = self.config.resolve_port(self.webservice_port) - sock.close() - envs.update({ - 'DEPLOY_MODE': 'STANDALONE', - 'METRICS_PORT': str(self.webservice_port) - }) - if sys.platform.lower() == 'linux': - self.prepend_path_to_envs(envs, 'LD_LIBRARY_PATH', dirname(milvus_exe)) - if sys.platform.lower() == 'darwin': - self.prepend_path_to_envs(envs, 'DYLD_LIBRARY_PATH', dirname(milvus_exe)) - for name in ('stdout', 'stderr'): - run_log = join(self.config.base_data_dir, 'logs', f'milvus-{name}.log') - # pylint: disable=consider-using-with - self.proc_fds[name] = open(run_log, 'w', encoding='utf-8') - cmds = [milvus_exe, 'run', 'standalone'] - proc_fds = self.proc_fds - if self._debug: - self.server_proc = subprocess.Popen(cmds, env=envs) - else: - # pylint: disable=consider-using-with - self.server_proc = subprocess.Popen(cmds, stdout=proc_fds['stdout'], stderr=proc_fds['stderr'], env=envs) - os.chdir(old_pwd) - if self.wait_for_started: - self.wait_started() - if not self._debug: - self.show_banner() - - def show_banner(self): - print(r""" - - __ _________ _ ____ ______ - / |/ / _/ /| | / / / / / __/ - / /|_/ // // /_| |/ / /_/ /\ \ - /_/ /_/___/____/___/\____/___/ {Lite} - - Welcome to use Milvus! -""") - print(f' Version: v{__version__}-lite') - print(f' Process: {self.server_proc.pid}') - print(f' Started: {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}') - print(f' Config: {join(self.config.base_data_dir, "configs", "milvus.yaml")}') - print(f' Logs: {join(self.config.base_data_dir, "logs")}') - print('\n Ctrl+C to exit ...') - - def stop(self): - if self.server_proc: - self.server_proc.terminate() - self.server_proc.wait() - self.server_proc = None - for fd in self.proc_fds.values(): - fd.close() - self.proc_fds.clear() - - def set_base_dir(self, dir_path): - self.config.configs.update(data_dir=dir_path) - self.config.resolve_storage() - - @property - def running(self) -> bool: - return self.server_proc is not None - - @property - def server_address(self) -> str: - return '127.0.0.1' - - @property - def config_keys(self) -> List[str]: - return self.config.configurable_items.keys() - - @property - def listen_port(self) -> int: - return int(self.config.get('proxy_port')) - - @listen_port.setter - def listen_port(self, val: int): - self.config.set('proxy_port', val) - - @property - def debug(self): - return self._debug - - @debug.setter - def debug(self, val: bool): - self._debug = val - self.logger = _create_logger('debug' if val else 'null') - self.config.logger = self.logger - - -default_server = MilvusServer() -debug_server = MilvusServer(MilvusServerConfig(), debug=True) - - -# pylint: disable=unused-argument -class ExtraConfigAcxtion(Action): - """ action class for extra config - - the extra config is in format of key=value, the value will be converted to int or float if possible - for setting a value to subkey, use key.subkey=value - """ - def __init__(self, option_strings, dest, **kwargs): - super().__init__(option_strings, dest, **kwargs) - - def __call__(self, parser, namespace, values, option_string=None): - if '=' not in values: - raise ValueError(f'Invalid extra config: {values}') - key, val = values.split('=', 1) - if val.isdigit(): - val = int(val) - elif val.replace('.', '', 1).isdigit(): - val = float(val) - elif val.lower() in ('true', 'false'): - val = val.lower() == 'true' - obj = getattr(namespace, self.dest) - obj[key] = val - setattr(namespace, self.dest, obj) - - -def main(): - parser = ArgumentParser() - parser.add_argument('--debug', action='store_true', dest='debug', default=False, help='enable debug') - parser.add_argument('--data', dest='data_dir', default='', help='set base data dir for milvus') - parser.add_argument('--extra-config', dest='extra_config', default={}, help='set extra config for milvus', - action=ExtraConfigAcxtion) - - # dynamic configurations - for key in default_server.config_keys: - val = default_server.config.get(key) - if val is not None: - val_type = default_server.config.get_type(key) - name = '--' + key.replace('_', '-') - parser.add_argument(name, type=val_type, default=val, dest=f'x_{key}', - help=f'set value for {key} ({val_type.__name__})') - - args = parser.parse_args() - - # select server - server = debug_server if args.debug else default_server - - # set base dir if configured - if args.data_dir: - server.set_base_dir(args.data_dir) - - # apply configs - # pylint: disable=protected-access - for name, value in args._get_kwargs(): - if name.startswith('x_'): - server.config.set(name[2:], value) - for key, value in args.extra_config.items(): - server.config.set(key, value) - - signal.signal(signal.SIGINT, lambda sig, h: server.stop()) - - try: - server.start() - except TimeoutError: - print('Wait for milvus server started timeout.') - except RuntimeError: - print('Milvus server already stopped.') - - server.wait() - - -if __name__ == '__main__': - main() diff --git a/src/milvus/data/.gitignore b/src/milvus/data/.gitignore deleted file mode 100644 index 6dd29b7..0000000 --- a/src/milvus/data/.gitignore +++ /dev/null @@ -1 +0,0 @@ -bin/ \ No newline at end of file diff --git a/src/milvus/data/config.yaml.template b/src/milvus/data/config.yaml.template deleted file mode 100644 index a407389..0000000 --- a/src/milvus/data/config.yaml.template +++ /dev/null @@ -1,541 +0,0 @@ -# Licensed to the LF AI & Data foundation under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Related configuration of etcd, used to store Milvus metadata & service discovery. -etcd: - endpoints: - - localhost:{{ etcd_port(int): 2379 }} - rootPath: by-dev # The root path where data is stored in etcd - metaSubPath: meta # metaRootPath = rootPath + '/' + metaSubPath - kvSubPath: kv # kvRootPath = rootPath + '/' + kvSubPath - log: - # path is one of: - # - "default" as os.Stderr, - # - "stderr" as os.Stderr, - # - "stdout" as os.Stdout, - # - file path to append server logs to. - # please adjust in embedded Milvus: /tmp/milvus/logs/etcd.log - path: {{ etcd_log_path }} - level: {{ etcd_log_level: info }} # Only supports debug, info, warn, error, panic, or fatal. Default 'info'. - use: - # please adjust in embedded Milvus: true - embed: true # Whether to enable embedded Etcd (an in-process EtcdServer). - data: - # Embedded Etcd only. - # please adjust in embedded Milvus: /tmp/milvus/etcdData/ - dir: {{ etcd_data_dir }} - ssl: - enabled: false # Whether to support ETCD secure connection mode - tlsCert: /path/to/etcd-client.pem # path to your cert file - tlsKey: /path/to/etcd-client-key.pem # path to your key file - tlsCACert: /path/to/ca.pem # path to your CACert file - # TLS min version - # Optional values: 1.0, 1.1, 1.2, 1.3。 - # We recommend using version 1.2 and above - tlsMinVersion: 1.3 - -# Default value: etcd -# Valid values: [etcd, mysql] -metastore: - type: etcd - -# Related configuration of mysql, used to store Milvus metadata. -mysql: - username: root - password: 123456 - address: localhost - port: 3306 - dbName: milvus_meta - driverName: mysql - maxOpenConns: 20 - maxIdleConns: 5 - -# please adjust in embedded Milvus: /tmp/milvus/data/ -localStorage: - path: {{ local_storage_dir }} - -# Related configuration of MinIO/S3/GCS or any other service supports S3 API, which is responsible for data persistence for Milvus. -# We refer to the storage service as MinIO/S3 in the following description for simplicity. -minio: - address: localhost # Address of MinIO/S3 - port: 9000 # Port of MinIO/S3 - accessKeyID: minioadmin # accessKeyID of MinIO/S3 - secretAccessKey: minioadmin # MinIO/S3 encryption string - useSSL: false # Access to MinIO/S3 with SSL - bucketName: "a-bucket" # Bucket name in MinIO/S3 - rootPath: files # The root path where the message is stored in MinIO/S3 - # Whether to use IAM role to access S3/GCS instead of access/secret keys - # For more infomation, refer to - # aws: https://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles_use.html - # gcp: https://cloud.google.com/storage/docs/access-control/iam - useIAM: false - # Cloud Provider of S3. Supports: "aws", "gcp". - # You can use "aws" for other cloud provider supports S3 API with signature v4, e.g.: minio - # You can use "gcp" for other cloud provider supports S3 API with signature v2 - # When `useIAM` enabled, only "aws" & "gcp" is supported for now - cloudProvider: "aws" - # Custom endpoint for fetch IAM role credentials. when useIAM is true & cloudProvider is "aws". - # Leave it empty if you want to use AWS default endpoint - iamEndpoint: "" - -# Milvus supports three MQ: rocksmq(based on RockDB), Pulsar and Kafka, which should be reserved in config what you use. -# There is a note about enabling priority if we config multiple mq in this file -# 1. standalone(local) mode: rockskmq(default) > Pulsar > Kafka -# 2. cluster mode: Pulsar(default) > Kafka (rocksmq is unsupported) - -# Related configuration of pulsar, used to manage Milvus logs of recent mutation operations, output streaming log, and provide log publish-subscribe services. -pulsar: - address: localhost # Address of pulsar - port: 6650 # Port of pulsar - webport: 80 # Web port of pulsar, if you connect direcly without proxy, should use 8080 - maxMessageSize: 5242880 # 5 * 1024 * 1024 Bytes, Maximum size of each message in pulsar. - tenant: public - namespace: default - -# If you want to enable kafka, needs to comment the pulsar configs -kafka: - producer: - client.id: dc - consumer: - client.id: dc1 -# brokerList: localhost1:9092,localhost2:9092,localhost3:9092 -# saslUsername: username -# saslPassword: password -# saslMechanisms: PLAIN -# securityProtocol: SASL_SSL - -rocksmq: - # please adjust in embedded Milvus: /tmp/milvus/rdb_data - path: {{ rocketmq_data_dir }} # The path where the message is stored in rocksmq - rocksmqPageSize: 268435456 # 256 MB, 256 * 1024 * 1024 bytes, The size of each page of messages in rocksmq - retentionTimeInMinutes: 7200 # 5 days, 5 * 24 * 60 minutes, The retention time of the message in rocksmq. - retentionSizeInMB: 8192 # 8 GB, 8 * 1024 MB, The retention size of the message in rocksmq. - compactionInterval: 86400 # 1 day, trigger rocksdb compaction every day to remove deleted data - lrucacheratio: 0.06 # rocksdb cache memory ratio - -# Related configuration of rootCoord, used to handle data definition language (DDL) and data control language (DCL) requests -rootCoord: - address: localhost - port: {{ root_coord_port(int) }} - enableActiveStandby: false # Enable active-standby - - dmlChannelNum: 256 # The number of dml channels created at system startup - maxPartitionNum: 4096 # Maximum number of partitions in a collection - minSegmentSizeToEnableIndex: 1024 # It's a threshold. When the segment size is less than this value, the segment will not be indexed - - # (in seconds) Duration after which an import task will expire (be killed). Default 900 seconds (15 minutes). - # Note: If default value is to be changed, change also the default in: internal/util/paramtable/component_param.go - importTaskExpiration: 900 - # (in seconds) Milvus will keep the record of import tasks for at least `importTaskRetention` seconds. Default 86400 - # seconds (24 hours). - # Note: If default value is to be changed, change also the default in: internal/util/paramtable/component_param.go - importTaskRetention: 86400 - -# Related configuration of proxy, used to validate client requests and reduce the returned results. -proxy: - port: {{ proxy_port(int): 19530 }} - internalPort: {{ proxy_internal_port(int): 19529 }} - http: - enabled: true # Whether to enable the http server - debug_mode: false # Whether to enable http server debug mode - - timeTickInterval: 200 # ms, the interval that proxy synchronize the time tick - msgStream: - timeTick: - bufSize: 512 - maxNameLength: 255 # Maximum length of name for a collection or alias - maxFieldNum: 64 # Maximum number of fields in a collection. - # As of today (2.2.0 and after) it is strongly DISCOURAGED to set maxFieldNum >= 64. - # So adjust at your risk! - maxDimension: 32768 # Maximum dimension of a vector - # It's strongly DISCOURAGED to set `maxShardNum` > 64. - maxShardNum: 64 # Maximum number of shards in a collection - maxTaskNum: 1024 # max task number of proxy task queue - # please adjust in embedded Milvus: false - ginLogging: false # Whether to produce gin logs. - grpc: - serverMaxRecvSize: 67108864 # 64M - serverMaxSendSize: 67108864 # 64M - clientMaxRecvSize: 104857600 # 100 MB, 100 * 1024 * 1024 - clientMaxSendSize: 104857600 # 100 MB, 100 * 1024 * 1024 - - -# Related configuration of queryCoord, used to manage topology and load balancing for the query nodes, and handoff from growing segments to sealed segments. -queryCoord: - address: localhost - port: {{ query_coord_port(int) }} - autoHandoff: true # Enable auto handoff - autoBalance: true # Enable auto balance - overloadedMemoryThresholdPercentage: 90 # The threshold percentage that memory overload - balanceIntervalSeconds: 60 - memoryUsageMaxDifferencePercentage: 30 - checkInterval: 1000 - channelTaskTimeout: 60000 # 1 minute - segmentTaskTimeout: 120000 # 2 minute - distPullInterval: 500 - loadTimeoutSeconds: 600 - checkHandoffInterval: 5000 - taskMergeCap: 16 - taskExecutionCap: 256 - enableActiveStandby: false # Enable active-standby - -# Related configuration of queryNode, used to run hybrid search between vector and scalar data. -queryNode: - cacheSize: 32 # GB, default 32 GB, `cacheSize` is the memory used for caching data for faster query. The `cacheSize` must be less than system memory size. - port: {{ query_node_port(int) }} - loadMemoryUsageFactor: 3 # The multiply factor of calculating the memory usage while loading segments - enableDisk: true # enable querynode load disk index, and search on disk index - maxDiskUsagePercentage: 95 - gracefulStopTimeout: 30 - - stats: - publishInterval: 1000 # Interval for querynode to report node information (milliseconds) - dataSync: - flowGraph: - maxQueueLength: 1024 # Maximum length of task queue in flowgraph - maxParallelism: 1024 # Maximum number of tasks executed in parallel in the flowgraph - # Segcore will divide a segment into multiple chunks to enbale small index - segcore: - chunkRows: 1024 # The number of vectors in a chunk. - # Note: we have disabled segment small index since @2022.05.12. So below related configurations won't work. - # We won't create small index for growing segments and search on these segments will directly use bruteforce scan. - smallIndex: - nlist: 128 # small index nlist, recommend to set sqrt(chunkRows), must smaller than chunkRows/8 - nprobe: 16 # nprobe to search small index, based on your accuracy requirement, must smaller than nlist - cache: - enabled: true - memoryLimit: 2147483648 # 2 GB, 2 * 1024 *1024 *1024 - - scheduler: - receiveChanSize: 10240 - unsolvedQueueSize: 10240 - # maxReadConcurrentRatio is the concurrency ratio of read task (search task and query task). - # Max read concurrency would be the value of `runtime.NumCPU * maxReadConcurrentRatio`. - # It defaults to 2.0, which means max read concurrency would be the value of runtime.NumCPU * 2. - # Max read concurrency must greater than or equal to 1, and less than or equal to runtime.NumCPU * 100. - maxReadConcurrentRatio: 2.0 # (0, 100] - cpuRatio: 10.0 # ratio used to estimate read task cpu usage. - # maxTimestampLag is the max ts lag between serviceable and guarantee timestamp. - # if the lag is larger than this config, scheduler will return error without waiting. - # the valid value is [3600, infinite) - maxTimestampLag: 86400 - - grouping: - enabled: true - maxNQ: 1000 - topKMergeRatio: 10.0 - -indexCoord: - address: localhost - port: {{ index_coord_port(int) }} - enableActiveStandby: false # Enable active-standby - - minSegmentNumRowsToEnableIndex: 1024 # It's a threshold. When the segment num rows is less than this value, the segment will not be indexed - - bindIndexNodeMode: - enable: false - address: "localhost:22930" - withCred: false - nodeID: 0 - - gc: - interval: 600 # gc interval in seconds - - scheduler: - interval: 1000 # scheduler interval in Millisecond - -indexNode: - port: {{ index_node_port(int) }} - enableDisk: true # enable index node build disk vector index - maxDiskUsagePercentage: 95 - gracefulStopTimeout: 30 - - scheduler: - buildParallel: 1 - -dataCoord: - address: localhost - port: {{ data_coord_port(int) }} - enableCompaction: true # Enable data segment compaction - enableGarbageCollection: true - enableActiveStandby: false # Enable active-standby - - channel: - watchTimeoutInterval: 30 # Timeout on watching channels (in seconds). Datanode tickler update watch progress will reset timeout timer. - balanceSilentDuration: 300 # The duration before the channelBalancer on datacoord to run - balanceInterval: 360 #The interval for the channelBalancer on datacoord to check balance status - - segment: - maxSize: 512 # Maximum size of a segment in MB - diskSegmentMaxSize: 2048 # Maximun size of a segment in MB for collection which has Disk index - # Minimum proportion for a segment which can be sealed. - # Sealing early can prevent producing large growing segments in case these segments might slow down our search/query. - # Segments that sealed early will be compacted into a larger segment (within maxSize) eventually. - sealProportion: 0.23 - assignmentExpiration: 2000 # The time of the assignment expiration in ms - maxLife: 86400 # The max lifetime of segment in seconds, 24*60*60 - # If a segment didn't accept dml records in `maxIdleTime` and the size of segment is greater than - # `minSizeFromIdleToSealed`, Milvus will automatically seal it. - maxIdleTime: 600 # The max idle time of segment in seconds, 10*60. - minSizeFromIdleToSealed: 16 # The min size in MB of segment which can be idle from sealed. - # The max number of binlog file for one segment, the segment will be sealed if - # the number of binlog file reaches to max value. - maxBinlogFileNumber: 32 - smallProportion: 0.5 # The segment is considered as "small segment" when its # of rows is smaller than - # (smallProportion * segment max # of rows). - compactableProportion: 0.85 # A compaction will happen on small segments if the segment after compaction will have - # over (compactableProportion * segment max # of rows) rows. - # MUST BE GREATER THAN OR EQUAL TO !!! - - compaction: - enableAutoCompaction: true - - gc: - interval: 3600 # gc interval in seconds - missingTolerance: 86400 # file meta missing tolerance duration in seconds, 60*24 - dropTolerance: 3600 # file belongs to dropped entity tolerance duration in seconds - - -dataNode: - port: {{ data_node_port(int) }} - - dataSync: - flowGraph: - maxQueueLength: 1024 # Maximum length of task queue in flowgraph - maxParallelism: 1024 # Maximum number of tasks executed in parallel in the flowgraph - segment: - # Max buffer size to flush for a single segment. - insertBufSize: 16777216 # Bytes, 16 MB - # Max buffer size to flush del for a single channel - deleteBufBytes: 67108864 # Bytes, 64MB - # The period to sync segments if buffer is not empty. - syncPeriod: 600 # Seconds, 10min - - memory: - forceSyncEnable: false # `true` to force sync if memory usage is too high - forceSyncThreshold: 0.6 # forceSync only take effects when memory usage ratio > forceSyncThreshold - forceSyncSegmentRatio: 0.3 # ratio of segments to sync, top largest forceSyncSegmentRatio segments will be synced - -# Configures the system log output. -log: - level: {{ system_log_level: debug }} # Only supports debug, info, warn, error, panic, or fatal. Default 'info'. - stdout: {{ log_stdout:false }} # default true, print log to stdout - file: - # please adjust in embedded Milvus: /tmp/milvus/logs - rootPath: {{ system_log_path }} # default to stdout, stderr - maxSize: 300 # MB - maxAge: 10 # Maximum time for log retention in day. - maxBackups: 20 - format: text # text/json - -grpc: - log: - level: WARNING - - serverMaxRecvSize: 536870912 # 512MB - serverMaxSendSize: 536870912 # 512MB - clientMaxRecvSize: 104857600 # 100 MB, 100 * 1024 * 1024 - clientMaxSendSize: 104857600 # 100 MB, 100 * 1024 * 1024 - - client: - dialTimeout: 5000 - keepAliveTime: 10000 - keepAliveTimeout: 20000 - maxMaxAttempts: 5 - initialBackOff: 1.0 - maxBackoff: 60.0 - backoffMultiplier: 2.0 - server: - retryTimes: 5 # retry times when receiving a grpc return value with a failure and retryable state code - -# Configure the proxy tls enable. -tls: - serverPemPath: {{ server_pem_path: server.pem }} - serverKeyPath: {{ server_key_path: server.key }} - caPemPath: {{ ca_pem_path: ca.pem }} - - -common: - # Channel name generation rule: ${namePrefix}-${ChannelIdx} - chanNamePrefix: - cluster: "by-dev" - rootCoordTimeTick: "rootcoord-timetick" - rootCoordStatistics: "rootcoord-statistics" - rootCoordDml: "rootcoord-dml" - rootCoordDelta: "rootcoord-delta" - search: "search" - searchResult: "searchResult" - queryTimeTick: "queryTimeTick" - queryNodeStats: "query-node-stats" - # Cmd for loadIndex, flush, etc... - cmd: "cmd" - dataCoordStatistic: "datacoord-statistics-channel" - dataCoordTimeTick: "datacoord-timetick-channel" - dataCoordSegmentInfo: "segment-info-channel" - - # Sub name generation rule: ${subNamePrefix}-${NodeID} - subNamePrefix: - rootCoordSubNamePrefix: "rootCoord" - proxySubNamePrefix: "proxy" - queryNodeSubNamePrefix: "queryNode" - dataNodeSubNamePrefix: "dataNode" - dataCoordSubNamePrefix: "dataCoord" - - defaultPartitionName: "_default" # default partition name for a collection - defaultIndexName: "_default_idx" # default index name - retentionDuration: 0 # time travel reserved time, insert/delete will not be cleaned in this period. disable it by default - entityExpiration: -1 # Entity expiration in seconds, CAUTION make sure entityExpiration >= retentionDuration and -1 means never expire - - gracefulTime: 5000 # milliseconds. it represents the interval (in ms) by which the request arrival time needs to be subtracted in the case of Bounded Consistency. - gracefulStopTimeout: 30 # seconds. it will force quit the server if the graceful stop process is not completed during this time. - - # Default value: auto - # Valid values: [auto, avx512, avx2, avx, sse4_2] - # This configuration is only used by querynode and indexnode, it selects CPU instruction set for Searching and Index-building. - simdType: auto - indexSliceSize: 16 # MB - DiskIndex: - MaxDegree: 56 - SearchListSize: 100 - PQCodeBudgetGBRatio: 0.125 - BuildNumThreadsRatio: 1.0 - SearchCacheBudgetGBRatio: 0.10 - LoadNumThreadRatio: 8.0 - BeamWidthRatio: 4.0 - # This parameter specify how many times the number of threads is the number of cores - threadCoreCoefficient : 10 - - # please adjust in embedded Milvus: local - storageType: local - - security: - authorizationEnabled: {{ authorization_enabled(bool): false }} - # The superusers will ignore some system check processes, - # like the old password verification when updating the credential - # superUsers: - # - "root" - # tls mode values [0, 1, 2] - # 0 is close, 1 is one-way authentication, 2 is two-way authentication. - tlsMode: {{ tls_mode(int): 0 }} - - session: - ttl: 60 # ttl value when session granting a lease to register service - retryTimes: 30 # retry times when session sending etcd requests - -# QuotaConfig, configurations of Milvus quota and limits. -# By default, we enable: -# 1. TT protection; -# 2. Memory protection. -# 3. Disk quota protection. -# You can enable: -# 1. DML throughput limitation; -# 2. DDL, DQL qps/rps limitation; -# 3. DQL Queue length/latency protection; -# 4. DQL result rate protection; -# If necessary, you can also manually force to deny RW requests. -quotaAndLimits: - enabled: true # `true` to enable quota and limits, `false` to disable. - - # quotaCenterCollectInterval is the time interval that quotaCenter - # collects metrics from Proxies, Query cluster and Data cluster. - quotaCenterCollectInterval: 3 # seconds, (0 ~ 65536) - - ddl: # ddl limit rates, default no limit. - enabled: false - collectionRate: -1 # qps, default no limit, rate for CreateCollection, DropCollection, LoadCollection, ReleaseCollection - partitionRate: -1 # qps, default no limit, rate for CreatePartition, DropPartition, LoadPartition, ReleasePartition - - indexRate: - enabled: false - max: -1 # qps, default no limit, rate for CreateIndex, DropIndex - flushRate: - enabled: false - max: -1 # qps, default no limit, rate for flush - compactionRate: - enabled: false - max: -1 # qps, default no limit, rate for manualCompaction - - # dml limit rates, default no limit. - # The maximum rate will not be greater than `max`. - dml: - enabled: false - insertRate: - max: -1 # MB/s, default no limit - deleteRate: - max: -1 # MB/s, default no limit - bulkLoadRate: # not support yet. TODO: limit bulkLoad rate - max: -1 # MB/s, default no limit - - # dql limit rates, default no limit. - # The maximum rate will not be greater than `max`. - dql: - enabled: false - searchRate: - max: -1 # vps (vectors per second), default no limit - queryRate: - max: -1 # qps, default no limit - - # limitWriting decides whether dml requests are allowed. - limitWriting: - # forceDeny `false` means dml requests are allowed (except for some - # specific conditions, such as memory of nodes to water marker), `true` means always reject all dml requests. - forceDeny: false - ttProtection: - enabled: false - # maxTimeTickDelay indicates the backpressure for DML Operations. - # DML rates would be reduced according to the ratio of time tick delay to maxTimeTickDelay, - # if time tick delay is greater than maxTimeTickDelay, all DML requests would be rejected. - maxTimeTickDelay: 300 # in seconds - memProtection: - enabled: true - # When memory usage > memoryHighWaterLevel, all dml requests would be rejected; - # When memoryLowWaterLevel < memory usage < memoryHighWaterLevel, reduce the dml rate; - # When memory usage < memoryLowWaterLevel, no action. - # memoryLowWaterLevel should be less than memoryHighWaterLevel. - dataNodeMemoryLowWaterLevel: 0.85 # (0, 1], memoryLowWaterLevel in DataNodes - dataNodeMemoryHighWaterLevel: 0.95 # (0, 1], memoryHighWaterLevel in DataNodes - queryNodeMemoryLowWaterLevel: 0.85 # (0, 1], memoryLowWaterLevel in QueryNodes - queryNodeMemoryHighWaterLevel: 0.95 # (0, 1], memoryHighWaterLevel in QueryNodes - diskProtection: - # When the total file size of object storage is greater than `diskQuota`, all dml requests would be rejected; - enabled: true - diskQuota: -1 # MB, (0, +inf), default no limit - - # limitReading decides whether dql requests are allowed. - limitReading: - # forceDeny `false` means dql requests are allowed (except for some - # specific conditions, such as collection has been dropped), `true` means always reject all dql requests. - forceDeny: false - queueProtection: - enabled: false - # nqInQueueThreshold indicated that the system was under backpressure for Search/Query path. - # If NQ in any QueryNode's queue is greater than nqInQueueThreshold, search&query rates would gradually cool off - # until the NQ in queue no longer exceeds nqInQueueThreshold. We think of the NQ of query request as 1. - nqInQueueThreshold: -1 # int, default no limit - - # queueLatencyThreshold indicated that the system was under backpressure for Search/Query path. - # If dql latency of queuing is greater than queueLatencyThreshold, search&query rates would gradually cool off - # until the latency of queuing no longer exceeds queueLatencyThreshold. - # The latency here refers to the averaged latency over a period of time. - queueLatencyThreshold: -1 # milliseconds, default no limit - resultProtection: - enabled: false - # maxReadResultRate indicated that the system was under backpressure for Search/Query path. - # If dql result rate is greater than maxReadResultRate, search&query rates would gradually cool off - # until the read result rate no longer exceeds maxReadResultRate. - maxReadResultRate: -1 # MB/s, default no limit - # coolOffSpeed is the speed of search&query rates cool off. - coolOffSpeed: 0.9 # (0, 1] diff --git a/src/milvus_local.cpp b/src/milvus_local.cpp new file mode 100644 index 0000000..3bf71ce --- /dev/null +++ b/src/milvus_local.cpp @@ -0,0 +1,331 @@ +#include "milvus_local.h" +#include +#include +#include +#include +#include +#include "common.h" +#include "pb/schema.pb.h" +#include "pb/segcore.pb.h" +#include "schema_util.h" +#include "status.h" +#include "storage.h" +#include "log/Log.h" +#include "string_util.hpp" + +namespace milvus::local { + +#define CHECK_COLLECTION_EXIST(collection_name) \ + do { \ + CHECK_STATUS(CheckCollectionName(string_util::Trim(collection_name)), \ + ""); \ + if (!storage_.CollectionExist(collection_name)) { \ + LOG_ERROR("Collecton {} not existed", collection_name); \ + return Status::CollectionNotFound(); \ + } \ + } while (0) + +#define CHECK_COLLECTION_NOT_EXIST(collection_name) \ + do { \ + CHECK_STATUS(CheckCollectionName(string_util::Trim(collection_name)), \ + ""); \ + if (storage_.CollectionExist(collection_name)) { \ + LOG_ERROR("Collecton {} already existed", collection_name); \ + return Status::CollectionAlreadExist(); \ + } \ + } while (0) + +MilvusLocal::MilvusLocal(const char* db_file) + : db_file_(db_file), storage_(db_file), initialized(false) { +} + +MilvusLocal::~MilvusLocal() { +} + +Status +MilvusLocal::CheckCollectionName(const std::string& collection_name) { + if (collection_name.empty()) { + return Status::ParameterInvalid("collection name should not be empty"); + } + std::string invalid_msg = + string_util::SFormat("Invalid collection {}. ", collection_name); + if (collection_name.size() > 255) { + return Status::ParameterInvalid( + "{}, the length of a collection name must " + "be less than 255 characters", + invalid_msg); + } + + char first = collection_name[0]; + if (first != '_' && !string_util::IsAlpha(first)) { + return Status::ParameterInvalid( + "{} the first character of a collection {} must be an underscore " + "or letter", + invalid_msg, + collection_name); + } + std::regex pattern("^[a-zA-Z_][a-zA-Z0-9_]*$"); + if (!std::regex_match(collection_name, pattern)) { + auto err = string_util::SFormat( + "{}, collection name can only contain " + "numbers, letters and underscores", + invalid_msg); + LOG_ERROR(err); + return Status::ParameterInvalid(err); + } + return Status::Ok(); +} + +bool +MilvusLocal::Init() { + std::lock_guard lock(mutex_); + if (initialized) { + LOG_WARN("Milvus has already initialized"); + return false; + } + + if (!storage_.Open()) { + return false; + } + initialized = true; + return true; +} + +Status +MilvusLocal::LoadCollection(const std::string& collection_name) { + std::lock_guard lock(mutex_); + CHECK_COLLECTION_EXIST(collection_name); + if (index_.HasLoaded(collection_name)) { + return Status::Ok(); + } + + std::string schema_proto, index_proto; + if (!storage_.GetCollectionSchema(collection_name, &schema_proto)) { + LOG_ERROR("Can not find {}'s schema", collection_name); + return Status::ServiceInternal("Schema not found"); + } + + CHECK_STATUS(index_.CreateCollection(collection_name, schema_proto), ""); + std::vector all_index_proto; + storage_.GetAllIndex(collection_name, "", &all_index_proto); + auto index_meta_proto = schema_util::MergeIndexs(all_index_proto); + CHECK_STATUS(index_.CreateIndex(collection_name, index_meta_proto), ""); + std::vector rows; + int64_t start = 0; + while (true) { + storage_.LoadCollecton(collection_name, start, 200000, &rows); + if (rows.size() == 0) { + LOG_INFO("Success load {} rows", start); + return Status::Ok(); + } + for (const auto& row : rows) { + CHECK_STATUS(index_.Insert(collection_name, 1, row), + "Load data failed: "); + } + start += rows.size(); + rows.clear(); + } + + return Status::Ok(); +} + +Status +MilvusLocal::ReleaseCollection(const std::string& collection_name) { + std::lock_guard lock(mutex_); + CHECK_COLLECTION_EXIST(collection_name); + if (index_.DropCollection(collection_name)) { + return Status::Ok(); + } + return Status::SegcoreErr(); +} + +Status +MilvusLocal::CreateCollection(const std::string& collection_name, + const std::string& pk_name, + const std::string& schema_proto) { + std::lock_guard lock(mutex_); + CHECK_STATUS(CheckCollectionName(string_util::Trim(collection_name)), ""); + + if (storage_.CollectionExist(collection_name)) { + std::string db_schema_proto; + if (!storage_.GetCollectionSchema(collection_name, &db_schema_proto)) { + return Status::ServiceInternal(); + } + if (!schema_util::SchemaEquals(schema_proto, db_schema_proto)) { + return Status::ParameterInvalid( + "create duplicate collection with different parameters, " + "collection {}", + collection_name); + } + return Status::Ok(); + } + + // CHECK_COLLECTION_NOT_EXIST(collection_name); + CHECK_STATUS(index_.CreateCollection(collection_name, schema_proto), ""); + if (!storage_.CreateCollection(collection_name, pk_name, schema_proto)) { + return Status::ServiceInternal(); + } + return Status::Ok(); +} + +Status +MilvusLocal::GetCollection(const std::string& collection_name, + std::string* schema_proto) { + std::lock_guard lock(mutex_); + CHECK_COLLECTION_EXIST(collection_name); + if (!storage_.GetCollectionSchema(collection_name, schema_proto)) { + return Status::ServiceInternal(); + } + return Status::Ok(); +} + +bool +MilvusLocal::DropCollection(const std::string& collection_name) { + std::lock_guard lock(mutex_); + if (!storage_.CollectionExist(collection_name)) { + LOG_WARN("Collection {} not existed", collection_name); + return true; + } + return index_.DropCollection(collection_name) && + storage_.DropCollection(collection_name); +} + +void +MilvusLocal::GetAllCollections(std::vector* collection_names) { + std::lock_guard lock(mutex_); + storage_.ListCollections(collection_names); +} + +Status +MilvusLocal::CreateIndex(const std::string& collection_name, + const std::string& index_name, + const std::string& index_proto) { + std::lock_guard lock(mutex_); + CHECK_COLLECTION_EXIST(collection_name); + if (storage_.HasIndex(collection_name, index_name)) { + // TODO add index info check + LOG_WARN("Collection {}'s index {} alread existed", + collection_name, + index_name); + return Status::Ok(); + } + // get existed index + std::vector all_index_proto; + storage_.GetAllIndex(collection_name, "", &all_index_proto); + all_index_proto.push_back(index_proto); + auto index_meta_proto = schema_util::MergeIndexs(all_index_proto); + + CHECK_STATUS(index_.CreateIndex(collection_name, index_meta_proto), ""); + if (!storage_.CreateIndex(collection_name, index_name, index_proto)) { + return Status::ServiceInternal(); + } + return Status::Ok(); +} + +Status +MilvusLocal::DropIndex(const std::string& collection_name, + const std::string& index_name) { + std::lock_guard lock(mutex_); + CHECK_COLLECTION_EXIST(collection_name); + if (!storage_.HasIndex(collection_name, index_name)) { + return Status::Ok(); + } + + std::vector all_index_proto; + storage_.GetAllIndex(collection_name, index_name, &all_index_proto); + auto index_meta_proto = schema_util::MergeIndexs(all_index_proto); + + CHECK_STATUS(index_.CreateIndex(collection_name, index_meta_proto), ""); + if (!storage_.DropIndex(collection_name, index_name)) { + return Status::ServiceInternal(); + } + return Status::Ok(); +} + +Status +MilvusLocal::Insert(const std::string& collection_name, + const Rows& rows, + std::vector* ids) { + std::lock_guard lock(mutex_); + CHECK_COLLECTION_EXIST(collection_name); + int64_t count = 0; + for (const auto& row : rows) { + if (index_.Insert(collection_name, 1, std::get<1>(row)).IsErr()) { + break; + } + ids->push_back(std::get<0>(row)); + count += 1; + } + auto start = rows.begin(); + auto end = rows.begin() + count; + std::vector rows_need_insert(start, end); + storage_.Insert(collection_name, rows_need_insert); + return Status::Ok(); +} + +Status +MilvusLocal::Retrieve(const std::string& collection_name, + const std::string& plan, + RetrieveResult* result) { + std::lock_guard lock(mutex_); + CHECK_COLLECTION_EXIST(collection_name); + return index_.Retrieve(collection_name, plan, result); +} + +Status +MilvusLocal::Search(const std::string& collection_name, + const std::string& plan, + const std::string& placeholder_group, + SearchResult* result) { + std::lock_guard lock(mutex_); + CHECK_COLLECTION_EXIST(collection_name); + return index_.Search(collection_name, plan, placeholder_group, result); +} + +Status +MilvusLocal::DeleteByIds(const std::string& collection_name, + const std::string& ids, + int64_t size, + const std::vector& storage_ids) { + std::lock_guard lock(mutex_); + CHECK_COLLECTION_EXIST(collection_name); + if (!storage_.Delete(collection_name, storage_ids)) { + return Status::ServiceInternal(); + } + CHECK_STATUS(index_.DeleteByIds(collection_name, ids, size), ""); + return Status::Ok(); +} + +Status +MilvusLocal::GetIndex(const std::string& collection_name, + const std::string& index_name, + std::string* index_proto) { + std::lock_guard lock(mutex_); + CHECK_COLLECTION_EXIST(collection_name); + if (!storage_.GetIndex(collection_name, index_name, index_proto)) { + return Status::IndexNotFound(); + } + return Status::Ok(); +} + +Status +MilvusLocal::GetAllIndexs(const std::string& collection_name, + std::vector* all_index_proto) { + std::lock_guard lock(mutex_); + CHECK_COLLECTION_EXIST(collection_name); + storage_.GetAllIndex(collection_name, "", all_index_proto); + return Status::Ok(); +} + +Status +MilvusLocal::GetNumRowsOfCollection(const std::string& collection_name, + int64_t* num) { + std::lock_guard lock(mutex_); + CHECK_COLLECTION_EXIST(collection_name); + *num = storage_.Count(collection_name); + if (*num < 0) { + return Status::ServiceInternal(); + } + return Status::Ok(); +} +} // namespace milvus::local diff --git a/src/milvus_local.h b/src/milvus_local.h new file mode 100644 index 0000000..0159dce --- /dev/null +++ b/src/milvus_local.h @@ -0,0 +1,105 @@ +#pragma once + +#include +#include +#include +#include +#include "status.h" +#include "common.h" +#include "index.h" +#include "retrieve_result.h" +#include "search_result.h" +#include "storage.h" +#include "type.h" + +namespace milvus::local { + +class MilvusLocal final : NonCopyableNonMovable { + public: + explicit MilvusLocal(const char* db_file); + ~MilvusLocal(); + + public: + // load all meta info + bool + Init(); + + Status + LoadCollection(const std::string& collection_name); + + Status + ReleaseCollection(const std::string& collection_name); + + Status + CreateCollection(const std::string& collection_name, + const std::string& pk_name, + const std::string& schema_proto); + + Status + GetCollection(const std::string& collection_name, + std::string* schema_proto); + + bool + DropCollection(const std::string& collection_name); + + void + GetAllCollections(std::vector* collection_names); + + Status + CreateIndex(const std::string& collection_name, + const std::string& index_name, + const std::string& index_proto); + + Status + GetIndex(const std::string& collection_name, + const std::string& index_name, + std::string* index_proto); + + Status + GetAllIndexs(const std::string& collection_name, + std::vector* all_index_proto); + + Status + DropIndex(const std::string& collection_name, + const std::string& index_name); + + /* + * Row 为InsertRecord proto数据. + */ + Status + Insert(const std::string& collection_name, + const Rows& rows, + std::vector* ids); + Status + Retrieve(const std::string& collection_name, + const std::string& expr, + RetrieveResult* result); + + Status + Search(const std::string& collection_name, + const std::string& plan, + const std::string& placeholder_group, + SearchResult* result); + + Status + DeleteByIds(const std::string& collection_name, + const std::string& ids, + int64_t size, + const std::vector& storage_id); + + Status + GetNumRowsOfCollection(const std::string& collection_name, int64_t* num); + + private: + Status + CheckCollectionName(const std::string& collection_name); + + private: + std::mutex mutex_; + std::string db_file_; + Storage storage_; + Index index_; + bool initialized; +}; + +} // namespace milvus::local diff --git a/src/milvus_proxy.cpp b/src/milvus_proxy.cpp new file mode 100644 index 0000000..bafdf11 --- /dev/null +++ b/src/milvus_proxy.cpp @@ -0,0 +1,356 @@ +#include "milvus_proxy.h" +#include +#include +#include +#include "common.h" +#include "common/Types.h" +#include "log/Log.h" +#include "create_collection_task.h" +#include "create_index_task.h" +#include "delete_task.h" +#include "insert_task.h" +#include "milvus_local.h" +#include "pb/common.pb.h" +#include "pb/milvus.pb.h" +#include "pb/plan.pb.h" +#include "pb/schema.pb.h" +#include "pb/segcore.pb.h" +#include "query_task.h" +#include "retrieve_result.h" +#include "search_result.h" +#include "search_task.h" +#include "status.h" +#include "string_util.hpp" +#include "type.h" + +namespace milvus::local { + +MilvusProxy::MilvusProxy(const char* work_dir) : milvus_local_(work_dir) { +} + +MilvusProxy::~MilvusProxy() { +} + +bool +MilvusProxy::Init() { + return milvus_local_.Init(); +} + +Status +MilvusProxy::LoadCollection(const std::string& collection_name) { + return milvus_local_.LoadCollection(collection_name); +} + +Status +MilvusProxy::ReleaseCollection(const std::string& collection_name) { + // Alignment error code with milvus + auto s = milvus_local_.ReleaseCollection(collection_name); + if (s.Code() == ErrCollectionNotFound) { + return Status::ParameterInvalid("collection not found[collection={}]", + collection_name); + } + return s; +} + +Status +MilvusProxy::HasCollection(const std::string& collection_name) { + std::string tmp; + return milvus_local_.GetCollection(collection_name, &tmp); +} + +Status +MilvusProxy::CreateCollection( + const ::milvus::proto::milvus::CreateCollectionRequest* r) { + ::milvus::proto::schema::CollectionSchema schema; + + CHECK_STATUS(CreateCollectionTask(r).Process(&schema), ""); + + for (const auto& field : schema.fields()) { + if (field.is_primary_key()) { + return milvus_local_.CreateCollection( + schema.name(), field.name(), schema.SerializeAsString()); + } + } + return Status::FieldNotFound("Lost primary key field"); +} + +Status +MilvusProxy::CreateIndex(const ::milvus::proto::milvus::CreateIndexRequest* r) { + ::milvus::proto::schema::CollectionSchema schema; + if (!GetSchemaInfo(r->collection_name(), &schema).IsOk()) { + auto err = string_util::SFormat("Can not find collection {}", + r->collection_name()); + LOG_ERROR(err); + return Status::CollectionNotFound(err); + } + CHECK_STATUS(milvus_local_.LoadCollection(r->collection_name()), ""); + + // get all index + milvus::proto::segcore::FieldIndexMeta field_meta; + CHECK_STATUS(CreateIndexTask(r, &schema).Process(&field_meta), ""); + return milvus_local_.CreateIndex(r->collection_name(), + field_meta.index_name(), + field_meta.SerializeAsString()); +} + +Status +MilvusProxy::Insert(const ::milvus::proto::milvus::InsertRequest* r, + ::milvus::proto::schema::IDs* ids) { + ::milvus::proto::schema::CollectionSchema schema; + if (!GetSchemaInfo(r->collection_name(), &schema).IsOk()) { + auto err = string_util::SFormat("Collection {} not found", + r->collection_name()); + return Status::CollectionNotFound(); + } + CHECK_STATUS(milvus_local_.LoadCollection(r->collection_name()), ""); + Rows rows; + auto insert_task = InsertTask( + const_cast<::milvus::proto::milvus::InsertRequest*>(r), &schema); + CHECK_STATUS(insert_task.Process(&rows), ""); + std::vector insert_ids; + milvus_local_.Insert(r->collection_name(), rows, &insert_ids); + + if (insert_task.PkType() == ::milvus::proto::schema::DataType::Int64) { + for (const auto& id : insert_ids) { + ids->mutable_int_id()->add_data(std::stoll(id)); + } + } else { + for (const auto& id : insert_ids) { + ids->mutable_str_id()->add_data(id); + } + } + return Status::Ok(); +} + +Status +MilvusProxy::Search(const ::milvus::proto::milvus::SearchRequest* r, + ::milvus::proto::milvus::SearchResults* search_result) { + ::milvus::proto::schema::CollectionSchema schema; + if (!GetSchemaInfo(r->collection_name(), &schema).IsOk()) { + auto err = string_util::SFormat("Can not find {}'s schema", + r->collection_name()); + LOG_ERROR(err); + return Status::CollectionNotFound(err); + } + CHECK_STATUS(milvus_local_.LoadCollection(r->collection_name()), ""); + + std::string placeholder_group; + ::milvus::proto::plan::PlanNode plan; + std::vector nqs, topks; + + SearchTask task(const_cast<::milvus::proto::milvus::SearchRequest*>(r), + &schema); + CHECK_STATUS(task.Process(&plan, &placeholder_group, &nqs, &topks), ""); + SearchResult result(nqs, topks); + CHECK_STATUS(milvus_local_.Search(r->collection_name(), + plan.SerializeAsString(), + placeholder_group, + &result), + ""); + search_result->set_collection_name(r->collection_name()); + task.PostProcess(result, search_result); + if (search_result->results().has_ids()) { + return Status::Ok(); + } else { + return Status::Ok("search result is empty"); + } +} + +Status +MilvusProxy::Query(const ::milvus::proto::milvus::QueryRequest* r, + ::milvus::proto::milvus::QueryResults* query_result) { + ::milvus::proto::schema::CollectionSchema schema; + if (!GetSchemaInfo(r->collection_name(), &schema).IsOk()) { + return Status::CollectionNotFound(); + } + CHECK_STATUS(milvus_local_.LoadCollection(r->collection_name()), ""); + + ::milvus::proto::plan::PlanNode plan; + QueryTask task(r, &schema); + CHECK_STATUS(task.Process(&plan), ""); + + RetrieveResult result; + CHECK_STATUS(milvus_local_.Retrieve( + r->collection_name(), plan.SerializeAsString(), &result), + ""); + + query_result->set_collection_name(r->collection_name()); + task.PostProcess(result, query_result); + return Status::Ok(); +} + +Status +MilvusProxy::Delete(const ::milvus::proto::milvus::DeleteRequest* r, + ::milvus::proto::milvus::MutationResult* response) { + ::milvus::proto::schema::CollectionSchema schema; + if (!GetSchemaInfo(r->collection_name(), &schema).IsOk()) { + return Status::CollectionNotFound(); + } + CHECK_STATUS(milvus_local_.LoadCollection(r->collection_name()), ""); + + ::milvus::proto::plan::PlanNode plan; + CHECK_STATUS(DeleteTask(r, &schema).Process(&plan), ""); + + RetrieveResult result; + CHECK_STATUS(milvus_local_.Retrieve( + r->collection_name(), plan.SerializeAsString(), &result), + ""); + + ::milvus::proto::segcore::RetrieveResults seg_result; + seg_result.ParseFromArray(result.retrieve_result_.proto_blob, + result.retrieve_result_.proto_size); + auto ids_str = seg_result.ids().SerializeAsString(); + std::vector storage_ids; + if (seg_result.ids().has_int_id()) { + for (const auto& id : seg_result.ids().int_id().data()) { + storage_ids.push_back(std::to_string(id)); + response->mutable_ids()->mutable_int_id()->add_data(id); + } + + } else { + for (const auto& id : seg_result.ids().str_id().data()) { + storage_ids.push_back(id); + response->mutable_ids()->mutable_str_id()->add_data(id); + } + } + if (storage_ids.size() != 0) { + CHECK_STATUS( + milvus_local_.DeleteByIds( + r->collection_name(), ids_str, storage_ids.size(), storage_ids), + ""); + } + response->set_delete_cnt(storage_ids.size()); + return Status::Ok(); +} + +Status +MilvusProxy::GetSchemaInfo(const std::string& collection_name, + ::milvus::proto::schema::CollectionSchema* schema) { + std::string schema_proto; + CHECK_STATUS(milvus_local_.GetCollection(collection_name, &schema_proto), + ""); + if (!schema->ParseFromString(schema_proto)) { + LOG_ERROR("Failed to parse schema info: {}", schema_proto); + return Status::ServiceInternal(); + ; + } + return Status::Ok(); +} + +Status +MilvusProxy::DescribeCollection( + const ::milvus::proto::milvus::DescribeCollectionRequest* request, + ::milvus::proto::milvus::DescribeCollectionResponse* response) { + ::milvus::proto::schema::CollectionSchema schema; + CHECK_STATUS(GetSchemaInfo(request->collection_name(), &schema), ""); + auto mutable_schema = response->mutable_schema(); + mutable_schema->set_name(schema.name()); + mutable_schema->set_description(schema.description()); + mutable_schema->set_enable_dynamic_field(schema.enable_dynamic_field()); + mutable_schema->mutable_properties()->CopyFrom(schema.properties()); + for (const auto& field : schema.fields()) { + if (field.name() == kTimeStampFieldName || + field.name() == kRowIdFieldName || field.is_dynamic()) + continue; + mutable_schema->add_fields()->CopyFrom(field); + } + return Status::Ok(); +} + +Status +MilvusProxy::GetIndex( + const std::string& collection_name, + const std::string& index_name, + ::milvus::proto::milvus::DescribeIndexResponse* response) { + ::milvus::proto::schema::CollectionSchema schema; + if (!GetSchemaInfo(collection_name, &schema).IsOk()) { + auto err = + string_util::SFormat("Can not find collection {}", collection_name); + LOG_ERROR(err); + return Status::CollectionNotFound(err); + } + + if (index_name.empty()) { + std::vector all_index; + CHECK_STATUS(milvus_local_.GetAllIndexs(collection_name, &all_index), + ""); + for (const auto& index : all_index) { + CHECK_STATUS( + ParseIndex(index, schema, response->add_index_descriptions()), + ""); + } + return Status::Ok(); + } else { + std::string index_proto; + CHECK_STATUS( + milvus_local_.GetIndex(collection_name, index_name, &index_proto), + ""); + return ParseIndex( + index_proto, schema, response->add_index_descriptions()); + } +} + +Status +MilvusProxy::DropIndex(const std::string& collection_name, + const std::string& index_name) { + CHECK_STATUS(milvus_local_.LoadCollection(collection_name), ""); + return milvus_local_.DropIndex(collection_name, index_name); +} + +Status +MilvusProxy::ParseIndex(const std::string& index_proto, + const ::milvus::proto::schema::CollectionSchema& schema, + ::milvus::proto::milvus::IndexDescription* index) { + milvus::proto::segcore::FieldIndexMeta field_index; + if (!field_index.ParseFromString(index_proto)) { + return Status::ServiceInternal("Error index info in db"); + } + + auto field_id = field_index.fieldid(); + for (const auto& field : schema.fields()) { + if (field.fieldid() == field_id) { + index->set_field_name(field.name()); + } + } + index->set_index_name(field_index.index_name()); + index->set_indexid(GetIndexId(field_index.index_name())); + index->set_state(::milvus::proto::common::IndexState::Finished); + for (const auto& param : field_index.type_params()) { + auto new_param = index->add_params(); + new_param->set_key(param.key()); + new_param->set_value(param.value()); + } + + for (const auto& param : field_index.index_params()) { + auto new_param = index->add_params(); + new_param->set_key(param.key()); + new_param->set_value(param.value()); + } + + for (const auto& param : field_index.user_index_params()) { + auto new_param = index->add_params(); + new_param->set_key(param.key()); + new_param->set_value(param.value()); + } + return Status::Ok(); +} + +bool +MilvusProxy::DropCollection(const std::string& collection_name) { + return milvus_local_.DropCollection(collection_name); +} + +Status +MilvusProxy::GetCollectionStatistics( + const std::string& collection_name, + ::milvus::proto::milvus::GetCollectionStatisticsResponse* r) { + int64_t num = -1; + CHECK_STATUS(milvus_local_.GetNumRowsOfCollection(collection_name, &num), + ""); + auto s = r->add_stats(); + s->set_key("row_count"); + s->set_value(std::to_string(num)); + return Status::Ok(); +} + +} // namespace milvus::local diff --git a/src/milvus_proxy.h b/src/milvus_proxy.h new file mode 100644 index 0000000..192a3dd --- /dev/null +++ b/src/milvus_proxy.h @@ -0,0 +1,94 @@ +#pragma once + +#include +#include "common.h" +#include "milvus_local.h" +#include "pb/schema.pb.h" +#include "pb/milvus.pb.h" +#include "status.h" + +namespace milvus::local { + +class MilvusProxy : NonCopyableNonMovable { + public: + explicit MilvusProxy(const char* work_dir); + virtual ~MilvusProxy(); + + public: + bool + Init(); + + Status + LoadCollection(const std::string& collection_name); + + Status + ReleaseCollection(const std::string& collection_name); + + Status + CreateCollection( + const ::milvus::proto::milvus::CreateCollectionRequest* request); + + Status + HasCollection(const std::string& collection_name); + + bool + DropCollection(const std::string& collection_name); + + void + ListCollection(std::vector* collections) { + milvus_local_.GetAllCollections(collections); + } + + Status + CreateIndex(const ::milvus::proto::milvus::CreateIndexRequest* request); + + Status + GetIndex(const std::string& collection_name, + const std::string& index_name, + ::milvus::proto::milvus::DescribeIndexResponse* response); + + Status + DropIndex(const std::string& collection_name, + const std::string& index_name); + + Status + Insert(const ::milvus::proto::milvus::InsertRequest* request, + ::milvus::proto::schema::IDs* ids); + + Status + Search(const ::milvus::proto::milvus::SearchRequest* request, + ::milvus::proto::milvus::SearchResults* search_result); + + Status + Query(const ::milvus::proto::milvus::QueryRequest* request, + ::milvus::proto::milvus::QueryResults* response); + + Status + Delete(const ::milvus::proto::milvus::DeleteRequest* request, + ::milvus::proto::milvus::MutationResult* response); + + Status + DescribeCollection( + const ::milvus::proto::milvus::DescribeCollectionRequest* request, + ::milvus::proto::milvus::DescribeCollectionResponse* response); + + Status + GetCollectionStatistics( + const std::string& collection_name, + ::milvus::proto::milvus::GetCollectionStatisticsResponse* r); + + private: + Status + GetSchemaInfo(const std::string& collection_name, + ::milvus::proto::schema::CollectionSchema* schema); + + Status + ParseIndex(const std::string& index_proto, + const ::milvus::proto::schema::CollectionSchema& schema, + ::milvus::proto::milvus::IndexDescription* index); + + private: + MilvusLocal milvus_local_; +}; + +} // namespace milvus::local diff --git a/src/milvus_service_impl.cpp b/src/milvus_service_impl.cpp new file mode 100644 index 0000000..1444a5a --- /dev/null +++ b/src/milvus_service_impl.cpp @@ -0,0 +1,252 @@ +#include "milvus_service_impl.h" +#include +#include +#include "status.h" + +namespace milvus::local { + +void +Status2Response(Status& s, ::milvus::proto::common::Status* response) { + response->set_code(s.Code()); + response->set_retriable(false); + response->set_detail(s.Detail()); + response->set_reason(s.Detail() + ": " + s.Msg()); +} + +::grpc::Status +MilvusServiceImpl::CreateCollection( + ::grpc::ServerContext* context, + const ::milvus::proto::milvus::CreateCollectionRequest* request, + ::milvus::proto::common::Status* response) { + Status s = proxy_.CreateCollection(request); + Status2Response(s, response); + return ::grpc::Status::OK; +} + +::grpc::Status +MilvusServiceImpl::LoadCollection( + ::grpc::ServerContext* context, + const ::milvus::proto::milvus::LoadCollectionRequest* request, + ::milvus::proto::common::Status* response) { + Status s = proxy_.LoadCollection(request->collection_name()); + Status2Response(s, response); + return ::grpc::Status::OK; +} + +::grpc::Status +MilvusServiceImpl::HasCollection( + ::grpc::ServerContext* context, + const ::milvus::proto::milvus::HasCollectionRequest* request, + ::milvus::proto::milvus::BoolResponse* response) { + Status s = proxy_.HasCollection(request->collection_name()); + Status2Response(s, response->mutable_status()); + response->set_value(s.IsOk()); + return ::grpc::Status::OK; +} + +::grpc::Status +MilvusServiceImpl::CreateIndex( + ::grpc::ServerContext* context, + const ::milvus::proto::milvus::CreateIndexRequest* request, + ::milvus::proto::common::Status* response) { + Status s = proxy_.CreateIndex(request); + Status2Response(s, response); + return ::grpc::Status::OK; +} + +::grpc::Status +MilvusServiceImpl::Insert(::grpc::ServerContext* context, + const ::milvus::proto::milvus::InsertRequest* request, + ::milvus::proto::milvus::MutationResult* response) { + Status s = proxy_.Insert(request, response->mutable_ids()); + Status2Response(s, response->mutable_status()); + auto num_rows = request->num_rows(); + auto succ_size = std::max(response->ids().int_id().data_size(), + response->ids().str_id().data_size()); + response->set_insert_cnt(succ_size); + for (int64_t i = 0; i < succ_size; ++i) { + response->mutable_succ_index()->Add(i); + } + + for (int64_t i = succ_size; i < num_rows; ++i) { + response->mutable_err_index()->Add(i); + } + return ::grpc::Status::OK; +} + +::grpc::Status +MilvusServiceImpl::Search(::grpc::ServerContext* context, + const ::milvus::proto::milvus::SearchRequest* request, + ::milvus::proto::milvus::SearchResults* response) { + Status s = proxy_.Search(request, response); + Status2Response(s, response->mutable_status()); + return ::grpc::Status::OK; +} + +::grpc::Status +MilvusServiceImpl::Query(::grpc::ServerContext* context, + const ::milvus::proto::milvus::QueryRequest* request, + ::milvus::proto::milvus::QueryResults* response) { + Status s = proxy_.Query(request, response); + Status2Response(s, response->mutable_status()); + return ::grpc::Status::OK; +} + +::grpc::Status +MilvusServiceImpl::Delete(::grpc::ServerContext* context, + const ::milvus::proto::milvus::DeleteRequest* request, + ::milvus::proto::milvus::MutationResult* response) { + Status s = proxy_.Delete(request, response); + Status2Response(s, response->mutable_status()); + if (s.IsErr()) { + response->clear_delete_cnt(); + response->clear_ids(); + } + return ::grpc::Status::OK; +} + +::grpc::Status +MilvusServiceImpl::DescribeCollection( + ::grpc::ServerContext* context, + const ::milvus::proto::milvus::DescribeCollectionRequest* request, + ::milvus::proto::milvus::DescribeCollectionResponse* response) { + Status s = proxy_.DescribeCollection(request, response); + Status2Response(s, response->mutable_status()); + return ::grpc::Status::OK; +} + +/* + * Useless interface, just to make the process run. + */ +::grpc::Status +MilvusServiceImpl::Connect( + ::grpc::ServerContext* context, + const ::milvus::proto::milvus::ConnectRequest* request, + ::milvus::proto::milvus::ConnectResponse* response) { + Status s = Status::Ok(); + Status2Response(s, response->mutable_status()); + return ::grpc::Status::OK; +} + +::grpc::Status +MilvusServiceImpl::DescribeIndex( + ::grpc::ServerContext* context, + const ::milvus::proto::milvus::DescribeIndexRequest* request, + ::milvus::proto::milvus::DescribeIndexResponse* response) { + auto s = proxy_.GetIndex( + request->collection_name(), request->index_name(), response); + Status2Response(s, response->mutable_status()); + return ::grpc::Status::OK; +} + +/* + * Useless interface, just to make the process run. + */ +::grpc::Status +MilvusServiceImpl::AllocTimestamp( + ::grpc::ServerContext* context, + const ::milvus::proto::milvus::AllocTimestampRequest* request, + ::milvus::proto::milvus::AllocTimestampResponse* response) { + Status s = Status::Ok(); + Status2Response(s, response->mutable_status()); + response->set_timestamp(0); + return ::grpc::Status::OK; +} + +/* + * Useless interface, just to make the process run. + */ +::grpc::Status +MilvusServiceImpl::GetLoadingProgress( + ::grpc::ServerContext* context, + const ::milvus::proto::milvus::GetLoadingProgressRequest* request, + ::milvus::proto::milvus::GetLoadingProgressResponse* response) { + response->set_progress(100); + response->set_refresh_progress(100); + return ::grpc::Status::OK; +} + +::grpc::Status +MilvusServiceImpl::DropCollection( + ::grpc::ServerContext* context, + const ::milvus::proto::milvus::DropCollectionRequest* request, + ::milvus::proto::common::Status* response) { + if (proxy_.DropCollection(request->collection_name())) { + Status s = Status::Ok(); + Status2Response(s, response); + } else { + Status s = Status::ServiceInternal(); + Status2Response(s, response); + } + return ::grpc::Status::OK; +} + +::grpc::Status +MilvusServiceImpl::ReleaseCollection( + ::grpc::ServerContext* context, + const ::milvus::proto::milvus::ReleaseCollectionRequest* request, + ::milvus::proto::common::Status* response) { + auto s = proxy_.ReleaseCollection(request->collection_name()); + Status2Response(s, response); + return ::grpc::Status::OK; +} + +::grpc::Status +MilvusServiceImpl::ShowCollections( + ::grpc::ServerContext* context, + const ::milvus::proto::milvus::ShowCollectionsRequest* request, + ::milvus::proto::milvus::ShowCollectionsResponse* response) { + std::vector collections; + proxy_.ListCollection(&collections); + for (const auto& name : collections) { + response->add_collection_names(name); + } + Status s = Status::Ok(); + Status2Response(s, response->mutable_status()); + return ::grpc::Status::OK; +} + +::grpc::Status +MilvusServiceImpl::DropIndex( + ::grpc::ServerContext* context, + const ::milvus::proto::milvus::DropIndexRequest* request, + ::milvus::proto::common::Status* response) { + auto s = + proxy_.DropIndex(request->collection_name(), request->index_name()); + Status2Response(s, response); + return ::grpc::Status::OK; +} + +::grpc::Status +MilvusServiceImpl::Flush(::grpc::ServerContext* context, + const ::milvus::proto::milvus::FlushRequest* request, + ::milvus::proto::milvus::FlushResponse* response) { + // do nothing + Status s = Status::Ok(); + Status2Response(s, response->mutable_status()); + return ::grpc::Status::OK; +} + +::grpc::Status +MilvusServiceImpl::GetFlushState( + ::grpc::ServerContext* context, + const ::milvus::proto::milvus::GetFlushStateRequest* request, + ::milvus::proto::milvus::GetFlushStateResponse* response) { + Status s = Status::Ok(); + Status2Response(s, response->mutable_status()); + response->set_flushed(true); + return ::grpc::Status::OK; +} + +::grpc::Status +MilvusServiceImpl::GetCollectionStatistics( + ::grpc::ServerContext* context, + const ::milvus::proto::milvus::GetCollectionStatisticsRequest* request, + ::milvus::proto::milvus::GetCollectionStatisticsResponse* response) { + Status s = + proxy_.GetCollectionStatistics(request->collection_name(), response); + Status2Response(s, response->mutable_status()); + return ::grpc::Status::OK; +} + +} // namespace milvus::local diff --git a/src/milvus_service_impl.h b/src/milvus_service_impl.h new file mode 100644 index 0000000..c3c0222 --- /dev/null +++ b/src/milvus_service_impl.h @@ -0,0 +1,139 @@ +#pragma once + +#include +#include "pb/milvus.grpc.pb.h" +#include "milvus_proxy.h" + +namespace milvus::local { +class MilvusServiceImpl final + : public milvus::proto::milvus::MilvusService::Service { + public: + MilvusServiceImpl(const std::string& work_dir) : proxy_(work_dir.c_str()) { + } + virtual ~MilvusServiceImpl() = default; + + public: + bool + Init() { + return proxy_.Init(); + } + + public: + ::grpc::Status + CreateCollection( + ::grpc::ServerContext* context, + const ::milvus::proto::milvus::CreateCollectionRequest* request, + ::milvus::proto::common::Status* response) override; + + ::grpc::Status + CreateIndex(::grpc::ServerContext* context, + const ::milvus::proto::milvus::CreateIndexRequest* request, + ::milvus::proto::common::Status* response) override; + + ::grpc::Status + LoadCollection( + ::grpc::ServerContext* context, + const ::milvus::proto::milvus::LoadCollectionRequest* request, + ::milvus::proto::common::Status* response) override; + + ::grpc::Status + Insert(::grpc::ServerContext* context, + const ::milvus::proto::milvus::InsertRequest* request, + ::milvus::proto::milvus::MutationResult* response) override; + + ::grpc::Status + Search(::grpc::ServerContext* context, + const ::milvus::proto::milvus::SearchRequest* request, + ::milvus::proto::milvus::SearchResults* response) override; + + ::grpc::Status + Query(::grpc::ServerContext* context, + const ::milvus::proto::milvus::QueryRequest* request, + ::milvus::proto::milvus::QueryResults* response) override; + + ::grpc::Status + Delete(::grpc::ServerContext* context, + const ::milvus::proto::milvus::DeleteRequest* request, + ::milvus::proto::milvus::MutationResult* response) override; + + ::grpc::Status + DescribeCollection( + ::grpc::ServerContext* context, + const ::milvus::proto::milvus::DescribeCollectionRequest* request, + ::milvus::proto::milvus::DescribeCollectionResponse* response) override; + + ::grpc::Status + Connect(::grpc::ServerContext* context, + const ::milvus::proto::milvus::ConnectRequest* request, + ::milvus::proto::milvus::ConnectResponse* response) override; + + ::grpc::Status + DescribeIndex( + ::grpc::ServerContext* context, + const ::milvus::proto::milvus::DescribeIndexRequest* request, + ::milvus::proto::milvus::DescribeIndexResponse* response) override; + + ::grpc::Status + AllocTimestamp( + ::grpc::ServerContext* context, + const ::milvus::proto::milvus::AllocTimestampRequest* request, + ::milvus::proto::milvus::AllocTimestampResponse* response) override; + + ::grpc::Status + GetLoadingProgress( + ::grpc::ServerContext* context, + const ::milvus::proto::milvus::GetLoadingProgressRequest* request, + ::milvus::proto::milvus::GetLoadingProgressResponse* response) override; + + ::grpc::Status + DropCollection( + ::grpc::ServerContext* context, + const ::milvus::proto::milvus::DropCollectionRequest* request, + ::milvus::proto::common::Status* response) override; + + ::grpc::Status + ReleaseCollection( + ::grpc::ServerContext* context, + const ::milvus::proto::milvus::ReleaseCollectionRequest* request, + ::milvus::proto::common::Status* response) override; + + ::grpc::Status + ShowCollections( + ::grpc::ServerContext* context, + const ::milvus::proto::milvus::ShowCollectionsRequest* request, + ::milvus::proto::milvus::ShowCollectionsResponse* response) override; + + ::grpc::Status + HasCollection(::grpc::ServerContext* context, + const ::milvus::proto::milvus::HasCollectionRequest* request, + ::milvus::proto::milvus::BoolResponse* response) override; + + ::grpc::Status + DropIndex(::grpc::ServerContext* context, + const ::milvus::proto::milvus::DropIndexRequest* request, + ::milvus::proto::common::Status* response) override; + + ::grpc::Status + Flush(::grpc::ServerContext* context, + const ::milvus::proto::milvus::FlushRequest* request, + ::milvus::proto::milvus::FlushResponse* response) override; + + ::grpc::Status + GetFlushState( + ::grpc::ServerContext* context, + const ::milvus::proto::milvus::GetFlushStateRequest* request, + ::milvus::proto::milvus::GetFlushStateResponse* response) override; + + // for now only row count is returned + ::grpc::Status + GetCollectionStatistics( + ::grpc::ServerContext* context, + const ::milvus::proto::milvus::GetCollectionStatisticsRequest* request, + ::milvus::proto::milvus::GetCollectionStatisticsResponse* response) + override; + + private: + ::milvus::local::MilvusProxy proxy_; +}; + +} // namespace milvus::local diff --git a/src/parser/Plan.g4 b/src/parser/Plan.g4 new file mode 100644 index 0000000..cc8a479 --- /dev/null +++ b/src/parser/Plan.g4 @@ -0,0 +1,151 @@ +grammar Plan; + +expr: + IntegerConstant # Integer + | FloatingConstant # Floating + | BooleanConstant # Boolean + | StringLiteral # String + | Identifier # Identifier + | JSONIdentifier # JSONIdentifier + | '(' expr ')' # Parens + | '[' expr (',' expr)* ','? ']' # Array + | expr LIKE StringLiteral # Like + | expr POW expr # Power + | op = (ADD | SUB | BNOT | NOT) expr # Unary +// | '(' typeName ')' expr # Cast + | expr op = (MUL | DIV | MOD) expr # MulDivMod + | expr op = (ADD | SUB) expr # AddSub + | expr op = (SHL | SHR) expr # Shift + | expr op = (IN | NIN) ('[' expr (',' expr)* ','? ']') # Term + | expr op = (IN | NIN) EmptyTerm # EmptyTerm + | (JSONContains | ArrayContains)'('expr',' expr')' # JSONContains + | (JSONContainsAll | ArrayContainsAll)'('expr',' expr')' # JSONContainsAll + | (JSONContainsAny | ArrayContainsAny)'('expr',' expr')' # JSONContainsAny + | ArrayLength'('(Identifier | JSONIdentifier)')' # ArrayLength + | expr op1 = (LT | LE) (Identifier | JSONIdentifier) op2 = (LT | LE) expr # Range + | expr op1 = (GT | GE) (Identifier | JSONIdentifier) op2 = (GT | GE) expr # ReverseRange + | expr op = (LT | LE | GT | GE) expr # Relational + | expr op = (EQ | NE) expr # Equality + | expr BAND expr # BitAnd + | expr BXOR expr # BitXor + | expr BOR expr # BitOr + | expr AND expr # LogicalAnd + | expr OR expr # LogicalOr + | EXISTS expr # Exists; + +// typeName: ty = (BOOL | INT8 | INT16 | INT32 | INT64 | FLOAT | DOUBLE); + +// BOOL: 'bool'; +// INT8: 'int8'; +// INT16: 'int16'; +// INT32: 'int32'; +// INT64: 'int64'; +// FLOAT: 'float'; +// DOUBLE: 'double'; + +LT: '<'; +LE: '<='; +GT: '>'; +GE: '>='; +EQ: '=='; +NE: '!='; + +LIKE: 'like' | 'LIKE'; +EXISTS: 'exists' | 'EXISTS'; + +ADD: '+'; +SUB: '-'; +MUL: '*'; +DIV: '/'; +MOD: '%'; +POW: '**'; +SHL: '<<'; +SHR: '>>'; +BAND: '&'; +BOR: '|'; +BXOR: '^'; + +AND: '&&' | 'and'; +OR: '||' | 'or'; + +BNOT: '~'; +NOT: '!' | 'not'; + +IN: 'in'; +NIN: 'not in'; +EmptyTerm: '[' (Whitespace | Newline)* ']'; + +JSONContains: 'json_contains' | 'JSON_CONTAINS'; +JSONContainsAll: 'json_contains_all' | 'JSON_CONTAINS_ALL'; +JSONContainsAny: 'json_contains_any' | 'JSON_CONTAINS_ANY'; + +ArrayContains: 'array_contains' | 'ARRAY_CONTAINS'; +ArrayContainsAll: 'array_contains_all' | 'ARRAY_CONTAINS_ALL'; +ArrayContainsAny: 'array_contains_any' | 'ARRAY_CONTAINS_ANY'; +ArrayLength: 'array_length' | 'ARRAY_LENGTH'; + +BooleanConstant: 'true' | 'True' | 'TRUE' | 'false' | 'False' | 'FALSE'; + +IntegerConstant: + DecimalConstant + | OctalConstant + | HexadecimalConstant + | BinaryConstant; + +FloatingConstant: + DecimalFloatingConstant + | HexadecimalFloatingConstant; + +Identifier: Nondigit (Nondigit | Digit)* | '$meta'; + +StringLiteral: EncodingPrefix? ('"' DoubleSCharSequence? '"' | '\'' SingleSCharSequence? '\''); +JSONIdentifier: Identifier('[' (StringLiteral | DecimalConstant) ']')+; + +fragment EncodingPrefix: 'u8' | 'u' | 'U' | 'L'; + +fragment DoubleSCharSequence: DoubleSChar+; +fragment SingleSCharSequence: SingleSChar+; + +fragment DoubleSChar: ~["\\\r\n] | EscapeSequence | '\\\n' | '\\\r\n'; +fragment SingleSChar: ~['\\\r\n] | EscapeSequence | '\\\n' | '\\\r\n'; +fragment Nondigit: [a-zA-Z_]; +fragment Digit: [0-9]; +fragment BinaryConstant: '0' [bB] [0-1]+; +fragment DecimalConstant: NonzeroDigit Digit* | '0'; +fragment OctalConstant: '0' OctalDigit*; +fragment HexadecimalConstant: '0' [xX] HexadecimalDigitSequence; +fragment NonzeroDigit: [1-9]; +fragment OctalDigit: [0-7]; +fragment HexadecimalDigit: [0-9a-fA-F]; +fragment HexQuad: + HexadecimalDigit HexadecimalDigit HexadecimalDigit HexadecimalDigit; +fragment UniversalCharacterName: + '\\u' HexQuad + | '\\U' HexQuad HexQuad; +fragment DecimalFloatingConstant: + FractionalConstant ExponentPart? + | DigitSequence ExponentPart; +fragment HexadecimalFloatingConstant: + '0' [xX] ( + HexadecimalFractionalConstant + | HexadecimalDigitSequence + ) BinaryExponentPart; +fragment FractionalConstant: + DigitSequence? '.' DigitSequence + | DigitSequence '.'; +fragment ExponentPart: [eE] [+-]? DigitSequence; +fragment DigitSequence: Digit+; +fragment HexadecimalFractionalConstant: + HexadecimalDigitSequence? '.' HexadecimalDigitSequence + | HexadecimalDigitSequence '.'; +fragment HexadecimalDigitSequence: HexadecimalDigit+; +fragment BinaryExponentPart: [pP] [+-]? DigitSequence; +fragment EscapeSequence: + '\\' ['"?abfnrtv\\] + | '\\' OctalDigit OctalDigit? OctalDigit? + | '\\x' HexadecimalDigitSequence + | UniversalCharacterName; + +Whitespace: [ \t]+ -> skip; + +Newline: ( '\r' '\n'? | '\n') -> skip; diff --git a/src/parser/antlr/PlanBaseVisitor.cpp b/src/parser/antlr/PlanBaseVisitor.cpp new file mode 100644 index 0000000..ad98630 --- /dev/null +++ b/src/parser/antlr/PlanBaseVisitor.cpp @@ -0,0 +1,7 @@ + +// Generated from Plan.g4 by ANTLR 4.13.1 + + +#include "PlanBaseVisitor.h" + + diff --git a/src/parser/antlr/PlanBaseVisitor.h b/src/parser/antlr/PlanBaseVisitor.h new file mode 100644 index 0000000..d4503d3 --- /dev/null +++ b/src/parser/antlr/PlanBaseVisitor.h @@ -0,0 +1,140 @@ + +// Generated from Plan.g4 by ANTLR 4.13.1 + +#pragma once + + +#include "antlr4-runtime.h" +#include "PlanVisitor.h" + + +/** + * This class provides an empty implementation of PlanVisitor, which can be + * extended to create a visitor which only needs to handle a subset of the available methods. + */ +class PlanBaseVisitor : public PlanVisitor { +public: + + virtual std::any visitJSONIdentifier(PlanParser::JSONIdentifierContext *ctx) override { + return visitChildren(ctx); + } + + virtual std::any visitParens(PlanParser::ParensContext *ctx) override { + return visitChildren(ctx); + } + + virtual std::any visitString(PlanParser::StringContext *ctx) override { + return visitChildren(ctx); + } + + virtual std::any visitFloating(PlanParser::FloatingContext *ctx) override { + return visitChildren(ctx); + } + + virtual std::any visitJSONContainsAll(PlanParser::JSONContainsAllContext *ctx) override { + return visitChildren(ctx); + } + + virtual std::any visitLogicalOr(PlanParser::LogicalOrContext *ctx) override { + return visitChildren(ctx); + } + + virtual std::any visitMulDivMod(PlanParser::MulDivModContext *ctx) override { + return visitChildren(ctx); + } + + virtual std::any visitIdentifier(PlanParser::IdentifierContext *ctx) override { + return visitChildren(ctx); + } + + virtual std::any visitLike(PlanParser::LikeContext *ctx) override { + return visitChildren(ctx); + } + + virtual std::any visitLogicalAnd(PlanParser::LogicalAndContext *ctx) override { + return visitChildren(ctx); + } + + virtual std::any visitEquality(PlanParser::EqualityContext *ctx) override { + return visitChildren(ctx); + } + + virtual std::any visitBoolean(PlanParser::BooleanContext *ctx) override { + return visitChildren(ctx); + } + + virtual std::any visitShift(PlanParser::ShiftContext *ctx) override { + return visitChildren(ctx); + } + + virtual std::any visitReverseRange(PlanParser::ReverseRangeContext *ctx) override { + return visitChildren(ctx); + } + + virtual std::any visitBitOr(PlanParser::BitOrContext *ctx) override { + return visitChildren(ctx); + } + + virtual std::any visitAddSub(PlanParser::AddSubContext *ctx) override { + return visitChildren(ctx); + } + + virtual std::any visitRelational(PlanParser::RelationalContext *ctx) override { + return visitChildren(ctx); + } + + virtual std::any visitArrayLength(PlanParser::ArrayLengthContext *ctx) override { + return visitChildren(ctx); + } + + virtual std::any visitTerm(PlanParser::TermContext *ctx) override { + return visitChildren(ctx); + } + + virtual std::any visitJSONContains(PlanParser::JSONContainsContext *ctx) override { + return visitChildren(ctx); + } + + virtual std::any visitRange(PlanParser::RangeContext *ctx) override { + return visitChildren(ctx); + } + + virtual std::any visitUnary(PlanParser::UnaryContext *ctx) override { + return visitChildren(ctx); + } + + virtual std::any visitInteger(PlanParser::IntegerContext *ctx) override { + return visitChildren(ctx); + } + + virtual std::any visitArray(PlanParser::ArrayContext *ctx) override { + return visitChildren(ctx); + } + + virtual std::any visitJSONContainsAny(PlanParser::JSONContainsAnyContext *ctx) override { + return visitChildren(ctx); + } + + virtual std::any visitBitXor(PlanParser::BitXorContext *ctx) override { + return visitChildren(ctx); + } + + virtual std::any visitExists(PlanParser::ExistsContext *ctx) override { + return visitChildren(ctx); + } + + virtual std::any visitBitAnd(PlanParser::BitAndContext *ctx) override { + return visitChildren(ctx); + } + + virtual std::any visitEmptyTerm(PlanParser::EmptyTermContext *ctx) override { + return visitChildren(ctx); + } + + virtual std::any visitPower(PlanParser::PowerContext *ctx) override { + return visitChildren(ctx); + } + + +}; + diff --git a/src/parser/antlr/PlanLexer.cpp b/src/parser/antlr/PlanLexer.cpp new file mode 100644 index 0000000..798c211 --- /dev/null +++ b/src/parser/antlr/PlanLexer.cpp @@ -0,0 +1,417 @@ + +// Generated from Plan.g4 by ANTLR 4.13.1 + + +#include "PlanLexer.h" + + +using namespace antlr4; + + + +using namespace antlr4; + +namespace { + +struct PlanLexerStaticData final { + PlanLexerStaticData(std::vector ruleNames, + std::vector channelNames, + std::vector modeNames, + std::vector literalNames, + std::vector symbolicNames) + : ruleNames(std::move(ruleNames)), channelNames(std::move(channelNames)), + modeNames(std::move(modeNames)), literalNames(std::move(literalNames)), + symbolicNames(std::move(symbolicNames)), + vocabulary(this->literalNames, this->symbolicNames) {} + + PlanLexerStaticData(const PlanLexerStaticData&) = delete; + PlanLexerStaticData(PlanLexerStaticData&&) = delete; + PlanLexerStaticData& operator=(const PlanLexerStaticData&) = delete; + PlanLexerStaticData& operator=(PlanLexerStaticData&&) = delete; + + std::vector decisionToDFA; + antlr4::atn::PredictionContextCache sharedContextCache; + const std::vector ruleNames; + const std::vector channelNames; + const std::vector modeNames; + const std::vector literalNames; + const std::vector symbolicNames; + const antlr4::dfa::Vocabulary vocabulary; + antlr4::atn::SerializedATNView serializedATN; + std::unique_ptr atn; +}; + +::antlr4::internal::OnceFlag planlexerLexerOnceFlag; +#if ANTLR4_USE_THREAD_LOCAL_CACHE +static thread_local +#endif +PlanLexerStaticData *planlexerLexerStaticData = nullptr; + +void planlexerLexerInitialize() { +#if ANTLR4_USE_THREAD_LOCAL_CACHE + if (planlexerLexerStaticData != nullptr) { + return; + } +#else + assert(planlexerLexerStaticData == nullptr); +#endif + auto staticData = std::make_unique( + std::vector{ + "T__0", "T__1", "T__2", "T__3", "T__4", "LT", "LE", "GT", "GE", "EQ", + "NE", "LIKE", "EXISTS", "ADD", "SUB", "MUL", "DIV", "MOD", "POW", + "SHL", "SHR", "BAND", "BOR", "BXOR", "AND", "OR", "BNOT", "NOT", "IN", + "NIN", "EmptyTerm", "JSONContains", "JSONContainsAll", "JSONContainsAny", + "ArrayContains", "ArrayContainsAll", "ArrayContainsAny", "ArrayLength", + "BooleanConstant", "IntegerConstant", "FloatingConstant", "Identifier", + "StringLiteral", "JSONIdentifier", "EncodingPrefix", "DoubleSCharSequence", + "SingleSCharSequence", "DoubleSChar", "SingleSChar", "Nondigit", "Digit", + "BinaryConstant", "DecimalConstant", "OctalConstant", "HexadecimalConstant", + "NonzeroDigit", "OctalDigit", "HexadecimalDigit", "HexQuad", "UniversalCharacterName", + "DecimalFloatingConstant", "HexadecimalFloatingConstant", "FractionalConstant", + "ExponentPart", "DigitSequence", "HexadecimalFractionalConstant", + "HexadecimalDigitSequence", "BinaryExponentPart", "EscapeSequence", + "Whitespace", "Newline" + }, + std::vector{ + "DEFAULT_TOKEN_CHANNEL", "HIDDEN" + }, + std::vector{ + "DEFAULT_MODE" + }, + std::vector{ + "", "'('", "')'", "'['", "','", "']'", "'<'", "'<='", "'>'", "'>='", + "'=='", "'!='", "", "", "'+'", "'-'", "'*'", "'/'", "'%'", "'**'", + "'<<'", "'>>'", "'&'", "'|'", "'^'", "", "", "'~'", "", "'in'", "'not in'" + }, + std::vector{ + "", "", "", "", "", "", "LT", "LE", "GT", "GE", "EQ", "NE", "LIKE", + "EXISTS", "ADD", "SUB", "MUL", "DIV", "MOD", "POW", "SHL", "SHR", + "BAND", "BOR", "BXOR", "AND", "OR", "BNOT", "NOT", "IN", "NIN", "EmptyTerm", + "JSONContains", "JSONContainsAll", "JSONContainsAny", "ArrayContains", + "ArrayContainsAll", "ArrayContainsAny", "ArrayLength", "BooleanConstant", + "IntegerConstant", "FloatingConstant", "Identifier", "StringLiteral", + "JSONIdentifier", "Whitespace", "Newline" + } + ); + static const int32_t serializedATNSegment[] = { + 4,0,46,752,6,-1,2,0,7,0,2,1,7,1,2,2,7,2,2,3,7,3,2,4,7,4,2,5,7,5,2,6,7, + 6,2,7,7,7,2,8,7,8,2,9,7,9,2,10,7,10,2,11,7,11,2,12,7,12,2,13,7,13,2,14, + 7,14,2,15,7,15,2,16,7,16,2,17,7,17,2,18,7,18,2,19,7,19,2,20,7,20,2,21, + 7,21,2,22,7,22,2,23,7,23,2,24,7,24,2,25,7,25,2,26,7,26,2,27,7,27,2,28, + 7,28,2,29,7,29,2,30,7,30,2,31,7,31,2,32,7,32,2,33,7,33,2,34,7,34,2,35, + 7,35,2,36,7,36,2,37,7,37,2,38,7,38,2,39,7,39,2,40,7,40,2,41,7,41,2,42, + 7,42,2,43,7,43,2,44,7,44,2,45,7,45,2,46,7,46,2,47,7,47,2,48,7,48,2,49, + 7,49,2,50,7,50,2,51,7,51,2,52,7,52,2,53,7,53,2,54,7,54,2,55,7,55,2,56, + 7,56,2,57,7,57,2,58,7,58,2,59,7,59,2,60,7,60,2,61,7,61,2,62,7,62,2,63, + 7,63,2,64,7,64,2,65,7,65,2,66,7,66,2,67,7,67,2,68,7,68,2,69,7,69,2,70, + 7,70,1,0,1,0,1,1,1,1,1,2,1,2,1,3,1,3,1,4,1,4,1,5,1,5,1,6,1,6,1,6,1,7, + 1,7,1,8,1,8,1,8,1,9,1,9,1,9,1,10,1,10,1,10,1,11,1,11,1,11,1,11,1,11,1, + 11,1,11,1,11,3,11,178,8,11,1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12,1, + 12,1,12,1,12,1,12,3,12,192,8,12,1,13,1,13,1,14,1,14,1,15,1,15,1,16,1, + 16,1,17,1,17,1,18,1,18,1,18,1,19,1,19,1,19,1,20,1,20,1,20,1,21,1,21,1, + 22,1,22,1,23,1,23,1,24,1,24,1,24,1,24,1,24,3,24,224,8,24,1,25,1,25,1, + 25,1,25,3,25,230,8,25,1,26,1,26,1,27,1,27,1,27,1,27,3,27,238,8,27,1,28, + 1,28,1,28,1,29,1,29,1,29,1,29,1,29,1,29,1,29,1,30,1,30,1,30,5,30,253, + 8,30,10,30,12,30,256,9,30,1,30,1,30,1,31,1,31,1,31,1,31,1,31,1,31,1,31, + 1,31,1,31,1,31,1,31,1,31,1,31,1,31,1,31,1,31,1,31,1,31,1,31,1,31,1,31, + 1,31,1,31,1,31,1,31,1,31,3,31,286,8,31,1,32,1,32,1,32,1,32,1,32,1,32, + 1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32, + 1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32, + 3,32,322,8,32,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33, + 1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33, + 1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,3,33,358,8,33,1,34,1,34, + 1,34,1,34,1,34,1,34,1,34,1,34,1,34,1,34,1,34,1,34,1,34,1,34,1,34,1,34, + 1,34,1,34,1,34,1,34,1,34,1,34,1,34,1,34,1,34,1,34,1,34,1,34,3,34,388, + 8,34,1,35,1,35,1,35,1,35,1,35,1,35,1,35,1,35,1,35,1,35,1,35,1,35,1,35, + 1,35,1,35,1,35,1,35,1,35,1,35,1,35,1,35,1,35,1,35,1,35,1,35,1,35,1,35, + 1,35,1,35,1,35,1,35,1,35,1,35,1,35,1,35,1,35,3,35,426,8,35,1,36,1,36, + 1,36,1,36,1,36,1,36,1,36,1,36,1,36,1,36,1,36,1,36,1,36,1,36,1,36,1,36, + 1,36,1,36,1,36,1,36,1,36,1,36,1,36,1,36,1,36,1,36,1,36,1,36,1,36,1,36, + 1,36,1,36,1,36,1,36,1,36,1,36,3,36,464,8,36,1,37,1,37,1,37,1,37,1,37, + 1,37,1,37,1,37,1,37,1,37,1,37,1,37,1,37,1,37,1,37,1,37,1,37,1,37,1,37, + 1,37,1,37,1,37,1,37,1,37,3,37,490,8,37,1,38,1,38,1,38,1,38,1,38,1,38, + 1,38,1,38,1,38,1,38,1,38,1,38,1,38,1,38,1,38,1,38,1,38,1,38,1,38,1,38, + 1,38,1,38,1,38,1,38,1,38,1,38,1,38,3,38,519,8,38,1,39,1,39,1,39,1,39, + 3,39,525,8,39,1,40,1,40,3,40,529,8,40,1,41,1,41,1,41,5,41,534,8,41,10, + 41,12,41,537,9,41,1,41,1,41,1,41,1,41,1,41,3,41,544,8,41,1,42,3,42,547, + 8,42,1,42,1,42,3,42,551,8,42,1,42,1,42,1,42,3,42,556,8,42,1,42,3,42,559, + 8,42,1,43,1,43,1,43,1,43,3,43,565,8,43,1,43,1,43,4,43,569,8,43,11,43, + 12,43,570,1,44,1,44,1,44,3,44,576,8,44,1,45,4,45,579,8,45,11,45,12,45, + 580,1,46,4,46,584,8,46,11,46,12,46,585,1,47,1,47,1,47,1,47,1,47,1,47, + 1,47,3,47,595,8,47,1,48,1,48,1,48,1,48,1,48,1,48,1,48,3,48,604,8,48,1, + 49,1,49,1,50,1,50,1,51,1,51,1,51,4,51,613,8,51,11,51,12,51,614,1,52,1, + 52,5,52,619,8,52,10,52,12,52,622,9,52,1,52,3,52,625,8,52,1,53,1,53,5, + 53,629,8,53,10,53,12,53,632,9,53,1,54,1,54,1,54,1,54,1,55,1,55,1,56,1, + 56,1,57,1,57,1,58,1,58,1,58,1,58,1,58,1,59,1,59,1,59,1,59,1,59,1,59,1, + 59,1,59,1,59,1,59,3,59,659,8,59,1,60,1,60,3,60,663,8,60,1,60,1,60,1,60, + 3,60,668,8,60,1,61,1,61,1,61,1,61,3,61,674,8,61,1,61,1,61,1,62,3,62,679, + 8,62,1,62,1,62,1,62,1,62,1,62,3,62,686,8,62,1,63,1,63,3,63,690,8,63,1, + 63,1,63,1,64,4,64,695,8,64,11,64,12,64,696,1,65,3,65,700,8,65,1,65,1, + 65,1,65,1,65,1,65,3,65,707,8,65,1,66,4,66,710,8,66,11,66,12,66,711,1, + 67,1,67,3,67,716,8,67,1,67,1,67,1,68,1,68,1,68,1,68,1,68,3,68,725,8,68, + 1,68,3,68,728,8,68,1,68,1,68,1,68,1,68,1,68,3,68,735,8,68,1,69,4,69,738, + 8,69,11,69,12,69,739,1,69,1,69,1,70,1,70,3,70,746,8,70,1,70,3,70,749, + 8,70,1,70,1,70,0,0,71,1,1,3,2,5,3,7,4,9,5,11,6,13,7,15,8,17,9,19,10,21, + 11,23,12,25,13,27,14,29,15,31,16,33,17,35,18,37,19,39,20,41,21,43,22, + 45,23,47,24,49,25,51,26,53,27,55,28,57,29,59,30,61,31,63,32,65,33,67, + 34,69,35,71,36,73,37,75,38,77,39,79,40,81,41,83,42,85,43,87,44,89,0,91, + 0,93,0,95,0,97,0,99,0,101,0,103,0,105,0,107,0,109,0,111,0,113,0,115,0, + 117,0,119,0,121,0,123,0,125,0,127,0,129,0,131,0,133,0,135,0,137,0,139, + 45,141,46,1,0,16,3,0,76,76,85,85,117,117,4,0,10,10,13,13,34,34,92,92, + 4,0,10,10,13,13,39,39,92,92,3,0,65,90,95,95,97,122,1,0,48,57,2,0,66,66, + 98,98,1,0,48,49,2,0,88,88,120,120,1,0,49,57,1,0,48,55,3,0,48,57,65,70, + 97,102,2,0,69,69,101,101,2,0,43,43,45,45,2,0,80,80,112,112,10,0,34,34, + 39,39,63,63,92,92,97,98,102,102,110,110,114,114,116,116,118,118,2,0,9, + 9,32,32,791,0,1,1,0,0,0,0,3,1,0,0,0,0,5,1,0,0,0,0,7,1,0,0,0,0,9,1,0,0, + 0,0,11,1,0,0,0,0,13,1,0,0,0,0,15,1,0,0,0,0,17,1,0,0,0,0,19,1,0,0,0,0, + 21,1,0,0,0,0,23,1,0,0,0,0,25,1,0,0,0,0,27,1,0,0,0,0,29,1,0,0,0,0,31,1, + 0,0,0,0,33,1,0,0,0,0,35,1,0,0,0,0,37,1,0,0,0,0,39,1,0,0,0,0,41,1,0,0, + 0,0,43,1,0,0,0,0,45,1,0,0,0,0,47,1,0,0,0,0,49,1,0,0,0,0,51,1,0,0,0,0, + 53,1,0,0,0,0,55,1,0,0,0,0,57,1,0,0,0,0,59,1,0,0,0,0,61,1,0,0,0,0,63,1, + 0,0,0,0,65,1,0,0,0,0,67,1,0,0,0,0,69,1,0,0,0,0,71,1,0,0,0,0,73,1,0,0, + 0,0,75,1,0,0,0,0,77,1,0,0,0,0,79,1,0,0,0,0,81,1,0,0,0,0,83,1,0,0,0,0, + 85,1,0,0,0,0,87,1,0,0,0,0,139,1,0,0,0,0,141,1,0,0,0,1,143,1,0,0,0,3,145, + 1,0,0,0,5,147,1,0,0,0,7,149,1,0,0,0,9,151,1,0,0,0,11,153,1,0,0,0,13,155, + 1,0,0,0,15,158,1,0,0,0,17,160,1,0,0,0,19,163,1,0,0,0,21,166,1,0,0,0,23, + 177,1,0,0,0,25,191,1,0,0,0,27,193,1,0,0,0,29,195,1,0,0,0,31,197,1,0,0, + 0,33,199,1,0,0,0,35,201,1,0,0,0,37,203,1,0,0,0,39,206,1,0,0,0,41,209, + 1,0,0,0,43,212,1,0,0,0,45,214,1,0,0,0,47,216,1,0,0,0,49,223,1,0,0,0,51, + 229,1,0,0,0,53,231,1,0,0,0,55,237,1,0,0,0,57,239,1,0,0,0,59,242,1,0,0, + 0,61,249,1,0,0,0,63,285,1,0,0,0,65,321,1,0,0,0,67,357,1,0,0,0,69,387, + 1,0,0,0,71,425,1,0,0,0,73,463,1,0,0,0,75,489,1,0,0,0,77,518,1,0,0,0,79, + 524,1,0,0,0,81,528,1,0,0,0,83,543,1,0,0,0,85,546,1,0,0,0,87,560,1,0,0, + 0,89,575,1,0,0,0,91,578,1,0,0,0,93,583,1,0,0,0,95,594,1,0,0,0,97,603, + 1,0,0,0,99,605,1,0,0,0,101,607,1,0,0,0,103,609,1,0,0,0,105,624,1,0,0, + 0,107,626,1,0,0,0,109,633,1,0,0,0,111,637,1,0,0,0,113,639,1,0,0,0,115, + 641,1,0,0,0,117,643,1,0,0,0,119,658,1,0,0,0,121,667,1,0,0,0,123,669,1, + 0,0,0,125,685,1,0,0,0,127,687,1,0,0,0,129,694,1,0,0,0,131,706,1,0,0,0, + 133,709,1,0,0,0,135,713,1,0,0,0,137,734,1,0,0,0,139,737,1,0,0,0,141,748, + 1,0,0,0,143,144,5,40,0,0,144,2,1,0,0,0,145,146,5,41,0,0,146,4,1,0,0,0, + 147,148,5,91,0,0,148,6,1,0,0,0,149,150,5,44,0,0,150,8,1,0,0,0,151,152, + 5,93,0,0,152,10,1,0,0,0,153,154,5,60,0,0,154,12,1,0,0,0,155,156,5,60, + 0,0,156,157,5,61,0,0,157,14,1,0,0,0,158,159,5,62,0,0,159,16,1,0,0,0,160, + 161,5,62,0,0,161,162,5,61,0,0,162,18,1,0,0,0,163,164,5,61,0,0,164,165, + 5,61,0,0,165,20,1,0,0,0,166,167,5,33,0,0,167,168,5,61,0,0,168,22,1,0, + 0,0,169,170,5,108,0,0,170,171,5,105,0,0,171,172,5,107,0,0,172,178,5,101, + 0,0,173,174,5,76,0,0,174,175,5,73,0,0,175,176,5,75,0,0,176,178,5,69,0, + 0,177,169,1,0,0,0,177,173,1,0,0,0,178,24,1,0,0,0,179,180,5,101,0,0,180, + 181,5,120,0,0,181,182,5,105,0,0,182,183,5,115,0,0,183,184,5,116,0,0,184, + 192,5,115,0,0,185,186,5,69,0,0,186,187,5,88,0,0,187,188,5,73,0,0,188, + 189,5,83,0,0,189,190,5,84,0,0,190,192,5,83,0,0,191,179,1,0,0,0,191,185, + 1,0,0,0,192,26,1,0,0,0,193,194,5,43,0,0,194,28,1,0,0,0,195,196,5,45,0, + 0,196,30,1,0,0,0,197,198,5,42,0,0,198,32,1,0,0,0,199,200,5,47,0,0,200, + 34,1,0,0,0,201,202,5,37,0,0,202,36,1,0,0,0,203,204,5,42,0,0,204,205,5, + 42,0,0,205,38,1,0,0,0,206,207,5,60,0,0,207,208,5,60,0,0,208,40,1,0,0, + 0,209,210,5,62,0,0,210,211,5,62,0,0,211,42,1,0,0,0,212,213,5,38,0,0,213, + 44,1,0,0,0,214,215,5,124,0,0,215,46,1,0,0,0,216,217,5,94,0,0,217,48,1, + 0,0,0,218,219,5,38,0,0,219,224,5,38,0,0,220,221,5,97,0,0,221,222,5,110, + 0,0,222,224,5,100,0,0,223,218,1,0,0,0,223,220,1,0,0,0,224,50,1,0,0,0, + 225,226,5,124,0,0,226,230,5,124,0,0,227,228,5,111,0,0,228,230,5,114,0, + 0,229,225,1,0,0,0,229,227,1,0,0,0,230,52,1,0,0,0,231,232,5,126,0,0,232, + 54,1,0,0,0,233,238,5,33,0,0,234,235,5,110,0,0,235,236,5,111,0,0,236,238, + 5,116,0,0,237,233,1,0,0,0,237,234,1,0,0,0,238,56,1,0,0,0,239,240,5,105, + 0,0,240,241,5,110,0,0,241,58,1,0,0,0,242,243,5,110,0,0,243,244,5,111, + 0,0,244,245,5,116,0,0,245,246,5,32,0,0,246,247,5,105,0,0,247,248,5,110, + 0,0,248,60,1,0,0,0,249,254,5,91,0,0,250,253,3,139,69,0,251,253,3,141, + 70,0,252,250,1,0,0,0,252,251,1,0,0,0,253,256,1,0,0,0,254,252,1,0,0,0, + 254,255,1,0,0,0,255,257,1,0,0,0,256,254,1,0,0,0,257,258,5,93,0,0,258, + 62,1,0,0,0,259,260,5,106,0,0,260,261,5,115,0,0,261,262,5,111,0,0,262, + 263,5,110,0,0,263,264,5,95,0,0,264,265,5,99,0,0,265,266,5,111,0,0,266, + 267,5,110,0,0,267,268,5,116,0,0,268,269,5,97,0,0,269,270,5,105,0,0,270, + 271,5,110,0,0,271,286,5,115,0,0,272,273,5,74,0,0,273,274,5,83,0,0,274, + 275,5,79,0,0,275,276,5,78,0,0,276,277,5,95,0,0,277,278,5,67,0,0,278,279, + 5,79,0,0,279,280,5,78,0,0,280,281,5,84,0,0,281,282,5,65,0,0,282,283,5, + 73,0,0,283,284,5,78,0,0,284,286,5,83,0,0,285,259,1,0,0,0,285,272,1,0, + 0,0,286,64,1,0,0,0,287,288,5,106,0,0,288,289,5,115,0,0,289,290,5,111, + 0,0,290,291,5,110,0,0,291,292,5,95,0,0,292,293,5,99,0,0,293,294,5,111, + 0,0,294,295,5,110,0,0,295,296,5,116,0,0,296,297,5,97,0,0,297,298,5,105, + 0,0,298,299,5,110,0,0,299,300,5,115,0,0,300,301,5,95,0,0,301,302,5,97, + 0,0,302,303,5,108,0,0,303,322,5,108,0,0,304,305,5,74,0,0,305,306,5,83, + 0,0,306,307,5,79,0,0,307,308,5,78,0,0,308,309,5,95,0,0,309,310,5,67,0, + 0,310,311,5,79,0,0,311,312,5,78,0,0,312,313,5,84,0,0,313,314,5,65,0,0, + 314,315,5,73,0,0,315,316,5,78,0,0,316,317,5,83,0,0,317,318,5,95,0,0,318, + 319,5,65,0,0,319,320,5,76,0,0,320,322,5,76,0,0,321,287,1,0,0,0,321,304, + 1,0,0,0,322,66,1,0,0,0,323,324,5,106,0,0,324,325,5,115,0,0,325,326,5, + 111,0,0,326,327,5,110,0,0,327,328,5,95,0,0,328,329,5,99,0,0,329,330,5, + 111,0,0,330,331,5,110,0,0,331,332,5,116,0,0,332,333,5,97,0,0,333,334, + 5,105,0,0,334,335,5,110,0,0,335,336,5,115,0,0,336,337,5,95,0,0,337,338, + 5,97,0,0,338,339,5,110,0,0,339,358,5,121,0,0,340,341,5,74,0,0,341,342, + 5,83,0,0,342,343,5,79,0,0,343,344,5,78,0,0,344,345,5,95,0,0,345,346,5, + 67,0,0,346,347,5,79,0,0,347,348,5,78,0,0,348,349,5,84,0,0,349,350,5,65, + 0,0,350,351,5,73,0,0,351,352,5,78,0,0,352,353,5,83,0,0,353,354,5,95,0, + 0,354,355,5,65,0,0,355,356,5,78,0,0,356,358,5,89,0,0,357,323,1,0,0,0, + 357,340,1,0,0,0,358,68,1,0,0,0,359,360,5,97,0,0,360,361,5,114,0,0,361, + 362,5,114,0,0,362,363,5,97,0,0,363,364,5,121,0,0,364,365,5,95,0,0,365, + 366,5,99,0,0,366,367,5,111,0,0,367,368,5,110,0,0,368,369,5,116,0,0,369, + 370,5,97,0,0,370,371,5,105,0,0,371,372,5,110,0,0,372,388,5,115,0,0,373, + 374,5,65,0,0,374,375,5,82,0,0,375,376,5,82,0,0,376,377,5,65,0,0,377,378, + 5,89,0,0,378,379,5,95,0,0,379,380,5,67,0,0,380,381,5,79,0,0,381,382,5, + 78,0,0,382,383,5,84,0,0,383,384,5,65,0,0,384,385,5,73,0,0,385,386,5,78, + 0,0,386,388,5,83,0,0,387,359,1,0,0,0,387,373,1,0,0,0,388,70,1,0,0,0,389, + 390,5,97,0,0,390,391,5,114,0,0,391,392,5,114,0,0,392,393,5,97,0,0,393, + 394,5,121,0,0,394,395,5,95,0,0,395,396,5,99,0,0,396,397,5,111,0,0,397, + 398,5,110,0,0,398,399,5,116,0,0,399,400,5,97,0,0,400,401,5,105,0,0,401, + 402,5,110,0,0,402,403,5,115,0,0,403,404,5,95,0,0,404,405,5,97,0,0,405, + 406,5,108,0,0,406,426,5,108,0,0,407,408,5,65,0,0,408,409,5,82,0,0,409, + 410,5,82,0,0,410,411,5,65,0,0,411,412,5,89,0,0,412,413,5,95,0,0,413,414, + 5,67,0,0,414,415,5,79,0,0,415,416,5,78,0,0,416,417,5,84,0,0,417,418,5, + 65,0,0,418,419,5,73,0,0,419,420,5,78,0,0,420,421,5,83,0,0,421,422,5,95, + 0,0,422,423,5,65,0,0,423,424,5,76,0,0,424,426,5,76,0,0,425,389,1,0,0, + 0,425,407,1,0,0,0,426,72,1,0,0,0,427,428,5,97,0,0,428,429,5,114,0,0,429, + 430,5,114,0,0,430,431,5,97,0,0,431,432,5,121,0,0,432,433,5,95,0,0,433, + 434,5,99,0,0,434,435,5,111,0,0,435,436,5,110,0,0,436,437,5,116,0,0,437, + 438,5,97,0,0,438,439,5,105,0,0,439,440,5,110,0,0,440,441,5,115,0,0,441, + 442,5,95,0,0,442,443,5,97,0,0,443,444,5,110,0,0,444,464,5,121,0,0,445, + 446,5,65,0,0,446,447,5,82,0,0,447,448,5,82,0,0,448,449,5,65,0,0,449,450, + 5,89,0,0,450,451,5,95,0,0,451,452,5,67,0,0,452,453,5,79,0,0,453,454,5, + 78,0,0,454,455,5,84,0,0,455,456,5,65,0,0,456,457,5,73,0,0,457,458,5,78, + 0,0,458,459,5,83,0,0,459,460,5,95,0,0,460,461,5,65,0,0,461,462,5,78,0, + 0,462,464,5,89,0,0,463,427,1,0,0,0,463,445,1,0,0,0,464,74,1,0,0,0,465, + 466,5,97,0,0,466,467,5,114,0,0,467,468,5,114,0,0,468,469,5,97,0,0,469, + 470,5,121,0,0,470,471,5,95,0,0,471,472,5,108,0,0,472,473,5,101,0,0,473, + 474,5,110,0,0,474,475,5,103,0,0,475,476,5,116,0,0,476,490,5,104,0,0,477, + 478,5,65,0,0,478,479,5,82,0,0,479,480,5,82,0,0,480,481,5,65,0,0,481,482, + 5,89,0,0,482,483,5,95,0,0,483,484,5,76,0,0,484,485,5,69,0,0,485,486,5, + 78,0,0,486,487,5,71,0,0,487,488,5,84,0,0,488,490,5,72,0,0,489,465,1,0, + 0,0,489,477,1,0,0,0,490,76,1,0,0,0,491,492,5,116,0,0,492,493,5,114,0, + 0,493,494,5,117,0,0,494,519,5,101,0,0,495,496,5,84,0,0,496,497,5,114, + 0,0,497,498,5,117,0,0,498,519,5,101,0,0,499,500,5,84,0,0,500,501,5,82, + 0,0,501,502,5,85,0,0,502,519,5,69,0,0,503,504,5,102,0,0,504,505,5,97, + 0,0,505,506,5,108,0,0,506,507,5,115,0,0,507,519,5,101,0,0,508,509,5,70, + 0,0,509,510,5,97,0,0,510,511,5,108,0,0,511,512,5,115,0,0,512,519,5,101, + 0,0,513,514,5,70,0,0,514,515,5,65,0,0,515,516,5,76,0,0,516,517,5,83,0, + 0,517,519,5,69,0,0,518,491,1,0,0,0,518,495,1,0,0,0,518,499,1,0,0,0,518, + 503,1,0,0,0,518,508,1,0,0,0,518,513,1,0,0,0,519,78,1,0,0,0,520,525,3, + 105,52,0,521,525,3,107,53,0,522,525,3,109,54,0,523,525,3,103,51,0,524, + 520,1,0,0,0,524,521,1,0,0,0,524,522,1,0,0,0,524,523,1,0,0,0,525,80,1, + 0,0,0,526,529,3,121,60,0,527,529,3,123,61,0,528,526,1,0,0,0,528,527,1, + 0,0,0,529,82,1,0,0,0,530,535,3,99,49,0,531,534,3,99,49,0,532,534,3,101, + 50,0,533,531,1,0,0,0,533,532,1,0,0,0,534,537,1,0,0,0,535,533,1,0,0,0, + 535,536,1,0,0,0,536,544,1,0,0,0,537,535,1,0,0,0,538,539,5,36,0,0,539, + 540,5,109,0,0,540,541,5,101,0,0,541,542,5,116,0,0,542,544,5,97,0,0,543, + 530,1,0,0,0,543,538,1,0,0,0,544,84,1,0,0,0,545,547,3,89,44,0,546,545, + 1,0,0,0,546,547,1,0,0,0,547,558,1,0,0,0,548,550,5,34,0,0,549,551,3,91, + 45,0,550,549,1,0,0,0,550,551,1,0,0,0,551,552,1,0,0,0,552,559,5,34,0,0, + 553,555,5,39,0,0,554,556,3,93,46,0,555,554,1,0,0,0,555,556,1,0,0,0,556, + 557,1,0,0,0,557,559,5,39,0,0,558,548,1,0,0,0,558,553,1,0,0,0,559,86,1, + 0,0,0,560,568,3,83,41,0,561,564,5,91,0,0,562,565,3,85,42,0,563,565,3, + 105,52,0,564,562,1,0,0,0,564,563,1,0,0,0,565,566,1,0,0,0,566,567,5,93, + 0,0,567,569,1,0,0,0,568,561,1,0,0,0,569,570,1,0,0,0,570,568,1,0,0,0,570, + 571,1,0,0,0,571,88,1,0,0,0,572,573,5,117,0,0,573,576,5,56,0,0,574,576, + 7,0,0,0,575,572,1,0,0,0,575,574,1,0,0,0,576,90,1,0,0,0,577,579,3,95,47, + 0,578,577,1,0,0,0,579,580,1,0,0,0,580,578,1,0,0,0,580,581,1,0,0,0,581, + 92,1,0,0,0,582,584,3,97,48,0,583,582,1,0,0,0,584,585,1,0,0,0,585,583, + 1,0,0,0,585,586,1,0,0,0,586,94,1,0,0,0,587,595,8,1,0,0,588,595,3,137, + 68,0,589,590,5,92,0,0,590,595,5,10,0,0,591,592,5,92,0,0,592,593,5,13, + 0,0,593,595,5,10,0,0,594,587,1,0,0,0,594,588,1,0,0,0,594,589,1,0,0,0, + 594,591,1,0,0,0,595,96,1,0,0,0,596,604,8,2,0,0,597,604,3,137,68,0,598, + 599,5,92,0,0,599,604,5,10,0,0,600,601,5,92,0,0,601,602,5,13,0,0,602,604, + 5,10,0,0,603,596,1,0,0,0,603,597,1,0,0,0,603,598,1,0,0,0,603,600,1,0, + 0,0,604,98,1,0,0,0,605,606,7,3,0,0,606,100,1,0,0,0,607,608,7,4,0,0,608, + 102,1,0,0,0,609,610,5,48,0,0,610,612,7,5,0,0,611,613,7,6,0,0,612,611, + 1,0,0,0,613,614,1,0,0,0,614,612,1,0,0,0,614,615,1,0,0,0,615,104,1,0,0, + 0,616,620,3,111,55,0,617,619,3,101,50,0,618,617,1,0,0,0,619,622,1,0,0, + 0,620,618,1,0,0,0,620,621,1,0,0,0,621,625,1,0,0,0,622,620,1,0,0,0,623, + 625,5,48,0,0,624,616,1,0,0,0,624,623,1,0,0,0,625,106,1,0,0,0,626,630, + 5,48,0,0,627,629,3,113,56,0,628,627,1,0,0,0,629,632,1,0,0,0,630,628,1, + 0,0,0,630,631,1,0,0,0,631,108,1,0,0,0,632,630,1,0,0,0,633,634,5,48,0, + 0,634,635,7,7,0,0,635,636,3,133,66,0,636,110,1,0,0,0,637,638,7,8,0,0, + 638,112,1,0,0,0,639,640,7,9,0,0,640,114,1,0,0,0,641,642,7,10,0,0,642, + 116,1,0,0,0,643,644,3,115,57,0,644,645,3,115,57,0,645,646,3,115,57,0, + 646,647,3,115,57,0,647,118,1,0,0,0,648,649,5,92,0,0,649,650,5,117,0,0, + 650,651,1,0,0,0,651,659,3,117,58,0,652,653,5,92,0,0,653,654,5,85,0,0, + 654,655,1,0,0,0,655,656,3,117,58,0,656,657,3,117,58,0,657,659,1,0,0,0, + 658,648,1,0,0,0,658,652,1,0,0,0,659,120,1,0,0,0,660,662,3,125,62,0,661, + 663,3,127,63,0,662,661,1,0,0,0,662,663,1,0,0,0,663,668,1,0,0,0,664,665, + 3,129,64,0,665,666,3,127,63,0,666,668,1,0,0,0,667,660,1,0,0,0,667,664, + 1,0,0,0,668,122,1,0,0,0,669,670,5,48,0,0,670,673,7,7,0,0,671,674,3,131, + 65,0,672,674,3,133,66,0,673,671,1,0,0,0,673,672,1,0,0,0,674,675,1,0,0, + 0,675,676,3,135,67,0,676,124,1,0,0,0,677,679,3,129,64,0,678,677,1,0,0, + 0,678,679,1,0,0,0,679,680,1,0,0,0,680,681,5,46,0,0,681,686,3,129,64,0, + 682,683,3,129,64,0,683,684,5,46,0,0,684,686,1,0,0,0,685,678,1,0,0,0,685, + 682,1,0,0,0,686,126,1,0,0,0,687,689,7,11,0,0,688,690,7,12,0,0,689,688, + 1,0,0,0,689,690,1,0,0,0,690,691,1,0,0,0,691,692,3,129,64,0,692,128,1, + 0,0,0,693,695,3,101,50,0,694,693,1,0,0,0,695,696,1,0,0,0,696,694,1,0, + 0,0,696,697,1,0,0,0,697,130,1,0,0,0,698,700,3,133,66,0,699,698,1,0,0, + 0,699,700,1,0,0,0,700,701,1,0,0,0,701,702,5,46,0,0,702,707,3,133,66,0, + 703,704,3,133,66,0,704,705,5,46,0,0,705,707,1,0,0,0,706,699,1,0,0,0,706, + 703,1,0,0,0,707,132,1,0,0,0,708,710,3,115,57,0,709,708,1,0,0,0,710,711, + 1,0,0,0,711,709,1,0,0,0,711,712,1,0,0,0,712,134,1,0,0,0,713,715,7,13, + 0,0,714,716,7,12,0,0,715,714,1,0,0,0,715,716,1,0,0,0,716,717,1,0,0,0, + 717,718,3,129,64,0,718,136,1,0,0,0,719,720,5,92,0,0,720,735,7,14,0,0, + 721,722,5,92,0,0,722,724,3,113,56,0,723,725,3,113,56,0,724,723,1,0,0, + 0,724,725,1,0,0,0,725,727,1,0,0,0,726,728,3,113,56,0,727,726,1,0,0,0, + 727,728,1,0,0,0,728,735,1,0,0,0,729,730,5,92,0,0,730,731,5,120,0,0,731, + 732,1,0,0,0,732,735,3,133,66,0,733,735,3,119,59,0,734,719,1,0,0,0,734, + 721,1,0,0,0,734,729,1,0,0,0,734,733,1,0,0,0,735,138,1,0,0,0,736,738,7, + 15,0,0,737,736,1,0,0,0,738,739,1,0,0,0,739,737,1,0,0,0,739,740,1,0,0, + 0,740,741,1,0,0,0,741,742,6,69,0,0,742,140,1,0,0,0,743,745,5,13,0,0,744, + 746,5,10,0,0,745,744,1,0,0,0,745,746,1,0,0,0,746,749,1,0,0,0,747,749, + 5,10,0,0,748,743,1,0,0,0,748,747,1,0,0,0,749,750,1,0,0,0,750,751,6,70, + 0,0,751,142,1,0,0,0,54,0,177,191,223,229,237,252,254,285,321,357,387, + 425,463,489,518,524,528,533,535,543,546,550,555,558,564,570,575,580,585, + 594,603,614,620,624,630,658,662,667,673,678,685,689,696,699,706,711,715, + 724,727,734,739,745,748,1,6,0,0 + }; + staticData->serializedATN = antlr4::atn::SerializedATNView(serializedATNSegment, sizeof(serializedATNSegment) / sizeof(serializedATNSegment[0])); + + antlr4::atn::ATNDeserializer deserializer; + staticData->atn = deserializer.deserialize(staticData->serializedATN); + + const size_t count = staticData->atn->getNumberOfDecisions(); + staticData->decisionToDFA.reserve(count); + for (size_t i = 0; i < count; i++) { + staticData->decisionToDFA.emplace_back(staticData->atn->getDecisionState(i), i); + } + planlexerLexerStaticData = staticData.release(); +} + +} + +PlanLexer::PlanLexer(CharStream *input) : Lexer(input) { + PlanLexer::initialize(); + _interpreter = new atn::LexerATNSimulator(this, *planlexerLexerStaticData->atn, planlexerLexerStaticData->decisionToDFA, planlexerLexerStaticData->sharedContextCache); +} + +PlanLexer::~PlanLexer() { + delete _interpreter; +} + +std::string PlanLexer::getGrammarFileName() const { + return "Plan.g4"; +} + +const std::vector& PlanLexer::getRuleNames() const { + return planlexerLexerStaticData->ruleNames; +} + +const std::vector& PlanLexer::getChannelNames() const { + return planlexerLexerStaticData->channelNames; +} + +const std::vector& PlanLexer::getModeNames() const { + return planlexerLexerStaticData->modeNames; +} + +const dfa::Vocabulary& PlanLexer::getVocabulary() const { + return planlexerLexerStaticData->vocabulary; +} + +antlr4::atn::SerializedATNView PlanLexer::getSerializedATN() const { + return planlexerLexerStaticData->serializedATN; +} + +const atn::ATN& PlanLexer::getATN() const { + return *planlexerLexerStaticData->atn; +} + + + + +void PlanLexer::initialize() { +#if ANTLR4_USE_THREAD_LOCAL_CACHE + planlexerLexerInitialize(); +#else + ::antlr4::internal::call_once(planlexerLexerOnceFlag, planlexerLexerInitialize); +#endif +} diff --git a/src/parser/antlr/PlanLexer.h b/src/parser/antlr/PlanLexer.h new file mode 100644 index 0000000..820ac68 --- /dev/null +++ b/src/parser/antlr/PlanLexer.h @@ -0,0 +1,56 @@ + +// Generated from Plan.g4 by ANTLR 4.13.1 + +#pragma once + + +#include "antlr4-runtime.h" + + + + +class PlanLexer : public antlr4::Lexer { +public: + enum { + T__0 = 1, T__1 = 2, T__2 = 3, T__3 = 4, T__4 = 5, LT = 6, LE = 7, GT = 8, + GE = 9, EQ = 10, NE = 11, LIKE = 12, EXISTS = 13, ADD = 14, SUB = 15, + MUL = 16, DIV = 17, MOD = 18, POW = 19, SHL = 20, SHR = 21, BAND = 22, + BOR = 23, BXOR = 24, AND = 25, OR = 26, BNOT = 27, NOT = 28, IN = 29, + NIN = 30, EmptyTerm = 31, JSONContains = 32, JSONContainsAll = 33, JSONContainsAny = 34, + ArrayContains = 35, ArrayContainsAll = 36, ArrayContainsAny = 37, ArrayLength = 38, + BooleanConstant = 39, IntegerConstant = 40, FloatingConstant = 41, Identifier = 42, + StringLiteral = 43, JSONIdentifier = 44, Whitespace = 45, Newline = 46 + }; + + explicit PlanLexer(antlr4::CharStream *input); + + ~PlanLexer() override; + + + std::string getGrammarFileName() const override; + + const std::vector& getRuleNames() const override; + + const std::vector& getChannelNames() const override; + + const std::vector& getModeNames() const override; + + const antlr4::dfa::Vocabulary& getVocabulary() const override; + + antlr4::atn::SerializedATNView getSerializedATN() const override; + + const antlr4::atn::ATN& getATN() const override; + + // By default the static state used to implement the lexer is lazily initialized during the first + // call to the constructor. You can call this function if you wish to initialize the static state + // ahead of time. + static void initialize(); + +private: + + // Individual action functions triggered by action() above. + + // Individual semantic predicate functions triggered by sempred() above. + +}; + diff --git a/src/parser/antlr/PlanParser.cpp b/src/parser/antlr/PlanParser.cpp new file mode 100644 index 0000000..df85e54 --- /dev/null +++ b/src/parser/antlr/PlanParser.cpp @@ -0,0 +1,1642 @@ + +// Generated from Plan.g4 by ANTLR 4.13.1 + + +#include "PlanVisitor.h" + +#include "PlanParser.h" + + +using namespace antlrcpp; + +using namespace antlr4; + +namespace { + +struct PlanParserStaticData final { + PlanParserStaticData(std::vector ruleNames, + std::vector literalNames, + std::vector symbolicNames) + : ruleNames(std::move(ruleNames)), literalNames(std::move(literalNames)), + symbolicNames(std::move(symbolicNames)), + vocabulary(this->literalNames, this->symbolicNames) {} + + PlanParserStaticData(const PlanParserStaticData&) = delete; + PlanParserStaticData(PlanParserStaticData&&) = delete; + PlanParserStaticData& operator=(const PlanParserStaticData&) = delete; + PlanParserStaticData& operator=(PlanParserStaticData&&) = delete; + + std::vector decisionToDFA; + antlr4::atn::PredictionContextCache sharedContextCache; + const std::vector ruleNames; + const std::vector literalNames; + const std::vector symbolicNames; + const antlr4::dfa::Vocabulary vocabulary; + antlr4::atn::SerializedATNView serializedATN; + std::unique_ptr atn; +}; + +::antlr4::internal::OnceFlag planParserOnceFlag; +#if ANTLR4_USE_THREAD_LOCAL_CACHE +static thread_local +#endif +PlanParserStaticData *planParserStaticData = nullptr; + +void planParserInitialize() { +#if ANTLR4_USE_THREAD_LOCAL_CACHE + if (planParserStaticData != nullptr) { + return; + } +#else + assert(planParserStaticData == nullptr); +#endif + auto staticData = std::make_unique( + std::vector{ + "expr" + }, + std::vector{ + "", "'('", "')'", "'['", "','", "']'", "'<'", "'<='", "'>'", "'>='", + "'=='", "'!='", "", "", "'+'", "'-'", "'*'", "'/'", "'%'", "'**'", + "'<<'", "'>>'", "'&'", "'|'", "'^'", "", "", "'~'", "", "'in'", "'not in'" + }, + std::vector{ + "", "", "", "", "", "", "LT", "LE", "GT", "GE", "EQ", "NE", "LIKE", + "EXISTS", "ADD", "SUB", "MUL", "DIV", "MOD", "POW", "SHL", "SHR", + "BAND", "BOR", "BXOR", "AND", "OR", "BNOT", "NOT", "IN", "NIN", "EmptyTerm", + "JSONContains", "JSONContainsAll", "JSONContainsAny", "ArrayContains", + "ArrayContainsAll", "ArrayContainsAny", "ArrayLength", "BooleanConstant", + "IntegerConstant", "FloatingConstant", "Identifier", "StringLiteral", + "JSONIdentifier", "Whitespace", "Newline" + } + ); + static const int32_t serializedATNSegment[] = { + 4,1,46,129,2,0,7,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1, + 0,1,0,1,0,5,0,18,8,0,10,0,12,0,21,9,0,1,0,3,0,24,8,0,1,0,1,0,1,0,1,0, + 1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1, + 0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,3,0,57,8,0,1,0,1,0,1,0,1,0,1,0, + 1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1, + 0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0, + 1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,5,0,111,8,0,10,0,12,0, + 114,9,0,1,0,3,0,117,8,0,1,0,1,0,1,0,1,0,1,0,5,0,124,8,0,10,0,12,0,127, + 9,0,1,0,0,1,0,1,0,0,13,2,0,14,15,27,28,2,0,32,32,35,35,2,0,33,33,36,36, + 2,0,34,34,37,37,2,0,42,42,44,44,1,0,16,18,1,0,14,15,1,0,20,21,1,0,6,7, + 1,0,8,9,1,0,6,9,1,0,10,11,1,0,29,30,160,0,56,1,0,0,0,2,3,6,0,-1,0,3,57, + 5,40,0,0,4,57,5,41,0,0,5,57,5,39,0,0,6,57,5,43,0,0,7,57,5,42,0,0,8,57, + 5,44,0,0,9,10,5,1,0,0,10,11,3,0,0,0,11,12,5,2,0,0,12,57,1,0,0,0,13,14, + 5,3,0,0,14,19,3,0,0,0,15,16,5,4,0,0,16,18,3,0,0,0,17,15,1,0,0,0,18,21, + 1,0,0,0,19,17,1,0,0,0,19,20,1,0,0,0,20,23,1,0,0,0,21,19,1,0,0,0,22,24, + 5,4,0,0,23,22,1,0,0,0,23,24,1,0,0,0,24,25,1,0,0,0,25,26,5,5,0,0,26,57, + 1,0,0,0,27,28,7,0,0,0,28,57,3,0,0,20,29,30,7,1,0,0,30,31,5,1,0,0,31,32, + 3,0,0,0,32,33,5,4,0,0,33,34,3,0,0,0,34,35,5,2,0,0,35,57,1,0,0,0,36,37, + 7,2,0,0,37,38,5,1,0,0,38,39,3,0,0,0,39,40,5,4,0,0,40,41,3,0,0,0,41,42, + 5,2,0,0,42,57,1,0,0,0,43,44,7,3,0,0,44,45,5,1,0,0,45,46,3,0,0,0,46,47, + 5,4,0,0,47,48,3,0,0,0,48,49,5,2,0,0,49,57,1,0,0,0,50,51,5,38,0,0,51,52, + 5,1,0,0,52,53,7,4,0,0,53,57,5,2,0,0,54,55,5,13,0,0,55,57,3,0,0,1,56,2, + 1,0,0,0,56,4,1,0,0,0,56,5,1,0,0,0,56,6,1,0,0,0,56,7,1,0,0,0,56,8,1,0, + 0,0,56,9,1,0,0,0,56,13,1,0,0,0,56,27,1,0,0,0,56,29,1,0,0,0,56,36,1,0, + 0,0,56,43,1,0,0,0,56,50,1,0,0,0,56,54,1,0,0,0,57,125,1,0,0,0,58,59,10, + 21,0,0,59,60,5,19,0,0,60,124,3,0,0,22,61,62,10,19,0,0,62,63,7,5,0,0,63, + 124,3,0,0,20,64,65,10,18,0,0,65,66,7,6,0,0,66,124,3,0,0,19,67,68,10,17, + 0,0,68,69,7,7,0,0,69,124,3,0,0,18,70,71,10,10,0,0,71,72,7,8,0,0,72,73, + 7,4,0,0,73,74,7,8,0,0,74,124,3,0,0,11,75,76,10,9,0,0,76,77,7,9,0,0,77, + 78,7,4,0,0,78,79,7,9,0,0,79,124,3,0,0,10,80,81,10,8,0,0,81,82,7,10,0, + 0,82,124,3,0,0,9,83,84,10,7,0,0,84,85,7,11,0,0,85,124,3,0,0,8,86,87,10, + 6,0,0,87,88,5,22,0,0,88,124,3,0,0,7,89,90,10,5,0,0,90,91,5,24,0,0,91, + 124,3,0,0,6,92,93,10,4,0,0,93,94,5,23,0,0,94,124,3,0,0,5,95,96,10,3,0, + 0,96,97,5,25,0,0,97,124,3,0,0,4,98,99,10,2,0,0,99,100,5,26,0,0,100,124, + 3,0,0,3,101,102,10,22,0,0,102,103,5,12,0,0,103,124,5,43,0,0,104,105,10, + 16,0,0,105,106,7,12,0,0,106,107,5,3,0,0,107,112,3,0,0,0,108,109,5,4,0, + 0,109,111,3,0,0,0,110,108,1,0,0,0,111,114,1,0,0,0,112,110,1,0,0,0,112, + 113,1,0,0,0,113,116,1,0,0,0,114,112,1,0,0,0,115,117,5,4,0,0,116,115,1, + 0,0,0,116,117,1,0,0,0,117,118,1,0,0,0,118,119,5,5,0,0,119,124,1,0,0,0, + 120,121,10,15,0,0,121,122,7,12,0,0,122,124,5,31,0,0,123,58,1,0,0,0,123, + 61,1,0,0,0,123,64,1,0,0,0,123,67,1,0,0,0,123,70,1,0,0,0,123,75,1,0,0, + 0,123,80,1,0,0,0,123,83,1,0,0,0,123,86,1,0,0,0,123,89,1,0,0,0,123,92, + 1,0,0,0,123,95,1,0,0,0,123,98,1,0,0,0,123,101,1,0,0,0,123,104,1,0,0,0, + 123,120,1,0,0,0,124,127,1,0,0,0,125,123,1,0,0,0,125,126,1,0,0,0,126,1, + 1,0,0,0,127,125,1,0,0,0,7,19,23,56,112,116,123,125 + }; + staticData->serializedATN = antlr4::atn::SerializedATNView(serializedATNSegment, sizeof(serializedATNSegment) / sizeof(serializedATNSegment[0])); + + antlr4::atn::ATNDeserializer deserializer; + staticData->atn = deserializer.deserialize(staticData->serializedATN); + + const size_t count = staticData->atn->getNumberOfDecisions(); + staticData->decisionToDFA.reserve(count); + for (size_t i = 0; i < count; i++) { + staticData->decisionToDFA.emplace_back(staticData->atn->getDecisionState(i), i); + } + planParserStaticData = staticData.release(); +} + +} + +PlanParser::PlanParser(TokenStream *input) : PlanParser(input, antlr4::atn::ParserATNSimulatorOptions()) {} + +PlanParser::PlanParser(TokenStream *input, const antlr4::atn::ParserATNSimulatorOptions &options) : Parser(input) { + PlanParser::initialize(); + _interpreter = new atn::ParserATNSimulator(this, *planParserStaticData->atn, planParserStaticData->decisionToDFA, planParserStaticData->sharedContextCache, options); +} + +PlanParser::~PlanParser() { + delete _interpreter; +} + +const atn::ATN& PlanParser::getATN() const { + return *planParserStaticData->atn; +} + +std::string PlanParser::getGrammarFileName() const { + return "Plan.g4"; +} + +const std::vector& PlanParser::getRuleNames() const { + return planParserStaticData->ruleNames; +} + +const dfa::Vocabulary& PlanParser::getVocabulary() const { + return planParserStaticData->vocabulary; +} + +antlr4::atn::SerializedATNView PlanParser::getSerializedATN() const { + return planParserStaticData->serializedATN; +} + + +//----------------- ExprContext ------------------------------------------------------------------ + +PlanParser::ExprContext::ExprContext(ParserRuleContext *parent, size_t invokingState) + : ParserRuleContext(parent, invokingState) { +} + + +size_t PlanParser::ExprContext::getRuleIndex() const { + return PlanParser::RuleExpr; +} + +void PlanParser::ExprContext::copyFrom(ExprContext *ctx) { + ParserRuleContext::copyFrom(ctx); +} + +//----------------- JSONIdentifierContext ------------------------------------------------------------------ + +tree::TerminalNode* PlanParser::JSONIdentifierContext::JSONIdentifier() { + return getToken(PlanParser::JSONIdentifier, 0); +} + +PlanParser::JSONIdentifierContext::JSONIdentifierContext(ExprContext *ctx) { copyFrom(ctx); } + + +std::any PlanParser::JSONIdentifierContext::accept(tree::ParseTreeVisitor *visitor) { + if (auto parserVisitor = dynamic_cast(visitor)) + return parserVisitor->visitJSONIdentifier(this); + else + return visitor->visitChildren(this); +} +//----------------- ParensContext ------------------------------------------------------------------ + +PlanParser::ExprContext* PlanParser::ParensContext::expr() { + return getRuleContext(0); +} + +PlanParser::ParensContext::ParensContext(ExprContext *ctx) { copyFrom(ctx); } + + +std::any PlanParser::ParensContext::accept(tree::ParseTreeVisitor *visitor) { + if (auto parserVisitor = dynamic_cast(visitor)) + return parserVisitor->visitParens(this); + else + return visitor->visitChildren(this); +} +//----------------- StringContext ------------------------------------------------------------------ + +tree::TerminalNode* PlanParser::StringContext::StringLiteral() { + return getToken(PlanParser::StringLiteral, 0); +} + +PlanParser::StringContext::StringContext(ExprContext *ctx) { copyFrom(ctx); } + + +std::any PlanParser::StringContext::accept(tree::ParseTreeVisitor *visitor) { + if (auto parserVisitor = dynamic_cast(visitor)) + return parserVisitor->visitString(this); + else + return visitor->visitChildren(this); +} +//----------------- FloatingContext ------------------------------------------------------------------ + +tree::TerminalNode* PlanParser::FloatingContext::FloatingConstant() { + return getToken(PlanParser::FloatingConstant, 0); +} + +PlanParser::FloatingContext::FloatingContext(ExprContext *ctx) { copyFrom(ctx); } + + +std::any PlanParser::FloatingContext::accept(tree::ParseTreeVisitor *visitor) { + if (auto parserVisitor = dynamic_cast(visitor)) + return parserVisitor->visitFloating(this); + else + return visitor->visitChildren(this); +} +//----------------- JSONContainsAllContext ------------------------------------------------------------------ + +std::vector PlanParser::JSONContainsAllContext::expr() { + return getRuleContexts(); +} + +PlanParser::ExprContext* PlanParser::JSONContainsAllContext::expr(size_t i) { + return getRuleContext(i); +} + +tree::TerminalNode* PlanParser::JSONContainsAllContext::JSONContainsAll() { + return getToken(PlanParser::JSONContainsAll, 0); +} + +tree::TerminalNode* PlanParser::JSONContainsAllContext::ArrayContainsAll() { + return getToken(PlanParser::ArrayContainsAll, 0); +} + +PlanParser::JSONContainsAllContext::JSONContainsAllContext(ExprContext *ctx) { copyFrom(ctx); } + + +std::any PlanParser::JSONContainsAllContext::accept(tree::ParseTreeVisitor *visitor) { + if (auto parserVisitor = dynamic_cast(visitor)) + return parserVisitor->visitJSONContainsAll(this); + else + return visitor->visitChildren(this); +} +//----------------- LogicalOrContext ------------------------------------------------------------------ + +std::vector PlanParser::LogicalOrContext::expr() { + return getRuleContexts(); +} + +PlanParser::ExprContext* PlanParser::LogicalOrContext::expr(size_t i) { + return getRuleContext(i); +} + +tree::TerminalNode* PlanParser::LogicalOrContext::OR() { + return getToken(PlanParser::OR, 0); +} + +PlanParser::LogicalOrContext::LogicalOrContext(ExprContext *ctx) { copyFrom(ctx); } + + +std::any PlanParser::LogicalOrContext::accept(tree::ParseTreeVisitor *visitor) { + if (auto parserVisitor = dynamic_cast(visitor)) + return parserVisitor->visitLogicalOr(this); + else + return visitor->visitChildren(this); +} +//----------------- MulDivModContext ------------------------------------------------------------------ + +std::vector PlanParser::MulDivModContext::expr() { + return getRuleContexts(); +} + +PlanParser::ExprContext* PlanParser::MulDivModContext::expr(size_t i) { + return getRuleContext(i); +} + +tree::TerminalNode* PlanParser::MulDivModContext::MUL() { + return getToken(PlanParser::MUL, 0); +} + +tree::TerminalNode* PlanParser::MulDivModContext::DIV() { + return getToken(PlanParser::DIV, 0); +} + +tree::TerminalNode* PlanParser::MulDivModContext::MOD() { + return getToken(PlanParser::MOD, 0); +} + +PlanParser::MulDivModContext::MulDivModContext(ExprContext *ctx) { copyFrom(ctx); } + + +std::any PlanParser::MulDivModContext::accept(tree::ParseTreeVisitor *visitor) { + if (auto parserVisitor = dynamic_cast(visitor)) + return parserVisitor->visitMulDivMod(this); + else + return visitor->visitChildren(this); +} +//----------------- IdentifierContext ------------------------------------------------------------------ + +tree::TerminalNode* PlanParser::IdentifierContext::Identifier() { + return getToken(PlanParser::Identifier, 0); +} + +PlanParser::IdentifierContext::IdentifierContext(ExprContext *ctx) { copyFrom(ctx); } + + +std::any PlanParser::IdentifierContext::accept(tree::ParseTreeVisitor *visitor) { + if (auto parserVisitor = dynamic_cast(visitor)) + return parserVisitor->visitIdentifier(this); + else + return visitor->visitChildren(this); +} +//----------------- LikeContext ------------------------------------------------------------------ + +PlanParser::ExprContext* PlanParser::LikeContext::expr() { + return getRuleContext(0); +} + +tree::TerminalNode* PlanParser::LikeContext::LIKE() { + return getToken(PlanParser::LIKE, 0); +} + +tree::TerminalNode* PlanParser::LikeContext::StringLiteral() { + return getToken(PlanParser::StringLiteral, 0); +} + +PlanParser::LikeContext::LikeContext(ExprContext *ctx) { copyFrom(ctx); } + + +std::any PlanParser::LikeContext::accept(tree::ParseTreeVisitor *visitor) { + if (auto parserVisitor = dynamic_cast(visitor)) + return parserVisitor->visitLike(this); + else + return visitor->visitChildren(this); +} +//----------------- LogicalAndContext ------------------------------------------------------------------ + +std::vector PlanParser::LogicalAndContext::expr() { + return getRuleContexts(); +} + +PlanParser::ExprContext* PlanParser::LogicalAndContext::expr(size_t i) { + return getRuleContext(i); +} + +tree::TerminalNode* PlanParser::LogicalAndContext::AND() { + return getToken(PlanParser::AND, 0); +} + +PlanParser::LogicalAndContext::LogicalAndContext(ExprContext *ctx) { copyFrom(ctx); } + + +std::any PlanParser::LogicalAndContext::accept(tree::ParseTreeVisitor *visitor) { + if (auto parserVisitor = dynamic_cast(visitor)) + return parserVisitor->visitLogicalAnd(this); + else + return visitor->visitChildren(this); +} +//----------------- EqualityContext ------------------------------------------------------------------ + +std::vector PlanParser::EqualityContext::expr() { + return getRuleContexts(); +} + +PlanParser::ExprContext* PlanParser::EqualityContext::expr(size_t i) { + return getRuleContext(i); +} + +tree::TerminalNode* PlanParser::EqualityContext::EQ() { + return getToken(PlanParser::EQ, 0); +} + +tree::TerminalNode* PlanParser::EqualityContext::NE() { + return getToken(PlanParser::NE, 0); +} + +PlanParser::EqualityContext::EqualityContext(ExprContext *ctx) { copyFrom(ctx); } + + +std::any PlanParser::EqualityContext::accept(tree::ParseTreeVisitor *visitor) { + if (auto parserVisitor = dynamic_cast(visitor)) + return parserVisitor->visitEquality(this); + else + return visitor->visitChildren(this); +} +//----------------- BooleanContext ------------------------------------------------------------------ + +tree::TerminalNode* PlanParser::BooleanContext::BooleanConstant() { + return getToken(PlanParser::BooleanConstant, 0); +} + +PlanParser::BooleanContext::BooleanContext(ExprContext *ctx) { copyFrom(ctx); } + + +std::any PlanParser::BooleanContext::accept(tree::ParseTreeVisitor *visitor) { + if (auto parserVisitor = dynamic_cast(visitor)) + return parserVisitor->visitBoolean(this); + else + return visitor->visitChildren(this); +} +//----------------- ShiftContext ------------------------------------------------------------------ + +std::vector PlanParser::ShiftContext::expr() { + return getRuleContexts(); +} + +PlanParser::ExprContext* PlanParser::ShiftContext::expr(size_t i) { + return getRuleContext(i); +} + +tree::TerminalNode* PlanParser::ShiftContext::SHL() { + return getToken(PlanParser::SHL, 0); +} + +tree::TerminalNode* PlanParser::ShiftContext::SHR() { + return getToken(PlanParser::SHR, 0); +} + +PlanParser::ShiftContext::ShiftContext(ExprContext *ctx) { copyFrom(ctx); } + + +std::any PlanParser::ShiftContext::accept(tree::ParseTreeVisitor *visitor) { + if (auto parserVisitor = dynamic_cast(visitor)) + return parserVisitor->visitShift(this); + else + return visitor->visitChildren(this); +} +//----------------- ReverseRangeContext ------------------------------------------------------------------ + +std::vector PlanParser::ReverseRangeContext::expr() { + return getRuleContexts(); +} + +PlanParser::ExprContext* PlanParser::ReverseRangeContext::expr(size_t i) { + return getRuleContext(i); +} + +tree::TerminalNode* PlanParser::ReverseRangeContext::Identifier() { + return getToken(PlanParser::Identifier, 0); +} + +tree::TerminalNode* PlanParser::ReverseRangeContext::JSONIdentifier() { + return getToken(PlanParser::JSONIdentifier, 0); +} + +std::vector PlanParser::ReverseRangeContext::GT() { + return getTokens(PlanParser::GT); +} + +tree::TerminalNode* PlanParser::ReverseRangeContext::GT(size_t i) { + return getToken(PlanParser::GT, i); +} + +std::vector PlanParser::ReverseRangeContext::GE() { + return getTokens(PlanParser::GE); +} + +tree::TerminalNode* PlanParser::ReverseRangeContext::GE(size_t i) { + return getToken(PlanParser::GE, i); +} + +PlanParser::ReverseRangeContext::ReverseRangeContext(ExprContext *ctx) { copyFrom(ctx); } + + +std::any PlanParser::ReverseRangeContext::accept(tree::ParseTreeVisitor *visitor) { + if (auto parserVisitor = dynamic_cast(visitor)) + return parserVisitor->visitReverseRange(this); + else + return visitor->visitChildren(this); +} +//----------------- BitOrContext ------------------------------------------------------------------ + +std::vector PlanParser::BitOrContext::expr() { + return getRuleContexts(); +} + +PlanParser::ExprContext* PlanParser::BitOrContext::expr(size_t i) { + return getRuleContext(i); +} + +tree::TerminalNode* PlanParser::BitOrContext::BOR() { + return getToken(PlanParser::BOR, 0); +} + +PlanParser::BitOrContext::BitOrContext(ExprContext *ctx) { copyFrom(ctx); } + + +std::any PlanParser::BitOrContext::accept(tree::ParseTreeVisitor *visitor) { + if (auto parserVisitor = dynamic_cast(visitor)) + return parserVisitor->visitBitOr(this); + else + return visitor->visitChildren(this); +} +//----------------- AddSubContext ------------------------------------------------------------------ + +std::vector PlanParser::AddSubContext::expr() { + return getRuleContexts(); +} + +PlanParser::ExprContext* PlanParser::AddSubContext::expr(size_t i) { + return getRuleContext(i); +} + +tree::TerminalNode* PlanParser::AddSubContext::ADD() { + return getToken(PlanParser::ADD, 0); +} + +tree::TerminalNode* PlanParser::AddSubContext::SUB() { + return getToken(PlanParser::SUB, 0); +} + +PlanParser::AddSubContext::AddSubContext(ExprContext *ctx) { copyFrom(ctx); } + + +std::any PlanParser::AddSubContext::accept(tree::ParseTreeVisitor *visitor) { + if (auto parserVisitor = dynamic_cast(visitor)) + return parserVisitor->visitAddSub(this); + else + return visitor->visitChildren(this); +} +//----------------- RelationalContext ------------------------------------------------------------------ + +std::vector PlanParser::RelationalContext::expr() { + return getRuleContexts(); +} + +PlanParser::ExprContext* PlanParser::RelationalContext::expr(size_t i) { + return getRuleContext(i); +} + +tree::TerminalNode* PlanParser::RelationalContext::LT() { + return getToken(PlanParser::LT, 0); +} + +tree::TerminalNode* PlanParser::RelationalContext::LE() { + return getToken(PlanParser::LE, 0); +} + +tree::TerminalNode* PlanParser::RelationalContext::GT() { + return getToken(PlanParser::GT, 0); +} + +tree::TerminalNode* PlanParser::RelationalContext::GE() { + return getToken(PlanParser::GE, 0); +} + +PlanParser::RelationalContext::RelationalContext(ExprContext *ctx) { copyFrom(ctx); } + + +std::any PlanParser::RelationalContext::accept(tree::ParseTreeVisitor *visitor) { + if (auto parserVisitor = dynamic_cast(visitor)) + return parserVisitor->visitRelational(this); + else + return visitor->visitChildren(this); +} +//----------------- ArrayLengthContext ------------------------------------------------------------------ + +tree::TerminalNode* PlanParser::ArrayLengthContext::ArrayLength() { + return getToken(PlanParser::ArrayLength, 0); +} + +tree::TerminalNode* PlanParser::ArrayLengthContext::Identifier() { + return getToken(PlanParser::Identifier, 0); +} + +tree::TerminalNode* PlanParser::ArrayLengthContext::JSONIdentifier() { + return getToken(PlanParser::JSONIdentifier, 0); +} + +PlanParser::ArrayLengthContext::ArrayLengthContext(ExprContext *ctx) { copyFrom(ctx); } + + +std::any PlanParser::ArrayLengthContext::accept(tree::ParseTreeVisitor *visitor) { + if (auto parserVisitor = dynamic_cast(visitor)) + return parserVisitor->visitArrayLength(this); + else + return visitor->visitChildren(this); +} +//----------------- TermContext ------------------------------------------------------------------ + +std::vector PlanParser::TermContext::expr() { + return getRuleContexts(); +} + +PlanParser::ExprContext* PlanParser::TermContext::expr(size_t i) { + return getRuleContext(i); +} + +tree::TerminalNode* PlanParser::TermContext::IN() { + return getToken(PlanParser::IN, 0); +} + +tree::TerminalNode* PlanParser::TermContext::NIN() { + return getToken(PlanParser::NIN, 0); +} + +PlanParser::TermContext::TermContext(ExprContext *ctx) { copyFrom(ctx); } + + +std::any PlanParser::TermContext::accept(tree::ParseTreeVisitor *visitor) { + if (auto parserVisitor = dynamic_cast(visitor)) + return parserVisitor->visitTerm(this); + else + return visitor->visitChildren(this); +} +//----------------- JSONContainsContext ------------------------------------------------------------------ + +std::vector PlanParser::JSONContainsContext::expr() { + return getRuleContexts(); +} + +PlanParser::ExprContext* PlanParser::JSONContainsContext::expr(size_t i) { + return getRuleContext(i); +} + +tree::TerminalNode* PlanParser::JSONContainsContext::JSONContains() { + return getToken(PlanParser::JSONContains, 0); +} + +tree::TerminalNode* PlanParser::JSONContainsContext::ArrayContains() { + return getToken(PlanParser::ArrayContains, 0); +} + +PlanParser::JSONContainsContext::JSONContainsContext(ExprContext *ctx) { copyFrom(ctx); } + + +std::any PlanParser::JSONContainsContext::accept(tree::ParseTreeVisitor *visitor) { + if (auto parserVisitor = dynamic_cast(visitor)) + return parserVisitor->visitJSONContains(this); + else + return visitor->visitChildren(this); +} +//----------------- RangeContext ------------------------------------------------------------------ + +std::vector PlanParser::RangeContext::expr() { + return getRuleContexts(); +} + +PlanParser::ExprContext* PlanParser::RangeContext::expr(size_t i) { + return getRuleContext(i); +} + +tree::TerminalNode* PlanParser::RangeContext::Identifier() { + return getToken(PlanParser::Identifier, 0); +} + +tree::TerminalNode* PlanParser::RangeContext::JSONIdentifier() { + return getToken(PlanParser::JSONIdentifier, 0); +} + +std::vector PlanParser::RangeContext::LT() { + return getTokens(PlanParser::LT); +} + +tree::TerminalNode* PlanParser::RangeContext::LT(size_t i) { + return getToken(PlanParser::LT, i); +} + +std::vector PlanParser::RangeContext::LE() { + return getTokens(PlanParser::LE); +} + +tree::TerminalNode* PlanParser::RangeContext::LE(size_t i) { + return getToken(PlanParser::LE, i); +} + +PlanParser::RangeContext::RangeContext(ExprContext *ctx) { copyFrom(ctx); } + + +std::any PlanParser::RangeContext::accept(tree::ParseTreeVisitor *visitor) { + if (auto parserVisitor = dynamic_cast(visitor)) + return parserVisitor->visitRange(this); + else + return visitor->visitChildren(this); +} +//----------------- UnaryContext ------------------------------------------------------------------ + +PlanParser::ExprContext* PlanParser::UnaryContext::expr() { + return getRuleContext(0); +} + +tree::TerminalNode* PlanParser::UnaryContext::ADD() { + return getToken(PlanParser::ADD, 0); +} + +tree::TerminalNode* PlanParser::UnaryContext::SUB() { + return getToken(PlanParser::SUB, 0); +} + +tree::TerminalNode* PlanParser::UnaryContext::BNOT() { + return getToken(PlanParser::BNOT, 0); +} + +tree::TerminalNode* PlanParser::UnaryContext::NOT() { + return getToken(PlanParser::NOT, 0); +} + +PlanParser::UnaryContext::UnaryContext(ExprContext *ctx) { copyFrom(ctx); } + + +std::any PlanParser::UnaryContext::accept(tree::ParseTreeVisitor *visitor) { + if (auto parserVisitor = dynamic_cast(visitor)) + return parserVisitor->visitUnary(this); + else + return visitor->visitChildren(this); +} +//----------------- IntegerContext ------------------------------------------------------------------ + +tree::TerminalNode* PlanParser::IntegerContext::IntegerConstant() { + return getToken(PlanParser::IntegerConstant, 0); +} + +PlanParser::IntegerContext::IntegerContext(ExprContext *ctx) { copyFrom(ctx); } + + +std::any PlanParser::IntegerContext::accept(tree::ParseTreeVisitor *visitor) { + if (auto parserVisitor = dynamic_cast(visitor)) + return parserVisitor->visitInteger(this); + else + return visitor->visitChildren(this); +} +//----------------- ArrayContext ------------------------------------------------------------------ + +std::vector PlanParser::ArrayContext::expr() { + return getRuleContexts(); +} + +PlanParser::ExprContext* PlanParser::ArrayContext::expr(size_t i) { + return getRuleContext(i); +} + +PlanParser::ArrayContext::ArrayContext(ExprContext *ctx) { copyFrom(ctx); } + + +std::any PlanParser::ArrayContext::accept(tree::ParseTreeVisitor *visitor) { + if (auto parserVisitor = dynamic_cast(visitor)) + return parserVisitor->visitArray(this); + else + return visitor->visitChildren(this); +} +//----------------- JSONContainsAnyContext ------------------------------------------------------------------ + +std::vector PlanParser::JSONContainsAnyContext::expr() { + return getRuleContexts(); +} + +PlanParser::ExprContext* PlanParser::JSONContainsAnyContext::expr(size_t i) { + return getRuleContext(i); +} + +tree::TerminalNode* PlanParser::JSONContainsAnyContext::JSONContainsAny() { + return getToken(PlanParser::JSONContainsAny, 0); +} + +tree::TerminalNode* PlanParser::JSONContainsAnyContext::ArrayContainsAny() { + return getToken(PlanParser::ArrayContainsAny, 0); +} + +PlanParser::JSONContainsAnyContext::JSONContainsAnyContext(ExprContext *ctx) { copyFrom(ctx); } + + +std::any PlanParser::JSONContainsAnyContext::accept(tree::ParseTreeVisitor *visitor) { + if (auto parserVisitor = dynamic_cast(visitor)) + return parserVisitor->visitJSONContainsAny(this); + else + return visitor->visitChildren(this); +} +//----------------- BitXorContext ------------------------------------------------------------------ + +std::vector PlanParser::BitXorContext::expr() { + return getRuleContexts(); +} + +PlanParser::ExprContext* PlanParser::BitXorContext::expr(size_t i) { + return getRuleContext(i); +} + +tree::TerminalNode* PlanParser::BitXorContext::BXOR() { + return getToken(PlanParser::BXOR, 0); +} + +PlanParser::BitXorContext::BitXorContext(ExprContext *ctx) { copyFrom(ctx); } + + +std::any PlanParser::BitXorContext::accept(tree::ParseTreeVisitor *visitor) { + if (auto parserVisitor = dynamic_cast(visitor)) + return parserVisitor->visitBitXor(this); + else + return visitor->visitChildren(this); +} +//----------------- ExistsContext ------------------------------------------------------------------ + +tree::TerminalNode* PlanParser::ExistsContext::EXISTS() { + return getToken(PlanParser::EXISTS, 0); +} + +PlanParser::ExprContext* PlanParser::ExistsContext::expr() { + return getRuleContext(0); +} + +PlanParser::ExistsContext::ExistsContext(ExprContext *ctx) { copyFrom(ctx); } + + +std::any PlanParser::ExistsContext::accept(tree::ParseTreeVisitor *visitor) { + if (auto parserVisitor = dynamic_cast(visitor)) + return parserVisitor->visitExists(this); + else + return visitor->visitChildren(this); +} +//----------------- BitAndContext ------------------------------------------------------------------ + +std::vector PlanParser::BitAndContext::expr() { + return getRuleContexts(); +} + +PlanParser::ExprContext* PlanParser::BitAndContext::expr(size_t i) { + return getRuleContext(i); +} + +tree::TerminalNode* PlanParser::BitAndContext::BAND() { + return getToken(PlanParser::BAND, 0); +} + +PlanParser::BitAndContext::BitAndContext(ExprContext *ctx) { copyFrom(ctx); } + + +std::any PlanParser::BitAndContext::accept(tree::ParseTreeVisitor *visitor) { + if (auto parserVisitor = dynamic_cast(visitor)) + return parserVisitor->visitBitAnd(this); + else + return visitor->visitChildren(this); +} +//----------------- EmptyTermContext ------------------------------------------------------------------ + +PlanParser::ExprContext* PlanParser::EmptyTermContext::expr() { + return getRuleContext(0); +} + +tree::TerminalNode* PlanParser::EmptyTermContext::EmptyTerm() { + return getToken(PlanParser::EmptyTerm, 0); +} + +tree::TerminalNode* PlanParser::EmptyTermContext::IN() { + return getToken(PlanParser::IN, 0); +} + +tree::TerminalNode* PlanParser::EmptyTermContext::NIN() { + return getToken(PlanParser::NIN, 0); +} + +PlanParser::EmptyTermContext::EmptyTermContext(ExprContext *ctx) { copyFrom(ctx); } + + +std::any PlanParser::EmptyTermContext::accept(tree::ParseTreeVisitor *visitor) { + if (auto parserVisitor = dynamic_cast(visitor)) + return parserVisitor->visitEmptyTerm(this); + else + return visitor->visitChildren(this); +} +//----------------- PowerContext ------------------------------------------------------------------ + +std::vector PlanParser::PowerContext::expr() { + return getRuleContexts(); +} + +PlanParser::ExprContext* PlanParser::PowerContext::expr(size_t i) { + return getRuleContext(i); +} + +tree::TerminalNode* PlanParser::PowerContext::POW() { + return getToken(PlanParser::POW, 0); +} + +PlanParser::PowerContext::PowerContext(ExprContext *ctx) { copyFrom(ctx); } + + +std::any PlanParser::PowerContext::accept(tree::ParseTreeVisitor *visitor) { + if (auto parserVisitor = dynamic_cast(visitor)) + return parserVisitor->visitPower(this); + else + return visitor->visitChildren(this); +} + +PlanParser::ExprContext* PlanParser::expr() { + return expr(0); +} + +PlanParser::ExprContext* PlanParser::expr(int precedence) { + ParserRuleContext *parentContext = _ctx; + size_t parentState = getState(); + PlanParser::ExprContext *_localctx = _tracker.createInstance(_ctx, parentState); + PlanParser::ExprContext *previousContext = _localctx; + (void)previousContext; // Silence compiler, in case the context is not used by generated code. + size_t startState = 0; + enterRecursionRule(_localctx, 0, PlanParser::RuleExpr, precedence); + + size_t _la = 0; + +#if __cplusplus > 201703L + auto onExit = finally([=, this] { +#else + auto onExit = finally([=] { +#endif + unrollRecursionContexts(parentContext); + }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(56); + _errHandler->sync(this); + switch (_input->LA(1)) { + case PlanParser::IntegerConstant: { + _localctx = _tracker.createInstance(_localctx); + _ctx = _localctx; + previousContext = _localctx; + + setState(3); + match(PlanParser::IntegerConstant); + break; + } + + case PlanParser::FloatingConstant: { + _localctx = _tracker.createInstance(_localctx); + _ctx = _localctx; + previousContext = _localctx; + setState(4); + match(PlanParser::FloatingConstant); + break; + } + + case PlanParser::BooleanConstant: { + _localctx = _tracker.createInstance(_localctx); + _ctx = _localctx; + previousContext = _localctx; + setState(5); + match(PlanParser::BooleanConstant); + break; + } + + case PlanParser::StringLiteral: { + _localctx = _tracker.createInstance(_localctx); + _ctx = _localctx; + previousContext = _localctx; + setState(6); + match(PlanParser::StringLiteral); + break; + } + + case PlanParser::Identifier: { + _localctx = _tracker.createInstance(_localctx); + _ctx = _localctx; + previousContext = _localctx; + setState(7); + match(PlanParser::Identifier); + break; + } + + case PlanParser::JSONIdentifier: { + _localctx = _tracker.createInstance(_localctx); + _ctx = _localctx; + previousContext = _localctx; + setState(8); + match(PlanParser::JSONIdentifier); + break; + } + + case PlanParser::T__0: { + _localctx = _tracker.createInstance(_localctx); + _ctx = _localctx; + previousContext = _localctx; + setState(9); + match(PlanParser::T__0); + setState(10); + expr(0); + setState(11); + match(PlanParser::T__1); + break; + } + + case PlanParser::T__2: { + _localctx = _tracker.createInstance(_localctx); + _ctx = _localctx; + previousContext = _localctx; + setState(13); + match(PlanParser::T__2); + setState(14); + expr(0); + setState(19); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 0, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(15); + match(PlanParser::T__3); + setState(16); + expr(0); + } + setState(21); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 0, _ctx); + } + setState(23); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == PlanParser::T__3) { + setState(22); + match(PlanParser::T__3); + } + setState(25); + match(PlanParser::T__4); + break; + } + + case PlanParser::ADD: + case PlanParser::SUB: + case PlanParser::BNOT: + case PlanParser::NOT: { + _localctx = _tracker.createInstance(_localctx); + _ctx = _localctx; + previousContext = _localctx; + setState(27); + antlrcpp::downCast(_localctx)->op = _input->LT(1); + _la = _input->LA(1); + if (!((((_la & ~ 0x3fULL) == 0) && + ((1ULL << _la) & 402702336) != 0))) { + antlrcpp::downCast(_localctx)->op = _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + setState(28); + expr(20); + break; + } + + case PlanParser::JSONContains: + case PlanParser::ArrayContains: { + _localctx = _tracker.createInstance(_localctx); + _ctx = _localctx; + previousContext = _localctx; + setState(29); + _la = _input->LA(1); + if (!(_la == PlanParser::JSONContains + + || _la == PlanParser::ArrayContains)) { + _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + setState(30); + match(PlanParser::T__0); + setState(31); + expr(0); + setState(32); + match(PlanParser::T__3); + setState(33); + expr(0); + setState(34); + match(PlanParser::T__1); + break; + } + + case PlanParser::JSONContainsAll: + case PlanParser::ArrayContainsAll: { + _localctx = _tracker.createInstance(_localctx); + _ctx = _localctx; + previousContext = _localctx; + setState(36); + _la = _input->LA(1); + if (!(_la == PlanParser::JSONContainsAll + + || _la == PlanParser::ArrayContainsAll)) { + _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + setState(37); + match(PlanParser::T__0); + setState(38); + expr(0); + setState(39); + match(PlanParser::T__3); + setState(40); + expr(0); + setState(41); + match(PlanParser::T__1); + break; + } + + case PlanParser::JSONContainsAny: + case PlanParser::ArrayContainsAny: { + _localctx = _tracker.createInstance(_localctx); + _ctx = _localctx; + previousContext = _localctx; + setState(43); + _la = _input->LA(1); + if (!(_la == PlanParser::JSONContainsAny + + || _la == PlanParser::ArrayContainsAny)) { + _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + setState(44); + match(PlanParser::T__0); + setState(45); + expr(0); + setState(46); + match(PlanParser::T__3); + setState(47); + expr(0); + setState(48); + match(PlanParser::T__1); + break; + } + + case PlanParser::ArrayLength: { + _localctx = _tracker.createInstance(_localctx); + _ctx = _localctx; + previousContext = _localctx; + setState(50); + match(PlanParser::ArrayLength); + setState(51); + match(PlanParser::T__0); + setState(52); + _la = _input->LA(1); + if (!(_la == PlanParser::Identifier + + || _la == PlanParser::JSONIdentifier)) { + _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + setState(53); + match(PlanParser::T__1); + break; + } + + case PlanParser::EXISTS: { + _localctx = _tracker.createInstance(_localctx); + _ctx = _localctx; + previousContext = _localctx; + setState(54); + match(PlanParser::EXISTS); + setState(55); + expr(1); + break; + } + + default: + throw NoViableAltException(this); + } + _ctx->stop = _input->LT(-1); + setState(125); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 6, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + if (!_parseListeners.empty()) + triggerExitRuleEvent(); + previousContext = _localctx; + setState(123); + _errHandler->sync(this); + switch (getInterpreter()->adaptivePredict(_input, 5, _ctx)) { + case 1: { + auto newContext = _tracker.createInstance(_tracker.createInstance(parentContext, parentState)); + _localctx = newContext; + pushNewRecursionContext(newContext, startState, RuleExpr); + setState(58); + + if (!(precpred(_ctx, 21))) throw FailedPredicateException(this, "precpred(_ctx, 21)"); + setState(59); + match(PlanParser::POW); + setState(60); + expr(22); + break; + } + + case 2: { + auto newContext = _tracker.createInstance(_tracker.createInstance(parentContext, parentState)); + _localctx = newContext; + pushNewRecursionContext(newContext, startState, RuleExpr); + setState(61); + + if (!(precpred(_ctx, 19))) throw FailedPredicateException(this, "precpred(_ctx, 19)"); + setState(62); + antlrcpp::downCast(_localctx)->op = _input->LT(1); + _la = _input->LA(1); + if (!((((_la & ~ 0x3fULL) == 0) && + ((1ULL << _la) & 458752) != 0))) { + antlrcpp::downCast(_localctx)->op = _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + setState(63); + expr(20); + break; + } + + case 3: { + auto newContext = _tracker.createInstance(_tracker.createInstance(parentContext, parentState)); + _localctx = newContext; + pushNewRecursionContext(newContext, startState, RuleExpr); + setState(64); + + if (!(precpred(_ctx, 18))) throw FailedPredicateException(this, "precpred(_ctx, 18)"); + setState(65); + antlrcpp::downCast(_localctx)->op = _input->LT(1); + _la = _input->LA(1); + if (!(_la == PlanParser::ADD + + || _la == PlanParser::SUB)) { + antlrcpp::downCast(_localctx)->op = _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + setState(66); + expr(19); + break; + } + + case 4: { + auto newContext = _tracker.createInstance(_tracker.createInstance(parentContext, parentState)); + _localctx = newContext; + pushNewRecursionContext(newContext, startState, RuleExpr); + setState(67); + + if (!(precpred(_ctx, 17))) throw FailedPredicateException(this, "precpred(_ctx, 17)"); + setState(68); + antlrcpp::downCast(_localctx)->op = _input->LT(1); + _la = _input->LA(1); + if (!(_la == PlanParser::SHL + + || _la == PlanParser::SHR)) { + antlrcpp::downCast(_localctx)->op = _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + setState(69); + expr(18); + break; + } + + case 5: { + auto newContext = _tracker.createInstance(_tracker.createInstance(parentContext, parentState)); + _localctx = newContext; + pushNewRecursionContext(newContext, startState, RuleExpr); + setState(70); + + if (!(precpred(_ctx, 10))) throw FailedPredicateException(this, "precpred(_ctx, 10)"); + setState(71); + antlrcpp::downCast(_localctx)->op1 = _input->LT(1); + _la = _input->LA(1); + if (!(_la == PlanParser::LT + + || _la == PlanParser::LE)) { + antlrcpp::downCast(_localctx)->op1 = _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + setState(72); + _la = _input->LA(1); + if (!(_la == PlanParser::Identifier + + || _la == PlanParser::JSONIdentifier)) { + _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + setState(73); + antlrcpp::downCast(_localctx)->op2 = _input->LT(1); + _la = _input->LA(1); + if (!(_la == PlanParser::LT + + || _la == PlanParser::LE)) { + antlrcpp::downCast(_localctx)->op2 = _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + setState(74); + expr(11); + break; + } + + case 6: { + auto newContext = _tracker.createInstance(_tracker.createInstance(parentContext, parentState)); + _localctx = newContext; + pushNewRecursionContext(newContext, startState, RuleExpr); + setState(75); + + if (!(precpred(_ctx, 9))) throw FailedPredicateException(this, "precpred(_ctx, 9)"); + setState(76); + antlrcpp::downCast(_localctx)->op1 = _input->LT(1); + _la = _input->LA(1); + if (!(_la == PlanParser::GT + + || _la == PlanParser::GE)) { + antlrcpp::downCast(_localctx)->op1 = _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + setState(77); + _la = _input->LA(1); + if (!(_la == PlanParser::Identifier + + || _la == PlanParser::JSONIdentifier)) { + _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + setState(78); + antlrcpp::downCast(_localctx)->op2 = _input->LT(1); + _la = _input->LA(1); + if (!(_la == PlanParser::GT + + || _la == PlanParser::GE)) { + antlrcpp::downCast(_localctx)->op2 = _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + setState(79); + expr(10); + break; + } + + case 7: { + auto newContext = _tracker.createInstance(_tracker.createInstance(parentContext, parentState)); + _localctx = newContext; + pushNewRecursionContext(newContext, startState, RuleExpr); + setState(80); + + if (!(precpred(_ctx, 8))) throw FailedPredicateException(this, "precpred(_ctx, 8)"); + setState(81); + antlrcpp::downCast(_localctx)->op = _input->LT(1); + _la = _input->LA(1); + if (!((((_la & ~ 0x3fULL) == 0) && + ((1ULL << _la) & 960) != 0))) { + antlrcpp::downCast(_localctx)->op = _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + setState(82); + expr(9); + break; + } + + case 8: { + auto newContext = _tracker.createInstance(_tracker.createInstance(parentContext, parentState)); + _localctx = newContext; + pushNewRecursionContext(newContext, startState, RuleExpr); + setState(83); + + if (!(precpred(_ctx, 7))) throw FailedPredicateException(this, "precpred(_ctx, 7)"); + setState(84); + antlrcpp::downCast(_localctx)->op = _input->LT(1); + _la = _input->LA(1); + if (!(_la == PlanParser::EQ + + || _la == PlanParser::NE)) { + antlrcpp::downCast(_localctx)->op = _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + setState(85); + expr(8); + break; + } + + case 9: { + auto newContext = _tracker.createInstance(_tracker.createInstance(parentContext, parentState)); + _localctx = newContext; + pushNewRecursionContext(newContext, startState, RuleExpr); + setState(86); + + if (!(precpred(_ctx, 6))) throw FailedPredicateException(this, "precpred(_ctx, 6)"); + setState(87); + match(PlanParser::BAND); + setState(88); + expr(7); + break; + } + + case 10: { + auto newContext = _tracker.createInstance(_tracker.createInstance(parentContext, parentState)); + _localctx = newContext; + pushNewRecursionContext(newContext, startState, RuleExpr); + setState(89); + + if (!(precpred(_ctx, 5))) throw FailedPredicateException(this, "precpred(_ctx, 5)"); + setState(90); + match(PlanParser::BXOR); + setState(91); + expr(6); + break; + } + + case 11: { + auto newContext = _tracker.createInstance(_tracker.createInstance(parentContext, parentState)); + _localctx = newContext; + pushNewRecursionContext(newContext, startState, RuleExpr); + setState(92); + + if (!(precpred(_ctx, 4))) throw FailedPredicateException(this, "precpred(_ctx, 4)"); + setState(93); + match(PlanParser::BOR); + setState(94); + expr(5); + break; + } + + case 12: { + auto newContext = _tracker.createInstance(_tracker.createInstance(parentContext, parentState)); + _localctx = newContext; + pushNewRecursionContext(newContext, startState, RuleExpr); + setState(95); + + if (!(precpred(_ctx, 3))) throw FailedPredicateException(this, "precpred(_ctx, 3)"); + setState(96); + match(PlanParser::AND); + setState(97); + expr(4); + break; + } + + case 13: { + auto newContext = _tracker.createInstance(_tracker.createInstance(parentContext, parentState)); + _localctx = newContext; + pushNewRecursionContext(newContext, startState, RuleExpr); + setState(98); + + if (!(precpred(_ctx, 2))) throw FailedPredicateException(this, "precpred(_ctx, 2)"); + setState(99); + match(PlanParser::OR); + setState(100); + expr(3); + break; + } + + case 14: { + auto newContext = _tracker.createInstance(_tracker.createInstance(parentContext, parentState)); + _localctx = newContext; + pushNewRecursionContext(newContext, startState, RuleExpr); + setState(101); + + if (!(precpred(_ctx, 22))) throw FailedPredicateException(this, "precpred(_ctx, 22)"); + setState(102); + match(PlanParser::LIKE); + setState(103); + match(PlanParser::StringLiteral); + break; + } + + case 15: { + auto newContext = _tracker.createInstance(_tracker.createInstance(parentContext, parentState)); + _localctx = newContext; + pushNewRecursionContext(newContext, startState, RuleExpr); + setState(104); + + if (!(precpred(_ctx, 16))) throw FailedPredicateException(this, "precpred(_ctx, 16)"); + setState(105); + antlrcpp::downCast(_localctx)->op = _input->LT(1); + _la = _input->LA(1); + if (!(_la == PlanParser::IN + + || _la == PlanParser::NIN)) { + antlrcpp::downCast(_localctx)->op = _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + + setState(106); + match(PlanParser::T__2); + setState(107); + expr(0); + setState(112); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 3, _ctx); + while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER) { + if (alt == 1) { + setState(108); + match(PlanParser::T__3); + setState(109); + expr(0); + } + setState(114); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 3, _ctx); + } + setState(116); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == PlanParser::T__3) { + setState(115); + match(PlanParser::T__3); + } + setState(118); + match(PlanParser::T__4); + break; + } + + case 16: { + auto newContext = _tracker.createInstance(_tracker.createInstance(parentContext, parentState)); + _localctx = newContext; + pushNewRecursionContext(newContext, startState, RuleExpr); + setState(120); + + if (!(precpred(_ctx, 15))) throw FailedPredicateException(this, "precpred(_ctx, 15)"); + setState(121); + antlrcpp::downCast(_localctx)->op = _input->LT(1); + _la = _input->LA(1); + if (!(_la == PlanParser::IN + + || _la == PlanParser::NIN)) { + antlrcpp::downCast(_localctx)->op = _errHandler->recoverInline(this); + } + else { + _errHandler->reportMatch(this); + consume(); + } + setState(122); + match(PlanParser::EmptyTerm); + break; + } + + default: + break; + } + } + setState(127); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, 6, _ctx); + } + } + catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + return _localctx; +} + +bool PlanParser::sempred(RuleContext *context, size_t ruleIndex, size_t predicateIndex) { + switch (ruleIndex) { + case 0: return exprSempred(antlrcpp::downCast(context), predicateIndex); + + default: + break; + } + return true; +} + +bool PlanParser::exprSempred(ExprContext *_localctx, size_t predicateIndex) { + switch (predicateIndex) { + case 0: return precpred(_ctx, 21); + case 1: return precpred(_ctx, 19); + case 2: return precpred(_ctx, 18); + case 3: return precpred(_ctx, 17); + case 4: return precpred(_ctx, 10); + case 5: return precpred(_ctx, 9); + case 6: return precpred(_ctx, 8); + case 7: return precpred(_ctx, 7); + case 8: return precpred(_ctx, 6); + case 9: return precpred(_ctx, 5); + case 10: return precpred(_ctx, 4); + case 11: return precpred(_ctx, 3); + case 12: return precpred(_ctx, 2); + case 13: return precpred(_ctx, 22); + case 14: return precpred(_ctx, 16); + case 15: return precpred(_ctx, 15); + + default: + break; + } + return true; +} + +void PlanParser::initialize() { +#if ANTLR4_USE_THREAD_LOCAL_CACHE + planParserInitialize(); +#else + ::antlr4::internal::call_once(planParserOnceFlag, planParserInitialize); +#endif +} diff --git a/src/parser/antlr/PlanParser.h b/src/parser/antlr/PlanParser.h new file mode 100644 index 0000000..2236833 --- /dev/null +++ b/src/parser/antlr/PlanParser.h @@ -0,0 +1,426 @@ + +// Generated from Plan.g4 by ANTLR 4.13.1 + +#pragma once + + +#include "antlr4-runtime.h" + + + + +class PlanParser : public antlr4::Parser { +public: + enum { + T__0 = 1, T__1 = 2, T__2 = 3, T__3 = 4, T__4 = 5, LT = 6, LE = 7, GT = 8, + GE = 9, EQ = 10, NE = 11, LIKE = 12, EXISTS = 13, ADD = 14, SUB = 15, + MUL = 16, DIV = 17, MOD = 18, POW = 19, SHL = 20, SHR = 21, BAND = 22, + BOR = 23, BXOR = 24, AND = 25, OR = 26, BNOT = 27, NOT = 28, IN = 29, + NIN = 30, EmptyTerm = 31, JSONContains = 32, JSONContainsAll = 33, JSONContainsAny = 34, + ArrayContains = 35, ArrayContainsAll = 36, ArrayContainsAny = 37, ArrayLength = 38, + BooleanConstant = 39, IntegerConstant = 40, FloatingConstant = 41, Identifier = 42, + StringLiteral = 43, JSONIdentifier = 44, Whitespace = 45, Newline = 46 + }; + + enum { + RuleExpr = 0 + }; + + explicit PlanParser(antlr4::TokenStream *input); + + PlanParser(antlr4::TokenStream *input, const antlr4::atn::ParserATNSimulatorOptions &options); + + ~PlanParser() override; + + std::string getGrammarFileName() const override; + + const antlr4::atn::ATN& getATN() const override; + + const std::vector& getRuleNames() const override; + + const antlr4::dfa::Vocabulary& getVocabulary() const override; + + antlr4::atn::SerializedATNView getSerializedATN() const override; + + + class ExprContext; + + class ExprContext : public antlr4::ParserRuleContext { + public: + ExprContext(antlr4::ParserRuleContext *parent, size_t invokingState); + + ExprContext() = default; + void copyFrom(ExprContext *context); + using antlr4::ParserRuleContext::copyFrom; + + virtual size_t getRuleIndex() const override; + + + }; + + class JSONIdentifierContext : public ExprContext { + public: + JSONIdentifierContext(ExprContext *ctx); + + antlr4::tree::TerminalNode *JSONIdentifier(); + + virtual std::any accept(antlr4::tree::ParseTreeVisitor *visitor) override; + }; + + class ParensContext : public ExprContext { + public: + ParensContext(ExprContext *ctx); + + ExprContext *expr(); + + virtual std::any accept(antlr4::tree::ParseTreeVisitor *visitor) override; + }; + + class StringContext : public ExprContext { + public: + StringContext(ExprContext *ctx); + + antlr4::tree::TerminalNode *StringLiteral(); + + virtual std::any accept(antlr4::tree::ParseTreeVisitor *visitor) override; + }; + + class FloatingContext : public ExprContext { + public: + FloatingContext(ExprContext *ctx); + + antlr4::tree::TerminalNode *FloatingConstant(); + + virtual std::any accept(antlr4::tree::ParseTreeVisitor *visitor) override; + }; + + class JSONContainsAllContext : public ExprContext { + public: + JSONContainsAllContext(ExprContext *ctx); + + std::vector expr(); + ExprContext* expr(size_t i); + antlr4::tree::TerminalNode *JSONContainsAll(); + antlr4::tree::TerminalNode *ArrayContainsAll(); + + virtual std::any accept(antlr4::tree::ParseTreeVisitor *visitor) override; + }; + + class LogicalOrContext : public ExprContext { + public: + LogicalOrContext(ExprContext *ctx); + + std::vector expr(); + ExprContext* expr(size_t i); + antlr4::tree::TerminalNode *OR(); + + virtual std::any accept(antlr4::tree::ParseTreeVisitor *visitor) override; + }; + + class MulDivModContext : public ExprContext { + public: + MulDivModContext(ExprContext *ctx); + + antlr4::Token *op = nullptr; + std::vector expr(); + ExprContext* expr(size_t i); + antlr4::tree::TerminalNode *MUL(); + antlr4::tree::TerminalNode *DIV(); + antlr4::tree::TerminalNode *MOD(); + + virtual std::any accept(antlr4::tree::ParseTreeVisitor *visitor) override; + }; + + class IdentifierContext : public ExprContext { + public: + IdentifierContext(ExprContext *ctx); + + antlr4::tree::TerminalNode *Identifier(); + + virtual std::any accept(antlr4::tree::ParseTreeVisitor *visitor) override; + }; + + class LikeContext : public ExprContext { + public: + LikeContext(ExprContext *ctx); + + ExprContext *expr(); + antlr4::tree::TerminalNode *LIKE(); + antlr4::tree::TerminalNode *StringLiteral(); + + virtual std::any accept(antlr4::tree::ParseTreeVisitor *visitor) override; + }; + + class LogicalAndContext : public ExprContext { + public: + LogicalAndContext(ExprContext *ctx); + + std::vector expr(); + ExprContext* expr(size_t i); + antlr4::tree::TerminalNode *AND(); + + virtual std::any accept(antlr4::tree::ParseTreeVisitor *visitor) override; + }; + + class EqualityContext : public ExprContext { + public: + EqualityContext(ExprContext *ctx); + + antlr4::Token *op = nullptr; + std::vector expr(); + ExprContext* expr(size_t i); + antlr4::tree::TerminalNode *EQ(); + antlr4::tree::TerminalNode *NE(); + + virtual std::any accept(antlr4::tree::ParseTreeVisitor *visitor) override; + }; + + class BooleanContext : public ExprContext { + public: + BooleanContext(ExprContext *ctx); + + antlr4::tree::TerminalNode *BooleanConstant(); + + virtual std::any accept(antlr4::tree::ParseTreeVisitor *visitor) override; + }; + + class ShiftContext : public ExprContext { + public: + ShiftContext(ExprContext *ctx); + + antlr4::Token *op = nullptr; + std::vector expr(); + ExprContext* expr(size_t i); + antlr4::tree::TerminalNode *SHL(); + antlr4::tree::TerminalNode *SHR(); + + virtual std::any accept(antlr4::tree::ParseTreeVisitor *visitor) override; + }; + + class ReverseRangeContext : public ExprContext { + public: + ReverseRangeContext(ExprContext *ctx); + + antlr4::Token *op1 = nullptr; + antlr4::Token *op2 = nullptr; + std::vector expr(); + ExprContext* expr(size_t i); + antlr4::tree::TerminalNode *Identifier(); + antlr4::tree::TerminalNode *JSONIdentifier(); + std::vector GT(); + antlr4::tree::TerminalNode* GT(size_t i); + std::vector GE(); + antlr4::tree::TerminalNode* GE(size_t i); + + virtual std::any accept(antlr4::tree::ParseTreeVisitor *visitor) override; + }; + + class BitOrContext : public ExprContext { + public: + BitOrContext(ExprContext *ctx); + + std::vector expr(); + ExprContext* expr(size_t i); + antlr4::tree::TerminalNode *BOR(); + + virtual std::any accept(antlr4::tree::ParseTreeVisitor *visitor) override; + }; + + class AddSubContext : public ExprContext { + public: + AddSubContext(ExprContext *ctx); + + antlr4::Token *op = nullptr; + std::vector expr(); + ExprContext* expr(size_t i); + antlr4::tree::TerminalNode *ADD(); + antlr4::tree::TerminalNode *SUB(); + + virtual std::any accept(antlr4::tree::ParseTreeVisitor *visitor) override; + }; + + class RelationalContext : public ExprContext { + public: + RelationalContext(ExprContext *ctx); + + antlr4::Token *op = nullptr; + std::vector expr(); + ExprContext* expr(size_t i); + antlr4::tree::TerminalNode *LT(); + antlr4::tree::TerminalNode *LE(); + antlr4::tree::TerminalNode *GT(); + antlr4::tree::TerminalNode *GE(); + + virtual std::any accept(antlr4::tree::ParseTreeVisitor *visitor) override; + }; + + class ArrayLengthContext : public ExprContext { + public: + ArrayLengthContext(ExprContext *ctx); + + antlr4::tree::TerminalNode *ArrayLength(); + antlr4::tree::TerminalNode *Identifier(); + antlr4::tree::TerminalNode *JSONIdentifier(); + + virtual std::any accept(antlr4::tree::ParseTreeVisitor *visitor) override; + }; + + class TermContext : public ExprContext { + public: + TermContext(ExprContext *ctx); + + antlr4::Token *op = nullptr; + std::vector expr(); + ExprContext* expr(size_t i); + antlr4::tree::TerminalNode *IN(); + antlr4::tree::TerminalNode *NIN(); + + virtual std::any accept(antlr4::tree::ParseTreeVisitor *visitor) override; + }; + + class JSONContainsContext : public ExprContext { + public: + JSONContainsContext(ExprContext *ctx); + + std::vector expr(); + ExprContext* expr(size_t i); + antlr4::tree::TerminalNode *JSONContains(); + antlr4::tree::TerminalNode *ArrayContains(); + + virtual std::any accept(antlr4::tree::ParseTreeVisitor *visitor) override; + }; + + class RangeContext : public ExprContext { + public: + RangeContext(ExprContext *ctx); + + antlr4::Token *op1 = nullptr; + antlr4::Token *op2 = nullptr; + std::vector expr(); + ExprContext* expr(size_t i); + antlr4::tree::TerminalNode *Identifier(); + antlr4::tree::TerminalNode *JSONIdentifier(); + std::vector LT(); + antlr4::tree::TerminalNode* LT(size_t i); + std::vector LE(); + antlr4::tree::TerminalNode* LE(size_t i); + + virtual std::any accept(antlr4::tree::ParseTreeVisitor *visitor) override; + }; + + class UnaryContext : public ExprContext { + public: + UnaryContext(ExprContext *ctx); + + antlr4::Token *op = nullptr; + ExprContext *expr(); + antlr4::tree::TerminalNode *ADD(); + antlr4::tree::TerminalNode *SUB(); + antlr4::tree::TerminalNode *BNOT(); + antlr4::tree::TerminalNode *NOT(); + + virtual std::any accept(antlr4::tree::ParseTreeVisitor *visitor) override; + }; + + class IntegerContext : public ExprContext { + public: + IntegerContext(ExprContext *ctx); + + antlr4::tree::TerminalNode *IntegerConstant(); + + virtual std::any accept(antlr4::tree::ParseTreeVisitor *visitor) override; + }; + + class ArrayContext : public ExprContext { + public: + ArrayContext(ExprContext *ctx); + + std::vector expr(); + ExprContext* expr(size_t i); + + virtual std::any accept(antlr4::tree::ParseTreeVisitor *visitor) override; + }; + + class JSONContainsAnyContext : public ExprContext { + public: + JSONContainsAnyContext(ExprContext *ctx); + + std::vector expr(); + ExprContext* expr(size_t i); + antlr4::tree::TerminalNode *JSONContainsAny(); + antlr4::tree::TerminalNode *ArrayContainsAny(); + + virtual std::any accept(antlr4::tree::ParseTreeVisitor *visitor) override; + }; + + class BitXorContext : public ExprContext { + public: + BitXorContext(ExprContext *ctx); + + std::vector expr(); + ExprContext* expr(size_t i); + antlr4::tree::TerminalNode *BXOR(); + + virtual std::any accept(antlr4::tree::ParseTreeVisitor *visitor) override; + }; + + class ExistsContext : public ExprContext { + public: + ExistsContext(ExprContext *ctx); + + antlr4::tree::TerminalNode *EXISTS(); + ExprContext *expr(); + + virtual std::any accept(antlr4::tree::ParseTreeVisitor *visitor) override; + }; + + class BitAndContext : public ExprContext { + public: + BitAndContext(ExprContext *ctx); + + std::vector expr(); + ExprContext* expr(size_t i); + antlr4::tree::TerminalNode *BAND(); + + virtual std::any accept(antlr4::tree::ParseTreeVisitor *visitor) override; + }; + + class EmptyTermContext : public ExprContext { + public: + EmptyTermContext(ExprContext *ctx); + + antlr4::Token *op = nullptr; + ExprContext *expr(); + antlr4::tree::TerminalNode *EmptyTerm(); + antlr4::tree::TerminalNode *IN(); + antlr4::tree::TerminalNode *NIN(); + + virtual std::any accept(antlr4::tree::ParseTreeVisitor *visitor) override; + }; + + class PowerContext : public ExprContext { + public: + PowerContext(ExprContext *ctx); + + std::vector expr(); + ExprContext* expr(size_t i); + antlr4::tree::TerminalNode *POW(); + + virtual std::any accept(antlr4::tree::ParseTreeVisitor *visitor) override; + }; + + ExprContext* expr(); + ExprContext* expr(int precedence); + + bool sempred(antlr4::RuleContext *_localctx, size_t ruleIndex, size_t predicateIndex) override; + + bool exprSempred(ExprContext *_localctx, size_t predicateIndex); + + // By default the static state used to implement the parser is lazily initialized during the first + // call to the constructor. You can call this function if you wish to initialize the static state + // ahead of time. + static void initialize(); + +private: +}; + diff --git a/src/parser/antlr/PlanVisitor.cpp b/src/parser/antlr/PlanVisitor.cpp new file mode 100644 index 0000000..9a36f3f --- /dev/null +++ b/src/parser/antlr/PlanVisitor.cpp @@ -0,0 +1,7 @@ + +// Generated from Plan.g4 by ANTLR 4.13.1 + + +#include "PlanVisitor.h" + + diff --git a/src/parser/antlr/PlanVisitor.h b/src/parser/antlr/PlanVisitor.h new file mode 100644 index 0000000..b7af125 --- /dev/null +++ b/src/parser/antlr/PlanVisitor.h @@ -0,0 +1,84 @@ + +// Generated from Plan.g4 by ANTLR 4.13.1 + +#pragma once + + +#include "antlr4-runtime.h" +#include "PlanParser.h" + + + +/** + * This class defines an abstract visitor for a parse tree + * produced by PlanParser. + */ +class PlanVisitor : public antlr4::tree::AbstractParseTreeVisitor { +public: + + /** + * Visit parse trees produced by PlanParser. + */ + virtual std::any visitJSONIdentifier(PlanParser::JSONIdentifierContext *context) = 0; + + virtual std::any visitParens(PlanParser::ParensContext *context) = 0; + + virtual std::any visitString(PlanParser::StringContext *context) = 0; + + virtual std::any visitFloating(PlanParser::FloatingContext *context) = 0; + + virtual std::any visitJSONContainsAll(PlanParser::JSONContainsAllContext *context) = 0; + + virtual std::any visitLogicalOr(PlanParser::LogicalOrContext *context) = 0; + + virtual std::any visitMulDivMod(PlanParser::MulDivModContext *context) = 0; + + virtual std::any visitIdentifier(PlanParser::IdentifierContext *context) = 0; + + virtual std::any visitLike(PlanParser::LikeContext *context) = 0; + + virtual std::any visitLogicalAnd(PlanParser::LogicalAndContext *context) = 0; + + virtual std::any visitEquality(PlanParser::EqualityContext *context) = 0; + + virtual std::any visitBoolean(PlanParser::BooleanContext *context) = 0; + + virtual std::any visitShift(PlanParser::ShiftContext *context) = 0; + + virtual std::any visitReverseRange(PlanParser::ReverseRangeContext *context) = 0; + + virtual std::any visitBitOr(PlanParser::BitOrContext *context) = 0; + + virtual std::any visitAddSub(PlanParser::AddSubContext *context) = 0; + + virtual std::any visitRelational(PlanParser::RelationalContext *context) = 0; + + virtual std::any visitArrayLength(PlanParser::ArrayLengthContext *context) = 0; + + virtual std::any visitTerm(PlanParser::TermContext *context) = 0; + + virtual std::any visitJSONContains(PlanParser::JSONContainsContext *context) = 0; + + virtual std::any visitRange(PlanParser::RangeContext *context) = 0; + + virtual std::any visitUnary(PlanParser::UnaryContext *context) = 0; + + virtual std::any visitInteger(PlanParser::IntegerContext *context) = 0; + + virtual std::any visitArray(PlanParser::ArrayContext *context) = 0; + + virtual std::any visitJSONContainsAny(PlanParser::JSONContainsAnyContext *context) = 0; + + virtual std::any visitBitXor(PlanParser::BitXorContext *context) = 0; + + virtual std::any visitExists(PlanParser::ExistsContext *context) = 0; + + virtual std::any visitBitAnd(PlanParser::BitAndContext *context) = 0; + + virtual std::any visitEmptyTerm(PlanParser::EmptyTermContext *context) = 0; + + virtual std::any visitPower(PlanParser::PowerContext *context) = 0; + + +}; + diff --git a/src/parser/parser.cc b/src/parser/parser.cc new file mode 100644 index 0000000..729bbb7 --- /dev/null +++ b/src/parser/parser.cc @@ -0,0 +1,35 @@ +#include "parser.h" +namespace milvus::local { + +std::string +ParserToMessage(milvus::proto::schema::CollectionSchema& schema, + const std::string& exprstr) { + antlr4::ANTLRInputStream input(exprstr); + PlanLexer lexer(&input); + antlr4::CommonTokenStream tokens(&lexer); + PlanParser parser(&tokens); + + PlanParser::ExprContext* tree = parser.expr(); + + auto helper = milvus::local::CreateSchemaHelper(&schema); + + milvus::local::PlanCCVisitor visitor(&helper); + auto res = std::any_cast(visitor.visit(tree)); + return res.expr->SerializeAsString(); +} + +std::shared_ptr +ParseIdentifier(milvus::local::SchemaHelper helper, + const std::string& identifier) { + auto expr = + google::protobuf::Arena::CreateMessage(NULL); + + assert(expr->column_expr().has_info()); + expr->ParseFromString(ParserToMessage(*(helper.schema), identifier)); + + auto ret = std::make_shared(); + ret.reset(expr); + return ret; +}; + +} // namespace milvus::local diff --git a/src/parser/parser.h b/src/parser/parser.h new file mode 100644 index 0000000..ac59dd4 --- /dev/null +++ b/src/parser/parser.h @@ -0,0 +1,1447 @@ +#pragma once + +#include +#include +#include +#include +#include "antlr/PlanBaseVisitor.h" +#include "antlr/PlanLexer.h" +#include "antlr/PlanParser.h" +#include "pb/plan.pb.h" +#include "utils.h" + +namespace milvus::local { + +class PlanCCVisitor : public PlanVisitor { + public: + // ok + virtual std::any + visitShift(PlanParser::ShiftContext*) override { + assert(false); + return nullptr; + } + // ok + virtual std::any + visitBitOr(PlanParser::BitOrContext*) override { + assert(false); + return nullptr; + } + // ok + virtual std::any + visitBitXor(PlanParser::BitXorContext*) override { + assert(false); + return nullptr; + } + // ok + virtual std::any + visitBitAnd(PlanParser::BitAndContext*) override { + assert(false); + return nullptr; + } + + // ok + virtual std::any + visitParens(PlanParser::ParensContext* ctx) override { + return visitChildren(ctx); + } + // ok + virtual std::any + visitString(PlanParser::StringContext* ctx) override { + auto val = ctx->getText(); + return ExprWithDtype(createValueExpr( + convertEscapeSingle(val), this->arena.get()), + proto::schema::DataType::VarChar, + true); + } + // ok + virtual std::any + visitFloating(PlanParser::FloatingContext* ctx) override { + auto text = ctx->getText(); + auto val = std::strtod(text.c_str(), NULL); + return ExprWithDtype(createValueExpr(val, this->arena.get()), + proto::schema::DataType::Float, + true); + } + // ok + virtual std::any + visitInteger(PlanParser::IntegerContext* ctx) override { + auto text = ctx->getText(); + int64_t val = std::strtoll(text.c_str(), NULL, 10); + return ExprWithDtype(createValueExpr(val, this->arena.get()), + proto::schema::DataType::Int64, + true); + } + // ok + virtual std::any + visitBoolean(PlanParser::BooleanContext* ctx) override { + auto text = ctx->getText(); + bool val; + std::istringstream(text) >> std::boolalpha >> val; + return ExprWithDtype(createValueExpr(val, this->arena.get()), + proto::schema::DataType::Bool, + true); + } + + virtual std::any + visitPower(PlanParser::PowerContext* ctx) override { + auto left_expr = + std::any_cast(ctx->expr()[0]->accept(this)).expr; + auto right_expr = + std::any_cast(ctx->expr()[1]->accept(this)).expr; + + auto left = extractValue(left_expr); + auto right = extractValue(right_expr); + + assert(left.has_value() && right.has_value()); + assert(left.type() == typeid(double) || left.type() == typeid(int64_t)); + assert(right.type() == typeid(double) || + right.type() == typeid(int64_t)); + float left_value, right_value; + if (left.type() == typeid(int64_t)) + left_value = float(std::any_cast(left)); + if (left.type() == typeid(double)) + left_value = float(std::any_cast(left)); + if (right.type() == typeid(int64_t)) + right_value = float(std::any_cast(right)); + if (right.type() == typeid(double)) + right_value = float(std::any_cast(right)); + + return ExprWithDtype( + createValueExpr(powf(left_value, right_value), + this->arena.get()), + proto::schema::DataType::Double, + false); + } + + virtual std::any + visitLogicalOr(PlanParser::LogicalOrContext* ctx) override { + auto left_expr_with_type = + std::any_cast(ctx->expr()[0]->accept(this)); + auto right_expr_with_type = + std::any_cast(ctx->expr()[1]->accept(this)); + + auto left_expr = left_expr_with_type.expr; + auto right_expr = right_expr_with_type.expr; + + auto left_value = extractValue(left_expr); + auto right_value = extractValue(right_expr); + + if (left_value.has_value() && right_value.has_value() && + left_value.type() == typeid(bool) && + right_value.type() == typeid(bool)) { + return ExprWithDtype( + createValueExpr(std::any_cast(left_value) || + std::any_cast(right_value), + this->arena.get() + + ), + proto::schema::DataType::Bool, + false + + ); + } + + assert(!left_expr_with_type.dependent); + assert(!right_expr_with_type.dependent); + assert(left_expr_with_type.dtype == proto::schema::DataType::Bool); + assert(right_expr_with_type.dtype == proto::schema::DataType::Bool); + return ExprWithDtype( + createBinExpr( + left_expr, right_expr), + proto::schema::DataType::Bool, + false); + } + + virtual std::any + visitLogicalAnd(PlanParser::LogicalAndContext* ctx) override { + auto left_expr_with_type = + std::any_cast(ctx->expr()[0]->accept(this)); + auto right_expr_with_type = + std::any_cast(ctx->expr()[1]->accept(this)); + + auto left_expr = left_expr_with_type.expr; + auto right_expr = right_expr_with_type.expr; + + auto left_value = extractValue(left_expr); + auto right_value = extractValue(right_expr); + + if (left_value.has_value() && right_value.has_value() && + left_value.type() == typeid(bool) && + right_value.type() == typeid(bool)) { + return ExprWithDtype( + createValueExpr(std::any_cast(left_value) || + std::any_cast(right_value), + this->arena.get() + + ), + proto::schema::DataType::Bool, + false + + ); + } + + assert(!left_expr_with_type.dependent); + assert(!right_expr_with_type.dependent); + assert(left_expr_with_type.dtype == proto::schema::DataType::Bool); + assert(right_expr_with_type.dtype == proto::schema::DataType::Bool); + return ExprWithDtype( + createBinExpr( + left_expr, right_expr, this->arena.get()), + proto::schema::DataType::Bool, + false); + } + + virtual std::any + visitJSONIdentifier(PlanParser::JSONIdentifierContext* ctx) override { + auto info = getChildColumnInfo(nullptr, ctx->JSONIdentifier()); + + assert(info); + + auto expr = google::protobuf::Arena::CreateMessage( + this->arena.get()); + auto col_expr = + google::protobuf::Arena::CreateMessage( + this->arena.get()); + + col_expr->unsafe_arena_set_allocated_info(info); + + expr->unsafe_arena_set_allocated_column_expr(col_expr); + return ExprWithDtype(expr, info->data_type(), true); + } + + virtual std::any + visitJSONContainsAll(PlanParser::JSONContainsAllContext* ctx) override { + auto field = std::any_cast(ctx->expr()[0]->accept(this)); + auto info = field.expr->column_expr().info(); + assert(info.data_type() == proto::schema::DataType::Array || + info.data_type() == proto::schema::DataType::JSON); + auto elem = std::any_cast(ctx->expr()[1]->accept(this)); + if (info.data_type() == proto::schema::DataType::Array) { + proto::plan::GenericValue expr = + proto::plan::GenericValue(elem.expr->value_expr().value()); + assert(canBeCompared(field, toValueExpr(&expr, this->arena.get()))); + } + + auto expr = google::protobuf::Arena::CreateMessage( + this->arena.get()); + auto json_contain_expr = google::protobuf::Arena::CreateMessage< + proto::plan::JSONContainsExpr>(this->arena.get()); + auto value = json_contain_expr->add_elements(); // MayBe BUG + value->unsafe_arena_set_allocated_array_val( + CreateMessageWithCopy( + this->arena.get(), + elem.expr->value_expr().value().array_val())); + json_contain_expr->set_elements_same_type( + elem.expr->value_expr().value().array_val().same_type()); + json_contain_expr->unsafe_arena_set_allocated_column_info( + CreateMessageWithCopy(this->arena.get(), + info)); + + json_contain_expr->set_op( + proto::plan::JSONContainsExpr_JSONOp_ContainsAll); + expr->unsafe_arena_set_allocated_json_contains_expr(json_contain_expr); + return ExprWithDtype(expr, proto::schema::Bool, false); + } + + virtual std::any + visitMulDivMod(PlanParser::MulDivModContext* ctx) override { + auto left_expr_with_type = + std::any_cast(ctx->expr()[0]->accept(this)); + auto right_expr_with_type = + std::any_cast(ctx->expr()[1]->accept(this)); + auto left_expr = left_expr_with_type.expr; + auto right_expr = right_expr_with_type.expr; + + auto left_value = extractValue(left_expr); + auto right_value = extractValue(right_expr); + if (left_value.has_value() && right_value.has_value()) { + if (left_value.type() == typeid(double) && + right_value.type() == typeid(double)) { + switch (ctx->op->getType()) { + case PlanParser::MUL: + return ExprWithDtype( + createValueExpr( + std::any_cast(left_value) * + std::any_cast(right_value), + this->arena.get()), + proto::schema::DataType::Double, + false); + case PlanParser::DIV: + return ExprWithDtype( + createValueExpr( + std::any_cast(left_value) / + std::any_cast(right_value), + this->arena.get()), + proto::schema::DataType::Double, + false); + default: + assert(false); + } + } + + if (left_value.type() == typeid(int64_t) && + right_value.type() == typeid(int64_t)) { + switch (ctx->op->getType()) { + case PlanParser::MUL: + return ExprWithDtype( + createValueExpr( + std::any_cast(left_value) * + std::any_cast(right_value), + this->arena.get()), + proto::schema::DataType::Int64, + false); + case PlanParser::DIV: + return ExprWithDtype( + createValueExpr( + std::any_cast(left_value) / + std::any_cast(right_value), + this->arena.get()), + proto::schema::DataType::Int64, + false); + case PlanParser::MOD: + return ExprWithDtype( + createValueExpr( + std::any_cast(left_value) % + std::any_cast(right_value), + this->arena.get()), + proto::schema::DataType::Int64, + false); + default: + assert(false); + } + } + + if (left_value.type() == typeid(double) && + right_value.type() == typeid(int64_t)) { + switch (ctx->op->getType()) { + case PlanParser::MUL: + return ExprWithDtype( + createValueExpr( + std::any_cast(left_value) * + std::any_cast(right_value), + this->arena.get()), + proto::schema::DataType::Double, + false); + case PlanParser::DIV: + return ExprWithDtype( + createValueExpr( + std::any_cast(left_value) / + std::any_cast(right_value), + this->arena.get()), + proto::schema::DataType::Double, + false); + default: + assert(false); + } + } + + if (left_value.type() == typeid(int64_t) && + right_value.type() == typeid(double)) { + switch (ctx->op->getType()) { + case PlanParser::MUL: + return ExprWithDtype( + createValueExpr( + double(std::any_cast(left_value)) * + std::any_cast(right_value), + this->arena.get()), + proto::schema::DataType::Double, + false); + case PlanParser::DIV: + return ExprWithDtype( + createValueExpr( + double(std::any_cast(left_value)) * + std::any_cast(right_value), + this->arena.get()), + proto::schema::DataType::Double, + false); + default: + assert(false); + } + } + + if (left_expr->has_column_expr()) { + assert(left_expr->column_expr().info().data_type() != + proto::schema::DataType::Array); + assert(left_expr->column_expr().info().nested_path_size() == 0); + } + + if (right_expr->has_column_expr()) { + assert(right_expr->column_expr().info().data_type() != + proto::schema::DataType::Array); + assert(right_expr->column_expr().info().nested_path_size() == + 0); + } + + if (left_expr_with_type.dtype == proto::schema::DataType::Array) { + if (right_expr_with_type.dtype == + proto::schema::DataType::Array) + assert(canArithmeticDtype(getArrayElementType(left_expr), + getArrayElementType(right_expr))); + else if (arithmeticDtype(left_expr_with_type.dtype)) + assert(canArithmeticDtype(getArrayElementType(left_expr), + right_expr_with_type.dtype)); + else + assert(false); + } + + if (right_expr_with_type.dtype == proto::schema::DataType::Array) { + if (arithmeticDtype(left_expr_with_type.dtype)) + assert(canArithmeticDtype(left_expr_with_type.dtype, + getArrayElementType(right_expr))); + else + assert(false); + } + + if (arithmeticDtype(left_expr_with_type.dtype) && + arithmeticDtype(right_expr_with_type.dtype)) { + assert(canArithmeticDtype(left_expr_with_type.dtype, + right_expr_with_type.dtype)); + } else { + assert(false); + } + + switch (ctx->op->getType()) { + case PlanParser::MUL: + return ExprWithDtype( + createBinArithExpr( + left_expr, right_expr, this->arena.get()), + calDataType(&left_expr_with_type, + &right_expr_with_type), + false); + case PlanParser::DIV: + return ExprWithDtype( + createBinArithExpr( + left_expr, right_expr, this->arena.get()), + calDataType(&left_expr_with_type, + &right_expr_with_type), + false); + case PlanParser::MOD: + return ExprWithDtype( + createBinArithExpr( + left_expr, right_expr, this->arena.get()), + calDataType(&left_expr_with_type, + &right_expr_with_type), + false); + + default: + assert(false); + } + } + return nullptr; + } + + virtual std::any + visitIdentifier(PlanParser::IdentifierContext* ctx) override { + auto identifier = ctx->getText(); + auto& field = helper->GetFieldFromNameDefaultJSON(identifier); + std::vector nested_path; + if (field.name() != identifier) { + nested_path.push_back(identifier); + } + assert(!(field.data_type() == proto::schema::DataType::JSON && + nested_path.empty())); + auto expr = google::protobuf::Arena::CreateMessage( + arena.get()); + auto col_expr = + google::protobuf::Arena::CreateMessage( + arena.get()); + auto info = + google::protobuf::Arena::CreateMessage( + arena.get()); + info->set_field_id(field.fieldid()); + info->set_data_type(field.data_type()); + info->set_is_primary_key(field.is_primary_key()); + info->set_is_autoid(field.autoid()); + for (int i = 0; i < (int)nested_path.size(); ++i) { + auto path_added = info->add_nested_path(); + *path_added = nested_path[i]; + } + info->set_is_primary_key(field.is_primary_key()); + info->set_element_type(field.element_type()); + col_expr->set_allocated_info(info); + expr->set_allocated_column_expr(col_expr); + return ExprWithDtype(expr, field.data_type(), false); + } + + virtual std::any + visitLike(PlanParser::LikeContext* ctx) override { + auto child_expr_with_type = + std::any_cast(ctx->expr()->accept(this)); + auto child_expr = child_expr_with_type.expr; + assert(child_expr); + auto info = child_expr->column_expr().info(); + assert(!(info.data_type() == proto::schema::DataType::JSON && + info.nested_path_size() == 0)); + assert( + (child_expr_with_type.dtype == proto::schema::DataType::VarChar || + child_expr_with_type.dtype == proto::schema::DataType::JSON) || + (child_expr_with_type.dtype == proto::schema::DataType::Array && + info.element_type() == proto::schema::DataType::VarChar)); + + auto str = ctx->StringLiteral()->getText(); + auto pattern = convertEscapeSingle(str); + + auto res = translatePatternMatch(pattern); + auto expr = google::protobuf::Arena::CreateMessage( + arena.get()); + auto unaryrange_expr = + google::protobuf::Arena::CreateMessage( + arena.get()); + unaryrange_expr->set_op(res.first); + + auto value = + google::protobuf::Arena::CreateMessage( + this->arena.get()); + value->set_string_val(res.second); + unaryrange_expr->unsafe_arena_set_allocated_value(value); + unaryrange_expr->unsafe_arena_set_allocated_column_info( + CreateMessageWithCopy(this->arena.get(), + info)); + expr->set_allocated_unary_range_expr(unaryrange_expr); + return ExprWithDtype(expr, proto::schema::DataType::Bool, false); + } + + virtual std::any + visitEquality(PlanParser::EqualityContext* ctx) override { + auto left_expr_with_type = + std::any_cast(ctx->expr()[0]->accept(this)); + auto right_expr_with_type = + std::any_cast(ctx->expr()[1]->accept(this)); + + auto left_value = extractValue(left_expr_with_type.expr); + auto right_value = extractValue(right_expr_with_type.expr); + + if (left_value.has_value() && right_value.has_value()) { +#define PROCESS_EQALITY(left_type, right_type) \ + if (left_value.type() == typeid(left_type) && \ + right_value.type() == typeid(right_type)) { \ + switch (ctx->op->getType()) { \ + case PlanParser::EQ: \ + return ExprWithDtype( \ + createValueExpr( \ + std::any_cast(left_value) == \ + std::any_cast(right_value), \ + this->arena.get()), \ + proto::schema::DataType::Bool, \ + false); \ + case PlanParser::NE: \ + return ExprWithDtype( \ + createValueExpr( \ + std::any_cast(left_value) != \ + std::any_cast(right_value), \ + this->arena.get()), \ + proto::schema::DataType::Bool, \ + false); \ + } \ + } + + PROCESS_EQALITY(bool, bool); + PROCESS_EQALITY(std::string, std::string); + PROCESS_EQALITY(double, double); + PROCESS_EQALITY(double, int64_t); + PROCESS_EQALITY(int64_t, double); + PROCESS_EQALITY(int32_t, int32_t); + PROCESS_EQALITY(int64_t, int64_t); + PROCESS_EQALITY(double, int32_t); + PROCESS_EQALITY(int32_t, double); + PROCESS_EQALITY(float, float); + PROCESS_EQALITY(int32_t, float); + PROCESS_EQALITY(float, int32_t); + PROCESS_EQALITY(float, double); + PROCESS_EQALITY(double, float); + } + + if (left_expr_with_type.expr->has_value_expr() && + !right_expr_with_type.expr->has_value_expr()) { + ExprWithDtype left = + toValueExpr(CreateMessageWithCopy( + this->arena.get(), + left_expr_with_type.expr->value_expr().value()), + this->arena.get()); + ExprWithDtype right = right_expr_with_type; + + return ExprWithDtype( + HandleCompare( + ctx->op->getType(), left, right, this->arena.get()), + proto::schema::DataType::Bool, + false); + } + + if (!left_expr_with_type.expr->has_value_expr() && + right_expr_with_type.expr->has_value_expr()) { + ExprWithDtype left = left_expr_with_type; + ExprWithDtype right = toValueExpr( + CreateMessageWithCopy( + this->arena.get(), + right_expr_with_type.expr->value_expr().value()), + this->arena.get()); + + return ExprWithDtype( + HandleCompare( + ctx->op->getType(), left, right, this->arena.get()), + proto::schema::DataType::Bool, + false); + } + + if (!left_expr_with_type.expr->has_value_expr() && + !right_expr_with_type.expr->has_value_expr()) { + return ExprWithDtype(HandleCompare(ctx->op->getType(), + left_expr_with_type, + right_expr_with_type, + this->arena.get()), + proto::schema::DataType::Bool, + false); + } + + return nullptr; + } + + proto::plan::ColumnInfo* + getChildColumnInfo(antlr4::tree::TerminalNode* identifier, + antlr4::tree::TerminalNode* child) { + if (identifier) { + auto text = identifier->getText(); + auto field = helper->GetFieldFromNameDefaultJSON(text); + std::vector nested_path; + if (field.name() != text) { + nested_path.push_back(text); + } + assert(!(field.data_type() == proto::schema::DataType::JSON && + nested_path.empty())); + auto info = + google::protobuf::Arena::CreateMessage( + this->arena.get()); + info->set_field_id(field.fieldid()); + info->set_data_type(field.data_type()); + info->set_is_primary_key(field.is_primary_key()); + info->set_is_autoid(field.autoid()); + for (int i = 0; i < (int)nested_path.size(); ++i) { + auto path_added = info->add_nested_path(); + *path_added = nested_path[i]; + } + info->set_is_primary_key(field.is_primary_key()); + info->set_element_type(field.element_type()); + return info; + } + + auto childtext = child->getText(); + std::string fieldname = childtext.substr(0, childtext.find("[", 0)); + + std::vector nested_path; + auto field = helper->GetFieldFromNameDefaultJSON(fieldname); + assert(field.data_type() == proto::schema::DataType::JSON || + field.data_type() == proto::schema::DataType::Array); + if (fieldname != field.name()) + nested_path.push_back(fieldname); + auto jsonkey = childtext.substr( + fieldname.length(), childtext.length() - fieldname.length()); + auto ss = tokenize(jsonkey, "]["); + for (size_t i = 0; i < ss.size(); ++i) { + std::string path_ = ss[i]; + + if (path_[0] == '[') + path_ = path_.substr(1, path_.length() - 1); + + if (path_[path_.length() - 1] == ']') + path_ = path_.substr(0, path_.length() - 1); + assert(path_ != ""); + + if ((path_[0] == '\"' && path_[path_.length() - 1] == '\"') || + (path_[0] == '\'' && path_[path_.length() - 1] == '\'')) { + path_ = path_.substr(1, path_.length() - 2); + assert(path_ != ""); + } + nested_path.push_back(path_); + } + + auto info = + google::protobuf::Arena::CreateMessage( + this->arena.get()); + + info->set_field_id(field.fieldid()); + info->set_data_type(field.data_type()); + info->set_is_primary_key(field.is_primary_key()); + info->set_is_autoid(field.autoid()); + for (int i = 0; i < (int)nested_path.size(); ++i) { + auto path_added = info->add_nested_path(); + *path_added = nested_path[i]; + } + info->set_is_primary_key(field.is_primary_key()); + info->set_element_type(field.element_type()); + return info; + } + + virtual std::any + visitReverseRange(PlanParser::ReverseRangeContext* ctx) override { + auto info = + getChildColumnInfo(ctx->Identifier(), ctx->JSONIdentifier()); + assert(info != nullptr); + assert(checkDirectComparisonBinaryField(info)); + auto lower = std::any_cast(ctx->expr()[1]->accept(this)); + auto upper = std::any_cast(ctx->expr()[0]->accept(this)); + + if (info->data_type() == proto::schema::DataType::Int8 || + info->data_type() == proto::schema::DataType::Int16 || + info->data_type() == proto::schema::DataType::Int32 || + info->data_type() == proto::schema::DataType::Int64 || + info->data_type() == proto::schema::DataType::Float || + info->data_type() == proto::schema::DataType::Double || + info->data_type() == proto::schema::DataType::VarChar) { + auto a = extractValue(lower.expr); + auto b = extractValue(upper.expr); + if (a.has_value() && b.has_value()) { + bool lowerinclusive = ctx->op1->getType() == PlanParser::GE; + bool upperinclusive = ctx->op2->getType() == PlanParser::GE; + auto expr = + google::protobuf::Arena::CreateMessage( + this->arena.get()); + auto binary_range_expr = google::protobuf::Arena::CreateMessage< + proto::plan::BinaryRangeExpr>(this->arena.get()); + auto lower_value = google::protobuf::Arena::CreateMessage< + proto::plan::GenericValue>(this->arena.get()); + auto upper_value = google::protobuf::Arena::CreateMessage< + proto::plan::GenericValue>(this->arena.get()); + if (a.type() == typeid(int8_t)) + lower_value->set_int64_val( + int64_t(std::any_cast(a))); + if (a.type() == typeid(int16_t)) + lower_value->set_int64_val( + int64_t(std::any_cast(a))); + if (a.type() == typeid(int32_t)) + lower_value->set_int64_val( + int64_t(std::any_cast(a))); + if (a.type() == typeid(int64_t)) + lower_value->set_int64_val(std::any_cast(a)); + if (a.type() == typeid(double)) + lower_value->set_float_val(std::any_cast(a)); + if (a.type() == typeid(float)) + lower_value->set_float_val(double(std::any_cast(a))); + if (a.type() == typeid(std::string)) + lower_value->set_string_val(std::any_cast(a)); + + if (b.type() == typeid(int8_t)) + upper_value->set_int64_val( + int64_t(std::any_cast(b))); + if (b.type() == typeid(int16_t)) + upper_value->set_int64_val( + int64_t(std::any_cast(b))); + if (b.type() == typeid(int32_t)) + upper_value->set_int64_val( + int64_t(std::any_cast(b))); + if (b.type() == typeid(int64_t)) + upper_value->set_int64_val( + int64_t(std::any_cast(b))); + if (b.type() == typeid(double)) + upper_value->set_float_val(std::any_cast(b)); + if (b.type() == typeid(float)) + upper_value->set_float_val(double(std::any_cast(b))); + if (b.type() == typeid(std::string)) + upper_value->set_string_val(std::any_cast(b)); + + binary_range_expr->set_lower_inclusive(lowerinclusive); + binary_range_expr->set_upper_inclusive(upperinclusive); + binary_range_expr->unsafe_arena_set_allocated_column_info(info); + + binary_range_expr->unsafe_arena_set_allocated_lower_value( + lower_value); + binary_range_expr->unsafe_arena_set_allocated_upper_value( + upper_value); + expr->set_allocated_binary_range_expr(binary_range_expr); + return ExprWithDtype( + expr, proto::schema::DataType::Bool, false); + } + } + + return nullptr; + } + + virtual std::any + visitAddSub(PlanParser::AddSubContext* ctx) override { + auto left_expr_with_type = + std::any_cast(ctx->expr()[0]->accept(this)); + auto right_expr_with_type = + std::any_cast(ctx->expr()[1]->accept(this)); + auto left_value = extractValue(left_expr_with_type.expr); + auto right_value = extractValue(right_expr_with_type.expr); + + if (left_value.has_value() && right_value.has_value()) { +#define PROCESS_ADDSUB(left_type, right_type, target_type, datatype) \ + if (left_value.type() == typeid(left_type) && \ + right_value.type() == typeid(right_type)) { \ + switch (ctx->op->getType()) { \ + case PlanParser::ADD: \ + return ExprWithDtype( \ + createValueExpr( \ + std::any_cast(left_value) + \ + std::any_cast(right_value), \ + this->arena.get()), \ + datatype, \ + false); \ + case PlanParser::SUB: \ + return ExprWithDtype( \ + createValueExpr( \ + std::any_cast(left_value) - \ + std::any_cast(right_value), \ + this->arena.get()), \ + datatype, \ + false); \ + default: \ + assert(false); \ + } \ + } + + PROCESS_ADDSUB( + double, double, double, proto::schema::DataType::Double) + PROCESS_ADDSUB( + double, int64_t, double, proto::schema::DataType::Double) + PROCESS_ADDSUB( + int64_t, double, double, proto::schema::DataType::Double) + PROCESS_ADDSUB( + int64_t, int64_t, int64_t, proto::schema::DataType::Int64) + PROCESS_ADDSUB( + int32_t, int32_t, int32_t, proto::schema::DataType::Int32) + PROCESS_ADDSUB( + float, float, double, proto::schema::DataType::Double) + } + + auto left_expr = left_expr_with_type.expr; + auto right_expr = right_expr_with_type.expr; + if (left_expr->has_column_expr()) { + assert(left_expr->column_expr().info().data_type() != + proto::schema::DataType::Array); + assert(left_expr->column_expr().info().nested_path_size() == 0); + } + + if (right_expr->has_column_expr()) { + assert(right_expr->column_expr().info().data_type() != + proto::schema::DataType::Array); + assert(right_expr->column_expr().info().nested_path_size() == 0); + } + + if (left_expr_with_type.dtype == proto::schema::DataType::Array) { + if (right_expr_with_type.dtype == proto::schema::DataType::Array) + assert(canArithmeticDtype(getArrayElementType(left_expr), + getArrayElementType(right_expr))); + else if (arithmeticDtype(right_expr_with_type.dtype)) + assert(canArithmeticDtype(getArrayElementType(left_expr), + right_expr_with_type.dtype)); + else + assert(false); + } + + if (right_expr_with_type.dtype == proto::schema::DataType::Array) { + if (arithmeticDtype(left_expr_with_type.dtype)) + assert(canArithmeticDtype(left_expr_with_type.dtype, + getArrayElementType(right_expr))); + else + assert(false); + } + + if (arithmeticDtype(left_expr_with_type.dtype) && + arithmeticDtype(right_expr_with_type.dtype)) { + assert(canArithmeticDtype(left_expr_with_type.dtype, + right_expr_with_type.dtype)); + } else { + assert(false); + } + + switch (ctx->op->getType()) { + case PlanParser::ADD: + return ExprWithDtype( + createBinArithExpr( + left_expr, right_expr, this->arena.get()), + calDataType(&left_expr_with_type, &right_expr_with_type), + false); + case PlanParser::SUB: + return ExprWithDtype( + createBinArithExpr( + left_expr, right_expr, this->arena.get()), + calDataType(&left_expr_with_type, &right_expr_with_type), + false); + + default: + assert(false); + } + } + + virtual std::any + visitRelational(PlanParser::RelationalContext* ctx) override { + auto left_expr_with_type = + std::any_cast(ctx->expr()[0]->accept(this)); + auto right_expr_with_type = + std::any_cast(ctx->expr()[1]->accept(this)); + auto left_value = extractValue(left_expr_with_type.expr); + auto right_value = extractValue(right_expr_with_type.expr); + if (left_value.has_value() && right_value.has_value()) { +#define PROCESS_RELATIONAL(left_type, right_type) \ + if (left_value.type() == typeid(left_type) && \ + right_value.type() == typeid(right_type)) { \ + switch (ctx->op->getType()) { \ + case PlanParser::LT: \ + return ExprWithDtype( \ + createValueExpr( \ + std::any_cast(left_value) < \ + std::any_cast(right_value), \ + this->arena.get()), \ + proto::schema::DataType::Bool, \ + false); \ + case PlanParser::LE: \ + return ExprWithDtype( \ + createValueExpr( \ + std::any_cast(left_value) <= \ + std::any_cast(right_value), \ + this->arena.get()), \ + proto::schema::DataType::Bool, \ + false); \ + case PlanParser::GT: \ + return ExprWithDtype( \ + createValueExpr( \ + std::any_cast(left_value) > \ + std::any_cast(right_value), \ + this->arena.get()), \ + proto::schema::DataType::Bool, \ + false); \ + case PlanParser::GE: \ + return ExprWithDtype( \ + createValueExpr( \ + std::any_cast(left_value) >= \ + std::any_cast(right_value), \ + this->arena.get()), \ + proto::schema::DataType::Bool, \ + false); \ + default: \ + assert(false); \ + } \ + } + + PROCESS_RELATIONAL(double, double) + PROCESS_RELATIONAL(double, int64_t) + PROCESS_RELATIONAL(int64_t, double) + PROCESS_RELATIONAL(std::string, std::string) + PROCESS_RELATIONAL(int64_t, int64_t) + PROCESS_RELATIONAL(int32_t, int32_t) + } + + if (left_expr_with_type.expr->has_value_expr() && + !right_expr_with_type.expr->has_value_expr()) { + ExprWithDtype left = + toValueExpr(CreateMessageWithCopy( + this->arena.get(), + left_expr_with_type.expr->value_expr().value()), + this->arena.get()); + ExprWithDtype right = right_expr_with_type; + + return ExprWithDtype( + HandleCompare( + ctx->op->getType(), left, right, this->arena.get()), + proto::schema::DataType::Bool, + false); + } + + if (!left_expr_with_type.expr->has_value_expr() && + right_expr_with_type.expr->has_value_expr()) { + ExprWithDtype left = left_expr_with_type; + ExprWithDtype right = toValueExpr( + CreateMessageWithCopy( + this->arena.get(), + right_expr_with_type.expr->value_expr().value()), + this->arena.get()); + + return ExprWithDtype( + HandleCompare( + ctx->op->getType(), left, right, this->arena.get()), + proto::schema::DataType::Bool, + false); + } + + if (!left_expr_with_type.expr->has_value_expr() && + !right_expr_with_type.expr->has_value_expr()) { + return ExprWithDtype(HandleCompare(ctx->op->getType(), + left_expr_with_type, + right_expr_with_type, + this->arena.get()), + proto::schema::DataType::Bool, + false); + } + return nullptr; + } + + virtual std::any + visitArrayLength(PlanParser::ArrayLengthContext* ctx) override { + auto info = + getChildColumnInfo(ctx->Identifier(), ctx->JSONIdentifier()); + assert(info); + assert(info->data_type() == proto::schema::Array || + info->data_type() == proto::schema::JSON); + auto expr = google::protobuf::Arena::CreateMessage( + this->arena.get()); + auto bin_arith_expr = google::protobuf::Arena::CreateMessage< + proto::plan::BinaryArithExpr>(this->arena.get()); + auto column_expr = + google::protobuf::Arena::CreateMessage( + this->arena.get()); + column_expr->unsafe_arena_set_allocated_info(info); + auto left_expr = + google::protobuf::Arena::CreateMessage( + this->arena.get()); + left_expr->unsafe_arena_set_allocated_column_expr(column_expr); + bin_arith_expr->unsafe_arena_set_allocated_left(left_expr); + bin_arith_expr->set_op(proto::plan::ArithOpType::ArrayLength); + expr->unsafe_arena_set_allocated_binary_arith_expr(bin_arith_expr); + return ExprWithDtype(expr, proto::schema::DataType::Int64, false); + } + + virtual std::any + visitTerm(PlanParser::TermContext* ctx) override { + auto first_expr_with_type = + std::any_cast(ctx->expr()[0]->accept(this)); + auto info = + first_expr_with_type.expr->unsafe_arena_release_column_expr() + ->unsafe_arena_release_info(); + + auto term_expr = + google::protobuf::Arena::CreateMessage( + arena.get()); + for (size_t i = 1; i < ctx->expr().size(); ++i) { + auto elem = ctx->expr()[i]; + auto expr_ = std::any_cast(elem->accept(this)).expr; + auto v = google::protobuf::Arena::CreateMessage< + proto::plan::GenericValue>(arena.get()); + auto value = extractValue(expr_); + if (value.type() == typeid(int8_t)) { + v->set_int64_val(int64_t(std::any_cast(value))); + term_expr->mutable_values()->UnsafeArenaAddAllocated(v); + continue; + } + + if (value.type() == typeid(int64_t)) { + v->set_int64_val(int64_t(std::any_cast(value))); + term_expr->mutable_values()->UnsafeArenaAddAllocated(v); + continue; + } + + if (value.type() == typeid(int32_t)) { + v->set_int64_val(int64_t(std::any_cast(value))); + term_expr->mutable_values()->UnsafeArenaAddAllocated(v); + continue; + } + + if (value.type() == typeid(double)) { + v->set_float_val(double(std::any_cast(value))); + term_expr->mutable_values()->UnsafeArenaAddAllocated(v); + continue; + } + + if (value.type() == typeid(float)) { + v->set_float_val(double(std::any_cast(value))); + term_expr->mutable_values()->UnsafeArenaAddAllocated(v); + continue; + } + + if (value.type() == typeid(bool)) { + v->set_bool_val(std::any_cast(value)); + term_expr->mutable_values()->UnsafeArenaAddAllocated(v); + continue; + } + + if (value.type() == typeid(std::string)) { + v->set_string_val(std::any_cast(value)); + term_expr->mutable_values()->UnsafeArenaAddAllocated(v); + continue; + } + + assert(false); + } + auto expr = google::protobuf::Arena::CreateMessage( + arena.get()); + + term_expr->unsafe_arena_set_allocated_column_info(info); + expr->unsafe_arena_set_allocated_term_expr(term_expr); + if (ctx->op->getType() == PlanParser::NIN) { + auto root_expr = + google::protobuf::Arena::CreateMessage( + arena.get()); + auto unary_expr = + google::protobuf::Arena::CreateMessage( + arena.get()); + unary_expr->set_op(proto::plan::UnaryExpr_UnaryOp_Not); + unary_expr->set_allocated_child(expr); + return ExprWithDtype( + root_expr, proto::schema::DataType::Bool, false); + } + return ExprWithDtype(expr, proto::schema::DataType::Bool, false); + } + + virtual std::any + visitJSONContains(PlanParser::JSONContainsContext* ctx) override { + auto field = std::any_cast(ctx->expr()[0]->accept(this)); + auto info = field.expr->column_expr().info(); + assert(info.data_type() == proto::schema::DataType::Array || + info.data_type() == proto::schema::DataType::JSON); + auto elem = std::any_cast(ctx->expr()[1]->accept(this)); + if (info.data_type() == proto::schema::DataType::Array) { + proto::plan::GenericValue expr = + proto::plan::GenericValue(elem.expr->value_expr().value()); + assert(canBeCompared(field, toValueExpr(&expr, this->arena.get()))); + } + + auto expr = google::protobuf::Arena::CreateMessage( + this->arena.get()); + auto json_contain_expr = google::protobuf::Arena::CreateMessage< + proto::plan::JSONContainsExpr>(this->arena.get()); + auto value = json_contain_expr->add_elements(); + value->unsafe_arena_set_allocated_array_val( + CreateMessageWithCopy( + this->arena.get(), + elem.expr->value_expr().value().array_val())); + json_contain_expr->set_elements_same_type(true); + json_contain_expr->set_allocated_column_info( + CreateMessageWithCopy(this->arena.get(), info)); + json_contain_expr->unsafe_arena_set_allocated_column_info( + CreateMessageWithCopy(this->arena.get(), + info)); + json_contain_expr->set_op( + proto::plan::JSONContainsExpr_JSONOp_Contains); + expr->set_allocated_json_contains_expr(json_contain_expr); + return ExprWithDtype(expr, proto::schema::Bool, false); + } + + virtual std::any + visitRange(PlanParser::RangeContext* ctx) override { + auto info = + getChildColumnInfo(ctx->Identifier(), ctx->JSONIdentifier()); + assert(info != nullptr); + assert(checkDirectComparisonBinaryField(info)); + auto lower = std::any_cast(ctx->expr()[0]->accept(this)); + auto upper = std::any_cast(ctx->expr()[1]->accept(this)); + + if (info->data_type() == proto::schema::DataType::Int8 || + info->data_type() == proto::schema::DataType::Int16 || + info->data_type() == proto::schema::DataType::Int32 || + info->data_type() == proto::schema::DataType::Int64 || + info->data_type() == proto::schema::DataType::Float || + info->data_type() == proto::schema::DataType::Double || + info->data_type() == proto::schema::DataType::VarChar) { + auto a = extractValue(lower.expr); + auto b = extractValue(upper.expr); + if (a.has_value() && b.has_value()) { + bool lowerinclusive = ctx->op1->getType() == PlanParser::LE; + bool upperinclusive = ctx->op2->getType() == PlanParser::LE; + auto expr = + google::protobuf::Arena::CreateMessage( + this->arena.get()); + auto binary_range_expr = google::protobuf::Arena::CreateMessage< + proto::plan::BinaryRangeExpr>(this->arena.get()); + auto lower_value = google::protobuf::Arena::CreateMessage< + proto::plan::GenericValue>(this->arena.get()); + auto upper_value = google::protobuf::Arena::CreateMessage< + proto::plan::GenericValue>(this->arena.get()); + if (a.type() == typeid(int8_t)) + lower_value->set_int64_val( + int64_t(std::any_cast(a))); + if (a.type() == typeid(int16_t)) + lower_value->set_int64_val( + int64_t(std::any_cast(a))); + if (a.type() == typeid(int32_t)) + lower_value->set_int64_val( + int64_t(std::any_cast(a))); + if (a.type() == typeid(int64_t)) + lower_value->set_int64_val(std::any_cast(a)); + if (a.type() == typeid(double)) + lower_value->set_float_val(std::any_cast(a)); + if (a.type() == typeid(float)) + lower_value->set_float_val(double(std::any_cast(a))); + if (a.type() == typeid(std::string)) + lower_value->set_string_val(std::any_cast(a)); + + if (b.type() == typeid(int8_t)) + upper_value->set_int64_val( + int64_t(std::any_cast(b))); + if (b.type() == typeid(int16_t)) + upper_value->set_int64_val( + int64_t(std::any_cast(b))); + if (b.type() == typeid(int32_t)) + upper_value->set_int64_val( + int64_t(std::any_cast(b))); + if (b.type() == typeid(int64_t)) + upper_value->set_int64_val( + int64_t(std::any_cast(b))); + if (b.type() == typeid(double)) + upper_value->set_float_val(std::any_cast(b)); + if (b.type() == typeid(float)) + upper_value->set_float_val(double(std::any_cast(b))); + if (b.type() == typeid(std::string)) + upper_value->set_string_val(std::any_cast(b)); + + binary_range_expr->set_lower_inclusive(lowerinclusive); + binary_range_expr->set_upper_inclusive(upperinclusive); + binary_range_expr->unsafe_arena_set_allocated_column_info(info); + + binary_range_expr->unsafe_arena_set_allocated_lower_value( + lower_value); + binary_range_expr->unsafe_arena_set_allocated_upper_value( + upper_value); + expr->set_allocated_binary_range_expr(binary_range_expr); + return ExprWithDtype( + expr, proto::schema::DataType::Bool, false); + } + } + + return nullptr; + } + + virtual std::any + visitUnary(PlanParser::UnaryContext* ctx) override { + auto expr_with_dtype = + std::any_cast(ctx->expr()->accept(this)); + auto value = extractValue(expr_with_dtype.expr); + if (value.has_value()) { +#define PROCESS_UNARY(dtype, schema_dtype) \ + if (value.type() == typeid(dtype)) { \ + switch (ctx->op->getType()) { \ + case PlanParser::ADD: \ + return expr_with_dtype; \ + case PlanParser::SUB: \ + return ExprWithDtype( \ + createValueExpr(-std::any_cast(value), \ + this->arena.get()), \ + schema_dtype, \ + false); \ + case PlanParser::NOT: \ + return ExprWithDtype( \ + createValueExpr(!std::any_cast(value), \ + this->arena.get()), \ + proto::schema::DataType::Bool, \ + false); \ + default: \ + assert(false); \ + } \ + } + + PROCESS_UNARY(double, proto::schema::DataType::Float); + PROCESS_UNARY(float, proto::schema::DataType::Float); + PROCESS_UNARY(int8_t, proto::schema::DataType::Int64); + PROCESS_UNARY(int32_t, proto::schema::DataType::Int64); + PROCESS_UNARY(int64_t, proto::schema::DataType::Int64); + PROCESS_UNARY(bool, proto::schema::DataType::Bool); + } + + assert(checkDirectComparisonBinaryField( + CreateMessageWithCopy( + this->arena.get(), + expr_with_dtype.expr->column_expr().info()))); + + switch (ctx->op->getType()) { + case PlanParser::ADD: + return expr_with_dtype.expr; + case PlanParser::NOT: + assert(!expr_with_dtype.dependent && + expr_with_dtype.dtype == proto::schema::DataType::Bool); + auto expr = + google::protobuf::Arena::CreateMessage( + this->arena.get()); + auto unary_expr = google::protobuf::Arena::CreateMessage< + proto::plan::UnaryExpr>(this->arena.get()); + unary_expr->unsafe_arena_set_allocated_child( + expr_with_dtype.expr); + unary_expr->set_op(proto::plan::UnaryExpr_UnaryOp_Not); + return ExprWithDtype(expr, proto::schema::Bool, false); + } + return nullptr; + } + + virtual std::any + visitArray(PlanParser::ArrayContext* ctx) override { + auto expr = google::protobuf::Arena::CreateMessage( + this->arena.get()); + auto array_expr = + google::protobuf::Arena::CreateMessage( + this->arena.get()); + auto dtype = proto::schema::DataType::None; + + auto is_same = true; + for (auto&& elem : ctx->expr()) { + auto expr_ = std::any_cast(elem->accept(this)).expr; + auto v = array_expr->add_array(); + auto value = extractValue(expr_); + if (value.has_value()) { + if (value.type() == typeid(int8_t)) { + v->set_int64_val(int64_t(std::any_cast(value))); + if (dtype != proto::schema::DataType::None && + dtype != proto::schema::DataType::Int8) { + is_same = false; + } + if (dtype == proto::schema::DataType::None) { + dtype = proto::schema::DataType::Int8; + } + continue; + } + if (value.type() == typeid(int16_t)) { + v->set_int64_val(int64_t(std::any_cast(value))); + if (dtype != proto::schema::DataType::None && + dtype != proto::schema::DataType::Int16) { + is_same = false; + } + if (dtype == proto::schema::DataType::None) { + dtype = proto::schema::DataType::Int16; + } + continue; + } + if (value.type() == typeid(int32_t)) { + v->set_int64_val(int64_t(std::any_cast(value))); + if (dtype != proto::schema::DataType::None && + dtype != proto::schema::DataType::Int32) { + is_same = false; + } + if (dtype == proto::schema::DataType::None) { + dtype = proto::schema::DataType::Int32; + } + continue; + } + + if (value.type() == typeid(int64_t)) { + v->set_int64_val(std::any_cast(value)); + if (dtype != proto::schema::DataType::None && + dtype != proto::schema::DataType::Int64) { + is_same = false; + } + if (dtype == proto::schema::DataType::None) { + dtype = proto::schema::DataType::Int64; + } + continue; + } + + if (value.type() == typeid(double)) { + v->set_float_val(std::any_cast(value)); + if (dtype != proto::schema::DataType::None && + dtype != proto::schema::DataType::Double) { + is_same = false; + } + if (dtype == proto::schema::DataType::None) { + dtype = proto::schema::DataType::Double; + } + continue; + } + + if (value.type() == typeid(float)) { + v->set_float_val(std::any_cast(value)); + if (dtype != proto::schema::DataType::None && + dtype != proto::schema::DataType::Float) { + is_same = false; + } + if (dtype == proto::schema::DataType::None) { + dtype = proto::schema::DataType::Float; + } + continue; + } + } + } + + auto generic_value = + google::protobuf::Arena::CreateMessage( + this->arena.get()); + + generic_value->unsafe_arena_set_allocated_array_val(array_expr); + + auto value_expr = + google::protobuf::Arena::CreateMessage( + this->arena.get()); + value_expr->unsafe_arena_set_allocated_value(generic_value); + expr->unsafe_arena_set_allocated_value_expr(value_expr); + return ExprWithDtype( + expr, is_same ? dtype : proto::schema::DataType::None, true); + } + virtual std::any + visitJSONContainsAny(PlanParser::JSONContainsAnyContext* ctx) override { + auto field = std::any_cast(ctx->expr()[0]->accept(this)); + auto info = field.expr->column_expr().info(); + assert(info.data_type() == proto::schema::DataType::Array || + info.data_type() == proto::schema::DataType::JSON); + auto elem = std::any_cast(ctx->expr()[1]->accept(this)); + if (info.data_type() == proto::schema::DataType::Array) { + proto::plan::GenericValue expr = + proto::plan::GenericValue(elem.expr->value_expr().value()); + assert(canBeCompared(field, toValueExpr(&expr, this->arena.get()))); + } + + auto expr = google::protobuf::Arena::CreateMessage( + this->arena.get()); + auto json_contain_expr = google::protobuf::Arena::CreateMessage< + proto::plan::JSONContainsExpr>(this->arena.get()); + + auto value = json_contain_expr->add_elements(); + value->unsafe_arena_set_allocated_array_val( + CreateMessageWithCopy( + this->arena.get(), + elem.expr->value_expr().value().array_val())); + + json_contain_expr->set_elements_same_type( + elem.expr->value_expr().value().array_val().same_type()); + json_contain_expr->unsafe_arena_set_allocated_column_info( + CreateMessageWithCopy(this->arena.get(), + info)); + json_contain_expr->set_op( + proto::plan::JSONContainsExpr_JSONOp_ContainsAny); + expr->unsafe_arena_set_allocated_json_contains_expr(json_contain_expr); + return ExprWithDtype(expr, proto::schema::Bool, false); + } + + virtual std::any + visitExists(PlanParser::ExistsContext* ctx) override { + auto a = std::any_cast(ctx->expr()); + auto info = a.expr->column_expr().info(); + assert(info.data_type() == proto::schema::DataType::Array); + auto expr = google::protobuf::Arena::CreateMessage( + this->arena.get()); + + auto col_expr = + google::protobuf::Arena::CreateMessage( + this->arena.get()); + col_expr->unsafe_arena_set_allocated_info( + CreateMessageWithCopy(this->arena.get(), + info)); + expr->unsafe_arena_set_allocated_column_expr(col_expr); + return ExprWithDtype(expr, proto::schema::DataType::Bool, false); + } + + virtual std::any + visitEmptyTerm(PlanParser::EmptyTermContext* ctx) override { + auto first = std::any_cast(ctx->expr()->accept(this)); + auto info = first.expr->column_expr().info(); + + auto expr = google::protobuf::Arena::CreateMessage( + this->arena.get()); + auto col_expr = + google::protobuf::Arena::CreateMessage( + this->arena.get()); + auto term_expr = + google::protobuf::Arena::CreateMessage( + this->arena.get()); + + expr->unsafe_arena_set_allocated_term_expr(term_expr); + col_expr->unsafe_arena_set_allocated_info( + CreateMessageWithCopy(this->arena.get(), + info)); + expr->unsafe_arena_set_allocated_column_expr(col_expr); + expr->unsafe_arena_set_allocated_term_expr(term_expr); + return ExprWithDtype(expr, proto::schema::DataType::Bool, false); + } + + PlanCCVisitor(SchemaHelper* const helper) + : helper(helper), arena(std::make_shared()) { + } + + private: + SchemaHelper* helper; + std::shared_ptr arena; +}; + +std::string +ParserToMessage(milvus::proto::schema::CollectionSchema& schema, + const std::string& exprstr); + +std::shared_ptr +ParseIdentifier(milvus::local::SchemaHelper helper, + const std::string& identifier); + +} // namespace milvus::local diff --git a/src/parser/utils.h b/src/parser/utils.h new file mode 100644 index 0000000..5759e32 --- /dev/null +++ b/src/parser/utils.h @@ -0,0 +1,789 @@ +#pragma once +#include +#include +#include +#include +#include +#include "antlr/PlanBaseVisitor.h" +#include "antlr/PlanLexer.h" +#include "antlr/PlanParser.h" +#include "pb/plan.pb.h" + +namespace milvus::local { + +template +inline T* +CreateMessageWithCopy(google::protobuf::Arena* arena, const T& val) { + T* ret = google::protobuf::Arena::CreateMessage(arena); + ret->CopyFrom(val); + return ret; +} + +struct ExprWithDtype { + proto::plan::Expr* expr; + proto::schema::DataType dtype; + bool dependent; + ExprWithDtype(proto::plan::Expr* const expr, + proto::schema::DataType dtype, + bool dependent) + : expr(expr), dtype(dtype), dependent(dependent) { + } +}; + +inline std::any +extractValue(proto::plan::Expr* expr) { + if (!expr->has_value_expr()) + return nullptr; + + if (!expr->value_expr().has_value()) + return nullptr; + if (expr->value_expr().value().has_bool_val()) + return expr->value_expr().value().bool_val(); + if (expr->value_expr().value().has_float_val()) + return expr->value_expr().value().float_val(); + if (expr->value_expr().value().has_string_val()) + return expr->value_expr().value().string_val(); + if (expr->value_expr().value().has_int64_val()) + return expr->value_expr().value().int64_val(); + + return nullptr; +} + +template +inline proto::plan::Expr* +createValueExpr(const T val, google::protobuf::Arena* arena = nullptr) { + auto expr = + google::protobuf::Arena::CreateMessage(arena); + auto val_expr = + google::protobuf::Arena::CreateMessage(arena); + auto value = + google::protobuf::Arena::CreateMessage( + arena); + if constexpr (std::is_same_v) + value->set_int64_val(val); + else if constexpr (std::is_same_v) + value->set_int64_val(int64_t(val)); + else if constexpr (std::is_same_v) + value->set_int64_val(int64_t(val)); + else if constexpr (std::is_same_v) + value->set_int64_val(int16_t(val)); + else if constexpr (std::is_same_v) + value->set_string_val(val); + else if constexpr (std::is_same_v || std::is_same_v) + value->set_float_val(val); + else if constexpr (std::is_same_v) + value->set_bool_val(val); + else + assert(false); + + val_expr->unsafe_arena_set_allocated_value(value); + expr->unsafe_arena_set_allocated_value_expr(val_expr); + return expr; +} + +inline std::vector +tokenize(std::string s, std::string del = " ") { + std::vector results; + int start, end = -1 * del.size(); + do { + start = end + del.size(); + end = s.find(del, start); + results.push_back(s.substr(start, end - start)); + } while (end != -1); + return results; +} + +template +inline proto::plan::Expr* +createBinExpr(proto::plan::Expr* left, + proto::plan::Expr* right, + google::protobuf::Arena* arena = nullptr) { + auto expr = + google::protobuf::Arena::CreateMessage(arena); + auto bin_expr = + google::protobuf::Arena::CreateMessage(arena); + bin_expr->set_op(T); + bin_expr->unsafe_arena_set_allocated_left(left); + bin_expr->unsafe_arena_set_allocated_right(right); + expr->unsafe_arena_set_allocated_binary_expr(bin_expr); + return expr; +} + +template +inline proto::plan::Expr* +createBinArithExpr(proto::plan::Expr* left, + proto::plan::Expr* right, + google::protobuf::Arena* arena = nullptr) { + auto expr = + google::protobuf::Arena::CreateMessage(arena); + auto bin_expr = + google::protobuf::Arena::CreateMessage( + arena); + bin_expr->set_op(T); + bin_expr->unsafe_arena_set_allocated_left(left); + bin_expr->unsafe_arena_set_allocated_right(right); + expr->unsafe_arena_set_allocated_binary_arith_expr(bin_expr); + return expr; +} + +inline bool +arithmeticDtype(proto::schema::DataType type) { + switch (type) { + case proto::schema::DataType::Float: + return true; + case proto::schema::DataType::Double: + return true; + case proto::schema::DataType::Int8: + return true; + case proto::schema::DataType::Int16: + return true; + case proto::schema::DataType::Int32: + return true; + case proto::schema::DataType::Int64: + return true; + default: + return false; + } +} + +inline proto::schema::DataType +getArrayElementType(proto::plan::Expr* expr) { + if (expr->has_column_expr()) { + return expr->column_expr().info().data_type(); + } + if (expr->has_value_expr() && expr->value_expr().has_value() && + expr->value_expr().value().has_array_val()) { + return expr->value_expr().value().array_val().element_type(); + } + + return proto::schema::DataType::None; +} + +inline bool +canArithmeticDtype(proto::schema::DataType left_type, + proto::schema::DataType right_type) { + if (left_type == proto::schema::DataType::JSON && + right_type == proto::schema::DataType::JSON) + return false; + if (left_type == proto::schema::DataType::JSON && + arithmeticDtype(right_type)) + return true; + if (arithmeticDtype(left_type) && + right_type == proto::schema::DataType::JSON) + return true; + if (arithmeticDtype(left_type) && arithmeticDtype(right_type)) + return true; + return false; +} + +inline proto::schema::DataType +calDataType(ExprWithDtype* a, ExprWithDtype* b) { + auto a_dtype = a->dtype; + auto b_dtype = b->dtype; + if (a->dtype == proto::schema::DataType::Array) { + a_dtype = getArrayElementType(a->expr); + } + if (b->dtype == proto::schema::DataType::Array) { + b_dtype = getArrayElementType(b->expr); + } + if (a_dtype == proto::schema::DataType::JSON) { + if (b_dtype == proto::schema::DataType::Float || + b_dtype == proto::schema::DataType::Double) + return proto::schema::DataType::Float; + if (b_dtype == proto::schema::DataType::Int8 || + b_dtype == proto::schema::DataType::Int16 || + b_dtype == proto::schema::DataType::Int32 || + b_dtype == proto::schema::DataType::Int64) + return proto::schema::DataType::Int64; + if (b_dtype == proto::schema::DataType::JSON) + return proto::schema::DataType::JSON; + } + + if (a_dtype == proto::schema::DataType::Float || + a_dtype == proto::schema::DataType::Double) { + if (b_dtype == proto::schema::DataType::JSON) + return proto::schema::DataType::Double; + if (arithmeticDtype(b_dtype)) + return proto::schema::DataType::Double; + } + + if (a_dtype == proto::schema::DataType::Int8 || + a_dtype == proto::schema::DataType::Int16 || + a_dtype == proto::schema::DataType::Int32 || + a_dtype == proto::schema::DataType::Int64) { + if (b_dtype == proto::schema::DataType::Float || + b_dtype == proto::schema::DataType::Double) + return proto::schema::DataType::Double; + if (b_dtype == proto::schema::DataType::Int8 || + b_dtype == proto::schema::DataType::Int16 || + b_dtype == proto::schema::DataType::Int32 || + b_dtype == proto::schema::DataType::Int64 || + b_dtype == proto::schema::DataType::JSON) + return proto::schema::DataType::Int64; + } + assert(false); +} + +struct SchemaHelper { + SchemaHelper() = default; + + proto::schema::CollectionSchema* schema = nullptr; + std::map name_offset; + std::map id_offset; + + int primary_key_offset = -1; + int partition_key_offset = -1; + + const proto::schema::FieldSchema& + GetPrimaryKeyField() { + assert(primary_key_offset != -1); + return schema->fields(primary_key_offset); + } + + const proto::schema::FieldSchema& + GetPartitionKeyField() { + assert(partition_key_offset != -1); + return schema->fields(partition_key_offset); + } + + const proto::schema::FieldSchema& + GetFieldFromName(const std::string& name) { + auto it = name_offset.find(name); + assert(it != name_offset.end()); + return schema->fields(it->second); + } + + const proto::schema::FieldSchema& + GetFieldFromNameDefaultJSON(const std::string& name) { + auto it = name_offset.find(name); + if (it == name_offset.end()) { + return GetDefaultJSONField(); + } + return schema->fields(it->second); + } + + const proto::schema::FieldSchema& + GetDefaultJSONField() { + for (int i = 0; i < schema->fields_size(); ++i) { + auto& field = schema->fields(i); + if (field.data_type() == proto::schema::DataType::JSON && + field.is_dynamic()) + return field; + } + + assert(false); + } + + const proto::schema::FieldSchema& + GetFieldFromID(int64_t id) { + auto it = id_offset.find(id); + assert(it != id_offset.end()); + return schema->fields(it->second); + } + + int + GetVectorDimFromID(int64_t id) { + auto& field = GetFieldFromID(id); + if (field.data_type() != proto::schema::DataType::FloatVector && + field.data_type() != proto::schema::DataType::Float16Vector && + field.data_type() != proto::schema::DataType::BinaryVector && + field.data_type() != proto::schema::DataType::BFloat16Vector) { + assert(false); + } + for (int i = 0; i < field.type_params_size(); ++i) { + if (field.type_params(i).key() == "dim") + return std::stoi( + field.type_params(i).value().c_str(), NULL, 10); + } + assert(false); + } +}; + +inline SchemaHelper +CreateSchemaHelper(proto::schema::CollectionSchema* schema) { + assert(schema); + SchemaHelper schema_helper; + schema_helper.schema = schema; + for (int i = 0; i < schema->fields_size(); ++i) { + auto field = schema->fields(i); + auto it = schema_helper.name_offset.find(field.name()); + if (it != schema_helper.name_offset.end()) + assert(false); + schema_helper.name_offset[field.name()] = i; + schema_helper.id_offset[field.fieldid()] = i; + if (field.is_primary_key()) { + assert(schema_helper.primary_key_offset != -1); + schema_helper.primary_key_offset = i; + } + if (field.is_partition_key()) { + assert(schema_helper.primary_key_offset != -1); + schema_helper.partition_key_offset = i; + } + } + return schema_helper; +} + +inline std::string +convertEscapeSingle(const std::string& in) { + std::vector need_replace_index; + size_t escape_ch_count = 0; + size_t in_string_lenth = in.length(); + for (size_t i = 1; i < in_string_lenth - 1; ++i) { + if (in[i] == '\\') { + escape_ch_count++; + continue; + } + if (in[i] == '"' && escape_ch_count % 2 == 0) { + need_replace_index.push_back(i); + } + + if (in[i] == '\'' && escape_ch_count % 2 != 0) { + need_replace_index.push_back(i); + } + + escape_ch_count = 0; + } + + std::string in_; + in_ += '"'; + size_t start = 1; + for (auto end : need_replace_index) { + if (in[end] == '"') { + in_ += in.substr(start, end - start); + in_ += "\\\""; + } else { + in_ += in.substr(start, end - start - 1); + in_ += '\''; + } + start = end + 1; + } + + in_ += in.substr(start, in.length() - start - 1); + + in_ += '"'; + std::stringstream ss; + ss << in_; + std::string out; + ss >> std::quoted(out); + + return out; +} + +inline bool +hasWildcards(std::string pattern) { + int64_t l = pattern.length(); + int64_t i = l - 1; + for (; i >= 0; i--) { + if (pattern[i] == '%' || pattern[i] == '_') { + if (i > 0 && pattern[i - 1] == '\\') { + i--; + continue; + } + return true; + } + } + return false; +} + +inline int +findLastNotOfWildcards(std::string pattern) { + int loc = pattern.length() - 1; + for (; loc >= 0; loc--) { + if (pattern[loc] == '%' || pattern[loc] == '_') { + if (loc > 0 && pattern[loc - 1] == '\\') { + break; + } + } else { + break; + } + } + return loc; +} + +inline std::pair +translatePatternMatch(const std::string& pattern) { + size_t l = pattern.length(); + size_t loc = findLastNotOfWildcards(pattern); + if (loc < 0) { + return std::make_pair(proto::plan::OpType::PrefixMatch, ""); + } + bool exist = hasWildcards(pattern.substr(0, loc + 1)); + + if (loc >= l - 1 && !exist) { + return std::make_pair(proto::plan::OpType::Equal, pattern); + } + + if (!exist) { + return std::make_pair(proto::plan::OpType::PrefixMatch, + pattern.substr(0, loc + 1)); + } + + return std::make_pair(proto::plan::OpType::Match, pattern); +} + +inline bool +canBeComparedDataType(proto::schema::DataType a, proto::schema::DataType b) { + switch (a) { + case proto::schema::DataType::Bool: + return (b == proto::schema::DataType::Bool) || + (b == proto::schema::DataType::JSON); + case proto::schema::DataType::Int8: + return arithmeticDtype(b) || (b == proto::schema::DataType::JSON); + case proto::schema::DataType::Int16: + return arithmeticDtype(b) || (b == proto::schema::DataType::JSON); + case proto::schema::DataType::Int32: + return arithmeticDtype(b) || (b == proto::schema::DataType::JSON); + case proto::schema::DataType::Int64: + return arithmeticDtype(b) || (b == proto::schema::DataType::JSON); + case proto::schema::DataType::Float: + return arithmeticDtype(b) || (b == proto::schema::DataType::JSON); + case proto::schema::DataType::Double: + return arithmeticDtype(b) || (b == proto::schema::DataType::JSON); + case proto::schema::DataType::VarChar: + return b == proto::schema::DataType::String || + b == proto::schema::DataType::VarChar || + b == proto::schema::DataType::JSON; + case proto::schema::DataType::String: + return b == proto::schema::DataType::String || + b == proto::schema::DataType::VarChar || + b == proto::schema::DataType::JSON; + case proto::schema::DataType::JSON: + return true; + default: + return false; + } +} + +inline bool +canBeCompared(ExprWithDtype a, ExprWithDtype b) { + if (a.dtype != proto::schema::DataType::Array && + b.dtype != proto::schema::DataType::Array) { + return canBeComparedDataType(a.dtype, b.dtype); + } + + if (a.dtype == proto::schema::DataType::Array && + b.dtype == proto::schema::DataType::Array) { + return canBeComparedDataType(getArrayElementType(a.expr), + getArrayElementType(b.expr)); + } + + if (a.dtype == proto::schema::DataType::Array) { + return canBeComparedDataType(getArrayElementType(a.expr), b.dtype); + } + + return canBeComparedDataType(b.dtype, getArrayElementType(b.expr)); +} + +inline ExprWithDtype +toValueExpr(proto::plan::GenericValue* value, + google::protobuf::Arena* arena = nullptr) { + auto expr = + google::protobuf::Arena::CreateMessage(arena); + + auto value_expr = + google::protobuf::Arena::CreateMessage(arena); + value_expr->unsafe_arena_set_allocated_value(value); + + expr->unsafe_arena_set_allocated_value_expr(value_expr); + if (value->has_bool_val()) { + return ExprWithDtype(expr, proto::schema::DataType::Bool, false); + } + if (value->has_int64_val()) { + return ExprWithDtype(expr, proto::schema::DataType::Int64, false); + } + if (value->has_float_val()) { + return ExprWithDtype(expr, proto::schema::DataType::Float, false); + } + if (value->has_string_val()) { + return ExprWithDtype(expr, proto::schema::DataType::VarChar, false); + } + if (value->has_array_val()) { + return ExprWithDtype(expr, proto::schema::DataType::Array, false); + } + assert(false); +} + +inline proto::plan::GenericValue* +castValue(proto::schema::DataType dtype, + proto::plan::GenericValue* value, + google::protobuf::Arena* arena = nullptr) { + if (dtype == proto::schema::DataType::JSON) + return CreateMessageWithCopy(arena, *value); + if (dtype == proto::schema::DataType::Array && value->has_array_val()) + return CreateMessageWithCopy(arena, *value); + if (dtype == proto::schema::DataType::VarChar && value->has_string_val()) + return CreateMessageWithCopy(arena, *value); + + if (dtype == proto::schema::DataType::Bool && value->has_bool_val()) + return CreateMessageWithCopy(arena, *value); + + if (dtype == proto::schema::DataType::Float || + dtype == proto::schema::DataType::Double) { + if (value->has_float_val()) + return CreateMessageWithCopy(arena, + *value); + ; + if (value->has_int64_val()) { + auto value_tmp = google::protobuf::Arena::CreateMessage< + proto::plan::GenericValue>(arena); + value_tmp->set_float_val(double(value->int64_val())); + return value_tmp; + } + } + + if (dtype == proto::schema::DataType::Int8 || + dtype == proto::schema::DataType::Int16 || + dtype == proto::schema::DataType::Int32 || + dtype == proto::schema::DataType::Int64) { + if (value->has_int64_val()) + return CreateMessageWithCopy(arena, + *value); + } + + assert(false); +} + +inline proto::plan::Expr* +combineArrayLengthExpr(proto::plan::OpType op, + proto::plan::ArithOpType arith_op, + const proto::plan::ColumnInfo& info, + const proto::plan::GenericValue& value, + google::protobuf::Arena* arena = nullptr) { + auto expr = + google::protobuf::Arena::CreateMessage(arena); + auto range_expr = google::protobuf::Arena::CreateMessage< + proto::plan::BinaryArithOpEvalRangeExpr>(arena); + expr->unsafe_arena_set_allocated_binary_arith_op_eval_range_expr( + range_expr); + range_expr->set_op(op); + range_expr->set_arith_op(arith_op); + range_expr->unsafe_arena_set_allocated_value( + CreateMessageWithCopy(arena, value)); + range_expr->unsafe_arena_set_allocated_column_info( + CreateMessageWithCopy(arena, info)); + return expr; +} + +inline proto::plan::Expr* +combineBinaryArithExpr(proto::plan::OpType op, + proto::plan::ArithOpType arith_op, + const proto::plan::ColumnInfo& info, + const proto::plan::GenericValue& operand, + const proto::plan::GenericValue& value, + google::protobuf::Arena* arena = nullptr) { + auto data_type = info.data_type(); + if (data_type != proto::schema::DataType::Array && + info.nested_path_size() != 0) { + data_type = info.element_type(); + } + auto casted_value = castValue( + data_type, + CreateMessageWithCopy(arena, operand)); + auto expr = + google::protobuf::Arena::CreateMessage(arena); + auto range_expr = google::protobuf::Arena::CreateMessage< + proto::plan::BinaryArithOpEvalRangeExpr>(arena); + expr->unsafe_arena_set_allocated_binary_arith_op_eval_range_expr( + range_expr); + range_expr->unsafe_arena_set_allocated_column_info( + CreateMessageWithCopy(arena, info) + + ); + range_expr->set_arith_op(arith_op); + range_expr->unsafe_arena_set_allocated_right_operand(casted_value); + range_expr->unsafe_arena_set_allocated_value( + CreateMessageWithCopy(arena, value)); + range_expr->set_op(op); + + return expr; +} + +inline proto::plan::Expr* +handleBinaryArithExpr(proto::plan::OpType op, + proto::plan::BinaryArithExpr* arith_expr, + proto::plan::ValueExpr* value_expr, + google::protobuf::Arena* arena = nullptr) { + switch (op) { + case proto::plan::OpType::Equal: + break; + case proto::plan::OpType::NotEqual: + break; + default: + assert(false); + } + + auto left_expr = arith_expr->left().column_expr(); + auto left_value = arith_expr->left().value_expr(); + auto right_expr = arith_expr->right().column_expr(); + auto right_value = arith_expr->right().value_expr(); + + auto arith_op = arith_expr->op(); + + if (arith_op == proto::plan::ArithOpType::ArrayLength) { + return combineArrayLengthExpr( + op, arith_op, left_expr.info(), value_expr->value(), arena); + } + if (arith_expr->left().has_column_expr() && + arith_expr->right().has_column_expr()) { + assert(false); + } + if (arith_expr->left().has_value_expr() && + arith_expr->right().has_value_expr()) { + assert(false); + } + if (arith_expr->left().has_column_expr() && + arith_expr->right().has_value_expr()) { + return combineBinaryArithExpr(op, + arith_op, + left_expr.info(), + right_value.value(), + value_expr->value(), + arena); + } + if (arith_expr->right().has_column_expr() && + arith_expr->left().has_value_expr()) { + switch (arith_expr->op()) { + case proto::plan::ArithOpType::Add: + return combineBinaryArithExpr(op, + arith_op, + right_expr.info(), + left_value.value(), + value_expr->value(), + arena); + case proto::plan::ArithOpType::Mul: + return combineBinaryArithExpr(op, + arith_op, + right_expr.info(), + left_value.value(), + value_expr->value(), + arena); + default: + assert(false); + } + } + assert(false); +} + +inline proto::plan::Expr* +handleCompareRightValue(proto::plan::OpType op, + ExprWithDtype a, + ExprWithDtype b, + google::protobuf::Arena* arena = nullptr) { + auto data_type = a.dtype; + if (data_type == proto::schema::DataType::Array && + a.expr->column_expr().info().nested_path_size() != 0) { + data_type = a.expr->column_expr().info().element_type(); + } + auto value = b.expr->value_expr().value(); + auto castedvalue = castValue(data_type, &value, arena); + if (a.expr->has_binary_arith_expr()) { + auto value_expr = + google::protobuf::Arena::CreateMessage( + arena); + value_expr->unsafe_arena_set_allocated_value(castedvalue); + return handleBinaryArithExpr( + op, + CreateMessageWithCopy( + arena, a.expr->binary_arith_expr()), + value_expr, + arena); + } + + assert(a.expr->has_column_expr()); + auto info = a.expr->column_expr().info(); + + auto expr = + google::protobuf::Arena::CreateMessage(arena); + auto unary_range_expr = + google::protobuf::Arena::CreateMessage( + arena); + + unary_range_expr->set_op(op); + unary_range_expr->unsafe_arena_set_allocated_column_info( + CreateMessageWithCopy(arena, info)); + + unary_range_expr->unsafe_arena_set_allocated_value(castedvalue); + + expr->unsafe_arena_set_allocated_unary_range_expr(unary_range_expr); + + return expr; +} + +inline proto::plan::Expr* +HandleCompare(proto::plan::OpType op, + ExprWithDtype a, + ExprWithDtype b, + google::protobuf::Arena* arena = nullptr) { + assert(a.expr->has_column_expr() && b.expr->has_column_expr()); + + auto a_info = a.expr->column_expr().info(); + + auto b_info = b.expr->column_expr().info(); + + auto expr = + google::protobuf::Arena::CreateMessage(arena); + auto compare_expr = + google::protobuf::Arena::CreateMessage(arena); + compare_expr->unsafe_arena_set_allocated_left_column_info( + CreateMessageWithCopy(arena, a_info)); + compare_expr->unsafe_arena_set_allocated_right_column_info( + CreateMessageWithCopy(arena, b_info)); + compare_expr->set_op(op); + expr->unsafe_arena_set_allocated_compare_expr(compare_expr); + + return expr; +} + +inline proto::plan::OpType +reverseOrder(proto::plan::OpType op) { + switch (op) { + case proto::plan::OpType::LessThan: + return proto::plan::OpType::GreaterThan; + case proto::plan::OpType::LessEqual: + return proto::plan::OpType::GreaterEqual; + case proto::plan::OpType::GreaterThan: + return proto::plan::OpType::LessThan; + case proto::plan::OpType::GreaterEqual: + return proto::plan::OpType::LessEqual; + case proto::plan::OpType::Equal: + return proto::plan::OpType::NotEqual; + case proto::plan::OpType::NotEqual: + return proto::plan::OpType::Equal; + default: + return proto::plan::OpType::Invalid; + } +} + +inline proto::plan::Expr* +HandleCompare(int op, + ExprWithDtype a, + ExprWithDtype b, + google::protobuf::Arena* arena = nullptr) { + assert(canBeCompared(a, b)); + std::map cmpOpMap{ + {PlanParser::LT, proto::plan::OpType::LessThan}, + {PlanParser::LE, proto::plan::OpType::LessEqual}, + {PlanParser::GT, proto::plan::OpType::GreaterThan}, + {PlanParser::GE, proto::plan::OpType::GreaterEqual}, + {PlanParser::EQ, proto::plan::OpType::Equal}, + {PlanParser::NE, proto::plan::OpType::NotEqual}}; + auto cmpop = cmpOpMap[op]; + if (a.expr->has_value_expr()) { + auto op = reverseOrder(cmpop); + return handleCompareRightValue(op, b, a, arena); + } else if (b.expr->has_value_expr()) { + return handleCompareRightValue(cmpop, a, b, arena); + } + + return HandleCompare(cmpop, a, b, arena); +} + +inline bool +checkDirectComparisonBinaryField(proto::plan::ColumnInfo* info) { + if (info->data_type() == proto::schema::DataType::Array && + info->nested_path_size() == 0) { + return false; + } + return true; +} +} // namespace milvus::local diff --git a/src/proto/common.proto b/src/proto/common.proto new file mode 120000 index 0000000..20a63ed --- /dev/null +++ b/src/proto/common.proto @@ -0,0 +1 @@ +../../thirdparty/milvus-proto/proto/common.proto \ No newline at end of file diff --git a/src/proto/feder.proto b/src/proto/feder.proto new file mode 120000 index 0000000..f55ed58 --- /dev/null +++ b/src/proto/feder.proto @@ -0,0 +1 @@ +../../thirdparty/milvus-proto/proto/feder.proto \ No newline at end of file diff --git a/src/proto/manifest.proto b/src/proto/manifest.proto new file mode 120000 index 0000000..e945deb --- /dev/null +++ b/src/proto/manifest.proto @@ -0,0 +1 @@ +../../thirdparty/milvus-storage/cpp/src/proto/manifest.proto \ No newline at end of file diff --git a/src/proto/milvus.proto b/src/proto/milvus.proto new file mode 120000 index 0000000..9b30eee --- /dev/null +++ b/src/proto/milvus.proto @@ -0,0 +1 @@ +../../thirdparty/milvus-proto/proto/milvus.proto \ No newline at end of file diff --git a/src/proto/msg.proto b/src/proto/msg.proto new file mode 120000 index 0000000..60e93ef --- /dev/null +++ b/src/proto/msg.proto @@ -0,0 +1 @@ +../../thirdparty/milvus-proto/proto/msg.proto \ No newline at end of file diff --git a/src/proto/plan.proto b/src/proto/plan.proto new file mode 120000 index 0000000..d57afe4 --- /dev/null +++ b/src/proto/plan.proto @@ -0,0 +1 @@ +../../thirdparty/milvus/internal/proto/plan.proto \ No newline at end of file diff --git a/src/proto/rg.proto b/src/proto/rg.proto new file mode 120000 index 0000000..d0732f3 --- /dev/null +++ b/src/proto/rg.proto @@ -0,0 +1 @@ +../../thirdparty/milvus-proto/proto/rg.proto \ No newline at end of file diff --git a/src/proto/schema.proto b/src/proto/schema.proto new file mode 120000 index 0000000..8df60d6 --- /dev/null +++ b/src/proto/schema.proto @@ -0,0 +1 @@ +../../thirdparty/milvus-proto/proto/schema.proto \ No newline at end of file diff --git a/src/proto/segcore.proto b/src/proto/segcore.proto new file mode 120000 index 0000000..ef3b22d --- /dev/null +++ b/src/proto/segcore.proto @@ -0,0 +1 @@ +../../thirdparty/milvus/internal/proto/segcore.proto \ No newline at end of file diff --git a/src/query_task.cpp b/src/query_task.cpp new file mode 100644 index 0000000..e5a0b04 --- /dev/null +++ b/src/query_task.cpp @@ -0,0 +1,250 @@ +#include "query_task.h" +#include +#include +#include +#include "antlr4-runtime.h" +#include "parser/parser.h" +#include "parser/utils.h" +#include "common.h" +#include "pb/plan.pb.h" +#include "schema.pb.h" +#include "schema_util.h" +#include "status.h" +#include "string_util.hpp" +#include "pb/segcore.pb.h" +#include "log/Log.h" + +namespace milvus::local { + +QueryTask::QueryTask(const ::milvus::proto::milvus::QueryRequest* query_request, + const ::milvus::proto::schema::CollectionSchema* schema) + : query_request_(query_request), + schema_(schema), + limit_(-1), + offset_(0), + is_count_(false) { +} +QueryTask::~QueryTask() { +} + +bool +QueryTask::GetOutputFieldIds(std::vector* ids) { + if (output_fields_.size() == 0) { + for (const auto& field : schema_->fields()) { + if (field.fieldid() >= kStartOfUserFieldId && + !schema_util::IsVectorField(field.data_type())) { + ids->push_back(field.fieldid()); + } + } + } else { + std::string pk; + std::map name_ids; + for (const auto& field : schema_->fields()) { + name_ids[field.name()] = field.fieldid(); + if (field.is_primary_key()) { + pk = field.name(); + ids->push_back(field.fieldid()); + } + } + for (const auto& output_field : output_fields_) { + if (output_field == pk) + continue; + auto it = name_ids.find(output_field); + if (it == name_ids.end()) { + LOG_ERROR("Can not find output field {} in schema", + output_field); + return false; + } + if (it->second >= kStartOfUserFieldId) { + ids->push_back(it->second); + } + } + } + return true; +} + +Status +QueryTask::ParseQueryParams(::milvus::proto::plan::PlanNode* plan, + bool expr_empty) { + for (const auto& param : query_request_->query_params()) { + if (param.key() == kLimitKey) { + try { + limit_ = std::stoll(param.value()); + } catch (std::exception& e) { + auto err = string_util::SFormat("Parse limit failed, limit: {}", + param.value()); + return Status::ParameterInvalid(err); + } + } else if (param.key() == kOffsetKey) { + try { + offset_ = std::stoll(param.value()); + } catch (std::exception& e) { + auto err = string_util::SFormat( + "Parse offset failed, offset: {}", param.value()); + return Status::ParameterInvalid(err); + } + } else if (param.key() == kReduceStopForBestKey) { + // not used in local + } + } + if (offset_ < 0 || offset_ >= kTopkLimit) { + return Status::ParameterInvalid( + "Offset should be in range [0, {}], but got {}", + kTopkLimit, + offset_); + } + + if (limit_ <= 0) { + if (is_count_) { + limit_ = -1; + } else if (expr_empty) { + return Status::ParameterInvalid( + "empty expression should be used with limit"); + } else { + limit_ = kTopkLimit - offset_ - 1; + } + } + + // get and validate topk + if (limit_ >= kTopkLimit) { + return Status::ParameterInvalid( + "limit should be in range [1, {}], but got {}", kTopkLimit, limit_); + } + + if ((limit_ + offset_) >= kTopkLimit) { + return Status::ParameterInvalid( + "topk + offset should be in range [1, {}], but got {}", + kTopkLimit, + limit_ + offset_); + } + plan->mutable_query()->set_limit(limit_ + offset_); + return Status::Ok(); +} + +Status +QueryTask::Process(::milvus::proto::plan::PlanNode* plan) { + if (query_request_->output_fields_size() == 1 && + string_util::Trim(string_util::ToLower( + query_request_->output_fields().Get(0))) == kCountStr) { + plan->mutable_query()->set_is_count(true); + is_count_ = true; + } + + CHECK_STATUS( + ParseQueryParams(plan, string_util::Trim(query_request_->expr()) == ""), + ""); + + if (query_request_->expr() != "") { + CHECK_STATUS( + schema_util::ParseExpr(query_request_->expr(), + *schema_, + plan->mutable_query()->mutable_predicates()), + ""); + } + if (is_count_) { + user_output_fields_.push_back(kCountStr); + } else { + if (!schema_util::TranslateOutputFields(query_request_->output_fields(), + *schema_, + true, + &output_fields_, + &user_output_fields_)) { + return Status::ParameterInvalid("Error output fields"); + } + + std::vector output_ids; + if (!GetOutputFieldIds(&output_ids)) { + return Status::ParameterInvalid("Error output fields"); + } + for (auto id : output_ids) { + plan->add_output_field_ids(id); + } + } + + return Status::Ok(); +} + +bool +QueryTask::PostProcess(const RetrieveResult& rt, + ::milvus::proto::milvus::QueryResults* ret) { + milvus::proto::segcore::RetrieveResults seg_ret; + seg_ret.ParseFromArray(rt.retrieve_result_.proto_blob, + rt.retrieve_result_.proto_size); + + if (is_count_) { + auto count_data = ret->add_fields_data(); + count_data->CopyFrom(seg_ret.fields_data(0)); + count_data->set_field_name(kCountStr); + ret->add_output_fields(kCountStr); + return true; + } + + // reduce data by id + std::vector<::milvus::proto::schema::FieldData> reduced_fields; + int64_t ret_size = 0; + for (const auto& field_data : seg_ret.fields_data()) { + ::milvus::proto::schema::FieldData data; + if (!schema_util::ReduceFieldByIDs( + seg_ret.ids(), field_data, &data, &ret_size)) { + return false; + } + reduced_fields.push_back(data); + } + + if (ret_size > offset_) { + for (const auto& field_data : reduced_fields) { + auto new_data = ret->add_fields_data(); + new_data->set_field_id(field_data.field_id()); + new_data->set_type(field_data.type()); + new_data->set_field_name(field_data.field_name()); + new_data->set_is_dynamic(field_data.is_dynamic()); + auto limit = std::min(limit_, ret_size - offset_); + schema_util::SliceFieldData( + field_data, + std::vector>{{offset_, limit}}, + new_data); + } + } else { + for (const auto& field_schema : schema_->fields()) { + if (field_schema.fieldid() >= kStartOfUserFieldId) { + schema_util::FillEmptyField(field_schema, + ret->add_fields_data()); + } + } + } + + FillInFieldInfo(ret); + for (const auto& name : user_output_fields_) { + ret->add_output_fields(name); + } + return true; +} + +void +QueryTask::FillInFieldInfo(::milvus::proto::milvus::QueryResults* result_data) { + if (output_fields_.size() == 0 || result_data->fields_data_size() == 0) { + return; + } + for (size_t i = 0; i < output_fields_.size(); i++) { + const std::string& name = output_fields_[i]; + for (const auto& field : schema_->fields()) { + if (name == field.name()) { + auto field_id = field.fieldid(); + for (int j = 0; j < result_data->fields_data().size(); j++) { + if (field_id == result_data->fields_data(j).field_id()) { + result_data->mutable_fields_data(j)->set_field_name( + field.name()); + result_data->mutable_fields_data(j)->set_field_id( + field.fieldid()); + result_data->mutable_fields_data(j)->set_type( + field.data_type()); + result_data->mutable_fields_data(j)->set_is_dynamic( + field.is_dynamic()); + } + } + } + } + } +} + +} // namespace milvus::local diff --git a/src/query_task.h b/src/query_task.h new file mode 100644 index 0000000..0fe4af9 --- /dev/null +++ b/src/query_task.h @@ -0,0 +1,51 @@ +#pragma once + +#include "pb/milvus.pb.h" +#include "pb/plan.pb.h" +#include +#include + +#include "retrieve_result.h" +#include "status.h" + +namespace milvus::local { + +class QueryTask : NonCopyableNonMovable { + public: + QueryTask(const ::milvus::proto::milvus::QueryRequest* query_request, + const ::milvus::proto::schema::CollectionSchema*); + virtual ~QueryTask(); + + Status + Process(::milvus::proto::plan::PlanNode* plan); + + bool + PostProcess(const RetrieveResult& rt, + ::milvus::proto::milvus::QueryResults* ret); + + private: + bool + GetOutputFieldIds(std::vector* ids); + + void + FilterSystemField(); + + void + FillInFieldInfo(::milvus::proto::milvus::QueryResults* result_data); + + Status + ParseQueryParams(::milvus::proto::plan::PlanNode* plan, bool expr_empty); + + private: + const ::milvus::proto::milvus::QueryRequest* query_request_; + const ::milvus::proto::schema::CollectionSchema* schema_; + + std::vector output_fields_; + std::vector user_output_fields_; + + int64_t limit_; + int64_t offset_; + bool is_count_; +}; + +} // namespace milvus::local diff --git a/src/retrieve_result.h b/src/retrieve_result.h new file mode 100644 index 0000000..da9bcf9 --- /dev/null +++ b/src/retrieve_result.h @@ -0,0 +1,27 @@ +#pragma once + +#include "common.h" +#include "segcore/segment_c.h" + +namespace milvus::local { + +class RetrieveResult final : NonCopyableNonMovable { + public: + RetrieveResult() { + retrieve_result_.proto_blob = nullptr; + retrieve_result_.proto_size = 0; + } + ~RetrieveResult() { + if (retrieve_result_.proto_blob != nullptr) { + DeleteRetrieveResult(&retrieve_result_); + retrieve_result_.proto_blob = nullptr; + retrieve_result_.proto_size = 0; + } + }; + + public: + // milvus::proto::segcore::RetrieveResults + CRetrieveResult retrieve_result_; +}; + +} // namespace milvus::local diff --git a/src/schema_util.cpp b/src/schema_util.cpp new file mode 100644 index 0000000..0782a17 --- /dev/null +++ b/src/schema_util.cpp @@ -0,0 +1,684 @@ +#include "schema_util.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "log/Log.h" +#include "pb/plan.pb.h" +#include "pb/segcore.pb.h" +#include "schema.pb.h" +#include "status.h" +#include "string_util.hpp" +#include "parser/utils.h" +#include "parser/parser.h" + +namespace milvus::local { +namespace schema_util { + +using DType = ::milvus::proto::schema::DataType; + +std::any +GetField(const ::milvus::proto::schema::FieldData& field_data, + uint32_t field_index) { + if (field_data.type() == DType::FloatVector) { + std::vector vec; + int64_t dim = field_data.vectors().dim(); + auto vd = field_data.vectors().float_vector(); + for (int index = field_index * dim; index < (field_index + 1) * dim; + index++) { + vec.push_back(vd.data(index)); + } + return vec; + } else if (field_data.type() == DType::BinaryVector) { + int64_t dim = field_data.vectors().dim(); + int64_t size = dim / 8; + std::string v = field_data.vectors().binary_vector().substr( + field_index * size, size); + return v; + } else if (field_data.type() == DType::Float16Vector) { + int64_t dim = field_data.vectors().dim(); + int64_t size = dim * 2; + std::string v = field_data.vectors().float16_vector().substr( + field_index * size, size); + return v; + } else if (field_data.type() == DType::BFloat16Vector) { + int64_t dim = field_data.vectors().dim(); + int64_t size = dim * 2; + std::string v = field_data.vectors().bfloat16_vector().substr( + field_index * size, size); + return v; + } else if (field_data.type() == DType::Bool) { + return field_data.scalars().bool_data().data(field_index); + } else if (field_data.type() == DType::Int8 || + field_data.type() == DType::Int16 || + field_data.type() == DType::Int32) { + return field_data.scalars().int_data().data(field_index); + } else if (field_data.type() == DType::Int64) { + return field_data.scalars().long_data().data(field_index); + } else if (field_data.type() == DType::Float) { + return field_data.scalars().float_data().data(field_index); + } else if (field_data.type() == DType::Double) { + return field_data.scalars().double_data().data(field_index); + } else if (field_data.type() == DType::String) { + return field_data.scalars().string_data().data(field_index); + } else if (field_data.type() == DType::VarChar) { + return field_data.scalars().string_data().data(field_index); + } else if (field_data.type() == DType::Array) { + auto array_data = field_data.scalars().array_data(); + ::milvus::proto::schema::ArrayArray new_array; + new_array.set_element_type(array_data.element_type()); + new_array.add_data()->CopyFrom(array_data.data(field_index)); + return new_array; + } else if (field_data.type() == DType::JSON) { + return field_data.scalars().json_data().data(field_index); + } else if (field_data.type() == DType::SparseFloatVector) { + ::milvus::proto::schema::SparseFloatArray sp; + sp.CopyFrom(field_data.vectors().sparse_float_vector()); + return sp; + } else { + LOG_ERROR("Unkown data type: {}", field_data.type()); + return nullptr; + } + return nullptr; +} + +bool +IsVectorField(::milvus::proto::schema::DataType dtype) { + return dtype == ::milvus::proto::schema::DataType::FloatVector || + dtype == ::milvus::proto::schema::DataType::BinaryVector || + dtype == ::milvus::proto::schema::DataType::Float16Vector || + dtype == ::milvus::proto::schema::DataType::BFloat16Vector || + dtype == ::milvus::proto::schema::DataType::SparseFloatVector; +} + +bool +IsSparseVectorType(::milvus::proto::schema::DataType dtype) { + return dtype == ::milvus::proto::schema::DataType::SparseFloatVector; +} + +bool +FindDimFromFieldParams(const ::milvus::proto::schema::FieldSchema& field, + std::string* dim) { + for (const auto& param : field.type_params()) { + if (param.key() == kDimKey) { + dim->assign(param.value()); + return true; + } + } + + for (const auto& param : field.index_params()) { + if (param.key() == kDimKey) { + dim->assign(param.value()); + return true; + } + } + return false; +} + +int64_t +GetDim(const ::milvus::proto::schema::FieldSchema& field) { + if (!IsVectorField(field.data_type())) { + LOG_ERROR("{} is not vector type", field.data_type()); + return -1; + } + if (IsSparseVectorType(field.data_type())) { + LOG_ERROR("GetDim should not invoke on sparse vector type"); + return -1; + } + + std::string dim_str; + bool succ = FindDimFromFieldParams(field, &dim_str); + if (!succ) { + LOG_ERROR("Dim not found"); + return -1; + } + try { + return std::stoll(dim_str); + } catch (const std::invalid_argument& e) { + LOG_ERROR("invalid dimension: {}, {}", dim_str, e.what()); + } + return -1; +} + +bool +FillEmptyField(const ::milvus::proto::schema::FieldSchema& field_schema, + ::milvus::proto::schema::FieldData* field_data) { + field_data->set_field_name(field_schema.name()); + field_data->set_type(field_schema.data_type()); + field_data->set_field_id(field_schema.fieldid()); + field_data->set_is_dynamic(field_schema.is_dynamic()); + + if (field_schema.data_type() == DType::FloatVector) { + int64_t dim = GetDim(field_schema); + if (dim < 0) { + return false; + } + auto vec_field = field_data->mutable_vectors(); + vec_field->set_dim(dim); + vec_field->mutable_float_vector(); + } else if (field_schema.data_type() == DType::BinaryVector) { + int64_t dim = GetDim(field_schema); + if (dim < 0) { + return false; + } + auto vec_field = field_data->mutable_vectors(); + vec_field->set_dim(dim); + vec_field->mutable_binary_vector(); + } else if (field_schema.data_type() == DType::Float16Vector) { + int64_t dim = GetDim(field_schema); + if (dim < 0) { + return false; + } + auto vec_field = field_data->mutable_vectors(); + vec_field->set_dim(dim); + vec_field->mutable_float16_vector(); + } else if (field_schema.data_type() == DType::BFloat16Vector) { + int64_t dim = GetDim(field_schema); + if (dim < 0) { + return false; + } + auto vec_field = field_data->mutable_vectors(); + vec_field->set_dim(dim); + vec_field->mutable_bfloat16_vector(); + } else if (field_schema.data_type() == DType::Bool) { + field_data->mutable_scalars()->mutable_bool_data(); + } else if (field_schema.data_type() == DType::Int8 || + field_schema.data_type() == DType::Int16 || + field_schema.data_type() == DType::Int32) { + field_data->mutable_scalars()->mutable_int_data(); + } else if (field_schema.data_type() == DType::Int64) { + field_data->mutable_scalars()->mutable_long_data(); + } else if (field_schema.data_type() == DType::Float) { + field_data->mutable_scalars()->mutable_float_data(); + } else if (field_schema.data_type() == DType::Double) { + field_data->mutable_scalars()->mutable_double_data(); + } else if (field_schema.data_type() == DType::String) { + field_data->mutable_scalars()->mutable_string_data(); + } else if (field_schema.data_type() == DType::VarChar) { + field_data->mutable_scalars()->mutable_string_data(); + } else if (field_schema.data_type() == DType::Array) { + field_data->mutable_scalars()->mutable_array_data(); + } else if (field_schema.data_type() == DType::JSON) { + field_data->mutable_scalars()->mutable_json_data(); + } else if (field_schema.data_type() == DType::SparseFloatVector) { + field_data->mutable_vectors()->mutable_sparse_float_vector(); + } else { + LOG_ERROR("Unkown data type: {}", field_schema.data_type()); + return false; + } + return true; +} + +bool +FindDimFromSchema(const ::milvus::proto::schema::CollectionSchema& schema, + std::string* dim) { + for (const auto& field : schema.fields()) { + if (IsVectorField(field.data_type())) { + return FindDimFromFieldParams(field, dim); + } + } + return false; +} + +std::optional<::milvus::proto::plan::VectorType> +DataTypeToVectorType(::milvus::proto::schema::DataType dtype) { + if (dtype == ::milvus::proto::schema::DataType::FloatVector) { + return ::milvus::proto::plan::VectorType::FloatVector; + } else if (dtype == ::milvus::proto::schema::DataType::BinaryVector) { + return ::milvus::proto::plan::VectorType::BinaryVector; + } else if (dtype == ::milvus::proto::schema::DataType::Float16Vector) { + return ::milvus::proto::plan::VectorType::Float16Vector; + } else if (dtype == ::milvus::proto::schema::DataType::BFloat16Vector) { + return ::milvus::proto::plan::VectorType::BFloat16Vector; + } else if (dtype == ::milvus::proto::schema::DataType::SparseFloatVector) { + return ::milvus::proto::plan::VectorType::SparseFloatVector; + } else { + return std::nullopt; + } +} + +Status +FindVectorField(const ::milvus::proto::schema::CollectionSchema& schema, + const std::string& ann_field, + const ::milvus::proto::schema::FieldSchema** field) { + std::map + vec_fields; + for (const auto& field : schema.fields()) { + if (IsVectorField(field.data_type())) { + vec_fields[field.name()] = &field; + } + } + if (vec_fields.size() == 0) { + auto err = string_util::SFormat( + "Can not found vector field in collection {}", schema.name()); + LOG_ERROR(err); + return Status::ParameterInvalid(err); + } + if (ann_field.empty()) { + if (vec_fields.size() > 1) { + auto err = string_util::SFormat( + "multiple anns_fields exist, please specify a anns_field " + "insearch_params"); + return Status::ParameterInvalid(err); + } else { + *field = vec_fields.begin()->second; + return Status::Ok(); + } + } else { + if (vec_fields.find(ann_field) == vec_fields.end()) { + auto err = + string_util::SFormat("fieldName({}) not found", ann_field); + LOG_ERROR(err); + return Status::ParameterInvalid(err); + } + *field = vec_fields.at(ann_field); + return Status::Ok(); + } +} + +std::string +MergeIndexs(std::vector& indexs) { + ::milvus::proto::segcore::CollectionIndexMeta index_meta; + index_meta.set_maxindexrowcount(kMaxIndexRow); + for (size_t i = 0; i < indexs.size(); i++) { + index_meta.add_index_metas()->ParseFromString(indexs[i]); + } + return index_meta.SerializeAsString(); +} + +std::optional +GetPkId(const ::milvus::proto::schema::CollectionSchema& schema) { + for (const auto& field : schema.fields()) { + if (field.is_primary_key()) { + return field.fieldid(); + } + } + return std::nullopt; +} + +std::optional +GetPkName(const ::milvus::proto::schema::CollectionSchema& schema) { + for (const auto& field : schema.fields()) { + if (field.is_primary_key()) { + return field.name(); + } + } + return std::nullopt; +} + +bool +PickFieldDataByIndex(const ::milvus::proto::schema::FieldData& src_data, + const std::vector& indexes, + ::milvus::proto::schema::FieldData* dst) { + for (int64_t i : indexes) { + switch (src_data.type()) { + case DType::FloatVector: { + dst->mutable_vectors()->set_dim(src_data.vectors().dim()); + auto vec = + std::any_cast>(GetField(src_data, i)); + for (const auto& item : vec) { + dst->mutable_vectors()->mutable_float_vector()->add_data( + item); + } + } break; + + case DType::BinaryVector: { + dst->mutable_vectors()->set_dim(src_data.vectors().dim()); + auto vec = std::any_cast(GetField(src_data, i)); + dst->mutable_vectors()->mutable_binary_vector()->assign( + dst->mutable_vectors()->binary_vector() + vec); + } break; + + case DType::Float16Vector: { + dst->mutable_vectors()->set_dim(src_data.vectors().dim()); + auto vec = std::any_cast(GetField(src_data, i)); + dst->mutable_vectors()->mutable_float16_vector()->assign( + dst->mutable_vectors()->float16_vector() + vec); + } break; + + case DType::BFloat16Vector: { + dst->mutable_vectors()->set_dim(src_data.vectors().dim()); + auto vec = std::any_cast(GetField(src_data, i)); + dst->mutable_vectors()->mutable_bfloat16_vector()->assign( + dst->mutable_vectors()->bfloat16_vector() + vec); + } break; + + case DType::Bool: { + auto data = std::any_cast(GetField(src_data, i)); + dst->mutable_scalars()->mutable_bool_data()->add_data(data); + } break; + + case DType::Int8: { + auto data = std::any_cast(GetField(src_data, i)); + dst->mutable_scalars()->mutable_int_data()->add_data(data); + } break; + + case DType::Int16: { + auto data = std::any_cast(GetField(src_data, i)); + dst->mutable_scalars()->mutable_int_data()->add_data(data); + } break; + + case DType::Int32: { + auto data = std::any_cast(GetField(src_data, i)); + dst->mutable_scalars()->mutable_int_data()->add_data(data); + } break; + + case DType::Int64: { + auto data = std::any_cast(GetField(src_data, i)); + dst->mutable_scalars()->mutable_long_data()->add_data(data); + } break; + + case DType::Float: { + auto data = std::any_cast(GetField(src_data, i)); + dst->mutable_scalars()->mutable_float_data()->add_data(data); + } break; + + case DType::Double: { + auto data = std::any_cast(GetField(src_data, i)); + dst->mutable_scalars()->mutable_double_data()->add_data(data); + } break; + + case DType::String: { + auto data = std::any_cast(GetField(src_data, i)); + dst->mutable_scalars()->mutable_string_data()->add_data(data); + } break; + + case DType::VarChar: { + auto data = std::any_cast(GetField(src_data, i)); + dst->mutable_scalars()->mutable_string_data()->add_data(data); + } break; + + case DType::Array: { + auto data = std::any_cast<::milvus::proto::schema::ArrayArray>( + GetField(src_data, i)); + auto arr = dst->mutable_scalars()->mutable_array_data(); + arr->set_element_type(data.element_type()); + arr->add_data()->CopyFrom(data.data(0)); + } break; + case DType::JSON: { + auto data = std::any_cast(GetField(src_data, i)); + dst->mutable_scalars()->mutable_json_data()->add_data(data); + } break; + case DType::SparseFloatVector: { + auto data = + std::any_cast<::milvus::proto::schema::SparseFloatArray>( + GetField(src_data, i)); + dst->mutable_vectors()->mutable_sparse_float_vector()->CopyFrom( + data); + } break; + + default: + LOG_ERROR("Field: [{}-{}] unkown data type: {}", + src_data.field_name(), + src_data.field_id(), + src_data.type()); + return false; + } + } + return true; +} + +bool +SliceFieldData(const ::milvus::proto::schema::FieldData& src_data, + const std::vector>& ranges, + ::milvus::proto::schema::FieldData* dst) { + std::vector indexes; + for (const auto& range : ranges) { + int64_t offset = std::get<0>(range); + int64_t limit = std::get<1>(range); + for (int64_t i = offset; i < offset + limit; i++) { + indexes.push_back(i); + } + } + return PickFieldDataByIndex(src_data, indexes, dst); +} + +// Support wildcard in output fields: +// +//"*" - all fields +// +// For example, A and B are scalar fields, C and D are vector fields, duplicated fields will automatically be removed. +// +//output_fields=["*"] ==> [A,B,C,D] +//output_fields=["*",A] ==> [A,B,C,D] +//output_fields=["*",C] ==> [A,B,C,D] +bool +TranslateOutputFields( + const ::google::protobuf::RepeatedPtrField& raw_fields, + const ::milvus::proto::schema::CollectionSchema& schema, + bool add_primary, + std::vector* result_outputs, + std::vector* user_output_fields) { + std::string pk_name; + + std::set all_fields; + + // when enable dynamic field, result_field store the real field of collection, + // user_output_field store user-specified name; + std::set result_field; + std::set user_output_field; + + for (const auto& field : schema.fields()) { + if (field.is_primary_key()) { + pk_name = field.name(); + } + if (field.fieldid() >= kStartOfUserFieldId) { + all_fields.insert(field.name()); + } + } + + for (const auto& name : raw_fields) { + auto output_name = string_util::Trim(name); + if (output_name == "*") { + for (const std::string& name : all_fields) { + result_field.insert(name); + user_output_field.insert(name); + } + } else { + if (all_fields.find(output_name) != all_fields.end()) { + result_field.insert(output_name); + user_output_field.insert(output_name); + } else { + if (schema.enable_dynamic_field()) { + milvus::proto::schema::CollectionSchema schema_; + schema_.CopyFrom(schema); + auto helper = milvus::local::CreateSchemaHelper(&schema_); + auto expr = ParseIdentifier(helper, name); + if (expr->column_expr().info().nested_path_size() == 1 && + expr->column_expr().info().nested_path(0) == name) { + result_field.insert(kMetaFieldName); + user_output_field.insert(name); + } + } else { + LOG_ERROR("Field {} not exist", output_name); + return false; + } + } + } + } + if (add_primary) { + result_field.insert(pk_name); + user_output_field.insert(pk_name); + } + for (const std::string& fname : result_field) { + result_outputs->push_back(fname); + } + for (const std::string& fname : user_output_field) { + user_output_fields->push_back(fname); + } + return true; +} + +bool +ReduceFieldByIDs(const ::milvus::proto::schema::IDs& ids, + const ::milvus::proto::schema::FieldData& src, + ::milvus::proto::schema::FieldData* dst, + int64_t* real_size) { + std::set unique_ids; + std::vector indexes; + dst->set_type(src.type()); + dst->set_field_id(src.field_id()); + dst->set_field_name(src.field_name()); + dst->set_is_dynamic(src.is_dynamic()); + if (ids.has_int_id()) { + for (int64_t i = 0; i < ids.int_id().data_size(); ++i) { + auto cur_id = std::to_string(ids.int_id().data(i)); + if (unique_ids.find(cur_id) != unique_ids.end()) + continue; + unique_ids.insert(cur_id); + indexes.push_back(i); + } + } else if (ids.has_str_id()) { + for (int64_t i = 0; i < ids.str_id().data_size(); ++i) { + auto cur_id = ids.str_id().data(i); + if (unique_ids.find(cur_id) != unique_ids.end()) + continue; + unique_ids.insert(cur_id); + indexes.push_back(i); + } + } else { + // empty data + return true; + } + *real_size = unique_ids.size(); + return PickFieldDataByIndex(src, indexes, dst); +} + +Status +ParseExpr(const std::string& expr_str, + const ::milvus::proto::schema::CollectionSchema& schema, + ::milvus::proto::plan::Expr* expr_out) { + try { + antlr4::ANTLRInputStream input(expr_str); + PlanLexer lexer(&input); + antlr4::CommonTokenStream tokens(&lexer); + PlanParser parser(&tokens); + PlanParser::ExprContext* tree = parser.expr(); + auto helper = milvus::local::CreateSchemaHelper( + const_cast(&schema)); + milvus::local::PlanCCVisitor visitor(&helper); + auto ret = visitor.visit(tree); + if (!ret.has_value()) { + return Status::ParameterInvalid( + string_util::SFormat("Invalid expr: {}", expr_str)); + } + auto expr = std::any_cast(ret); + expr_out->CopyFrom(*(expr.expr)); + return Status::Ok(); + } catch (std::exception& e) { + return Status::ParameterInvalid( + string_util::SFormat("Invalid expr: {}", expr_str)); + } +} + +bool +SchemaEquals(const std::string& schema_str_l, const std::string& schema_str_r) { + ::milvus::proto::schema::CollectionSchema schema_l, schema_r; + if (!schema_l.ParseFromString(schema_str_l) || + !schema_r.ParseFromString(schema_str_r)) { + LOG_ERROR("Parse schema failed"); + return false; + } + if (schema_l.name() != schema_r.name() || + schema_l.description() != schema_r.description() || + schema_l.enable_dynamic_field() != schema_r.enable_dynamic_field() || + schema_l.fields_size() != schema_r.fields_size() || + !CheckParamsEqual(schema_l.properties(), schema_r.properties())) { + return false; + } + // check field + for (int i = 0; i < schema_l.fields_size(); i++) { + if (schema_l.fields(i).fieldid() != schema_r.fields(i).fieldid() || + schema_l.fields(i).name() != schema_r.fields(i).name() || + schema_l.fields(i).is_primary_key() != + schema_r.fields(i).is_primary_key() || + schema_l.fields(i).description() != + schema_r.fields(i).description() || + schema_l.fields(i).data_type() != schema_r.fields(i).data_type() || + schema_l.fields(i).autoid() != schema_r.fields(i).autoid() || + schema_l.fields(i).is_dynamic() != + schema_r.fields(i).is_dynamic() || + schema_l.fields(i).element_type() != + schema_r.fields(i).element_type() || + !CheckValueFieldEqual(schema_l.fields(i).default_value(), + schema_r.fields(i).default_value()) || + !CheckParamsEqual(schema_l.fields(i).type_params(), + schema_r.fields(i).type_params()) || + !CheckParamsEqual(schema_l.fields(i).index_params(), + schema_r.fields(i).index_params()) + + ) { + return false; + } + } + return true; +} + +bool +CheckParamsEqual(const ::google::protobuf::RepeatedPtrField< + ::milvus::proto::common::KeyValuePair>& left, + const ::google::protobuf::RepeatedPtrField< + ::milvus::proto::common::KeyValuePair>& right) { + if (left.size() != right.size()) + return false; + KVMap right_map; + for (int i = 0; i < right.size(); i++) { + right_map[right[i].key()] = right[i].value(); + } + for (int i = 0; i < left.size(); i++) { + auto it = right_map.find(left[i].key()); + if (it == right_map.end()) + return false; + if (it->second != left[i].value()) + return false; + } + return true; +} + +bool +CheckValueFieldEqual(const ::milvus::proto::schema::ValueField& left, + const ::milvus::proto::schema::ValueField& right) { + if (!left.IsInitialized() && !right.IsInitialized()) + return true; + + if (left.data_case() != right.data_case()) + return false; + + if (left.data_case() == 0) { + return true; + } + + if (left.has_bool_data() && right.has_bool_data() && + left.bool_data() == right.bool_data()) + return true; + if (left.has_int_data() && right.has_int_data() && + left.int_data() == right.int_data()) + return true; + if (left.has_long_data() && right.has_long_data() && + left.long_data() == right.long_data()) + return true; + if (left.has_float_data() && right.has_float_data() && + std::fabs(left.float_data() - right.float_data()) < 0.00001f) + return true; + if (left.has_double_data() && right.has_double_data() && + std::fabs(left.double_data() - right.double_data()) < 0.0000001) + return true; + if (left.has_string_data() && right.has_string_data() && + left.string_data() == right.string_data()) + return true; + if (left.has_bytes_data() && right.has_bytes_data() && + left.bytes_data() == right.bytes_data()) + return true; + return false; +} + +} // namespace schema_util + +} // namespace milvus::local diff --git a/src/schema_util.h b/src/schema_util.h new file mode 100644 index 0000000..cca48b9 --- /dev/null +++ b/src/schema_util.h @@ -0,0 +1,115 @@ +#pragma once + +#include +#include +#include +#include +#include "common.h" +#include "common.pb.h" +#include "pb/plan.pb.h" +#include "status.h" +#include "string_util.hpp" +#include "pb/schema.pb.h" +#include + +namespace milvus::local { + +namespace schema_util { + +std::any +GetField(const ::milvus::proto::schema::FieldData& field_data, + uint32_t field_index); + +bool +IsVectorField(::milvus::proto::schema::DataType dtype); + +bool +IsSparseVectorType(::milvus::proto::schema::DataType dtype); + +bool +FindDimFromFieldParams(const ::milvus::proto::schema::FieldSchema& field, + std::string* dim); +int64_t +GetDim(const ::milvus::proto::schema::FieldSchema& field); + +bool +FindDimFromSchema(const ::milvus::proto::schema::CollectionSchema& schema, + std::string* dim); + +std::optional<::milvus::proto::plan::VectorType> +DataTypeToVectorType(::milvus::proto::schema::DataType dtype); + +Status +FindVectorField(const ::milvus::proto::schema::CollectionSchema& schema, + const std::string& ann_field, + const ::milvus::proto::schema::FieldSchema** field); + +std::optional +GetPkId(const ::milvus::proto::schema::CollectionSchema& schema); + +std::optional +GetPkName(const ::milvus::proto::schema::CollectionSchema& schema); + +bool +SliceFieldData(const ::milvus::proto::schema::FieldData& src_data, + const std::vector>& ranges, + ::milvus::proto::schema::FieldData* dst); + +bool +FillEmptyField(const ::milvus::proto::schema::FieldSchema& field_schema, + ::milvus::proto::schema::FieldData* field_data); + +std::string +MergeIndexs(std::vector& indexs); + +bool +SchemaEquals(const std::string& schema_str_l, const std::string& schema_str_r); + +// Support wildcard in output fields: +// +//"*" - all fields +// +// For example, A and B are scalar fields, C and D are vector fields, duplicated fields will automatically be removed. +// +//output_fields=["*"] ==> [A,B,C,D] +//output_fields=["*",A] ==> [A,B,C,D] +//output_fields=["*",C] ==> [A,B,C,D] +bool +TranslateOutputFields( + const ::google::protobuf::RepeatedPtrField& raw_fields, + const ::milvus::proto::schema::CollectionSchema& schema, + bool add_primary, + std::vector* result_outputs, + std::vector* user_output_fields); + +bool +ReduceFieldByIDs(const ::milvus::proto::schema::IDs& ids, + const ::milvus::proto::schema::FieldData& src, + ::milvus::proto::schema::FieldData* dst, + int64_t* real_size); + +// PositivelyRelated return if metricType are "ip" or "IP" +inline bool +PositivelyRelated(const std::string& metrics_type) { + auto upper_str = string_util::ToUpper(metrics_type); + return upper_str == KMetricsIPName || upper_str == kMetricsCosineName; +} + +Status +ParseExpr(const std::string& expr_str, + const ::milvus::proto::schema::CollectionSchema& schema, + ::milvus::proto::plan::Expr* expr_out); + +bool +CheckParamsEqual(const ::google::protobuf::RepeatedPtrField< + ::milvus::proto::common::KeyValuePair>& left, + const ::google::protobuf::RepeatedPtrField< + ::milvus::proto::common::KeyValuePair>& right); + +bool +CheckValueFieldEqual(const ::milvus::proto::schema::ValueField& left, + const ::milvus::proto::schema::ValueField& right); + +} // namespace schema_util + +} // namespace milvus::local diff --git a/src/search_result.h b/src/search_result.h new file mode 100644 index 0000000..690aa31 --- /dev/null +++ b/src/search_result.h @@ -0,0 +1,39 @@ +#pragma once + +#include +#include +#include +#include "common.h" +#include "common/type_c.h" +#include "segcore/reduce_c.h" + +namespace milvus::local { + +class SearchResult final : NonCopyableNonMovable { + public: + SearchResult(const std::vector& slice_nqs, + const std::vector& slice_topKs) + : slice_nqs_(slice_nqs), slice_topKs_(slice_topKs) { + blob_ = nullptr; + } + ~SearchResult() { + if (blob_ != nullptr) { + DELETE_AND_SET_NULL(blob_, DeleteSearchResultDataBlobs); + result_.clear(); + } + } + + public: + // std::vector> + CSearchResultDataBlobs blob_; + + // milvus::proto::schema::SearchResultData + // ptr to blob_ + std::vector result_; + + public: + std::vector slice_nqs_; + std::vector slice_topKs_; +}; + +} // namespace milvus::local diff --git a/src/search_task.cpp b/src/search_task.cpp new file mode 100644 index 0000000..49cf56b --- /dev/null +++ b/src/search_task.cpp @@ -0,0 +1,294 @@ +#include "search_task.h" +#include +#include +#include +#include +#include +#include "common.h" +#include "pb/plan.pb.h" +#include "antlr4-runtime.h" +#include "parser/parser.h" +#include "parser/utils.h" +#include "log/Log.h" +#include "schema.pb.h" +#include "schema_util.h" +#include "status.h" + +namespace milvus::local { + +SearchTask::SearchTask(::milvus::proto::milvus::SearchRequest* search_reques, + const ::milvus::proto::schema::CollectionSchema* schema) + : search_request_(search_reques), + schema_(schema), + topk_(-1), + offset_(0), + ann_field_(""), + metric_(""), + groupby_field_name_("") { +} +SearchTask::~SearchTask() { +} + +bool +SearchTask::ParseSearchInfo(::milvus::proto::plan::QueryInfo* info) { + int64_t round_decimal = -1; + std::string search_param_str(""); + for (const auto& param : search_request_->search_params()) { + if (param.key() == kTopkKey) { + try { + topk_ = std::stoll(param.value()); + } catch (std::exception& e) { + LOG_ERROR("Parse topk failed, topk: {}, err: {}", + param.value(), + e.what()); + return false; + } + + } else if (param.key() == kOffsetKey) { + try { + offset_ = std::stoll(param.value()); + } catch (std::exception& e) { + LOG_ERROR("Parse offset failed, offset: {}, err: {}", + param.value(), + e.what()); + return false; + } + } else if (param.key() == kMetricTypeKey) { + metric_ = param.value(); + } else if (param.key() == kRoundDecimalKey) { + try { + round_decimal = std::stoll(param.value()); + } catch (std::exception& e) { + LOG_ERROR("Parse round_decimal failed, topk: {}, err: {}", + param.value(), + e.what()); + return false; + } + } else if (param.key() == kSearchParamKey) { + search_param_str = param.value(); + } else if (param.key() == kGroupByFieldKey) { + groupby_field_name_ = param.value(); + } else if (param.key() == kAnnFieldKey) { + ann_field_ = param.value(); + } + } + + // get and validate topk + if (topk_ <= 0 || topk_ >= kTopkLimit) { + LOG_ERROR( + "Topk should be in range [1, {}], but got {}", kTopkLimit, topk_); + return false; + } + + if (offset_ >= kTopkLimit) { + LOG_ERROR("Offset should be in range [0, {}], but got {}", + kTopkLimit, + offset_); + return false; + } + + if ((topk_ + offset_) >= kTopkLimit) { + LOG_ERROR("topk + offset should be in range [1, {}], but got {}", + kTopkLimit, + topk_ + offset_); + return false; + } + + if (round_decimal != -1 && (round_decimal > 6 || round_decimal < 0)) { + LOG_ERROR( + "round_decimal {} is invalid, should be -1 or an integer in " + "range [0, 6]", + round_decimal); + return false; + } + + int64_t groupby_field_id = kRowIdField; + if (groupby_field_name_ != "") { + groupby_field_id = -1; + for (const auto& field : schema_->fields()) { + if (groupby_field_name_ == field.name()) { + groupby_field_id = field.fieldid(); + break; + } + } + if (groupby_field_id == -1) { + LOG_ERROR("groupBy field {} not found in schema", + groupby_field_name_); + return false; + } + } + info->set_topk(topk_ + offset_); + info->set_metric_type(metric_); + info->set_search_params(search_param_str); + info->set_round_decimal(round_decimal); + info->set_group_by_field_id(groupby_field_id); + return true; +} + +bool +SearchTask::GetOutputFieldsIds(std::vector* ids) { + std::map name_ids; + for (const auto& field : schema_->fields()) { + name_ids[field.name()] = field.fieldid(); + } + + for (const auto& output_field : output_fields_) { + auto it = name_ids.find(output_field); + if (it == name_ids.end()) { + LOG_ERROR("Can not find output field {} in schema", output_field); + return false; + } + ids->push_back(it->second); + } + return true; +} + +std::optional> +SearchTask::GetVectorField() { + for (const auto& field : schema_->fields()) { + if (schema_util::IsVectorField(field.data_type())) { + return std::make_tuple(field.name(), field.data_type()); + } + } + LOG_ERROR("Can not found vector field"); + return std::nullopt; +} + +Status +SearchTask::Process(::milvus::proto::plan::PlanNode* plan, + std::string* placeholder_group, + std::vector* nqs, + std::vector* topks) { + if (!schema_util::TranslateOutputFields(search_request_->output_fields(), + *schema_, + false, + &output_fields_, + &user_output_fields_)) { + return Status::ParameterInvalid(); + } + + std::vector ids; + if (!GetOutputFieldsIds(&ids)) { + return Status::ParameterInvalid(); + } + for (int64_t id : ids) { + plan->add_output_field_ids(id); + } + auto vector_anns = plan->mutable_vector_anns(); + vector_anns->set_placeholder_tag(kPlaceholderTag); + if (!ParseSearchInfo(vector_anns->mutable_query_info())) { + return Status::ParameterInvalid(); + } + + placeholder_group->assign(search_request_->placeholder_group()); + nqs->push_back(search_request_->nq()); + topks->push_back(vector_anns->query_info().topk()); + + const ::milvus::proto::schema::FieldSchema* field; + auto s = schema_util::FindVectorField(*schema_, ann_field_, &field); + CHECK_STATUS(s, ""); + vector_anns->set_field_id(field->fieldid()); + auto vtype = schema_util::DataTypeToVectorType(field->data_type()); + vector_anns->set_vector_type(*vtype); + + if (!search_request_->dsl().empty()) { + CHECK_STATUS(schema_util::ParseExpr( + search_request_->dsl(), + *schema_, + plan->mutable_vector_anns()->mutable_predicates()), + ""); + } + return Status::Ok(); +} + +bool +SearchTask::PostProcess( + const SearchResult& segcore_reault, + ::milvus::proto::milvus::SearchResults* search_results) { + ::milvus::proto::schema::SearchResultData tmp_ret; + tmp_ret.ParseFromArray(segcore_reault.result_[0].proto_blob, + segcore_reault.result_[0].proto_size); + + search_results->mutable_results()->set_num_queries(tmp_ret.num_queries()); + auto ret_size = tmp_ret.scores_size(); + auto nq = tmp_ret.num_queries(); + int score_coefficient = schema_util::PositivelyRelated(metric_) ? 1 : -1; + + if (nq * offset_ < ret_size) { + for (const auto& name : user_output_fields_) { + search_results->mutable_results()->add_output_fields(name); + } + int64_t limit = std::min(topk_, ret_size / nq - offset_); + search_results->mutable_results()->set_top_k(limit); + std::vector> ranges; + for (int i = 0; i < nq; i++) { + search_results->mutable_results()->mutable_topks()->Add(limit); + ranges.push_back(std::make_tuple(offset_ * i, limit)); + // copy topks and scores + for (int j = offset_ * i; j < (offset_ * i) + limit; j++) { + search_results->mutable_results()->mutable_scores()->Add( + tmp_ret.scores(j) * score_coefficient); + + // copy ids + if (tmp_ret.ids().has_int_id()) { + search_results->mutable_results() + ->mutable_ids() + ->mutable_int_id() + ->add_data(tmp_ret.ids().int_id().data(j)); + } else { + search_results->mutable_results() + ->mutable_ids() + ->mutable_str_id() + ->add_data(tmp_ret.ids().str_id().data(j)); + } + } + } + // copy fields_data + for (const auto& field_data : tmp_ret.fields_data()) { + auto new_data = + search_results->mutable_results()->add_fields_data(); + new_data->set_field_id(field_data.field_id()); + new_data->set_type(field_data.type()); + new_data->set_field_name(field_data.field_name()); + new_data->set_is_dynamic(field_data.is_dynamic()); + schema_util::SliceFieldData(field_data, ranges, new_data); + } + FillInFieldInfo(search_results->mutable_results()); + } else { + for (int i = 0; i < nq; i++) { + search_results->mutable_results()->mutable_topks()->Add(0); + } + } + + return true; +} + +void +SearchTask::FillInFieldInfo( + ::milvus::proto::schema::SearchResultData* result_data) { + if (output_fields_.size() == 0 || result_data->fields_data_size() == 0) { + return; + } + for (size_t i = 0; i < output_fields_.size(); i++) { + const std::string& name = output_fields_[i]; + for (const auto& field : schema_->fields()) { + if (name == field.name()) { + auto field_id = field.fieldid(); + for (int j = 0; j < result_data->fields_data().size(); j++) { + if (field_id == result_data->fields_data(j).field_id()) { + result_data->mutable_fields_data(j)->set_field_name( + field.name()); + result_data->mutable_fields_data(j)->set_field_id( + field.fieldid()); + result_data->mutable_fields_data(j)->set_type( + field.data_type()); + result_data->mutable_fields_data(j)->set_is_dynamic( + field.is_dynamic()); + } + } + } + } + } +} + +} // namespace milvus::local diff --git a/src/search_task.h b/src/search_task.h new file mode 100644 index 0000000..16c0441 --- /dev/null +++ b/src/search_task.h @@ -0,0 +1,55 @@ +#pragma once +#include +#include +#include +#include "pb/milvus.pb.h" +#include "pb/plan.pb.h" +#include "schema.pb.h" +#include "search_result.h" +#include "status.h" + +namespace milvus::local { + +class SearchTask final : NonCopyableNonMovable { + public: + SearchTask(::milvus::proto::milvus::SearchRequest* search_reques, + const ::milvus::proto::schema::CollectionSchema* schema); + virtual ~SearchTask(); + + public: + Status + Process(::milvus::proto::plan::PlanNode* plan, + std::string* placeholder_group, + std::vector* nqs, + std::vector* topks); + + bool + PostProcess(const SearchResult& segcore_reaul, + ::milvus::proto::milvus::SearchResults* search_results); + + private: + bool + ParseSearchInfo(::milvus::proto::plan::QueryInfo* info); + + bool + GetOutputFieldsIds(std::vector* ids); + + std::optional> + GetVectorField(); + + void + FillInFieldInfo(::milvus::proto::schema::SearchResultData* result_data); + + private: + ::milvus::proto::milvus::SearchRequest* search_request_; + const ::milvus::proto::schema::CollectionSchema* schema_; + + std::vector output_fields_; + std::vector user_output_fields_; + int64_t topk_, offset_; + std::string ann_field_; + std::string metric_; + std::string groupby_field_name_; +}; + +} // namespace milvus::local diff --git a/src/segcore_wrapper.cpp b/src/segcore_wrapper.cpp new file mode 100644 index 0000000..07bf9ad --- /dev/null +++ b/src/segcore_wrapper.cpp @@ -0,0 +1,269 @@ +#include "segcore_wrapper.h" +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "common/type_c.h" +#include "log/Log.h" +#include "pb/segcore.pb.h" +#include "retrieve_result.h" +#include "segcore/reduce_c.h" +#include "segcore/segment_c.h" +#include "pb/schema.pb.h" +#include "status.h" + +namespace milvus::local { + +const int64_t DEFAULT_MAX_OUTPUT_SIZE = 67108864; + +class RetrievePlanWrapper final : private NonCopyableNonMovable { + public: + RetrievePlanWrapper() : plan_(nullptr) { + } + virtual ~RetrievePlanWrapper() { + DELETE_AND_SET_NULL(plan_, DeleteRetrievePlan); + } + + public: + CRetrievePlan plan_; +}; + +class SearchPlanWrapper final : private NonCopyableNonMovable { + public: + SearchPlanWrapper() : plan_(nullptr) { + } + virtual ~SearchPlanWrapper() { + DELETE_AND_SET_NULL(plan_, DeleteSearchPlan); + } + + public: + CSearchPlan plan_; +}; + +class PlaceholderGroupWrapper final : private NonCopyableNonMovable { + public: + PlaceholderGroupWrapper() : group_(nullptr) { + } + virtual ~PlaceholderGroupWrapper() { + DELETE_AND_SET_NULL(group_, DeletePlaceholderGroup); + } + + public: + CPlaceholderGroup group_; +}; + +class SearchResultWrapper final : private NonCopyableNonMovable { + public: + SearchResultWrapper() : ret_(nullptr) { + } + virtual ~SearchResultWrapper() { + DELETE_AND_SET_NULL(ret_, DeleteSearchResult); + } + + public: + CSearchResult ret_; +}; + +SegcoreWrapper::~SegcoreWrapper() { + if (collection_ != nullptr) { + try { + DELETE_AND_SET_NULL(collection_, DeleteCollection); + } catch (std::exception& e) { + LOG_ERROR("Release collection {} failed", collection_name_); + } + } + if (segment_ != nullptr) { + try { + DELETE_AND_SET_NULL(segment_, DeleteSegment); + } catch (std::exception& e) { + LOG_ERROR("Release segment {} failed", collection_name_); + } + } +} + +Status +SegcoreWrapper::SetCollectionInfo(const std::string& collection_name, + const std::string& collection_info) { + assert(collection_ == nullptr); + auto new_collection_info = NewCollectionInfo(collection_info); + try { + collection_ = ::NewCollection(new_collection_info.c_str(), + new_collection_info.size()); + CHECK_STATUS(Status(::NewSegment(collection_, Growing, 0, &segment_)), + "Init segcore failed"); + collection_name_ = collection_name; + return Status::Ok(); + } catch (std::exception& e) { + return Status::SegcoreErr(e.what()); + } +} + +std::string +SegcoreWrapper::NewCollectionInfo(const std::string& info) { + ::milvus::proto::schema::CollectionSchema schema; + schema.ParseFromString(info); + for (auto it = schema.fields().begin(); it != schema.fields().end();) { + if (it->fieldid() < kStartOfUserFieldId) { + schema.mutable_fields()->erase(it); + } else { + ++it; + } + } + return schema.SerializeAsString(); +} + +Status +SegcoreWrapper::SetIndexMeta(const std::string& meta_info) { + try { + ::SetIndexMeta(collection_, meta_info.c_str(), meta_info.size()); + return Status::Ok(); + } catch (std::exception& e) { + LOG_ERROR("Set Index meta failed, err: {}", e.what()); + return Status::SegcoreErr(e.what()); + } +} + +Status +SegcoreWrapper::Insert(int64_t size, const std::string& insert_record_proto) { + try { + int64_t offset = 0; + CHECK_STATUS(Status(::PreInsert(segment_, size, &offset)), + "Pre insert failed, err:"); + + ::milvus::proto::segcore::InsertRecord r; + r.ParseFromString(insert_record_proto); + + std::vector row_ids; + std::vector timestamps; + for (const auto& field_data : r.fields_data()) { + if (field_data.field_id() == kRowIdField) { + for (int64_t rowid : field_data.scalars().long_data().data()) { + row_ids.push_back(rowid); + } + } + if (field_data.field_id() == kTimeStampField) { + for (int64_t ts : field_data.scalars().long_data().data()) { + timestamps.push_back(ts); + } + } + } + + CHECK_STATUS(Status(::Insert(segment_, + offset, + size, + row_ids.data(), + timestamps.data(), + reinterpret_cast( + insert_record_proto.data()), + insert_record_proto.size())), + "Insert failed:"); + return Status::Ok(); + } catch (std::exception& e) { + LOG_ERROR("Insert failed, err: {}", e.what()); + return Status::SegcoreErr(e.what()); + } +} + +Status +SegcoreWrapper::Retrieve(const std::string& plan, RetrieveResult* result) { + try { + RetrievePlanWrapper retrieve_plan; + auto status = Status(::CreateRetrievePlanByExpr( + collection_, plan.c_str(), plan.size(), &retrieve_plan.plan_)); + CHECK_STATUS(status, "Create retrieve plan failed, invalid expr"); + auto rs = Status(::Retrieve({}, + segment_, + retrieve_plan.plan_, + GetTimestamp(), + &(result->retrieve_result_), + DEFAULT_MAX_OUTPUT_SIZE)); + CHECK_STATUS(rs, "Retrieve failed, errs:"); + return Status::Ok(); + } catch (std::exception& e) { + LOG_ERROR("Retrieve failed, err: {}", e.what()); + return Status::SegcoreErr(e.what()); + } +} + +Status +SegcoreWrapper::Search(const std::string& plan, + const std::string& placeholder_group, + SearchResult* result) { + try { + SearchPlanWrapper search_plan; + CHECK_STATUS( + Status(::CreateSearchPlanByExpr( + collection_, plan.c_str(), plan.size(), &search_plan.plan_)), + "Create search plan failed, err:"); + + PlaceholderGroupWrapper group; + CHECK_STATUS( + Status(::ParsePlaceholderGroup(search_plan.plan_, + (void*)placeholder_group.c_str(), + placeholder_group.size(), + &group.group_)), + "Parse placeholder group failed"); + SearchResultWrapper search_result; + CHECK_STATUS(Status(::Search({}, + segment_, + search_plan.plan_, + group.group_, + GetTimestamp(), + &(search_result.ret_))), + "Search failed"); + + CHECK_STATUS( + Status(::ReduceSearchResultsAndFillData(&(result->blob_), + search_plan.plan_, + &(search_result.ret_), + 1, + result->slice_nqs_.data(), + result->slice_topKs_.data(), + result->slice_nqs_.size())), + "Reduce search result failed"); + result->result_.resize(result->slice_nqs_.size()); + for (size_t i = 0; i < result->slice_nqs_.size(); i++) { + CHECK_STATUS(Status(::GetSearchResultDataBlob( + &(result->result_[i]), result->blob_, i)), + "Get search reault blob failed"); + } + + return Status::Ok(); + } catch (std::exception& e) { + LOG_ERROR("Search failed, err: {}", e.what()); + return Status::SegcoreErr(e.what()); + } +} + +Status +SegcoreWrapper::DeleteByIds(const std::string& ids, int64_t size) { + CHECK_STATUS(Status(::Delete(segment_, + 0, + size, + reinterpret_cast(ids.data()), + ids.size(), + GetTimestamps(size).data())), + "Detete failed"); + return Status::Ok(); +} + +std::vector +SegcoreWrapper::GetTimestamps(int64_t size) { + auto ts = GetTimestamp(); + return std::vector(size, ts); +} + +uint64_t +SegcoreWrapper::GetTimestamp() { + auto now = std::chrono::system_clock::now(); + auto duration = now.time_since_epoch(); + auto ms = + std::chrono::duration_cast(duration).count(); + return (ms << 18) + cur_id_; +} + +} // namespace milvus::local diff --git a/src/segcore_wrapper.h b/src/segcore_wrapper.h new file mode 100644 index 0000000..4d46a71 --- /dev/null +++ b/src/segcore_wrapper.h @@ -0,0 +1,64 @@ +#pragma once + +#include +#include +#include +#include "status.h" +#include "retrieve_result.h" +#include "search_result.h" +#include "segcore/collection_c.h" +#include "segcore/segcore_init_c.h" +#include "segcore/segment_c.h" + +namespace milvus::local { + +class SegcoreWrapper final : NonCopyableNonMovable { + public: + SegcoreWrapper() : collection_(nullptr), cur_id_(0), segment_(nullptr) { + SegcoreSetEnableTempSegmentIndex(true); + } + virtual ~SegcoreWrapper(); + + public: + Status + SetCollectionInfo(const std::string& collection_name_, + const std::string& collection_info); + + Status + SetIndexMeta(const std::string& meta_info); + + Status + CreateIndex(const std::string& meta_info); + + Status + Insert(int64_t size, const std::string& insert_record_proto); + + Status + Retrieve(const std::string& plan, RetrieveResult* result); + + Status + Search(const std::string& plan, + const std::string& placeholder_group, + SearchResult* result); + + Status + DeleteByIds(const std::string& ids, int64_t size); + + private: + std::vector + GetTimestamps(int64_t size); + + uint64_t + GetTimestamp(); + + std::string + NewCollectionInfo(const std::string& info); + + private: + CCollection collection_; + int64_t cur_id_; + CSegmentInterface segment_; + std::string collection_name_; +}; + +} // namespace milvus::local diff --git a/src/server.cpp b/src/server.cpp new file mode 100644 index 0000000..b336478 --- /dev/null +++ b/src/server.cpp @@ -0,0 +1,79 @@ +#include "milvus_service_impl.h" +#include +#include "log/Log.h" +#include "string_util.hpp" +#include +#include +#include +#include +#include + +int +BlockLock(const char* filename) { + int fd = open(filename, O_RDWR | O_CREAT, 0666); + if (fd == -1) { + LOG_ERROR("Open lock file {} failed", filename); + return -1; + } + struct flock fl; + fl.l_type = F_WRLCK; + fl.l_whence = SEEK_SET; + fl.l_start = 0; + fl.l_len = 0; + // block lock + if (fcntl(fd, F_SETLKW, &fl) == -1) { + close(fd); + return -1; + } + // unlock file + fl.l_type = F_UNLCK; + if (fcntl(fd, F_SETLK, &fl) == -1) { + return -1; + } + LOG_ERROR("Process exit"); + close(fd); + return 0; +} + +int +main(int argc, char** argv) { + google::InitGoogleLogging(argv[0]); + if (!(argc == 3 || argc == 4 || argc == 5)) { + return -1; + } + + std::string work_dir = argv[1]; + std::string address = argv[2]; + std::string log_level = "ERROR"; + if (argc == 4) { + log_level = argv[3]; + } + if (log_level == "INFO") { + google::SetStderrLogging(google::INFO); + } else { + google::SetStderrLogging(google::ERROR); + } + + ::milvus::local::MilvusServiceImpl service(work_dir); + if (!service.Init()) { + LOG_ERROR("Init milvus failed"); + return -1; + } + ::grpc::ServerBuilder builder; + builder.AddListeningPort(address, grpc::InsecureServerCredentials()); + builder.RegisterService(&service); + std::unique_ptr<::grpc::Server> server(builder.BuildAndStart()); + LOG_INFO("Start milvus-local success..."); + if (argc == 5) { + auto filename = argv[4]; + /* + Blocked while attempting to acquire a file lock held by the parent process. + When the lock is successfully acquired, it indicates that the parent process has exited, + and the child process should exit as well. + */ + BlockLock(filename); + } else { + server->Wait(); + } + return 0; +} diff --git a/src/status.h b/src/status.h new file mode 100644 index 0000000..e814528 --- /dev/null +++ b/src/status.h @@ -0,0 +1,255 @@ +#pragma once + +#include +#include "common/type_c.h" +#include "string_util.hpp" + +namespace milvus::local { + +// errors are the subset of milvus/pkg/merr/errors.go +enum ErrCode { + Succ = 0, + ErrServiceInternal = 5, + ErrCollectionNotFound = 100, + ErrCollectionNotLoaded = 101, + ErrCollectionNumLimitExceeded = 102, + ErrCollectionNotFullyLoaded = 103, + ErrCollectionLoaded = 104, + ErrCollectionIllegalSchema = 105, + + // not in milvus + ErrCollectionAlreadExist = 199, + + ErrIndexNotFound = 700, + ErrIndexNotSupported = 701, + ErrIndexDuplicate = 702, + + ErrParameterInvalid = 1100, + ErrParameterMissing = 1101, + + ErrMetricNotFound = 1200, + + ErrFieldNotFound = 1700, + ErrFieldInvalidName = 1701, + + ErrMissingRequiredParameters = 1802, + ErrMarshalCollectionSchema = 1803, + ErrInvalidInsertData = 1804, + ErrInvalidSearchResult = 1805, + ErrCheckPrimaryKey = 1806, + + ErrSegcore = 2000, + + ErrUndefined = 65535 + +}; + +class Status { + public: + explicit Status(CStatus cstatus) { + if (cstatus.error_code != 0) { + msg_ = "segcore error"; + code_ = ErrSegcore; + detail_ = cstatus.error_msg; + free((void*)cstatus.error_msg); + cstatus.error_msg = NULL; + } else { + code_ = 0; + msg_ = ""; + detail_ = ""; + } + } + + private: + Status(int code, const std::string& msg, const std::string& detail = "") + : code_(code), msg_(msg), detail_(detail) { + } + + public: + virtual ~Status() = default; + + Status(const Status& rhs) = delete; + Status& + operator=(const Status& rhs) = delete; + + Status(Status&& rhs) : code_(rhs.code_) { + msg_ = std::move(rhs.msg_); + detail_ = std::move(rhs.detail_); + } + + Status& + operator=(Status&& rhs); + + public: + template + static Status + ServiceInternal(const std::string& detail = "", Args&&... args) { + return Status(ErrServiceInternal, + "internal error", + string_util::SFormat(detail, args...)); + } + + template + static Status + SegcoreErr(const std::string& detail = "", Args&&... args) { + return Status( + ErrSegcore, "segcore error", string_util::SFormat(detail, args...)); + } + + template + static Status + CollectionNotFound(const std::string& detail = "", Args&&... args) { + return Status(ErrCollectionNotFound, + "collection not found", + string_util::SFormat(detail, args...)); + } + + template + static Status + CollectionAlreadExist(const std::string& detail = "", Args&&... args) { + return Status(ErrCollectionAlreadExist, + "collection alread exists", + string_util::SFormat(detail, args...)); + } + + template + static Status + CollectionNotLoaded(const std::string& detail = "", Args&&... args) { + return Status(ErrCollectionNotLoaded, + "collection not loaded", + string_util::SFormat(detail, args...)); + } + + template + static Status + CollectionLoaded(const std::string& detail = "", Args&&... args) { + return Status(ErrCollectionLoaded, + "collection already loaded", + string_util::SFormat(detail, args...)); + } + + template + static Status + CollectionIllegalSchema(const std::string& detail = "", Args&&... args) { + return Status(ErrCollectionIllegalSchema, + "illegal collection schema", + string_util::SFormat(detail, args...)); + } + + template + static Status + IndexNotFound(const std::string& detail = "", Args&&... args) { + return Status(ErrIndexNotFound, + "index not found", + string_util::SFormat(detail, args...)); + } + + template + static Status + IndexNotSupported(const std::string& detail = "", Args&&... args) { + return Status(ErrIndexNotSupported, + "index type not supported", + string_util::SFormat(detail, args...)); + } + + template + static Status + IndexDuplicate(const std::string& detail = "", Args&&... args) { + return Status(ErrIndexDuplicate, + "index duplicates", + string_util::SFormat(detail, args...)); + } + + template + static Status + ParameterInvalid(const std::string& detail = "", Args&&... args) { + return Status(ErrParameterInvalid, + "invalid parameter", + string_util::SFormat(detail, args...)); + } + + template + static Status + ParameterMissing(const std::string& detail = "", Args&&... args) { + return Status(ErrParameterMissing, + "missing parameter", + string_util::SFormat(detail, args...)); + } + + template + static Status + MetricNotFound(const std::string& detail = "", Args&&... args) { + return Status(ErrMetricNotFound, + "metric not found", + string_util::SFormat(detail, args...)); + } + + template + static Status + FieldNotFound(const std::string& detail = "", Args&&... args) { + return Status(ErrFieldNotFound, + "field not found", + string_util::SFormat(detail, args...)); + } + + template + static Status + FieldInvalidName(const std::string& detail = "", Args&&... args) { + return Status(ErrFieldInvalidName, + "field name invalid", + string_util::SFormat(detail, args...)); + } + + template + static Status + Undefined(const std::string& detail = "", Args&&... args) { + return Status(ErrUndefined, "", string_util::SFormat(detail, args...)); + } + + template + static Status + Ok(const std::string& detail = "", Args&&... args) { + return Status(0, "", string_util::SFormat(detail, args...)); + } + + public: + bool + IsOk() { + return code_ == 0; + } + + bool + IsErr() { + return code_ != 0; + } + + int + Code() { + return code_; + } + + const std::string& + Msg() { + return msg_; + } + + const std::string& + Detail() { + return detail_; + } + + private: + int code_; + std::string msg_; + std::string detail_; +}; + +inline Status& +Status::operator=(Status&& rhs) { + code_ = rhs.code_; + msg_ = std::move(rhs.msg_); + detail_ = std::move(rhs.detail_); + return *this; +} + +} // namespace milvus::local diff --git a/src/storage.cpp b/src/storage.cpp new file mode 100644 index 0000000..3ccc546 --- /dev/null +++ b/src/storage.cpp @@ -0,0 +1,164 @@ +#include "storage.h" +#include +#include +#include +#include +#include +#include +#include "collection_data.h" +#include "log/Log.h" +#include "pb/schema.pb.h" + +namespace milvus::local { + +Storage::Storage(const char* db_file) : db_file_(db_file) { +} + +Storage::~Storage() { +} + +bool +Storage::Open() { + try { + db_ptr_ = std::make_unique( + db_file_, + SQLite::OPEN_READWRITE | SQLite::OPEN_CREATE | + SQLite::OPEN_FULLMUTEX); + if (!cm_.Init(db_ptr_.get())) { + return false; + } + std::vector names; + cm_.CollectionNames(&names); + for (const auto& name : names) { + collections_.emplace( + name, std::make_unique(name.c_str())); + } + return true; + } catch (std::exception& e) { + LOG_ERROR("Open storage failed, err: {}", e.what()); + return false; + } +} + +bool +Storage::CreateCollection(const std::string& collection_name, + const std::string& pk_name, + const std::string& schema_proto) { + SQLite::Transaction transaction(*db_ptr_.get()); + if (!cm_.CreateCollection( + db_ptr_.get(), collection_name, pk_name, schema_proto)) { + return false; + } + auto data_ptr = std::make_unique(collection_name.c_str()); + if (!data_ptr->CreateCollection(db_ptr_.get())) { + return false; + } + collections_[collection_name] = std::move(data_ptr); + transaction.commit(); + return true; +} + +bool +Storage::DropCollection(const std::string& collection_name) { + SQLite::Transaction transaction(*db_ptr_.get()); + if (!cm_.DropCollection(db_ptr_.get(), collection_name)) { + LOG_ERROR("Delete collection: {}'s meta failed", collection_name); + return false; + } + if (!collections_[collection_name]->DropCollection(db_ptr_.get())) { + LOG_ERROR("Delete collection: {}'s data failed", collection_name); + return false; + } + collections_.erase(collection_name); + transaction.commit(); + return true; +} + +bool +Storage::LoadCollecton(const std::string& collection_name, + int64_t start, + int64_t size, + std::vector* out_rows) { + collections_[collection_name]->Load(db_ptr_.get(), start, size, out_rows); + return static_cast(out_rows->size()) == size; +} + +bool +Storage::GetCollectionSchema(const std::string& collection_name, + std::string* output_info_str) { + output_info_str->assign(cm_.GetCollectionSchema(collection_name).c_str()); + return true; +} + +bool +Storage::CreateIndex(const std::string& collection_name, + const std::string& index_name, + const std::string& index_proto) { + SQLite::Transaction transaction(*db_ptr_.get()); + if (!cm_.CreateIndex( + db_ptr_.get(), collection_name, index_name, index_proto)) { + return false; + } + transaction.commit(); + return true; +} + +bool +Storage::GetIndex(const std::string& collection_name, + const std::string& index_name, + std::string* output_index_proto) { + return cm_.GetCollectionIndex( + collection_name, index_name, output_index_proto); +} + +void +Storage::GetAllIndex(const std::string& collection_name, + const std::string& exclude, + std::vector* index_protos) { + cm_.GetAllIndex(collection_name, exclude, index_protos); +} + +bool +Storage::DropIndex(const std::string& collection_name, + const std::string& index_name) { + SQLite::Transaction transaction(*db_ptr_.get()); + if (!cm_.DropIndex(db_ptr_.get(), collection_name, index_name)) { + return false; + } + transaction.commit(); + return true; +} + +int +Storage::Insert(const std::string collection_name, + const std::vector& rows) { + SQLite::Transaction transaction(*db_ptr_.get()); + for (const auto& row : rows) { + if (collections_[collection_name]->Insert(db_ptr_.get(), + std::get<0>(row).c_str(), + std::get<1>(row)) < 0) { + return -1; + } + } + transaction.commit(); + return rows.size(); +} + +int +Storage::Delete(const std::string collection_name, + const std::vector& ids) { + SQLite::Transaction transaction(*db_ptr_.get()); + int n = collections_[collection_name]->Delete(db_ptr_.get(), ids); + transaction.commit(); + return n; +} + +int64_t +Storage::Count(const std::string& collection_name) { + SQLite::Transaction transaction(*db_ptr_.get()); + int64_t n = collections_[collection_name]->Count(db_ptr_.get()); + transaction.commit(); + return n; +} + +} // namespace milvus::local diff --git a/src/storage.h b/src/storage.h new file mode 100644 index 0000000..a8a6024 --- /dev/null +++ b/src/storage.h @@ -0,0 +1,114 @@ +#pragma once +#include +#include +#include +#include +#include +#include "collection_data.h" +#include "collection_meta.h" + +namespace milvus::local { + +class CollectionMeta; +class CollectionData; + +class Storage final { + public: + explicit Storage(const char* db_file); + ~Storage(); + Storage(const Storage&) = delete; + Storage& + operator=(const Storage&) = delete; + Storage(const Storage&&) = delete; + Storage& + operator=(const Storage&&) = delete; + + public: + bool + Open(); + + bool + CreateCollection(const std::string& collection_name, + const std::string& pk_name, + const std::string& schema_proto); + bool + DropCollection(const std::string& collection_name); + + void + ListCollections(std::vector* collection_names) { + cm_.CollectionNames(collection_names); + } + + /* + * @brief 读取collection数据 + * + * @collection collection 名字 + * @size 一次读取的数量 + * @out_rows 输出, 如果out_rows的size小于参数的size,说明已经读取完 + */ + bool + LoadCollecton(const std::string& collection_name, + int64_t start, + int64_t size, + std::vector* out_rows); + + bool + CreateIndex(const std::string& collection_name, + const std::string& index_name, + const std::string& index_proto); + + bool + GetIndex(const std::string& collection_name, + const std::string& index_name, + std::string* output_index_proto); + + bool + DropIndex(const std::string& collection_name, + const std::string& index_name); + + void + GetAllIndex(const std::string& collection_name, + const std::string& exclude, + std::vector* index_protos); + + bool + HasIndex(const std::string& collection_name, + const std::string& index_name) { + return cm_.HasIndex(collection_name, index_name); + } + + std::string + GetPrimaryKey(const std::string& collection_name) { + return cm_.GetPkName(collection_name); + } + + bool + GetCollectionSchema(const std::string& collection_name, + std::string* output_info_str); + + // data interface + int + Insert(const std::string collection_name, const std::vector& rows); + + int + Delete(const std::string collection_name, + const std::vector& ids); + + bool + CollectionExist(const std::string& collection_name) { + return collections_.find(collection_name) != collections_.end(); + } + + int64_t + Count(const std::string& collection_name); + + private: + CollectionMeta cm_; + std::map> collections_; + + private: + std::unique_ptr db_ptr_; + const char* db_file_; +}; + +} // namespace milvus::local diff --git a/src/string_util.hpp b/src/string_util.hpp new file mode 100644 index 0000000..a20431b --- /dev/null +++ b/src/string_util.hpp @@ -0,0 +1,65 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace milvus::local { + +namespace string_util { + +inline std::string +ToLower(const std::string& str) { + std::string lower_str; + std::transform(str.begin(), + str.end(), + std::back_inserter(lower_str), + [](unsigned char c) { return std::tolower(c); }); + return lower_str; +} + +inline std::string +ToUpper(const std::string& str) { + std::string upper_str; + std::transform(str.begin(), + str.end(), + std::back_inserter(upper_str), + [](unsigned char c) { return std::toupper(c); }); + return upper_str; +} + +inline std::string +Trim(const std::string& str) { + size_t first = str.find_first_not_of(" \t\n\r\f\v"); + if (std::string::npos == first) + return ""; + size_t last = str.find_last_not_of(" \t\n\r\f\v"); + return str.substr(first, (last - first + 1)); +} + +template +inline std::string +SFormat(const std::string& str, Args&&... args) { + return folly::sformat(str, args...); +} + +template +inline std::string +Join(const Delim& delimiter, const Container& container) { + return folly::join(delimiter, container); +} + +inline bool +IsAlpha(char c) { + if ((c < 'A' || c > 'Z') && (c < 'a' || c > 'z')) { + return false; + } + return true; +} + +} // namespace string_util + +} // namespace milvus::local diff --git a/src/type.h b/src/type.h new file mode 100644 index 0000000..f90cc0a --- /dev/null +++ b/src/type.h @@ -0,0 +1,12 @@ +#pragma once + +#include +#include +#include + +namespace milvus::local { + +using Row = std::tuple; +using Rows = std::vector; + +} // namespace milvus::local diff --git a/src/unittest/CMakeLists.txt b/src/unittest/CMakeLists.txt new file mode 100644 index 0000000..7f160a9 --- /dev/null +++ b/src/unittest/CMakeLists.txt @@ -0,0 +1,56 @@ +find_package(GTest REQUIRED) +include_directories(${GTest_INCLUDES}) + +find_program(PYTHON_EXECUTABLE NAMES python python3) + +include_directories(${CMAKE_CURRENT_LIST_DIR}) + + +set(MILVUS_LOCAL_TEST_DEPS + milite + GTest::gtest + GTest::gtest_main +) + +add_executable( + milvus_proxy_test + milvus_proxy_test.cpp + test_util.cpp +) +target_link_libraries(milvus_proxy_test + PRIVATE + ${MILVUS_LOCAL_TEST_DEPS} +) + +add_test( + NAME milvus_proxy_test + COMMAND $ +) + +add_executable( + server_test + grpc_server_test.cpp + test_util.cpp +) +target_link_libraries( + server_test + PRIVATE + milvus_service + ${MILVUS_LOCAL_TEST_DEPS} +) + + +add_test( + NAME server_test + COMMAND $ +) + +add_test( + NAME run_examples + COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/run_examples.py +) + +set_tests_properties(run_examples + PROPERTIES + ENVIRONMENT "BIN_PATH=${CMAKE_BINARY_DIR}/lib;PYTHONPATH=${CMAKE_SOURCE_DIR}/python/src" +) diff --git a/src/unittest/grpc_server_test.cpp b/src/unittest/grpc_server_test.cpp new file mode 100644 index 0000000..9b307d1 --- /dev/null +++ b/src/unittest/grpc_server_test.cpp @@ -0,0 +1,284 @@ +#include "milvus_service_impl.h" +#include "pb/milvus.pb.h" +#include "test_util.h" +#include +#include +#include +#include "status.h" + +namespace milvus::local { +namespace test { + +const char* tmp_db_name = "server_test.db"; + +TEST(MilvusServiceImplTest, create_collection) { + const char* collection_name = "test_collection"; + { + ::milvus::local::MilvusServiceImpl service(tmp_db_name); + EXPECT_TRUE(service.Init()); + ::grpc::ServerContext server_context; + ::milvus::proto::common::Status response; + // drop + auto drop_r = GetDropCollectionRequest(collection_name); + service.DropCollection(&server_context, &drop_r, &response); + + auto r = GetCreateCollectionRequestProto(collection_name); + service.CreateCollection(&server_context, &r, &response); + EXPECT_EQ(response.code(), 0); + service.CreateCollection(&server_context, &r, &response); + EXPECT_EQ(response.code(), 0); + } + + { + ::milvus::local::MilvusServiceImpl service(tmp_db_name); + EXPECT_TRUE(service.Init()); + ::grpc::ServerContext server_context; + ::milvus::proto::common::Status response; + + // collection alread exists + auto r = GetCreateCollectionRequestProto(collection_name); + service.CreateCollection(&server_context, &r, &response); + EXPECT_EQ(response.code(), 0); + + auto lr = GetLoadCollectionRequestProto(collection_name); + service.LoadCollection(&server_context, &lr, &response); + EXPECT_EQ(response.code(), 0); + + auto new_lr = GetLoadCollectionRequestProto("not_exist"); + service.LoadCollection(&server_context, &new_lr, &response); + EXPECT_EQ(response.code(), ErrCollectionNotFound); + } +} + +TEST(MilvusServiceImplTest, CreateIndex) { + const char* collection_name = "test_collection"; + ::milvus::local::MilvusServiceImpl service(tmp_db_name); + EXPECT_TRUE(service.Init()); + ::grpc::ServerContext server_context; + ::milvus::proto::common::Status response; + auto drop_r = GetDropCollectionRequest(collection_name); + service.DropCollection(&server_context, &drop_r, &response); + auto r = GetCreateCollectionRequestProto(collection_name); + service.CreateCollection(&server_context, &r, &response); + EXPECT_EQ(response.code(), 0); + + { + auto r = + GetCreateIndexRequestProto(collection_name, "test_index", VEC_NAME); + service.CreateIndex(&server_context, &r, &response); + EXPECT_EQ(response.code(), 0); + + service.CreateIndex(&server_context, &r, &response); + EXPECT_EQ(response.code(), 0); + } + + { + auto r = + GetCreateIndexRequestProto("not_exist", "test_index", VEC_NAME); + service.CreateIndex(&server_context, &r, &response); + EXPECT_EQ(response.code(), ErrCollectionNotFound); + } +} + +TEST(MilvusServiceImplTest, Insert) { + const char* collection_name = "test_collection"; + ::milvus::local::MilvusServiceImpl service(tmp_db_name); + EXPECT_TRUE(service.Init()); + ::grpc::ServerContext server_context; + ::milvus::proto::common::Status response; + auto drop_r = GetDropCollectionRequest(collection_name); + service.DropCollection(&server_context, &drop_r, &response); + auto r = GetCreateCollectionRequestProto(collection_name); + service.CreateCollection(&server_context, &r, &response); + EXPECT_EQ(response.code(), 0); + + { + auto insert_requst = GetInsertRequestProto(collection_name, 3); + ::milvus::proto::milvus::MutationResult insert_response; + service.Insert(&server_context, &insert_requst, &insert_response); + EXPECT_EQ(insert_response.insert_cnt(), 3); + EXPECT_EQ(insert_response.status().code(), 0); + } +} + +TEST(MilvusServiceImplTest, Search) { + const char* collection_name = "test_collection"; + { + ::milvus::local::MilvusServiceImpl service(tmp_db_name); + EXPECT_TRUE(service.Init()); + ::grpc::ServerContext server_context; + ::milvus::proto::common::Status response; + auto drop_r = GetDropCollectionRequest(collection_name); + service.DropCollection(&server_context, &drop_r, &response); + auto r = GetCreateCollectionRequestProto(collection_name); + service.CreateCollection(&server_context, &r, &response); + EXPECT_EQ(response.code(), 0); + auto insert_requst = GetInsertRequestProto(collection_name, 3); + ::milvus::proto::milvus::MutationResult insert_response; + service.Insert(&server_context, &insert_requst, &insert_response); + EXPECT_EQ(insert_response.insert_cnt(), 3); + EXPECT_EQ(insert_response.status().code(), 0); + + auto index_req = + GetCreateIndexRequestProto(collection_name, "test_index", VEC_NAME); + service.CreateIndex(&server_context, &index_req, &response); + EXPECT_EQ(response.code(), 0); + + auto search_req = GetSearchRequestProto( + collection_name, + "id in [1, 2, 3]", + std::vector>{{0.1, 0.3, 0.6}, {0.3, 0.3, 0.4}}, + "2", + "COSINE", + "1"); + ::milvus::proto::milvus::SearchResults search_result; + service.Search(&server_context, &search_req, &search_result); + } + { + ::milvus::local::MilvusServiceImpl service(tmp_db_name); + EXPECT_TRUE(service.Init()); + ::grpc::ServerContext server_context; + ::milvus::proto::common::Status response; + auto lr = GetLoadCollectionRequestProto(collection_name); + service.LoadCollection(&server_context, &lr, &response); + EXPECT_EQ(response.code(), 0); + auto search_req = GetSearchRequestProto( + collection_name, + "id in [1, 2, 3]", + std::vector>{{0.1, 0.3, 0.6}, {0.3, 0.3, 0.4}}, + "2", + "COSINE", + "1"); + ::milvus::proto::milvus::SearchResults search_result; + service.Search(&server_context, &search_req, &search_result); + } +} + +TEST(MilvusServiceImplTest, Query) { + const char* collection_name = "test_collection"; + { + ::milvus::local::MilvusServiceImpl service(tmp_db_name); + EXPECT_TRUE(service.Init()); + ::grpc::ServerContext server_context; + ::milvus::proto::common::Status response; + auto drop_r = GetDropCollectionRequest(collection_name); + service.DropCollection(&server_context, &drop_r, &response); + auto r = GetCreateCollectionRequestProto(collection_name); + service.CreateCollection(&server_context, &r, &response); + EXPECT_EQ(response.code(), 0); + auto insert_requst = GetInsertRequestProto(collection_name, 3); + ::milvus::proto::milvus::MutationResult insert_response; + service.Insert(&server_context, &insert_requst, &insert_response); + EXPECT_EQ(insert_response.insert_cnt(), 3); + EXPECT_EQ(insert_response.status().code(), 0); + + auto query_req = GetQueryRequestProto(collection_name, + "id in [1, 2, 3]", + "2", + "0", + std::vector{"id"}); + ::milvus::proto::milvus::QueryResults query_result; + service.Query(&server_context, &query_req, &query_result); + EXPECT_EQ( + query_result.fields_data()[0].scalars().long_data().data_size(), 2); + } +} + +TEST(MilvusServiceImplTest, Delete) { + const char* collection_name = "test_collection"; + { + ::milvus::local::MilvusServiceImpl service(tmp_db_name); + EXPECT_TRUE(service.Init()); + ::grpc::ServerContext server_context; + ::milvus::proto::common::Status response; + auto drop_r = GetDropCollectionRequest(collection_name); + service.DropCollection(&server_context, &drop_r, &response); + auto r = GetCreateCollectionRequestProto(collection_name); + service.CreateCollection(&server_context, &r, &response); + EXPECT_EQ(response.code(), 0); + auto insert_requst = GetInsertRequestProto(collection_name, 10); + ::milvus::proto::milvus::MutationResult insert_response; + service.Insert(&server_context, &insert_requst, &insert_response); + EXPECT_EQ(insert_response.insert_cnt(), 10); + EXPECT_EQ(insert_response.status().code(), 0); + + auto delete_req = + GetDeleteRequestProto(collection_name, "id in [1, 2]"); + ::milvus::proto::milvus::MutationResult delete_result; + service.Delete(&server_context, &delete_req, &delete_result); + EXPECT_EQ(delete_result.delete_cnt(), 2); + + auto query_req = GetQueryRequestProto(collection_name, + "id in [1, 2, 3]", + "3", + "0", + std::vector{"id"}); + ::milvus::proto::milvus::QueryResults query_result; + service.Query(&server_context, &query_req, &query_result); + EXPECT_EQ( + query_result.fields_data()[0].scalars().long_data().data_size(), 1); + } + { + ::milvus::local::MilvusServiceImpl service(tmp_db_name); + EXPECT_TRUE(service.Init()); + ::grpc::ServerContext server_context; + ::milvus::proto::common::Status response; + auto lr = GetLoadCollectionRequestProto(collection_name); + service.LoadCollection(&server_context, &lr, &response); + EXPECT_EQ(response.code(), 0); + auto query_req = GetQueryRequestProto(collection_name, + "id in [1, 2, 3]", + "3", + "0", + std::vector{"id"}); + ::milvus::proto::milvus::QueryResults query_result; + service.Query(&server_context, &query_req, &query_result); + EXPECT_EQ( + query_result.fields_data()[0].scalars().long_data().data_size(), 1); + } +} + +TEST(MilvusServiceImplTest, describe_collection) { + const char* collection_name = "test_collection"; + + { + ::milvus::local::MilvusServiceImpl service(tmp_db_name); + EXPECT_TRUE(service.Init()); + ::grpc::ServerContext server_context; + ::milvus::proto::common::Status drop_res; + auto drop_r = GetDropCollectionRequest(collection_name); + service.DropCollection(&server_context, &drop_r, &drop_res); + ::milvus::proto::milvus::DescribeCollectionResponse res; + auto r = GetDescribeCollectionRequest(collection_name); + service.DescribeCollection(&server_context, &r, &res); + EXPECT_EQ(res.status().code(), ErrCollectionNotFound); + } + + { + ::milvus::local::MilvusServiceImpl service(tmp_db_name); + EXPECT_TRUE(service.Init()); + ::grpc::ServerContext server_context; + ::milvus::proto::common::Status response; + auto r = GetCreateCollectionRequestProto(collection_name); + service.CreateCollection(&server_context, &r, &response); + EXPECT_EQ(response.code(), 0); + + ::milvus::proto::milvus::DescribeCollectionResponse res; + auto dr = GetDescribeCollectionRequest(collection_name); + service.DescribeCollection(&server_context, &dr, &res); + EXPECT_EQ(res.status().code(), 0); + } + + { + ::milvus::local::MilvusServiceImpl service(tmp_db_name); + EXPECT_TRUE(service.Init()); + ::grpc::ServerContext server_context; + ::milvus::proto::milvus::DescribeCollectionResponse res; + auto dr = GetDescribeCollectionRequest(collection_name); + service.DescribeCollection(&server_context, &dr, &res); + EXPECT_EQ(res.status().code(), 0); + } +} + +} // namespace test +} // namespace milvus::local diff --git a/src/unittest/milvus_local_test.cpp b/src/unittest/milvus_local_test.cpp new file mode 100644 index 0000000..ec95af4 --- /dev/null +++ b/src/unittest/milvus_local_test.cpp @@ -0,0 +1,154 @@ +#include "milvus_local.h" +#include +#include +#include +#include +#include +#include "antlr4-runtime.h" +#include "parser/parser.h" +#include "parser/utils.h" +#include "pb/plan.pb.h" +#include "pb/schema.pb.h" +#include "pb/segcore.pb.h" +#include "test_util.h" +#include "type.h" + +namespace milvus::local { + +namespace test { + +TEST(MilvusLocal, h) { + std::string collection_name("test_schema"); + milvus::local::Rows data = CreateData(10); + auto schema_str = CreateCollection(); + auto index_str = CreateVectorIndex(); + + std::remove("milvus_data.db"); + { + MilvusLocal ms("./"); + ms.Init(); + ms.CreateCollection(collection_name, PK_NAME, schema_str); + ms.CreateIndex(collection_name, "test_index", index_str); + auto rows = CreateData(20); + std::vector ids; + ms.Insert(collection_name, rows, &ids); + milvus::proto::schema::CollectionSchema schema; + schema.ParseFromString(schema_str); + + // std::cout << schema.DebugString() << std::endl; + std::string exprstr("id in [1, 2, 6, 5, 8, 9]"); + antlr4::ANTLRInputStream input(exprstr); + PlanLexer lexer(&input); + antlr4::CommonTokenStream tokens(&lexer); + PlanParser parser(&tokens); + PlanParser::ExprContext* tree = parser.expr(); + auto helper = milvus::CreateSchemaHelper(&schema); + milvus::PlanCCVisitor visitor(&helper); + auto res = std::any_cast(visitor.visit(tree)); + { + ::milvus::proto::plan::PlanNode plan; + plan.mutable_query()->set_is_count(false); + plan.mutable_query()->set_limit(5); + plan.mutable_query()->set_allocated_predicates(res.expr); + plan.add_output_field_ids(200); + plan.add_output_field_ids(202); + std::cout << plan.DebugString() << std::endl; + RetrieveResult result; + ms.Retrieve(collection_name, plan.SerializeAsString(), &result); + milvus::proto::segcore::RetrieveResults rs; + rs.ParseFromArray(result.retrieve_result_.proto_blob, + result.retrieve_result_.proto_size); + std::cout << rs.DebugString() << std::endl; + } + + { + std::cout << "===============================" << std::endl; + ::milvus::proto::plan::PlanNode plan; + plan.mutable_vector_anns()->set_field_id(201); + plan.mutable_vector_anns()->set_allocated_predicates(res.expr); + plan.mutable_vector_anns()->set_placeholder_tag("$0"); + plan.mutable_vector_anns()->set_vector_type( + ::milvus::proto::plan::VectorType::FloatVector); + plan.mutable_vector_anns()->mutable_query_info()->set_topk(3); + plan.mutable_vector_anns()->mutable_query_info()->set_metric_type( + "IP"); + plan.mutable_vector_anns()->mutable_query_info()->set_search_params( + "{\"nprobe\": 10}"); + plan.mutable_vector_anns()->mutable_query_info()->set_round_decimal( + -1); + plan.add_output_field_ids(200); + plan.add_output_field_ids(202); + std::cout << plan.DebugString() << std::endl; + + milvus::proto::common::PlaceholderGroup raw_group; + auto value = raw_group.add_placeholders(); + value->set_tag("$0"); + value->set_type( + milvus::proto::common::PlaceholderType::FloatVector); + std::vector vec{0.3, 0.5, 0.2}; + value->add_values(vec.data(), vec.size() * sizeof(float)); + + auto slice_nqs = std::vector{1}; + auto slice_topKs = std::vector{3}; + SearchResult result(slice_nqs, slice_topKs); + ms.Search(collection_name, + plan.SerializeAsString(), + raw_group.SerializeAsString(), + &result); + milvus::proto::schema::SearchResultData rz; + rz.ParseFromArray(result.result_[0].proto_blob, + result.result_[0].proto_size); + // std::cout << rz.DebugString() << std::endl; + + milvus::proto::schema::IDs ids; + ids.mutable_int_id()->add_data(0); + ids.mutable_int_id()->add_data(1); + ids.mutable_int_id()->add_data(2); + ms.DeleteByIds(collection_name, + ids.SerializeAsString(), + 3, + std::vector{"0", "1", "2"}); + ms.Search(collection_name, + plan.SerializeAsString(), + raw_group.SerializeAsString(), + &result); + rz.ParseFromArray(result.result_[0].proto_blob, + result.result_[0].proto_size); + // std::cout << rz.DebugString() << std::endl; + + ms.ReleaseCollection(collection_name); + } + } + + { + MilvusLocal ms("./"); + ms.Init(); + ms.LoadCollection(collection_name); + ms.ReleaseCollection(collection_name); + } + + std::remove("milvus_data.db"); +} + +TEST(MilvusLocal, parser) { + auto schema_str = CreateCollection(); + milvus::proto::schema::CollectionSchema schema; + schema.ParseFromString(schema_str); + + // std::string exprstr("sc in [1, 2, 3, 4]"); + // antlr4::ANTLRInputStream input(exprstr); + // PlanLexer lexer(&input); + // antlr4::CommonTokenStream tokens(&lexer); + // PlanParser parser(&tokens); + + // PlanParser::ExprContext* tree = parser.expr(); + + // auto helper = milvus::CreateSchemaHelper(&schema); + // milvus::PlanCCVisitor visitor(&helper); + // auto res = std::any_cast(visitor.visit(tree)); + // std::cout << res.expr->DebugString() << std::endl; + // std::cout << "---------------" << std::endl; + // std::cout << res.expr->term_expr().DebugString() << std::endl; +} +} // namespace test +} // namespace milvus::local diff --git a/src/unittest/milvus_proxy_test.cpp b/src/unittest/milvus_proxy_test.cpp new file mode 100644 index 0000000..01b8e61 --- /dev/null +++ b/src/unittest/milvus_proxy_test.cpp @@ -0,0 +1,203 @@ +#include "milvus_proxy.h" +#include +#include +#include +#include "log/Log.h" +#include "pb/milvus.pb.h" +#include "pb/schema.pb.h" +#include "test_util.h" +#include + +namespace milvus::local { +namespace test { + +const char* tmp_db_name = "test.db"; + +TEST(MilvusProxyTest, CreateCollection) { + const char* collection_name = "test_collection"; + { + // create new collection + ::milvus::local::MilvusProxy proxy(tmp_db_name); + EXPECT_TRUE(proxy.Init()); + EXPECT_FALSE(proxy.Init()); + proxy.DropCollection(collection_name); + auto cr = GetCreateCollectionRequestProto(collection_name); + EXPECT_TRUE(proxy.CreateCollection(&cr).IsOk()); + } + + { + // load collection + ::milvus::local::MilvusProxy proxy(tmp_db_name); + EXPECT_TRUE(proxy.Init()); + EXPECT_TRUE(proxy.LoadCollection(collection_name).IsOk()); + // reload is ok + EXPECT_TRUE(proxy.LoadCollection(collection_name).IsOk()); + + EXPECT_TRUE(proxy.ReleaseCollection(collection_name).IsOk()); + EXPECT_TRUE(proxy.ReleaseCollection(collection_name).IsOk()); + + EXPECT_FALSE(proxy.LoadCollection("not_existed").IsOk()); + } +} + +TEST(MilvusProxyTest, CreateIndex) { + const char* collection_name = "test_collection"; + const char* index_name = "test_index"; + + ::milvus::local::MilvusProxy proxy(tmp_db_name); + EXPECT_TRUE(proxy.Init()); + proxy.DropCollection(collection_name); + auto cr = GetCreateCollectionRequestProto(collection_name); + EXPECT_TRUE(proxy.CreateCollection(&cr).IsOk()); + + { + // create new index + auto index_req = + GetCreateIndexRequestProto(collection_name, index_name, VEC_NAME); + EXPECT_TRUE(proxy.CreateIndex(&index_req).IsOk()); + + EXPECT_TRUE(proxy.CreateIndex(&index_req).IsOk()); + } +} + +TEST(MilvusProxyTest, Insert) { + const char* collection_name = "test_collection"; + ::milvus::local::MilvusProxy proxy(tmp_db_name); + EXPECT_TRUE(proxy.Init()); + proxy.DropCollection(collection_name); + auto cr = GetCreateCollectionRequestProto(collection_name); + EXPECT_TRUE(proxy.CreateCollection(&cr).IsOk()); + + { + auto data = GetInsertRequestProto(collection_name, 3); + ::milvus::proto::schema::IDs ids; + proxy.Insert(&data, &ids); + EXPECT_EQ(3, ids.int_id().data_size()); + } +} + +TEST(MilvusProxyTest, search) { + const char* collection_name = "test_collection"; + { + ::milvus::local::MilvusProxy proxy(tmp_db_name); + EXPECT_TRUE(proxy.Init()); + proxy.DropCollection(collection_name); + auto cr = GetCreateCollectionRequestProto(collection_name); + EXPECT_TRUE(proxy.CreateCollection(&cr).IsOk()); + auto data = GetInsertRequestProto(collection_name, 10); + ::milvus::proto::schema::IDs ids; + proxy.Insert(&data, &ids); + EXPECT_EQ(10, ids.int_id().data_size()); + auto index_req = + GetCreateIndexRequestProto(collection_name, "vindex", VEC_NAME); + EXPECT_TRUE(proxy.CreateIndex(&index_req).IsOk()); + + auto search_req = GetSearchRequestProto( + collection_name, + "id in [1, 2, 3]", + std::vector>{{0.1, 0.3, 0.6}, {0.3, 0.3, 0.4}}, + "2", + "COSINE", + "1"); + ::milvus::proto::milvus::SearchResults search_result; + EXPECT_TRUE(proxy.Search(&search_req, &search_result).IsOk()); + } + + { + ::milvus::local::MilvusProxy proxy(tmp_db_name); + EXPECT_TRUE(proxy.Init()); + EXPECT_TRUE(proxy.LoadCollection(collection_name).IsOk()); + auto search_req = GetSearchRequestProto( + collection_name, + "id in [1, 2, 3]", + std::vector>{{0.1, 0.3, 0.6}, {0.3, 0.3, 0.4}}, + "2", + "COSINE", + "1"); + ::milvus::proto::milvus::SearchResults search_result; + EXPECT_TRUE(proxy.Search(&search_req, &search_result).IsOk()); + } +} + +TEST(MilvusProxyTest, query) { + const char* collection_name = "test_collection"; + { + ::milvus::local::MilvusProxy proxy(tmp_db_name); + EXPECT_TRUE(proxy.Init()); + proxy.DropCollection(collection_name); + auto cr = GetCreateCollectionRequestProto(collection_name); + EXPECT_TRUE(proxy.CreateCollection(&cr).IsOk()); + auto data = GetInsertRequestProto(collection_name, 10); + ::milvus::proto::schema::IDs ids; + proxy.Insert(&data, &ids); + EXPECT_EQ(10, ids.int_id().data_size()); + auto query_req = GetQueryRequestProto(collection_name, + "id in [1, 2, 3]", + "2", + "0", + std::vector{"id"}); + ::milvus::proto::milvus::QueryResults query_result; + EXPECT_TRUE(proxy.Query(&query_req, &query_result).IsOk()); + + EXPECT_EQ( + query_result.fields_data()[0].scalars().long_data().data_size(), 2); + } + + { + ::milvus::local::MilvusProxy proxy(tmp_db_name); + EXPECT_TRUE(proxy.Init()); + EXPECT_TRUE(proxy.LoadCollection(collection_name).IsOk()); + auto query_req = GetQueryRequestProto( + collection_name, "id==1", "2", "0", std::vector{"id"}); + ::milvus::proto::milvus::QueryResults query_result; + EXPECT_TRUE(proxy.Query(&query_req, &query_result).IsOk()); + query_result.PrintDebugString(); + } +} + +TEST(MilvusProxyTest, delete) { + const char* collection_name = "test_collection"; + { + ::milvus::local::MilvusProxy proxy(tmp_db_name); + EXPECT_TRUE(proxy.Init()); + proxy.DropCollection(collection_name); + auto cr = GetCreateCollectionRequestProto(collection_name); + EXPECT_TRUE(proxy.CreateCollection(&cr).IsOk()); + auto data = GetInsertRequestProto(collection_name, 10); + ::milvus::proto::schema::IDs ids; + proxy.Insert(&data, &ids); + EXPECT_EQ(10, ids.int_id().data_size()); + + auto delete_req = + GetDeleteRequestProto(collection_name, "id in [1, 2]"); + ::milvus::proto::milvus::MutationResult response; + EXPECT_TRUE(proxy.Delete(&delete_req, &response).IsOk()); + + auto query_req = GetQueryRequestProto(collection_name, + "id in [1, 2, 3]", + "3", + "0", + std::vector{"id"}); + ::milvus::proto::milvus::QueryResults query_result; + EXPECT_TRUE(proxy.Query(&query_req, &query_result).IsOk()); + EXPECT_EQ( + query_result.fields_data()[0].scalars().long_data().data_size(), 1); + } + { + ::milvus::local::MilvusProxy proxy(tmp_db_name); + EXPECT_TRUE(proxy.Init()); + proxy.LoadCollection(collection_name); + auto query_req = GetQueryRequestProto(collection_name, + "id in [1, 2, 3]", + "3", + "0", + std::vector{"id"}); + ::milvus::proto::milvus::QueryResults query_result; + EXPECT_TRUE(proxy.Query(&query_req, &query_result).IsOk()); + EXPECT_EQ( + query_result.fields_data()[0].scalars().long_data().data_size(), 1); + } +} + +} // namespace test +} // namespace milvus::local diff --git a/src/unittest/run_examples.py b/src/unittest/run_examples.py new file mode 100644 index 0000000..97787bf --- /dev/null +++ b/src/unittest/run_examples.py @@ -0,0 +1,25 @@ +import os +import sys +import pathlib +import subprocess + + +def run_all_examples(): + examples_dir = pathlib.Path(__file__).absolute().parent.parent.parent / 'examples' + for f in examples_dir.glob('*.py'): + if str(f).endswith('bfloat16_example.py') or str(f).endswith('dynamic_field.py'): + continue + print(str(f)) + p = subprocess.Popen(args=[sys.executable, str(f)]) + p.wait() + if p.returncode != 0: + return False + return True + + +if __name__ == '__main__': + if not run_all_examples(): + exit(-1) + exit(0) + + diff --git a/src/unittest/storage_test.cpp b/src/unittest/storage_test.cpp new file mode 100644 index 0000000..396a93d --- /dev/null +++ b/src/unittest/storage_test.cpp @@ -0,0 +1,39 @@ +#include "storage.h" +#include +#include "test_util.h" + +namespace milvus::local { + +TEST(Storage, h) { + // auto schema_str = create_test_collection(); + // auto index_str = create_test_index(); + + // ::milvus::proto::msg::InsertRequest insert; + // auto row_data = insert.add_row_data(); + + // const char* db_path = "test.db"; + // { + // Storage s(db_path); + // s.open(); + // s.create_collection("test", schema_str); + // s.create_index("test", "test_index", index_str); + // } + + // { + // Storage s(db_path); + // s.open(); + // std::string schema; + // std::string index; + // s.get_collection_schema("test", schema); + // s.get_index("test", "test_index", &index); + // ::milvus::proto::schema::CollectionSchema sc; + // sc.ParseFromString(schema); + + // milvus::proto::segcore::FieldIndexMeta index_mt; + // index_mt.ParseFromString(index); + // std::cout << index_mt.index_name(); + // } + // std::remove(db_path); +} + +} // namespace milvus::local diff --git a/src/unittest/test_util.cpp b/src/unittest/test_util.cpp new file mode 100644 index 0000000..c432ee1 --- /dev/null +++ b/src/unittest/test_util.cpp @@ -0,0 +1,283 @@ +#include "test_util.h" +#include +#include +#include +#include +#include + +#include "common.h" +#include "log/Log.h" +#include "type.h" +#include "pb/milvus.pb.h" +#include "pb/msg.pb.h" +#include "pb/schema.pb.h" +#include "pb/segcore.pb.h" + +namespace milvus::local { +namespace test { + +std::string +CreateCollection(const std::string& collection_name) { + ::milvus::proto::schema::CollectionSchema schema; + + schema.set_name(collection_name); + schema.set_enable_dynamic_field(false); + auto field1 = schema.add_fields(); + field1->set_fieldid(PK_ID); + field1->set_name(PK_NAME); + field1->set_is_primary_key(true); + field1->set_data_type(::milvus::proto::schema::DataType::Int64); + + auto field2 = schema.add_fields(); + field2->set_fieldid(VEC_ID); + field2->set_name(VEC_NAME); + field2->set_data_type(::milvus::proto::schema::DataType::FloatVector); + auto params = field2->add_type_params(); + params->set_key(VEC_DIM_NAME); + params->set_value(std::to_string(VEC_DIM)); + + auto field3 = schema.add_fields(); + field3->set_fieldid(SCALAR_ID); + field3->set_name(SCALAR_NAME); + field3->set_data_type(::milvus::proto::schema::DataType::Int32); + return schema.SerializeAsString(); +} + +std::string +CreateVectorIndex() { + milvus::proto::segcore::CollectionIndexMeta index_meta; + index_meta.set_maxindexrowcount(1000000); + auto field_meta = index_meta.add_index_metas(); + field_meta->set_index_name("vec_index"); + field_meta->set_fieldid(VEC_ID); + field_meta->set_collectionid(0); + field_meta->set_is_auto_index(true); + + auto pair = field_meta->add_index_params(); + pair->set_key("metric_type"); + pair->set_value("IP"); + + return index_meta.SerializeAsString(); +} + +milvus::local::Rows +CreateData(int32_t count) { + milvus::local::Rows rs; + for (int64_t i = 0; i < count; i++) { + ::milvus::proto::segcore::InsertRecord r; + r.set_num_rows(1); + + // set pk + ::milvus::proto::schema::FieldData* pk = r.add_fields_data(); + pk->set_field_id(PK_ID); + pk->set_field_name(PK_NAME); + pk->set_type(::milvus::proto::schema::Int64); + pk->mutable_scalars()->mutable_long_data()->add_data(i); + + // set vec + ::milvus::proto::schema::FieldData* vec = r.add_fields_data(); + vec->set_field_id(VEC_ID); + vec->set_field_name(VEC_NAME); + vec->set_type(::milvus::proto::schema::FloatVector); + auto v = vec->mutable_vectors(); + v->set_dim(VEC_DIM); + auto vd = v->mutable_float_vector(); + vd->add_data(0.1); + vd->add_data(0.5); + vd->add_data(0.4); + + // set scalar + ::milvus::proto::schema::FieldData* sc = r.add_fields_data(); + sc->set_field_id(SCALAR_ID); + sc->set_field_name(SCALAR_NAME); + sc->set_type(::milvus::proto::schema::Int32); + sc->mutable_scalars()->mutable_int_data()->add_data(i); + rs.push_back(std::make_tuple(std::to_string(i), r.SerializeAsString())); + } + return rs; +} + +::milvus::proto::milvus::CreateCollectionRequest +GetCreateCollectionRequestProto(const std::string& collection_name) { + ::milvus::proto::milvus::CreateCollectionRequest r; + r.set_collection_name(collection_name); + auto schema_str = CreateCollection(collection_name); + r.set_schema(schema_str.data()); + r.set_consistency_level(::milvus::proto::common::ConsistencyLevel::Strong); + return r; +} + +::milvus::proto::milvus::LoadCollectionRequest +GetLoadCollectionRequestProto(const std::string& collection_name) { + ::milvus::proto::milvus::LoadCollectionRequest r; + r.set_collection_name(collection_name); + return r; +} + +::milvus::proto::milvus::CreateIndexRequest +GetCreateIndexRequestProto(const std::string& collection_name, + const std::string& index_name, + const std::string& field_name) { + ::milvus::proto::milvus::CreateIndexRequest r; + r.set_index_name(index_name); + r.set_collection_name(collection_name); + r.set_field_name(field_name); + + auto p1 = r.add_extra_params(); + p1->set_key("params"); + p1->set_value("{}"); + auto p2 = r.add_extra_params(); + p2->set_key("metric_type"); + p2->set_value("IP"); + + auto p3 = r.add_extra_params(); + p3->set_key("index_type"); + p3->set_value("AUTOINDEX"); + + return r; +} + +::milvus::proto::milvus::InsertRequest +GetInsertRequestProto(const std::string& collection_name, int64_t row_num) { + ::milvus::proto::milvus::InsertRequest r; + r.set_collection_name(collection_name); + r.set_num_rows(row_num); + + // set pk + ::milvus::proto::schema::FieldData* pk = r.add_fields_data(); + pk->set_field_id(PK_ID); + pk->set_field_name(PK_NAME); + pk->set_type(::milvus::proto::schema::Int64); + for (int64_t i = 0; i < row_num; ++i) { + pk->mutable_scalars()->mutable_long_data()->add_data(i); + } + + // set vec + ::milvus::proto::schema::FieldData* vec = r.add_fields_data(); + vec->set_field_id(VEC_ID); + vec->set_field_name(VEC_NAME); + vec->set_type(::milvus::proto::schema::FloatVector); + auto v = vec->mutable_vectors(); + v->set_dim(VEC_DIM); + auto vd = v->mutable_float_vector(); + for (int64_t i = 0; i < row_num; ++i) { + vd->add_data(0.1 * i); + vd->add_data(0.5 * i); + vd->add_data(0.4 * i); + } + + // set scalar + ::milvus::proto::schema::FieldData* sc = r.add_fields_data(); + sc->set_field_id(SCALAR_ID); + sc->set_field_name(SCALAR_NAME); + sc->set_type(::milvus::proto::schema::Int32); + for (int64_t i = 0; i < row_num; ++i) { + sc->mutable_scalars()->mutable_int_data()->add_data(i); + } + + return r; +} + +::milvus::proto::milvus::SearchRequest +GetSearchRequestProto(const std::string& collection_name, + const std::string& expr, + const std::vector>& vecs, + const std::string& topk, + const std::string& metric_type, + const std::string& offset) { + ::milvus::proto::milvus::SearchRequest r; + r.set_collection_name(collection_name); + r.set_dsl(expr); + r.set_dsl_type(::milvus::proto::common::DslType::BoolExprV1); + r.mutable_output_fields()->Add(SCALAR_NAME); + r.mutable_output_fields()->Add(VEC_NAME); + + milvus::proto::common::PlaceholderGroup raw_group; + auto value = raw_group.add_placeholders(); + value->set_tag(milvus::local::kPlaceholderTag); + value->set_type(milvus::proto::common::PlaceholderType::FloatVector); + for (const auto& vec : vecs) { + value->add_values(vec.data(), vec.size() * sizeof(float)); + } + + r.set_placeholder_group(raw_group.SerializeAsString()); + r.set_nq(vecs.size()); + + auto p1 = r.mutable_search_params()->Add(); + // p1->set_key("search_param"); + // p1->set_value("{\"nprobe\":10}"); + p1->set_key("params"); + p1->set_value("{\"nprobe\":10}"); + + auto p2 = r.mutable_search_params()->Add(); + p2->set_key("round_decimal"); + p2->set_value("-1"); + + auto p3 = r.mutable_search_params()->Add(); + p3->set_key("ignore_growing"); + p3->set_value("False"); + + auto p4 = r.mutable_search_params()->Add(); + p4->set_key("topk"); + p4->set_value(topk); + + // auto p5 = r.mutable_search_params()->Add(); + // p5->set_key("metric_type"); + // p5->set_value(metric_type); + + // auto p6 = r.mutable_search_params()->Add(); + // p6->set_key("offset"); + // p6->set_value(offset); + return r; +} + +::milvus::proto::milvus::QueryRequest +GetQueryRequestProto(const std::string& collection_name, + const std::string& expr, + const std::string& limit, + const std::string& offset, + const std::vector& output_fields) { + ::milvus::proto::milvus::QueryRequest r; + r.set_collection_name(collection_name); + r.set_expr(expr); + auto p1 = r.mutable_query_params()->Add(); + p1->set_key("limit"); + p1->set_value(limit); + + auto p2 = r.mutable_query_params()->Add(); + p2->set_key("reduce_stop_for_best"); + p2->set_value("False"); + + auto p3 = r.mutable_query_params()->Add(); + p3->set_key("ignore_growing"); + p3->set_value("False"); + + auto p4 = r.mutable_query_params()->Add(); + p4->set_key("offset"); + p4->set_value(offset); + r.set_guarantee_timestamp(1); + r.set_use_default_consistency(true); + for (const auto& f : output_fields) { + r.add_output_fields(f); + } + return r; +} + +::milvus::proto::milvus::DeleteRequest +GetDeleteRequestProto(const std::string& collection_name, + const std::string& expr) { + ::milvus::proto::milvus::DeleteRequest r; + r.set_collection_name(collection_name); + r.set_expr(expr); + return r; +} + +::milvus::proto::milvus::DescribeCollectionRequest +GetDescribeCollectionRequest(const std::string& collection_name) { + ::milvus::proto::milvus::DescribeCollectionRequest r; + r.set_collection_name(collection_name); + return r; +} + +} // namespace test +} // namespace milvus::local diff --git a/src/unittest/test_util.h b/src/unittest/test_util.h new file mode 100644 index 0000000..eebf081 --- /dev/null +++ b/src/unittest/test_util.h @@ -0,0 +1,75 @@ +#pragma once +#include +#include +#include +#include "type.h" +#include "pb/milvus.pb.h" + +namespace milvus::local { +namespace test { + +#define PK_NAME "id" +#define PK_ID 200 + +#define VEC_NAME "vec" +#define VEC_ID 201 +#define VEC_DIM_NAME "dim" +#define VEC_DIM 3 + +#define SCALAR_NAME "sc" +#define SCALAR_ID 202 + +std::string +CreateCollection(const std::string& collection_name = "test_schema"); + +std::string +CreateVectorIndex(); + +milvus::local::Rows +CreateData(int32_t count); + +::milvus::proto::milvus::CreateCollectionRequest +GetCreateCollectionRequestProto(const std::string& collection_name); + +::milvus::proto::milvus::LoadCollectionRequest +GetLoadCollectionRequestProto(const std::string& collection_name); + +::milvus::proto::milvus::CreateIndexRequest +GetCreateIndexRequestProto(const std::string& collection_name, + const std::string& index_name, + const std::string& field_name); + +::milvus::proto::milvus::InsertRequest +GetInsertRequestProto(const std::string& collection_name, int64_t row_num); + +::milvus::proto::milvus::SearchRequest +GetSearchRequestProto(const std::string& collection_name, + const std::string& expr, + const std::vector>& vecs, + const std::string& topk, + const std::string& metric_type, + const std::string& offset); + +::milvus::proto::milvus::QueryRequest +GetQueryRequestProto(const std::string& collection_name, + const std::string& expr, + const std::string& limit, + const std::string& offset, + const std::vector& output_fields); + +::milvus::proto::milvus::DeleteRequest +GetDeleteRequestProto(const std::string& collection_name, + const std::string& expr); + +::milvus::proto::milvus::DescribeCollectionRequest +GetDescribeCollectionRequest(const std::string& collection_name); + +inline ::milvus::proto::milvus::DropCollectionRequest +GetDropCollectionRequest(const std::string& collection_name) { + ::milvus::proto::milvus::DropCollectionRequest request; + request.set_collection_name(collection_name); + return request; +} + +} // namespace test +} // namespace milvus::local diff --git a/tests/test_milvus_config.py b/tests/test_milvus_config.py deleted file mode 100644 index 85581d2..0000000 --- a/tests/test_milvus_config.py +++ /dev/null @@ -1,47 +0,0 @@ -from io import StringIO -from typing import Any - -import pytest -from milvus import MilvusServerConfig -import yaml - - -def file_to_yaml(filepath: str) -> Any: - with open(filepath, encoding='utf-8') as stream: - return yaml.load(stream, Loader=yaml.FullLoader) - - -@pytest.mark.parametrize('text, key, value', ( - ( - """ -a: - b: - c: 1 - x: hello - """, 'a.b.c', 2 - ), - ( - """ -x: - y: 123 ## -a: - b: - c: 1 ## - x: hello - """, 'a.x', False - ) -)) -def test_config_extra_config(tmpdir, text, key, value): - template_file = tmpdir.join("test.yaml") - template_file.write(text) - config = MilvusServerConfig(template=template_file.strpath) - config.load_template() - config.base_data_dir = tmpdir.strpath - config.set(key, value) - config.write_config() - milvus_config = tmpdir.join('configs', 'milvus.yaml') - data = file_to_yaml(milvus_config) - for w in key.split('.'): - assert w in data - data = data[w] - assert data == value diff --git a/thirdparty/milvus b/thirdparty/milvus new file mode 160000 index 0000000..ffb6edd --- /dev/null +++ b/thirdparty/milvus @@ -0,0 +1 @@ +Subproject commit ffb6edd433856ba1dc0c4466582c8829711296ac diff --git a/thirdparty/milvus-proto b/thirdparty/milvus-proto new file mode 160000 index 0000000..e3012ae --- /dev/null +++ b/thirdparty/milvus-proto @@ -0,0 +1 @@ +Subproject commit e3012ae615bc1fdb0908cccf44c54c8a7da1a222 diff --git a/thirdparty/milvus-storage b/thirdparty/milvus-storage new file mode 160000 index 0000000..c23ba73 --- /dev/null +++ b/thirdparty/milvus-storage @@ -0,0 +1 @@ +Subproject commit c23ba736d7e6dcd21f7e6288525f706746329e8e