diff --git a/.bazelrc b/.bazelrc index 118c3bd2..7ba09399 100644 --- a/.bazelrc +++ b/.bazelrc @@ -43,7 +43,7 @@ build --copt=-fstack-protector-strong build:linux --copt=-Wl,-z,noexecstack build:macos --copt=-Wa,--noexecstack -test --keep_going +build --keep_going test --test_output=errors build:benchmark --copt -O3 @@ -55,11 +55,12 @@ build:linux --action_env=BAZEL_LINKLIBS=-l%:libstdc++.a:-l%:libgcc.a # platform specific config # Bazel will automatic pick platform config since we have enable_platform_specific_config set -build:macos --copt="-Xpreprocessor -fopenmp" +build:macos --copt=-Xclang=-fopenmp build:macos --copt=-Wno-unused-command-line-argument build:macos --features=-supports_dynamic_linker -build:macos --macos_minimum_os=12.0 -build:macos --host_macos_minimum_os=12.0 +build:macos --macos_minimum_os=13.0 +build:macos --host_macos_minimum_os=13.0 +build:macos --action_env MACOSX_DEPLOYMENT_TARGET=13.0 build:linux --copt=-fopenmp build:linux --linkopt=-fopenmp diff --git a/.circleci/asan-config.yml b/.circleci/asan-config.yml index 33c6cf87..901cad5d 100644 --- a/.circleci/asan-config.yml +++ b/.circleci/asan-config.yml @@ -19,7 +19,7 @@ version: 2.1 parameters: run-asan: type: boolean - default: false + default: true # Define a job to be invoked later in a workflow. # See: https://circleci.com/docs/2.0/configuration-reference/#jobs @@ -55,7 +55,7 @@ jobs: command: | set +e declare -i test_status - bazel test //libspu/... --features=asan --ui_event_filters=-info,-debug,-warning --test_output=errors | tee test_result.log; test_status=${PIPESTATUS[0]} + bazel test //libspu/... --features=asan --test_timeout=500 --ui_event_filters=-info,-debug,-warning --test_output=errors | tee test_result.log; test_status=${PIPESTATUS[0]} sh ../devtools/rename-junit-xml.sh find bazel-testlogs/ -type f -name "test.log" -print0 | xargs -0 tar -cvzf test_logs.tar.gz diff --git a/.circleci/config.yml b/.circleci/config.yml index 41e5522e..c4617cd9 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -17,7 +17,7 @@ version: 2.1 setup: true orbs: - path-filtering: circleci/path-filtering@1.0.0 + path-filtering: circleci/path-filtering@1.1.0 continuation: circleci/continuation@1.0.0 parameters: diff --git a/.circleci/continue-config.yml b/.circleci/continue-config.yml index df2d3faa..50ff328c 100644 --- a/.circleci/continue-config.yml +++ b/.circleci/continue-config.yml @@ -64,7 +64,7 @@ commands: ../devtools/bazel_cache_setup.py --in_file=../gcs.data --out_file=../gcs.json --min_download - run: name: "build" - command: bazel build <> -c opt --ui_event_filters=-info,-debug,-warning --jobs 20 + command: bazel build <> -c opt --ui_event_filters=-info,-debug,-warning - run: name: "test" command: | @@ -120,7 +120,7 @@ jobs: extra_bazel_test_args: --test_env LD_LIBRARY_PATH=/root/miniconda3/lib/ macOS_ut: macos: - xcode: 15.4.0 + xcode: 16.0.0 resource_class: macos.m1.large.gen1 steps: - checkout diff --git a/.circleci/release-config.yml b/.circleci/release-config.yml index c52b787d..762fa599 100644 --- a/.circleci/release-config.yml +++ b/.circleci/release-config.yml @@ -63,7 +63,7 @@ commands: jobs: macOS_publish: macos: - xcode: 15.4.0 + xcode: 16.0.0 resource_class: macos.m1.large.gen1 parameters: python_ver: diff --git a/.clang-tidy b/.clang-tidy index bacbb3a2..227a407d 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -28,7 +28,8 @@ Checks: "abseil-cleanup-ctad, -readability-identifier-length, -readability-function-cognitive-complexity, -readability-magic-numbers, - -readability-named-parameter" + -readability-named-parameter, + -readability-convert-member-functions-to-static" CheckOptions: - key: bugprone-argument-comment.StrictMode diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml index 5b25f5b2..3dadfd88 100644 --- a/.github/workflows/scorecard.yml +++ b/.github/workflows/scorecard.yml @@ -32,12 +32,12 @@ jobs: steps: - name: "Checkout code" - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: persist-credentials: false - name: "Run analysis" - uses: ossf/scorecard-action@dc50aa9510b46c811795eb24b2f1ba02a914e534 # v2.3.3 + uses: ossf/scorecard-action@62b2cac7ed8198b15735ed49ab1e5cf35480ba46 # v2.4.0 with: results_file: results.sarif results_format: sarif @@ -67,6 +67,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard. - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@4fa2a7953630fd2f3fb380f21be14ede0169dd4f # v3.25.12 + uses: github/codeql-action/upload-sarif@df409f7d9260372bd5f19e5b04e83cb3c43714ae # v3.27.9 with: sarif_file: results.sarif diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b86bb5b..97ff36b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,12 @@ > > please add your unreleased change here. +## 20241219 + +- [SPU] 0.9.3b0 release +- [Improvement] Optimize exponential computation for semi2k (**experimental**) +- [Feature] Add more send/recv actions profiling + ## 20240716 - [SPU] 0.9.2b0 release diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 0aaf7287..9701c407 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -72,7 +72,7 @@ python3 -m pip install -r requirements-dev.txt #### macOS ```sh -# macOS >= 12.0, Xcode >= 14.0 +# macOS >= 13.0, Xcode >= 15.0 # Install Xcode https://apps.apple.com/us/app/xcode/id497799835?mt=12 diff --git a/README.md b/README.md index 6a1fe7e3..f1c5eb7c 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,9 @@ Please follow [Installation Guidelines](INSTALLATION.md) to install SPU. ## Citing SPU -If you think SPU is helpful for your research or development, please consider citing our [paper](https://www.usenix.org/conference/atc23/presentation/ma): +If you think SPU is helpful for your research or development, please consider citing our papers: + +[USENIX ATC'23](https://www.usenix.org/conference/atc23/presentation/ma) ```text @inproceedings {spu, @@ -69,6 +71,26 @@ If you think SPU is helpful for your research or development, please consider ci } ``` +[ICML'24](https://proceedings.mlr.press/v235/wu24d.html) + +```text +@inproceedings{ditto, + title = {Ditto: Quantization-aware Secure Inference of Transformers upon {MPC}}, + author = {Wu, Haoqi and Fang, Wenjing and Zheng, Yancheng and Ma, Junming and Tan, Jin and Wang, Lei}, + booktitle = {Proceedings of the 41st International Conference on Machine Learning}, + pages = {53346--53365}, + year = {2024}, + editor = {Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix}, + volume = {235}, + series = {Proceedings of Machine Learning Research}, + month = {21--27 Jul}, + publisher = {PMLR}, + pdf = {https://raw.githubusercontent.com/mlresearch/v235/main/assets/wu24d/wu24d.pdf}, + url = {https://proceedings.mlr.press/v235/wu24d.html}, + abstract = {Due to the rising privacy concerns on sensitive client data and trained models like Transformers, secure multi-party computation (MPC) techniques are employed to enable secure inference despite attendant overhead. Existing works attempt to reduce the overhead using more MPC-friendly non-linear function approximations. However, the integration of quantization widely used in plaintext inference into the MPC domain remains unclear. To bridge this gap, we propose the framework named Ditto to enable more efficient quantization-aware secure Transformer inference. Concretely, we first incorporate an MPC-friendly quantization into Transformer inference and employ a quantization-aware distillation procedure to maintain the model utility. Then, we propose novel MPC primitives to support the type conversions that are essential in quantization and implement the quantization-aware MPC execution of secure quantized inference. This approach significantly decreases both computation and communication overhead, leading to improvements in overall efficiency. We conduct extensive experiments on Bert and GPT2 models to evaluate the performance of Ditto. The results demonstrate that Ditto is about $3.14\sim 4.40\times$ faster than MPCFormer (ICLR 2023) and $1.44\sim 2.35\times$ faster than the state-of-the-art work PUMA with negligible utility degradation.} +} +``` + ## Acknowledgement We thank the significant contributions made by [Alibaba Gemini Lab](https://alibaba-gemini-lab.github.io) and security advisories made by [VUL337@NISL@THU](https://netsec.ccert.edu.cn/vul337). diff --git a/WORKSPACE b/WORKSPACE index 207eb359..8a40fb58 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -53,6 +53,10 @@ rules_foreign_cc_dependencies( register_preinstalled_tools = True, ) +load("@bazel_features//:deps.bzl", "bazel_features_deps") + +bazel_features_deps() + load("@rules_cuda//cuda:repositories.bzl", "register_detected_cuda_toolchains", "rules_cuda_dependencies") rules_cuda_dependencies() diff --git a/bazel/eigen.BUILD b/bazel/eigen.BUILD index 73df40c8..80ccf3ca 100644 --- a/bazel/eigen.BUILD +++ b/bazel/eigen.BUILD @@ -20,6 +20,7 @@ # matrices, and related algorithms. load("@rules_cc//cc:defs.bzl", "cc_library") +load("@yacl//bazel:yacl.bzl", "OMP_DEPS") licenses([ # Note: Eigen is an MPL2 library that includes GPL v3 and LGPL v2.1+ code. @@ -64,6 +65,7 @@ cc_library( ], includes = ["."], visibility = ["//visibility:public"], + deps = OMP_DEPS, ) filegroup( diff --git a/bazel/patches/seal.patch b/bazel/patches/seal.patch deleted file mode 100644 index 20467095..00000000 --- a/bazel/patches/seal.patch +++ /dev/null @@ -1,262 +0,0 @@ -diff --git a/native/src/seal/serializable.h b/native/src/seal/serializable.h -index a940190..e490b30 100644 ---- a/native/src/seal/serializable.h -+++ b/native/src/seal/serializable.h -@@ -135,6 +135,9 @@ namespace seal - return obj_.save(out, size, compr_mode); - } - -+ const T& obj() const { return obj_; } -+ -+ T& obj() { return obj_; } - private: - Serializable(T &&obj) : obj_(std::move(obj)) - {} - -diff --git a/native/src/seal/context.cpp b/native/src/seal/context.cpp -index 887a1312..932d9774 100644 ---- a/native/src/seal/context.cpp -+++ b/native/src/seal/context.cpp -@@ -477,7 +477,8 @@ namespace seal - // more than one modulus in coeff_modulus. This is equivalent to expanding - // the chain by one step. Otherwise, we set first_parms_id_ to equal - // key_parms_id_. -- if (!context_data_map_.at(key_parms_id_)->qualifiers_.parameters_set() || parms.coeff_modulus().size() == 1) -+ if (!context_data_map_.at(key_parms_id_)->qualifiers_.parameters_set() || parms.coeff_modulus().size() == 1 || -+ !parms.use_special_prime()) - { - first_parms_id_ = key_parms_id_; - } -diff --git a/native/src/seal/encryptionparams.cpp b/native/src/seal/encryptionparams.cpp -index 31e07441..c34d0a45 100644 ---- a/native/src/seal/encryptionparams.cpp -+++ b/native/src/seal/encryptionparams.cpp -@@ -23,8 +23,10 @@ namespace seal - uint64_t poly_modulus_degree64 = static_cast(poly_modulus_degree_); - uint64_t coeff_modulus_size64 = static_cast(coeff_modulus_.size()); - uint8_t scheme = static_cast(scheme_); -+ uint8_t use_special_prime = static_cast(use_special_prime); - - stream.write(reinterpret_cast(&scheme), sizeof(uint8_t)); -+ stream.write(reinterpret_cast(&use_special_prime), sizeof(uint8_t)); - stream.write(reinterpret_cast(&poly_modulus_degree64), sizeof(uint64_t)); - stream.write(reinterpret_cast(&coeff_modulus_size64), sizeof(uint64_t)); - for (const auto &mod : coeff_modulus_) -@@ -63,6 +65,10 @@ namespace seal - // This constructor will throw if scheme is invalid - EncryptionParameters parms(scheme); - -+ uint8_t use_special_prime; -+ stream.read(reinterpret_cast(&use_special_prime), sizeof(uint8_t)); -+ parms.set_use_special_prime(use_special_prime); -+ - // Read the poly_modulus_degree - uint64_t poly_modulus_degree64 = 0; - stream.read(reinterpret_cast(&poly_modulus_degree64), sizeof(uint64_t)); -@@ -128,7 +134,8 @@ namespace seal - size_t total_uint64_count = add_safe( - size_t(1), // scheme - size_t(1), // poly_modulus_degree -- coeff_modulus_size, plain_modulus_.uint64_count()); -+ size_t(1), // use_special_prime -+ coeff_modulus_size); - - auto param_data(allocate_uint(total_uint64_count, pool_)); - uint64_t *param_data_ptr = param_data.get(); -@@ -139,13 +146,15 @@ namespace seal - // Write the poly_modulus_degree. Note that it will always be positive. - *param_data_ptr++ = static_cast(poly_modulus_degree_); - -+ *param_data_ptr++ = static_cast(use_special_prime_); - for (const auto &mod : coeff_modulus_) - { - *param_data_ptr++ = mod.value(); - } - -- set_uint(plain_modulus_.data(), plain_modulus_.uint64_count(), param_data_ptr); -- param_data_ptr += plain_modulus_.uint64_count(); -+ // NOTE(juhou): we skip the plain modulus for parms_id -+ // set_uint(plain_modulus_.data(), plain_modulus_.uint64_count(), param_data_ptr); -+ // param_data_ptr += plain_modulus_.uint64_count(); - - HashFunction::hash(param_data.get(), total_uint64_count, parms_id_); - -diff --git a/native/src/seal/encryptionparams.h b/native/src/seal/encryptionparams.h -index 9e1fbe48..eb71c4ac 100644 ---- a/native/src/seal/encryptionparams.h -+++ b/native/src/seal/encryptionparams.h -@@ -266,6 +266,11 @@ namespace seal - random_generator_ = std::move(random_generator); - } - -+ inline void set_use_special_prime(bool flag) -+ { -+ use_special_prime_ = flag; -+ } -+ - /** - Returns the encryption scheme type. - */ -@@ -274,6 +279,11 @@ namespace seal - return scheme_; - } - -+ bool use_special_prime() const noexcept -+ { -+ return use_special_prime_; -+ } -+ - /** - Returns the degree of the polynomial modulus parameter. - */ -@@ -501,6 +511,8 @@ namespace seal - - Modulus plain_modulus_{}; - -+ bool use_special_prime_ = true; -+ - parms_id_type parms_id_ = parms_id_zero; - }; - } // namespace seal - -diff --git a/native/src/seal/evaluator.cpp b/native/src/seal/evaluator.cpp -index dabd3bab..afaa71dc 100644 ---- a/native/src/seal/evaluator.cpp -+++ b/native/src/seal/evaluator.cpp -@@ -2382,6 +2382,7 @@ namespace seal - size_t encrypted_size = encrypted.size(); - // Use key_context_data where permutation tables exist since previous runs. - auto galois_tool = context_.key_context_data()->galois_tool(); -+ bool is_ntt_form = encrypted.is_ntt_form(); - - // Size check - if (!product_fits_in(coeff_count, coeff_modulus_size)) -@@ -2412,7 +2413,7 @@ namespace seal - // DO NOT CHANGE EXECUTION ORDER OF FOLLOWING SECTION - // BEGIN: Apply Galois for each ciphertext - // Execution order is sensitive, since apply_galois is not inplace! -- if (parms.scheme() == scheme_type::bfv) -+ if (not is_ntt_form) - { - // !!! DO NOT CHANGE EXECUTION ORDER!!! - -@@ -2426,7 +2427,7 @@ namespace seal - // Next transform encrypted.data(1) - galois_tool->apply_galois(encrypted_iter[1], coeff_modulus_size, galois_elt, coeff_modulus, temp); - } -- else if (parms.scheme() == scheme_type::ckks || parms.scheme() == scheme_type::bgv) -+ else - { - // !!! DO NOT CHANGE EXECUTION ORDER!!! - -@@ -2440,10 +2441,6 @@ namespace seal - // Next transform encrypted.data(1) - galois_tool->apply_galois_ntt(encrypted_iter[1], coeff_modulus_size, galois_elt, temp); - } -- else -- { -- throw logic_error("scheme not implemented"); -- } - - // Wipe encrypted.data(1) - set_zero_poly(coeff_count, coeff_modulus_size, encrypted.data(1)); -@@ -2530,6 +2527,7 @@ namespace seal - auto &key_context_data = *context_.key_context_data(); - auto &key_parms = key_context_data.parms(); - auto scheme = parms.scheme(); -+ bool is_ntt_form = encrypted.is_ntt_form(); - - // Verify parameters. - if (!is_metadata_valid_for(encrypted, context_) || !is_buffer_valid(encrypted)) -@@ -2559,14 +2557,6 @@ namespace seal - { - throw invalid_argument("pool is uninitialized"); - } -- if (scheme == scheme_type::bfv && encrypted.is_ntt_form()) -- { -- throw invalid_argument("BFV encrypted cannot be in NTT form"); -- } -- if (scheme == scheme_type::ckks && !encrypted.is_ntt_form()) -- { -- throw invalid_argument("CKKS encrypted must be in NTT form"); -- } - if (scheme == scheme_type::bgv && !encrypted.is_ntt_form()) - { - throw invalid_argument("BGV encrypted must be in NTT form"); -@@ -2605,7 +2595,7 @@ namespace seal - set_uint(target_iter, decomp_modulus_size * coeff_count, t_target); - - // In CKKS or BGV, t_target is in NTT form; switch back to normal form -- if (scheme == scheme_type::ckks || scheme == scheme_type::bgv) -+ if (is_ntt_form) - { - inverse_ntt_negacyclic_harvey(t_target, decomp_modulus_size, key_ntt_tables); - } -@@ -2632,7 +2622,7 @@ namespace seal - ConstCoeffIter t_operand; - - // RNS-NTT form exists in input -- if ((scheme == scheme_type::ckks || scheme == scheme_type::bgv) && (I == J)) -+ if (is_ntt_form && (I == J)) - { - t_operand = target_iter[J]; - } -@@ -2789,7 +2779,7 @@ namespace seal - SEAL_ITERATE(t_ntt, coeff_count, [fix](auto &K) { K += fix; }); - - uint64_t qi_lazy = qi << 1; // some multiples of qi -- if (scheme == scheme_type::ckks) -+ if (is_ntt_form) - { - // This ntt_negacyclic_harvey_lazy results in [0, 4*qi). - ntt_negacyclic_harvey_lazy(t_ntt, get<2>(J)); -@@ -2802,7 +2792,7 @@ namespace seal - qi_lazy = qi << 2; - #endif - } -- else if (scheme == scheme_type::bfv) -+ else - { - inverse_ntt_negacyclic_harvey_lazy(get<0, 1>(J), get<2>(J)); - } - -diff --git a/CMakeLists.txt b/CMakeLists.txt -index 1a7a2bfd..bc4ad9d9 100644 ---- a/CMakeLists.txt -+++ b/CMakeLists.txt -@@ -223,7 +223,7 @@ if(SEAL_USE_INTEL_HEXL) - message(STATUS "Intel HEXL: download ...") - seal_fetch_thirdparty_content(ExternalIntelHEXL) - else() -- find_package(HEXL 1.2.4) -+ find_package(HEXL 1.2.5) - if (NOT TARGET HEXL::hexl) - message(FATAL_ERROR "Intel HEXL: not found") - endif() - -diff --git a/native/src/seal/evaluator.h b/native/src/seal/evaluator.h -index 33bc3c7d..8a00ebea 100644 ---- a/native/src/seal/evaluator.h -+++ b/native/src/seal/evaluator.h -@@ -1199,6 +1199,10 @@ namespace seal - */ - struct EvaluatorPrivateHelper; - -+ void switch_key_inplace( -+ Ciphertext &encrypted, util::ConstRNSIter target_iter, const KSwitchKeys &kswitch_keys, -+ std::size_t key_index, MemoryPoolHandle pool = MemoryManager::GetPool()) const; -+ - private: - Evaluator(const Evaluator ©) = delete; - -@@ -1257,10 +1261,6 @@ namespace seal - apply_galois_inplace(encrypted, galois_tool->get_elt_from_step(0), galois_keys, std::move(pool)); - } - -- void switch_key_inplace( -- Ciphertext &encrypted, util::ConstRNSIter target_iter, const KSwitchKeys &kswitch_keys, -- std::size_t key_index, MemoryPoolHandle pool = MemoryManager::GetPool()) const; -- - void multiply_plain_normal(Ciphertext &encrypted, const Plaintext &plain, MemoryPoolHandle pool) const; - - void multiply_plain_ntt(Ciphertext &encrypted_ntt, const Plaintext &plain_ntt) const; diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index 59ed839d..deec975a 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -29,7 +29,6 @@ def spu_deps(): _com_github_emptoolkit_emp_tool() _com_github_emptoolkit_emp_ot() _com_github_facebook_zstd() - _com_github_microsoft_seal() _com_github_eigenteam_eigen() _com_github_nvidia_cutlass() _yacl() @@ -40,10 +39,10 @@ def _yacl(): http_archive, name = "yacl", urls = [ - "https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b3.tar.gz", + "https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b8_nightly_20241014.tar.gz", ], - strip_prefix = "yacl-0.4.5b3", - sha256 = "bd89d63312e5e83eff5e001e2cf2135baff321c4b72a309f7d00cc53ce02e1a1", + strip_prefix = "yacl-0.4.5b8_nightly_20241014", + sha256 = "9141792f07eba507ffd21c57ec3df2ad5fdf90ce605ffb7bc1b7b4e84a9c34fa", ) def _libpsi(): @@ -51,10 +50,10 @@ def _libpsi(): http_archive, name = "psi", urls = [ - "https://github.com/secretflow/psi/archive/refs/tags/v0.4.0beta.tar.gz", + "https://github.com/secretflow/psi/archive/refs/tags/v0.5.0.dev241115.tar.gz", ], - strip_prefix = "psi-0.4.0beta", - sha256 = "c2fbf486a66eca9d3ec1725a81d93a7c6e80a9206ef1c9263a1608e0bef95e1a", + strip_prefix = "psi-0.5.0.dev241115", + sha256 = "4d5ccc61282c4f887cee2c12fe3f414dfd7e916952849e92ffb1f6835d657a35", ) def _rules_proto_grpc(): @@ -70,9 +69,9 @@ def _rules_proto_grpc(): def _rules_cuda(): http_archive( name = "rules_cuda", - sha256 = "2f8c8c8c85f727bec4423efecec12d3b751cb0a98bda99f0f9d351608a23b858", - strip_prefix = "rules_cuda-v0.2.1", - urls = ["https://github.com/bazel-contrib/rules_cuda/releases/download/v0.2.1/rules_cuda-v0.2.1.tar.gz"], + sha256 = "c92b334d769a07cd991b7675b2f6076b8b95cd3b28b14268a2f379f8baae58e0", + strip_prefix = "rules_cuda-v0.2.3", + urls = ["https://github.com/bazel-contrib/rules_cuda/releases/download/v0.2.3/rules_cuda-v0.2.3.tar.gz"], ) def _bazel_platform(): @@ -136,8 +135,8 @@ def _bazel_skylib(): ) def _com_github_openxla_xla(): - OPENXLA_COMMIT = "9b0dd58c9b625a2e958f4fc7787a1ff5c95dbb40" - OPENXLA_SHA256 = "f150c5b49e4d4497aae2c79232f1efe2baccaa72223b21dc8715be73eab74417" + OPENXLA_COMMIT = "64bdcc53a1b24abf19b1fe598e6f9b0fe6454470" + OPENXLA_SHA256 = "60918b3a0391fe9e0bd506c9b90170b7b5fa64d06de7ec1f4f0e351a303a88fa" # We need openxla to handle xla/mhlo/stablehlo maybe( @@ -169,10 +168,10 @@ def _com_github_pybind11(): http_archive, name = "pybind11", build_file = "@pybind11_bazel//:pybind11.BUILD", - sha256 = "51631e88960a8856f9c497027f55c9f2f9115cafb08c0005439838a05ba17bfc", - strip_prefix = "pybind11-2.13.1", + sha256 = "e08cb87f4773da97fa7b5f035de8763abc656d87d5773e62f6da0587d1f0ec20", + strip_prefix = "pybind11-2.13.6", urls = [ - "https://github.com/pybind/pybind11/archive/refs/tags/v2.13.1.tar.gz", + "https://github.com/pybind/pybind11/archive/refs/tags/v2.13.6.tar.gz", ], ) @@ -225,27 +224,12 @@ def _com_github_emptoolkit_emp_ot(): build_file = "@spulib//bazel:emp-ot.BUILD", ) -def _com_github_microsoft_seal(): - maybe( - http_archive, - name = "com_github_microsoft_seal", - sha256 = "acc2a1a127a85d1e1ffcca3ffd148f736e665df6d6b072df0e42fff64795a13c", - strip_prefix = "SEAL-4.1.2", - type = "tar.gz", - patch_args = ["-p1"], - patches = ["@spulib//bazel:patches/seal.patch"], - urls = [ - "https://github.com/microsoft/SEAL/archive/refs/tags/v4.1.2.tar.gz", - ], - build_file = "@spulib//bazel:seal.BUILD", - ) - def _com_github_eigenteam_eigen(): EIGEN_COMMIT = "66e8f38891841bf88ee976a316c0c78a52f0cee5" EIGEN_SHA256 = "01fcd68409c038bbcfd16394274c2bf71e2bb6dda89a2319e23fc59a2da17210" maybe( http_archive, - name = "com_github_eigenteam_eigen", + name = "eigen_archive", sha256 = EIGEN_SHA256, build_file = "@spulib//bazel:eigen.BUILD", strip_prefix = "eigen-{commit}".format(commit = EIGEN_COMMIT), @@ -257,11 +241,11 @@ def _com_github_eigenteam_eigen(): def _com_github_nvidia_cutlass(): maybe( http_archive, - name = "com_github_nvidia_cutlass", - strip_prefix = "cutlass-3.5.0", + name = "cutlass_archive", + strip_prefix = "cutlass-3.5.1", urls = [ - "https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.5.0.tar.gz", + "https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.5.1.tar.gz", ], - sha256 = "ef6af8526e3ad04f9827f35ee57eec555d09447f70a0ad0cf684a2e426ccbcb6", + sha256 = "20b7247cda2d257cbf8ba59ba3ca40a9211c4da61a9c9913e32b33a2c5883a36", build_file = "@spulib//bazel:nvidia_cutlass.BUILD", ) diff --git a/docs/development/add_protocols.rst b/docs/development/add_protocols.rst index 31b4b5e2..b5c44796 100644 --- a/docs/development/add_protocols.rst +++ b/docs/development/add_protocols.rst @@ -105,7 +105,7 @@ member function and a member variable of an Object, respectively. // register customized kernels template void regKernel() { - regKernel(KernelT::kBindName, std::make_unique()); + regKernel(KernelT::kBindName(), std::make_unique()); } template @@ -116,7 +116,7 @@ member function and a member variable of an Object, respectively. // add customized states template void addState(Args&&... args) { - addState(StateT::kBindName, + addState(StateT::kBindName(), std::make_unique(std::forward(args)...)); } ... @@ -205,7 +205,7 @@ As a result, the ABY3 developer can directly register these kernels through the class AndPP : public BinaryKernel { public: // kernel name for dynamic binding - static constexpr char kBindName[] = "and_pp"; + static constexpr const char* kBindName() { return "and_pp"; } // define cost model ce::CExpr latency() const override { return ce::Const(0); } @@ -248,7 +248,7 @@ When kernels are implemented and registered, a new protocol is finally added. auto* prg_state = ctx->getState(); // dispatch the real implementation to different fields - return DISPATCH_ALL_FIELDS(field, "aby3.mulAA", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { // the real protocol implementation ... }); diff --git a/docs/requirements.txt b/docs/requirements.txt index 3b26b70d..c750e8d9 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,12 +1,12 @@ -myst-parser==3.0.1 +myst-parser==4.0.0 rstcheck==6.2.4 -sphinx==7.4.4 -nbsphinx==0.9.4 -sphinx-autobuild==2024.4.16 +sphinx==8.1.3 +nbsphinx==0.9.5 +sphinx-autobuild==2024.10.3 sphinx-markdown-parser==0.2.4 sphinxcontrib-actdiag==3.0.0 sphinxcontrib-blockdiag==3.0.0 -sphinxcontrib-mermaid==0.9.2 +sphinxcontrib-mermaid==1.0.0 sphinxcontrib-nwdiag==2.0.0 sphinxcontrib-seqdiag==3.0.0 pytablewriter==1.2.0 diff --git a/examples/python/ml/jax_lr/README.md b/examples/python/ml/jax_lr/README.md index fc0e6580..bc340199 100644 --- a/examples/python/ml/jax_lr/README.md +++ b/examples/python/ml/jax_lr/README.md @@ -5,7 +5,7 @@ This example demonstrates how to use SPU to train a logistic regression model pr 1. Launch SPU backend runtime ```sh - bazel run -c opt //examples/python/utils:nodectl -- up + bazel run -c opt //examples/python/utils:nodectl -- -c examples/python/conf/2pc_semi2k.json up ``` 2. Run `jax_lr` example diff --git a/experimental/squirrel/BUILD.bazel b/experimental/squirrel/BUILD.bazel index 5717fcbf..64c794a5 100644 --- a/experimental/squirrel/BUILD.bazel +++ b/experimental/squirrel/BUILD.bazel @@ -60,8 +60,8 @@ spu_cc_library( "//libspu/mpc/cheetah/rlwe:cheetah_rlwe", "//libspu/mpc/cheetah/rlwe:lwe", "//libspu/mpc/cheetah/rlwe:packlwes", - "@com_github_eigenteam_eigen//:eigen3", "@com_github_microsoft_seal//:seal", + "@eigen_archive//:eigen3", "@yacl//yacl/utils:elapsed_timer", ], ) diff --git a/experimental/squirrel/README.md b/experimental/squirrel/README.md index 5c33ff83..6d2eb815 100644 --- a/experimental/squirrel/README.md +++ b/experimental/squirrel/README.md @@ -29,11 +29,24 @@ Code under this folder is purely for research demonstration and it's **NOT desig * On one terminal ```sh - bazel-bin/experimental/squirrel/squirrel_demo_main --rank0_nfeatures=85 --rank1_nfeatures=85 --standalone --train=BinaryClassification_Aps_Test_60000_171.csv --test=BinaryClassification_Aps_Test_16000_171.csv --standalone=true --rank=0 --has_label=0 --lr=1.0 --subsample=0.8 + bazel-bin/experimental/squirrel/squirrel_demo_main --rank0_nfeatures=85 --rank1_nfeatures=85 --standalone=true --train=BinaryClassification_Aps_Test_60000_171.csv --test=BinaryClassification_Aps_Test_16000_171.csv --rank=0 --has_label=0 --lr=1.0 --subsample=0.8 ``` * On another terminal ```sh - bazel-bin/experimental/squirrel/squirrel_demo_main --rank0_nfeatures=85 --rank1_nfeatures=85 --standalone --train=BinaryClassification_Aps_Test_60000_171.csv --test=BinaryClassification_Aps_Test_16000_171.csv --standalone=true --rank=1 --has_label=1 --lr=1.0 --subsample=0.8 + bazel-bin/experimental/squirrel/squirrel_demo_main --rank0_nfeatures=85 --rank1_nfeatures=85 --standalone=true --train=BinaryClassification_Aps_Test_60000_171.csv --test=BinaryClassification_Aps_Test_16000_171.csv --rank=1 --has_label=1 --lr=1.0 --subsample=0.8 + ``` + +* Run on distributed dataset, e.g., using the `breast_cancer` dataset from the SPU repo. + * On one terminal + + ```sh + bazel-bin/experimental/squirrel/squirrel_demo_main --rank0_nfeatures=15 --rank1_nfeatures=15 --standalone=false --train=examples/data/breast_cancer_a.csv --rank=0 --has_label=0 --lr=1.0 --subsample=0.8 + ``` + + * On another terminal + + ```sh + bazel-bin/experimental/squirrel/squirrel_demo_main --rank0_nfeatures=15 --rank1_nfeatures=15 --standalone=false --train=examples/data/breast_cancer_b.csv --rank=1 --has_label=1 --lr=1.0 --subsample=0.8 ``` diff --git a/experimental/squirrel/bin_matvec_prot.cc b/experimental/squirrel/bin_matvec_prot.cc index 6946fa4a..35e4ee3a 100644 --- a/experimental/squirrel/bin_matvec_prot.cc +++ b/experimental/squirrel/bin_matvec_prot.cc @@ -766,9 +766,12 @@ spu::NdArrayRef BinMatVecProtocol::Recv(const spu::NdArrayRef &vec_in, SPU_ENFORCE_EQ(vec_in.numel(), dim_in); if (not indicator.empty()) { // mat * diag(indicator) should not be all zeros. - SPU_ENFORCE(std::any_of(indicator.begin(), indicator.end(), - [](uint8_t x) { return x > 0; }), - "empty matrix is not allowed"); + if (std::all_of(indicator.begin(), indicator.end(), + [](uint8_t x) { return x == 0; })) { + SPDLOG_WARN( + "Empty matrix! Make sure the 1-bit error will not ruin your " + "computation."); + } } auto eltype = vec_in.eltype(); diff --git a/experimental/squirrel/bin_matvec_prot.h b/experimental/squirrel/bin_matvec_prot.h index 2ac8ea6c..ba34499f 100644 --- a/experimental/squirrel/bin_matvec_prot.h +++ b/experimental/squirrel/bin_matvec_prot.h @@ -70,7 +70,10 @@ struct StlSparseMatrix { // Outputs: // Sender: z0 \in Zk^{m} // Recv: z1 \in Zk^{m} -// such that z0 + z1 = M * (v0 + v1) mod Zk +// such that z0 + z1 = M * (v0 + v1) + e mod Zk +// +// Note that we might introduce 1-bit error `e` due to the coefficient-based +// resharing HE ciphertexts to additive shares. class BinMatVecProtocol { public: BinMatVecProtocol(size_t ring_bitwidth, diff --git a/experimental/squirrel/bin_matvec_prot_test.cc b/experimental/squirrel/bin_matvec_prot_test.cc index 0d5cae29..62c07ead 100644 --- a/experimental/squirrel/bin_matvec_prot_test.cc +++ b/experimental/squirrel/bin_matvec_prot_test.cc @@ -110,7 +110,7 @@ TEST_P(BinMatVecProtTest, Basic) { }); NdArrayRef reveal = ring_add(out_shr[0], out_shr[1]); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _vec(vec); auto expected = BinAccumuate(_vec, mat); NdArrayView got(reveal); @@ -160,7 +160,7 @@ TEST_P(BinMatVecProtTest, WithIndicator) { }); NdArrayRef reveal = ring_add(out_shr[0], out_shr[1]); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _vec(vec); auto expected = BinAccumuate(_vec, mat, absl::MakeConstSpan(indicator)); @@ -172,4 +172,100 @@ TEST_P(BinMatVecProtTest, WithIndicator) { } }); } + +TEST_P(BinMatVecProtTest, EmptyMat) { + using namespace spu; + using namespace spu::mpc; + constexpr size_t kWorldSize = 2; + + FieldType field = std::get<0>(GetParam()); + int64_t dim_in = std::get<0>(std::get<1>(GetParam())); + int64_t dim_out = std::get<1>(std::get<1>(GetParam())); + + StlSparseMatrix mat; + mat.rows_data_.resize(dim_out); + mat.cols_ = dim_in; + + NdArrayRef vec_shr[2]; + vec_shr[0] = ring_rand(field, {dim_in}) + .as(spu::makeType(field)); + vec_shr[1] = ring_rand(field, {dim_in}) + .as(spu::makeType(field)); + + NdArrayRef vec = ring_add(vec_shr[0], vec_shr[1]); + + NdArrayRef out_shr[2]; + utils::simulate(kWorldSize, [&](std::shared_ptr lctx) { + BinMatVecProtocol binmat_prot(SizeOf(field) * 8, lctx); + if (0 == lctx->Rank()) { + out_shr[0] = binmat_prot.Send(vec_shr[0], dim_out, dim_in); + } else { + out_shr[1] = binmat_prot.Recv(vec_shr[1], dim_out, dim_in, mat); + } + }); + NdArrayRef reveal = ring_add(out_shr[0], out_shr[1]); + + DISPATCH_ALL_FIELDS(field, [&]() { + using sT = std::make_signed::type; + NdArrayView _vec(vec); + auto expected = BinAccumuate(_vec, mat); + NdArrayView got(reveal); + + EXPECT_EQ(expected.size(), (size_t)got.numel()); + for (int64_t i = 0; i < dim_out; ++i) { + EXPECT_NEAR(expected[i], got[i], 1); + } + }); +} + +TEST_P(BinMatVecProtTest, WithEmptyIndicator) { + using namespace spu; + using namespace spu::mpc; + constexpr size_t kWorldSize = 2; + + FieldType field = std::get<0>(GetParam()); + int64_t dim_in = std::get<0>(std::get<1>(GetParam())); + int64_t dim_out = std::get<1>(std::get<1>(GetParam())); + + StlSparseMatrix mat; + PrepareBinaryMat(mat, dim_out, dim_in, 0); + + NdArrayRef vec_shr[2]; + vec_shr[0] = ring_rand(field, {dim_in}) + .as(spu::makeType(field)); + vec_shr[1] = ring_rand(field, {dim_in}) + .as(spu::makeType(field)); + std::vector indicator(dim_in); + std::default_random_engine rdv; + std::uniform_int_distribution dist(0, 10); + // empty indicator + std::fill_n(indicator.data(), indicator.size(), 0); + + NdArrayRef vec = ring_add(vec_shr[0], vec_shr[1]); + + NdArrayRef out_shr[2]; + utils::simulate(kWorldSize, [&](std::shared_ptr lctx) { + BinMatVecProtocol binmat_prot(SizeOf(field) * 8, lctx); + if (0 == lctx->Rank()) { + out_shr[0] = binmat_prot.Send(vec_shr[0], dim_out, dim_in); + } else { + out_shr[1] = binmat_prot.Recv(vec_shr[1], dim_out, dim_in, mat, + absl::MakeConstSpan(indicator)); + } + }); + NdArrayRef reveal = ring_add(out_shr[0], out_shr[1]); + + DISPATCH_ALL_FIELDS(field, [&]() { + using sT = std::make_signed::type; + NdArrayView _vec(vec); + NdArrayView got(reveal); + auto expected = BinAccumuate(_vec, mat, absl::MakeConstSpan(indicator)); + + EXPECT_EQ(expected.size(), (size_t)got.numel()); + for (int64_t i = 0; i < dim_out; ++i) { + EXPECT_NEAR(expected[i], got[i], 1); + } + }); +} + } // namespace squirrel::test diff --git a/experimental/squirrel/objectives.cc b/experimental/squirrel/objectives.cc index b9423d2d..baf5f08d 100644 --- a/experimental/squirrel/objectives.cc +++ b/experimental/squirrel/objectives.cc @@ -272,9 +272,9 @@ namespace { res = NdArrayRef(makeType(ftype), in.shape()); } - return DISPATCH_ALL_FIELDS(field, "cheetah.ring_cast", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using from_ring2k_t = ring2k_t; - return DISPATCH_ALL_FIELDS(ftype, "cheetah.ring_cast", [&]() { + return DISPATCH_ALL_FIELDS(ftype, [&]() { using to_ring2k_t = ring2k_t; NdArrayView _in(in); NdArrayView _res(res); @@ -383,7 +383,7 @@ spu::Value Logistic(spu::SPUContext* ctx, const spu::Value& x) { spu::Value Sigmoid(spu::SPUContext* ctx, const spu::Value& x) { namespace sk = spu::kernel; auto c05 = sk::hlo::Constant(ctx, 0.5F, x.shape()); - auto half = sk::hal::right_shift_arithmetic(ctx, x, 1); + auto half = sk::hal::right_shift_arithmetic(ctx, x, {1}); auto divisor = sk::hlo::Add(ctx, sk::hlo::Constant(ctx, 1, x.shape()), sk::hal::f_square(ctx, x)); return sk::hlo::Add(ctx, c05, diff --git a/experimental/squirrel/tree_build_worker.cc b/experimental/squirrel/tree_build_worker.cc index 30c441f9..06724354 100644 --- a/experimental/squirrel/tree_build_worker.cc +++ b/experimental/squirrel/tree_build_worker.cc @@ -230,7 +230,7 @@ void AccumulateHistogram(spu::NdArrayRef buckets_share, size_t nfeatures, // The buckets belong to the i-th feature is // `buckets[i*bucket_size:(i+1)*bucket_size]` auto field = buckets_share.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, "AccumulateHistogram", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView histogram(buckets_share); for (size_t j = 0; j < nfeatures; ++j) { size_t start = j * bucket_size; diff --git a/experimental/squirrel/utils.cc b/experimental/squirrel/utils.cc index 3f14436e..58a98d70 100644 --- a/experimental/squirrel/utils.cc +++ b/experimental/squirrel/utils.cc @@ -167,7 +167,7 @@ spu::Value MulPrivateArithWithPrivateBoolean(spu::SPUContext* ctx, &kctx, arith.data(), [&](const NdArrayRef& input, const std::shared_ptr& base_ot) { - return DISPATCH_ALL_FIELDS(ft, "ot", [&]() { + return DISPATCH_ALL_FIELDS(ft, [&]() { NdArrayRef ot_out = spu::mpc::ring_zeros(ft, input.shape()); auto inp = absl::MakeConstSpan(&input.at(0), input.numel()); auto oup = absl::MakeSpan(&ot_out.at(0), ot_out.numel()); @@ -193,7 +193,7 @@ spu::Value MulPrivateArithWithPrivateBoolean(spu::SPUContext* ctx, &kctx, boolean, [&](absl::Span input, const std::shared_ptr& base_ot) { - return DISPATCH_ALL_FIELDS(ft, "ot", [&]() { + return DISPATCH_ALL_FIELDS(ft, [&]() { NdArrayRef ot_out = spu::mpc::ring_zeros(ft, {(int64_t)input.size()}); auto oup = absl::MakeSpan(&ot_out.at(0), input.size()); base_ot->GetReceiverCOT()->RecvCAMCC(input, oup); @@ -222,7 +222,7 @@ spu::Value MulArithShareWithANDBoolShare(spu::SPUContext* ctx, std::shared_ptr base_ot) { NdArrayRef out(x.eltype(), x.shape()); - DISPATCH_ALL_FIELDS(ft, "camcc", [&]() { + DISPATCH_ALL_FIELDS(ft, [&]() { spu::NdArrayView _ashr(x); auto oup = absl::MakeSpan(&out.at(0), y.size()); std::vector corr(y.size()); diff --git a/experimental/squirrel/utils_test.cc b/experimental/squirrel/utils_test.cc index 3bdc004e..8cde97fd 100644 --- a/experimental/squirrel/utils_test.cc +++ b/experimental/squirrel/utils_test.cc @@ -85,7 +85,7 @@ TEST_F(UtilsTest, ReduceSum) { const double fxp = std::pow(2., rt_config.fxp_fraction_bits()); auto flatten = got.data().reshape({got.numel()}); - DISPATCH_ALL_FIELDS(field, "check", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using s2k = std::make_signed::type; NdArrayView got(flatten); for (int64_t i = 0; i < expected.numel(); ++i) { @@ -136,7 +136,7 @@ TEST_F(UtilsTest, ArgMax) { if (lctx->Rank() == 0) { auto flatten = got.data().reshape({got.numel()}); - DISPATCH_ALL_FIELDS(field, "check", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView got(flatten); for (size_t i = 0; i < expected.size(); ++i) { ASSERT_EQ(expected(i), got[i]); diff --git a/libspu/compiler/BUILD.bazel b/libspu/compiler/BUILD.bazel index 68e0478b..ae5dc261 100644 --- a/libspu/compiler/BUILD.bazel +++ b/libspu/compiler/BUILD.bazel @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//bazel:spu.bzl", "spu_cc_binary", "spu_cc_library") +load("//bazel:spu.bzl", "spu_cc_library") package( default_visibility = ["//visibility:public"], diff --git a/libspu/compiler/common/compilation_context.h b/libspu/compiler/common/compilation_context.h index 24ebe043..8abfed96 100644 --- a/libspu/compiler/common/compilation_context.h +++ b/libspu/compiler/common/compilation_context.h @@ -16,7 +16,6 @@ #include #include -#include #include "mlir/IR/MLIRContext.h" #include "mlir/Pass/PassManager.h" diff --git a/libspu/compiler/compile.cc b/libspu/compiler/compile.cc index 38a59acb..236d6d02 100644 --- a/libspu/compiler/compile.cc +++ b/libspu/compiler/compile.cc @@ -14,8 +14,6 @@ #include "libspu/compiler/compile.h" -#include "mlir/IR/BuiltinOps.h" - #include "libspu/compiler/codegen/codegen.h" #include "libspu/compiler/common/compilation_context.h" #include "libspu/compiler/core/core.h" diff --git a/libspu/compiler/core/core.cc b/libspu/compiler/core/core.cc index e977050c..b7c46d8f 100644 --- a/libspu/compiler/core/core.cc +++ b/libspu/compiler/core/core.cc @@ -16,7 +16,6 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" diff --git a/libspu/compiler/front_end/BUILD.bazel b/libspu/compiler/front_end/BUILD.bazel index 8b68cf6d..1cbca11b 100644 --- a/libspu/compiler/front_end/BUILD.bazel +++ b/libspu/compiler/front_end/BUILD.bazel @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@yacl//bazel:yacl.bzl", "OMP_DEPS") load("//bazel:spu.bzl", "spu_cc_library") spu_cc_library( @@ -58,9 +59,9 @@ spu_cc_library( "@xla//xla/service:while_loop_constant_sinking", "@xla//xla/service:while_loop_simplifier", "@xla//xla/service:zero_sized_hlo_elimination", - "@xla//xla/service/gpu:dot_dimension_sorter", + "@xla//xla/service/gpu/transforms:dot_dimension_sorter", "@xla//xla/translate/hlo_to_mhlo:hlo_module_importer", - ], + ] + OMP_DEPS, ) spu_cc_library( diff --git a/libspu/compiler/front_end/fe.cc b/libspu/compiler/front_end/fe.cc index e560c772..7f9529e5 100644 --- a/libspu/compiler/front_end/fe.cc +++ b/libspu/compiler/front_end/fe.cc @@ -14,6 +14,7 @@ #include "libspu/compiler/front_end/fe.h" +#include "fmt/ranges.h" #include "mlir/Dialect/Func/Extensions/InlinerExtension.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" @@ -21,7 +22,6 @@ #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" -#include "spdlog/spdlog.h" #include "stablehlo/dialect/StablehloOps.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" diff --git a/libspu/compiler/front_end/hlo_importer.cc b/libspu/compiler/front_end/hlo_importer.cc index 363bd76c..f328c966 100644 --- a/libspu/compiler/front_end/hlo_importer.cc +++ b/libspu/compiler/front_end/hlo_importer.cc @@ -30,7 +30,7 @@ #include "xla/service/float_support.h" #include "xla/service/gather_expander.h" #include "xla/service/gather_simplifier.h" -#include "xla/service/gpu/dot_dimension_sorter.h" +#include "xla/service/gpu/transforms/dot_dimension_sorter.h" #include "xla/service/hlo_constant_folding.h" #include "xla/service/hlo_cse.h" #include "xla/service/hlo_dce.h" diff --git a/libspu/compiler/front_end/hlo_importer.h b/libspu/compiler/front_end/hlo_importer.h index 8e57b7f0..551f2755 100644 --- a/libspu/compiler/front_end/hlo_importer.h +++ b/libspu/compiler/front_end/hlo_importer.h @@ -28,7 +28,9 @@ class CompilationContext; class HloImporter final { public: + // clang-format off explicit HloImporter(CompilationContext *context) : context_(context) {}; + // clang-format on /// Load a xla module and returns a mlir-hlo module mlir::OwningOpRef diff --git a/libspu/compiler/tests/interpret/abs.mlir b/libspu/compiler/tests/interpret/abs.mlir index dd655bf1..49e409a8 100644 --- a/libspu/compiler/tests/interpret/abs.mlir +++ b/libspu/compiler/tests/interpret/abs.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @abs_op_test_i64_i64_p() { diff --git a/libspu/compiler/tests/interpret/add.mlir b/libspu/compiler/tests/interpret/add.mlir index a763c0ae..5074d3eb 100644 --- a/libspu/compiler/tests/interpret/add.mlir +++ b/libspu/compiler/tests/interpret/add.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @add_op_test_i8_i8_pp() { diff --git a/libspu/compiler/tests/interpret/and.mlir b/libspu/compiler/tests/interpret/and.mlir index 6e3cdfa6..fbcc2a2d 100644 --- a/libspu/compiler/tests/interpret/and.mlir +++ b/libspu/compiler/tests/interpret/and.mlir @@ -1,6 +1,7 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @and_op_test_i8_i8_pp() { diff --git a/libspu/compiler/tests/interpret/atan2.mlir b/libspu/compiler/tests/interpret/atan2.mlir index e1938852..52c2b122 100644 --- a/libspu/compiler/tests/interpret/atan2.mlir +++ b/libspu/compiler/tests/interpret/atan2.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @atan2_op_test_f64_f64_pp() { diff --git a/libspu/compiler/tests/interpret/broadcast.mlir b/libspu/compiler/tests/interpret/broadcast.mlir index 1d0e9a63..7e2d17f8 100644 --- a/libspu/compiler/tests/interpret/broadcast.mlir +++ b/libspu/compiler/tests/interpret/broadcast.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @broadcast_in_dim() { %operand = pphlo.constant dense<[[1], [2], [3]]> : tensor<3x1xi64> diff --git a/libspu/compiler/tests/interpret/case.mlir b/libspu/compiler/tests/interpret/case.mlir index 543fbf86..96105d4e 100644 --- a/libspu/compiler/tests/interpret/case.mlir +++ b/libspu/compiler/tests/interpret/case.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @case_negative_index_default() { %index = pphlo.constant dense<-1> : tensor diff --git a/libspu/compiler/tests/interpret/ceil.mlir b/libspu/compiler/tests/interpret/ceil.mlir index d4f74b8b..23c05538 100644 --- a/libspu/compiler/tests/interpret/ceil.mlir +++ b/libspu/compiler/tests/interpret/ceil.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @ceil_op_test_f16_f16_p() { diff --git a/libspu/compiler/tests/interpret/clamp.mlir b/libspu/compiler/tests/interpret/clamp.mlir index 69c05696..9e76acc8 100644 --- a/libspu/compiler/tests/interpret/clamp.mlir +++ b/libspu/compiler/tests/interpret/clamp.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @clamp_op_test_si64() { %min = pphlo.constant dense<[1, 5, -5]> : tensor<3xi64> diff --git a/libspu/compiler/tests/interpret/concatenate.mlir b/libspu/compiler/tests/interpret/concatenate.mlir index 1d0d7655..63107f22 100644 --- a/libspu/compiler/tests/interpret/concatenate.mlir +++ b/libspu/compiler/tests/interpret/concatenate.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @concatenate() { %input0 = pphlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi64> diff --git a/libspu/compiler/tests/interpret/convert.mlir b/libspu/compiler/tests/interpret/convert.mlir index 747047ee..b3b23f3b 100644 --- a/libspu/compiler/tests/interpret/convert.mlir +++ b/libspu/compiler/tests/interpret/convert.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @convert_op_test_1() { %0 = pphlo.constant dense<[0, 1, 8, -9, 0]> : tensor<5xi32> diff --git a/libspu/compiler/tests/interpret/convolution.mlir b/libspu/compiler/tests/interpret/convolution.mlir index 294cdd36..2ce1030c 100644 --- a/libspu/compiler/tests/interpret/convolution.mlir +++ b/libspu/compiler/tests/interpret/convolution.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @main() { %0 = pphlo.constant dense<[[[[ 1.0, 2.0, 3.0, 4.0], diff --git a/libspu/compiler/tests/interpret/cosine.mlir b/libspu/compiler/tests/interpret/cosine.mlir index 72d6f249..5ee36f58 100644 --- a/libspu/compiler/tests/interpret/cosine.mlir +++ b/libspu/compiler/tests/interpret/cosine.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @cosine_op_test_f16_f16_p() { diff --git a/libspu/compiler/tests/interpret/divide.mlir b/libspu/compiler/tests/interpret/divide.mlir index 56ce53c2..a7540602 100644 --- a/libspu/compiler/tests/interpret/divide.mlir +++ b/libspu/compiler/tests/interpret/divide.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @divide_op_test_i64_i64_pp() { diff --git a/libspu/compiler/tests/interpret/dot_general.mlir b/libspu/compiler/tests/interpret/dot_general.mlir index 29023c81..a4c506ed 100644 --- a/libspu/compiler/tests/interpret/dot_general.mlir +++ b/libspu/compiler/tests/interpret/dot_general.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @dot_general_op_test_si64() { %lhs = pphlo.constant dense<[[[1, 2], [3, 4]], diff --git a/libspu/compiler/tests/interpret/dynamic_slice.mlir b/libspu/compiler/tests/interpret/dynamic_slice.mlir index 672ace7c..d0275fd9 100644 --- a/libspu/compiler/tests/interpret/dynamic_slice.mlir +++ b/libspu/compiler/tests/interpret/dynamic_slice.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @dynamic_slice() { %operand = pphlo.constant dense<[[1, 1, 1], diff --git a/libspu/compiler/tests/interpret/dynamic_update_slice.mlir b/libspu/compiler/tests/interpret/dynamic_update_slice.mlir index a7053428..342b7d3e 100644 --- a/libspu/compiler/tests/interpret/dynamic_update_slice.mlir +++ b/libspu/compiler/tests/interpret/dynamic_update_slice.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @dynamic_update_slice() { %operand = pphlo.constant dense<[[1, 1, 1, 1], @@ -17,3 +19,22 @@ func.func @dynamic_update_slice() { pphlo.custom_call @expect_eq (%result, %expected) : (tensor<4x4xi64>,tensor<4x4xi64>)->() func.return } + +// ----- + +func.func @dynamic_update_slice() { + %operand = pphlo.constant dense<[[1, 1, 1, 1], + [1, 1, 1, 1], + [1, 2, 2, 2], + [1, 2, 2, 2]]> : tensor<4x4xi64> + %update = pphlo.constant dense<[[1, 1, 1], + [1, 1, 1]]> : tensor<2x3xi64> + %i0 = pphlo.constant dense<4> : tensor + %start_indices0 = pphlo.convert %i0 : (tensor) -> tensor> + %start_indices1 = pphlo.constant dense<4> : tensor + %result = pphlo.dynamic_update_slice %operand, %update, %start_indices0, %start_indices1 : + (tensor<4x4xi64>, tensor<2x3xi64>, tensor>, tensor) -> tensor<4x4x!pphlo.secret> + %expected = pphlo.constant dense<[[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]> : tensor<4x4xi64> + pphlo.custom_call @expect_eq (%result, %expected) : (tensor<4x4x!pphlo.secret>, tensor<4x4xi64>)->() + func.return +} diff --git a/libspu/compiler/tests/interpret/equal.mlir b/libspu/compiler/tests/interpret/equal.mlir index f5364638..c8291d1a 100644 --- a/libspu/compiler/tests/interpret/equal.mlir +++ b/libspu/compiler/tests/interpret/equal.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @equal_op_test_i64_i1_pp() { diff --git a/libspu/compiler/tests/interpret/exponential.mlir b/libspu/compiler/tests/interpret/exponential.mlir index a85c5bb6..21ec3591 100644 --- a/libspu/compiler/tests/interpret/exponential.mlir +++ b/libspu/compiler/tests/interpret/exponential.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @exponential_op_test_f64_f64_p() { diff --git a/libspu/compiler/tests/interpret/exponential_minus_one.mlir b/libspu/compiler/tests/interpret/exponential_minus_one.mlir index 1a337a21..35131bbc 100644 --- a/libspu/compiler/tests/interpret/exponential_minus_one.mlir +++ b/libspu/compiler/tests/interpret/exponential_minus_one.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @exponential_minus_one_op_test_f64_f64_p() { diff --git a/libspu/compiler/tests/interpret/floor.mlir b/libspu/compiler/tests/interpret/floor.mlir index 97ccdb1e..602d0161 100644 --- a/libspu/compiler/tests/interpret/floor.mlir +++ b/libspu/compiler/tests/interpret/floor.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @floor_op_test_f16_f16_p() { diff --git a/libspu/compiler/tests/interpret/generate_mlir_tests.py b/libspu/compiler/tests/interpret/generate_mlir_tests.py index f9bdef86..460f67f6 100755 --- a/libspu/compiler/tests/interpret/generate_mlir_tests.py +++ b/libspu/compiler/tests/interpret/generate_mlir_tests.py @@ -65,6 +65,7 @@ def TestCase(inputs, expected, checker='expect_eq', tol=None): "or", # "popcnt", "power", + "reciprocal", "reshape", "round_afz", "rsqrt", @@ -99,6 +100,16 @@ def TestCase(inputs, expected, checker='expect_eq', tol=None): f.write( "// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s\n" ) + # FIXME: these tests are not stable for cheetah now + if test not in ["xor", "or", "and"]: + f.write( + "// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s\n" + ) + # Some test values in max and min are not supported by protocol 5. + if test not in ["max", "min"]: + f.write( + "// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s\n" + ) f.write("// AUTO GENERATED, DO NOT EDIT\n\n") # Emit cases diff --git a/libspu/compiler/tests/interpret/greater.mlir b/libspu/compiler/tests/interpret/greater.mlir index 7f8e76be..92140f85 100644 --- a/libspu/compiler/tests/interpret/greater.mlir +++ b/libspu/compiler/tests/interpret/greater.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @greater_op_test_i64_i1_pp() { diff --git a/libspu/compiler/tests/interpret/greater_equal.mlir b/libspu/compiler/tests/interpret/greater_equal.mlir index 305aaff5..af3ffa7a 100644 --- a/libspu/compiler/tests/interpret/greater_equal.mlir +++ b/libspu/compiler/tests/interpret/greater_equal.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @greater_equal_op_test_i64_i1_pp() { diff --git a/libspu/compiler/tests/interpret/if.mlir b/libspu/compiler/tests/interpret/if.mlir index cc98b65d..ef73d547 100644 --- a/libspu/compiler/tests/interpret/if.mlir +++ b/libspu/compiler/tests/interpret/if.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @if_ops_true_branch() { %pred = pphlo.constant dense : tensor diff --git a/libspu/compiler/tests/interpret/iota.mlir b/libspu/compiler/tests/interpret/iota.mlir index a7ee86ed..cc71b040 100644 --- a/libspu/compiler/tests/interpret/iota.mlir +++ b/libspu/compiler/tests/interpret/iota.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @iota_op_test_si8_dim_0() { %0 = pphlo.iota dim = 0 : tensor<3x4xi8> diff --git a/libspu/compiler/tests/interpret/less.mlir b/libspu/compiler/tests/interpret/less.mlir index 58444a29..1a9d3060 100644 --- a/libspu/compiler/tests/interpret/less.mlir +++ b/libspu/compiler/tests/interpret/less.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @less_op_test_i64_i1_pp() { diff --git a/libspu/compiler/tests/interpret/less_equal.mlir b/libspu/compiler/tests/interpret/less_equal.mlir index 9951a569..c454bcc7 100644 --- a/libspu/compiler/tests/interpret/less_equal.mlir +++ b/libspu/compiler/tests/interpret/less_equal.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @less_equal_op_test_i64_i1_pp() { diff --git a/libspu/compiler/tests/interpret/log.mlir b/libspu/compiler/tests/interpret/log.mlir index af61d86a..bf309137 100644 --- a/libspu/compiler/tests/interpret/log.mlir +++ b/libspu/compiler/tests/interpret/log.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @log_op_test_f64_f64_p() { diff --git a/libspu/compiler/tests/interpret/log_plus_one.mlir b/libspu/compiler/tests/interpret/log_plus_one.mlir index 3aec9184..99bcf9ff 100644 --- a/libspu/compiler/tests/interpret/log_plus_one.mlir +++ b/libspu/compiler/tests/interpret/log_plus_one.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @log_plus_one_op_test_f64_f64_p() { diff --git a/libspu/compiler/tests/interpret/logistic.mlir b/libspu/compiler/tests/interpret/logistic.mlir index eeac6fab..655cb82b 100644 --- a/libspu/compiler/tests/interpret/logistic.mlir +++ b/libspu/compiler/tests/interpret/logistic.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @logistic_op_test_f64_f64_p() { diff --git a/libspu/compiler/tests/interpret/maximum.mlir b/libspu/compiler/tests/interpret/maximum.mlir index b90553e7..4919c356 100644 --- a/libspu/compiler/tests/interpret/maximum.mlir +++ b/libspu/compiler/tests/interpret/maximum.mlir @@ -1,6 +1,7 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @maximum_op_test_i8_i8_pp() { diff --git a/libspu/compiler/tests/interpret/minimum.mlir b/libspu/compiler/tests/interpret/minimum.mlir index 74bb2b83..1853124b 100644 --- a/libspu/compiler/tests/interpret/minimum.mlir +++ b/libspu/compiler/tests/interpret/minimum.mlir @@ -1,6 +1,7 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @minimum_op_test_i8_i8_pp() { diff --git a/libspu/compiler/tests/interpret/multiply.mlir b/libspu/compiler/tests/interpret/multiply.mlir index 82b780f3..f7d415c4 100644 --- a/libspu/compiler/tests/interpret/multiply.mlir +++ b/libspu/compiler/tests/interpret/multiply.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @multiply_op_test_i8_i8_pp() { diff --git a/libspu/compiler/tests/interpret/negate.mlir b/libspu/compiler/tests/interpret/negate.mlir index 8c5e74c3..6a00d9b8 100644 --- a/libspu/compiler/tests/interpret/negate.mlir +++ b/libspu/compiler/tests/interpret/negate.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @negate_op_test_i8_i8_p() { diff --git a/libspu/compiler/tests/interpret/not.mlir b/libspu/compiler/tests/interpret/not.mlir index 0fdf44e4..721e1850 100644 --- a/libspu/compiler/tests/interpret/not.mlir +++ b/libspu/compiler/tests/interpret/not.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @not_op_test_i8_i8_p() { diff --git a/libspu/compiler/tests/interpret/not_equal.mlir b/libspu/compiler/tests/interpret/not_equal.mlir index 1bbbd5b3..cb598070 100644 --- a/libspu/compiler/tests/interpret/not_equal.mlir +++ b/libspu/compiler/tests/interpret/not_equal.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @not_equal_op_test_i64_i1_pp() { diff --git a/libspu/compiler/tests/interpret/or.mlir b/libspu/compiler/tests/interpret/or.mlir index 79836813..acd9c28f 100644 --- a/libspu/compiler/tests/interpret/or.mlir +++ b/libspu/compiler/tests/interpret/or.mlir @@ -1,6 +1,7 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @or_op_test_i8_i8_pp() { diff --git a/libspu/compiler/tests/interpret/pad.mlir b/libspu/compiler/tests/interpret/pad.mlir index abedc73a..521313e2 100644 --- a/libspu/compiler/tests/interpret/pad.mlir +++ b/libspu/compiler/tests/interpret/pad.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @pad() { %operand = pphlo.constant dense<[[0, 0, 0, 0], diff --git a/libspu/compiler/tests/interpret/popcnt.mlir b/libspu/compiler/tests/interpret/popcnt.mlir index e5f83152..efea9133 100644 --- a/libspu/compiler/tests/interpret/popcnt.mlir +++ b/libspu/compiler/tests/interpret/popcnt.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @popcnt_op_test_i64_i64_p() { diff --git a/libspu/compiler/tests/interpret/power.mlir b/libspu/compiler/tests/interpret/power.mlir index 307fce4c..b6bfc475 100644 --- a/libspu/compiler/tests/interpret/power.mlir +++ b/libspu/compiler/tests/interpret/power.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @power_op_test_i64_i64_pp() { diff --git a/libspu/compiler/tests/interpret/reciprocal.mlir b/libspu/compiler/tests/interpret/reciprocal.mlir new file mode 100644 index 00000000..fe1623a5 --- /dev/null +++ b/libspu/compiler/tests/interpret/reciprocal.mlir @@ -0,0 +1,26 @@ +// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s +// AUTO GENERATED, DO NOT EDIT + +func.func @reciprocal_op_test_f64_f64_p() { + %0 = pphlo.constant dense<[[1.0, -200.0], [100.0, 286991.875]]> : tensor<2x2xf64> + %1 = pphlo.reciprocal %0 : (tensor<2x2xf64>)->tensor<2x2xf64> + %2 = pphlo.constant dense<[[1.0, -0.005], [0.01, 0.0]]> : tensor<2x2xf64> + pphlo.custom_call @expect_almost_eq(%1, %2) { tol = 0.01 }: (tensor<2x2xf64>, tensor<2x2xf64>)->() + func.return +} + +// ----- + +func.func @reciprocal_op_test_f64_f64_s() { + %0 = pphlo.constant dense<[[1.0, -200.0], [100.0, 286991.875]]> : tensor<2x2xf64> + %1 = pphlo.convert %0 : (tensor<2x2xf64>)->tensor<2x2x!pphlo.secret> + %2 = pphlo.reciprocal %1 : (tensor<2x2x!pphlo.secret>)->tensor<2x2x!pphlo.secret> + %3 = pphlo.constant dense<[[1.0, -0.005], [0.01, 0.0]]> : tensor<2x2xf64> + %4 = pphlo.convert %2 : (tensor<2x2x!pphlo.secret>)->tensor<2x2xf64> + pphlo.custom_call @expect_almost_eq(%3, %4) { tol = 0.01 }: (tensor<2x2xf64>, tensor<2x2xf64>)->() + func.return +} diff --git a/libspu/compiler/tests/interpret/reduce.mlir b/libspu/compiler/tests/interpret/reduce.mlir index 3b34bf3e..efc1d40d 100644 --- a/libspu/compiler/tests/interpret/reduce.mlir +++ b/libspu/compiler/tests/interpret/reduce.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @reduce() { %input = pphlo.constant dense<[[0, 1, 2, 3, 4, 5]]> : tensor<1x6xi64> diff --git a/libspu/compiler/tests/interpret/reduce_window.mlir b/libspu/compiler/tests/interpret/reduce_window.mlir index 5385ab04..d15cd021 100644 --- a/libspu/compiler/tests/interpret/reduce_window.mlir +++ b/libspu/compiler/tests/interpret/reduce_window.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @reduce_window() { %input = pphlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi64> diff --git a/libspu/compiler/tests/interpret/reshape.mlir b/libspu/compiler/tests/interpret/reshape.mlir index 9e483b77..0a670923 100644 --- a/libspu/compiler/tests/interpret/reshape.mlir +++ b/libspu/compiler/tests/interpret/reshape.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @reshape_op_test_i32_i32_p() { diff --git a/libspu/compiler/tests/interpret/reverse.mlir b/libspu/compiler/tests/interpret/reverse.mlir index 63ab9590..832262e2 100644 --- a/libspu/compiler/tests/interpret/reverse.mlir +++ b/libspu/compiler/tests/interpret/reverse.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @reverse() { %operand = pphlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi64> diff --git a/libspu/compiler/tests/interpret/ring_cast.mlir b/libspu/compiler/tests/interpret/ring_cast.mlir index 94b53241..a6b6806c 100644 --- a/libspu/compiler/tests/interpret/ring_cast.mlir +++ b/libspu/compiler/tests/interpret/ring_cast.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @cast_1() { %c0 = pphlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32> diff --git a/libspu/compiler/tests/interpret/round_nearest_afz.mlir b/libspu/compiler/tests/interpret/round_nearest_afz.mlir index 6601fa1d..40e64ebe 100644 --- a/libspu/compiler/tests/interpret/round_nearest_afz.mlir +++ b/libspu/compiler/tests/interpret/round_nearest_afz.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @round_nearest_afz_op_test_f64_f64_p() { diff --git a/libspu/compiler/tests/interpret/rsqrt.mlir b/libspu/compiler/tests/interpret/rsqrt.mlir index ef6470c9..5e5e0869 100644 --- a/libspu/compiler/tests/interpret/rsqrt.mlir +++ b/libspu/compiler/tests/interpret/rsqrt.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @rsqrt_op_test_f64_f64_p() { diff --git a/libspu/compiler/tests/interpret/select.mlir b/libspu/compiler/tests/interpret/select.mlir index 33e9a9c7..c2752761 100644 --- a/libspu/compiler/tests/interpret/select.mlir +++ b/libspu/compiler/tests/interpret/select.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @select_op_test_si64() { %pred = pphlo.constant dense<[true, false, true]> : tensor<3xi1> diff --git a/libspu/compiler/tests/interpret/select_and_scatter.mlir b/libspu/compiler/tests/interpret/select_and_scatter.mlir index 5e55dc59..c7f0057a 100644 --- a/libspu/compiler/tests/interpret/select_and_scatter.mlir +++ b/libspu/compiler/tests/interpret/select_and_scatter.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // FIXME func.func @select_and_scatter_op_test() { diff --git a/libspu/compiler/tests/interpret/shift_right_arithmetic.mlir b/libspu/compiler/tests/interpret/shift_right_arithmetic.mlir index 494666ae..ce4bedbc 100644 --- a/libspu/compiler/tests/interpret/shift_right_arithmetic.mlir +++ b/libspu/compiler/tests/interpret/shift_right_arithmetic.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @shift_right_arithmetic_op_test_i64_i64_pp() { diff --git a/libspu/compiler/tests/interpret/shift_right_logical.mlir b/libspu/compiler/tests/interpret/shift_right_logical.mlir index 253d6cc3..73d69f0a 100644 --- a/libspu/compiler/tests/interpret/shift_right_logical.mlir +++ b/libspu/compiler/tests/interpret/shift_right_logical.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @shift_right_logical_op_test_i64_i64_pp() { diff --git a/libspu/compiler/tests/interpret/sign.mlir b/libspu/compiler/tests/interpret/sign.mlir index b153a373..105c5cb2 100644 --- a/libspu/compiler/tests/interpret/sign.mlir +++ b/libspu/compiler/tests/interpret/sign.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @sign_op_test_i64_i64_p() { diff --git a/libspu/compiler/tests/interpret/sine.mlir b/libspu/compiler/tests/interpret/sine.mlir index f0b59e5c..1d62529d 100644 --- a/libspu/compiler/tests/interpret/sine.mlir +++ b/libspu/compiler/tests/interpret/sine.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @sine_op_test_f16_f16_p() { diff --git a/libspu/compiler/tests/interpret/slice.mlir b/libspu/compiler/tests/interpret/slice.mlir index b4b8bdbe..583b977b 100644 --- a/libspu/compiler/tests/interpret/slice.mlir +++ b/libspu/compiler/tests/interpret/slice.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @slice_op() { %operand = pphlo.constant dense<[[0, 0, 1, 0, 0, 1], diff --git a/libspu/compiler/tests/interpret/sort.mlir b/libspu/compiler/tests/interpret/sort.mlir index 837ded37..2433d4ae 100644 --- a/libspu/compiler/tests/interpret/sort.mlir +++ b/libspu/compiler/tests/interpret/sort.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @sort_stable() { %input0 = pphlo.constant dense<[[1, 2, 3], [3, 2, 1]]> : tensor<2x3xi64> diff --git a/libspu/compiler/tests/interpret/sqrt.mlir b/libspu/compiler/tests/interpret/sqrt.mlir index 9248077a..59372c6b 100644 --- a/libspu/compiler/tests/interpret/sqrt.mlir +++ b/libspu/compiler/tests/interpret/sqrt.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @sqrt_op_test_f64_f64_p() { diff --git a/libspu/compiler/tests/interpret/subtract.mlir b/libspu/compiler/tests/interpret/subtract.mlir index ce8b9633..37032882 100644 --- a/libspu/compiler/tests/interpret/subtract.mlir +++ b/libspu/compiler/tests/interpret/subtract.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @subtract_op_test_i8_i8_pp() { diff --git a/libspu/compiler/tests/interpret/tanh.mlir b/libspu/compiler/tests/interpret/tanh.mlir index 413dc6d2..ab45fa2e 100644 --- a/libspu/compiler/tests/interpret/tanh.mlir +++ b/libspu/compiler/tests/interpret/tanh.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @tanh_op_test_f16_f16_p() { diff --git a/libspu/compiler/tests/interpret/test_json/reciprocal.json b/libspu/compiler/tests/interpret/test_json/reciprocal.json new file mode 100644 index 00000000..d8e07529 --- /dev/null +++ b/libspu/compiler/tests/interpret/test_json/reciprocal.json @@ -0,0 +1,24 @@ +{ + "name": "reciprocal", + "template": "basic_unary", + "testcases": [ + { + "inputs": [ + { + "data": "[[1.0, -200.0], [100.0, 286991.875]]", + "shape": "2x2", + "dtype": "f64" + } + ], + "expected": [ + { + "data": "[[1.0, -0.005], [0.01, 0.0]]", + "shape": "2x2", + "dtype": "f64" + } + ], + "checker": "expect_almost_eq", + "tol": 0.01 + } + ] +} \ No newline at end of file diff --git a/libspu/compiler/tests/interpret/transpose.mlir b/libspu/compiler/tests/interpret/transpose.mlir index d5ce9b6e..d36883de 100644 --- a/libspu/compiler/tests/interpret/transpose.mlir +++ b/libspu/compiler/tests/interpret/transpose.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @transpose_op_test_si32() { %0 = pphlo.constant dense<[[[1,2],[3,4],[5,6]], [[7,8],[9,10],[11,12]]]> : tensor<2x3x2xi32> diff --git a/libspu/compiler/tests/interpret/while.mlir b/libspu/compiler/tests/interpret/while.mlir index 32789fae..734e199e 100644 --- a/libspu/compiler/tests/interpret/while.mlir +++ b/libspu/compiler/tests/interpret/while.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @while() { // int i = 0; diff --git a/libspu/compiler/tests/interpret/xor.mlir b/libspu/compiler/tests/interpret/xor.mlir index c2e1f1ee..f2b8fba9 100644 --- a/libspu/compiler/tests/interpret/xor.mlir +++ b/libspu/compiler/tests/interpret/xor.mlir @@ -1,6 +1,7 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @xor_op_test_i8_i8_pp() { diff --git a/libspu/compiler/tests/passes/hlo2pphlo/select_and_scatter.mlir b/libspu/compiler/tests/passes/hlo2pphlo/select_and_scatter.mlir index ac7492bc..bc21172f 100644 --- a/libspu/compiler/tests/passes/hlo2pphlo/select_and_scatter.mlir +++ b/libspu/compiler/tests/passes/hlo2pphlo/select_and_scatter.mlir @@ -1,7 +1,7 @@ // RUN: spu-opt --hlo-legalize-to-pphlo=input_vis_list=VIS_SECRET,VIS_PUBLIC,VIS_PUBLIC --lower-conversion-cast --split-input-file %s | FileCheck %s func.func @main(%arg0: tensor<128x5x5x32xf32>, %arg1: tensor<128x4x4x32xf32>, %arg2: tensor) -> tensor<128x5x5x32xf32> { - // CHECK: %1 = "pphlo.select_and_scatter"(%arg0, %arg1, %0) ({ + // CHECK: %1 = "pphlo.select_and_scatter"(%arg0, %arg1, %0) <{window_dimensions = array, window_strides = array}> ({ // CHECK: ^bb0(%arg3: tensor>, %arg4: tensor>): // CHECK: %2 = pphlo.greater_equal %arg3, %arg4 : (tensor>, tensor>) -> tensor> // CHECK: pphlo.return %2 : tensor> @@ -9,7 +9,7 @@ func.func @main(%arg0: tensor<128x5x5x32xf32>, %arg1: tensor<128x4x4x32xf32>, %a // CHECK: ^bb0(%arg3: tensor, %arg4: tensor>): // CHECK: %2 = pphlo.add %arg3, %arg4 : (tensor, tensor>) -> tensor> // CHECK: pphlo.return %2 : tensor> - // CHECK: }) {window_dimensions = array, window_strides = array} : (tensor<128x5x5x32x!pphlo.secret>, tensor<128x4x4x32xf32>, tensor>) -> tensor<128x5x5x32x!pphlo.secret> + // CHECK: }) : (tensor<128x5x5x32x!pphlo.secret>, tensor<128x4x4x32xf32>, tensor>) -> tensor<128x5x5x32x!pphlo.secret> %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ ^bb0(%arg3: tensor, %arg4: tensor): %1 = "stablehlo.compare"(%arg3, %arg4) {comparison_direction = #stablehlo} : (tensor, tensor) -> tensor diff --git a/libspu/compiler/tests/passes/hlo2pphlo/sort_p.mlir b/libspu/compiler/tests/passes/hlo2pphlo/sort_p.mlir index 2224ee7d..facd338f 100644 --- a/libspu/compiler/tests/passes/hlo2pphlo/sort_p.mlir +++ b/libspu/compiler/tests/passes/hlo2pphlo/sort_p.mlir @@ -2,11 +2,11 @@ func.func @main(%arg0: tensor<20xi32>) -> (tensor<20xi32>) { %0 = stablehlo.iota dim = 0 : tensor<20xi32> - // CHECK: %1:2 = "pphlo.sort"(%arg0, %0) ({ + // CHECK: %1:2 = "pphlo.sort"(%arg0, %0) <{dimension = 0 : i64, is_stable = true}> ({ // CHECK: ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): // CHECK: %2 = pphlo.less %arg1, %arg2 : (tensor, tensor) -> tensor // CHECK: pphlo.return %2 : tensor - // CHECK: }) {dimension = 0 : i64, is_stable = true} : (tensor<20xi32>, tensor<20xi32>) -> (tensor<20xi32>, tensor<20xi32>) + // CHECK: }) : (tensor<20xi32>, tensor<20xi32>) -> (tensor<20xi32>, tensor<20xi32>) %1:2 = "stablehlo.sort"(%arg0, %0) ({ ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): %2 = stablehlo.compare LT, %arg1, %arg2 : (tensor, tensor) -> tensor diff --git a/libspu/compiler/tests/passes/hlo2pphlo/sort_s.mlir b/libspu/compiler/tests/passes/hlo2pphlo/sort_s.mlir index c4d4bd41..cf7ef73d 100644 --- a/libspu/compiler/tests/passes/hlo2pphlo/sort_s.mlir +++ b/libspu/compiler/tests/passes/hlo2pphlo/sort_s.mlir @@ -2,11 +2,11 @@ func.func @main(%arg0: tensor<20xi32>) -> (tensor<20xi32>) { %0 = stablehlo.iota dim = 0 : tensor<20xi32> - // CHECK: %1:2 = "pphlo.sort"(%arg0, %0) ({ + // CHECK: %1:2 = "pphlo.sort"(%arg0, %0) <{dimension = 0 : i64, is_stable = true}> ({ // CHECK: ^bb0(%arg1: tensor>, %arg2: tensor>, %arg3: tensor, %arg4: tensor): // CHECK: %2 = pphlo.less %arg1, %arg2 : (tensor>, tensor>) -> tensor> // CHECK: pphlo.return %2 : tensor> - // CHECK: }) {dimension = 0 : i64, is_stable = true} : (tensor<20x!pphlo.secret>, tensor<20xi32>) -> (tensor<20x!pphlo.secret>, tensor<20x!pphlo.secret>) + // CHECK: }) : (tensor<20x!pphlo.secret>, tensor<20xi32>) -> (tensor<20x!pphlo.secret>, tensor<20x!pphlo.secret>) %1:2 = "stablehlo.sort"(%arg0, %0) ({ ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): %2 = stablehlo.compare LT, %arg1, %arg2 : (tensor, tensor) -> tensor diff --git a/libspu/compiler/tests/passes/optimizations/ops_negative.mlir b/libspu/compiler/tests/passes/optimizations/ops_negative.mlir index 926a5de5..e8644b96 100644 --- a/libspu/compiler/tests/passes/optimizations/ops_negative.mlir +++ b/libspu/compiler/tests/passes/optimizations/ops_negative.mlir @@ -24,7 +24,7 @@ func.func @main() -> tensor { func.func @main() -> tensor { %0 = pphlo.constant dense<[0.000000e+00, -3.40282347E+38]> : tensor<2xf32> - // expected-error @+1 {{op broadcast_dimensions contains invalid value -6 for result with rank 1}} + // expected-error @+1 {{broadcast_dimensions contains invalid value -6 for result with rank 1}} %1 = pphlo.broadcast %0, dims = [-6] : (tensor<2xf32>) -> tensor<2xf32> %2 = pphlo.constant dense<5> : tensor pphlo.return %2 : tensor @@ -33,7 +33,7 @@ func.func @main() -> tensor { // ----- func.func @main() -> tensor { - // expected-error @+1 {{op iota dimension cannot go beyond the output rank or be negative}} + // expected-error @+1 {{iota dimension cannot go beyond the output rank}} %0 = pphlo.iota dim = 1000 : tensor<1xi32> %1 = pphlo.constant dense<5> : tensor pphlo.return %1 : tensor diff --git a/libspu/compiler/tools/spu-lsp.cc b/libspu/compiler/tools/spu-lsp.cc index 0b9ddbd6..9f59509d 100644 --- a/libspu/compiler/tools/spu-lsp.cc +++ b/libspu/compiler/tools/spu-lsp.cc @@ -25,5 +25,6 @@ int main(int argc, char **argv) { registry.insert(); mlir::func::registerInlinerExtension(registry); - return mlir::failed(mlir::MlirLspServerMain(argc, argv, registry)); + return static_cast( + mlir::failed(mlir::MlirLspServerMain(argc, argv, registry))); } diff --git a/libspu/compiler/tools/spu-translate.cc b/libspu/compiler/tools/spu-translate.cc index 423321a5..89903340 100644 --- a/libspu/compiler/tools/spu-translate.cc +++ b/libspu/compiler/tools/spu-translate.cc @@ -22,12 +22,10 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Tools/mlir-translate/MlirTranslateMain.h" #include "mlir/Tools/mlir-translate/Translation.h" -#include "mlir/Transforms/Passes.h" #include "stablehlo/dialect/StablehloOps.h" #include "xtensor/xio.hpp" #include "libspu/compiler/common/compilation_context.h" -#include "libspu/compiler/utils/utils.h" #include "libspu/core/prelude.h" #include "libspu/device/pphlo/pphlo_executor.h" #include "libspu/dialect/pphlo/IR/dialect.h" @@ -77,7 +75,7 @@ void isEqual(const xt::xarray &lhs, const xt::xarray &rhs) { auto error = lhs - rhs; - for (auto v : error) { + for (T v : error) { if (v != 0) { llvm::report_fatal_error(fmt::format("Diff = {}", v).c_str()); } @@ -89,7 +87,7 @@ bool testOpHandler(::spu::SPUContext *sctx, mlir::Operation *op, auto callOp = mlir::dyn_cast(op); if (callOp.getCallTargetName() == "expect_almost_eq") { ::spu::Value runtimeLhs = inputs[0]; - ::spu::Value runtimeRhs = inputs[1]; + const ::spu::Value &runtimeRhs = inputs[1]; if (!runtimeLhs.isPublic()) { runtimeLhs = ::spu::kernel::hal::_s2p(sctx, runtimeLhs) @@ -111,7 +109,7 @@ bool testOpHandler(::spu::SPUContext *sctx, mlir::Operation *op, auto error = xt::fabs(lhs - rhs); - for (auto v : error) { + for (double v : error) { if (v > tol) { llvm::report_fatal_error( fmt::format("Diff {} greater than tol {}", v, tol).c_str()); @@ -123,7 +121,7 @@ bool testOpHandler(::spu::SPUContext *sctx, mlir::Operation *op, if (callOp.getCallTargetName() == "expect_eq") { ::spu::Value runtimeLhs = inputs[0]; - ::spu::Value runtimeRhs = inputs[1]; + const ::spu::Value &runtimeRhs = inputs[1]; if (!runtimeLhs.isPublic()) { runtimeLhs = ::spu::kernel::hal::_s2p(sctx, runtimeLhs) @@ -239,6 +237,16 @@ void evalModule(ModuleOp module) { numParties = 3; break; } + case 4: { + conf.set_protocol(::spu::CHEETAH); + numParties = 2; + break; + } + case 5: { + conf.set_protocol(::spu::SECURENN); + numParties = 3; + break; + } } SPDLOG_INFO(conf.DebugString()); @@ -278,6 +286,6 @@ TranslateFromMLIRRegistration interpretRegistration( } // namespace mlir int main(int argc, char **argv) { - return failed( - mlir::mlirTranslateMain(argc, argv, "SPU interpreter driver\n")); + return static_cast( + failed(mlir::mlirTranslateMain(argc, argv, "SPU interpreter driver\n"))); } diff --git a/libspu/compiler/utils/utils.h b/libspu/compiler/utils/utils.h index faeff8d9..3b06397a 100644 --- a/libspu/compiler/utils/utils.h +++ b/libspu/compiler/utils/utils.h @@ -14,6 +14,7 @@ #pragma once +#include "llvm/ADT/Twine.h" #include "mlir/Support/LogicalResult.h" namespace mlir::spu { diff --git a/libspu/core/config.cc b/libspu/core/config.cc index 7f63fc67..81ca2f9d 100644 --- a/libspu/core/config.cc +++ b/libspu/core/config.cc @@ -62,6 +62,17 @@ void populateRuntimeConfig(RuntimeConfig& cfg) { if (cfg.fxp_exp_mode() == RuntimeConfig::EXP_DEFAULT) { cfg.set_fxp_exp_mode(RuntimeConfig::EXP_TAYLOR); } + if (cfg.fxp_exp_mode() == RuntimeConfig::EXP_PRIME) { + // 0 offset is not supported + if (cfg.experimental_exp_prime_offset() == 0) { + // For FM128 default offset is 13 + if (cfg.field() == FieldType::FM128) { + cfg.set_experimental_exp_prime_offset(13); + } + // TODO: set defaults for other fields, currently only FM128 is + // supported + } + } if (cfg.fxp_exp_iters() == 0) { cfg.set_fxp_exp_iters(8); diff --git a/libspu/core/encoding.cc b/libspu/core/encoding.cc index eb26ce9b..98a17a1a 100644 --- a/libspu/core/encoding.cc +++ b/libspu/core/encoding.cc @@ -60,8 +60,8 @@ NdArrayRef encodeToRing(const PtBufferView& bv, FieldType field, } if (pt_type == PT_F32 || pt_type == PT_F64 || pt_type == PT_F16) { - DISPATCH_FLOAT_PT_TYPES(pt_type, "_", [&]() { - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_FLOAT_PT_TYPES(pt_type, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using Float = ScalarT; using T = std::make_signed_t; @@ -100,8 +100,8 @@ NdArrayRef encodeToRing(const PtBufferView& bv, FieldType field, return dst; } else { // handle integer & boolean - DISPATCH_INT_PT_TYPES(pt_type, "_", [&]() { - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_INT_PT_TYPES(pt_type, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using Integer = ScalarT; SPU_ENFORCE(sizeof(ring2k_t) >= sizeof(Integer), "integer encoding failed, ring={} could not represent {}", @@ -138,8 +138,8 @@ void decodeFromRing(const NdArrayRef& src, DataType in_dtype, size_t fxp_bits, *out_pt_type = pt_type; } - DISPATCH_ALL_FIELDS(field, "field", [&]() { - DISPATCH_ALL_PT_TYPES(pt_type, "pt_type", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { + DISPATCH_ALL_PT_TYPES(pt_type, [&]() { using T = std::make_signed_t; auto _src = NdArrayView(src); diff --git a/libspu/core/ndarray_ref.h b/libspu/core/ndarray_ref.h index 87e43c2c..a797d803 100644 --- a/libspu/core/ndarray_ref.h +++ b/libspu/core/ndarray_ref.h @@ -20,6 +20,7 @@ #include "absl/types/span.h" #include "fmt/ostream.h" +#include "fmt/ranges.h" #include "yacl/base/buffer.h" #include "libspu/core/bit_utils.h" @@ -409,7 +410,6 @@ struct SimdTrait { NdArrayRef makeConstantArrayRef(const Type& eltype, const Shape& shape); std::ostream& operator<<(std::ostream& out, const NdArrayRef& v); -inline auto format_as(const spu::NdArrayRef& f) { return fmt::streamed(f); } template class NdArrayView { @@ -491,3 +491,10 @@ struct std::hash { return std::hash{}(r.data()); } }; + +namespace fmt { + +template <> +struct formatter : ostream_formatter {}; + +} // namespace fmt diff --git a/libspu/core/object.h b/libspu/core/object.h index a1aa3c03..6ea34d44 100644 --- a/libspu/core/object.h +++ b/libspu/core/object.h @@ -112,7 +112,7 @@ class Object final { template void regKernel() { - regKernel(KernelT::kBindName, std::make_unique()); + regKernel(KernelT::kBindName(), std::make_unique()); } template @@ -137,14 +137,15 @@ class Object final { template void addState(Args&&... args) { - addState(StateT::kBindName, + addState(StateT::kBindName(), std::make_unique(std::forward(args)...)); } template StateT* getState() { - const auto& itr = states_.find(StateT::kBindName); - SPU_ENFORCE(itr != states_.end(), "state={} not found", StateT::kBindName); + const auto& itr = states_.find(StateT::kBindName()); + SPU_ENFORCE(itr != states_.end(), "state={} not found", + StateT::kBindName()); return dynamic_cast(itr->second.get()); } diff --git a/libspu/core/pt_buffer_view.cc b/libspu/core/pt_buffer_view.cc index 50e6f891..3f7e1084 100644 --- a/libspu/core/pt_buffer_view.cc +++ b/libspu/core/pt_buffer_view.cc @@ -50,7 +50,7 @@ NdArrayRef convertToNdArray(PtBufferView bv) { } const auto type = makePtType(bv.pt_type); auto out = NdArrayRef(type, bv.shape); - return DISPATCH_ALL_PT_TYPES(bv.pt_type, "pt_type", [&]() { + return DISPATCH_ALL_PT_TYPES(bv.pt_type, [&]() { using T = ScalarT; if (bv.shape.numel() > 0) { auto* out_ptr = out.data(); diff --git a/libspu/core/shape.h b/libspu/core/shape.h index 3b6a5c40..151872e5 100644 --- a/libspu/core/shape.h +++ b/libspu/core/shape.h @@ -22,6 +22,7 @@ #include #include "absl/types/span.h" +#include "fmt/ranges.h" #include "llvm/ADT/ArrayRef.h" #include "libspu/core/prelude.h" @@ -73,8 +74,6 @@ class Shape : public std::vector { bool empty() const { return Base::empty(); } }; -inline auto format_as(const Shape &s) { return fmt::streamed(s); } - class Index : public std::vector { private: using Base = std::vector; @@ -95,8 +94,6 @@ class Index : public std::vector { } }; -inline auto format_as(const Index &idx) { return fmt::streamed(idx); } - using Stride = int64_t; class Strides : public std::vector { @@ -117,8 +114,6 @@ class Strides : public std::vector { } }; -inline auto format_as(const Strides &s) { return fmt::streamed(s); } - class Sizes : public std::vector { private: using Base = std::vector; @@ -135,8 +130,6 @@ class Sizes : public std::vector { } }; -inline auto format_as(const Sizes &s) { return fmt::streamed(s); } - class Axes : public std::vector { private: using Base = std::vector; @@ -153,8 +146,6 @@ class Axes : public std::vector { } }; -inline auto format_as(const Axes &axes) { return fmt::streamed(axes); } - Strides makeCompactStrides(const Shape &shape); int64_t flattenIndex(const Index &index, const Shape &shape); @@ -191,3 +182,22 @@ inline size_t calcFlattenOffset(const Index &indices, const Shape &shape, } } // namespace spu + +namespace fmt { + +template <> +struct formatter : ostream_formatter {}; + +template <> +struct formatter : ostream_formatter {}; + +template <> +struct formatter : ostream_formatter {}; + +template <> +struct formatter : ostream_formatter {}; + +template <> +struct formatter : ostream_formatter {}; + +} // namespace fmt diff --git a/libspu/core/trace.h b/libspu/core/trace.h index c67dbff0..582257d7 100644 --- a/libspu/core/trace.h +++ b/libspu/core/trace.h @@ -22,7 +22,7 @@ #include #include "absl/types/span.h" -#include "fmt/format.h" +#include "fmt/ranges.h" #include "spdlog/spdlog.h" #include "yacl/link/context.h" @@ -154,6 +154,10 @@ struct ActionRecord final { size_t send_bytes_end; size_t recv_bytes_start; size_t recv_bytes_end; + size_t send_actions_start; + size_t send_actions_end; + size_t recv_actions_start; + size_t recv_actions_end; }; class ProfState final { @@ -238,6 +242,10 @@ class TraceAction final { size_t send_bytes_end_; size_t recv_bytes_start_; size_t recv_bytes_end_; + size_t send_actions_start_; + size_t send_actions_end_; + size_t recv_actions_start_; + size_t recv_actions_end_; int64_t saved_tracer_flag_; @@ -247,6 +255,8 @@ class TraceAction final { if (lctx_) { send_bytes_start_ = lctx_->GetStats()->sent_bytes.load(); recv_bytes_start_ = lctx_->GetStats()->recv_bytes.load(); + send_actions_start_ = lctx_->GetStats()->sent_actions.load(); + recv_actions_start_ = lctx_->GetStats()->recv_actions.load(); } const auto flag = flag_ & tracer_->getFlag(); if ((flag & TR_LOGB) != 0) { @@ -269,6 +279,8 @@ class TraceAction final { if (lctx_) { send_bytes_end_ = lctx_->GetStats()->sent_bytes.load(); recv_bytes_end_ = lctx_->GetStats()->recv_bytes.load(); + send_actions_end_ = lctx_->GetStats()->sent_actions.load(); + recv_actions_end_ = lctx_->GetStats()->recv_actions.load(); } const auto flag = flag_ & tracer_->getFlag(); if ((flag & TR_LOGE) != 0) { @@ -279,7 +291,8 @@ class TraceAction final { tracer_->getProfState()->addRecord( ActionRecord{id_, name_, std::move(detail_), flag_, start_, end_, send_bytes_start_, send_bytes_end_, recv_bytes_start_, - recv_bytes_end_}); + recv_bytes_end_, send_actions_start_, send_actions_end_, + recv_actions_start_, recv_actions_end_}); } } diff --git a/libspu/core/type.h b/libspu/core/type.h index 62ebca90..2bcf6819 100644 --- a/libspu/core/type.h +++ b/libspu/core/type.h @@ -43,6 +43,16 @@ class Ring2k { FieldType field() const { return field_; } }; +// This trait means the data is maintained in Galois prime field. +class Gfp { + protected: + uint128_t prime_{0}; + + public: + virtual ~Gfp() = default; + uint128_t p() const { return prime_; } +}; + // The public interface. // // The value of this type is public visible for parties. @@ -384,6 +394,54 @@ class RingTy : public TypeImpl { } }; +// Galois field type of Mersenne primes, e.g., 2^127-1 +class GfmpTy : public TypeImpl { + using Base = TypeImpl; + + protected: + size_t mersenne_prime_exp_; + + public: + using Base::Base; + explicit GfmpTy(FieldType field) { + field_ = field; + mersenne_prime_exp_ = GetMersennePrimeExp(field); + prime_ = (static_cast(1) << mersenne_prime_exp_) - 1; + } + + static std::string_view getStaticId() { return "Gfmp"; } + + size_t size() const override { + if (field_ == FT_INVALID) { + return 0; + } + return SizeOf(GetStorageType(field_)); + } + + size_t mp_exp() const { return mersenne_prime_exp_; } + + void fromString(std::string_view detail) override { + auto comma = detail.find_first_of(','); + auto field_str = detail.substr(0, comma); + auto mp_exp_str = detail.substr(comma + 1); + SPU_ENFORCE(FieldType_Parse(std::string(field_str), &field_), + "parse failed from={}", detail); + mersenne_prime_exp_ = std::stoul(std::string(mp_exp_str)); + prime_ = (static_cast(1) << mersenne_prime_exp_) - 1; + } + + std::string toString() const override { + return fmt::format("{},{}", FieldType_Name(field()), mersenne_prime_exp_); + } + + bool equals(TypeObject const* other) const override { + auto const* derived_other = dynamic_cast(other); + SPU_ENFORCE(derived_other); + return field() == derived_other->field() && + mp_exp() == derived_other->mp_exp() && p() == derived_other->p(); + } +}; + class TypeContext final { public: using TypeCreateFn = @@ -395,7 +453,8 @@ class TypeContext final { public: TypeContext() { - addTypes(); // Base types that we need to register + addTypes(); // Base types that we need to register } template diff --git a/libspu/core/type_test.cc b/libspu/core/type_test.cc index 88107473..29510b52 100644 --- a/libspu/core/type_test.cc +++ b/libspu/core/type_test.cc @@ -125,4 +125,24 @@ TEST(TypeTest, RingTy) { EXPECT_EQ(Type::fromString(fm128.toString()), fm128); } +TEST(TypeTest, GfmpTy) { + Type gfmp31 = makeType(FM32); + EXPECT_EQ(gfmp31.size(), 4); + EXPECT_TRUE(gfmp31.isa()); + EXPECT_EQ(gfmp31.toString(), "Gfmp"); + EXPECT_EQ(Type::fromString(gfmp31.toString()), gfmp31); + + Type gfmp61 = makeType(FM64); + EXPECT_EQ(gfmp61.size(), 8); + EXPECT_TRUE(gfmp61.isa()); + EXPECT_EQ(gfmp61.toString(), "Gfmp"); + EXPECT_EQ(Type::fromString(gfmp61.toString()), gfmp61); + + Type gfmp127 = makeType(FM128); + EXPECT_EQ(gfmp127.size(), 16); + EXPECT_TRUE(gfmp127.isa()); + EXPECT_EQ(gfmp127.toString(), "Gfmp"); + EXPECT_EQ(Type::fromString(gfmp127.toString()), gfmp127); +} + } // namespace spu diff --git a/libspu/core/type_util.cc b/libspu/core/type_util.cc index 8261e03f..7eac207e 100644 --- a/libspu/core/type_util.cc +++ b/libspu/core/type_util.cc @@ -122,6 +122,22 @@ std::ostream& operator<<(std::ostream& os, ProtocolKind protocol) { return os; } +////////////////////////////////////////////////////////////// +// Field GFP mappings, currently only support Mersenne primes +////////////////////////////////////////////////////////////// +size_t GetMersennePrimeExp(FieldType field) { +#define CASE(Name, ScalarT, MersennePrimeExp) \ + case FieldType::Name: \ + return MersennePrimeExp; \ + break; + switch (field) { + FIELD_TO_MERSENNE_PRIME_EXP_MAP(CASE) + default: + SPU_THROW("unknown supported field {}", field); + } +#undef CASE +} + ////////////////////////////////////////////////////////////// // Field 2k types, TODO(jint) support Zq ////////////////////////////////////////////////////////////// diff --git a/libspu/core/type_util.h b/libspu/core/type_util.h index 2ac1d294..84b04f45 100644 --- a/libspu/core/type_util.h +++ b/libspu/core/type_util.h @@ -97,91 +97,90 @@ std::ostream& operator<<(std::ostream& os, const DataType& dtype); // Helper macros to enumerate all py types. // NOLINTNEXTLINE: Global internal used macro. -#define __CASE_PT_TYPE(PT_TYPE, NAME, ...) \ - case (PT_TYPE): { \ - [[maybe_unused]] constexpr std::string_view _kName = NAME; \ - using ScalarT = EnumToPtType::type; \ - return __VA_ARGS__(); \ +#define __CASE_PT_TYPE(PT_TYPE, ...) \ + case (PT_TYPE): { \ + using ScalarT = EnumToPtType::type; \ + return __VA_ARGS__(); \ } -#define DISPATCH_FLOAT_PT_TYPES(PT_TYPE, NAME, ...) \ - [&] { \ - switch (PT_TYPE) { \ - __CASE_PT_TYPE(spu::PT_F16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_F32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_F64, NAME, __VA_ARGS__) \ - default: \ - SPU_THROW("{} not implemented for pt_type={}", #NAME, PT_TYPE); \ - } \ +#define DISPATCH_FLOAT_PT_TYPES(PT_TYPE, ...) \ + [&] { \ + switch (PT_TYPE) { \ + __CASE_PT_TYPE(spu::PT_F16, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_F32, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_F64, __VA_ARGS__) \ + default: \ + SPU_THROW("unimplemented for pt_type={}", PT_TYPE); \ + } \ }() -#define DISPATCH_UINT_PT_TYPES(PT_TYPE, NAME, ...) \ - [&] { \ - switch (PT_TYPE) { \ - __CASE_PT_TYPE(spu::PT_U8, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U64, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U128, NAME, __VA_ARGS__) \ - default: \ - SPU_THROW("{} not implemented for pt_type={}", #NAME, PT_TYPE); \ - } \ +#define DISPATCH_UINT_PT_TYPES(PT_TYPE, ...) \ + [&] { \ + switch (PT_TYPE) { \ + __CASE_PT_TYPE(spu::PT_U8, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U16, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U32, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U64, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U128, __VA_ARGS__) \ + default: \ + SPU_THROW("unimplemented for pt_type={}", PT_TYPE); \ + } \ }() -#define DISPATCH_INT_PT_TYPES(PT_TYPE, NAME, ...) \ - [&] { \ - switch (PT_TYPE) { \ - __CASE_PT_TYPE(spu::PT_I1, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I8, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U8, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I64, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U64, NAME, __VA_ARGS__) \ - default: \ - SPU_THROW("{} not implemented for pt_type={}", #NAME, PT_TYPE); \ - } \ +#define DISPATCH_INT_PT_TYPES(PT_TYPE, ...) \ + [&] { \ + switch (PT_TYPE) { \ + __CASE_PT_TYPE(spu::PT_I1, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I8, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U8, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I16, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U16, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I32, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U32, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I64, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U64, __VA_ARGS__) \ + default: \ + SPU_THROW("unimplemented for pt_type={}", PT_TYPE); \ + } \ }() -#define DISPATCH_ALL_PT_TYPES(PT_TYPE, NAME, ...) \ - [&] { \ - switch (PT_TYPE) { \ - __CASE_PT_TYPE(spu::PT_I1, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I8, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U8, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I64, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U64, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_F16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_F32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_F64, NAME, __VA_ARGS__) \ - default: \ - SPU_THROW("{} not implemented for pt_type={}", #NAME, PT_TYPE); \ - } \ +#define DISPATCH_ALL_PT_TYPES(PT_TYPE, ...) \ + [&] { \ + switch (PT_TYPE) { \ + __CASE_PT_TYPE(spu::PT_I1, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I8, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U8, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I16, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U16, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I32, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U32, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I64, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U64, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_F16, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_F32, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_F64, __VA_ARGS__) \ + default: \ + SPU_THROW("unimplemented for pt_type={}", PT_TYPE); \ + } \ }() -#define DISPATCH_ALL_NONE_BOOL_PT_TYPES(PT_TYPE, NAME, ...) \ - [&] { \ - switch (PT_TYPE) { \ - __CASE_PT_TYPE(spu::PT_I8, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U8, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I64, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U64, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_F16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_F32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_F64, NAME, __VA_ARGS__) \ - default: \ - SPU_THROW("{} not implemented for pt_type={}", #NAME, PT_TYPE); \ - } \ +#define DISPATCH_ALL_NONE_BOOL_PT_TYPES(PT_TYPE, ...) \ + [&] { \ + switch (PT_TYPE) { \ + __CASE_PT_TYPE(spu::PT_I8, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U8, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I16, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U16, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I32, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U32, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I64, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U64, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_F16, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_F32, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_F64, __VA_ARGS__) \ + default: \ + SPU_THROW("unimplemented for pt_type={}", PT_TYPE); \ + } \ }() std::ostream& operator<<(std::ostream& os, const PtType& pt_type); @@ -213,7 +212,17 @@ FOREACH_PT_TYPES(CASE) std::ostream& operator<<(std::ostream& os, ProtocolKind protocol); ////////////////////////////////////////////////////////////// -// Field 2k types, TODO(jint) support Zq +// Field GFP mappings, currently only support Mersenne primes +////////////////////////////////////////////////////////////// +#define FIELD_TO_MERSENNE_PRIME_EXP_MAP(FN) \ + FN(FM32, uint32_t, 31) \ + FN(FM64, uint64_t, 61) \ + FN(FM128, uint128_t, 127) + +size_t GetMersennePrimeExp(FieldType field); + +////////////////////////////////////////////////////////////// +// Field 2k types ////////////////////////////////////////////////////////////// #define FIELD_TO_STORAGE_MAP(FN) \ FN(FM32, PT_U32) \ @@ -241,26 +250,40 @@ inline size_t SizeOf(FieldType field) { return SizeOf(GetStorageType(field)); } // Helper macros to enumerate all fields // NOLINTNEXTLINE: Global internal used macro. -#define __CASE_FIELD(FIELD, NAME, ...) \ +#define __CASE_FIELD(FIELD, ...) \ case (FIELD): { \ /* inject `_kField` & `_kName` for the continuation call */ \ [[maybe_unused]] constexpr spu::FieldType _kField = FIELD; \ - [[maybe_unused]] constexpr std::string_view _kName = NAME; \ using ring2k_t [[maybe_unused]] = Ring2kTrait<_kField>::scalar_t; \ return __VA_ARGS__(); \ } -#define DISPATCH_ALL_FIELDS(FIELD, NAME, ...) \ - [&] { \ - switch (FIELD) { \ - __CASE_FIELD(spu::FieldType::FM32, NAME, __VA_ARGS__) \ - __CASE_FIELD(spu::FieldType::FM64, NAME, __VA_ARGS__) \ - __CASE_FIELD(spu::FieldType::FM128, NAME, __VA_ARGS__) \ - default: \ - SPU_THROW("{} not implemented for field={}", #NAME, FIELD); \ - } \ +#define DISPATCH_ALL_FIELDS(FIELD, ...) \ + [&] { \ + switch (FIELD) { \ + __CASE_FIELD(spu::FieldType::FM32, __VA_ARGS__) \ + __CASE_FIELD(spu::FieldType::FM64, __VA_ARGS__) \ + __CASE_FIELD(spu::FieldType::FM128, __VA_ARGS__) \ + default: \ + SPU_THROW("unimplemented for field={}", FIELD); \ + } \ }() +////////////////////////////////////////////////////////////// +// Field Prime types +////////////////////////////////////////////////////////////// +template +struct ScalarTypeToPrime {}; + +#define DEF_TRAITS(Field, ScalarT, Exp) \ + template <> \ + struct ScalarTypeToPrime { \ + static constexpr size_t exp = Exp; \ + static constexpr ScalarT prime = (static_cast(1) << Exp) - 1; \ + }; +FIELD_TO_MERSENNE_PRIME_EXP_MAP(DEF_TRAITS) +#undef DEF_TRAITS + ////////////////////////////////////////////////////////////// // Value range information, should it be here, at top level(jint)? ////////////////////////////////////////////////////////////// diff --git a/libspu/core/xt_helper.h b/libspu/core/xt_helper.h index 3230eada..44921507 100644 --- a/libspu/core/xt_helper.h +++ b/libspu/core/xt_helper.h @@ -63,3 +63,6 @@ NdArrayRef xt_to_ndarray(const xt::xexpression& e) { } } // namespace spu + +template +struct fmt::is_range, char> : std::false_type {}; diff --git a/libspu/device/api.cc b/libspu/device/api.cc index 1b27f881..bae9b462 100644 --- a/libspu/device/api.cc +++ b/libspu/device/api.cc @@ -71,12 +71,14 @@ struct CommunicationStats { size_t send_bytes = 0; size_t recv_bytes = 0; size_t send_actions = 0; + size_t recv_actions = 0; void reset(const std::shared_ptr &lctx) { if (!lctx) { return; } send_actions = lctx->GetStats()->sent_actions; + recv_actions = lctx->GetStats()->recv_actions; send_bytes = lctx->GetStats()->sent_bytes; recv_bytes = lctx->GetStats()->recv_bytes; } @@ -88,6 +90,7 @@ struct CommunicationStats { send_bytes = lctx->GetStats()->sent_bytes - send_bytes; recv_bytes = lctx->GetStats()->recv_bytes - recv_bytes; send_actions = lctx->GetStats()->sent_actions - send_actions; + recv_actions = lctx->GetStats()->recv_actions - recv_actions; } }; @@ -108,6 +111,10 @@ struct ActionStats { size_t send_bytes = 0; // total recv bytes. size_t recv_bytes = 0; + // total send actions. + size_t send_actions = 0; + // total recv actions. + size_t recv_actions = 0; inline double getTotalTimeInSecond() const { return std::chrono::duration_cast>(total_time) @@ -183,6 +190,8 @@ void printProfilingData(spu::SPUContext *sctx, const std::string &name, std::chrono::duration_cast(rec.end - rec.start); stat.send_bytes += (rec.send_bytes_end - rec.send_bytes_start); stat.recv_bytes += (rec.recv_bytes_end - rec.recv_bytes_start); + stat.send_actions += (rec.send_actions_end - rec.send_actions_start); + stat.recv_actions += (rec.recv_actions_end - rec.recv_actions_start); } static std::map kModules = { @@ -213,17 +222,19 @@ void printProfilingData(spu::SPUContext *sctx, const std::string &name, const auto &stat = stats.find(key)->second; SPDLOG_INFO( "- {}, executed {} times, duration {}s, send bytes {} recv " - "bytes {}", + "bytes {}, send actions {}, recv actions {}", key.name, stat.count, stat.getTotalTimeInSecond(), stat.send_bytes, - stat.recv_bytes); + stat.recv_bytes, stat.send_actions, stat.recv_actions); } } } // print link statistics SPDLOG_INFO( - "Link details: total send bytes {}, recv bytes {}, send actions {}", - comm_stats.send_bytes, comm_stats.recv_bytes, comm_stats.send_actions); + "Link details: total send bytes {}, recv bytes {}, send actions {}, recv " + "actions {}", + comm_stats.send_bytes, comm_stats.recv_bytes, comm_stats.send_actions, + comm_stats.recv_actions); } void SPUErrorHandler(void *use_data, const char *reason, bool gen_crash_diag) { diff --git a/libspu/device/utils/BUILD.bazel b/libspu/device/utils/BUILD.bazel index 699a4302..f864ce56 100644 --- a/libspu/device/utils/BUILD.bazel +++ b/libspu/device/utils/BUILD.bazel @@ -14,6 +14,10 @@ load("//bazel:spu.bzl", "spu_cc_binary", "spu_cc_library") +package( + default_visibility = ["//visibility:public"], +) + spu_cc_library( name = "debug_dump_constant", srcs = [ diff --git a/libspu/dialect/pphlo/IR/dialect.td b/libspu/dialect/pphlo/IR/dialect.td index de2c648c..7d47197c 100644 --- a/libspu/dialect/pphlo/IR/dialect.td +++ b/libspu/dialect/pphlo/IR/dialect.td @@ -32,15 +32,14 @@ def PPHlo_Dialect : Dialect { string summary = "Privacy-Preserving HLO(PPHLO) dialect"; string description = [{ PPHLO represents a high level abstraction for language use by SPU. - It implements a subset of mlir mhlo ops with it's own privacy-preserving focused type system. + It implements a subset of mlir stablehlo ops with it's own privacy-preserving focused type system. - Learn more about mlir hlo at https://github.com/tensorflow/mlir-hlo + Learn more about mlir stablehlo at https://github.com/openxla/stablehlo }]; let name = "pphlo"; let cppNamespace = "::mlir::spu::pphlo"; let useDefaultAttributePrinterParser = 0; let useDefaultTypePrinterParser = 0; - let usePropertiesForAttributes = 0; let hasConstantMaterializer = 1; let extraClassDeclaration = [{ Attribute parseAttribute(DialectAsmParser & parser, Type type) diff --git a/libspu/dialect/pphlo/IR/fold.cc b/libspu/dialect/pphlo/IR/fold.cc index 7946c07b..f4f42e45 100644 --- a/libspu/dialect/pphlo/IR/fold.cc +++ b/libspu/dialect/pphlo/IR/fold.cc @@ -49,6 +49,14 @@ OpFoldResult ReverseOp::fold(FoldAdaptor) { dims, [&](int64_t dim) { return shapedType.getDimSize(dim) == 1; })) { return input; } + + // reverse(reverse(x, dims), dims) = x + if (auto prev = input.getDefiningOp()) { + if (prev.getDimensions() == dims) { + return prev.getOperand(); + } + } + return {}; } diff --git a/libspu/dialect/pphlo/IR/ops.cc b/libspu/dialect/pphlo/IR/ops.cc index 564d44e1..e65313cc 100644 --- a/libspu/dialect/pphlo/IR/ops.cc +++ b/libspu/dialect/pphlo/IR/ops.cc @@ -15,24 +15,16 @@ #include "libspu/dialect/pphlo/IR/ops.h" #include "fmt/format.h" +#include "fmt/ranges.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/IR/Builders.h" #include "mlir/IR/TypeUtilities.h" +#include "stablehlo/dialect/TypeInference.h" #include "libspu/dialect/pphlo/IR/ops.h.inc" namespace mlir::spu::pphlo { -namespace { - -// Checks if the vector `nums` has duplicates. -bool hasDuplicates(const ArrayRef nums) { - llvm::SmallDenseSet set(nums.begin(), nums.end()); - return set.size() != nums.size(); -} - -} // namespace - template static LogicalResult Verify(T /*op*/) { return success(); @@ -386,75 +378,12 @@ LogicalResult ConcatenateOp::verify() { } LogicalResult BroadcastOp::verify() { - auto operandType = mlir::dyn_cast(getOperand().getType()); - - auto operandRank = operandType.getRank(); - - if (getBroadcastDimensions().empty()) { - if (operandRank == 0) { - return success(); - } - return emitOpError( - llvm::formatv("broadcast_dimensions is absent, but required because " - "operand has non-zero rank ({0})", - operandRank)); - } - - auto dimensionsSize = getBroadcastDimensions().size(); - if (static_cast(dimensionsSize) != operandRank) { - return emitOpError(llvm::formatv( - "broadcast_dimensions size ({0}) does not match operand rank ({1})", - dimensionsSize, operandRank)); - } - - auto dimensions = getBroadcastDimensions(); - if (hasDuplicates(dimensions)) { - return emitOpError("broadcast_dimensions should not have duplicates"); - } - - auto resultType = mlir::dyn_cast(getResult().getType()); - auto resultRank = resultType.getRank(); - - for (size_t i = 0; i != dimensionsSize; ++i) { - auto dimIndex = dimensions[i]; - if ((dimIndex >= resultRank) || (dimIndex < 0)) { - return emitOpError( - llvm::formatv("broadcast_dimensions contains invalid value {0} for " - "result with rank {1}", - dimIndex, resultRank)); - } - - if (!operandType.isDynamicDim(i)) { - auto dimSize = operandType.getDimSize(i); - auto resultDimSize = resultType.getDimSize(dimIndex); - if (dimSize != 1 && dimSize != resultDimSize) { - return emitOpError( - llvm::formatv("size of operand dimension {0} ({1}) is not equal to " - "1 or size of result dimension {2} ({3})", - i, dimSize, dimIndex, resultDimSize)); - } - } - } - - return success(); + return hlo::verifyBroadcastInDimOp(getLoc(), getOperand(), + getBroadcastDimensions(), getResult()); } LogicalResult IotaOp::verify() { - auto shape = mlir::dyn_cast(getType()); - if (!shape.hasRank()) { - return success(); - } - - if (shape.getRank() == 0) { - return emitOpError() << "does not support scalars."; - } - - auto iotaDimension = static_cast(this->getIotaDimension()); - if (iotaDimension >= shape.getRank() || iotaDimension < 0) { - return emitOpError() - << "iota dimension cannot go beyond the output rank or be negative."; - } - return success(); + return hlo::verifyIotaOp(getLoc(), getIotaDimension(), getResult()); } LogicalResult SliceOp::verify() { diff --git a/libspu/dialect/pphlo/IR/print_parse.cc b/libspu/dialect/pphlo/IR/print_parse.cc index 8716e011..6f693e30 100644 --- a/libspu/dialect/pphlo/IR/print_parse.cc +++ b/libspu/dialect/pphlo/IR/print_parse.cc @@ -105,9 +105,19 @@ ParseResult ConstantOp::parse(OpAsmParser& parser, OperationState& result) { if (parser.parseRParen()) { return failure(); } + // Parse optional properties + if (succeeded(parser.parseOptionalLess()) && + (failed(parser.parseAttribute(result.propertiesAttr)) || + failed(parser.parseGreater()))) { + return failure(); + } + + // Parse optional attributes if (parser.parseOptionalAttrDict(result.attributes)) { return failure(); } + + // Parse type signature if (parser.parseColon() || parser.parseLParen() || parser.parseRParen() || parser.parseArrow()) { return failure(); diff --git a/libspu/dialect/pphlo/IR/type_inference.cc b/libspu/dialect/pphlo/IR/type_inference.cc index 4e3a9464..1b22d8a7 100644 --- a/libspu/dialect/pphlo/IR/type_inference.cc +++ b/libspu/dialect/pphlo/IR/type_inference.cc @@ -285,7 +285,7 @@ LogicalResult PadOp::inferReturnTypes( ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>& inferredReturnTypes) { - PadOp::Adaptor adaptor(operands, attributes, {}, regions); + PadOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferPadOp(location, adaptor.getOperand().getType(), adaptor.getPaddingValue().getType(), adaptor.getEdgePaddingLow(), @@ -295,27 +295,27 @@ LogicalResult PadOp::inferReturnTypes( LogicalResult ConcatenateOp::inferReturnTypes( MLIRContext*, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferred_return_types) { - ConcatenateOp::Adaptor adaptor(operands, attributes, {}, regions); + ConcatenateOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferConcatenateOp(location, adaptor.getInputs().getTypes(), adaptor.getDimension(), inferred_return_types); } LogicalResult TransposeOp::inferReturnTypes( MLIRContext*, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferred_return_types) { - TransposeOp::Adaptor adaptor(operands, attributes, {}, regions); + TransposeOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferTransposeOp(location, adaptor.getOperand(), adaptor.getPermutation(), inferred_return_types); } LogicalResult SliceOp::inferReturnTypes( MLIRContext*, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferred_return_types) { - SliceOp::Adaptor adaptor(operands, attributes, {}, regions); + SliceOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferSliceOp(location, adaptor.getOperand().getType(), adaptor.getStartIndices(), adaptor.getLimitIndices(), adaptor.getStrides(), inferred_return_types); @@ -375,9 +375,9 @@ LogicalResult inferDynamicSliceOp(std::optional location, LogicalResult DynamicSliceOp::inferReturnTypes( MLIRContext*, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { - DynamicSliceOp::Adaptor adaptor(operands, attributes, {}, regions); + DynamicSliceOp::Adaptor adaptor(operands, attributes, properties, regions); return inferDynamicSliceOp(location, adaptor.getOperand().getType(), adaptor.getStartIndices().getTypes(), adaptor.getSliceSizes(), inferredReturnTypes); @@ -420,16 +420,26 @@ LogicalResult inferDynamicUpdateSliceOp( } // dynamic_update_slice_c1 - inferredReturnTypes.emplace_back(RankedTensorType::get( - operandType.getShape(), operandType.getElementType())); + TypeTools tools(operand.getContext()); + auto vis = llvm::map_to_vector(startIndices, [&](mlir::Value v) { + return tools.getTypeVisibility(v.getType()); + }); + vis.emplace_back(tools.getTypeVisibility(operand.getType())); + vis.emplace_back(tools.getTypeVisibility(update.getType())); + + inferredReturnTypes.emplace_back( + RankedTensorType::get(operandType.getShape(), + tools.getType(operandType.getElementType(), + tools.computeCommonVisibility(vis)))); return success(); } LogicalResult DynamicUpdateSliceOp::inferReturnTypes( MLIRContext*, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { - DynamicUpdateSliceOp::Adaptor adaptor(operands, attributes, {}, regions); + DynamicUpdateSliceOp::Adaptor adaptor(operands, attributes, properties, + regions); return inferDynamicUpdateSliceOp( location, adaptor.getOperand(), adaptor.getUpdate(), diff --git a/libspu/dialect/pphlo/transforms/hlo_legalize_to_pphlo.cc b/libspu/dialect/pphlo/transforms/hlo_legalize_to_pphlo.cc index e17f85e5..97c41370 100644 --- a/libspu/dialect/pphlo/transforms/hlo_legalize_to_pphlo.cc +++ b/libspu/dialect/pphlo/transforms/hlo_legalize_to_pphlo.cc @@ -133,24 +133,20 @@ class FuncOpConverter : public OpConversionPattern<::mlir::func::FuncOp> { auto ®ion = op.getBody(); // Convert non-entry blocks - SmallVector conversions; - for (Block &block : llvm::drop_begin(region, 1)) { - conversions.emplace_back(block.getNumArguments()); - TypeConverter::SignatureConversion &back = conversions.back(); + for (Block &block : + llvm::make_early_inc_range(llvm::drop_begin(region, 1))) { + TypeConverter::SignatureConversion conversion( + /*numOrigInputs=*/block.getNumArguments()); for (BlockArgument blockArgument : block.getArguments()) { auto idx = blockArgument.getArgNumber(); auto vis_v = vis_.getValueVisibility(blockArgument); auto convertedType = tools_.getType( typeConverter->convertType(blockArgument.getType()), vis_v); - back.addInputs(idx, convertedType); + conversion.addInputs(idx, convertedType); } - } - if (failed(rewriter.convertNonEntryRegionTypes(®ion, *typeConverter, - conversions))) { - rewriter.cancelOpModification(op); - return failure(); + rewriter.applySignatureConversion(&block, conversion, getTypeConverter()); } // Convert function arguments using the provided TypeConverter. @@ -1267,10 +1263,9 @@ class HloToPPHloOpConverter Type result_type = typetools_.getType( this->getTypeConverter()->convertType(op.getType()), result_vis); - auto materialized = materializeInputs(op, op->getOperands()); + auto materialized = materializeInputs(op, adaptor.getOperands()); rewriter.replaceOpWithNewOp( - op, result_type, materialized[0], materialized[1], - adaptor.getStartIndices()); + op, TypeRange{result_type}, materialized); return success(); } diff --git a/libspu/dialect/pphlo/transforms/inline_secret_control_flow.cc b/libspu/dialect/pphlo/transforms/inline_secret_control_flow.cc index 981a8c41..9f801875 100644 --- a/libspu/dialect/pphlo/transforms/inline_secret_control_flow.cc +++ b/libspu/dialect/pphlo/transforms/inline_secret_control_flow.cc @@ -36,9 +36,8 @@ class CaseConverter : public OpRewritePattern { if (target_type.getNumElements() == in_type.getNumElements()) { return rewriter.create(loc, broadcasted_mask_type, in); } else { - return rewriter.create( - loc, broadcasted_mask_type, in, - llvm::SmallVector(target_type.getRank(), 0)); + return rewriter.create(loc, broadcasted_mask_type, in, + llvm::SmallVector{0}); } } diff --git a/libspu/dialect/pphlo/transforms/partial_sort_to_topk.cc b/libspu/dialect/pphlo/transforms/partial_sort_to_topk.cc index 139f2650..8d10e7a5 100644 --- a/libspu/dialect/pphlo/transforms/partial_sort_to_topk.cc +++ b/libspu/dialect/pphlo/transforms/partial_sort_to_topk.cc @@ -172,8 +172,7 @@ struct SortConversion : public OpRewritePattern { } // rewrite all slices - for (const auto &use : uses) { - auto slice = mlir::dyn_cast(use.getOwner()); + for (auto &slice : slices_to_rewrite) { auto offset = slice.getStartIndices()[sort_dim] - start; llvm::SmallVector new_start(slice.getStartIndices().begin(), slice.getStartIndices().end()); diff --git a/libspu/dialect/utils/BUILD.bazel b/libspu/dialect/utils/BUILD.bazel index acfb2b05..2182bc8d 100644 --- a/libspu/dialect/utils/BUILD.bazel +++ b/libspu/dialect/utils/BUILD.bazel @@ -22,6 +22,7 @@ spu_cc_library( hdrs = glob([ "*.h", ]), + visibility = ["//visibility:public"], deps = [ "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", diff --git a/libspu/kernel/hal/BUILD.bazel b/libspu/kernel/hal/BUILD.bazel index 56a64688..b5aeef41 100644 --- a/libspu/kernel/hal/BUILD.bazel +++ b/libspu/kernel/hal/BUILD.bazel @@ -127,6 +127,7 @@ spu_cc_test( deps = [ ":fxp_approx", "//libspu/kernel:test_util", + "//libspu/mpc/utils:simulate", ], ) @@ -232,7 +233,6 @@ spu_cc_library( hdrs = ["utils.h"], deps = [ ":constants", - ":polymorphic", ":ring", ":shape_ops", "//libspu/core:prelude", diff --git a/libspu/kernel/hal/constants.cc b/libspu/kernel/hal/constants.cc index 1b44cbd0..8659c5e3 100644 --- a/libspu/kernel/hal/constants.cc +++ b/libspu/kernel/hal/constants.cc @@ -103,7 +103,7 @@ spu::Value zeros(SPUContext* ctx, DataType dtype, const Shape& shape) { } Value iota(SPUContext* ctx, DataType dtype, int64_t numel) { - return DISPATCH_ALL_NONE_BOOL_PT_TYPES(getDecodeType(dtype), "iota", [&]() { + return DISPATCH_ALL_NONE_BOOL_PT_TYPES(getDecodeType(dtype), [&]() { std::vector arr(numel); std::iota(arr.begin(), arr.end(), 0); return constant(ctx, arr, dtype, {numel}); diff --git a/libspu/kernel/hal/fxp_approx.cc b/libspu/kernel/hal/fxp_approx.cc index 5cda83df..34667e84 100644 --- a/libspu/kernel/hal/fxp_approx.cc +++ b/libspu/kernel/hal/fxp_approx.cc @@ -29,18 +29,6 @@ namespace spu::kernel::hal { namespace detail { -Value EvaluatePolynomial(SPUContext* ctx, const Value& x, - absl::Span coefficients) { - auto poly = constant(ctx, coefficients[0], x.dtype(), x.shape()); - - for (size_t i = 1; i < coefficients.size(); ++i) { - auto c = constant(ctx, coefficients[i], x.dtype(), x.shape()); - poly = f_mul(ctx, poly, x); - poly = f_add(ctx, poly, c); - } - return poly; -} - Value log_minmax_normalized(SPUContext* ctx, const Value& x) { static std::array kLogCoefficient{ 0.0, 0.9999964239, -0.4998741238, 0.3317990258, -0.2407338084, @@ -69,12 +57,12 @@ Value log_minmax(SPUContext* ctx, const Value& x) { // get most significant non-zero bit of x // we avoid direct using detail::highestOneBit for saving one _prefix_or - auto pre_x1 = _rshift(ctx, pre_x, 1); + auto pre_x1 = _rshift(ctx, pre_x, {1}); auto msb = _xor(ctx, pre_x, pre_x1); // let x = x_norm * factor, where x in [1.0, 2.0) auto factor = _bitrev(ctx, msb, 0, 2 * num_fxp_bits + 1).setDtype(x.dtype()); - detail::hintNumberOfBits(factor, 2 * num_fxp_bits + 1); + factor = maskNumberOfBits(ctx, factor, 2 * num_fxp_bits + 1); auto norm = f_mul(ctx, x, factor); // log(x) = log(x_norm * factor) @@ -83,7 +71,7 @@ Value log_minmax(SPUContext* ctx, const Value& x) { auto log_norm = log_minmax_normalized(ctx, norm); auto log2_e = _lshift(ctx, _sub(ctx, k, _constant(ctx, num_fxp_bits + 1, x.shape())), - num_fxp_bits) + {static_cast(num_fxp_bits)}) .setDtype(x.dtype()); auto k_log2 = constant(ctx, std::log(2), x.dtype(), x.shape()); auto log_e = f_mul(ctx, log2_e, k_log2); @@ -145,7 +133,7 @@ Value log2_pade(SPUContext* ctx, const Value& x) { // let x = x_norm * factor, where x in [0.5, 1.0) auto msb = detail::highestOneBit(ctx, x); auto factor = _bitrev(ctx, msb, 0, 2 * num_fxp_bits).setDtype(x.dtype()); - detail::hintNumberOfBits(factor, 2 * num_fxp_bits); + factor = maskNumberOfBits(ctx, factor, 2 * num_fxp_bits); auto norm = f_mul(ctx, x, factor); // log2(x) = log2(x_norm * factor) @@ -154,7 +142,7 @@ Value log2_pade(SPUContext* ctx, const Value& x) { return _add( ctx, log2_pade_normalized(ctx, norm), _lshift(ctx, _sub(ctx, k, _constant(ctx, num_fxp_bits, x.shape())), - num_fxp_bits)) + {static_cast(num_fxp_bits)})) .setDtype(x.dtype()); } @@ -213,6 +201,31 @@ Value exp_taylor(SPUContext* ctx, const Value& x) { return res; } +Value exp_prime(SPUContext* ctx, const Value& x) { + auto clamped_x = x; + auto offset = ctx->config().experimental_exp_prime_offset(); + auto fxp = ctx->getFxpBits(); + if (!ctx->config().experimental_exp_prime_disable_lower_bound()) { + // currently the bound is tied to FM128 + SPU_ENFORCE_EQ(ctx->getField(), FieldType::FM128); + auto lower_bound = (48.0 - offset - 2.0 * fxp) / M_LOG2E; + clamped_x = _clamp_lower(ctx, clamped_x, + constant(ctx, lower_bound, x.dtype(), x.shape())) + .setDtype(x.dtype()); + } + if (ctx->config().experimental_exp_prime_enable_upper_bound()) { + // currently the bound is tied to FM128 + SPU_ENFORCE_EQ(ctx->getField(), FieldType::FM128); + auto upper_bound = (124.0 - 2.0 * fxp - offset) / M_LOG2E; + clamped_x = _clamp_upper(ctx, clamped_x, + constant(ctx, upper_bound, x.dtype(), x.shape())) + .setDtype(x.dtype()); + } + + auto ret = dynDispatch(ctx, "exp_a", clamped_x); + return ret.setDtype(x.dtype()); +} + namespace { // Pade approximation of exp2(x), x is in [0, 1]. @@ -260,15 +273,18 @@ Value exp2_pade(SPUContext* ctx, const Value& x) { const size_t bit_width = SizeOf(ctx->getField()) * 8; const auto x_bshare = _prefer_b(ctx, x); - const auto x_msb = _rshift(ctx, x_bshare, bit_width - 1); - auto x_integer = _rshift(ctx, x_bshare, fbits); + const auto x_msb = + _rshift(ctx, x_bshare, {static_cast(bit_width - 1)}); + auto x_integer = _rshift(ctx, x_bshare, {static_cast(fbits)}); auto x_fraction = - _sub(ctx, x, _lshift(ctx, x_integer, fbits)).setDtype(x.dtype()); + _sub(ctx, x, _lshift(ctx, x_integer, {static_cast(fbits)})) + .setDtype(x.dtype()); auto ret = exp2_pade_normalized(ctx, x_fraction); for (size_t idx = 0; idx < int_bits; idx++) { - auto a = _and(ctx, _rshift(ctx, x_integer, idx), k1); - detail::hintNumberOfBits(a, 1); + auto a = + _and(ctx, _rshift(ctx, x_integer, {static_cast(idx)}), k1); + a = detail::maskNumberOfBits(ctx, a, 1); a = _prefer_a(ctx, a); const auto K = 1U << std::min(1UL << idx, bit_width - 2); ret = _mul(ctx, ret, @@ -448,13 +464,22 @@ Value f_exp(SPUContext* ctx, const Value& x) { case RuntimeConfig::EXP_PADE: { // The valid input for exp_pade is [-kInputLimit, kInputLimit]. // TODO(junfeng): should merge clamp into exp_pade to save msb ops. - const float kInputLimit = 32 / std::log2(std::exp(1)); + const float kInputLimit = 32.0 / std::log2(std::exp(1)); const auto clamped_x = _clamp(ctx, x, constant(ctx, -kInputLimit, x.dtype(), x.shape()), constant(ctx, kInputLimit, x.dtype(), x.shape())) .setDtype(x.dtype()); return detail::exp_pade(ctx, clamped_x); } + case RuntimeConfig::EXP_PRIME: + if (ctx->hasKernel("exp_a")) { + return detail::exp_prime(ctx, x); + } else { + SPU_THROW( + "exp_a is not implemented for this protocol, currently only " + "2pc " + "semi2k is supported."); + } default: SPU_THROW("unexpected exp approximation method {}", ctx->config().fxp_exp_mode()); @@ -543,7 +568,7 @@ static Value rsqrt_init_guess(SPUContext* ctx, const Value& x, const Value& z) { // let u in [0.25, 0.5) auto z_rev = _bitrev(ctx, z, 0, 2 * f); - detail::hintNumberOfBits(z_rev, 2 * f); + z_rev = detail::maskNumberOfBits(ctx, z_rev, 2 * f); auto u = _trunc(ctx, _mul(ctx, x, z_rev)).setDtype(x.dtype()); @@ -583,17 +608,17 @@ static Value rsqrt_comp(SPUContext* ctx, const Value& x, const Value& z) { auto lo_mask = _constant(ctx, (static_cast(1) << (k / 2)) - 1, x.shape()); auto z_even = _and(ctx, z_sep, lo_mask); - auto z_odd = _and(ctx, _rshift(ctx, z_sep, k / 2), lo_mask); + auto z_odd = + _and(ctx, _rshift(ctx, z_sep, {static_cast(k / 2)}), lo_mask); // a[i] = z[2*i] ^ z[2*i+1] a = _xor(ctx, z_odd, z_even); // b ^= z[2*i] b = _bit_parity(ctx, z_even, k / 2); - detail::hintNumberOfBits(b, 1); } auto a_rev = _bitrev(ctx, a, 0, (f / 2) * 2); - detail::hintNumberOfBits(a_rev, (f / 2) * 2); + a_rev = detail::maskNumberOfBits(ctx, a_rev, (f / 2) * 2); // do compensation // Note: @@ -623,7 +648,7 @@ static Value rsqrt_np2(SPUContext* ctx, const Value& x) { SPU_TRACE_HAL_LEAF(ctx, x); // let e = NP2(x), z = 2^(e+f) - return _lshift(ctx, detail::highestOneBit(ctx, x), 1); + return _lshift(ctx, detail::highestOneBit(ctx, x), {1}); } // Reference: diff --git a/libspu/kernel/hal/fxp_approx.h b/libspu/kernel/hal/fxp_approx.h index fa401887..44c724f3 100644 --- a/libspu/kernel/hal/fxp_approx.h +++ b/libspu/kernel/hal/fxp_approx.h @@ -38,6 +38,8 @@ Value exp2_pade(SPUContext* ctx, const Value& x); // Works for range [-12.0, 18.0] Value exp_pade(SPUContext* ctx, const Value& x); +Value exp_prime(SPUContext* ctx, const Value& x); + Value tanh_chebyshev(SPUContext* ctx, const Value& x); } // namespace detail diff --git a/libspu/kernel/hal/fxp_approx_test.cc b/libspu/kernel/hal/fxp_approx_test.cc index c79dc434..d540eb2b 100644 --- a/libspu/kernel/hal/fxp_approx_test.cc +++ b/libspu/kernel/hal/fxp_approx_test.cc @@ -20,6 +20,7 @@ #include "libspu/kernel/hal/constants.h" #include "libspu/kernel/hal/type_cast.h" #include "libspu/kernel/test_util.h" +#include "libspu/mpc/utils/simulate.h" namespace spu::kernel::hal { @@ -78,10 +79,35 @@ TEST(FxpTest, ExponentialPade) { << y; } +TEST(FxpTest, ExponentialPrime) { + std::cout << "test exp_prime" << std::endl; + spu::mpc::utils::simulate(2, [&](std::shared_ptr lctx) { + RuntimeConfig conf; + conf.set_protocol(ProtocolKind::SEMI2K); + conf.set_field(FieldType::FM128); + conf.set_fxp_fraction_bits(40); + conf.set_experimental_enable_exp_prime(true); + SPUContext ctx = test::makeSPUContext(conf, lctx); + + auto offset = ctx.config().experimental_exp_prime_offset(); + auto fxp = ctx.getFxpBits(); + auto lower_bound = (48.0 - offset - 2.0 * fxp) / M_LOG2E; + auto upper_bound = (124.0 - 2.0 * fxp - offset) / M_LOG2E; + + xt::xarray x = xt::linspace(lower_bound, upper_bound, 4000); + + Value a = test::makeValue(&ctx, x, VIS_SECRET); + Value c = detail::exp_prime(&ctx, a); + auto y = dump_public_as(&ctx, reveal(&ctx, c)); + EXPECT_TRUE(xt::allclose(xt::exp(x), y, 0.01, 0.001)) + << xt::exp(x) << std::endl + << y; + }); +} + TEST(FxpTest, Log) { // GIVEN SPUContext ctx = test::makeSPUContext(); - xt::xarray x = {{0.05, 0.5}, {5, 50}}; // public log { diff --git a/libspu/kernel/hal/fxp_base.cc b/libspu/kernel/hal/fxp_base.cc index 209782e1..486f66b4 100644 --- a/libspu/kernel/hal/fxp_base.cc +++ b/libspu/kernel/hal/fxp_base.cc @@ -34,26 +34,34 @@ Value polynomial(SPUContext* ctx, const Value& x, SPU_ENFORCE(x.isFxp()); SPU_ENFORCE(!coeffs.empty()); - if (coeffs.size() == 1U) { + if (coeffs.size() == 1U || x.numel() == 0) { return coeffs[0]; } - Value x_pow = constant(ctx, 1.0F, x.dtype(), x.shape()); - Value res = _mul(ctx, x_pow, coeffs[0]); + // Use a parallel circuit to calculate x, x^2, x^3, ..., x^n. + // The general log(n) algorithm + // algorithm: + // Step 0. x + // Step 1. x, x2 + // Step 2. x, x2, x3, x4 + // ... + std::vector x_prefix(1, x); + size_t degree = coeffs.size() - 1; + for (int64_t i = 0; i < Log2Ceil(degree); ++i) { + size_t x_size = std::min(x_prefix.size(), degree - x_prefix.size()); + std::vector x_pow(x_size, x_prefix.back()); + // TODO: this can be further optimized to use sign hint + vmap(x_prefix.begin(), x_prefix.begin() + x_size, x_pow.begin(), + x_pow.end(), std::back_inserter(x_prefix), + [ctx, sign_x](const Value& a, const Value& b) { + return f_mul(ctx, a, b, sign_x); + }); + } + + Value res = _mul(ctx, constant(ctx, 1.0F, x.dtype(), x.shape()), coeffs[0]); const auto fbits = ctx->getFxpBits(); for (size_t i = 1; i < coeffs.size(); i++) { - if ((i & 1) == 0U) { - // x^{even order} is always positive - x_pow = _trunc(ctx, _mul(ctx, x_pow, x), fbits, SignType::Positive); - } else { - if (i > 1) { - x_pow = _trunc(ctx, _mul(ctx, x_pow, x), fbits, sign_x); - } else { - // i=1, then save a _trunc - x_pow = x; - } - } - res = _add(ctx, res, _mul(ctx, x_pow, coeffs[i])); + res = _add(ctx, res, _mul(ctx, x_prefix[i - 1], coeffs[i])); } return _trunc(ctx, res, fbits, sign_ret).setDtype(x.dtype()); @@ -72,7 +80,7 @@ Value polynomial(SPUContext* ctx, const Value& x, Value highestOneBit(SPUContext* ctx, const Value& x) { auto y = _prefix_or(ctx, x); - auto y1 = _rshift(ctx, y, 1); + auto y1 = _rshift(ctx, y, {1}); return _xor(ctx, y, y1); } @@ -85,8 +93,14 @@ void hintNumberOfBits(const Value& a, size_t nbits) { } } -namespace { +Value maskNumberOfBits(SPUContext* ctx, const Value& in, size_t nbits) { + auto k1 = constant(ctx, static_cast(1), spu::DT_I64, in.shape()); + auto mask = _sub(ctx, _lshift(ctx, k1, {static_cast(nbits)}), k1); + auto out = _and(ctx, in, mask).setDtype(in.dtype()); + return out; +} +namespace { Value reciprocal_goldschmidt_normalized_approx(SPUContext* ctx, const Value& b_abs, const Value& factor) { @@ -178,7 +192,8 @@ Value div_goldschmidt_general(SPUContext* ctx, const Value& a, const Value& b, // factor = 2^{f-m} = 2^{-m} * 2^f, the fixed point repr of 2^{-m} const size_t num_fxp_bits = ctx->getFxpBits(); auto factor = _bitrev(ctx, b_msb, 0, 2 * num_fxp_bits).setDtype(b.dtype()); - detail::hintNumberOfBits(factor, 2 * num_fxp_bits); + + factor = maskNumberOfBits(ctx, factor, 2 * num_fxp_bits); // also, we use factor twice factor = _prefer_a(ctx, factor); @@ -209,7 +224,7 @@ Value reciprocal_goldschmidt_positive(SPUContext* ctx, const Value& b_abs) { const size_t num_fxp_bits = ctx->getFxpBits(); auto factor = _bitrev(ctx, b_msb, 0, 2 * num_fxp_bits).setDtype(b_abs.dtype()); - detail::hintNumberOfBits(factor, 2 * num_fxp_bits); + factor = maskNumberOfBits(ctx, factor, 2 * num_fxp_bits); // also, we use factor twice factor = _prefer_a(ctx, factor); @@ -237,13 +252,12 @@ Value reciprocal_goldschmidt(SPUContext* ctx, const Value& b) { // factor = 2^{f-m} = 2^{-m} * 2^f, the fixed point repr of 2^{-m} const size_t num_fxp_bits = ctx->getFxpBits(); auto factor = _bitrev(ctx, b_msb, 0, 2 * num_fxp_bits).setDtype(b.dtype()); - detail::hintNumberOfBits(factor, 2 * num_fxp_bits); + factor = maskNumberOfBits(ctx, factor, 2 * num_fxp_bits); // also, we use factor twice factor = _prefer_a(ctx, factor); // compute approximation of normalize b_abs auto r = reciprocal_goldschmidt_normalized_approx(ctx, b_abs, factor); - r = f_mul(ctx, r, factor, SignType::Positive); return _mux(ctx, is_negative, _negate(ctx, r), r).setDtype(b.dtype()); @@ -370,8 +384,8 @@ Value f_floor(SPUContext* ctx, const Value& x) { SPU_ENFORCE(x.isFxp()); - const size_t fbits = ctx->getFxpBits(); - return _lshift(ctx, _arshift(ctx, x, fbits), fbits).setDtype(x.dtype()); + const int64_t fbits = ctx->getFxpBits(); + return _lshift(ctx, _arshift(ctx, x, {fbits}), {fbits}).setDtype(x.dtype()); } Value f_ceil(SPUContext* ctx, const Value& x) { diff --git a/libspu/kernel/hal/fxp_base.h b/libspu/kernel/hal/fxp_base.h index c72022b6..61fe7bca 100644 --- a/libspu/kernel/hal/fxp_base.h +++ b/libspu/kernel/hal/fxp_base.h @@ -30,6 +30,8 @@ Value highestOneBit(SPUContext* ctx, const Value& x); void hintNumberOfBits(const Value& a, size_t nbits); +Value maskNumberOfBits(SPUContext* ctx, const Value& a, size_t nbits); + // we provide this general function to support some special cases (a or b has // guarranteed sign) in fxp_approx for better both performance and accuracy. Value div_goldschmidt_general(SPUContext* ctx, const Value& a, const Value& b, diff --git a/libspu/kernel/hal/fxp_cleartext.cc b/libspu/kernel/hal/fxp_cleartext.cc index 818e0b23..b5061d5b 100644 --- a/libspu/kernel/hal/fxp_cleartext.cc +++ b/libspu/kernel/hal/fxp_cleartext.cc @@ -57,7 +57,7 @@ Value applyFloatingPointFn(SPUContext* ctx, const Value& in, FN&& fn) { auto pt_type = getDecodeType(in.dtype()); for (auto iter = fp_arr.begin(); iter != fp_arr.end(); ++iter) { - DISPATCH_FLOAT_PT_TYPES(pt_type, "pt_type", [&]() { + DISPATCH_FLOAT_PT_TYPES(pt_type, [&]() { auto* ptr = reinterpret_cast(&*iter); *ptr = fn(*ptr); }); @@ -92,9 +92,9 @@ Value applyFloatingPointFn(SPUContext* ctx, const Value& x, const Value& y, for (auto itr_x = flp_x.begin(), itr_y = flp_y.begin(); itr_x != flp_x.end(); itr_x++, itr_y++) { - DISPATCH_FLOAT_PT_TYPES(x_pt_type, "x_pt_type", [&]() { + DISPATCH_FLOAT_PT_TYPES(x_pt_type, [&]() { auto* ptr_x = reinterpret_cast(&*itr_x); - DISPATCH_FLOAT_PT_TYPES(y_pt_type, "y_pt_type", [&]() { + DISPATCH_FLOAT_PT_TYPES(y_pt_type, [&]() { auto* ptr_y = reinterpret_cast(&*itr_y); *ptr_x = fn(*ptr_x, *ptr_y); }); diff --git a/libspu/kernel/hal/permute.cc b/libspu/kernel/hal/permute.cc index 1c67c997..b7c34ff2 100644 --- a/libspu/kernel/hal/permute.cc +++ b/libspu/kernel/hal/permute.cc @@ -553,7 +553,8 @@ std::vector _bit_decompose(SPUContext *ctx, const spu::Value &x, rets_b.reserve(nbits); for (size_t bit = 0; bit < nbits; ++bit) { - auto x_bshare_shift = right_shift_logical(ctx, x_bshare, bit); + auto x_bshare_shift = + right_shift_logical(ctx, x_bshare, {static_cast(bit)}); rets_b.push_back(_and(ctx, x_bshare_shift, k1)); } @@ -703,10 +704,10 @@ spu::Value _apply_perm_ss(SPUContext *ctx, const Value &x, const Value &perm) { // Find mergeable keys from keys. Consecutive public/private(belong to one // owner) keys can be merged. Assume there are six keys, i.e., public_key0, -// bob_key0, bob_key1, alice_key0, alice_key1, secret_key0. We can merge the six -// keys into bob_new_key, alice_new_key, secret_key0 for the following sorting. -// This function will return a vector of indices [3,5,6] which means key[0,3), -// key[3,5), and key[5,6) can be merged. +// bob_key0, bob_key1, alice_key0, alice_key1, secret_key0. We can merge the +// six keys into bob_new_key, alice_new_key, secret_key0 for the following +// sorting. This function will return a vector of indices [3,5,6] which means +// key[0,3), key[3,5), and key[5,6) can be merged. std::vector _find_mergeable_keys(SPUContext *ctx, absl::Span keys) { std::vector split_indices; @@ -1158,7 +1159,7 @@ std::vector permute(SPUContext *ctx, for (auto const &input : inputs) { auto transposed = hal::transpose(ctx, input, perm); auto reshaped = hal::reshape(ctx, transposed, {N, W}); - inputs2d.push_back(reshaped); + inputs2d.push_back(std::move(reshaped)); } // Call permute1d for each dim to permute. diff --git a/libspu/kernel/hal/polymorphic.cc b/libspu/kernel/hal/polymorphic.cc index 34cf5f77..21680cc6 100644 --- a/libspu/kernel/hal/polymorphic.cc +++ b/libspu/kernel/hal/polymorphic.cc @@ -356,7 +356,7 @@ Value power(SPUContext* ctx, const Value& x, const Value& y) { const auto bit_width = SizeOf(ctx->getField()) * 8; auto y_b = _prefer_b(ctx, y); - auto msb_y = _rshift(ctx, y_b, bit_width - 1); + auto msb_y = _rshift(ctx, y_b, {static_cast(bit_width - 1)}); auto x_abs1 = _equal(ctx, abs(ctx, x), k1); auto ret = _constant(ctx, 1, x.shape()); @@ -379,7 +379,9 @@ Value power(SPUContext* ctx, const Value& x, const Value& y) { // e.g. y=0101, then ret = (x) * (1) * (x^(2^2)) * (1) = x^5 for (size_t idx = 0; idx < y_bits; idx++) { // x^(2^idx) * y_{idx} - auto cur_pow = _mux(ctx, _and(ctx, _rshift(ctx, y_b, idx), k1), base, k1); + auto cur_pow = _mux( + ctx, _and(ctx, _rshift(ctx, y_b, {static_cast(idx)}), k1), + base, k1); ret = _mul(ctx, cur_pow, ret); if (idx < y_bits - 1) { base = _mul(ctx, base, base); @@ -409,8 +411,9 @@ Value power(SPUContext* ctx, const Value& x, const Value& y) { // the final sign is decided on both sign of x and the parity of y // when x<0 and y is odd, e.g. (-2)^3 = -8 - auto odd = _and(ctx, _rshift(ctx, y, ctx->getFxpBits()), - _constant(ctx, 1, y.shape())); + auto odd = + _and(ctx, _rshift(ctx, y, {static_cast(ctx->getFxpBits())}), + _constant(ctx, 1, y.shape())); auto sign = _and(ctx, msb, odd); return _mux(ctx, sign, _negate(ctx, val), val).setDtype(x.dtype()); @@ -488,19 +491,20 @@ Value bitcast(SPUContext* ctx, const Value& x, DataType dtype) { return Value(x.data().clone(), dtype); } -Value left_shift(SPUContext* ctx, const Value& x, size_t bits) { +Value left_shift(SPUContext* ctx, const Value& x, const Sizes& bits) { SPU_TRACE_HAL_DISP(ctx, x, bits); return _lshift(ctx, x, bits).setDtype(x.dtype()); } -Value right_shift_logical(SPUContext* ctx, const Value& x, size_t bits) { +Value right_shift_logical(SPUContext* ctx, const Value& x, const Sizes& bits) { SPU_TRACE_HAL_DISP(ctx, x, bits); return _rshift(ctx, x, bits).setDtype(x.dtype()); } -Value right_shift_arithmetic(SPUContext* ctx, const Value& x, size_t bits) { +Value right_shift_arithmetic(SPUContext* ctx, const Value& x, + const Sizes& bits) { SPU_TRACE_HAL_DISP(ctx, x, bits); return _arshift(ctx, x, bits).setDtype(x.dtype()); diff --git a/libspu/kernel/hal/polymorphic.h b/libspu/kernel/hal/polymorphic.h index 58e90f00..9b47b85d 100644 --- a/libspu/kernel/hal/polymorphic.h +++ b/libspu/kernel/hal/polymorphic.h @@ -187,11 +187,12 @@ Value clamp(SPUContext* ctx, const Value& x, const Value& min, // @param dtype, second input value Value bitcast(SPUContext* ctx, const Value& x, DataType dtype); -Value left_shift(SPUContext* ctx, const Value& x, size_t bits); +Value left_shift(SPUContext* ctx, const Value& x, const Sizes& bits); -Value right_shift_logical(SPUContext* ctx, const Value& x, size_t bits); +Value right_shift_logical(SPUContext* ctx, const Value& x, const Sizes& bits); -Value right_shift_arithmetic(SPUContext* ctx, const Value& x, size_t bits); +Value right_shift_arithmetic(SPUContext* ctx, const Value& x, + const Sizes& bits); /// the element-wise base-2 logarithm of x // @param in, should be positive, or the result is implementation defined. diff --git a/libspu/kernel/hal/prot_wrapper.cc b/libspu/kernel/hal/prot_wrapper.cc index 10743c10..7e03454d 100644 --- a/libspu/kernel/hal/prot_wrapper.cc +++ b/libspu/kernel/hal/prot_wrapper.cc @@ -30,11 +30,11 @@ namespace spu::kernel::hal { return mpc::NAME(ctx, in); \ } -#define MAP_SHIFT_OP(NAME) \ - Value _##NAME(SPUContext* ctx, const Value& in, size_t bits) { \ - SPU_TRACE_HAL_DISP(ctx, in, bits); \ - auto ret = mpc::NAME(ctx, in, bits); \ - return ret; \ +#define MAP_SHIFT_OP(NAME) \ + Value _##NAME(SPUContext* ctx, const Value& in, const Sizes& bits) { \ + SPU_TRACE_HAL_DISP(ctx, in, bits); \ + auto ret = mpc::NAME(ctx, in, bits); \ + return ret; \ } #define MAP_BITREV_OP(NAME) \ @@ -163,6 +163,10 @@ Value _s2v(SPUContext* ctx, const Value& in, int owner) { MAP_UNARY_OP(not_p) MAP_UNARY_OP(not_s) MAP_UNARY_OP(not_v) +// Negate family +MAP_UNARY_OP(negate_p) +MAP_UNARY_OP(negate_s) +MAP_UNARY_OP(negate_v) // Msb family MAP_UNARY_OP(msb_p) MAP_UNARY_OP(msb_s) diff --git a/libspu/kernel/hal/prot_wrapper.h b/libspu/kernel/hal/prot_wrapper.h index 6294f834..a3c138ca 100644 --- a/libspu/kernel/hal/prot_wrapper.h +++ b/libspu/kernel/hal/prot_wrapper.h @@ -44,6 +44,10 @@ Value _not_p(SPUContext* ctx, const Value& in); Value _not_s(SPUContext* ctx, const Value& in); Value _not_v(SPUContext* ctx, const Value& in); +Value _negate_p(SPUContext* ctx, const Value& in); +Value _negate_s(SPUContext* ctx, const Value& in); +Value _negate_v(SPUContext* ctx, const Value& in); + Value _msb_p(SPUContext* ctx, const Value& in); Value _msb_s(SPUContext* ctx, const Value& in); Value _msb_v(SPUContext* ctx, const Value& in); @@ -52,17 +56,17 @@ Value _equal_pp(SPUContext* ctx, const Value& x, const Value& y); std::optional _equal_sp(SPUContext* ctx, const Value& x, const Value& y); std::optional _equal_ss(SPUContext* ctx, const Value& x, const Value& y); -Value _lshift_p(SPUContext* ctx, const Value& in, size_t bits); -Value _lshift_s(SPUContext* ctx, const Value& in, size_t bits); -Value _lshift_v(SPUContext* ctx, const Value& in, size_t bits); +Value _lshift_p(SPUContext* ctx, const Value& in, const Sizes& bits); +Value _lshift_s(SPUContext* ctx, const Value& in, const Sizes& bits); +Value _lshift_v(SPUContext* ctx, const Value& in, const Sizes& bits); -Value _rshift_p(SPUContext* ctx, const Value& in, size_t bits); -Value _rshift_s(SPUContext* ctx, const Value& in, size_t bits); -Value _rshift_v(SPUContext* ctx, const Value& in, size_t bits); +Value _rshift_p(SPUContext* ctx, const Value& in, const Sizes& bits); +Value _rshift_s(SPUContext* ctx, const Value& in, const Sizes& bits); +Value _rshift_v(SPUContext* ctx, const Value& in, const Sizes& bits); -Value _arshift_p(SPUContext* ctx, const Value& in, size_t bits); -Value _arshift_s(SPUContext* ctx, const Value& in, size_t bits); -Value _arshift_v(SPUContext* ctx, const Value& in, size_t bits); +Value _arshift_p(SPUContext* ctx, const Value& in, const Sizes& bits); +Value _arshift_s(SPUContext* ctx, const Value& in, const Sizes& bits); +Value _arshift_v(SPUContext* ctx, const Value& in, const Sizes& bits); Value _trunc_p(SPUContext* ctx, const Value& in, size_t bits, SignType sign); Value _trunc_s(SPUContext* ctx, const Value& in, size_t bits, SignType sign); diff --git a/libspu/kernel/hal/ring.cc b/libspu/kernel/hal/ring.cc index 08b873f6..725fd498 100644 --- a/libspu/kernel/hal/ring.cc +++ b/libspu/kernel/hal/ring.cc @@ -77,25 +77,25 @@ Value _cast_type(SPUContext* ctx, const Value& x, const Type& to) { SPU_THROW("unsupport unary op={} for {}", #Name, in); \ } \ } - IMPL_UNARY_OP(_not) +IMPL_UNARY_OP(_negate) IMPL_UNARY_OP(_msb) IMPL_UNARY_OP(_square) #undef IMPL_UNARY_OP -#define IMPL_SHIFT_OP(Name) \ - Value Name(SPUContext* ctx, const Value& in, size_t bits) { \ - SPU_TRACE_HAL_LEAF(ctx, in, bits); \ - if (in.isPublic()) { \ - return Name##_p(ctx, in, bits); \ - } else if (in.isSecret()) { \ - return Name##_s(ctx, in, bits); \ - } else if (in.isPrivate()) { \ - return Name##_v(ctx, in, bits); \ - } else { \ - SPU_THROW("unsupport unary op={} for {}", #Name, in); \ - } \ +#define IMPL_SHIFT_OP(Name) \ + Value Name(SPUContext* ctx, const Value& in, const Sizes& bits) { \ + SPU_TRACE_HAL_LEAF(ctx, in, bits); \ + if (in.isPublic()) { \ + return Name##_p(ctx, in, bits); \ + } else if (in.isSecret()) { \ + return Name##_s(ctx, in, bits); \ + } else if (in.isPrivate()) { \ + return Name##_v(ctx, in, bits); \ + } else { \ + SPU_THROW("unsupport unary op={} for {}", #Name, in); \ + } \ } IMPL_SHIFT_OP(_lshift) @@ -438,13 +438,6 @@ Value _equal(SPUContext* ctx, const Value& x, const Value& y) { _xor(ctx, _less(ctx, y, x), _k1)); } -Value _negate(SPUContext* ctx, const Value& x) { - SPU_TRACE_HAL_LEAF(ctx, x); - - // negate(x) = not(x) + 1 - return _add(ctx, _not(ctx, x), _constant(ctx, 1, x.shape())); -} - Value _sign(SPUContext* ctx, const Value& x) { SPU_TRACE_HAL_LEAF(ctx, x); @@ -479,14 +472,25 @@ Value _mux(SPUContext* ctx, const Value& pred, const Value& a, const Value& b) { Value _clamp(SPUContext* ctx, const Value& x, const Value& minv, const Value& maxv) { SPU_TRACE_HAL_LEAF(ctx, x, minv, maxv); - // clamp lower bound, res = x < minv ? minv : x auto res = _mux(ctx, _less(ctx, x, minv), minv, x); - // clamp upper bound, res = res < maxv ? res, maxv return _mux(ctx, _less(ctx, res, maxv), res, maxv); } +// TODO: refactor polymorphic, and may use select functions in polymorphic +Value _clamp_lower(SPUContext* ctx, const Value& x, const Value& minv) { + SPU_TRACE_HAL_LEAF(ctx, x, minv); + // clamp lower bound, res = x < minv ? minv : x + return _mux(ctx, _less(ctx, x, minv), minv, x); +} + +Value _clamp_upper(SPUContext* ctx, const Value& x, const Value& maxv) { + SPU_TRACE_HAL_LEAF(ctx, x, maxv); + // clamp upper bound, x = x < maxv ? x, maxv + return _mux(ctx, _less(ctx, x, maxv), x, maxv); +} + Value _constant(SPUContext* ctx, uint128_t init, const Shape& shape) { return _make_p(ctx, init, shape); } @@ -497,7 +501,7 @@ Value _bit_parity(SPUContext* ctx, const Value& x, size_t bits) { SPU_ENFORCE(absl::has_single_bit(bits), "currently only support power of 2"); auto ret = _prefer_b(ctx, x); while (bits > 1) { - ret = _xor(ctx, ret, _rshift(ctx, ret, bits / 2)); + ret = _xor(ctx, ret, _rshift(ctx, ret, {static_cast(bits / 2)})); bits /= 2; } @@ -518,7 +522,7 @@ Value _popcount(SPUContext* ctx, const Value& x, size_t bits) { std::vector vs; vs.reserve(bits); for (size_t idx = 0; idx < bits; idx++) { - auto x_ = _rshift(ctx, xb, idx); + auto x_ = _rshift(ctx, xb, {static_cast(idx)}); x_ = _and(ctx, x_, _constant(ctx, 1U, x.shape())); if (x_.storage_type().isa()) { @@ -547,8 +551,8 @@ Value _prefix_or(SPUContext* ctx, const Value& x) { auto b0 = _prefer_b(ctx, x); const size_t bit_width = SizeOf(ctx->getField()) * 8; for (int idx = 0; idx < absl::bit_width(bit_width) - 1; idx++) { - const size_t offset = 1UL << idx; - auto b1 = _rshift(ctx, b0, offset); + const int64_t offset = 1L << idx; + auto b1 = _rshift(ctx, b0, {offset}); b0 = _or(ctx, b0, b1); } return b0; @@ -574,8 +578,8 @@ Value _bitdeintl(SPUContext* ctx, const Value& in) { // out = (out & keep) ^ ((out >> shift) & move) ^ ((out & move) << shift); out = _xor(ctx, _xor(ctx, _and(ctx, out, keep), - _and(ctx, _rshift(ctx, out, shift), move)), - _lshift(ctx, _and(ctx, out, move), shift)); + _and(ctx, _rshift(ctx, out, {shift}), move)), + _lshift(ctx, _and(ctx, out, move), {shift})); } return out; } diff --git a/libspu/kernel/hal/ring.h b/libspu/kernel/hal/ring.h index b2fb6f65..f0bbb01b 100644 --- a/libspu/kernel/hal/ring.h +++ b/libspu/kernel/hal/ring.h @@ -71,11 +71,11 @@ Value _equal(SPUContext* ctx, const Value& x, const Value& y); Value _less(SPUContext* ctx, const Value& x, const Value& y); -Value _lshift(SPUContext* ctx, const Value& in, size_t bits); +Value _lshift(SPUContext* ctx, const Value& in, const Sizes& bits); -Value _rshift(SPUContext* ctx, const Value& in, size_t bits); +Value _rshift(SPUContext* ctx, const Value& in, const Sizes& bits); -Value _arshift(SPUContext* ctx, const Value& in, size_t bits); +Value _arshift(SPUContext* ctx, const Value& in, const Sizes& bits); Value _trunc(SPUContext* ctx, const Value& x, size_t bits = 0, SignType sign = SignType::Unknown); @@ -88,6 +88,11 @@ Value _mux(SPUContext* ctx, const Value& pred, const Value& a, const Value& b); // TODO: test me Value _clamp(SPUContext* ctx, const Value& x, const Value& minv, const Value& maxv); + +Value _clamp_lower(SPUContext* ctx, const Value& x, const Value& minv); + +Value _clamp_upper(SPUContext* ctx, const Value& x, const Value& maxv); + // Make a public value from uint128_t init value. // // If the current working field has less than 128bit, the lower sizeof(field) diff --git a/libspu/kernel/hal/shape_ops.cc b/libspu/kernel/hal/shape_ops.cc index 6337719f..ffcfed60 100644 --- a/libspu/kernel/hal/shape_ops.cc +++ b/libspu/kernel/hal/shape_ops.cc @@ -43,10 +43,10 @@ Value update_slice(SPUContext* ctx, const Value& in, const Value& update, SPU_TRACE_HAL_DISP(ctx, in, start_indices); if (in.storage_type() != update.storage_type()) { - auto u = - _cast_type(ctx, update, in.storage_type()).setDtype(update.dtype()); - - return update_slice(ctx, in, u, start_indices); + auto ct = _common_type(ctx, update.storage_type(), in.storage_type()); + auto u = _cast_type(ctx, update, ct).setDtype(update.dtype()); + auto i = _cast_type(ctx, in, ct).setDtype(in.dtype()); + return update_slice(ctx, i, u, start_indices); } return _update_slice(ctx, in, update, start_indices).setDtype(in.dtype()); diff --git a/libspu/kernel/hal/type_cast.cc b/libspu/kernel/hal/type_cast.cc index 53b65742..89ad0142 100644 --- a/libspu/kernel/hal/type_cast.cc +++ b/libspu/kernel/hal/type_cast.cc @@ -27,7 +27,8 @@ Value int2fxp(SPUContext* ctx, const Value& x, DataType to_type) { SPU_TRACE_HAL_LEAF(ctx, x); SPU_ENFORCE(x.isInt(), "expect integer, got {}", x.dtype()); - return _lshift(ctx, x, ctx->getFxpBits()).setDtype(to_type); + return _lshift(ctx, x, {static_cast(ctx->getFxpBits())}) + .setDtype(to_type); } // Casting fxp to integer. @@ -49,12 +50,12 @@ Value fxp2int(SPUContext* ctx, const Value& x, DataType to_type) { SPU_TRACE_HAL_LEAF(ctx, x); SPU_ENFORCE(x.isFxp()); - const size_t fxp_bits = ctx->getFxpBits(); + const int64_t fxp_bits = ctx->getFxpBits(); const Value kOneMinusEps = _constant(ctx, (1 << fxp_bits) - 1, x.shape()); // (x + 0.99 * (x < 0)) >> fxp_bits return _arshift(ctx, _add(ctx, x, _mul(ctx, kOneMinusEps, _msb(ctx, x))), - fxp_bits) + {fxp_bits}) .setDtype(to_type); } @@ -77,6 +78,12 @@ Value reveal(SPUContext* ctx, const Value& x) { return _s2p(ctx, x).setDtype(x.dtype()); } +Value reveal_to(SPUContext* ctx, const Value& x, size_t rank) { + SPU_TRACE_HAL_LEAF(ctx, x, rank); + SPU_ENFORCE(x.isSecret()); + return _s2v(ctx, x, rank).setDtype(x.dtype()); +} + Value dtype_cast(SPUContext* ctx, const Value& in, DataType to_type) { SPU_TRACE_HAL_DISP(ctx, in, to_type); diff --git a/libspu/kernel/hal/type_cast.h b/libspu/kernel/hal/type_cast.h index cfcc05fb..43ee72bb 100644 --- a/libspu/kernel/hal/type_cast.h +++ b/libspu/kernel/hal/type_cast.h @@ -35,4 +35,9 @@ Value seal(SPUContext* ctx, const Value& x); // @param in, the input value Value reveal(SPUContext* ctx, const Value& x); +/// reveal a secret to a specific party +// @param in, the input value +// @param rank, the rank of the party to reveal to +Value reveal_to(SPUContext* ctx, const Value& x, size_t rank); + } // namespace spu::kernel::hal diff --git a/libspu/kernel/hal/utils.h b/libspu/kernel/hal/utils.h index eb1ab1bc..3d2e218b 100644 --- a/libspu/kernel/hal/utils.h +++ b/libspu/kernel/hal/utils.h @@ -16,13 +16,29 @@ #include "libspu/core/context.h" #include "libspu/core/value.h" +#include "libspu/core/vectorize.h" #include "libspu/kernel/hal/constants.h" -#include "libspu/kernel/hal/polymorphic.h" #include "libspu/kernel/hal/ring.h" #include "libspu/kernel/hal/shape_ops.h" namespace spu::kernel::hal { +////////////////////////////////////////////////////////////////////////////// +// Shape utils +////////////////////////////////////////////////////////////////////////////// + +/// the squeeze function, i.e., removes dimensions of size 1 from the shape of +/// a tensor. +// @param in, the input +// @param dim, the dimension to be squeezed +Value squeeze(SPUContext* ctx, const Value& in, int64_t dim = 0); + +/// the unsqueeze function, i.e., expands a tensor with a length 1 axis +/// inserted at index axis. +// @param in, the input +// @param dim, the dimension to be unsqueezed +Value unsqueeze(SPUContext* ctx, const Value& in, int64_t dim = 0); + // This is SPU's version of JAX's associative_scan // See: // https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.associative_scan.html @@ -32,61 +48,82 @@ namespace spu::kernel::hal { // for the detailed algorithm explanation // // fn: an associative binary Function -// in: a 1-d tensor +// in: a tensor, scan the last axis template spu::Value associative_scan(Fn&& fn, SPUContext* ctx, const Value& in) { - SPU_ENFORCE(in.shape().ndim() == 1U, "input should be 1d"); - const auto numel = in.numel(); - if (numel < 2) { + SPU_ENFORCE(in.shape().ndim() >= 1U, "input should not be scalar"); + // First reshape to 2D {M, N} tensor, scan each N elements + const Shape shape = in.shape(); + const auto N = shape.back(); + // in case some empty tensors + if (N < 2 || shape.numel() == 0) { return in; } + const auto M = shape.numel() / N; + spu::Value in_2d = hal::reshape(ctx, in, {M, N}); // merge consecutive even/odd index elements - auto reduced_elems = fn(ctx, hal::slice(ctx, in, {0}, {numel - 1}, {2}), - hal::slice(ctx, in, {1}, {numel}, {2})); - // process half elements recursively and get odd index elements - auto odd_elems = associative_scan(fn, ctx, reduced_elems); + spu::Value odd_elems; + std::vector odd_vec; + std::vector even_vec; + { + for (int64_t i = 0; i < M; ++i) { + odd_vec.push_back(hal::slice(ctx, in_2d, {i, 0}, {i + 1, N - 1}, {1, 2})); + even_vec.push_back(hal::slice(ctx, in_2d, {i, 1}, {i + 1, N}, {1, 2})); + } + std::vector reduced_elems_vec; + vmap(odd_vec.begin(), odd_vec.end(), even_vec.begin(), even_vec.end(), + std::back_inserter(reduced_elems_vec), + [&](const spu::Value& odd, const spu::Value& even) { + return fn(ctx, odd, even); + }); + + auto concat_reduced_elems = hal::concatenate(ctx, reduced_elems_vec, 0); + + // process half elements recursively and get odd index elements + odd_elems = associative_scan(fn, ctx, concat_reduced_elems); + } // get even index elements + odd_vec.clear(); + even_vec.clear(); spu::Value even_elems; - if (numel % 2 == 0) { - even_elems = - fn(ctx, hal::slice(ctx, odd_elems, {0}, {odd_elems.numel() - 1}, {1}), - hal::slice(ctx, in, {2}, {numel}, {2})); - } else { - even_elems = fn(ctx, odd_elems, hal::slice(ctx, in, {2}, {numel}, {2})); + { + std::vector even_elems_vec; + for (int64_t i = 0; i < M; ++i) { + if (N % 2 == 0) { + odd_vec.push_back(hal::slice(ctx, odd_elems, {i, 0}, + {i + 1, odd_elems.shape().back() - 1}, + {1, 1})); + } else { + odd_vec.push_back(hal::slice(ctx, odd_elems, {i, 0}, + {i + 1, odd_elems.shape().back()}, {})); + } + even_vec.push_back(hal::slice(ctx, in_2d, {i, 2}, {i + 1, N}, {1, 2})); + } + vmap(odd_vec.begin(), odd_vec.end(), even_vec.begin(), even_vec.end(), + std::back_inserter(even_elems_vec), + [&](const spu::Value& odd, const spu::Value& even) { + return fn(ctx, odd, even); + }); + + even_elems = hal::concatenate(ctx, even_elems_vec, 0); } // concat the 0th element - auto final_even_elems = - hal::concatenate(ctx, {hal::slice(ctx, in, {0}, {1}), even_elems}, 0); + auto final_even_elems = hal::concatenate( + ctx, {hal::slice(ctx, in_2d, {0, 0}, {M, 1}), even_elems}, 1); // concat even and odd elems interleavely auto zero = hal::constant(ctx, 0U, in.dtype(), {1}); - auto pad_even = - hal::pad(ctx, final_even_elems, zero, {0}, - {final_even_elems.numel() == odd_elems.numel() ? 1 : 0}, {1}); - auto pad_odd = - hal::pad(ctx, odd_elems, zero, {1}, - {final_even_elems.numel() == odd_elems.numel() ? 0 : 1}, {1}); + auto pad_even = hal::pad( + ctx, final_even_elems, zero, {0, 0}, + {0, final_even_elems.numel() == odd_elems.numel() ? 1 : 0}, {0, 1}); + auto pad_odd = hal::pad( + ctx, odd_elems, zero, {0, 1}, + {0, final_even_elems.numel() == odd_elems.numel() ? 0 : 1}, {0, 1}); auto ret = hal::_add(ctx, pad_even, pad_odd).setDtype(in.dtype()); - return ret; + return hal::reshape(ctx, ret, in.shape()); } -////////////////////////////////////////////////////////////////////////////// -// Shape utils -////////////////////////////////////////////////////////////////////////////// - -/// the squeeze function, i.e., removes dimensions of size 1 from the shape of a -/// tensor. -// @param in, the input -// @param dim, the dimension to be squeezed -Value squeeze(SPUContext* ctx, const Value& in, int64_t dim = 0); - -/// the unsqueeze function, i.e., expands a tensor with a length 1 axis inserted -/// at index axis. -// @param in, the input -// @param dim, the dimension to be unsqueezed -Value unsqueeze(SPUContext* ctx, const Value& in, int64_t dim = 0); - } // namespace spu::kernel::hal \ No newline at end of file diff --git a/libspu/kernel/hal/utils_test.cc b/libspu/kernel/hal/utils_test.cc index 087f83f7..f221693c 100644 --- a/libspu/kernel/hal/utils_test.cc +++ b/libspu/kernel/hal/utils_test.cc @@ -24,7 +24,7 @@ namespace spu::kernel::hal { namespace { -TEST(UtilsTest, associative_scan) { +TEST(UtilsTest, associative_scan_1d) { SPUContext ctx = test::makeSPUContext(); { @@ -82,6 +82,62 @@ TEST(UtilsTest, associative_scan) { } } +TEST(UtilsTest, associative_scan_2d) { + SPUContext ctx = test::makeSPUContext(); + + { + const xt::xarray x = {{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}; + const xt::xarray prefix_sum = {{1, 2, 3, 4, 5}, {1, 2, 3, 4, 5}}; + Value a = test::makeValue(&ctx, x, VIS_SECRET); + Value b = associative_scan(hal::add, &ctx, a); + auto ret = dump_public_as(&ctx, hal::reveal(&ctx, b)); + EXPECT_TRUE(prefix_sum == ret) << x << std::endl + << prefix_sum << std::endl + << ret; + } + + { + const xt::xarray x = {{1, 2, 3, 4, 5}, {1, 2, 3, 4, 5}}; + const xt::xarray prefix_prod = {{1, 2, 6, 24, 120}, + {1, 2, 6, 24, 120}}; + Value a = test::makeValue(&ctx, x, VIS_SECRET); + Value b = associative_scan(hal::mul, &ctx, a); + auto ret = dump_public_as(&ctx, hal::reveal(&ctx, b)); + EXPECT_TRUE(prefix_prod == ret) << x << std::endl + << prefix_prod << std::endl + << ret; + } + + { + const xt::xarray x = {{true, true, true, false, true, false}, + {true, true, true, false, true, false}}; + const xt::xarray prefix_and = { + {true, true, true, false, false, false}, + {true, true, true, false, false, false}}; + + Value a = test::makeValue(&ctx, x, VIS_SECRET); + Value b = associative_scan(hal::bitwise_and, &ctx, a); + auto ret = dump_public_as(&ctx, hal::reveal(&ctx, b)); + EXPECT_TRUE(prefix_and == ret) << x << std::endl + << prefix_and << std::endl + << ret; + } + + { + const xt::xarray x = {{true, true, true, false, true, false}, + {true, true, true, false, true, false}}; + const xt::xarray prefix_or = {{true, true, true, true, true, true}, + {true, true, true, true, true, true}}; + + Value a = test::makeValue(&ctx, x, VIS_SECRET); + Value b = associative_scan(hal::bitwise_or, &ctx, a); + auto ret = dump_public_as(&ctx, hal::reveal(&ctx, b)); + EXPECT_TRUE(prefix_or == ret) << x << std::endl + << prefix_or << std::endl + << ret; + } +} + TEST(UtilsTest, Squeeze) { // GIVEN xt::xarray x = xt::ones({2, 1, 2, 1, 2}); diff --git a/libspu/kernel/hlo/BUILD.bazel b/libspu/kernel/hlo/BUILD.bazel index e798fa5e..d9ee64ca 100644 --- a/libspu/kernel/hlo/BUILD.bazel +++ b/libspu/kernel/hlo/BUILD.bazel @@ -110,6 +110,7 @@ spu_cc_test( ":casting", ":const", "//libspu/kernel:test_util", + "//libspu/mpc/utils:simulate", ], ) diff --git a/libspu/kernel/hlo/basic_unary.cc b/libspu/kernel/hlo/basic_unary.cc index ce1d2c17..44cc9e36 100644 --- a/libspu/kernel/hlo/basic_unary.cc +++ b/libspu/kernel/hlo/basic_unary.cc @@ -118,19 +118,19 @@ spu::Value Round_RNTE(SPUContext *ctx, const spu::Value &in) { // so comp = b && (c || a) SPU_ENFORCE(!in.isComplex()); SPU_ENFORCE(in.isFxp(), "Round only supports fxp"); - const auto fxp_bits = ctx->getFxpBits(); + const int64_t fxp_bits = ctx->getFxpBits(); const auto k1 = hal::_constant(ctx, 1U, in.shape()); auto x_prime = hal::_prefer_b(ctx, in); auto y = hal::floor(ctx, x_prime); - auto a = hal::_and(ctx, hal::_rshift(ctx, x_prime, fxp_bits), k1); - auto b = hal::_and(ctx, hal::_rshift(ctx, x_prime, fxp_bits - 1), k1); + auto a = hal::_and(ctx, hal::_rshift(ctx, x_prime, {fxp_bits}), k1); + auto b = hal::_and(ctx, hal::_rshift(ctx, x_prime, {fxp_bits - 1}), k1); std::vector cs; cs.reserve(fxp_bits - 1); - for (size_t idx = 0; idx < fxp_bits - 1; idx++) { - auto x_ = hal::_and(ctx, hal::_rshift(ctx, x_prime, idx), k1); + for (int64_t idx = 0; idx < fxp_bits - 1; idx++) { + auto x_ = hal::_and(ctx, hal::_rshift(ctx, x_prime, {idx}), k1); cs.push_back(std::move(x_)); } auto c = vreduce(cs.begin(), cs.end(), [&](const Value &a, const Value &b) { diff --git a/libspu/kernel/hlo/casting.cc b/libspu/kernel/hlo/casting.cc index 1b87c653..b3c0993c 100644 --- a/libspu/kernel/hlo/casting.cc +++ b/libspu/kernel/hlo/casting.cc @@ -52,6 +52,10 @@ spu::Value Reveal(SPUContext *ctx, const spu::Value &in) { return hal::reveal(ctx, in); } +spu::Value RevealTo(SPUContext *ctx, const spu::Value &in, size_t rank) { + return hal::reveal_to(ctx, in, rank); +} + spu::Value Seal(SPUContext *ctx, const spu::Value &in) { return hal::seal(ctx, in); } diff --git a/libspu/kernel/hlo/casting.h b/libspu/kernel/hlo/casting.h index 469ba42d..cf5a5be9 100644 --- a/libspu/kernel/hlo/casting.h +++ b/libspu/kernel/hlo/casting.h @@ -29,6 +29,8 @@ spu::Value Bitcast(SPUContext *ctx, const spu::Value &in, DataType dst_dtype); spu::Value Reveal(SPUContext *ctx, const spu::Value &in); +spu::Value RevealTo(SPUContext *ctx, const spu::Value &in, size_t rank); + spu::Value Seal(SPUContext *ctx, const spu::Value &in); } // namespace spu::kernel::hlo diff --git a/libspu/kernel/hlo/casting_test.cc b/libspu/kernel/hlo/casting_test.cc index f103a155..6e59b4b3 100644 --- a/libspu/kernel/hlo/casting_test.cc +++ b/libspu/kernel/hlo/casting_test.cc @@ -20,23 +20,48 @@ #include "libspu/core/value.h" #include "libspu/kernel/hlo/const.h" #include "libspu/kernel/test_util.h" +#include "libspu/mpc/utils/simulate.h" namespace spu::kernel::hlo { -TEST(ConstTest, Empty) { - SPUContext sctx = test::makeSPUContext(); +class CastingTest + : public ::testing::TestWithParam> {}; - auto empty_c = Constant(&sctx, true, {0}); +TEST_P(CastingTest, Empty) { + FieldType field = std::get<0>(GetParam()); + ProtocolKind prot = std::get<1>(GetParam()); - // Seal - auto s_empty = Seal(&sctx, empty_c); + mpc::utils::simulate( + 3, [&](const std::shared_ptr &lctx) { + SPUContext sctx = test::makeSPUContext(prot, field, lctx); + auto empty_c = Constant(&sctx, true, {0}); - // Reveal - auto p_empty = Reveal(&sctx, s_empty); + // Seal + auto s_empty = Seal(&sctx, empty_c); - EXPECT_EQ(p_empty.numel(), 0); - EXPECT_EQ(p_empty.shape().size(), 1); - EXPECT_EQ(p_empty.shape()[0], 0); + // Reveal + auto p_empty = Reveal(&sctx, s_empty); + + // RevealTo + auto v_empty = RevealTo(&sctx, s_empty, 0); + + EXPECT_EQ(p_empty.numel(), 0); + EXPECT_EQ(p_empty.shape().size(), 1); + EXPECT_EQ(p_empty.shape()[0], 0); + + EXPECT_EQ(v_empty.numel(), 0); + EXPECT_EQ(v_empty.shape().size(), 1); + EXPECT_EQ(v_empty.shape()[0], 0); + }); } +INSTANTIATE_TEST_SUITE_P( + CastingTestInstances, CastingTest, + testing::Combine(testing::Values(FieldType::FM64, FieldType::FM128), + testing::Values(ProtocolKind::REF2K, ProtocolKind::SEMI2K, + ProtocolKind::ABY3)), + [](const testing::TestParamInfo &p) { + return fmt::format("{}x{}", std::get<0>(p.param), std::get<1>(p.param)); + }); + } // namespace spu::kernel::hlo diff --git a/libspu/kernel/hlo/shift.cc b/libspu/kernel/hlo/shift.cc index 08fe7140..0b6f3f1d 100644 --- a/libspu/kernel/hlo/shift.cc +++ b/libspu/kernel/hlo/shift.cc @@ -26,31 +26,8 @@ namespace spu::kernel::hlo { template spu::Value shift_impl_p(SPUContext *ctx, const spu::Value &lhs, const spu::Value &rhs, const Fn &f) { - auto shift_bits = hal::dump_public_as(ctx, rhs); - if (std::all_of(rhs.strides().begin(), rhs.strides().end(), - [](int64_t s) { return s == 0; })) { - // rhs is a splat - return f(ctx, lhs, shift_bits[0]); - } - - // Not a splat... - spu::Value ret = - hal::constant(ctx, static_cast(0), lhs.dtype(), lhs.shape()); - auto dtype_size = getWidth(lhs.dtype()); - for (size_t bits = 0; bits < dtype_size; ++bits) { - if (std::none_of(shift_bits.begin(), shift_bits.end(), [&bits](int8_t b) { - return b == static_cast(bits); - })) { - continue; - } - auto current_bits = hal::constant(ctx, static_cast(bits), - rhs.dtype(), rhs.shape()); - auto mask = hal::equal(ctx, rhs, current_bits); - auto shifted = f(ctx, lhs, bits); - ret = hal::add(ctx, ret, hal::mul(ctx, mask, shifted)); - } - - return ret; + auto shift_bits = hal::dump_public_as(ctx, rhs); + return f(ctx, lhs, {shift_bits.begin(), shift_bits.end()}); } template @@ -63,7 +40,7 @@ spu::Value shift_impl_s(SPUContext *ctx, const spu::Value &lhs, auto current_bits = hal::constant(ctx, static_cast(bits), rhs.dtype(), rhs.shape()); auto mask = hal::equal(ctx, rhs, current_bits); - auto shifted = f(ctx, lhs, bits); + auto shifted = f(ctx, lhs, {static_cast(bits)}); ret = hal::add(ctx, ret, hal::mul(ctx, mask, shifted)); } diff --git a/libspu/mpc/ab_api.cc b/libspu/mpc/ab_api.cc index 7d02fe47..6ba634ad 100644 --- a/libspu/mpc/ab_api.cc +++ b/libspu/mpc/ab_api.cc @@ -86,7 +86,7 @@ Value rand_b(SPUContext* ctx, const Shape& shape) { FORCE_DISPATCH(ctx, shape); } -Value not_a(SPUContext* ctx, const Value& x) { FORCE_DISPATCH(ctx, x); } +Value negate_a(SPUContext* ctx, const Value& x) { FORCE_DISPATCH(ctx, x); } Value add_ap(SPUContext* ctx, const Value& x, const Value& y) { FORCE_DISPATCH(ctx, x, y); @@ -133,7 +133,7 @@ OptionalAPI mul_a1bv(SPUContext* ctx, const Value& x, const Value& y) { return NotAvailable; } -Value lshift_a(SPUContext* ctx, const Value& x, size_t nbits) { +Value lshift_a(SPUContext* ctx, const Value& x, const Sizes& nbits) { FORCE_DISPATCH(ctx, x, nbits); } @@ -201,15 +201,15 @@ OptionalAPI xor_bv(SPUContext* ctx, const Value& x, const Value& y) { return NotAvailable; } -Value lshift_b(SPUContext* ctx, const Value& x, size_t nbits) { +Value lshift_b(SPUContext* ctx, const Value& x, const Sizes& nbits) { FORCE_DISPATCH(ctx, x, nbits); } -Value rshift_b(SPUContext* ctx, const Value& x, size_t nbits) { +Value rshift_b(SPUContext* ctx, const Value& x, const Sizes& nbits) { FORCE_DISPATCH(ctx, x, nbits); } -Value arshift_b(SPUContext* ctx, const Value& x, size_t nbits) { +Value arshift_b(SPUContext* ctx, const Value& x, const Sizes& nbits) { FORCE_DISPATCH(ctx, x, nbits); } @@ -254,10 +254,10 @@ Value bitintl_b(SPUContext* ctx, const Value& x, size_t stride) { auto M = hack_make_p(ctx, spu::detail::kBitIntlSwapMasks[idx], x.shape()); int64_t S = static_cast(1) << idx; // out = (out & K) ^ ((out >> S) & M) ^ ((out & M) << S); - out = xor_bb( - ctx, - xor_bb(ctx, and_bp(ctx, out, K), and_bp(ctx, rshift_b(ctx, out, S), M)), - lshift_b(ctx, and_bp(ctx, out, M), S)); + out = xor_bb(ctx, + xor_bb(ctx, and_bp(ctx, out, K), + and_bp(ctx, rshift_b(ctx, out, {S}), M)), + lshift_b(ctx, and_bp(ctx, out, M), {S})); } out = setNumBits(out, numBits(x)); return out; @@ -283,10 +283,10 @@ Value bitdeintl_b(SPUContext* ctx, const Value& x, size_t stride) { auto M = hack_make_p(ctx, spu::detail::kBitIntlSwapMasks[idx], x.shape()); int64_t S = static_cast(1) << idx; // out = (out & K) ^ ((out >> S) & M) ^ ((out & M) << S); - out = xor_bb( - ctx, - xor_bb(ctx, and_bp(ctx, out, K), and_bp(ctx, rshift_b(ctx, out, S), M)), - lshift_b(ctx, and_bp(ctx, out, M), S)); + out = xor_bb(ctx, + xor_bb(ctx, and_bp(ctx, out, K), + and_bp(ctx, rshift_b(ctx, out, {S}), M)), + lshift_b(ctx, and_bp(ctx, out, M), {S})); } out = setNumBits(out, numBits(x)); return out; @@ -318,9 +318,9 @@ Value ppa_kogge_stone(SPUContext* ctx, const Value& lhs, const Value& rhs, auto G = and_bb(ctx, lhs, rhs); for (int idx = 0; idx < Log2Ceil(nbits); ++idx) { - const size_t offset = 1UL << idx; - auto G1 = lshift_b(ctx, G, offset); - auto P1 = lshift_b(ctx, P, offset); + const int64_t offset = static_cast(1) << idx; + auto G1 = lshift_b(ctx, G, {offset}); + auto P1 = lshift_b(ctx, P, {offset}); // P1 = P & P1 // G1 = G ^ (P & G1) @@ -332,7 +332,7 @@ Value ppa_kogge_stone(SPUContext* ctx, const Value& lhs, const Value& rhs, } // out = (G << 1) ^ p0 - auto C = lshift_b(ctx, G, 1); + auto C = lshift_b(ctx, G, {1}); return xor_bb(ctx, xor_bb(ctx, lhs, rhs), C); } @@ -343,7 +343,7 @@ std::pair bit_scatter(SPUContext* ctx, const Value& in, SPU_ENFORCE(absl::has_single_bit(nbits), "unsupported {}", nbits); auto out = bitdeintl_b(ctx, in, stride); - auto hi = rshift_b(ctx, out, nbits / 2); + auto hi = rshift_b(ctx, out, {static_cast(nbits / 2)}); auto mask = hack_make_p(ctx, (static_cast(1) << (nbits / 2)) - 1, in.shape()); auto lo = and_bp(ctx, out, mask); @@ -357,7 +357,7 @@ Value bit_gather(SPUContext* ctx, const Value& hi, const Value& lo, SPU_ENFORCE(nbits == numBits(lo), "nbits mismatch {}, {}", nbits, numBits(lo)); - auto out = xor_bb(ctx, lshift_b(ctx, hi, nbits), lo); + auto out = xor_bb(ctx, lshift_b(ctx, hi, {static_cast(nbits)}), lo); return bitintl_b(ctx, out, stride); } @@ -395,8 +395,8 @@ Value ppa_sklansky(SPUContext* ctx, Value const& lhs, Value const& rhs, auto Gs = and_bp(ctx, Gl, s_mask); auto Ps = and_bp(ctx, Pl, s_mask); for (int j = 0; j < idx; j++) { - Gs = xor_bb(ctx, Gs, rshift_b(ctx, Gs, 1 << j)); - Ps = xor_bb(ctx, Ps, rshift_b(ctx, Ps, 1 << j)); + Gs = xor_bb(ctx, Gs, rshift_b(ctx, Gs, {1 << j})); + Ps = xor_bb(ctx, Ps, rshift_b(ctx, Ps, {1 << j})); } // SPU_ENFORCE(numBits(Ps) == bit_width / 2); // SPU_ENFORCE(numBits(Gs) == bit_width / 2); @@ -416,7 +416,7 @@ Value ppa_sklansky(SPUContext* ctx, Value const& lhs, Value const& rhs, } // out = (G0 << 1) ^ p0 - auto C = lshift_b(ctx, G, 1); + auto C = lshift_b(ctx, G, {1}); return xor_bb(ctx, xor_bb(ctx, lhs, rhs), C); } @@ -460,8 +460,8 @@ Value carry_a2b(SPUContext* ctx, const Value& x, const Value& y, size_t k) { while (k > 1) { if (k % 2 != 0) { k += 1; - P = lshift_b(ctx, P, 1); - G = lshift_b(ctx, G, 1); + P = lshift_b(ctx, P, {1}); + G = lshift_b(ctx, G, {1}); } auto [P1, P0] = bit_scatter(ctx, P, 0); auto [G1, G0] = bit_scatter(ctx, G, 0); diff --git a/libspu/mpc/ab_api.h b/libspu/mpc/ab_api.h index 72a8476f..68cbb23f 100644 --- a/libspu/mpc/ab_api.h +++ b/libspu/mpc/ab_api.h @@ -29,7 +29,7 @@ Value msb_a2b(SPUContext* ctx, const Value& x); Value rand_a(SPUContext* ctx, const Shape& shape); Value rand_b(SPUContext* ctx, const Shape& shape); -Value not_a(SPUContext* ctx, const Value& x); +Value negate_a(SPUContext* ctx, const Value& x); Value equal_ap(SPUContext* ctx, const Value& x, const Value& y); Value equal_aa(SPUContext* ctx, const Value& x, const Value& y); @@ -46,7 +46,7 @@ OptionalAPI mul_av(SPUContext* ctx, const Value& x, const Value& y); Value mul_a1b(SPUContext* ctx, const Value& x, const Value& y); OptionalAPI mul_a1bv(SPUContext* ctx, const Value& x, const Value& y); -Value lshift_a(SPUContext* ctx, const Value& x, size_t nbits); +Value lshift_a(SPUContext* ctx, const Value& x, const Sizes& nbits); Value trunc_a(SPUContext* ctx, const Value& x, size_t nbits, SignType sign); Value mmul_ap(SPUContext* ctx, const Value& x, const Value& y); @@ -72,9 +72,9 @@ Value xor_bb(SPUContext* ctx, const Value& x, const Value& y); OptionalAPI xor_bv(SPUContext* ctx, const Value& x, const Value& y); // TODO -Value lshift_b(SPUContext* ctx, const Value& x, size_t nbits); -Value rshift_b(SPUContext* ctx, const Value& x, size_t nbits); -Value arshift_b(SPUContext* ctx, const Value& x, size_t nbits); +Value lshift_b(SPUContext* ctx, const Value& x, const Sizes& nbits); +Value rshift_b(SPUContext* ctx, const Value& x, const Sizes& nbits); +Value arshift_b(SPUContext* ctx, const Value& x, const Sizes& nbits); // Bit reverse for binary share. Value bitrev_b(SPUContext* ctx, const Value& x, size_t start, size_t end); diff --git a/libspu/mpc/ab_api_test.cc b/libspu/mpc/ab_api_test.cc index 13ef4080..82d802b3 100644 --- a/libspu/mpc/ab_api_test.cc +++ b/libspu/mpc/ab_api_test.cc @@ -195,7 +195,7 @@ TEST_P(ArithmeticTest, MulA1B) { return; } - const size_t K = spu::SizeOf(conf.field()) * 8; + const int64_t K = spu::SizeOf(conf.field()) * 8; /* GIVEN */ auto p0 = rand_p(obj.get(), conf.protocol() == ProtocolKind::CHEETAH @@ -204,12 +204,11 @@ TEST_P(ArithmeticTest, MulA1B) { auto p1 = rand_p(obj.get(), conf.protocol() == ProtocolKind::CHEETAH ? Shape({200, 26}) : kShape); - p1 = rshift_p(obj.get(), p1, K - 1); auto a0 = p2a(obj.get(), p0); auto a1 = p2b(obj.get(), p1); // hint runtime this is a 1bit value. - a1 = lshift_b(obj.get(), a1, K - 1); - a1 = rshift_b(obj.get(), a1, K - 1); + // Sometimes, the underlying value is not strictly 1bit + a1.storage_type().as()->setNbits(1); /* WHEN */ auto prev = obj->prot()->getState()->getStats(); @@ -217,7 +216,9 @@ TEST_P(ArithmeticTest, MulA1B) { auto cost = obj->prot()->getState()->getStats() - prev; auto r_aa = a2p(obj.get(), tmp); - auto r_pp = mul_pp(obj.get(), p0, p1); + auto r_pp = + mul_pp(obj.get(), p0, + rshift_p(obj.get(), lshift_p(obj.get(), p1, {K - 1}), {K - 1})); /* THEN */ EXPECT_VALUE_EQ(r_aa, r_pp); @@ -238,12 +239,12 @@ TEST_P(ArithmeticTest, MulAV) { return; } - const size_t K = spu::SizeOf(conf.field()) * 8; + const int64_t K = spu::SizeOf(conf.field()) * 8; /* GIVEN */ auto p0 = rand_p(obj.get(), kShape); auto p1 = rand_p(obj.get(), kShape); - p1 = rshift_p(obj.get(), p1, K - 1); + p1 = rshift_p(obj.get(), p1, {K - 1}); auto a0 = p2a(obj.get(), p0); auto a1 = p2v(obj.get(), p1, 0); @@ -275,17 +276,17 @@ TEST_P(ArithmeticTest, MulA1BV) { return; } - const size_t K = spu::SizeOf(conf.field()) * 8; + const int64_t K = spu::SizeOf(conf.field()) * 8; /* GIVEN */ auto p0 = rand_p(obj.get(), kShape); auto p1 = rand_p(obj.get(), kShape); - p1 = rshift_p(obj.get(), p1, K - 1); + p1 = rshift_p(obj.get(), p1, {K - 1}); auto a0 = p2a(obj.get(), p0); auto a1 = p2v(obj.get(), p1, 0); // hint runtime this is a 1bit value. - a1 = lshift_v(obj.get(), a1, K - 1); - a1 = rshift_v(obj.get(), a1, K - 1); + a1 = lshift_v(obj.get(), a1, {K - 1}); + a1 = rshift_v(obj.get(), a1, {K - 1}); // auto a1 = b2v(obj.get(), _a1, 0); /* WHEN */ @@ -436,7 +437,7 @@ TEST_P(ArithmeticTest, MatMulAV) { }); } -TEST_P(ArithmeticTest, NotA) { +TEST_P(ArithmeticTest, NegateA) { const auto factory = std::get<0>(GetParam()); const RuntimeConfig& conf = std::get<1>(GetParam()); const size_t npc = std::get<2>(GetParam()); @@ -450,15 +451,15 @@ TEST_P(ArithmeticTest, NotA) { /* WHEN */ auto prev = obj->prot()->getState()->getStats(); - auto r_a = not_a(obj.get(), a0); + auto r_a = negate_a(obj.get(), a0); auto cost = obj->prot()->getState()->getStats() - prev; auto r_p = a2p(obj.get(), r_a); - auto r_pp = a2p(obj.get(), not_a(obj.get(), a0)); + auto r_pp = a2p(obj.get(), negate_a(obj.get(), a0)); /* THEN */ EXPECT_VALUE_EQ(r_p, r_pp); - EXPECT_TRUE(verifyCost(obj->prot()->getKernel("not_a"), "not_a", + EXPECT_TRUE(verifyCost(obj->prot()->getKernel("negate_a"), "negate_a", conf.field(), kShape, npc, cost)); }); } @@ -482,10 +483,10 @@ TEST_P(ArithmeticTest, LShiftA) { } /* WHEN */ auto prev = obj->prot()->getState()->getStats(); - auto tmp = lshift_a(obj.get(), a0, bits); + auto tmp = lshift_a(obj.get(), a0, {static_cast(bits)}); auto cost = obj->prot()->getState()->getStats() - prev; auto r_b = a2p(obj.get(), tmp); - auto r_p = lshift_p(obj.get(), p0, bits); + auto r_p = lshift_p(obj.get(), p0, {static_cast(bits)}); /* THEN */ EXPECT_VALUE_EQ(r_b, r_p); @@ -513,10 +514,11 @@ TEST_P(ArithmeticTest, TruncA) { if (!kernel->hasMsbError()) { // trunc requires MSB to be zero. - p0 = arshift_p(obj.get(), p0, 1); + p0 = arshift_p(obj.get(), p0, {1}); } else { // has msb error, only use lowest 10 bits. - p0 = arshift_p(obj.get(), p0, SizeOf(conf.field()) * 8 - 10); + p0 = arshift_p(obj.get(), p0, + {static_cast(SizeOf(conf.field()) * 8 - 10)}); } /* GIVEN */ @@ -529,7 +531,7 @@ TEST_P(ArithmeticTest, TruncA) { auto cost = obj->prot()->getState()->getStats() - prev; auto r_a = a2p(obj.get(), a1); - auto r_p = arshift_p(obj.get(), p0, bits); + auto r_p = arshift_p(obj.get(), p0, {static_cast(bits)}); /* THEN */ EXPECT_VALUE_ALMOST_EQ(r_a, r_p, npc); @@ -674,11 +676,11 @@ TEST_BOOLEAN_BINARY_OP(xor) } \ /* WHEN */ \ auto prev = obj->prot()->getState()->getStats(); \ - auto tmp = OP##_b(obj.get(), b0, bits); \ + auto tmp = OP##_b(obj.get(), b0, {static_cast(bits)}); \ auto cost = \ obj->prot()->getState()->getStats() - prev; \ auto r_b = b2p(obj.get(), tmp); \ - auto r_p = OP##_p(obj.get(), p0, bits); \ + auto r_p = OP##_p(obj.get(), p0, {static_cast(bits)}); \ \ /* THEN */ \ EXPECT_VALUE_EQ(r_b, r_p); \ @@ -837,7 +839,7 @@ TEST_P(ConversionTest, MSB) { // SECURENN has an msb input range here if (conf.protocol() == ProtocolKind::SECURENN) { - p0 = arshift_p(obj.get(), p0, 1); + p0 = arshift_p(obj.get(), p0, {1}); } auto a0 = p2a(obj.get(), p0); @@ -850,8 +852,10 @@ TEST_P(ConversionTest, MSB) { /* THEN */ EXPECT_TRUE(verifyCost(obj->prot()->getKernel("msb_a2b"), "msb_a2b", conf.field(), kShape, npc, cost)); - EXPECT_VALUE_EQ(rshift_p(obj.get(), p0, SizeOf(conf.field()) * 8 - 1), - b2p(obj.get(), b1)); + EXPECT_VALUE_EQ( + rshift_p(obj.get(), p0, + {static_cast(SizeOf(conf.field()) * 8 - 1)}), + b2p(obj.get(), b1)); }); } diff --git a/libspu/mpc/aby3/arithmetic.cc b/libspu/mpc/aby3/arithmetic.cc index a23749ff..62a78436 100644 --- a/libspu/mpc/aby3/arithmetic.cc +++ b/libspu/mpc/aby3/arithmetic.cc @@ -41,7 +41,7 @@ std::vector a1b_offline(size_t sender, const NdArrayRef& a, auto numel = a.numel(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() -> std::vector { + return DISPATCH_ALL_FIELDS(field, [&]() -> std::vector { using ashr_el_t = ring2k_t; NdArrayRef m0(makeType(field), a.shape()); NdArrayRef m1(makeType(field), a.shape()); @@ -52,7 +52,7 @@ std::vector a1b_offline(size_t sender, const NdArrayRef& a, NdArrayView _m1(m1); return DISPATCH_UINT_PT_TYPES( - b.eltype().as()->getBacktype(), "_", + b.eltype().as()->getBacktype(), [&]() -> std::vector { using bshr_t = std::array; if (self_rank == sender) { @@ -82,8 +82,6 @@ std::vector a1b_offline(size_t sender, const NdArrayRef& a, return {c1, c2, m0, m1}; } else if (self_rank == (sender + 1) % 3) { - prg_state->genPrssPair(field, a.shape(), - PrgState::GenPrssCtrl::None); auto c1 = prg_state ->genPrssPair(field, a.shape(), PrgState::GenPrssCtrl::First) @@ -95,8 +93,6 @@ std::vector a1b_offline(size_t sender, const NdArrayRef& a, ->genPrssPair(field, a.shape(), PrgState::GenPrssCtrl::Second) .second; - prg_state->genPrssPair(field, a.shape(), - PrgState::GenPrssCtrl::None); return {c2}; } @@ -109,7 +105,7 @@ std::vector ring_cast_boolean(const NdArrayRef& x) { const size_t numel = x.numel(); std::vector res(numel); - DISPATCH_UINT_PT_TYPES(x.eltype().as()->pt_type(), "_", [&]() { + DISPATCH_UINT_PT_TYPES(x.eltype().as()->pt_type(), [&]() { NdArrayView _x(x); pforeach(0, numel, [&](int64_t idx) { res[idx] = static_cast(_x[idx] & 0x1); @@ -127,7 +123,7 @@ NdArrayRef RandA::proc(KernelEvalContext* ctx, const Shape& shape) const { NdArrayRef out(makeType(field), shape); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; std::vector r0(shape.numel()); @@ -153,7 +149,7 @@ NdArrayRef A2P::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { const auto field = in.eltype().as()->field(); auto numel = in.numel(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using pshr_el_t = ring2k_t; using ashr_el_t = ring2k_t; using ashr_t = std::array; @@ -184,7 +180,7 @@ NdArrayRef P2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { auto rank = comm->getRank(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using ashr_el_t = ring2k_t; using ashr_t = std::array; using pshr_el_t = ring2k_t; @@ -226,7 +222,7 @@ NdArrayRef A2V::proc(KernelEvalContext* ctx, const NdArrayRef& in, auto* comm = ctx->getState(); const auto field = in.eltype().as()->field(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using vshr_el_t = ring2k_t; using ashr_el_t = ring2k_t; using ashr_t = std::array; @@ -267,7 +263,7 @@ NdArrayRef V2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { size_t owner_rank = in_ty->owner(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using ashr_el_t = ring2k_t; using ashr_t = std::array; @@ -304,14 +300,11 @@ NdArrayRef V2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { }); } -NdArrayRef NotA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { - auto* comm = ctx->getState(); +NdArrayRef NegateA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { const auto* in_ty = in.eltype().as(); const auto field = in_ty->field(); - auto rank = comm->getRank(); - - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using el_t = std::make_unsigned_t; using shr_t = std::array; @@ -319,16 +312,9 @@ NdArrayRef NotA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { NdArrayView _out(out); NdArrayView _in(in); - // neg(x) = not(x) + 1 - // not(x) = neg(x) - 1 pforeach(0, in.numel(), [&](int64_t idx) { _out[idx][0] = -_in[idx][0]; _out[idx][1] = -_in[idx][1]; - if (rank == 0) { - _out[idx][1] -= 1; - } else if (rank == 1) { - _out[idx][0] -= 1; - } }); return out; @@ -349,7 +335,7 @@ NdArrayRef AddAP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, auto rank = comm->getRank(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; @@ -376,7 +362,7 @@ NdArrayRef AddAA::proc(KernelEvalContext*, const NdArrayRef& lhs, SPU_ENFORCE(lhs_ty->field() == rhs_ty->field()); const auto field = lhs_ty->field(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using shr_t = std::array; NdArrayRef out(makeType(field), lhs.shape()); @@ -403,7 +389,7 @@ NdArrayRef MulAP::proc(KernelEvalContext*, const NdArrayRef& lhs, SPU_ENFORCE(lhs_ty->field() == rhs_ty->field()); const auto field = lhs_ty->field(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; @@ -426,7 +412,7 @@ NdArrayRef MulAA::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, auto* comm = ctx->getState(); auto* prg_state = ctx->getState(); - return DISPATCH_ALL_FIELDS(field, "aby3.mulAA", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; @@ -575,7 +561,7 @@ NdArrayRef MulA1B::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, .reshape(r2.first.shape()); } - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { // using = ring2k_t; NdArrayView r1_0(r1.first); NdArrayView r1_1(r1.second); @@ -661,13 +647,13 @@ NdArrayRef MatMulAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, auto z1 = ring_sum({t0, t2.get(), r.get()}); auto f = std::async([&] { ring_assign(o1, z1); }); - ring_assign(o2, comm->rotate(z1, kBindName)); // comm => 1, k + ring_assign(o2, comm->rotate(z1, kBindName())); // comm => 1, k f.get(); #ifdef CUDA_ENABLED } else { matmul_aa_gpu(x, y, o1); ring_add_(o1, r.get()); - ring_assign(o2, comm->rotate(o1, kBindName)); // comm => 1, k + ring_assign(o2, comm->rotate(o1, kBindName())); // comm => 1, k } #endif @@ -675,11 +661,12 @@ NdArrayRef MatMulAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, } NdArrayRef LShiftA::proc(KernelEvalContext*, const NdArrayRef& in, - size_t bits) const { + const Sizes& bits) const { const auto* in_ty = in.eltype().as(); const auto field = in_ty->field(); + bool is_splat = bits.size() == 1; - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using shr_t = std::array; NdArrayRef out(makeType(field), in.shape()); @@ -687,8 +674,9 @@ NdArrayRef LShiftA::proc(KernelEvalContext*, const NdArrayRef& in, NdArrayView _in(in); pforeach(0, in.numel(), [&](int64_t idx) { - _out[idx][0] = _in[idx][0] << bits; - _out[idx][1] = _in[idx][1] << bits; + auto shift_bit = is_splat ? bits[0] : bits[idx]; + _out[idx][0] = _in[idx][0] << shift_bit; + _out[idx][1] = _in[idx][1] << shift_bit; }); return out; @@ -722,22 +710,23 @@ NdArrayRef TruncA::proc(KernelEvalContext* ctx, const NdArrayRef& in, comm->addCommStatsManually(1, kComm); // comm => 1, 2 // ret + const Sizes shift_bit = {static_cast(bits)}; switch (comm->getRank()) { case 0: { - const auto z1 = ring_arshift(x1, bits); - const auto z2 = comm->recv(1, x1.eltype(), kBindName); + const auto z1 = ring_arshift(x1, shift_bit); + const auto z2 = comm->recv(1, x1.eltype(), kBindName()); return makeAShare(z1, z2, field); } case 1: { auto r1 = r_future.get().second; - const auto z1 = ring_sub(ring_arshift(ring_add(x1, x2), bits), r1); - comm->sendAsync(0, z1, kBindName); + const auto z1 = ring_sub(ring_arshift(ring_add(x1, x2), shift_bit), r1); + comm->sendAsync(0, z1, kBindName()); return makeAShare(z1, r1, field); } case 2: { - const auto z2 = ring_arshift(x2, bits); + const auto z2 = ring_arshift(x2, shift_bit); return makeAShare(r_future.get().first, z2, field); } @@ -790,7 +779,7 @@ NdArrayRef TruncAPr::proc(KernelEvalContext* ctx, const NdArrayRef& in, size_t P2 = (pivot + 2) % 3; NdArrayRef out(in.eltype(), in.shape()); - DISPATCH_ALL_FIELDS(field, "aby3.truncpr", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; diff --git a/libspu/mpc/aby3/arithmetic.h b/libspu/mpc/aby3/arithmetic.h index 75135fa7..e7ff3520 100644 --- a/libspu/mpc/aby3/arithmetic.h +++ b/libspu/mpc/aby3/arithmetic.h @@ -26,7 +26,7 @@ namespace spu::mpc::aby3 { class A2P : public UnaryKernel { public: - static constexpr char kBindName[] = "a2p"; + static constexpr const char* kBindName() { return "a2p"; } ce::CExpr latency() const override { // 1 * rotate: 1 @@ -43,7 +43,7 @@ class A2P : public UnaryKernel { class P2A : public UnaryKernel { public: - static constexpr char kBindName[] = "p2a"; + static constexpr const char* kBindName() { return "p2a"; } ce::CExpr latency() const override { #ifdef ENABLE_MASK_DURING_ABY3_P2A @@ -66,7 +66,7 @@ class P2A : public UnaryKernel { class A2V : public RevealToKernel { public: - static constexpr char kBindName[] = "a2v"; + static constexpr const char* kBindName() { return "a2v"; } // TODO: communication is unbalanced Kind kind() const override { return Kind::Dynamic; } @@ -87,7 +87,7 @@ class A2V : public RevealToKernel { class V2A : public UnaryKernel { public: - static constexpr char kBindName[] = "v2a"; + static constexpr const char* kBindName() { return "v2a"; } // TODO: communication is unbalanced Kind kind() const override { return Kind::Dynamic; } @@ -107,7 +107,7 @@ class V2A : public UnaryKernel { class RandA : public RandKernel { public: - static constexpr char kBindName[] = "rand_a"; + static constexpr const char* kBindName() { return "rand_a"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -116,9 +116,9 @@ class RandA : public RandKernel { NdArrayRef proc(KernelEvalContext* ctx, const Shape& shape) const override; }; -class NotA : public UnaryKernel { +class NegateA : public UnaryKernel { public: - static constexpr char kBindName[] = "not_a"; + static constexpr const char* kBindName() { return "negate_a"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -132,7 +132,7 @@ class NotA : public UnaryKernel { //////////////////////////////////////////////////////////////////// class AddAP : public BinaryKernel { public: - static constexpr char kBindName[] = "add_ap"; + static constexpr const char* kBindName() { return "add_ap"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -144,7 +144,7 @@ class AddAP : public BinaryKernel { class AddAA : public BinaryKernel { public: - static constexpr char kBindName[] = "add_aa"; + static constexpr const char* kBindName() { return "add_aa"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -159,7 +159,7 @@ class AddAA : public BinaryKernel { //////////////////////////////////////////////////////////////////// class MulAP : public BinaryKernel { public: - static constexpr char kBindName[] = "mul_ap"; + static constexpr const char* kBindName() { return "mul_ap"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -171,7 +171,7 @@ class MulAP : public BinaryKernel { class MulAA : public BinaryKernel { public: - static constexpr char kBindName[] = "mul_aa"; + static constexpr const char* kBindName() { return "mul_aa"; } ce::CExpr latency() const override { // 1 * rotate: 1 @@ -189,7 +189,7 @@ class MulAA : public BinaryKernel { class MulA1B : public BinaryKernel { public: - static constexpr char kBindName[] = "mul_a1b"; + static constexpr const char* kBindName() { return "mul_a1b"; } ce::CExpr latency() const override { return ce::Const(2); } @@ -204,7 +204,7 @@ class MulA1B : public BinaryKernel { //////////////////////////////////////////////////////////////////// class MatMulAP : public MatmulKernel { public: - static constexpr char kBindName[] = "mmul_ap"; + static constexpr const char* kBindName() { return "mmul_ap"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -216,7 +216,7 @@ class MatMulAP : public MatmulKernel { class MatMulAA : public MatmulKernel { public: - static constexpr char kBindName[] = "mmul_aa"; + static constexpr const char* kBindName() { return "mmul_aa"; } ce::CExpr latency() const override { // 1 * rotate: 1 @@ -236,14 +236,14 @@ class MatMulAA : public MatmulKernel { class LShiftA : public ShiftKernel { public: - static constexpr char kBindName[] = "lshift_a"; + static constexpr const char* kBindName() { return "lshift_a"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; // Refer to: @@ -252,7 +252,7 @@ class LShiftA : public ShiftKernel { // - https://eprint.iacr.org/2018/403.pdf class TruncA : public TruncAKernel { public: - static constexpr char kBindName[] = "trunc_a"; + static constexpr const char* kBindName() { return "trunc_a"; } ce::CExpr latency() const override { return ce::Const(1); } @@ -274,7 +274,7 @@ class TruncA : public TruncAKernel { // - https://arxiv.org/pdf/1910.12435.pdf class TruncAPr : public TruncAKernel { public: - static constexpr char kBindName[] = "trunc_a"; + static constexpr const char* kBindName() { return "trunc_a"; } ce::CExpr latency() const override { return ce::Const(3); } diff --git a/libspu/mpc/aby3/boolean.cc b/libspu/mpc/aby3/boolean.cc index 8e07307a..96033cc5 100644 --- a/libspu/mpc/aby3/boolean.cc +++ b/libspu/mpc/aby3/boolean.cc @@ -42,11 +42,11 @@ void CommonTypeB::evaluate(KernelEvalContext* ctx) const { NdArrayRef CastTypeB::proc(KernelEvalContext*, const NdArrayRef& in, const Type& to_type) const { NdArrayRef out(to_type, in.shape()); - DISPATCH_UINT_PT_TYPES(in.eltype().as()->getBacktype(), "_", [&]() { + DISPATCH_UINT_PT_TYPES(in.eltype().as()->getBacktype(), [&]() { using in_el_t = ScalarT; using in_shr_t = std::array; - DISPATCH_UINT_PT_TYPES(to_type.as()->getBacktype(), "_", [&]() { + DISPATCH_UINT_PT_TYPES(to_type.as()->getBacktype(), [&]() { using out_el_t = ScalarT; using out_shr_t = std::array; @@ -69,11 +69,11 @@ NdArrayRef B2P::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { const PtType btype = in.eltype().as()->getBacktype(); const auto field = ctx->getState()->getDefaultField(); - return DISPATCH_UINT_PT_TYPES(btype, "aby3.b2p", [&]() { + return DISPATCH_UINT_PT_TYPES(btype, [&]() { using bshr_el_t = ScalarT; using bshr_t = std::array; - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using pshr_el_t = ring2k_t; NdArrayRef out(makeType(field), in.shape()); @@ -99,12 +99,12 @@ NdArrayRef P2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { const auto* in_ty = in.eltype().as(); const auto field = in_ty->field(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { const size_t nbits = maxBitWidth(in); const PtType btype = calcBShareBacktype(nbits); NdArrayView _in(in); - return DISPATCH_UINT_PT_TYPES(btype, "_", [&]() { + return DISPATCH_UINT_PT_TYPES(btype, [&]() { using bshr_el_t = ScalarT; using bshr_t = std::array; @@ -134,12 +134,12 @@ NdArrayRef B2V::proc(KernelEvalContext* ctx, const NdArrayRef& in, const PtType btype = in.eltype().as()->getBacktype(); const auto field = ctx->getState()->getDefaultField(); - return DISPATCH_UINT_PT_TYPES(btype, "aby3.b2v", [&]() { + return DISPATCH_UINT_PT_TYPES(btype, [&]() { using bshr_el_t = ScalarT; using bshr_t = std::array; NdArrayView _in(in); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using vshr_scalar_t = ring2k_t; auto out_ty = makeType(field, rank); @@ -177,7 +177,7 @@ NdArrayRef AndBP::proc(KernelEvalContext*, const NdArrayRef& lhs, const auto* lhs_ty = lhs.eltype().as(); const auto* rhs_ty = rhs.eltype().as(); - return DISPATCH_ALL_FIELDS(rhs_ty->field(), "_", [&]() { + return DISPATCH_ALL_FIELDS(rhs_ty->field(), [&]() { using rhs_scalar_t = ring2k_t; const size_t rhs_nbits = maxBitWidth(rhs); @@ -186,13 +186,13 @@ NdArrayRef AndBP::proc(KernelEvalContext*, const NdArrayRef& lhs, NdArrayView _rhs(rhs); - return DISPATCH_UINT_PT_TYPES(lhs_ty->getBacktype(), "_", [&]() { + return DISPATCH_UINT_PT_TYPES(lhs_ty->getBacktype(), [&]() { using lhs_el_t = ScalarT; using lhs_shr_t = std::array; NdArrayView _lhs(lhs); - return DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { + return DISPATCH_UINT_PT_TYPES(out_btype, [&]() { using out_el_t = ScalarT; using out_shr_t = std::array; @@ -224,17 +224,17 @@ NdArrayRef AndBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const PtType out_btype = calcBShareBacktype(out_nbits); NdArrayRef out(makeType(out_btype, out_nbits), lhs.shape()); - return DISPATCH_UINT_PT_TYPES(rhs_ty->getBacktype(), "_", [&]() { + return DISPATCH_UINT_PT_TYPES(rhs_ty->getBacktype(), [&]() { using rhs_el_t = ScalarT; using rhs_shr_t = std::array; NdArrayView _rhs(rhs); - return DISPATCH_UINT_PT_TYPES(lhs_ty->getBacktype(), "_", [&]() { + return DISPATCH_UINT_PT_TYPES(lhs_ty->getBacktype(), [&]() { using lhs_el_t = ScalarT; using lhs_shr_t = std::array; NdArrayView _lhs(lhs); - return DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { + return DISPATCH_UINT_PT_TYPES(out_btype, [&]() { using out_el_t = ScalarT; using out_shr_t = std::array; @@ -269,7 +269,7 @@ NdArrayRef XorBP::proc(KernelEvalContext*, const NdArrayRef& lhs, const auto* lhs_ty = lhs.eltype().as(); const auto* rhs_ty = rhs.eltype().as(); - return DISPATCH_ALL_FIELDS(rhs_ty->field(), "_", [&]() { + return DISPATCH_ALL_FIELDS(rhs_ty->field(), [&]() { using rhs_scalar_t = ring2k_t; const size_t rhs_nbits = maxBitWidth(rhs); @@ -280,13 +280,13 @@ NdArrayRef XorBP::proc(KernelEvalContext*, const NdArrayRef& lhs, NdArrayRef out(makeType(out_btype, out_nbits), lhs.shape()); - return DISPATCH_UINT_PT_TYPES(lhs_ty->getBacktype(), "_", [&]() { + return DISPATCH_UINT_PT_TYPES(lhs_ty->getBacktype(), [&]() { using lhs_el_t = ScalarT; using lhs_shr_t = std::array; NdArrayView _lhs(lhs); - return DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { + return DISPATCH_UINT_PT_TYPES(out_btype, [&]() { using out_el_t = ScalarT; using out_shr_t = std::array; @@ -311,19 +311,19 @@ NdArrayRef XorBB::proc(KernelEvalContext*, const NdArrayRef& lhs, const size_t out_nbits = std::max(lhs_ty->nbits(), rhs_ty->nbits()); const PtType out_btype = calcBShareBacktype(out_nbits); - return DISPATCH_UINT_PT_TYPES(rhs_ty->getBacktype(), "_", [&]() { + return DISPATCH_UINT_PT_TYPES(rhs_ty->getBacktype(), [&]() { using rhs_el_t = ScalarT; using rhs_shr_t = std::array; NdArrayView _rhs(rhs); - return DISPATCH_UINT_PT_TYPES(lhs_ty->getBacktype(), "_", [&]() { + return DISPATCH_UINT_PT_TYPES(lhs_ty->getBacktype(), [&]() { using lhs_el_t = ScalarT; using lhs_shr_t = std::array; NdArrayView _lhs(lhs); - return DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { + return DISPATCH_UINT_PT_TYPES(out_btype, [&]() { using out_el_t = ScalarT; using out_shr_t = std::array; @@ -343,21 +343,24 @@ NdArrayRef XorBB::proc(KernelEvalContext*, const NdArrayRef& lhs, } NdArrayRef LShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const { + const Sizes& bits) const { const auto* in_ty = in.eltype().as(); // TODO: the hal dtype should tell us about the max number of possible bits. const auto field = ctx->getState()->getDefaultField(); - const size_t out_nbits = std::min(in_ty->nbits() + bits, SizeOf(field) * 8); + const size_t out_nbits = std::min( + in_ty->nbits() + *std::max_element(bits.begin(), bits.end()), + SizeOf(field) * 8); const PtType out_btype = calcBShareBacktype(out_nbits); + bool is_splat = bits.size() == 1; - return DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { + return DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), [&]() { using in_el_t = ScalarT; using in_shr_t = std::array; NdArrayView _in(in); - return DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { + return DISPATCH_UINT_PT_TYPES(out_btype, [&]() { using out_el_t = ScalarT; using out_shr_t = std::array; @@ -366,8 +369,9 @@ NdArrayRef LShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, pforeach(0, in.numel(), [&](int64_t idx) { const auto& v = _in[idx]; - _out[idx][0] = static_cast(v[0]) << bits; - _out[idx][1] = static_cast(v[1]) << bits; + auto shift_bit = is_splat ? bits[0] : bits[idx]; + _out[idx][0] = static_cast(v[0]) << shift_bit; + _out[idx][1] = static_cast(v[1]) << shift_bit; }); return out; @@ -376,19 +380,19 @@ NdArrayRef LShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, } NdArrayRef RShiftB::proc(KernelEvalContext*, const NdArrayRef& in, - size_t bits) const { + const Sizes& bits) const { const auto* in_ty = in.eltype().as(); - bits = std::min(in_ty->nbits(), bits); - size_t out_nbits = in_ty->nbits(); - out_nbits -= std::min(out_nbits, bits); + int64_t out_nbits = in_ty->nbits(); + out_nbits -= std::min(out_nbits, *std::min_element(bits.begin(), bits.end())); const PtType out_btype = calcBShareBacktype(out_nbits); + bool is_splat = bits.size() == 1; - return DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { + return DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), [&]() { using in_shr_t = std::array; NdArrayView _in(in); - return DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { + return DISPATCH_UINT_PT_TYPES(out_btype, [&]() { using out_el_t = ScalarT; using out_shr_t = std::array; @@ -397,8 +401,9 @@ NdArrayRef RShiftB::proc(KernelEvalContext*, const NdArrayRef& in, pforeach(0, in.numel(), [&](int64_t idx) { const auto& v = _in[idx]; - _out[idx][0] = static_cast(v[0] >> bits); - _out[idx][1] = static_cast(v[1] >> bits); + auto shift_bit = is_splat ? bits[0] : bits[idx]; + _out[idx][0] = static_cast(v[0] >> shift_bit); + _out[idx][1] = static_cast(v[1] >> shift_bit); }); return out; @@ -407,9 +412,10 @@ NdArrayRef RShiftB::proc(KernelEvalContext*, const NdArrayRef& in, } NdArrayRef ARShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const { + const Sizes& bits) const { const auto field = ctx->getState()->getDefaultField(); const auto* in_ty = in.eltype().as(); + bool is_splat = bits.size() == 1; // arithmetic right shift expects to work on ring, or the behaviour is // undefined. @@ -418,7 +424,7 @@ NdArrayRef ARShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, const PtType out_btype = in_ty->getBacktype(); const size_t out_nbits = in_ty->nbits(); - return DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { + return DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), [&]() { using el_t = std::make_signed_t; using shr_t = std::array; @@ -428,8 +434,9 @@ NdArrayRef ARShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, pforeach(0, in.numel(), [&](int64_t idx) { const auto& v = _in[idx]; - _out[idx][0] = v[0] >> bits; - _out[idx][1] = v[1] >> bits; + auto shift_bit = is_splat ? bits[0] : bits[idx]; + _out[idx][0] = v[0] >> shift_bit; + _out[idx][1] = v[1] >> shift_bit; }); return out; @@ -444,13 +451,13 @@ NdArrayRef BitrevB::proc(KernelEvalContext*, const NdArrayRef& in, size_t start, const size_t out_nbits = std::max(in_ty->nbits(), end); const PtType out_btype = calcBShareBacktype(out_nbits); - return DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { + return DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), [&]() { using in_el_t = ScalarT; using in_shr_t = std::array; NdArrayView _in(in); - return DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { + return DISPATCH_UINT_PT_TYPES(out_btype, [&]() { using out_el_t = ScalarT; using out_shr_t = std::array; @@ -488,7 +495,7 @@ NdArrayRef BitIntlB::proc(KernelEvalContext*, const NdArrayRef& in, SPU_ENFORCE(absl::has_single_bit(nbits)); NdArrayRef out(in.eltype(), in.shape()); - DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { + DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), [&]() { using el_t = ScalarT; using shr_t = std::array; NdArrayView _out(out); @@ -511,7 +518,7 @@ NdArrayRef BitDeintlB::proc(KernelEvalContext*, const NdArrayRef& in, SPU_ENFORCE(absl::has_single_bit(nbits)); NdArrayRef out(in.eltype(), in.shape()); - DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { + DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), [&]() { using el_t = ScalarT; using shr_t = std::array; NdArrayView _out(out); diff --git a/libspu/mpc/aby3/boolean.h b/libspu/mpc/aby3/boolean.h index b3036620..dac53b10 100644 --- a/libspu/mpc/aby3/boolean.h +++ b/libspu/mpc/aby3/boolean.h @@ -22,7 +22,7 @@ namespace spu::mpc::aby3 { class CommonTypeB : public Kernel { public: - static constexpr char kBindName[] = "common_type_b"; + static constexpr const char* kBindName() { return "common_type_b"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -33,7 +33,7 @@ class CommonTypeB : public Kernel { class CastTypeB : public CastTypeKernel { public: - static constexpr char kBindName[] = "cast_type_b"; + static constexpr const char* kBindName() { return "cast_type_b"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -45,7 +45,7 @@ class CastTypeB : public CastTypeKernel { class B2P : public UnaryKernel { public: - static constexpr char kBindName[] = "b2p"; + static constexpr const char* kBindName() { return "b2p"; } ce::CExpr latency() const override { // rotate : 1 @@ -62,7 +62,7 @@ class B2P : public UnaryKernel { class P2B : public UnaryKernel { public: - static constexpr char kBindName[] = "p2b"; + static constexpr const char* kBindName() { return "p2b"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -73,7 +73,7 @@ class P2B : public UnaryKernel { class B2V : public RevealToKernel { public: - static constexpr char kBindName[] = "b2v"; + static constexpr const char* kBindName() { return "b2v"; } ce::CExpr latency() const override { // 1 * send/recv: 1 @@ -91,7 +91,7 @@ class B2V : public RevealToKernel { class AndBP : public BinaryKernel { public: - static constexpr char kBindName[] = "and_bp"; + static constexpr const char* kBindName() { return "and_bp"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -103,7 +103,7 @@ class AndBP : public BinaryKernel { class AndBB : public BinaryKernel { public: - static constexpr char kBindName[] = "and_bb"; + static constexpr const char* kBindName() { return "and_bb"; } ce::CExpr latency() const override { // rotate : 1 @@ -121,7 +121,7 @@ class AndBB : public BinaryKernel { class XorBP : public BinaryKernel { public: - static constexpr char kBindName[] = "xor_bp"; + static constexpr const char* kBindName() { return "xor_bp"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -133,7 +133,7 @@ class XorBP : public BinaryKernel { class XorBB : public BinaryKernel { public: - static constexpr char kBindName[] = "xor_bb"; + static constexpr const char* kBindName() { return "xor_bb"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -145,43 +145,43 @@ class XorBB : public BinaryKernel { class LShiftB : public ShiftKernel { public: - static constexpr char kBindName[] = "lshift_b"; + static constexpr const char* kBindName() { return "lshift_b"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; class RShiftB : public ShiftKernel { public: - static constexpr char kBindName[] = "rshift_b"; + static constexpr const char* kBindName() { return "rshift_b"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; class ARShiftB : public ShiftKernel { public: - static constexpr char kBindName[] = "arshift_b"; + static constexpr const char* kBindName() { return "arshift_b"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; class BitrevB : public BitrevKernel { public: - static constexpr char kBindName[] = "bitrev_b"; + static constexpr const char* kBindName() { return "bitrev_b"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -193,7 +193,7 @@ class BitrevB : public BitrevKernel { class BitIntlB : public BitSplitKernel { public: - static constexpr char kBindName[] = "bitintl_b"; + static constexpr const char* kBindName() { return "bitintl_b"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -205,7 +205,7 @@ class BitIntlB : public BitSplitKernel { class BitDeintlB : public BitSplitKernel { public: - static constexpr char kBindName[] = "bitdeintl_b"; + static constexpr const char* kBindName() { return "bitdeintl_b"; } ce::CExpr latency() const override { return ce::Const(0); } diff --git a/libspu/mpc/aby3/conversion.cc b/libspu/mpc/aby3/conversion.cc index 070450c5..8154acac 100644 --- a/libspu/mpc/aby3/conversion.cc +++ b/libspu/mpc/aby3/conversion.cc @@ -65,11 +65,11 @@ NdArrayRef A2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { auto numel = in.numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using ashr_t = std::array; NdArrayView _in(in); - DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { + DISPATCH_UINT_PT_TYPES(out_btype, [&]() { using bshr_el_t = ScalarT; using bshr_t = std::array; @@ -152,7 +152,7 @@ NdArrayRef B2AByPPA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { if (in_nbits == 0) { // special case, it's known to be zero. - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView> _out(out); pforeach(0, numel, [&](int64_t idx) { _out[idx][0] = 0; @@ -165,11 +165,11 @@ NdArrayRef B2AByPPA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { auto* comm = ctx->getState(); auto* prg_state = ctx->getState(); - DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { + DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), [&]() { using bshr_t = std::array; NdArrayView _in(in); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using ashr_el_t = ring2k_t; using ashr_t = std::array; @@ -324,7 +324,7 @@ NdArrayRef B2AByOT::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { if (in_nbits == 0) { // special case, it's known to be zero. - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView> _out(out); pforeach(0, numel, [&](int64_t idx) { _out[idx][0] = 0; @@ -345,12 +345,12 @@ NdArrayRef B2AByOT::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { size_t P1 = (pivot + 1) % 3; size_t P2 = (pivot + 2) % 3; - DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { + DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), [&]() { using bshr_el_t = ScalarT; using bshr_t = std::array; NdArrayView _in(in); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using ashr_el_t = ring2k_t; using ashr_t = std::array; @@ -391,11 +391,6 @@ NdArrayRef B2AByOT::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { }); } else if (comm->getRank() == P1) { // the receiver - prg_state->fillPrssPair(nullptr, nullptr, r0.size(), - PrgState::GenPrssCtrl::None); - prg_state->fillPrssPair(nullptr, nullptr, r0.size(), - PrgState::GenPrssCtrl::None); - auto b2 = bitDecompose(getShare(in, 0), in_nbits); // ot.recv @@ -494,12 +489,12 @@ NdArrayRef B2AByOT::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { NdArrayRef lo(out_type, in.shape()); NdArrayRef hi(out_type, in.shape()); - DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { + DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), [&]() { using in_el_t = ScalarT; using in_shr_t = std::array; NdArrayView _in(in); - DISPATCH_UINT_PT_TYPES(out_backtype, "_", [&]() { + DISPATCH_UINT_PT_TYPES(out_backtype, [&]() { using out_el_t = ScalarT; using out_shr_t = std::array; @@ -574,7 +569,7 @@ NdArrayRef MsbA2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { makeType(GetStorageType(field), SizeOf(field) * 8); NdArrayRef m(bshr_type, in.shape()); NdArrayRef n(bshr_type, in.shape()); - DISPATCH_ALL_FIELDS(field, "aby3.msb.split", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; @@ -620,7 +615,9 @@ NdArrayRef MsbA2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { // Compute the k'th bit. // (m^n)[k] ^ carry - auto msb = xor_bb(sctx, rshift_b(sctx, xor_bb(sctx, wrap_m, wrap_n), nbits), + auto msb = xor_bb(sctx, + rshift_b(sctx, xor_bb(sctx, wrap_m, wrap_n), + {static_cast(nbits)}), carry); return UnwrapValue(msb); @@ -660,10 +657,10 @@ NdArrayRef eqz(KernelEvalContext* ctx, const NdArrayRef& in) { size_t P1 = (pivot + 1) % 3; size_t P2 = (pivot + 2) % 3; - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using ashr_el_t = ring2k_t; using ashr_t = std::array; - DISPATCH_UINT_PT_TYPES(in_bshr_btype, "_", [&]() { + DISPATCH_UINT_PT_TYPES(in_bshr_btype, [&]() { using bshr_el_t = ScalarT; std::vector zero_flag_3pc_0(numel); std::vector zero_flag_3pc_1(numel); @@ -715,10 +712,6 @@ NdArrayRef eqz(KernelEvalContext* ctx, const NdArrayRef& in) { PrgState::GenPrssCtrl::First); } else { pforeach(0, numel, [&](int64_t idx) { a_s[idx] = _in[idx][1]; }); - prg_state->fillPrssPair({}, {}, numel, - PrgState::GenPrssCtrl::None); - prg_state->fillPrssPair({}, {}, numel, - PrgState::GenPrssCtrl::None); r_arith = comm->recv(P0, "r_arith"); r_bool = comm->recv(P0, "r_bool"); } @@ -754,8 +747,6 @@ NdArrayRef eqz(KernelEvalContext* ctx, const NdArrayRef& in) { // P1 zero_flag = (not(c_p xor [r]b0)^ rz, rb1) pforeach(0, numel, [&](int64_t idx) { zero_flag_3pc_1[idx] = r_bool[idx]; }); - prg_state->fillPrssPair({}, {}, numel, - PrgState::GenPrssCtrl::None); auto flag_split = comm->recv(P1, "flag_split"); pforeach(0, numel, [&](int64_t idx) { @@ -866,7 +857,7 @@ NdArrayRef EqualAA::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const auto field = lhs_ty->field(); NdArrayRef out(makeType(field), lhs.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using shr_t = std::array; NdArrayView _out(out); NdArrayView _lhs(lhs); @@ -893,7 +884,7 @@ NdArrayRef EqualAP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, auto rank = comm->getRank(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; diff --git a/libspu/mpc/aby3/conversion.h b/libspu/mpc/aby3/conversion.h index c44a9574..b704b72d 100644 --- a/libspu/mpc/aby3/conversion.h +++ b/libspu/mpc/aby3/conversion.h @@ -27,7 +27,7 @@ namespace spu::mpc::aby3 { // Latency: 2 + log(nbits) from 2 rotate and 1 ppa. class A2B : public UnaryKernel { public: - static constexpr char kBindName[] = "a2b"; + static constexpr const char* kBindName() { return "a2b"; } ce::CExpr latency() const override { // 1 * AddBB : log(k) + 1 @@ -47,7 +47,7 @@ class A2B : public UnaryKernel { class B2ASelector : public UnaryKernel { public: - static constexpr char kBindName[] = "b2a"; + static constexpr const char* kBindName() { return "b2a"; } Kind kind() const override { return Kind::Dynamic; } @@ -59,7 +59,7 @@ class B2ASelector : public UnaryKernel { // https://encrypto.de/papers/DSZ15.pdf class B2AByPPA : public UnaryKernel { public: - static constexpr char kBindName[] = "b2a"; + static constexpr const char* kBindName() { return "b2a"; } ce::CExpr latency() const override { // 2 * rotate : 2 @@ -82,7 +82,7 @@ class B2AByPPA : public UnaryKernel { // https://eprint.iacr.org/2018/403.pdf class B2AByOT : public UnaryKernel { public: - static constexpr char kBindName[] = "b2a"; + static constexpr const char* kBindName() { return "b2a"; } ce::CExpr latency() const override { return ce::Const(2); } @@ -101,7 +101,7 @@ class B2AByOT : public UnaryKernel { class MsbA2B : public UnaryKernel { public: - static constexpr char kBindName[] = "msb_a2b"; + static constexpr const char* kBindName() { return "msb_a2b"; } ce::CExpr latency() const override { // 1 * carry : log(k) + 1 @@ -120,7 +120,7 @@ class MsbA2B : public UnaryKernel { class EqualAA : public BinaryKernel { public: - static constexpr char kBindName[] = "equal_aa"; + static constexpr const char* kBindName() { return "equal_aa"; } Kind kind() const override { return Kind::Dynamic; } @@ -130,7 +130,7 @@ class EqualAA : public BinaryKernel { class EqualAP : public BinaryKernel { public: - static constexpr char kBindName[] = "equal_ap"; + static constexpr const char* kBindName() { return "equal_ap"; } Kind kind() const override { return Kind::Dynamic; } @@ -140,7 +140,7 @@ class EqualAP : public BinaryKernel { class CommonTypeV : public Kernel { public: - static constexpr char kBindName[] = "common_type_v"; + static constexpr const char* kBindName() { return "common_type_v"; } Kind kind() const override { return Kind::Dynamic; } diff --git a/libspu/mpc/aby3/io.cc b/libspu/mpc/aby3/io.cc index bc3be20c..3ded34a9 100644 --- a/libspu/mpc/aby3/io.cc +++ b/libspu/mpc/aby3/io.cc @@ -148,7 +148,7 @@ NdArrayRef Aby3Io::fromShares(const std::vector& shares) const { SPU_ENFORCE(field_ == eltype.as()->field()); NdArrayRef out(makeType(field_), shares[0].shape()); - DISPATCH_ALL_FIELDS(field_, "_", [&]() { + DISPATCH_ALL_FIELDS(field_, [&]() { using el_t = ring2k_t; using shr_t = std::array; NdArrayView _out(out); @@ -166,10 +166,10 @@ NdArrayRef Aby3Io::fromShares(const std::vector& shares) const { } else if (eltype.isa()) { NdArrayRef out(makeType(field_), shares[0].shape()); - DISPATCH_ALL_FIELDS(field_, "_", [&]() { + DISPATCH_ALL_FIELDS(field_, [&]() { NdArrayView _out(out); - DISPATCH_UINT_PT_TYPES(eltype.as()->getBacktype(), "_", [&] { + DISPATCH_UINT_PT_TYPES(eltype.as()->getBacktype(), [&] { using shr_t = std::array; for (size_t si = 0; si < shares.size(); si++) { NdArrayView _s(shares[si]); diff --git a/libspu/mpc/aby3/oram.cc b/libspu/mpc/aby3/oram.cc index 2ccf83f0..7eeb2089 100644 --- a/libspu/mpc/aby3/oram.cc +++ b/libspu/mpc/aby3/oram.cc @@ -37,7 +37,7 @@ NdArrayRef OramOneHotAA::proc(KernelEvalContext *ctx, const NdArrayRef &in, const auto field = eltype.as()->field(); NdArrayRef out(makeType(field), {s}); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; NdArrayView out_(out); @@ -48,7 +48,7 @@ NdArrayRef OramOneHotAA::proc(KernelEvalContext *ctx, const NdArrayRef &in, // generate aeskey for dpf auto [self_aes_keys, next_aes_keys] = oram::genAesKey(ctx, 1); - auto *octx = new oram::OramContext(s); + auto octx = oram::OramContext(s); for (int64_t j = 0; j < 3; j++) { // in round (rank - 1), as helper @@ -64,16 +64,16 @@ NdArrayRef OramOneHotAA::proc(KernelEvalContext *ctx, const NdArrayRef &in, auto target_point = dpf_rank ? target_idxs_[0][0] ^ target_idxs_[0][1] : target_idxs_[0][0]; // dpf gen - octx->genDpf(ctx, static_cast(j), aes_key, - target_point); + octx.genDpf(ctx, static_cast(j), aes_key, + target_point); // B2A - octx->onehotB2A(ctx, static_cast(j)); + octx.onehotB2A(ctx, static_cast(j)); } } pforeach(0, s, [&](int64_t k) { for (int64_t j = 0; j < 2; j++) { - out_[k][j] = octx->dpf_e[j][k]; + out_[k][j] = octx.dpf_e[j][k]; } }); }); @@ -91,7 +91,7 @@ NdArrayRef OramOneHotAP::proc(KernelEvalContext *ctx, const NdArrayRef &in, const auto numel = in.numel(); NdArrayRef out(makeType(field), {s}); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; NdArrayView out_(out); @@ -124,15 +124,16 @@ NdArrayRef OramOneHotAP::proc(KernelEvalContext *ctx, const NdArrayRef &in, comm->sendAsync(dst_rank, {aes_key}, "aes_key"); aes_key += comm->recv(dst_rank, "aes_key")[0]; - auto *octx = new oram::OramContext(s); + auto octx = oram::OramContext(s); + // dpf gen - octx->genDpf(ctx, static_cast(1), aes_key, - target_point_2pc_[0]); + octx.genDpf(ctx, static_cast(1), aes_key, + target_point_2pc_[0]); // B2A - octx->onehotB2A(ctx, static_cast(1)); + octx.onehotB2A(ctx, static_cast(1)); int64_t j = comm->getRank() == 0 ? 1 : 0; - pforeach(0, s, [&](int64_t k) { out_[k] = octx->dpf_e[j][k]; }); + pforeach(0, s, [&](int64_t k) { out_[k] = octx.dpf_e[j][k]; }); } }); @@ -150,7 +151,7 @@ NdArrayRef OramReadOA::proc(KernelEvalContext *ctx, const NdArrayRef &onehot, NdArrayRef out(makeType(field), {1, index_times}); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; @@ -187,7 +188,7 @@ NdArrayRef OramReadOA::proc(KernelEvalContext *ctx, const NdArrayRef &onehot, auto f = std::async([&] { ring_assign(o1, z1); }); // reshare - ring_assign(o2, comm->rotate(z1, kBindName)); + ring_assign(o2, comm->rotate(z1, kBindName())); f.get(); }); @@ -208,7 +209,7 @@ NdArrayRef OramReadOP::proc(KernelEvalContext *ctx, const NdArrayRef &onehot, auto o2 = getSecondShare(out); int64_t db_numel = onehot.numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; @@ -243,7 +244,7 @@ NdArrayRef OramReadOP::proc(KernelEvalContext *ctx, const NdArrayRef &onehot, ring_add_(out2pc, r.get()); auto f = std::async([&] { ring_assign(o1, out2pc); }); - ring_assign(o2, comm->rotate(out2pc, kBindName)); + ring_assign(o2, comm->rotate(out2pc, kBindName())); f.get(); }); @@ -286,15 +287,11 @@ Triple> genOramBeaverPrim(KernelEvalContext *ctx, int64_t num, std::vector beaver_triple(num * 3); if (comm->getRank() == adjust_rank) { - prg->fillPrssPair(nullptr, nullptr, num * 3, - PrgState::GenPrssCtrl::None); prg->fillPrssPair(nullptr, beaver_triple.data(), num * 3, PrgState::GenPrssCtrl::Second); } else { prg->fillPrssPair(beaver_triple.data(), nullptr, num * 3, PrgState::GenPrssCtrl::First); - prg->fillPrssPair(nullptr, nullptr, num * 3, - PrgState::GenPrssCtrl::None); } std::vector a(beaver_triple.begin(), beaver_triple.begin() + num); @@ -443,14 +440,14 @@ void OramContext::onehotB2A(KernelEvalContext *ctx, DpfGenCtrl ctrl) { const std::vector v = convert_help_v[dpf_idx]; std::for_each(e.begin(), e.end(), [&](T ele) { pm += ele; }); std::for_each(v.begin(), v.end(), [&](T ele) { F -= ele; }); - auto blinded_pm = pm + r[0]; + T blinded_pm = pm + r[0]; // open blinded_pm comm->sendAsync(dst_rank, {blinded_pm}, "open(blinded_pm)"); blinded_pm += comm->recv(dst_rank, "open(blinded_pm)")[0]; auto pm_mul_F = mul2pc(ctx, {pm}, {F}, static_cast(ctrl)); - auto blinded_F = pm_mul_F[0] + r[0]; + T blinded_F = pm_mul_F[0] + r[0]; // open blinded_F comm->sendAsync(dst_rank, {blinded_F}, "open(blinded_F)"); @@ -484,9 +481,9 @@ void OramContext::genDpf(KernelEvalContext *ctx, DpfGenCtrl ctrl, uint128_t aes_key, uint128_t target_point) { auto *comm = ctx->getState(); - auto *odpf = new OramDpf(dpf_size_, yacl::crypto::SecureRandU128(), aes_key, - static_cast(target_point)); - odpf->gen(ctx, ctrl); + auto odpf = OramDpf(dpf_size_, yacl::crypto::SecureRandU128(), aes_key, + static_cast(target_point)); + odpf.gen(ctx, ctrl); auto dpf_rank = comm->getRank() == static_cast(ctrl); int64_t dpf_idx = dpf_rank ? 0 : 1; @@ -494,10 +491,10 @@ void OramContext::genDpf(KernelEvalContext *ctx, DpfGenCtrl ctrl, // cast e and v to T type and convert v to arith // leave convert e outside - std::transform(odpf->final_e.begin(), odpf->final_e.begin() + dpf_size_, + std::transform(odpf.final_e.begin(), odpf.final_e.begin() + dpf_size_, dpf_e[dpf_idx].begin(), [&](uint8_t x) { return neg_flag * static_cast(x); }); - std::transform(odpf->final_v.begin(), odpf->final_v.begin() + dpf_size_, + std::transform(odpf.final_v.begin(), odpf.final_v.begin() + dpf_size_, convert_help_v[dpf_idx].begin(), [&](uint128_t x) { return neg_flag * static_cast(x); }); }; diff --git a/libspu/mpc/aby3/oram.h b/libspu/mpc/aby3/oram.h index ac50e7c9..3e958978 100644 --- a/libspu/mpc/aby3/oram.h +++ b/libspu/mpc/aby3/oram.h @@ -24,7 +24,7 @@ namespace spu::mpc::aby3 { // Ashared index, Ashared database class OramOneHotAA : public OramOneHotKernel { public: - static constexpr char kBindName[] = "oram_onehot_aa"; + static constexpr const char* kBindName() { return "oram_onehot_aa"; } Kind kind() const override { return Kind::Dynamic; } @@ -35,7 +35,7 @@ class OramOneHotAA : public OramOneHotKernel { // Ashared index, Public database class OramOneHotAP : public OramOneHotKernel { public: - static constexpr char kBindName[] = "oram_onehot_ap"; + static constexpr const char* kBindName() { return "oram_onehot_ap"; } Kind kind() const override { return Kind::Dynamic; } @@ -45,7 +45,7 @@ class OramOneHotAP : public OramOneHotKernel { class OramReadOA : public OramReadKernel { public: - static constexpr char kBindName[] = "oram_read_aa"; + static constexpr const char* kBindName() { return "oram_read_aa"; } ce::CExpr latency() const override { // 1 * rotate: 1 @@ -64,7 +64,7 @@ class OramReadOA : public OramReadKernel { class OramReadOP : public OramReadKernel { public: - static constexpr char kBindName[] = "oram_read_ap"; + static constexpr const char* kBindName() { return "oram_read_ap"; } ce::CExpr latency() const override { // 1 * rotate: 1 @@ -101,6 +101,8 @@ class OramDpf { std::vector final_e; OramDpf() = delete; + + // clang-format off explicit OramDpf(int64_t numel, DpfKeyT root_seed, uint128_t aes_key, uint128_t target_point) : cw(Log2Ceil(numel), 0), @@ -113,6 +115,7 @@ class OramDpf { root_seed_(root_seed), aes_crypto_(yacl::crypto::SymmetricCrypto::CryptoType::AES128_ECB, aes_key, 1) {}; + // clang-format on // genrate 2pc-dpf according to 'ctrl' void gen(KernelEvalContext* ctx, DpfGenCtrl ctrl); @@ -135,10 +138,13 @@ class OramContext { std::vector> convert_help_v; OramContext() = default; + + // clang-format off explicit OramContext(int64_t dpf_size) : dpf_e(2, std::vector(dpf_size)), convert_help_v(2, std::vector(dpf_size)), dpf_size_(dpf_size) {}; + // clang-format on void genDpf(KernelEvalContext* ctx, DpfGenCtrl ctrl, uint128_t aes_key, uint128_t target_point); diff --git a/libspu/mpc/aby3/ot.cc b/libspu/mpc/aby3/ot.cc index 122e0fa3..8f5f4863 100644 --- a/libspu/mpc/aby3/ot.cc +++ b/libspu/mpc/aby3/ot.cc @@ -71,8 +71,6 @@ std::pair Ot3::genMasks() { } } else { SPU_ENFORCE(comm_->getRank() == roles_.receiver); - prg_state_->genPrssPair(field_, shape_, PrgState::GenPrssCtrl::None); - prg_state_->genPrssPair(field_, shape_, PrgState::GenPrssCtrl::None); } return {w0, w1}; diff --git a/libspu/mpc/aby3/permute.cc b/libspu/mpc/aby3/permute.cc index edbeed7f..bd5f95c1 100644 --- a/libspu/mpc/aby3/permute.cc +++ b/libspu/mpc/aby3/permute.cc @@ -23,38 +23,18 @@ namespace spu::mpc::aby3 { -namespace { - -PermVector ring2pv(const NdArrayRef& x) { - SPU_ENFORCE(x.eltype().isa(), "must be ring2k_type, got={}", - x.eltype()); - const auto field = x.eltype().as()->field(); - PermVector pv(x.numel()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { - NdArrayView _x(x); - pforeach(0, x.numel(), [&](int64_t idx) { pv[idx] = int64_t(_x[idx]); }); - }); - return pv; -} - -} // namespace - NdArrayRef RandPermM::proc(KernelEvalContext* ctx, const Shape& shape) const { NdArrayRef out(makeType(), shape); - // generate a RandU64 pair as permutation seeds auto* prg_state = ctx->getState(); - const auto [seed_self, seed_next] = - prg_state->genPrssPair(FieldType::FM64, {1}, PrgState::GenPrssCtrl::Both); - NdArrayView _seed_self(seed_self); - NdArrayView _seed_next(seed_next); - const auto pv_self = genRandomPerm(out.numel(), _seed_self[0]); - const auto pv_next = genRandomPerm(out.numel(), _seed_next[0]); + const auto& pvs = prg_state->genPrssPermPair(out.numel()); + const auto& pv_self = pvs.first; + const auto& pv_next = pvs.second; const auto field = out.eltype().as()->field(); auto out1 = getFirstShare(out); auto out2 = getSecondShare(out); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _out1(out1); NdArrayView _out2(out2); pforeach(0, out.numel(), [&](int64_t idx) { @@ -74,11 +54,11 @@ NdArrayRef PermAM::proc(KernelEvalContext* ctx, const NdArrayRef& in, const auto field = in.eltype().as()->field(); auto* prg_state = ctx->getState(); - PermVector pv_self = ring2pv(getFirstShare(perm)); - PermVector pv_next = ring2pv(getSecondShare(perm)); + auto pv_self = getFirstShare(perm); + auto pv_next = getSecondShare(perm); NdArrayRef out(in.eltype(), in.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; @@ -90,77 +70,87 @@ NdArrayRef PermAM::proc(KernelEvalContext* ctx, const NdArrayRef& in, prg_state->fillPrssPair(a0.data(), a1.data(), a0.size(), PrgState::GenPrssCtrl::Both); - if (comm->getRank() == 0) { - std::vector tmp(numel); - std::vector delta(numel); - pforeach(0, numel, [&](int64_t idx) { - tmp[idx] = _in[pv_self[idx]][0] + _in[pv_self[idx]][1] - a0[idx]; - }); - pforeach(0, numel, - [&](int64_t idx) { delta[idx] = tmp[pv_next[idx]] - a1[idx]; }); - comm->sendAsync(2, delta, "delta"); - - // 2to3 re-share - std::vector r0(numel); - std::vector r1(numel); - prg_state->fillPrssPair(r0.data(), r1.data(), r1.size(), - PrgState::GenPrssCtrl::Both); - pforeach(0, numel, [&](int64_t idx) { - _out[idx][0] = r0[idx]; - _out[idx][1] = r1[idx]; - }); - - } else if (comm->getRank() == 1) { - auto gama = comm->recv(2, "gama"); - std::vector tmp(numel); - std::vector beta(numel); - - pforeach(0, numel, - [&](int64_t idx) { tmp[idx] = gama[pv_self[idx]] + a0[idx]; }); - pforeach(0, numel, [&](int64_t idx) { beta[idx] = tmp[pv_next[idx]]; }); - - // 2to3 re-share - std::vector r0(numel); - prg_state->fillPrssPair(r0.data(), nullptr, r0.size(), - PrgState::GenPrssCtrl::First); - pforeach(0, numel, [&](int64_t idx) { beta[idx] -= r0[idx]; }); - - comm->sendAsync(2, beta, "2to3"); - tmp = comm->recv(2, "2to3"); - - pforeach(0, numel, [&](int64_t idx) { - _out[idx][0] = r0[idx]; - _out[idx][1] = beta[idx] + tmp[idx]; - }); - - } else if (comm->getRank() == 2) { - std::vector gama(numel); - std::vector beta(numel); - pforeach(0, numel, [&](int64_t idx) { - gama[idx] = _in[pv_next[idx]][0] + a1[idx]; - }); - comm->sendAsync(1, gama, "gama"); - auto delta = comm->recv(0, "delta"); - pforeach(0, numel, [&](int64_t idx) { beta[idx] = delta[pv_self[idx]]; }); - - // 2to3 re-share - std::vector r1(numel); - prg_state->fillPrssPair(nullptr, r1.data(), r1.size(), - PrgState::GenPrssCtrl::Second); - pforeach(0, numel, [&](int64_t idx) { // - beta[idx] -= r1[idx]; - }); - comm->sendAsync(1, beta, "2to3"); - auto tmp = comm->recv(1, "2to3"); - - // rebuild the final result. - pforeach(0, numel, [&](int64_t idx) { - _out[idx][0] = beta[idx] + tmp[idx]; - _out[idx][1] = r1[idx]; - }); - } else { - SPU_THROW("Party number exceeds 3!"); - } + const auto pv_field = pv_self.eltype().as()->field(); + DISPATCH_ALL_FIELDS(pv_field, [&]() { + using pv_t = ring2k_t; + NdArrayView _pv_self(pv_self); + NdArrayView _pv_next(pv_next); + if (comm->getRank() == 0) { + std::vector tmp(numel); + std::vector delta(numel); + pforeach(0, numel, [&](int64_t idx) { + tmp[idx] = _in[_pv_self[idx]][0] + _in[_pv_self[idx]][1] - a0[idx]; + }); + pforeach(0, numel, [&](int64_t idx) { + delta[idx] = tmp[_pv_next[idx]] - a1[idx]; + }); + comm->sendAsync(2, delta, "delta"); + + // 2to3 re-share + std::vector r0(numel); + std::vector r1(numel); + prg_state->fillPrssPair(r0.data(), r1.data(), r1.size(), + PrgState::GenPrssCtrl::Both); + pforeach(0, numel, [&](int64_t idx) { + _out[idx][0] = r0[idx]; + _out[idx][1] = r1[idx]; + }); + + } else if (comm->getRank() == 1) { + auto gama = comm->recv(2, "gama"); + std::vector tmp(numel); + std::vector beta(numel); + + pforeach(0, numel, [&](int64_t idx) { + tmp[idx] = gama[_pv_self[idx]] + a0[idx]; + }); + pforeach(0, numel, + [&](int64_t idx) { beta[idx] = tmp[_pv_next[idx]]; }); + + // 2to3 re-share + std::vector r0(numel); + prg_state->fillPrssPair(r0.data(), nullptr, r0.size(), + PrgState::GenPrssCtrl::First); + pforeach(0, numel, [&](int64_t idx) { beta[idx] -= r0[idx]; }); + + comm->sendAsync(2, beta, "2to3"); + tmp = comm->recv(2, "2to3"); + + pforeach(0, numel, [&](int64_t idx) { + _out[idx][0] = r0[idx]; + _out[idx][1] = beta[idx] + tmp[idx]; + }); + + } else if (comm->getRank() == 2) { + std::vector gama(numel); + std::vector beta(numel); + pforeach(0, numel, [&](int64_t idx) { + gama[idx] = _in[_pv_next[idx]][0] + a1[idx]; + }); + comm->sendAsync(1, gama, "gama"); + auto delta = comm->recv(0, "delta"); + pforeach(0, numel, + [&](int64_t idx) { beta[idx] = delta[_pv_self[idx]]; }); + + // 2to3 re-share + std::vector r1(numel); + prg_state->fillPrssPair(nullptr, r1.data(), r1.size(), + PrgState::GenPrssCtrl::Second); + pforeach(0, numel, [&](int64_t idx) { // + beta[idx] -= r1[idx]; + }); + comm->sendAsync(1, beta, "2to3"); + auto tmp = comm->recv(1, "2to3"); + + // rebuild the final result. + pforeach(0, numel, [&](int64_t idx) { + _out[idx][0] = beta[idx] + tmp[idx]; + _out[idx][1] = r1[idx]; + }); + } else { + SPU_THROW("Party number exceeds 3!"); + } + }); }); return out; } @@ -170,11 +160,10 @@ NdArrayRef PermAP::proc(KernelEvalContext* ctx, const NdArrayRef& in, NdArrayRef out(in.eltype(), in.shape()); if (out.numel() != 0) { - PermVector pv = ring2pv(perm); - const auto& in1 = getFirstShare(in); - const auto& in2 = getSecondShare(in); - auto perm1 = applyPerm(in1, pv); - auto perm2 = applyPerm(in2, pv); + const auto in1 = getFirstShare(in); + const auto in2 = getSecondShare(in); + auto perm1 = applyPerm(in1, perm); + auto perm2 = applyPerm(in2, perm); auto out1 = getFirstShare(out); auto out2 = getSecondShare(out); @@ -194,11 +183,11 @@ NdArrayRef InvPermAM::proc(KernelEvalContext* ctx, const NdArrayRef& in, const auto field = in.eltype().as()->field(); auto* prg_state = ctx->getState(); - PermVector pv_self = ring2pv(getFirstShare(perm)); - PermVector pv_next = ring2pv(getSecondShare(perm)); + auto pv_self = getFirstShare(perm); + auto pv_next = getSecondShare(perm); NdArrayRef out(in.eltype(), in.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; @@ -210,82 +199,91 @@ NdArrayRef InvPermAM::proc(KernelEvalContext* ctx, const NdArrayRef& in, prg_state->fillPrssPair(a0.data(), a1.data(), a0.size(), PrgState::GenPrssCtrl::Both); - if (comm->getRank() == 0) { - std::vector beta(numel); - std::vector tmp(numel); - auto gama = comm->recv(2, "gama"); - - pforeach(0, numel, [&](int64_t idx) { - tmp[pv_next[idx]] = gama[idx] + a1[pv_next[idx]]; - }); - pforeach(0, numel, [&](int64_t idx) { beta[pv_self[idx]] = tmp[idx]; }); - - // 2to3 re-share - std::vector r1(numel); - prg_state->fillPrssPair(nullptr, r1.data(), r1.size(), - PrgState::GenPrssCtrl::Second); - pforeach(0, numel, [&](int64_t idx) { // - beta[idx] -= r1[idx]; - }); - comm->sendAsync(2, beta, "2to3"); - tmp = comm->recv(2, "2to3"); - - // rebuild the final result. - pforeach(0, numel, [&](int64_t idx) { - _out[idx][0] = beta[idx] + tmp[idx]; - _out[idx][1] = r1[idx]; - }); - } else if (comm->getRank() == 1) { - std::vector tmp(numel); - std::vector delta(numel); - - pforeach(0, numel, [&](int64_t idx) { - tmp[pv_next[idx]] = _in[idx][0] + _in[idx][1] - a1[pv_next[idx]]; - }); - pforeach(0, numel, [&](int64_t idx) { - delta[pv_self[idx]] = tmp[idx] - a0[pv_self[idx]]; - }); - comm->sendAsync(2, delta, "delta"); - - // 2to3 re-share - std::vector r0(numel); - std::vector r1(numel); - prg_state->fillPrssPair(r0.data(), r1.data(), r1.size(), - PrgState::GenPrssCtrl::Both); - pforeach(0, numel, [&](int64_t idx) { - _out[idx][0] = r0[idx]; - _out[idx][1] = r1[idx]; - }); - - } else if (comm->getRank() == 2) { - std::vector gama(numel); - std::vector beta(numel); - pforeach(0, numel, [&](int64_t idx) { - gama[pv_self[idx]] = _in[idx][1] + a0[pv_self[idx]]; - }); - comm->sendAsync(0, gama, "gama"); - auto delta = comm->recv(1, "delta"); - pforeach(0, numel, [&](int64_t idx) { beta[pv_next[idx]] = delta[idx]; }); - - // 2to3 re-share - std::vector r0(numel); - prg_state->fillPrssPair(r0.data(), nullptr, r0.size(), - PrgState::GenPrssCtrl::First); - pforeach(0, numel, [&](int64_t idx) { // - beta[idx] -= r0[idx]; - }); - - comm->sendAsync(0, beta, "2to3"); - auto tmp = comm->recv(0, "2to3"); - - pforeach(0, numel, [&](int64_t idx) { - _out[idx][0] = r0[idx]; - _out[idx][1] = beta[idx] + tmp[idx]; - }); - - } else { - SPU_THROW("Party number exceeds 3!"); - } + const auto pv_field = pv_self.eltype().as()->field(); + DISPATCH_ALL_FIELDS(pv_field, [&]() { + using pv_t = ring2k_t; + NdArrayView _pv_self(pv_self); + NdArrayView _pv_next(pv_next); + + if (comm->getRank() == 0) { + std::vector beta(numel); + std::vector tmp(numel); + auto gama = comm->recv(2, "gama"); + + pforeach(0, numel, [&](int64_t idx) { + tmp[_pv_next[idx]] = gama[idx] + a1[_pv_next[idx]]; + }); + pforeach(0, numel, + [&](int64_t idx) { beta[_pv_self[idx]] = tmp[idx]; }); + + // 2to3 re-share + std::vector r1(numel); + prg_state->fillPrssPair(nullptr, r1.data(), r1.size(), + PrgState::GenPrssCtrl::Second); + pforeach(0, numel, [&](int64_t idx) { // + beta[idx] -= r1[idx]; + }); + comm->sendAsync(2, beta, "2to3"); + tmp = comm->recv(2, "2to3"); + + // rebuild the final result. + pforeach(0, numel, [&](int64_t idx) { + _out[idx][0] = beta[idx] + tmp[idx]; + _out[idx][1] = r1[idx]; + }); + } else if (comm->getRank() == 1) { + std::vector tmp(numel); + std::vector delta(numel); + + pforeach(0, numel, [&](int64_t idx) { + tmp[_pv_next[idx]] = _in[idx][0] + _in[idx][1] - a1[_pv_next[idx]]; + }); + pforeach(0, numel, [&](int64_t idx) { + delta[_pv_self[idx]] = tmp[idx] - a0[_pv_self[idx]]; + }); + comm->sendAsync(2, delta, "delta"); + + // 2to3 re-share + std::vector r0(numel); + std::vector r1(numel); + prg_state->fillPrssPair(r0.data(), r1.data(), r1.size(), + PrgState::GenPrssCtrl::Both); + pforeach(0, numel, [&](int64_t idx) { + _out[idx][0] = r0[idx]; + _out[idx][1] = r1[idx]; + }); + + } else if (comm->getRank() == 2) { + std::vector gama(numel); + std::vector beta(numel); + pforeach(0, numel, [&](int64_t idx) { + gama[_pv_self[idx]] = _in[idx][1] + a0[_pv_self[idx]]; + }); + comm->sendAsync(0, gama, "gama"); + auto delta = comm->recv(1, "delta"); + pforeach(0, numel, + [&](int64_t idx) { beta[_pv_next[idx]] = delta[idx]; }); + + // 2to3 re-share + std::vector r0(numel); + prg_state->fillPrssPair(r0.data(), nullptr, r0.size(), + PrgState::GenPrssCtrl::First); + pforeach(0, numel, [&](int64_t idx) { // + beta[idx] -= r0[idx]; + }); + + comm->sendAsync(0, beta, "2to3"); + auto tmp = comm->recv(0, "2to3"); + + pforeach(0, numel, [&](int64_t idx) { + _out[idx][0] = r0[idx]; + _out[idx][1] = beta[idx] + tmp[idx]; + }); + + } else { + SPU_THROW("Party number exceeds 3!"); + } + }); }); return out; } @@ -295,12 +293,11 @@ NdArrayRef InvPermAP::proc(KernelEvalContext* ctx, const NdArrayRef& in, NdArrayRef out(in.eltype(), in.shape()); if (out.numel() != 0) { - PermVector pv = ring2pv(perm); - const auto& in1 = getFirstShare(in); - const auto& in2 = getSecondShare(in); + const auto in1 = getFirstShare(in); + const auto in2 = getSecondShare(in); - auto perm1 = applyInvPerm(in1, pv); - auto perm2 = applyInvPerm(in2, pv); + auto perm1 = applyInvPerm(in1, perm); + auto perm2 = applyInvPerm(in2, perm); auto out1 = getFirstShare(out); auto out2 = getSecondShare(out); diff --git a/libspu/mpc/aby3/permute.h b/libspu/mpc/aby3/permute.h index 8a332da5..a197009b 100644 --- a/libspu/mpc/aby3/permute.h +++ b/libspu/mpc/aby3/permute.h @@ -20,7 +20,7 @@ namespace spu::mpc::aby3 { class RandPermM : public RandKernel { public: - static constexpr char kBindName[] = "rand_perm_m"; + static constexpr const char* kBindName() { return "rand_perm_m"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -31,7 +31,7 @@ class RandPermM : public RandKernel { class PermAM : public PermKernel { public: - static constexpr char kBindName[] = "perm_am"; + static constexpr const char* kBindName() { return "perm_am"; } Kind kind() const override { return Kind::Dynamic; } @@ -41,7 +41,7 @@ class PermAM : public PermKernel { class PermAP : public PermKernel { public: - static constexpr char kBindName[] = "perm_ap"; + static constexpr const char* kBindName() { return "perm_ap"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -53,7 +53,7 @@ class PermAP : public PermKernel { class InvPermAM : public PermKernel { public: - static constexpr char kBindName[] = "inv_perm_am"; + static constexpr const char* kBindName() { return "inv_perm_am"; } Kind kind() const override { return Kind::Dynamic; } @@ -63,7 +63,7 @@ class InvPermAM : public PermKernel { class InvPermAP : public PermKernel { public: - static constexpr char kBindName[] = "inv_perm_ap"; + static constexpr const char* kBindName() { return "inv_perm_ap"; } ce::CExpr latency() const override { return ce::Const(0); } diff --git a/libspu/mpc/aby3/protocol.cc b/libspu/mpc/aby3/protocol.cc index 9c92af92..9ca2e019 100644 --- a/libspu/mpc/aby3/protocol.cc +++ b/libspu/mpc/aby3/protocol.cc @@ -54,7 +54,7 @@ void regAby3Protocol(SPUContext* ctx, aby3::B2P, aby3::P2B, aby3::A2B, // Conversion2 aby3::B2ASelector, /*aby3::B2AByOT, aby3::B2AByPPA*/ // B2A aby3::CastTypeB, // Cast - aby3::NotA, // Not + aby3::NegateA, // Negate aby3::AddAP, aby3::AddAA, // Add aby3::MulAP, aby3::MulAA, aby3::MulA1B, // Mul aby3::MatMulAP, aby3::MatMulAA, // MatMul diff --git a/libspu/mpc/aby3/value.h b/libspu/mpc/aby3/value.h index fbb413a0..8396b121 100644 --- a/libspu/mpc/aby3/value.h +++ b/libspu/mpc/aby3/value.h @@ -58,7 +58,7 @@ std::vector getShareAs(const NdArrayRef& in, size_t share_idx) { auto numel = in.numel(); std::vector res(numel); - DISPATCH_UINT_PT_TYPES(share.eltype().as()->pt_type(), "_", [&]() { + DISPATCH_UINT_PT_TYPES(share.eltype().as()->pt_type(), [&]() { NdArrayView _share(share); for (auto idx = 0; idx < numel; ++idx) { res[idx] = _share[idx]; diff --git a/libspu/mpc/api.cc b/libspu/mpc/api.cc index 2008e8aa..3404f00f 100644 --- a/libspu/mpc/api.cc +++ b/libspu/mpc/api.cc @@ -14,8 +14,6 @@ #include "libspu/mpc/api.h" -#include - #include "libspu/core/trace.h" #include "libspu/mpc/ab_api.h" @@ -175,7 +173,11 @@ Value s2v(SPUContext* ctx, const Value& x, size_t owner) { return a2v(ctx, x, owner); } else { SPU_ENFORCE(IsB(x)); - return b2v(ctx, x, owner); + if (ctx->hasKernel("b2v")) { + return b2v(ctx, x, owner); + } else { + return a2v(ctx, _2a(ctx, x), owner); + } } } @@ -275,22 +277,42 @@ Value rand_s(SPUContext* ctx, const Shape& shape) { return rand_a(ctx, shape); } +// only works for Z2k. +// Neg(x) = Not(x) + 1 +// Not(x) = Neg(x) - 1 +Value not_v(SPUContext* ctx, const Value& x) { + SPU_TRACE_MPC_DISP(ctx, x); + auto k1 = make_p(ctx, 1, x.shape()); + return add_vp(ctx, negate_v(ctx, x), negate_p(ctx, k1)); +} + +Value not_p(SPUContext* ctx, const Value& x) { + SPU_TRACE_MPC_DISP(ctx, x); + auto k1 = make_p(ctx, 1, x.shape()); + return add_pp(ctx, negate_p(ctx, x), negate_p(ctx, k1)); +} + Value not_s(SPUContext* ctx, const Value& x) { + SPU_TRACE_MPC_DISP(ctx, x); + if (x.storage_type().isa()) { + auto ones = make_p(ctx, -1, x.shape()); + return xor_bp(ctx, x, ones); + } else { + SPU_ENFORCE(x.storage_type().isa()); + auto k1 = make_p(ctx, 1, x.shape()); + return add_sp(ctx, negate_s(ctx, x), negate_p(ctx, k1)); + } +} + +Value negate_s(SPUContext* ctx, const Value& x) { SPU_TRACE_MPC_DISP(ctx, x); TRY_DISPATCH(ctx, x); - // TODO: Both A&B could handle not(invert). - // if (x.eltype().isa()) { - // return not_b(ctx, x); - //} else { - // SPU_ENFORCE(x.eltype().isa()); - // return not_a(ctx, x); - //} - return not_a(ctx, _2a(ctx, x)); + return negate_a(ctx, _2a(ctx, x)); } -Value not_v(SPUContext* ctx, const Value& x) { FORCE_DISPATCH(ctx, x); } +Value negate_v(SPUContext* ctx, const Value& x) { FORCE_DISPATCH(ctx, x); } -Value not_p(SPUContext* ctx, const Value& x) { FORCE_DISPATCH(ctx, x); } +Value negate_p(SPUContext* ctx, const Value& x) { FORCE_DISPATCH(ctx, x); } ////////////////////////////////////////////////////////////////////////////// @@ -303,13 +325,14 @@ Value msb_s(SPUContext* ctx, const Value& x) { if (ctx->hasKernel("msb_a2b")) { if (IsB(x)) { - return rshift_b(ctx, x, SizeOf(field) * 8 - 1); + return rshift_b(ctx, x, {static_cast(SizeOf(field) * 8 - 1)}); } else { // fast path, directly apply msb x AShare, result a BShare. return msb_a2b(ctx, x); } } else { - return rshift_b(ctx, _2b(ctx, x), SizeOf(field) * 8 - 1); + return rshift_b(ctx, _2b(ctx, x), + {static_cast(SizeOf(field) * 8 - 1)}); } } @@ -601,7 +624,7 @@ Value xor_pp(SPUContext* ctx, const Value& x, const Value& y) { ////////////////////////////////////////////////////////////////////////////// -Value lshift_s(SPUContext* ctx, const Value& x, size_t bits) { +Value lshift_s(SPUContext* ctx, const Value& x, const Sizes& bits) { SPU_TRACE_MPC_DISP(ctx, x, bits); TRY_DISPATCH(ctx, x, bits); if (IsA(x)) { @@ -613,43 +636,43 @@ Value lshift_s(SPUContext* ctx, const Value& x, size_t bits) { } } -Value lshift_v(SPUContext* ctx, const Value& x, size_t nbits) { +Value lshift_v(SPUContext* ctx, const Value& x, const Sizes& nbits) { FORCE_DISPATCH(ctx, x, nbits); } -Value lshift_p(SPUContext* ctx, const Value& x, size_t nbits) { +Value lshift_p(SPUContext* ctx, const Value& x, const Sizes& nbits) { FORCE_DISPATCH(ctx, x, nbits); } ////////////////////////////////////////////////////////////////////////////// -Value rshift_s(SPUContext* ctx, const Value& x, size_t bits) { +Value rshift_s(SPUContext* ctx, const Value& x, const Sizes& bits) { SPU_TRACE_MPC_DISP(ctx, x, bits); TRY_DISPATCH(ctx, x, bits); return rshift_b(ctx, _2b(ctx, x), bits); } -Value rshift_v(SPUContext* ctx, const Value& x, size_t nbits) { +Value rshift_v(SPUContext* ctx, const Value& x, const Sizes& nbits) { FORCE_DISPATCH(ctx, x, nbits); } -Value rshift_p(SPUContext* ctx, const Value& x, size_t nbits) { +Value rshift_p(SPUContext* ctx, const Value& x, const Sizes& nbits) { FORCE_DISPATCH(ctx, x, nbits); } ////////////////////////////////////////////////////////////////////////////// -Value arshift_s(SPUContext* ctx, const Value& x, size_t bits) { +Value arshift_s(SPUContext* ctx, const Value& x, const Sizes& bits) { SPU_TRACE_MPC_DISP(ctx, x, bits); TRY_DISPATCH(ctx, x, bits); return arshift_b(ctx, _2b(ctx, x), bits); } -Value arshift_v(SPUContext* ctx, const Value& x, size_t nbits) { +Value arshift_v(SPUContext* ctx, const Value& x, const Sizes& nbits) { FORCE_DISPATCH(ctx, x, nbits); } -Value arshift_p(SPUContext* ctx, const Value& x, size_t nbits) { +Value arshift_p(SPUContext* ctx, const Value& x, const Sizes& nbits) { FORCE_DISPATCH(ctx, x, nbits); } @@ -662,11 +685,15 @@ Value trunc_s(SPUContext* ctx, const Value& x, size_t bits, SignType sign) { } Value trunc_v(SPUContext* ctx, const Value& x, size_t nbits, SignType sign) { - FORCE_DISPATCH(ctx, x, nbits, sign); + // FIXME: trunc_v use shift kernel + const Sizes trunc_bits = {static_cast(nbits)}; + FORCE_DISPATCH(ctx, x, trunc_bits, sign); } Value trunc_p(SPUContext* ctx, const Value& x, size_t nbits, SignType sign) { - FORCE_DISPATCH(ctx, x, nbits, sign); + // FIXME: trunc_p use shift kernel + const Sizes trunc_bits = {static_cast(nbits)}; + FORCE_DISPATCH(ctx, x, trunc_bits, sign); } ////////////////////////////////////////////////////////////////////////////// diff --git a/libspu/mpc/api.h b/libspu/mpc/api.h index 7d0b2333..50656f9a 100644 --- a/libspu/mpc/api.h +++ b/libspu/mpc/api.h @@ -91,11 +91,16 @@ Value make_p(SPUContext* ctx, uint128_t init, const Shape& shape); Value rand_p(SPUContext* ctx, const Shape& shape); Value rand_s(SPUContext* ctx, const Shape& shape); -// Compute bitwise_not(invert) of a value in ring 2k space. +// Compute bitwise not of a value. Value not_p(SPUContext* ctx, const Value& x); Value not_s(SPUContext* ctx, const Value& x); Value not_v(SPUContext* ctx, const Value& x); +// Compute negate of a value. +Value negate_p(SPUContext* ctx, const Value& x); +Value negate_s(SPUContext* ctx, const Value& x); +Value negate_v(SPUContext* ctx, const Value& x); + Value msb_p(SPUContext* ctx, const Value& x); Value msb_s(SPUContext* ctx, const Value& x); Value msb_v(SPUContext* ctx, const Value& x); @@ -143,17 +148,17 @@ Value xor_vv(SPUContext* ctx, const Value& x, const Value& y); Value xor_vp(SPUContext* ctx, const Value& x, const Value& y); Value xor_pp(SPUContext* ctx, const Value& x, const Value& y); -Value lshift_s(SPUContext* ctx, const Value& x, size_t nbits); -Value lshift_v(SPUContext* ctx, const Value& x, size_t nbits); -Value lshift_p(SPUContext* ctx, const Value& x, size_t nbits); +Value lshift_s(SPUContext* ctx, const Value& x, const Sizes& nbits); +Value lshift_v(SPUContext* ctx, const Value& x, const Sizes& nbits); +Value lshift_p(SPUContext* ctx, const Value& x, const Sizes& nbits); -Value rshift_s(SPUContext* ctx, const Value& x, size_t nbits); -Value rshift_v(SPUContext* ctx, const Value& x, size_t nbits); -Value rshift_p(SPUContext* ctx, const Value& x, size_t nbits); +Value rshift_s(SPUContext* ctx, const Value& x, const Sizes& nbits); +Value rshift_v(SPUContext* ctx, const Value& x, const Sizes& nbits); +Value rshift_p(SPUContext* ctx, const Value& x, const Sizes& nbits); -Value arshift_s(SPUContext* ctx, const Value& x, size_t nbits); -Value arshift_v(SPUContext* ctx, const Value& x, size_t nbits); -Value arshift_p(SPUContext* ctx, const Value& x, size_t nbits); +Value arshift_s(SPUContext* ctx, const Value& x, const Sizes& nbits); +Value arshift_v(SPUContext* ctx, const Value& x, const Sizes& nbits); +Value arshift_p(SPUContext* ctx, const Value& x, const Sizes& nbits); Value trunc_s(SPUContext* ctx, const Value& x, size_t nbits, SignType sign); Value trunc_v(SPUContext* ctx, const Value& x, size_t nbits, SignType sign); diff --git a/libspu/mpc/api_test.cc b/libspu/mpc/api_test.cc index 33568d77..d890258f 100644 --- a/libspu/mpc/api_test.cc +++ b/libspu/mpc/api_test.cc @@ -245,6 +245,7 @@ TEST_BINARY_OP(xor) TEST_UNARY_OP_V(OP) \ TEST_UNARY_OP_P(OP) +TEST_UNARY_OP(negate) TEST_UNARY_OP(not ) TEST_UNARY_OP_V(msb) TEST_UNARY_OP_P(msb) @@ -261,7 +262,7 @@ TEST_P(ApiTest, MsbS) { // SECURENN has an msb input range requirement here if (conf.protocol() == ProtocolKind::SECURENN) { - p0 = arshift_p(sctx.get(), p0, 1); + p0 = arshift_p(sctx.get(), p0, {1}); } auto r_s = s2p(sctx.get(), msb_s(sctx.get(), p2s(sctx.get(), p0))); @@ -272,88 +273,92 @@ TEST_P(ApiTest, MsbS) { }); } -#define TEST_UNARY_OP_WITH_BIT_S(OP) \ - TEST_P(ApiTest, OP##S) { \ - const auto factory = std::get<0>(GetParam()); \ - const RuntimeConfig& conf = std::get<1>(GetParam()); \ - const size_t npc = std::get<2>(GetParam()); \ - \ - utils::simulate( \ - npc, [&](const std::shared_ptr& lctx) { \ - auto sctx = factory(conf, lctx); \ - \ - /* GIVEN */ \ - auto x_p = rand_p(sctx.get(), kShape); \ - auto x_s = p2s(sctx.get(), x_p); \ - \ - for (auto bits : kShiftBits) { \ - if (bits >= SizeOf(conf.field()) * 8) { \ - continue; \ - } \ - /* WHEN */ \ - auto r_s = s2p(sctx.get(), OP##_s(sctx.get(), x_s, bits)); \ - auto r_p = OP##_p(sctx.get(), x_p, bits); \ - \ - /* THEN */ \ - EXPECT_VALUE_EQ(r_s, r_p); \ - } \ - }); \ +#define TEST_UNARY_OP_WITH_BIT_S(OP) \ + TEST_P(ApiTest, OP##S) { \ + const auto factory = std::get<0>(GetParam()); \ + const RuntimeConfig& conf = std::get<1>(GetParam()); \ + const size_t npc = std::get<2>(GetParam()); \ + \ + utils::simulate( \ + npc, [&](const std::shared_ptr& lctx) { \ + auto sctx = factory(conf, lctx); \ + \ + /* GIVEN */ \ + auto x_p = rand_p(sctx.get(), kShape); \ + auto x_s = p2s(sctx.get(), x_p); \ + \ + for (auto bits : kShiftBits) { \ + if (bits >= SizeOf(conf.field()) * 8) { \ + continue; \ + } \ + /* WHEN */ \ + auto r_s = s2p(sctx.get(), OP##_s(sctx.get(), x_s, \ + {static_cast(bits)})); \ + auto r_p = OP##_p(sctx.get(), x_p, {static_cast(bits)}); \ + \ + /* THEN */ \ + EXPECT_VALUE_EQ(r_s, r_p); \ + } \ + }); \ } -#define TEST_UNARY_OP_WITH_BIT_V(OP) \ - TEST_P(ApiTest, OP##V) { \ - const auto factory = std::get<0>(GetParam()); \ - const RuntimeConfig& conf = std::get<1>(GetParam()); \ - const size_t npc = std::get<2>(GetParam()); \ - \ - utils::simulate( \ - npc, [&](const std::shared_ptr& lctx) { \ - auto sctx = factory(conf, lctx); \ - \ - for (size_t rank = 0; rank < npc; rank++) { \ - /* GIVEN */ \ - auto x_p = rand_p(sctx.get(), kShape); \ - auto x_v = p2v(sctx.get(), x_p, rank); \ - \ - for (auto bits : kShiftBits) { \ - if (bits >= SizeOf(conf.field()) * 8) { \ - continue; \ - } \ - /* WHEN */ \ - auto r_v = v2p(sctx.get(), OP##_v(sctx.get(), x_v, bits)); \ - auto r_p = OP##_p(sctx.get(), x_p, bits); \ - \ - /* THEN */ \ - EXPECT_VALUE_EQ(r_v, r_p); \ - } \ - } \ - }); \ +#define TEST_UNARY_OP_WITH_BIT_V(OP) \ + TEST_P(ApiTest, OP##V) { \ + const auto factory = std::get<0>(GetParam()); \ + const RuntimeConfig& conf = std::get<1>(GetParam()); \ + const size_t npc = std::get<2>(GetParam()); \ + \ + utils::simulate( \ + npc, [&](const std::shared_ptr& lctx) { \ + auto sctx = factory(conf, lctx); \ + \ + for (size_t rank = 0; rank < npc; rank++) { \ + /* GIVEN */ \ + auto x_p = rand_p(sctx.get(), kShape); \ + auto x_v = p2v(sctx.get(), x_p, rank); \ + \ + for (auto bits : kShiftBits) { \ + if (bits >= SizeOf(conf.field()) * 8) { \ + continue; \ + } \ + /* WHEN */ \ + auto r_v = \ + v2p(sctx.get(), \ + OP##_v(sctx.get(), x_v, {static_cast(bits)})); \ + auto r_p = \ + OP##_p(sctx.get(), x_p, {static_cast(bits)}); \ + \ + /* THEN */ \ + EXPECT_VALUE_EQ(r_v, r_p); \ + } \ + } \ + }); \ } -#define TEST_UNARY_OP_WITH_BIT_P(OP) \ - TEST_P(ApiTest, OP##P) { \ - const auto factory = std::get<0>(GetParam()); \ - const RuntimeConfig& conf = std::get<1>(GetParam()); \ - const size_t npc = std::get<2>(GetParam()); \ - \ - utils::simulate(npc, \ - [&](const std::shared_ptr& lctx) { \ - auto sctx = factory(conf, lctx); \ - \ - /* GIVEN */ \ - auto p0 = rand_p(sctx.get(), kShape); \ - \ - for (auto bits : kShiftBits) { /* WHEN */ \ - if (bits >= SizeOf(conf.field()) * 8) { \ - continue; \ - } \ - auto r_p = OP##_p(sctx.get(), p0, bits); \ - auto r_pp = OP##_p(sctx.get(), p0, bits); \ - \ - /* THEN */ \ - EXPECT_VALUE_EQ(r_p, r_pp); \ - } \ - }); \ +#define TEST_UNARY_OP_WITH_BIT_P(OP) \ + TEST_P(ApiTest, OP##P) { \ + const auto factory = std::get<0>(GetParam()); \ + const RuntimeConfig& conf = std::get<1>(GetParam()); \ + const size_t npc = std::get<2>(GetParam()); \ + \ + utils::simulate( \ + npc, [&](const std::shared_ptr& lctx) { \ + auto sctx = factory(conf, lctx); \ + \ + /* GIVEN */ \ + auto p0 = rand_p(sctx.get(), kShape); \ + \ + for (auto bits : kShiftBits) { /* WHEN */ \ + if (bits >= SizeOf(conf.field()) * 8) { \ + continue; \ + } \ + auto r_p = OP##_p(sctx.get(), p0, {static_cast(bits)}); \ + auto r_pp = OP##_p(sctx.get(), p0, {static_cast(bits)}); \ + \ + /* THEN */ \ + EXPECT_VALUE_EQ(r_p, r_pp); \ + } \ + }); \ } #define TEST_UNARY_OP_WITH_BIT(OP) \ @@ -379,12 +384,13 @@ TEST_P(ApiTest, TruncS) { : kShape); // TODO: here we assume has msb error, only use lowest 10 bits. - p0 = arshift_p(sctx.get(), p0, SizeOf(conf.field()) * 8 - 10); + p0 = arshift_p(sctx.get(), p0, + {static_cast(SizeOf(conf.field()) * 8 - 10)}); const size_t bits = 2; auto r_s = s2p(sctx.get(), trunc_s(sctx.get(), p2s(sctx.get(), p0), bits, SignType::Unknown)); - auto r_p = arshift_p(sctx.get(), p0, bits); + auto r_p = arshift_p(sctx.get(), p0, {bits}); /* THEN */ EXPECT_VALUE_ALMOST_EQ(r_s, r_p, npc); diff --git a/libspu/mpc/cheetah/arith/cheetah_conv2d_test.cc b/libspu/mpc/cheetah/arith/cheetah_conv2d_test.cc index b05235a6..d71d2bb1 100644 --- a/libspu/mpc/cheetah/arith/cheetah_conv2d_test.cc +++ b/libspu/mpc/cheetah/arith/cheetah_conv2d_test.cc @@ -84,7 +84,7 @@ TEST_P(CheetahConv2dTest, Basic) { flatten(ring_add(toNdArray(result[0]), toNdArray(result[1]))); const int64_t kMaxDiff = 1; - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { ArrayView c(computed); NdArrayView exp(expected); for (auto idx = 0; idx < expected.numel(); idx++) { diff --git a/libspu/mpc/cheetah/arith/cheetah_dot_test.cc b/libspu/mpc/cheetah/arith/cheetah_dot_test.cc index 79697937..df93112d 100644 --- a/libspu/mpc/cheetah/arith/cheetah_dot_test.cc +++ b/libspu/mpc/cheetah/arith/cheetah_dot_test.cc @@ -68,7 +68,7 @@ TEST_P(CheetahDotTest, Basic) { EXPECT_EQ(expected.numel(), computed.numel()); const int64_t kMaxDiff = 1; - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto e = NdArrayView(expected); auto c = NdArrayView(computed); @@ -119,7 +119,7 @@ TEST_P(CheetahDotTest, BatchDot) { [[maybe_unused]] constexpr int64_t kMaxDiff = 1; int64_t max_diff = 0; - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto e = NdArrayView(expected); auto c = NdArrayView(computed); diff --git a/libspu/mpc/cheetah/arith/cheetah_mul.cc b/libspu/mpc/cheetah/arith/cheetah_mul.cc index 072b2978..779041a9 100644 --- a/libspu/mpc/cheetah/arith/cheetah_mul.cc +++ b/libspu/mpc/cheetah/arith/cheetah_mul.cc @@ -114,6 +114,10 @@ struct CheetahMul::Impl : public EnableCPRNG { NdArrayRef MulOLE(const NdArrayRef &shr, yacl::link::Context *conn, bool evaluator, uint32_t msg_width_hint); + NdArrayRef MulShare(const NdArrayRef &xshr, const NdArrayRef &yshr, + yacl::link::Context *conn, bool evaluator, + uint32_t msg_width_hint); + protected: void LocalExpandSEALContexts(size_t target); @@ -167,6 +171,16 @@ struct CheetahMul::Impl : public EnableCPRNG { absl::Span rnd_mask, yacl::link::Context *conn = nullptr); + // Enc(x0) * y1 + Enc(y0) * x1 + rand_mask + void FMAThenResponse(FieldType field, int64_t num_elts, + const Options &options, + absl::Span ciphers_x0, + absl::Span ciphers_y0, + absl::Span plains_x1, + absl::Span plains_y1, + absl::Span rnd_mask, + yacl::link::Context *conn = nullptr); + void PrepareRandomMask(FieldType field, int64_t size, const Options &options, std::vector &mask); @@ -386,6 +400,79 @@ NdArrayRef CheetahMul::Impl::MulOLE(const NdArrayRef &shr, return DecryptArray(field, numel, options, recv_ct).reshape(shr.shape()); } +NdArrayRef CheetahMul::Impl::MulShare(const NdArrayRef &xshr, + const NdArrayRef &yshr, + yacl::link::Context *conn, bool evaluator, + uint32_t msg_width_hint) { + if (conn == nullptr) { + conn = lctx_.get(); + } + + auto eltype = xshr.eltype(); + SPU_ENFORCE(eltype.isa(), "must be ring_type, got={}", eltype); + SPU_ENFORCE(yshr.eltype().isa(), "must be ring_type, got={}", + yshr.eltype()); + SPU_ENFORCE(xshr.numel() > 0); + SPU_ENFORCE_EQ(xshr.shape(), yshr.shape()); + + auto field = eltype.as()->field(); + Options options; + options.ring_bitlen = SizeOf(field) * 8; + options.msg_bitlen = + msg_width_hint == 0 ? options.ring_bitlen : msg_width_hint; + SPU_ENFORCE(options.msg_bitlen > 0 && + options.msg_bitlen <= options.ring_bitlen); + LazyExpandSEALContexts(options, conn); + LazyInitModSwitchHelper(options); + + size_t numel = xshr.numel(); + int nxt_rank = conn->NextRank(); + + // x0*y0 + + x1 * y1 + if (evaluator) { + std::vector encoded_x0; + std::vector encoded_y0; + EncodeArray(xshr, false, options, &encoded_x0); + EncodeArray(yshr, false, options, &encoded_y0); + + size_t payload_sze = encoded_x0.size(); + std::vector recv_ct_x1(payload_sze); + std::vector recv_ct_y1(payload_sze); + auto io_task = std::async(std::launch::async, [&]() { + for (size_t idx = 0; idx < payload_sze; ++idx) { + recv_ct_x1[idx] = conn->Recv(nxt_rank, ""); + } + for (size_t idx = 0; idx < payload_sze; ++idx) { + recv_ct_y1[idx] = conn->Recv(nxt_rank, ""); + } + }); + + std::vector random_share_mask; + PrepareRandomMask(field, xshr.numel(), options, random_share_mask); + + // wait for IO + io_task.get(); + FMAThenResponse(field, numel, options, recv_ct_x1, recv_ct_y1, encoded_x0, + encoded_y0, absl::MakeConstSpan(random_share_mask), conn); + // convert x \in [0, P) to [0, 2^k) by round(2^k*x/P) + auto &ms_helper = ms_helpers_.find(options)->second; + auto out = ms_helper.ModulusDownRNS(field, xshr.shape(), random_share_mask) + .reshape(xshr.shape()); + ring_add_(out, ring_mul(xshr, yshr)); + return out; + } + + size_t payload_sze = EncryptArrayThenSend(xshr, options, conn); + (void)EncryptArrayThenSend(yshr, options, conn); + std::vector recv_ct(payload_sze); + for (size_t idx = 0; idx < payload_sze; ++idx) { + recv_ct[idx] = conn->Recv(nxt_rank, ""); + } + auto out = DecryptArray(field, numel, options, recv_ct).reshape(xshr.shape()); + ring_add_(out, ring_mul(xshr, yshr)); + return out; +} + size_t CheetahMul::Impl::EncryptArrayThenSend(const NdArrayRef &array, const Options &options, yacl::link::Context *conn) { @@ -573,6 +660,72 @@ void CheetahMul::Impl::MulThenResponse(FieldType, int64_t num_elts, } } +void CheetahMul::Impl::FMAThenResponse( + FieldType, int64_t num_elts, const Options &options, + absl::Span ciphers_x0, + absl::Span ciphers_y0, + absl::Span plains_x1, absl::Span plains_y1, + absl::Span rnd_mask, yacl::link::Context *conn) { + SPU_ENFORCE(!ciphers_x0.empty(), "CheetahMul: empty cipher"); + SPU_ENFORCE(!ciphers_y0.empty(), "CheetahMul: empty cipher"); + SPU_ENFORCE_EQ(ciphers_x0.size(), ciphers_y0.size()); + SPU_ENFORCE_EQ(plains_x1.size(), ciphers_x0.size(), + "CheetahMul: ct/pt size mismatch"); + SPU_ENFORCE_EQ(plains_y1.size(), ciphers_y0.size(), + "CheetahMul: ct/pt size mismatch"); + + const int64_t num_splits = CeilDiv(num_elts, num_slots()); + const int64_t num_seal_ctx = WorkingContextSize(options); + const int64_t num_ciphers = num_seal_ctx * num_splits; + SPU_ENFORCE(ciphers_x0.size() == (size_t)num_ciphers, + "CheetahMul : expect {} != {}", num_ciphers, ciphers_x0.size()); + SPU_ENFORCE(rnd_mask.size() == (size_t)num_elts * num_seal_ctx, + "CheetahMul: rnd_mask size mismatch"); + + std::vector response(num_ciphers); + yacl::parallel_for(0, num_ciphers, [&](int64_t job_bgn, int64_t job_end) { + RLWECt ct_x; + RLWECt ct_y; + std::vector u64tmp(num_slots(), 0); + for (int64_t job_id = job_bgn; job_id < job_end; ++job_id) { + int64_t cntxt_id = job_id / num_splits; + int64_t split_id = job_id % num_splits; + + int64_t slice_bgn = split_id * num_slots(); + int64_t slice_n = std::min(num_slots(), num_elts - slice_bgn); + // offset by context id + slice_bgn += cntxt_id * num_elts; + + DecodeSEALObject(ciphers_x0[job_id], seal_cntxts_[cntxt_id], &ct_x); + DecodeSEALObject(ciphers_y0[job_id], seal_cntxts_[cntxt_id], &ct_y); + + // ct_x <- Re-randomize(ct_x * pt_y + ct_y * pt_x) - random_mask + simd_mul_instances_[cntxt_id]->FMAThenReshareInplace( + {&ct_x, 1}, {&ct_y, 1}, plains_y1.subspan(job_id, 1), + plains_x1.subspan(job_id, 1), rnd_mask.subspan(slice_bgn, slice_n), + *peer_pub_key_, seal_cntxts_[cntxt_id]); + + response[job_id] = EncodeSEALObject(ct_x); + } + }); + + if (conn == nullptr) { + conn = lctx_.get(); + } + + int nxt_rank = conn->NextRank(); + for (int64_t i = 0; i < num_ciphers; i += kCtAsyncParallel) { + int64_t this_batch = std::min(num_ciphers - i, kCtAsyncParallel); + conn->Send(nxt_rank, response[i], + fmt::format("FMAThenResponse ct[{}] to rank{}", i, nxt_rank)); + for (int64_t j = 1; j < this_batch; ++j) { + conn->SendAsync( + nxt_rank, response[i + j], + fmt::format("FMAThenResponse ct[{}] to rank{}", i + j, nxt_rank)); + } + } +} + NdArrayRef CheetahMul::Impl::DecryptArray( FieldType field, int64_t size, const Options &options, const std::vector &ct_array) { @@ -625,6 +778,20 @@ size_t CheetahMul::OLEBatchSize() const { return impl_->OLEBatchSize(); } +NdArrayRef CheetahMul::MulShare(const NdArrayRef &xshr, const NdArrayRef &yshr, + yacl::link::Context *conn, bool is_evaluator, + uint32_t msg_width_hint) { + SPU_ENFORCE(impl_ != nullptr); + SPU_ENFORCE(conn != nullptr); + return impl_->MulShare(xshr, yshr, conn, is_evaluator, msg_width_hint); +} + +NdArrayRef CheetahMul::MulShare(const NdArrayRef &xshr, const NdArrayRef &yshr, + bool is_evaluator, uint32_t msg_width_hint) { + SPU_ENFORCE(impl_ != nullptr); + return impl_->MulShare(xshr, yshr, nullptr, is_evaluator, msg_width_hint); +} + NdArrayRef CheetahMul::MulOLE(const NdArrayRef &inp, yacl::link::Context *conn, bool is_evaluator, uint32_t msg_width_hint) { SPU_ENFORCE(impl_ != nullptr); diff --git a/libspu/mpc/cheetah/arith/cheetah_mul.h b/libspu/mpc/cheetah/arith/cheetah_mul.h index e687477a..304a7871 100644 --- a/libspu/mpc/cheetah/arith/cheetah_mul.h +++ b/libspu/mpc/cheetah/arith/cheetah_mul.h @@ -44,14 +44,27 @@ class CheetahMul { void LazyInitKeys(FieldType field, uint32_t msg_width_hint = 0); + // x, y => [x*y] for two private inputs // NOTE: make sure to call InitKeys first NdArrayRef MulOLE(const NdArrayRef& inp, yacl::link::Context* conn, bool is_evaluator, uint32_t msg_width_hint = 0); + // x, y => [x*y] for two private inputs // NOTE: make sure to call InitKeys first NdArrayRef MulOLE(const NdArrayRef& inp, bool is_evaluator, uint32_t msg_width_hint = 0); + // [x], [y] => [x*y] for two shares + // NOTE: make sure to call InitKeys first + NdArrayRef MulShare(const NdArrayRef& x, const NdArrayRef& y, + yacl::link::Context* conn, bool is_evaluator, + uint32_t msg_width_hint = 0); + + // [x], [y] => [x*y] for two shares + // NOTE: make sure to call InitKeys first + NdArrayRef MulShare(const NdArrayRef& x, const NdArrayRef& y, + bool is_evaluator, uint32_t msg_width_hint = 0); + int Rank() const; size_t OLEBatchSize() const; diff --git a/libspu/mpc/cheetah/arith/cheetah_mul_test.cc b/libspu/mpc/cheetah/arith/cheetah_mul_test.cc index d22040c1..4cd9635e 100644 --- a/libspu/mpc/cheetah/arith/cheetah_mul_test.cc +++ b/libspu/mpc/cheetah/arith/cheetah_mul_test.cc @@ -207,4 +207,37 @@ TEST_P(CheetahMulTest, MixedRingSizeMul) { EXPECT_TRUE(ring_all_equal(expected2, computed2, kMaxDiff)); } +TEST_P(CheetahMulTest, MulShare) { + size_t kWorldSize = 2; + auto field = std::get<0>(GetParam()); + int64_t n = std::get<1>(GetParam()); + bool allow_approx = std::get<2>(GetParam()); + + auto a_bits = ring_rand(field, {n}); + auto b_bits = ring_rand(field, {n}); + + std::vector a_shr(kWorldSize); + std::vector b_shr(kWorldSize); + a_shr[0] = ring_rand(field, {n}); + b_shr[0] = ring_rand(field, {n}); + a_shr[1] = ring_sub(a_bits, a_shr[0]); + b_shr[1] = ring_sub(b_bits, b_shr[0]); + + std::vector result(kWorldSize); + utils::simulate(kWorldSize, [&](std::shared_ptr lctx) { + int rank = lctx->Rank(); + // (a0 + a1) * (b0 + b1) + // a0*b0 + a0*b1 + a1*b0 + a1*b1 + auto mul = std::make_shared(lctx, allow_approx); + + result[rank] = mul->MulShare(a_shr[rank], b_shr[rank], rank == 0); + }); + + auto expected = ring_mul(a_bits, b_bits); + auto computed = ring_add(result[0], result[1]); + + const int64_t kMaxDiff = allow_approx ? 1 : 0; + EXPECT_TRUE(ring_all_equal(expected, computed, kMaxDiff)); +} + } // namespace spu::mpc::cheetah::test diff --git a/libspu/mpc/cheetah/arith/common.cc b/libspu/mpc/cheetah/arith/common.cc index 3deda13e..f87d92ab 100644 --- a/libspu/mpc/cheetah/arith/common.cc +++ b/libspu/mpc/cheetah/arith/common.cc @@ -102,7 +102,7 @@ NdArrayRef ring_conv2d(const NdArrayRef &tensor, const NdArrayRef &filter, NdArrayRef _filter = filter.reshape(fs); NdArrayRef _ret = ring_zeros(field, result_shape); - DISPATCH_ALL_FIELDS(field, "ring_conv2d", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { // NOTE(juhou): valid padding so offset are always 0. constexpr int64_t padh = 0; constexpr int64_t padw = 0; diff --git a/libspu/mpc/cheetah/arith/matmat_prot.cc b/libspu/mpc/cheetah/arith/matmat_prot.cc index ebb80f9f..919a13fa 100644 --- a/libspu/mpc/cheetah/arith/matmat_prot.cc +++ b/libspu/mpc/cheetah/arith/matmat_prot.cc @@ -113,7 +113,7 @@ NdArrayRef ConcatSubMatrix(const NdArrayRef& mat, const Shape2D& mat_shape, // NOTE: zero padding via initialization NdArrayRef flatten = ring_zeros(field, {num_coeff}); - DISPATCH_ALL_FIELDS(field, "ConcatSubMat", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using uT = std::make_unsigned::type; for (int64_t r = 0, rr = starts[0]; r < extents[0]; ++r, ++rr) { @@ -423,7 +423,7 @@ NdArrayRef MatMatProtocol::ParseResult(FieldType field, const Meta& meta, for (int64_t r = 0; r < row_ext; ++r) { for (int64_t c = 0; c < col_ext; ++c) { int64_t dst_idx = (r + row_start) * meta.dims[2] + col_start + c; - DISPATCH_ALL_FIELDS(field, "ParseResult", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { matmat.at(dst_idx) = result_poly.at(r * subdims[2] + c); }); @@ -532,13 +532,13 @@ NdArrayRef MatMatProtocol::ParsePackLWEsResult( std::vector decoded_vectors(ans_poly.size()); for (size_t i = 0; i < ans_poly.size(); ++i) { decoded_vectors[i] = - msh.ModulusDownRNS(field, {(int64_t)packing_width}, + msh.ModulusDownRNS(field, {static_cast(packing_width)}, {ans_poly[i].data(), ans_poly[i].coeff_count()}); } NdArrayRef matmat = ring_zeros(field, {meta.dims[0] * meta.dims[2]}); - DISPATCH_ALL_FIELDS(field, "pack_lwes_results", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView xmatmat(matmat); for (size_t i = 0; i < ans_poly.size(); ++i) { diff --git a/libspu/mpc/cheetah/arith/matmat_prot_test.cc b/libspu/mpc/cheetah/arith/matmat_prot_test.cc index 0b01f0d0..3a5f6521 100644 --- a/libspu/mpc/cheetah/arith/matmat_prot_test.cc +++ b/libspu/mpc/cheetah/arith/matmat_prot_test.cc @@ -153,7 +153,7 @@ TEST_P(MatMatProtTest, Plain) { EXPECT_EQ(expected.numel(), computed.numel()); - DISPATCH_ALL_FIELDS(field_, "", [&]() { + DISPATCH_ALL_FIELDS(field_, [&]() { auto xe = NdArrayView(expected); auto xc = NdArrayView(computed); for (int64_t i = 0; i < xc.numel(); ++i) { @@ -215,7 +215,7 @@ TEST_P(MatMatProtTest, EncLHS) { EXPECT_EQ(expected.numel(), computed.numel()); - DISPATCH_ALL_FIELDS(field_, "", [&]() { + DISPATCH_ALL_FIELDS(field_, [&]() { auto xe = NdArrayView(expected); auto xc = NdArrayView(computed); for (int64_t i = 0; i < xc.numel(); ++i) { @@ -276,7 +276,7 @@ TEST_P(MatMatProtTest, EncRHS) { EXPECT_EQ(expected.numel(), computed.numel()); - DISPATCH_ALL_FIELDS(field_, "", [&]() { + DISPATCH_ALL_FIELDS(field_, [&]() { auto xe = NdArrayView(expected); auto xc = NdArrayView(computed); for (int64_t i = 0; i < xc.numel(); ++i) { diff --git a/libspu/mpc/cheetah/arith/simd_mul_prot.cc b/libspu/mpc/cheetah/arith/simd_mul_prot.cc index f347b87b..74d0b291 100644 --- a/libspu/mpc/cheetah/arith/simd_mul_prot.cc +++ b/libspu/mpc/cheetah/arith/simd_mul_prot.cc @@ -222,6 +222,66 @@ void SIMDMulProt::MulThenReshareInplace(absl::Span ct, } } +// Compute ct0 * pt1 + ct1 * pt1 - mask mod p +void SIMDMulProt::FMAThenReshareInplace(absl::Span ct0, + absl::Span ct1, + absl::Span pt0, + absl::Span pt1, + absl::Span share_mask, + const RLWEPublicKey &public_key, + const seal::SEALContext &context) { + SPU_ENFORCE_EQ(ct0.size(), ct1.size()); + SPU_ENFORCE_EQ(pt0.size(), pt1.size()); + SPU_ENFORCE_EQ(ct0.size(), pt0.size()); + SPU_ENFORCE_EQ(CeilDiv(share_mask.size(), (size_t)simd_lane_), ct0.size()); + + seal::Evaluator evaluator(context); + RLWECt zero_enc; + RLWEPt rnd; + + constexpr int kMarginBitsForDec = 10; + seal::parms_id_type final_level_id = context.last_parms_id(); + while (final_level_id != context.first_parms_id()) { + auto cntxt = context.get_context_data(final_level_id); + if (cntxt->total_coeff_modulus_bit_count() >= + kMarginBitsForDec + cntxt->parms().plain_modulus().bit_count()) { + break; + } + final_level_id = cntxt->prev_context_data()->parms_id(); + } + + RLWECt tmp_ct; + for (size_t i = 0; i < ct0.size(); ++i) { + // 1. Ct-Pt Mul + evaluator.multiply_plain_inplace(ct0[i], pt0[i]); + evaluator.multiply_plain(ct1[i], pt1[i], tmp_ct); + evaluator.add_inplace(ct0[i], tmp_ct); + + // 2. Noise flooding + NoiseFloodInplace(ct0[i], context); + + // 3. Drop some modulus for a smaller communication + evaluator.mod_switch_to_inplace(ct0[i], final_level_id); + + // 4. Re-randomize via adding enc(0) + seal::util::encrypt_zero_asymmetric(public_key, context, ct0[i].parms_id(), + ct0[i].is_ntt_form(), zero_enc); + evaluator.add_inplace(ct0[i], zero_enc); + + // 5. Additive share + size_t slice_bgn = i * simd_lane_; + size_t slice_n = + std::min((size_t)simd_lane_, share_mask.size() - slice_bgn); + EncodeSingle(share_mask.subspan(slice_bgn, slice_n), rnd); + evaluator.sub_plain_inplace(ct0[i], rnd); + + // 6. Truncate for smaller communication + if (ct0[i].coeff_modulus_size() == 1) { + TruncateBFVForDecryption(ct0[i], context); + } + } +} + void SIMDMulProt::NoiseFloodInplace(RLWECt &ct, const seal::SEALContext &context) { SPU_ENFORCE(seal::is_metadata_valid_for(ct, context)); diff --git a/libspu/mpc/cheetah/arith/simd_mul_prot.h b/libspu/mpc/cheetah/arith/simd_mul_prot.h index 22d79f02..c18d7da5 100644 --- a/libspu/mpc/cheetah/arith/simd_mul_prot.h +++ b/libspu/mpc/cheetah/arith/simd_mul_prot.h @@ -55,11 +55,19 @@ class SIMDMulProt : public EnableCPRNG { const RLWEPublicKey& public_key, const seal::SEALContext& context); - void MulThenReshareInplaceOneBit(absl::Span ct, - absl::Span pt, - absl::Span share_mask, - const RLWEPublicKey& public_key, - const seal::SEALContext& context); + // ct0 * pt0 + ct1 * pt1 + mask + void FMAThenReshareInplace(absl::Span ct0, + absl::Span ct1, + absl::Span pt0, + absl::Span pt1, + absl::Span share_mask, + const RLWEPublicKey& public_key, + const seal::SEALContext& context); + + [[deprecated]] void MulThenReshareInplaceOneBit( + absl::Span ct, absl::Span pt, + absl::Span share_mask, const RLWEPublicKey& public_key, + const seal::SEALContext& context); inline int64_t SIMDLane() const { return simd_lane_; } diff --git a/libspu/mpc/cheetah/arith/simd_mul_prot_test.cc b/libspu/mpc/cheetah/arith/simd_mul_prot_test.cc index 006baaee..3c15ff57 100644 --- a/libspu/mpc/cheetah/arith/simd_mul_prot_test.cc +++ b/libspu/mpc/cheetah/arith/simd_mul_prot_test.cc @@ -87,7 +87,7 @@ class SIMDMulTest : public ::testing::TestWithParam, public EnableCPRNG { }; INSTANTIATE_TEST_SUITE_P( - Cheetah, SIMDMulTest, testing::Values(true, false), + Cheetah, SIMDMulTest, testing::Values(true), [](const testing::TestParamInfo &p) { return fmt::format("{}", p.param ? "NoiseFlood" : "Approx"); }); @@ -116,20 +116,10 @@ TEST_P(SIMDMulTest, Basic) { simd_mul_prot_->SymEncrypt(encode_b, *rlwe_sk_, *context_, false, absl::MakeSpan(encrypt_b)); - if (GetParam()) { - RandomPlain(absl::MakeSpan(out_a)); - simd_mul_prot_->MulThenReshareInplace(absl::MakeSpan(encrypt_b), encode_a, - absl::MakeConstSpan(out_a), - *rlwe_pk_, *context_); - } else { - simd_mul_prot_->MulThenReshareInplaceOneBit( - absl::MakeSpan(encrypt_b), encode_a, absl::MakeSpan(out_a), *rlwe_pk_, - *context_); - } - if (rep == 0) { - printf("rep ct.L %zd\n", encrypt_b[0].coeff_modulus_size()); - } - + RandomPlain(absl::MakeSpan(out_a)); + simd_mul_prot_->MulThenReshareInplace(absl::MakeSpan(encrypt_b), encode_a, + absl::MakeConstSpan(out_a), *rlwe_pk_, + *context_); auto _out_b = absl::MakeSpan(out_b); for (size_t i = 0; i < num_pt; ++i) { seal::Plaintext pt; diff --git a/libspu/mpc/cheetah/arith/vector_encoder.cc b/libspu/mpc/cheetah/arith/vector_encoder.cc index 52b71488..3ee37a2b 100644 --- a/libspu/mpc/cheetah/arith/vector_encoder.cc +++ b/libspu/mpc/cheetah/arith/vector_encoder.cc @@ -16,7 +16,6 @@ #include "libspu/core/prelude.h" #include "libspu/core/type_util.h" -#include "libspu/core/xt_helper.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::cheetah { @@ -84,7 +83,7 @@ void VectorEncoder::Backward(const NdArrayRef &vec, RLWEPt *out, const auto field = eltype.as()->field(); - DISPATCH_ALL_FIELDS(field, "Backward", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto tmp_buff = ring_zeros(field, {(int64_t)poly_deg_}); auto xvec = NdArrayView(vec); auto xtmp = NdArrayView(tmp_buff); diff --git a/libspu/mpc/cheetah/arith/vector_encoder_test.cc b/libspu/mpc/cheetah/arith/vector_encoder_test.cc index 6056a5c6..0231eeb7 100644 --- a/libspu/mpc/cheetah/arith/vector_encoder_test.cc +++ b/libspu/mpc/cheetah/arith/vector_encoder_test.cc @@ -125,7 +125,7 @@ TEST_P(VectorEncoderTest, ForwardBackward) { auto computed = ms_helper_->ModulusDownRNS(field_, {1L}, absl::MakeSpan(cnst)); - DISPATCH_ALL_FIELDS(field_, "Check", [&]() { + DISPATCH_ALL_FIELDS(field_, [&]() { NdArrayView got(computed); NdArrayView v0(vec0); NdArrayView v1(vec1); diff --git a/libspu/mpc/cheetah/arithmetic.cc b/libspu/mpc/cheetah/arithmetic.cc index e90e80ae..fbf85043 100644 --- a/libspu/mpc/cheetah/arithmetic.cc +++ b/libspu/mpc/cheetah/arithmetic.cc @@ -74,7 +74,7 @@ NdArrayRef MsbA2B::proc(KernelEvalContext* ctx, const NdArrayRef& x) const { const int rank = ctx->getState()->getRank(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using u2k = std::make_unsigned::type; const u2k mask = (static_cast(1) << shft) - 1; NdArrayRef adjusted = ring_zeros(field, x.shape()); @@ -245,7 +245,7 @@ NdArrayRef MulAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, int64_t batch_sze = ctx->getState()->get()->OLEBatchSize(); int64_t numel = x.numel(); - if (numel >= batch_sze) { + if (numel >= 2 * batch_sze) { return mulDirectly(ctx, x, y); } return mulWithBeaver(ctx, x, y); @@ -309,7 +309,7 @@ NdArrayRef MulAA::mulWithBeaver(KernelEvalContext* ctx, const NdArrayRef& x, auto* comm = ctx->getState(); // Open x - a & y - b auto res = vmap({ring_sub(x, a), ring_sub(y, b)}, [&](const NdArrayRef& s) { - return comm->allReduce(ReduceOp::ADD, s, kBindName); + return comm->allReduce(ReduceOp::ADD, s, kBindName()); }); auto x_a = std::move(res[0]); auto y_b = std::move(res[1]); @@ -326,6 +326,46 @@ NdArrayRef MulAA::mulWithBeaver(KernelEvalContext* ctx, const NdArrayRef& x, return z.as(x.eltype()); } +#if 1 +NdArrayRef MulAA::mulDirectly(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) const { + // Compute (x0 + x1) * (y0+ y1) + auto* comm = ctx->getState(); + auto* mul_prot = ctx->getState()->get(); + mul_prot->LazyInitKeys(x.eltype().as()->field()); + + auto fx = x.reshape({x.numel()}); + auto fy = y.reshape({y.numel()}); + const int64_t n = fx.numel(); + const int64_t nhalf = n / 2; + const int rank = comm->getRank(); + + // For long vectors, split into two subtasks. + auto dupx = ctx->getState()->duplx(); + std::future task = std::async(std::launch::async, [&] { + return mul_prot->MulShare(fx.slice({nhalf}, {n}, {1}), + fy.slice({nhalf}, {n}, {1}), dupx.get(), + /*evaluator*/ rank == 0); + }); + + std::vector out_slices(2); + out_slices[0] = + mul_prot->MulShare(fx.slice({0}, {nhalf}, {1}), + fy.slice({0}, {nhalf}, {1}), /*evaluato*/ rank != 0); + out_slices[1] = task.get(); + + NdArrayRef out(x.eltype(), x.shape()); + int64_t offset = 0; + for (auto& out_slice : out_slices) { + std::memcpy(out.data() + offset, out_slice.data(), + out_slice.numel() * out.elsize()); + offset += out_slice.numel() * out.elsize(); + } + return out; +} +#else +// Old code for MulAA using two OLEs which commnuicate about 30% more than the +// above version. NdArrayRef MulAA::mulDirectly(KernelEvalContext* ctx, const NdArrayRef& x, const NdArrayRef& y) const { // (x0 + x1) * (y0+ y1) @@ -335,7 +375,6 @@ NdArrayRef MulAA::mulDirectly(KernelEvalContext* ctx, const NdArrayRef& x, mul_prot->LazyInitKeys(x.eltype().as()->field()); const int rank = comm->getRank(); - // auto fy = y.reshape({y.numel()}); auto dupx = ctx->getState()->duplx(); std::future task = std::async(std::launch::async, [&] { @@ -355,6 +394,7 @@ NdArrayRef MulAA::mulDirectly(KernelEvalContext* ctx, const NdArrayRef& x, NdArrayRef x0y1 = task.get(); return ring_add(x0y1, ring_add(x1y0, ring_mul(x, y))).as(x.eltype()); } +#endif NdArrayRef MatMulVVS::proc(KernelEvalContext* ctx, const NdArrayRef& x, const NdArrayRef& y) const { diff --git a/libspu/mpc/cheetah/arithmetic.h b/libspu/mpc/cheetah/arithmetic.h index a0a1a5a6..ebbb0ffa 100644 --- a/libspu/mpc/cheetah/arithmetic.h +++ b/libspu/mpc/cheetah/arithmetic.h @@ -19,7 +19,7 @@ namespace spu::mpc::cheetah { class RandA : public RandKernel { public: - static constexpr char kBindName[] = "rand_a"; + static constexpr const char* kBindName() { return "rand_a"; } Kind kind() const override { return Kind::Dynamic; } @@ -28,7 +28,7 @@ class RandA : public RandKernel { class P2A : public UnaryKernel { public: - static constexpr char kBindName[] = "p2a"; + static constexpr const char* kBindName() { return "p2a"; } Kind kind() const override { return Kind::Dynamic; } @@ -37,7 +37,7 @@ class P2A : public UnaryKernel { class A2P : public UnaryKernel { public: - static constexpr char kBindName[] = "a2p"; + static constexpr const char* kBindName() { return "a2p"; } Kind kind() const override { return Kind::Dynamic; } @@ -46,7 +46,7 @@ class A2P : public UnaryKernel { class A2V : public RevealToKernel { public: - static constexpr char kBindName[] = "a2v"; + static constexpr const char* kBindName() { return "a2v"; } Kind kind() const override { return Kind::Dynamic; } @@ -56,16 +56,16 @@ class A2V : public RevealToKernel { class V2A : public UnaryKernel { public: - static constexpr char kBindName[] = "v2a"; + static constexpr const char* kBindName() { return "v2a"; } Kind kind() const override { return Kind::Dynamic; } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override; }; -class NotA : public UnaryKernel { +class NegateA : public UnaryKernel { public: - static constexpr char kBindName[] = "not_a"; + static constexpr const char* kBindName() { return "negate_a"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -76,7 +76,7 @@ class NotA : public UnaryKernel { class AddAP : public BinaryKernel { public: - static constexpr char kBindName[] = "add_ap"; + static constexpr const char* kBindName() { return "add_ap"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -88,7 +88,7 @@ class AddAP : public BinaryKernel { class AddAA : public BinaryKernel { public: - static constexpr char kBindName[] = "add_aa"; + static constexpr const char* kBindName() { return "add_aa"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -100,7 +100,7 @@ class AddAA : public BinaryKernel { class MulAP : public BinaryKernel { public: - static constexpr char kBindName[] = "mul_ap"; + static constexpr const char* kBindName() { return "mul_ap"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -113,7 +113,7 @@ class MulAP : public BinaryKernel { class MulA1B : public BinaryKernel { public: - static constexpr char kBindName[] = "mul_a1b"; + static constexpr const char* kBindName() { return "mul_a1b"; } Kind kind() const override { return Kind::Dynamic; } @@ -123,7 +123,7 @@ class MulA1B : public BinaryKernel { class MulA1BV : public BinaryKernel { public: - static constexpr char kBindName[] = "mul_a1bv"; + static constexpr const char* kBindName() { return "mul_a1bv"; } Kind kind() const override { return Kind::Dynamic; } @@ -142,7 +142,7 @@ class MulAA : public BinaryKernel { NdArrayRef squareDirectly(KernelEvalContext* ctx, const NdArrayRef& x) const; public: - static constexpr char kBindName[] = "mul_aa"; + static constexpr const char* kBindName() { return "mul_aa"; } Kind kind() const override { return Kind::Dynamic; } @@ -152,7 +152,7 @@ class MulAA : public BinaryKernel { class SquareA : public UnaryKernel { public: - static constexpr char kBindName[] = "square_a"; + static constexpr const char* kBindName() { return "square_a"; } Kind kind() const override { return Kind::Dynamic; } @@ -165,7 +165,7 @@ class MulAV : public BinaryKernel { const NdArrayRef& rhs) const; public: - static constexpr char kBindName[] = "mul_av"; + static constexpr const char* kBindName() { return "mul_av"; } Kind kind() const override { return Kind::Dynamic; } @@ -178,7 +178,7 @@ class MulAV : public BinaryKernel { //////////////////////////////////////////////////////////////////// class MatMulAP : public MatmulKernel { public: - static constexpr char kBindName[] = "mmul_ap"; + static constexpr const char* kBindName() { return "mmul_ap"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -190,7 +190,7 @@ class MatMulAP : public MatmulKernel { class MatMulAV : public MatmulKernel { public: - static constexpr char kBindName[] = "mmul_av"; + static constexpr const char* kBindName() { return "mmul_av"; } Kind kind() const override { return Kind::Dynamic; } // LHS: m x k @@ -201,7 +201,7 @@ class MatMulAV : public MatmulKernel { class MatMulVVS : public MatmulKernel { public: - static constexpr char kBindName[] = "mmul_vvs"; + static constexpr const char* kBindName() { return "mmul_vvs"; } Kind kind() const override { return Kind::Dynamic; } // LHS: m x k @@ -212,7 +212,7 @@ class MatMulVVS : public MatmulKernel { class MatMulAA : public MatmulKernel { public: - static constexpr char kBindName[] = "mmul_aa"; + static constexpr const char* kBindName() { return "mmul_aa"; } Kind kind() const override { return Kind::Dynamic; } // LHS: m x k @@ -223,7 +223,7 @@ class MatMulAA : public MatmulKernel { class Conv2DAA : public Conv2DKernel { public: - static constexpr char kBindName[] = "conv2d_aa"; + static constexpr const char* kBindName() { return "conv2d_aa"; } Kind kind() const override { return Kind::Dynamic; } @@ -234,7 +234,7 @@ class Conv2DAA : public Conv2DKernel { class TruncA : public TruncAKernel { public: - static constexpr char kBindName[] = "trunc_a"; + static constexpr const char* kBindName() { return "trunc_a"; } Kind kind() const override { return Kind::Dynamic; } @@ -250,19 +250,19 @@ class TruncA : public TruncAKernel { class LShiftA : public ShiftKernel { public: - static constexpr char kBindName[] = "lshift_a"; + static constexpr const char* kBindName() { return "lshift_a"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; class MsbA2B : public UnaryKernel { public: - static constexpr char kBindName[] = "msb_a2b"; + static constexpr const char* kBindName() { return "msb_a2b"; } MsbA2B(size_t nbits = 0) : nbits_(nbits) {} @@ -276,7 +276,7 @@ class MsbA2B : public UnaryKernel { class EqualAA : public BinaryKernel { public: - static constexpr char kBindName[] = "equal_aa"; + static constexpr const char* kBindName() { return "equal_aa"; } EqualAA(size_t nbits = 0) : nbits_(nbits) {} @@ -291,7 +291,7 @@ class EqualAA : public BinaryKernel { class EqualAP : public BinaryKernel { public: - static constexpr char kBindName[] = "equal_ap"; + static constexpr const char* kBindName() { return "equal_ap"; } Kind kind() const override { return Kind::Dynamic; } @@ -301,7 +301,7 @@ class EqualAP : public BinaryKernel { class LessAP : public BinaryKernel { public: - static constexpr char kBindName[] = "f_less_ap"; + static constexpr const char* kBindName() { return "f_less_ap"; } Kind kind() const override { return Kind::Dynamic; } @@ -311,7 +311,7 @@ class LessAP : public BinaryKernel { class LessPA : public BinaryKernel { public: - static constexpr char kBindName[] = "f_less_pa"; + static constexpr const char* kBindName() { return "f_less_pa"; } Kind kind() const override { return Kind::Dynamic; } diff --git a/libspu/mpc/cheetah/arithmetic_semi2k.cc b/libspu/mpc/cheetah/arithmetic_semi2k.cc index 89ef7daa..85dadfc3 100644 --- a/libspu/mpc/cheetah/arithmetic_semi2k.cc +++ b/libspu/mpc/cheetah/arithmetic_semi2k.cc @@ -24,7 +24,7 @@ namespace spu::mpc::cheetah { NdArrayRef RandA::proc(KernelEvalContext* ctx, const Shape& shape) const { auto* prg_state = ctx->getState(); const auto field = ctx->getState()->getDefaultField(); - return ring_rshift(prg_state->genPriv(field, shape), 2) + return ring_rshift(prg_state->genPriv(field, shape), {2}) .as(makeType(field)); } @@ -47,7 +47,7 @@ NdArrayRef P2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { NdArrayRef A2P::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { const auto field = in.eltype().as()->field(); auto* comm = ctx->getState(); - auto out = comm->allReduce(ReduceOp::ADD, in, kBindName); + auto out = comm->allReduce(ReduceOp::ADD, in, kBindName()); return out.as(makeType(field)); } @@ -59,7 +59,7 @@ NdArrayRef A2V::proc(KernelEvalContext* ctx, const NdArrayRef& in, auto numel = in.numel(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { std::vector share(numel); NdArrayView _in(in); pforeach(0, numel, [&](int64_t idx) { share[idx] = _in[idx]; }); @@ -101,13 +101,8 @@ NdArrayRef V2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { return x.as(makeType(field)); } -NdArrayRef NotA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { - auto* comm = ctx->getState(); +NdArrayRef NegateA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { auto res = ring_neg(in); - if (comm->getRank() == 0) { - const auto field = in.eltype().as()->field(); - ring_add_(res, ring_not(ring_zeros(field, in.shape()))); - } return res.as(in.eltype()); } @@ -143,10 +138,7 @@ NdArrayRef MatMulAP::proc(KernelEvalContext*, const NdArrayRef& x, } NdArrayRef LShiftA::proc(KernelEvalContext*, const NdArrayRef& in, - size_t bits) const { - const auto field = in.eltype().as()->field(); - bits %= SizeOf(field) * 8; - + const Sizes& bits) const { return ring_lshift(in, bits).as(in.eltype()); } diff --git a/libspu/mpc/cheetah/boolean.h b/libspu/mpc/cheetah/boolean.h index cdca12db..c3799708 100644 --- a/libspu/mpc/cheetah/boolean.h +++ b/libspu/mpc/cheetah/boolean.h @@ -20,7 +20,7 @@ namespace spu::mpc::cheetah { class CommonTypeB : public Kernel { public: - static constexpr char kBindName[] = "common_type_b"; + static constexpr const char* kBindName() { return "common_type_b"; } Kind kind() const override { return Kind::Dynamic; } @@ -29,7 +29,7 @@ class CommonTypeB : public Kernel { class CastTypeB : public CastTypeKernel { public: - static constexpr char kBindName[] = "cast_type_b"; + static constexpr const char* kBindName() { return "cast_type_b"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -41,7 +41,7 @@ class CastTypeB : public CastTypeKernel { class B2P : public UnaryKernel { public: - static constexpr char kBindName[] = "b2p"; + static constexpr const char* kBindName() { return "b2p"; } ce::CExpr latency() const override { return ce::Const(1); } @@ -52,7 +52,7 @@ class B2P : public UnaryKernel { class P2B : public UnaryKernel { public: - static constexpr char kBindName[] = "p2b"; + static constexpr const char* kBindName() { return "p2b"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -63,7 +63,7 @@ class P2B : public UnaryKernel { class AndBP : public BinaryKernel { public: - static constexpr char kBindName[] = "and_bp"; + static constexpr const char* kBindName() { return "and_bp"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -75,7 +75,7 @@ class AndBP : public BinaryKernel { class AndBB : public BinaryKernel { public: - static constexpr char kBindName[] = "and_bb"; + static constexpr const char* kBindName() { return "and_bb"; } Kind kind() const override { return Kind::Dynamic; } @@ -85,7 +85,7 @@ class AndBB : public BinaryKernel { class XorBP : public BinaryKernel { public: - static constexpr char kBindName[] = "xor_bp"; + static constexpr const char* kBindName() { return "xor_bp"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -97,7 +97,7 @@ class XorBP : public BinaryKernel { class XorBB : public BinaryKernel { public: - static constexpr char kBindName[] = "xor_bb"; + static constexpr const char* kBindName() { return "xor_bb"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -109,43 +109,43 @@ class XorBB : public BinaryKernel { class LShiftB : public ShiftKernel { public: - static constexpr char kBindName[] = "lshift_b"; + static constexpr const char* kBindName() { return "lshift_b"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const override; + const Sizes& shift) const override; }; class RShiftB : public ShiftKernel { public: - static constexpr char kBindName[] = "rshift_b"; + static constexpr const char* kBindName() { return "rshift_b"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const override; + const Sizes& shift) const override; }; class ARShiftB : public ShiftKernel { public: - static constexpr char kBindName[] = "arshift_b"; + static constexpr const char* kBindName() { return "arshift_b"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const override; + const Sizes& shift) const override; }; class BitrevB : public BitrevKernel { public: - static constexpr char kBindName[] = "bitrev_b"; + static constexpr const char* kBindName() { return "bitrev_b"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -157,7 +157,7 @@ class BitrevB : public BitrevKernel { class BitIntlB : public BitSplitKernel { public: - static constexpr char kBindName[] = "bitintl_b"; + static constexpr const char* kBindName() { return "bitintl_b"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -169,7 +169,7 @@ class BitIntlB : public BitSplitKernel { class BitDeintlB : public BitSplitKernel { public: - static constexpr char kBindName[] = "bitdeintl_b"; + static constexpr const char* kBindName() { return "bitdeintl_b"; } ce::CExpr latency() const override { return ce::Const(0); } diff --git a/libspu/mpc/cheetah/boolean_semi2k.cc b/libspu/mpc/cheetah/boolean_semi2k.cc index 786d0c38..c432e38a 100644 --- a/libspu/mpc/cheetah/boolean_semi2k.cc +++ b/libspu/mpc/cheetah/boolean_semi2k.cc @@ -25,7 +25,7 @@ namespace { size_t getNumBits(const NdArrayRef& in) { if (in.eltype().isa()) { const auto field = in.eltype().as()->field(); - return DISPATCH_ALL_FIELDS(field, "_", + return DISPATCH_ALL_FIELDS(field, [&]() { return maxBitWidth(in); }); } else if (in.eltype().isa()) { return in.eltype().as()->nbits(); @@ -64,7 +64,7 @@ NdArrayRef CastTypeB::proc(KernelEvalContext*, const NdArrayRef& in, NdArrayRef B2P::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { const auto field = in.eltype().as()->field(); auto* comm = ctx->getState(); - auto out = comm->allReduce(ReduceOp::XOR, in, kBindName); + auto out = comm->allReduce(ReduceOp::XOR, in, kBindName()); return out.as(makeType(field)); } @@ -93,7 +93,7 @@ NdArrayRef AndBP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const size_t out_nbits = std::min(getNumBits(lhs), getNumBits(rhs)); NdArrayRef out(makeType(field, out_nbits), lhs.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _lhs(lhs); NdArrayView _rhs(rhs); NdArrayView _out(out); @@ -130,32 +130,30 @@ NdArrayRef XorBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, } NdArrayRef LShiftB::proc(KernelEvalContext*, const NdArrayRef& in, - size_t shift) const { + const Sizes& shift) const { const auto field = in.eltype().as()->field(); - shift %= SizeOf(field) * 8; - - size_t out_nbits = in.eltype().as()->nbits() + shift; + size_t out_nbits = in.eltype().as()->nbits() + + *std::max_element(shift.begin(), shift.end()); out_nbits = std::clamp(out_nbits, static_cast(0), SizeOf(field) * 8); return makeBShare(ring_lshift(in, shift), field, out_nbits); } NdArrayRef RShiftB::proc(KernelEvalContext*, const NdArrayRef& in, - size_t shift) const { + const Sizes& shift) const { const auto field = in.eltype().as()->field(); - shift %= SizeOf(field) * 8; - size_t nbits = in.eltype().as()->nbits(); - size_t out_nbits = nbits - std::min(nbits, shift); - SPU_ENFORCE(nbits <= SizeOf(field) * 8); + int64_t nbits = in.eltype().as()->nbits(); + int64_t out_nbits = + nbits - std::min(nbits, *std::min_element(shift.begin(), shift.end())); + SPU_ENFORCE(nbits <= static_cast(SizeOf(field) * 8)); return makeBShare(ring_rshift(in, shift), field, out_nbits); } NdArrayRef ARShiftB::proc(KernelEvalContext*, const NdArrayRef& in, - size_t shift) const { + const Sizes& shift) const { const auto field = in.eltype().as()->field(); - shift %= SizeOf(field) * 8; // arithmetic right shift expects to work on ring, or the behaviour is // undefined. @@ -183,7 +181,7 @@ NdArrayRef BitIntlB::proc(KernelEvalContext*, const NdArrayRef& in, NdArrayRef out(in.eltype(), in.shape()); auto numel = in.numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _in(in); NdArrayView _out(out); @@ -204,7 +202,7 @@ NdArrayRef BitDeintlB::proc(KernelEvalContext*, const NdArrayRef& in, NdArrayRef out(in.eltype(), in.shape()); auto numel = in.numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _in(in); NdArrayView _out(out); diff --git a/libspu/mpc/cheetah/conversion.h b/libspu/mpc/cheetah/conversion.h index 1272824a..ad468801 100644 --- a/libspu/mpc/cheetah/conversion.h +++ b/libspu/mpc/cheetah/conversion.h @@ -20,7 +20,7 @@ namespace spu::mpc::cheetah { class A2B : public UnaryKernel { public: - static constexpr char kBindName[] = "a2b"; + static constexpr const char* kBindName() { return "a2b"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -31,7 +31,7 @@ class A2B : public UnaryKernel { class B2A : public UnaryKernel { public: - static constexpr char kBindName[] = "b2a"; + static constexpr const char* kBindName() { return "b2a"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -42,7 +42,7 @@ class B2A : public UnaryKernel { class CommonTypeV : public Kernel { public: - static constexpr char kBindName[] = "common_type_v"; + static constexpr const char* kBindName() { return "common_type_v"; } Kind kind() const override { return Kind::Dynamic; } diff --git a/libspu/mpc/cheetah/nonlinear/compare_prot.cc b/libspu/mpc/cheetah/nonlinear/compare_prot.cc index 7ee00a9f..94d8b2a1 100644 --- a/libspu/mpc/cheetah/nonlinear/compare_prot.cc +++ b/libspu/mpc/cheetah/nonlinear/compare_prot.cc @@ -16,13 +16,11 @@ #include "yacl/crypto/rand/rand.h" #include "yacl/crypto/tools/prg.h" -#include "yacl/link/link.h" #include "libspu/core/type.h" #include "libspu/mpc/cheetah/ot/basic_ot_prot.h" #include "libspu/mpc/cheetah/ot/ot_util.h" #include "libspu/mpc/cheetah/type.h" -#include "libspu/mpc/common/communicator.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::cheetah { @@ -58,14 +56,14 @@ void SetLeafOTMsg(absl::Span ot_messages, uint8_t digit, NdArrayRef CompareProtocol::DoCompute(const NdArrayRef& inp, bool greater_than, NdArrayRef* keep_eq, int64_t bitwidth) { auto field = inp.eltype().as()->field(); - int64_t num_digits = CeilDiv(bitwidth, (int64_t)compare_radix_); + int64_t num_digits = CeilDiv(bitwidth, static_cast(compare_radix_)); size_t radix = static_cast(1) << compare_radix_; // one-of-N OT int64_t num_cmp = inp.numel(); // init to all zero std::vector digits(num_cmp * num_digits, 0); // Step 1 break into digits \in [0, radix) - DISPATCH_ALL_FIELDS(field, "break_digits", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using u2k = std::make_unsigned::type; const auto mask_radix = makeBitsMask(compare_radix_); NdArrayView xinp(inp); @@ -142,7 +140,7 @@ NdArrayRef CompareProtocol::DoCompute(const NdArrayRef& inp, bool greater_than, ring_zeros(field, {static_cast(num_digits * num_cmp)}) .as(boolean_t); - DISPATCH_ALL_FIELDS(field, "copy_leaf", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView xprev_cmp(prev_cmp); NdArrayView xprev_eq(prev_eq); pforeach(0, prev_cmp.numel(), [&](int64_t i) { diff --git a/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc b/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc index a04da8b4..5abf60dc 100644 --- a/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc +++ b/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc @@ -14,12 +14,9 @@ #include "libspu/mpc/cheetah/nonlinear/compare_prot.h" -#include - #include "gtest/gtest.h" #include "libspu/mpc/cheetah/ot/basic_ot_prot.h" -#include "libspu/mpc/cheetah/type.h" #include "libspu/mpc/utils/ring_ops.h" #include "libspu/mpc/utils/simulate.h" @@ -51,7 +48,7 @@ TEST_P(CompareProtTest, Compare) { inp[0] = ring_rand(field, shape); inp[1] = ring_rand(field, shape); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto xinp = NdArrayView(inp[0]); xinp[0] = 1; xinp[1] = 10; @@ -60,7 +57,11 @@ TEST_P(CompareProtTest, Compare) { xinp = NdArrayView(inp[1]); xinp[0] = 1; xinp[1] = 9; - xinp[2] = 1000; + if constexpr (std::is_same_v) { + xinp[2] = 100; + } else { + xinp[2] = 1000; + } }); NdArrayRef cmp_oup[2]; @@ -74,7 +75,7 @@ TEST_P(CompareProtTest, Compare) { cmp_oup[rank] = _c; }); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto xout0 = NdArrayView(cmp_oup[0]); auto xout1 = NdArrayView(cmp_oup[1]); auto xinp0 = NdArrayView(inp[0]); @@ -100,7 +101,7 @@ TEST_P(CompareProtTest, CompareBitWidth) { inp[0] = ring_rand(field, {n, 2}); inp[1] = ring_rand(field, {n, 2}); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { ring2k_t mask = (static_cast(1) << bw) - 1; auto xinp = NdArrayView(inp[0]); xinp[0] = 1; @@ -111,7 +112,11 @@ TEST_P(CompareProtTest, CompareBitWidth) { xinp = NdArrayView(inp[1]); xinp[0] = 1; xinp[1] = 9; - xinp[2] = 1000; + if constexpr (std::is_same_v) { + xinp[2] = 100; + } else { + xinp[2] = 1000; + } pforeach(0, inp[0].numel(), [&](int64_t i) { xinp[i] &= mask; }); }); @@ -142,7 +147,7 @@ TEST_P(CompareProtTest, CompareBitWidth) { cmp_oup[rank] = _c; }); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto xout0 = NdArrayView(cmp_oup[0]); auto xout1 = NdArrayView(cmp_oup[1]); auto xinp0 = NdArrayView(inp[0]); @@ -172,7 +177,7 @@ TEST_P(CompareProtTest, WithEq) { inp[1] = _inp[0].slice({0, 0, 0}, {10, 10, 10}, {2, 3, 2}); shape = inp[0].shape(); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto xinp = NdArrayView(inp[0]); xinp[0] = 1; xinp[1] = 10; @@ -181,7 +186,11 @@ TEST_P(CompareProtTest, WithEq) { xinp = NdArrayView(inp[1]); xinp[0] = 1; xinp[1] = 9; - xinp[2] = 1000; + if constexpr (std::is_same_v) { + xinp[2] = 100; + } else { + xinp[2] = 1000; + } }); NdArrayRef cmp_oup[2]; @@ -197,7 +206,7 @@ TEST_P(CompareProtTest, WithEq) { eq_oup[rank] = _e; }); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto xout0 = NdArrayView(cmp_oup[0]); auto xout1 = NdArrayView(cmp_oup[1]); auto xeq0 = NdArrayView(eq_oup[0]); @@ -229,7 +238,7 @@ TEST_P(CompareProtTest, WithEqBitWidth) { inp[0] = ring_rand(field, {n, 2}); inp[1] = ring_rand(field, {n, 2}); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { ring2k_t mask = (static_cast(1) << bw) - 1; auto xinp = NdArrayView(inp[0]); xinp[0] = 1; @@ -240,7 +249,11 @@ TEST_P(CompareProtTest, WithEqBitWidth) { xinp = NdArrayView(inp[1]); xinp[0] = 1; xinp[1] = 9; - xinp[2] = 1000; + if constexpr (std::is_same_v) { + xinp[2] = 100; + } else { + xinp[2] = 1000; + } pforeach(0, inp[0].numel(), [&](int64_t i) { xinp[i] &= mask; }); }); @@ -271,7 +284,7 @@ TEST_P(CompareProtTest, WithEqBitWidth) { eq_oup[rank] = _e; }); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto xout0 = NdArrayView(cmp_oup[0]); auto xout1 = NdArrayView(cmp_oup[1]); auto xeq0 = NdArrayView(eq_oup[0]); diff --git a/libspu/mpc/cheetah/nonlinear/equal_prot.cc b/libspu/mpc/cheetah/nonlinear/equal_prot.cc index 1ad852e6..2d04c886 100644 --- a/libspu/mpc/cheetah/nonlinear/equal_prot.cc +++ b/libspu/mpc/cheetah/nonlinear/equal_prot.cc @@ -16,13 +16,11 @@ #include "yacl/crypto/rand/rand.h" #include "yacl/crypto/tools/prg.h" -#include "yacl/link/link.h" #include "libspu/core/type.h" #include "libspu/mpc/cheetah/ot/basic_ot_prot.h" #include "libspu/mpc/cheetah/ot/ot_util.h" #include "libspu/mpc/cheetah/type.h" -#include "libspu/mpc/common/communicator.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::cheetah { @@ -61,7 +59,7 @@ NdArrayRef EqualProtocol::DoCompute(const NdArrayRef& inp, size_t bit_width) { // init to all zero std::vector digits(num_cmp * num_digits, 0); - DISPATCH_ALL_FIELDS(field, "Equal_break_digits", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using u2k = std::make_unsigned::type; const auto mask_radix = makeBitsMask(compare_radix_); const auto mask_remain = makeBitsMask(remain); @@ -133,7 +131,7 @@ NdArrayRef EqualProtocol::DoCompute(const NdArrayRef& inp, size_t bit_width) { // m0[1], m1[1], ..., mN[1] // ... // m0[M], m1[M], ..., mN[M] - DISPATCH_ALL_FIELDS(field, "Equal_transpose", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView xprev_eq(prev_eq); for (int64_t r = 0; r < num_cmp; ++r) { for (int64_t c = 0; c < num_digits; ++c) { diff --git a/libspu/mpc/cheetah/nonlinear/equal_prot_test.cc b/libspu/mpc/cheetah/nonlinear/equal_prot_test.cc index 7a51cbd5..4eeead58 100644 --- a/libspu/mpc/cheetah/nonlinear/equal_prot_test.cc +++ b/libspu/mpc/cheetah/nonlinear/equal_prot_test.cc @@ -45,7 +45,7 @@ TEST_P(EqualProtTest, Basic) { inp[0] = ring_rand(field, shape); inp[1] = ring_rand(field, shape); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto xinp0 = NdArrayView(inp[0]); auto xinp1 = NdArrayView(inp[1]); std::copy_n(&xinp1[0], 5, &xinp0[0]); @@ -64,7 +64,7 @@ TEST_P(EqualProtTest, Basic) { SPU_ENFORCE_EQ(eq_oup[0].shape(), shape); SPU_ENFORCE_EQ(eq_oup[1].shape(), shape); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto xeq0 = NdArrayView(eq_oup[0]); auto xeq1 = NdArrayView(eq_oup[1]); auto xinp0 = NdArrayView(inp[0]); diff --git a/libspu/mpc/cheetah/nonlinear/truncate_prot.cc b/libspu/mpc/cheetah/nonlinear/truncate_prot.cc index 3edb5117..32e8dff1 100644 --- a/libspu/mpc/cheetah/nonlinear/truncate_prot.cc +++ b/libspu/mpc/cheetah/nonlinear/truncate_prot.cc @@ -16,7 +16,6 @@ #include "libspu/core/type.h" #include "libspu/mpc/cheetah/nonlinear/compare_prot.h" #include "libspu/mpc/cheetah/ot/basic_ot_prot.h" -#include "libspu/mpc/cheetah/ot/ot_util.h" #include "libspu/mpc/cheetah/type.h" #include "libspu/mpc/utils/ring_ops.h" @@ -66,7 +65,7 @@ NdArrayRef TruncateProtocol::ComputeWrap(const NdArrayRef& inp, wrap_bool = compare_prot.Compute(inp, true); } else { auto adjusted = ring_neg(inp); - DISPATCH_ALL_FIELDS(field, "wrap_adjust", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView xadj(adjusted); pforeach(0, inp.numel(), [&](int64_t i) { xadj[i] -= 1; }); }); @@ -93,7 +92,7 @@ NdArrayRef TruncateProtocol::MSB1ToWrap(const NdArrayRef& inp, const size_t bw = SizeOf(field) * 8; NdArrayRef cot_output = ring_zeros(field, inp.shape()); - DISPATCH_ALL_FIELDS(field, "MSB1ToWrap", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using u2k = std::make_unsigned::type; NdArrayView xinp(inp); auto xout = absl::MakeSpan(&cot_output.at(0), cot_output.numel()); @@ -148,7 +147,7 @@ NdArrayRef TruncateProtocol::MSB0ToWrap(const NdArrayRef& inp, outp = ring_randbit(field, inp.shape()); std::vector send(numel * N); - DISPATCH_ALL_FIELDS(field, "MSB0_adjust", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using u2k = std::make_unsigned::type; NdArrayView xinp(inp); NdArrayView xrnd(outp); @@ -166,7 +165,7 @@ NdArrayRef TruncateProtocol::MSB0ToWrap(const NdArrayRef& inp, sender->Flush(); } else { std::vector choices(numel, 0); - DISPATCH_ALL_FIELDS(field, "MSB0_adjust", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using u2k = std::make_unsigned::type; NdArrayView xinp(inp); for (int64_t i = 0; i < numel; ++i) { @@ -179,7 +178,7 @@ NdArrayRef TruncateProtocol::MSB0ToWrap(const NdArrayRef& inp, absl::MakeSpan(recv), nbits); outp = ring_zeros(field, inp.shape()); - DISPATCH_ALL_FIELDS(field, "MSB0_finalize", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView xoup(outp); pforeach(0, numel, [&](int64_t i) { xoup[i] = static_cast(recv[i] & 1); @@ -220,7 +219,7 @@ NdArrayRef TruncateProtocol::Compute(const NdArrayRef& inp, Meta meta) { if (rank == 0) { NdArrayRef tmp = inp.clone(); - DISPATCH_ALL_FIELDS(field, "trunc_with_heuristic", [&] { + DISPATCH_ALL_FIELDS(field, [&] { NdArrayView _inp(tmp); ring2k_t big_value = static_cast(1) << (bit_width - kHeuristicBound); @@ -230,7 +229,7 @@ NdArrayRef TruncateProtocol::Compute(const NdArrayRef& inp, Meta meta) { tmp = Compute(tmp, meta); - DISPATCH_ALL_FIELDS(field, "trunc_with_heuristic", [&] { + DISPATCH_ALL_FIELDS(field, [&] { NdArrayView _outp(tmp); ring2k_t big_value = static_cast(1) << (bit_width - kHeuristicBound - shift); @@ -246,7 +245,7 @@ NdArrayRef TruncateProtocol::Compute(const NdArrayRef& inp, Meta meta) { NdArrayRef wrap_ashr; NdArrayRef out = ring_zeros(field, inp.shape()); - return DISPATCH_ALL_FIELDS(field, "Truncate", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { const ring2k_t component = (static_cast(1) << (bit_width - 1)); NdArrayView xinp(inp); diff --git a/libspu/mpc/cheetah/nonlinear/truncate_prot_test.cc b/libspu/mpc/cheetah/nonlinear/truncate_prot_test.cc index 4cf591ac..f1c737e8 100644 --- a/libspu/mpc/cheetah/nonlinear/truncate_prot_test.cc +++ b/libspu/mpc/cheetah/nonlinear/truncate_prot_test.cc @@ -14,12 +14,9 @@ #include "libspu/mpc/cheetah/nonlinear/truncate_prot.h" -#include - #include "gtest/gtest.h" #include "libspu/mpc/cheetah/ot/basic_ot_prot.h" -#include "libspu/mpc/cheetah/type.h" #include "libspu/mpc/utils/ring_ops.h" #include "libspu/mpc/utils/simulate.h" @@ -65,7 +62,7 @@ TEST_P(TruncateProtTest, Basic) { sign = SignType::Unknown; } else { auto msg = ring_rand(field, {n}); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto xmsg = NdArrayView(msg); size_t bw = SizeOf(field) * 8; if (msb == "Zero") { @@ -111,7 +108,7 @@ TEST_P(TruncateProtTest, Basic) { EXPECT_EQ(oup[0].shape(), oup[1].shape()); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using signed_t = std::make_signed::type; using usigned_t = std::make_unsigned::type; @@ -163,9 +160,10 @@ TEST_P(TruncateProtTest, Heuristic) { NdArrayRef inp[2]; inp[0] = ring_rand(field, {n}); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto msg = ring_rand(field, {n}); - ring_rshift_(msg, TruncateProtocol::kHeuristicBound); + ring_rshift_(msg, + {static_cast(TruncateProtocol::kHeuristicBound)}); NdArrayView xmsg(msg); for (int64_t i = 0; i < n; i += 2) { xmsg[i] = -xmsg[i]; @@ -191,7 +189,7 @@ TEST_P(TruncateProtTest, Heuristic) { [[maybe_unused]] int count_zero = 0; [[maybe_unused]] int count_pos = 0; [[maybe_unused]] int count_neg = 0; - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using signed_t = std::make_signed::type; auto xout0 = NdArrayView(oup[0]); diff --git a/libspu/mpc/cheetah/ot/BUILD.bazel b/libspu/mpc/cheetah/ot/BUILD.bazel index 54663cc6..00a1dfc7 100644 --- a/libspu/mpc/cheetah/ot/BUILD.bazel +++ b/libspu/mpc/cheetah/ot/BUILD.bazel @@ -72,6 +72,8 @@ spu_cc_test( srcs = ["ot_util_test.cc"], deps = [ ":ot_util", + "//libspu/mpc/common:communicator", "//libspu/mpc/utils:ring_ops", + "//libspu/mpc/utils:simulate", ], ) diff --git a/libspu/mpc/cheetah/ot/basic_ot_prot.cc b/libspu/mpc/cheetah/ot/basic_ot_prot.cc index 201a57ad..34d05a1c 100644 --- a/libspu/mpc/cheetah/ot/basic_ot_prot.cc +++ b/libspu/mpc/cheetah/ot/basic_ot_prot.cc @@ -68,122 +68,89 @@ NdArrayRef BasicOTProtocols::B2A(const NdArrayRef &inp) { return PackedB2A(inp); } +// Convert the packed boolean shares to arithmetic share +// Input x in Z2k is the packed of b-bits for 1 <= b <= k. +// That is x0, x1, ..., x{b-1} +// Output y in Z2k such that y = \sum_i x{i}*2^i mod 2^k +// +// Ref: The ABY paper https://encrypto.de/papers/DSZ15.pdf Section E NdArrayRef BasicOTProtocols::PackedB2A(const NdArrayRef &inp) { const auto *share_t = inp.eltype().as(); auto field = inp.eltype().as()->field(); - const int64_t n = inp.numel(); - size_t nbits = share_t->nbits() == 0 ? 1 : share_t->nbits(); - if (n >= 8) { - // 8bits-align for a larger input - nbits = (nbits + 7) / 8 * 8; - } - SPU_ENFORCE(nbits > 0 && nbits <= 8 * SizeOf(field)); - - auto rand_bits = DISPATCH_ALL_FIELDS(field, "single_b2a", [&]() { - if ((nbits & 7) or (n * inp.elsize()) & 7) { - // The SseTranspose requires the #rows and #columns is multiple of 8. - // Thus, we call the less efficient RandBits on margin cases. - return RandBits(field, {static_cast(n * nbits)}); - } + const int64_t ring_width = SizeOf(field) * 8; - // More efficient randbits that ultilize collapse COTs. - int64_t B = nbits; - auto r = ring_randbit(field, {n * B}).as(makeType(field, 1)); - const int64_t numl = r.numel(); + const int64_t n = inp.numel(); + const int64_t nbits = share_t->nbits() == 0 ? 1 : share_t->nbits(); + const int64_t numel = n * nbits; - NdArrayRef oup = ring_zeros(field, r.shape()); + NdArrayRef cot_oup = ring_zeros(field, {numel}); + NdArrayRef arith_oup = ring_zeros(field, inp.shape()); + DISPATCH_ALL_FIELDS(field, [&]() { using u2k = std::make_unsigned::type; - auto input = NdArrayView(r); - auto output = absl::MakeSpan(&oup.at(0), numl); - SPU_ENFORCE(oup.isCompact()); + auto input = NdArrayView(inp); + auto cot_output = absl::MakeSpan(&cot_oup.at(0), cot_oup.numel()); if (Rank() == 0) { - std::vector corr_data(numl); - // NOTE(lwj): Masking to make sure there is only single bit. - for (int64_t i = 0; i < numl; ++i) { - // corr=-2*xi - corr_data[i] = -((input[i] & 1) << 1); + std::vector corr_data(numel); + + for (int64_t k = 0; k < nbits; ++k) { + int64_t i = k * n; + auto msk = makeBitsMask(ring_width - k); + for (int64_t j = 0; j < n; ++j) { + // corr[k] = -2*x0_k + corr_data[i + j] = -2 * ((input[j] >> k) & 1); + corr_data[i + j] &= msk; + } } // Run the multiple COT in the collapse mode. - // That is, the i-th COT returns output of `nbits - i` bits. - ferret_sender_->SendCAMCC_Collapse(absl::MakeSpan(corr_data), output, - /*bw*/ nbits, /*num_level*/ nbits); - ferret_sender_->Flush(); + // That is, the k-th COT returns output of `ring_width - k` bits. + // + // The k-th COT gives the arithmetic share of the k-th bit of the input + // according to x_0 ^ x_1 = x_0 + x_1 - 2 * x_0 * x_1 + ferret_sender_->SendCAMCC_Collapse(absl::MakeSpan(corr_data), cot_output, + /*bw*/ ring_width, + /*num_level*/ nbits); - for (int64_t i = 0; i < numl; ++i) { - output[i] = (input[i] & 1) - output[i]; + ferret_sender_->Flush(); + for (int64_t k = 0; k < nbits; ++k) { + int64_t i = k * n; + for (int64_t j = 0; j < n; ++j) { + cot_output[i + j] = ((input[j] >> k) & 1) - cot_output[i + j]; + } } } else { - std::vector choices(numl); - for (int64_t i = 0; i < numl; ++i) { - choices[i] = static_cast(input[i] & 1); - } - ferret_receiver_->RecvCAMCC_Collapse(absl::MakeSpan(choices), output, - nbits, nbits); - - for (int64_t i = 0; i < numl; ++i) { - output[i] = (input[i] & 1) + output[i]; + // choice[k] is the k-th bit x1_k + std::vector choices(numel); + for (int64_t k = 0; k < nbits; ++k) { + int64_t i = k * n; + for (int64_t j = 0; j < n; ++j) { + choices[i + j] = (input[j] >> k) & 1; + } } - } - // oup.shape B x (n * T) - std::vector tmp(B * n * inp.elsize()); + ferret_receiver_->RecvCAMCC_Collapse(absl::MakeSpan(choices), cot_output, + ring_width, nbits); - // bit matrix transpose - SseTranspose(oup.data(), tmp.data(), B, n * inp.elsize()); - - std::copy_n(tmp.data(), tmp.size(), oup.data()); - return oup; - }); - - // convert the bit form to integer form - auto rand = [&](NdArrayRef _bits) { - SPU_ENFORCE(_bits.isCompact(), "need compact input"); - const int64_t n = _bits.numel() / nbits; - // init as all 0s. - auto iform = ring_zeros(field, inp.shape()); - DISPATCH_ALL_FIELDS(field, "conv_to_bits", [&]() { - auto bits = NdArrayView(_bits); - auto digit = NdArrayView(iform); - for (int64_t i = 0; i < n; ++i) { - // LSB is bits[0]; MSB is bits[nbits - 1] - // We iterate the bits in reversed order - const size_t offset = i * nbits; - digit[i] = 0; - for (size_t j = nbits; j > 0; --j) { - digit[i] = (digit[i] << 1) | (bits[offset + j - 1] & 1); + for (int64_t k = 0; k < nbits; ++k) { + int64_t i = k * n; + for (int64_t j = 0; j < n; ++j) { + cot_output[i + j] = ((input[j] >> k) & 1) + cot_output[i + j]; } } - }); - return iform; - }(rand_bits); - - // open c = x ^ r - auto opened = OpenShare(ring_xor(inp, rand), ReduceOp::XOR, nbits, conn_); + } - // compute c + (1 - 2*c)* - NdArrayRef oup = ring_zeros(field, inp.shape()); - DISPATCH_ALL_FIELDS(field, "packed_b2a", [&]() { - using u2k = std::make_unsigned::type; - int rank = Rank(); - auto xr = NdArrayView(rand_bits); - auto xc = NdArrayView(opened); - auto xo = NdArrayView(oup); - - for (int64_t i = 0; i < n; ++i) { - const size_t offset = i * nbits; - u2k this_elt = xc[i]; - for (size_t j = 0; j < nbits; ++j, this_elt >>= 1) { - u2k c_ij = this_elt & 1; - ring2k_t one_bit = (1 - c_ij * 2) * xr[offset + j]; - if (rank == 0) { - one_bit += c_ij; - } - xo[i] += (one_bit << j); + // = \sum_k 2^k * + // where is the arithmetic share of the k-th bit + NdArrayView arith(arith_oup); + for (int64_t k = 0; k < nbits; ++k) { + int64_t i = k * n; + for (int64_t j = 0; j < n; ++j) { + arith[j] += (cot_output[i + j] << k); } } }); - return oup; + + return arith_oup; } // Math: @@ -205,7 +172,7 @@ NdArrayRef BasicOTProtocols::SingleB2A(const NdArrayRef &inp, int bit_width) { const int64_t n = inp.numel(); NdArrayRef oup = ring_zeros(field, inp.shape()); - DISPATCH_ALL_FIELDS(field, "single_b2a", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using u2k = std::make_unsigned::type; const u2k msk = makeBitsMask(bit_width); auto input = NdArrayView(inp); @@ -373,7 +340,7 @@ std::array BasicOTProtocols::AndTriple(FieldType field, auto AND_b = ring_zeros(field, shape); auto AND_c = ring_zeros(field, shape); - DISPATCH_ALL_FIELDS(field, "AndTriple", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto AND_xa = NdArrayView(AND_a); auto AND_xb = NdArrayView(AND_b); auto AND_xc = NdArrayView(AND_c); @@ -430,7 +397,7 @@ std::array BasicOTProtocols::CorrelatedAndTriple( auto AND_b1 = ring_zeros(field, shape); auto AND_c1 = ring_zeros(field, shape); - DISPATCH_ALL_FIELDS(field, "AndTriple", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto AND_xa = NdArrayView(AND_a); auto AND_xb0 = NdArrayView(AND_b0); auto AND_xc0 = NdArrayView(AND_c0); @@ -465,7 +432,7 @@ NdArrayRef BasicOTProtocols::Multiplexer(const NdArrayRef &msg, std::vector sel(size); // Compute (x0 + x1) * (b0 ^ b1) // Also b0 ^ b1 = 1 - 2*b0*b1 - return DISPATCH_ALL_FIELDS(field, "Multiplexer", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _msg(msg); NdArrayView _sel(select); auto corr_data = absl::MakeSpan(&_corr_data.at(0), size); @@ -506,7 +473,7 @@ NdArrayRef BasicOTProtocols::PrivateMulxRecv(const NdArrayRef &msg, auto recv = ring_zeros(field, msg.shape()); std::vector sel(size); - DISPATCH_ALL_FIELDS(field, "convert", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _sel(select); pforeach(0, size, [&](int64_t i) { sel[i] = static_cast(_sel[i] & 1); }); @@ -526,7 +493,7 @@ NdArrayRef BasicOTProtocols::PrivateMulxRecv(const NdArrayRef &msg, // Compute (x0 + x1) * b // x0 * b + x1 * b // COT compute - DISPATCH_ALL_FIELDS(field, "MultiplexerOnPrivate", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _msg(msg); auto _recv = absl::MakeSpan(&recv.at(0), size); @@ -549,7 +516,7 @@ NdArrayRef BasicOTProtocols::PrivateMulxSend(const NdArrayRef &msg) { auto recv = ring_zeros(field, msg.shape()); // Compute (x0 + x1) * b // x0 * b + x1 * b - DISPATCH_ALL_FIELDS(field, "MultiplexerOnPrivate", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto _msg = absl::MakeConstSpan(&msg.at(0), size); auto _recv = absl::MakeSpan(&recv.at(0), size); diff --git a/libspu/mpc/cheetah/ot/basic_ot_prot_test.cc b/libspu/mpc/cheetah/ot/basic_ot_prot_test.cc index aa18321c..17f78a4d 100644 --- a/libspu/mpc/cheetah/ot/basic_ot_prot_test.cc +++ b/libspu/mpc/cheetah/ot/basic_ot_prot_test.cc @@ -82,7 +82,7 @@ TEST_P(BasicOTProtTest, SingleB2A) { auto bshr0 = ring_rand(field, shape).as(boolean_t); auto bshr1 = ring_rand(field, shape).as(boolean_t); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto mask = static_cast(-1); if (nbits > 0) { mask = (static_cast(1) << packed_nbits) - 1; @@ -110,7 +110,7 @@ TEST_P(BasicOTProtTest, SingleB2A) { EXPECT_EQ(ashr0.shape(), ashr1.shape()); EXPECT_EQ(shape, ashr0.shape()); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView b0(bshr0); NdArrayView b1(bshr1); NdArrayView a0(ashr0); @@ -138,7 +138,7 @@ TEST_P(BasicOTProtTest, PackedB2A) { auto bshr0 = ring_rand(field, shape).as(boolean_t); auto bshr1 = ring_rand(field, shape).as(boolean_t); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto mask = static_cast(-1); if (nbits > 0) { mask = (static_cast(1) << packed_nbits) - 1; @@ -165,7 +165,7 @@ TEST_P(BasicOTProtTest, PackedB2A) { EXPECT_EQ(ashr0.shape(), ashr1.shape()); EXPECT_EQ(ashr0.shape(), shape); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView b0(bshr0); NdArrayView b1(bshr1); NdArrayView a0(ashr0); @@ -197,7 +197,7 @@ TEST_P(BasicOTProtTest, PackedB2AFull) { auto bshr0 = ring_rand(field, shape).as(boolean_t); auto bshr1 = ring_rand(field, shape).as(boolean_t); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto mask = static_cast(-1); if (nbits > 0) { mask = (static_cast(1) << packed_nbits) - 1; @@ -228,7 +228,7 @@ TEST_P(BasicOTProtTest, PackedB2AFull) { EXPECT_EQ(ashr0.shape(), ashr1.shape()); EXPECT_EQ(ashr0.shape(), shape); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView b0(bshr0); NdArrayView b1(bshr1); NdArrayView a0(ashr0); @@ -267,7 +267,7 @@ TEST_P(BasicOTProtTest, AndTripleSparse) { } }); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { ring2k_t max = static_cast(1) << target_nbits; NdArrayView a0(triple[0][0]); NdArrayView b0(triple[0][1]); @@ -304,7 +304,7 @@ TEST_P(BasicOTProtTest, BitwiseAnd) { for (int i : {0, 1}) { lhs[i] = ring_rand(field, shape).as(boolean_t); rhs[i] = ring_rand(field, shape).as(boolean_t); - DISPATCH_ALL_FIELDS(field, "mask", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { ring2k_t mask = makeBitsMask(bw); NdArrayView L(lhs[i]); NdArrayView R(rhs[i]); @@ -324,7 +324,7 @@ TEST_P(BasicOTProtTest, BitwiseAnd) { auto expected = ring_and(ring_xor(lhs[0], lhs[1]), ring_xor(rhs[0], rhs[1])); auto got = ring_xor(out[0], out[1]); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView e(expected); NdArrayView g(got); @@ -350,7 +350,7 @@ TEST_P(BasicOTProtTest, AndTripleFull) { } }); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView a0(packed_triple[0][0]); NdArrayView b0(packed_triple[0][1]); NdArrayView c0(packed_triple[0][2]); @@ -383,7 +383,7 @@ TEST_P(BasicOTProtTest, Multiplexer) { auto bshr0 = ring_rand(field, shape).as(boolean_t); auto bshr1 = ring_rand(field, shape).as(boolean_t); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto mask = static_cast(1); NdArrayView xb0(bshr0); NdArrayView xb1(bshr1); @@ -407,7 +407,7 @@ TEST_P(BasicOTProtTest, Multiplexer) { EXPECT_EQ(computed[0].shape(), computed[1].shape()); EXPECT_EQ(computed[0].shape(), shape); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView a0(ashr0); NdArrayView a1(ashr1); NdArrayView b0(bshr0); @@ -444,7 +444,7 @@ TEST_P(BasicOTProtTest, CorrelatedAndTriple) { EXPECT_EQ(corr_triple[1][0].shape(), corr_triple[1][i].shape()); } - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto a0 = NdArrayView(corr_triple[0][0]); auto b0 = NdArrayView(corr_triple[0][1]); auto c0 = NdArrayView(corr_triple[0][2]); @@ -487,7 +487,7 @@ TEST_P(BasicOTProtTest, PrivateMulx) { auto ashr1 = ring_rand(field, shape); auto choices = ring_rand(field, shape).as(boolean_t); - DISPATCH_ALL_FIELDS(field, "bit", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto mask = static_cast(1); NdArrayView xb(choices); pforeach(0, xb.numel(), [&](int64_t i) { xb[i] &= mask; }); @@ -507,7 +507,7 @@ TEST_P(BasicOTProtTest, PrivateMulx) { EXPECT_EQ(computed[0].shape(), computed[1].shape()); EXPECT_EQ(computed[0].shape(), shape); - DISPATCH_ALL_FIELDS(field, "check", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView a0(ashr0); NdArrayView a1(ashr1); NdArrayView c(choices); diff --git a/libspu/mpc/cheetah/ot/emp/ferret_test.cc b/libspu/mpc/cheetah/ot/emp/ferret_test.cc index af6cb96a..35353964 100644 --- a/libspu/mpc/cheetah/ot/emp/ferret_test.cc +++ b/libspu/mpc/cheetah/ot/emp/ferret_test.cc @@ -55,7 +55,7 @@ TEST_P(FerretCOTTest, ChosenCorrelationChosenChoice) { return static_cast(uniform(rdv) & 1); }); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView correlation(_correlation); std::vector computed[2]; utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { @@ -86,7 +86,7 @@ TEST_P(FerretCOTTest, RndMsgRndChoice) { constexpr size_t bw = 2; size_t n = 10; - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { std::vector msg0(n); std::vector msg1(n); ring2k_t max = static_cast(1) << bw; @@ -123,7 +123,7 @@ TEST_P(FerretCOTTest, RndMsgChosenChoice) { constexpr size_t bw = 2; size_t n = 10; - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { std::vector msg0(n); std::vector msg1(n); ring2k_t max = static_cast(1) << bw; @@ -163,7 +163,7 @@ TEST_P(FerretCOTTest, ChosenMsgChosenChoice) { size_t kWorldSize = 2; int64_t n = 100; auto field = GetParam(); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using scalar_t = ring2k_t; std::default_random_engine rdv; std::uniform_int_distribution uniform(0, -1); @@ -205,4 +205,68 @@ TEST_P(FerretCOTTest, ChosenMsgChosenChoice) { }); } +template +T makeMask(int bw) { + if (bw == sizeof(T) * 8) { + return static_cast(-1); + } + return (static_cast(1) << bw) - 1; +} + +TEST_P(FerretCOTTest, COT_Collapse) { + size_t kWorldSize = 2; + int64_t n = 8; + auto field = GetParam(); + + const auto bw = SizeOf(field) * 8; + const int level = bw; + + // generate random choices and correlation + const auto _correlation = ring_rand(field, {static_cast(n * level)}); + const auto N = _correlation.numel(); + + NdArrayRef oup1 = ring_zeros(field, _correlation.shape()); + NdArrayRef oup2 = ring_zeros(field, _correlation.shape()); + + std::vector choices(N, 1); + + DISPATCH_ALL_FIELDS(field, [&]() { + using u2k = std::make_unsigned::type; + + auto out1_span = absl::MakeSpan(&oup1.at(0), N); + auto out2_span = absl::MakeSpan(&oup2.at(0), N); + + NdArrayView correlation(_correlation); + + utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { + auto conn = std::make_shared(ctx); + int rank = ctx->Rank(); + + EmpFerretOt ferret(conn, rank == 0); + if (rank == 0) { + ferret.SendCAMCC_Collapse(makeConstSpan(correlation), out1_span, bw, + level); + ferret.Flush(); + + } else { + ferret.RecvCAMCC_Collapse(absl::MakeSpan(choices), out2_span, bw, + level); + } + }); + + // Sample-major order + // n || n || n || .... || n + // k=level||k=level - 1||k=level - 2|| .... + for (int64_t i = 0; i < N; i += n) { + const auto cur_bw = bw - (i / n); + const auto mask = makeMask(cur_bw); + for (int64_t j = 0; j < n; ++j) { + ring2k_t c = (-out1_span[i + j] + out2_span[i + j]) & mask; + ring2k_t e = (choices[i + j] ? correlation[i + j] : 0) & mask; + + ASSERT_EQ(c, e); + } + } + }); +} } // namespace spu::mpc::cheetah::test diff --git a/libspu/mpc/cheetah/ot/ot_util.cc b/libspu/mpc/cheetah/ot/ot_util.cc index 1d308a2f..6b418a3f 100644 --- a/libspu/mpc/cheetah/ot/ot_util.cc +++ b/libspu/mpc/cheetah/ot/ot_util.cc @@ -40,6 +40,30 @@ void U8ToBool(absl::Span bits, uint8_t u8) { } } +template +static T _makeBitsMask(size_t nbits) { + size_t max = sizeof(T) * 8; + if (nbits == 0) { + nbits = max; + } + SPU_ENFORCE(nbits <= max); + T mask = static_cast(-1); + if (nbits < max) { + mask = (static_cast(1) << nbits) - 1; + } + return mask; +} + +static void maskArray(NdArrayRef array, FieldType field, size_t bw) { + DISPATCH_ALL_FIELDS(field, [&]() { + NdArrayView view(array); + auto msk = _makeBitsMask(bw); + for (int64_t i = 0; i < view.numel(); ++i) { + view[i] &= msk; + } + }); +} + NdArrayRef OpenShare(const NdArrayRef &shr, ReduceOp op, size_t nbits, std::shared_ptr conn) { SPU_ENFORCE(conn != nullptr); @@ -52,20 +76,27 @@ NdArrayRef OpenShare(const NdArrayRef &shr, ReduceOp op, size_t nbits, nbits = fwidth; } SPU_ENFORCE(nbits <= fwidth, "nbits out-of-bound"); - bool packable = fwidth > nbits; - if (not packable) { - return conn->allReduce(op, shr, "open"); - } + size_t space_bits = op == ReduceOp::ADD ? nbits + 1 : nbits; size_t numel = shr.numel(); - size_t compact_numel = CeilDiv(numel * nbits, fwidth); + size_t compact_numel = CeilDiv(numel * space_bits, fwidth); + + if (space_bits > nbits and 0 != (fwidth % space_bits)) { + // FIXME(lwj): for Add, we can have a better ZipArray to handle a ring + // element that placed in two different blocks. + // For now, we use ZipArray for Add only when one element is just fit in one + // block. + auto out = conn->allReduce(op, shr, "open"); + maskArray(out, field, nbits); + return out; + } NdArrayRef out(shr.eltype(), {(int64_t)numel}); - DISPATCH_ALL_FIELDS(field, "zip", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto inp = absl::MakeConstSpan(&shr.at(0), numel); auto oup = absl::MakeSpan(&out.at(0), compact_numel); - size_t used = ZipArray(inp, nbits, oup); + size_t used = ZipArray(inp, space_bits, oup); SPU_ENFORCE_EQ(used, compact_numel); std::vector opened; @@ -76,8 +107,16 @@ NdArrayRef OpenShare(const NdArrayRef &shr, ReduceOp op, size_t nbits, } oup = absl::MakeSpan(&out.at(0), numel); - UnzipArray(absl::MakeConstSpan(opened), nbits, oup); + UnzipArray(absl::MakeConstSpan(opened), space_bits, oup); + + if (space_bits > nbits and nbits < fwidth) { + auto msk = (static_cast(1) << nbits) - 1; + for (size_t i = 0; i < numel; ++i) { + oup[i] &= msk; + } + } }); + return out.reshape(shr.shape()); } @@ -94,8 +133,10 @@ NdArrayRef OpenShare(const NdArrayRef &shr, ReduceOp op, size_t nbits, __attribute__((target("sse2"))) #endif void SseTranspose(uint8_t *out, uint8_t const *inp, uint64_t nrows, uint64_t ncols) { - uint64_t rr, cc; - int i, h; + uint64_t rr; + uint64_t cc; + int i; + int h; union { __m128i x; uint8_t b[16]; @@ -116,7 +157,9 @@ void SseTranspose(uint8_t *out, uint8_t const *inp, uint64_t nrows, uint64_t nco *(uint16_t *)&OUT(rr, cc + i) = _mm_movemask_epi8(vec); } } - if (rr == nrows) return; + if (rr == nrows) { + return; + } // The remainder is a block of 8x(16n+8) bits (n may be 0). // Do a PAIR of 8x8 blocks in each step: @@ -153,9 +196,12 @@ void SseTranspose(uint8_t *out, uint8_t const *inp, uint64_t nrows, uint64_t nco if (cc == ncols) return; // Do the remaining 8x8 block: - for (i = 0; i < 8; ++i) tmp.b[i] = INP(rr + i, cc); - for (i = 8; --i >= 0; tmp.x = _mm_slli_epi64(tmp.x, 1)) + for (i = 0; i < 8; ++i) { + tmp.b[i] = INP(rr + i, cc); + } + for (i = 8; --i >= 0; tmp.x = _mm_slli_epi64(tmp.x, 1)) { OUT(rr, cc + i) = _mm_movemask_epi8(tmp.x); + } } } // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/ot/ot_util_test.cc b/libspu/mpc/cheetah/ot/ot_util_test.cc index ea5e0aac..2105f408 100644 --- a/libspu/mpc/cheetah/ot/ot_util_test.cc +++ b/libspu/mpc/cheetah/ot/ot_util_test.cc @@ -14,11 +14,11 @@ #include "libspu/mpc/cheetah/ot/ot_util.h" -#include - #include "gtest/gtest.h" +#include "libspu/mpc/common/communicator.h" #include "libspu/mpc/utils/ring_ops.h" +#include "libspu/mpc/utils/simulate.h" namespace spu::mpc::cheetah::test { @@ -40,7 +40,7 @@ TEST_P(OtUtilTest, ZipArray) { auto unzip = ring_zeros(field, {n}); - DISPATCH_ALL_FIELDS(field, "UT_ZipArray", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { for (size_t bw : {1, 2, 4, 7, 15, 16}) { int64_t pack_load = elsze * 8 / bw; auto zip = ring_zeros(field, {(n + pack_load - 1) / pack_load}); @@ -69,7 +69,7 @@ TEST_P(OtUtilTest, ZipArrayBit) { auto unzip = ring_zeros(field, {n}); - DISPATCH_ALL_FIELDS(field, "UT_ZipArrayBit", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { const size_t elsze = SizeOf(field); for (size_t bw : {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { size_t width = elsze * 8; @@ -83,7 +83,6 @@ TEST_P(OtUtilTest, ZipArrayBit) { auto _zip = absl::MakeSpan(&zip.at(0), zip.numel()); auto _unzip = absl::MakeSpan(&unzip.at(0), unzip.numel()); pforeach(0, array.numel(), [&](int64_t i) { inp[i] &= mask; }); - size_t zip_sze = ZipArrayBit(inp, bw, _zip); SPU_ENFORCE(zip_sze == pack_sze); @@ -101,4 +100,77 @@ TEST_P(OtUtilTest, ZipArrayBit) { }); } -} // namespace spu::mpc::cheetah::test \ No newline at end of file +template +T makeBitsMask(size_t nbits) { + size_t max = sizeof(T) * 8; + if (nbits == 0) { + nbits = max; + } + SPU_ENFORCE(nbits <= max); + T mask = static_cast(-1); + if (nbits < max) { + mask = (static_cast(1) << nbits) - 1; + } + return mask; +} + +void MaskArray(NdArrayRef array, FieldType field, size_t bw) { + DISPATCH_ALL_FIELDS(field, [&]() { + NdArrayView view(array); + auto msk = makeBitsMask(bw); + for (int64_t i = 0; i < view.numel(); ++i) { + view[i] &= msk; + } + }); +} + +TEST_P(OtUtilTest, OpenShare_ADD) { + const auto field = GetParam(); + Shape shape = {1000L}; + + for (size_t bw_offset : {0, 15, 17}) { + size_t bw = SizeOf(field) * 8 - bw_offset; + NdArrayRef inp[2]; + utils::simulate(2, [&](std::shared_ptr ctx) { + int rank = ctx->Rank(); + + inp[rank] = ring_rand(field, shape); + MaskArray(inp[rank], field, bw); + + auto conn = std::make_shared(ctx); + auto opened = OpenShare(inp[rank], ReduceOp::ADD, bw, conn); + if (rank == 0) return; + auto expected = ring_add(inp[0], inp[1]); + MaskArray(expected, field, bw); + + ASSERT_TRUE(std::memcmp(&opened.at(0), &expected.at(0), + opened.elsize() * opened.numel()) == 0); + }); + } +} + +TEST_P(OtUtilTest, OpenShare_XOR) { + const auto field = GetParam(); + Shape shape = {1000L}; + + for (size_t bw_offset : {0, 3, 15}) { + size_t bw = SizeOf(field) * 8 - bw_offset; + NdArrayRef inp[2]; + utils::simulate(2, [&](std::shared_ptr ctx) { + int rank = ctx->Rank(); + + inp[rank] = ring_rand(field, shape); + MaskArray(inp[rank], field, bw); + + auto conn = std::make_shared(ctx); + auto opened = OpenShare(inp[rank], ReduceOp::XOR, bw, conn); + if (rank == 0) return; + auto expected = ring_xor(inp[0], inp[1]); + + ASSERT_TRUE(std::memcmp(&opened.at(0), &expected.at(0), + opened.elsize() * opened.numel()) == 0); + }); + } +} + +} // namespace spu::mpc::cheetah::test diff --git a/libspu/mpc/cheetah/ot/yacl/ferret_test.cc b/libspu/mpc/cheetah/ot/yacl/ferret_test.cc index c78abe62..1d5822ca 100644 --- a/libspu/mpc/cheetah/ot/yacl/ferret_test.cc +++ b/libspu/mpc/cheetah/ot/yacl/ferret_test.cc @@ -55,7 +55,7 @@ TEST_P(FerretCOTTest, ChosenCorrelationChosenChoice) { return static_cast(uniform(rdv) & 1); }); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView correlation(_correlation); std::vector computed[2]; utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { @@ -87,7 +87,7 @@ TEST_P(FerretCOTTest, RndMsgRndChoice) { constexpr size_t bw = 2; size_t n = 10; - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { std::vector msg0(n); std::vector msg1(n); ring2k_t max = static_cast(1) << bw; @@ -125,7 +125,7 @@ TEST_P(FerretCOTTest, RndMsgChosenChoice) { constexpr size_t bw = 2; size_t n = 10; - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { std::vector msg0(n); std::vector msg1(n); ring2k_t max = static_cast(1) << bw; @@ -166,7 +166,7 @@ TEST_P(FerretCOTTest, ChosenMsgChosenChoice) { int64_t n = 1 << 10; auto field = std::get<0>(GetParam()); auto use_ss = std::get<1>(GetParam()); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using scalar_t = ring2k_t; std::default_random_engine rdv; std::uniform_int_distribution uniform(0, -1); @@ -210,4 +210,69 @@ TEST_P(FerretCOTTest, ChosenMsgChosenChoice) { }); } +template +T makeMask(int bw) { + if (bw == sizeof(T) * 8) { + return static_cast(-1); + } + return (static_cast(1) << bw) - 1; +} + +TEST_P(FerretCOTTest, COT_Collapse) { + size_t kWorldSize = 2; + int64_t n = 8; + auto field = std::get<0>(GetParam()); + auto use_ss = std::get<1>(GetParam()); + + const auto bw = SizeOf(field) * 8; + const int level = bw; + + // generate random choices and correlation + const auto _correlation = ring_rand(field, {static_cast(n * level)}); + const auto N = _correlation.numel(); + + NdArrayRef oup1 = ring_zeros(field, _correlation.shape()); + NdArrayRef oup2 = ring_zeros(field, _correlation.shape()); + + std::vector choices(N, 1); + + DISPATCH_ALL_FIELDS(field, [&]() { + using u2k = std::make_unsigned::type; + + auto out1_span = absl::MakeSpan(&oup1.at(0), N); + auto out2_span = absl::MakeSpan(&oup2.at(0), N); + + NdArrayView correlation(_correlation); + + utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { + auto conn = std::make_shared(ctx); + int rank = ctx->Rank(); + + YaclFerretOt ferret(conn, rank == 0, use_ss); + if (rank == 0) { + ferret.SendCAMCC_Collapse(makeConstSpan(correlation), out1_span, bw, + level); + ferret.Flush(); + + } else { + ferret.RecvCAMCC_Collapse(absl::MakeSpan(choices), out2_span, bw, + level); + } + }); + + // Sample-major order + // n || n || n || .... || n + // k=level||k=level - 1||k=level - 2|| .... + for (int64_t i = 0; i < N; i += n) { + const auto cur_bw = bw - (i / n); + const auto mask = makeMask(cur_bw); + for (int64_t j = 0; j < n; ++j) { + ring2k_t c = (-out1_span[i + j] + out2_span[i + j]) & mask; + ring2k_t e = (choices[i + j] ? correlation[i + j] : 0) & mask; + + ASSERT_EQ(c, e); + } + } + }); +} } // namespace spu::mpc::cheetah::test diff --git a/libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.h b/libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.h index 294e095f..bd37640b 100644 --- a/libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.h +++ b/libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.h @@ -19,8 +19,8 @@ #include "yacl/kernel/algorithms/base_ot.h" #include "yacl/kernel/algorithms/ferret_ote.h" #include "yacl/kernel/algorithms/iknp_ote.h" -#include "yacl/kernel/algorithms/ot_store.h" #include "yacl/kernel/algorithms/softspoken_ote.h" +#include "yacl/kernel/type/ot_store.h" #include "libspu/core/prelude.h" #include "libspu/mpc/cheetah/ot/ot_util.h" diff --git a/libspu/mpc/cheetah/protocol.cc b/libspu/mpc/cheetah/protocol.cc index a18da61e..8b638d57 100644 --- a/libspu/mpc/cheetah/protocol.cc +++ b/libspu/mpc/cheetah/protocol.cc @@ -63,7 +63,7 @@ void regCheetahProtocol(SPUContext* ctx, ctx->prot() ->regKernel(), "source must be ring_type, got={}", eltype); const auto field = eltype.as()->field(); - DISPATCH_ALL_FIELDS(field, "ModulusUpAt", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using ring2u = std::make_unsigned::type; impl_->ModulusUpAt(NdArrayView(src), mod_idx, out); }); @@ -450,7 +450,7 @@ void ModulusSwitchHelper::CenteralizeAt(const NdArrayRef &src, size_t mod_idx, SPU_ENFORCE_EQ(numel, out.size()); SPU_ENFORCE(eltype.isa(), "source must be ring_type, got={}", eltype); const auto field = eltype.as()->field(); - DISPATCH_ALL_FIELDS(field, "CenteralizeAt", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using ring2u = std::make_unsigned::type; impl_->CenteralizeAt(NdArrayView(src), mod_idx, out); }); @@ -481,7 +481,7 @@ void ModulusSwitchHelper::ModulusDownRNS(absl::Span src, size_t num_elt = out.numel(); SPU_ENFORCE_EQ(num_elt * num_modulus, src.size()); - return DISPATCH_ALL_FIELDS(field, "ModulusDownRNS", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using ring2u = std::make_unsigned::type; absl::Span out_wrap(reinterpret_cast(out.data()), num_elt); diff --git a/libspu/mpc/cheetah/rlwe/modswitch_helper_test.cc b/libspu/mpc/cheetah/rlwe/modswitch_helper_test.cc index ecb5beee..2aa7a203 100644 --- a/libspu/mpc/cheetah/rlwe/modswitch_helper_test.cc +++ b/libspu/mpc/cheetah/rlwe/modswitch_helper_test.cc @@ -120,9 +120,10 @@ TEST_P(RLWE2LWETest, ModulusSwitch_UpDown) { } // e = round(r'/Delta) mod t \in R_t auto src = absl::MakeSpan(pt.data(), pt.coeff_count()); - auto cmp = ms_helper_->ModulusDownRNS(field_, {(int64_t)poly_deg}, src); + auto cmp = ms_helper_->ModulusDownRNS( + field_, {static_cast(poly_deg)}, src); // check r =? e - DISPATCH_ALL_FIELDS(field_, "", [&]() { + DISPATCH_ALL_FIELDS(field_, [&]() { auto expected = NdArrayView(_vec); auto computed = NdArrayView(cmp); for (int64_t i = 0; i < expected.numel(); ++i) { @@ -145,7 +146,7 @@ TEST_P(RLWE2LWETest, ModulusSwitch_DownUp) { // r <- R_q RLWEPt rnd; UniformPoly(*context_, &rnd); - auto &modulus = context_->first_context_data()->parms().coeff_modulus(); + const auto &modulus = context_->first_context_data()->parms().coeff_modulus(); // b = a' - r RLWEPt poly1; { @@ -158,8 +159,8 @@ TEST_P(RLWE2LWETest, ModulusSwitch_DownUp) { // r' = round(r/Delta) mod t \in R_t // b' = round(b/Delta) mod t \in R_t - auto shr0 = ring_zeros(field_, {(int64_t)poly_deg}); - auto shr1 = ring_zeros(field_, {(int64_t)poly_deg}); + auto shr0 = ring_zeros(field_, {static_cast(poly_deg)}); + auto shr1 = ring_zeros(field_, {static_cast(poly_deg)}); { auto src = absl::MakeSpan(rnd.data(), rnd.coeff_count()); ms_helper_->ModulusDownRNS(src, shr0); @@ -168,7 +169,7 @@ TEST_P(RLWE2LWETest, ModulusSwitch_DownUp) { ms_helper_->ModulusDownRNS(src, shr1); } - DISPATCH_ALL_FIELDS(field_, "", [&]() { + DISPATCH_ALL_FIELDS(field_, [&]() { auto expected = NdArrayView(vec_a); auto computed0 = NdArrayView(shr0); auto computed1 = NdArrayView(shr1); diff --git a/libspu/mpc/cheetah/rlwe/packlwes_test.cc b/libspu/mpc/cheetah/rlwe/packlwes_test.cc index 6d924c26..4bda1a27 100644 --- a/libspu/mpc/cheetah/rlwe/packlwes_test.cc +++ b/libspu/mpc/cheetah/rlwe/packlwes_test.cc @@ -13,8 +13,6 @@ // limitations under the License. #include "libspu/mpc/cheetah/rlwe/packlwes.h" -#include - #include "gtest/gtest.h" #include "seal/seal.h" @@ -235,7 +233,8 @@ TEST_P(PackLWEsTest, Phantom) { N_encoder_->Forward(array, &pt, true); NttInplace(pt, *N_context_); - RLWECt rlwe0, rlwe1; + RLWECt rlwe0; + RLWECt rlwe1; CATCH_SEAL_ERROR(encryptor.encrypt_symmetric(pt, rlwe0)); CATCH_SEAL_ERROR(encryptor.encrypt_symmetric(pt, rlwe1)); if (rlwe0.is_ntt_form()) { @@ -345,7 +344,7 @@ void VectorEncoder::Backward(const NdArrayRef &vec, RLWEPt *out, const auto field = eltype.as()->field(); - DISPATCH_ALL_FIELDS(field, "Backward", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto tmp_buff = ring_zeros(field, {(int64_t)poly_deg_}); auto xvec = NdArrayView(vec); auto xtmp = NdArrayView(tmp_buff); diff --git a/libspu/mpc/cheetah/state.cc b/libspu/mpc/cheetah/state.cc index 1f754baa..dded2f5f 100644 --- a/libspu/mpc/cheetah/state.cc +++ b/libspu/mpc/cheetah/state.cc @@ -16,8 +16,6 @@ #include -#include "spdlog/spdlog.h" - #include "libspu/core/context.h" #include "libspu/core/ndarray_ref.h" #include "libspu/core/prelude.h" @@ -89,7 +87,7 @@ void CheetahMulState::makeSureCacheSize(FieldType field, int64_t numel) { cross.slice({num_beaver}, {2 * num_beaver}, {1})), ring_mul(beaver[0], beaver[1])); - DISPATCH_ALL_FIELDS(field, "makeSureCacheSize", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { for (size_t i : {0, 1, 2}) { auto tmp = ring_zeros(field, {num_beaver + cached_sze_}); NdArrayView new_cache(tmp); diff --git a/libspu/mpc/cheetah/state.h b/libspu/mpc/cheetah/state.h index eaf20eec..0ee3e575 100644 --- a/libspu/mpc/cheetah/state.h +++ b/libspu/mpc/cheetah/state.h @@ -75,7 +75,7 @@ class CheetahMulState : public State { : mul_prot_(std::move(mul_prot)) {} public: - static constexpr char kBindName[] = "CheetahMul"; + static constexpr const char* kBindName() { return "CheetahMul"; } explicit CheetahMulState(const std::shared_ptr& lctx, bool enable_mul_lsb_error = false) { @@ -100,7 +100,7 @@ class CheetahDotState : public State { : dot_prot_(std::move(dot_prot)) {} public: - static constexpr char kBindName[] = "CheetahDot"; + static constexpr const char* kBindName() { return "CheetahDot"; } explicit CheetahDotState(const std::shared_ptr& lctx, bool disable_matmul_pack = false) { @@ -125,7 +125,7 @@ class CheetahOTState : public State { CheetahOtKind ot_kind_; public: - static constexpr char kBindName[] = "CheetahOT"; + static constexpr const char* kBindName() { return "CheetahOT"; } explicit CheetahOTState(size_t maximum_instances, CheetahOtKind ot_kind) : maximum_instances_(std::min(kMaxOTParallel, maximum_instances)), diff --git a/libspu/mpc/common/BUILD.bazel b/libspu/mpc/common/BUILD.bazel index 017c316a..f22bc908 100644 --- a/libspu/mpc/common/BUILD.bazel +++ b/libspu/mpc/common/BUILD.bazel @@ -42,6 +42,7 @@ spu_cc_library( hdrs = ["communicator.h"], deps = [ "//libspu/core:object", + "//libspu/mpc/utils:gfmp_ops", "//libspu/mpc/utils:ring_ops", "@yacl//yacl/link:context", "@yacl//yacl/link/algorithm:allgather", @@ -65,6 +66,7 @@ spu_cc_library( hdrs = ["prg_state.h"], deps = [ "//libspu/core:object", + "//libspu/mpc/utils:permute", "@yacl//yacl/crypto/rand", "@yacl//yacl/crypto/tools:prg", "@yacl//yacl/link:context", @@ -87,6 +89,7 @@ spu_cc_library( hdrs = ["prg_tensor.h"], deps = [ "//libspu/core:ndarray_ref", + "//libspu/mpc/utils:gfmp_ops", "//libspu/mpc/utils:ring_ops", "@yacl//yacl/crypto/tools:prg", ], diff --git a/libspu/mpc/common/communicator.cc b/libspu/mpc/common/communicator.cc index 0d20dddd..b7dc4089 100644 --- a/libspu/mpc/common/communicator.cc +++ b/libspu/mpc/common/communicator.cc @@ -14,6 +14,7 @@ #include "libspu/mpc/common/communicator.h" +#include "libspu/mpc/utils/gfmp_ops.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc { @@ -26,19 +27,22 @@ std::shared_ptr stealBuffer(yacl::Buffer&& buf) { return std::make_shared(std::move(buf)); } -std::shared_ptr getOrCreateCompactBuf(const NdArrayRef& in) { - if (in.numel() * in.elsize() != static_cast(in.buf()->size())) { - return in.clone().buf(); +NdArrayRef getOrCreateCompactArray(const NdArrayRef& in) { + if (!in.isCompact()) { + return in.clone(); } - return in.buf(); + + return in; } } // namespace NdArrayRef Communicator::allReduce(ReduceOp op, const NdArrayRef& in, std::string_view tag) { - const auto buf = getOrCreateCompactBuf(in); - std::vector bufs = yacl::link::AllGather(lctx_, *buf, tag); + const auto array = getOrCreateCompactArray(in); + yacl::ByteContainerView bv(reinterpret_cast(array.data()), + in.numel() * in.elsize()); + std::vector bufs = yacl::link::AllGather(lctx_, bv, tag); SPU_ENFORCE(bufs.size() == getWorldSize()); auto res = in.clone(); @@ -50,7 +54,11 @@ NdArrayRef Communicator::allReduce(ReduceOp op, const NdArrayRef& in, auto arr = NdArrayRef(stealBuffer(std::move(bufs[idx])), in.eltype(), in.shape(), makeCompactStrides(in.shape()), kOffset); if (op == ReduceOp::ADD) { - ring_add_(res, arr); + if (in.eltype().isa()) { + gfmp_add_mod_(res, arr); + } else { + ring_add_(res, arr); + } } else if (op == ReduceOp::XOR) { ring_xor_(res, arr); } else { @@ -59,7 +67,7 @@ NdArrayRef Communicator::allReduce(ReduceOp op, const NdArrayRef& in, } stats_.latency += 1; - stats_.comm += buf->size() * (lctx_->WorldSize() - 1); + stats_.comm += in.numel() * in.elsize() * (lctx_->WorldSize() - 1); return res; } @@ -67,8 +75,10 @@ NdArrayRef Communicator::allReduce(ReduceOp op, const NdArrayRef& in, NdArrayRef Communicator::reduce(ReduceOp op, const NdArrayRef& in, size_t root, std::string_view tag) { SPU_ENFORCE(root < lctx_->WorldSize()); - const auto buf = getOrCreateCompactBuf(in); - std::vector bufs = yacl::link::Gather(lctx_, *buf, root, tag); + const auto array = getOrCreateCompactArray(in); + yacl::ByteContainerView bv(reinterpret_cast(array.data()), + in.numel() * in.elsize()); + std::vector bufs = yacl::link::Gather(lctx_, bv, root, tag); auto res = in.clone(); if (getRank() == root) { @@ -81,7 +91,11 @@ NdArrayRef Communicator::reduce(ReduceOp op, const NdArrayRef& in, size_t root, NdArrayRef(stealBuffer(std::move(bufs[idx])), in.eltype(), in.shape(), makeCompactStrides(in.shape()), kOffset); if (op == ReduceOp::ADD) { - ring_add_(res, arr); + if (in.eltype().isa()) { + gfmp_add_mod_(res, arr); + } else { + ring_add_(res, arr); + } } else if (op == ReduceOp::XOR) { ring_xor_(res, arr); } else { @@ -89,30 +103,69 @@ NdArrayRef Communicator::reduce(ReduceOp op, const NdArrayRef& in, size_t root, } } } - stats_.latency += 1; - stats_.comm += buf->size(); + stats_.comm += in.numel() * in.elsize(); return res; } NdArrayRef Communicator::rotate(const NdArrayRef& in, std::string_view tag) { - const auto buf = getOrCreateCompactBuf(in); - lctx_->SendAsync(lctx_->PrevRank(), *buf, tag); + const auto array = getOrCreateCompactArray(in); + yacl::ByteContainerView bv(reinterpret_cast(array.data()), + in.numel() * in.elsize()); + lctx_->SendAsync(lctx_->PrevRank(), bv, tag); auto res_buf = lctx_->Recv(lctx_->NextRank(), tag); stats_.latency += 1; - stats_.comm += buf->size(); + stats_.comm += in.numel() * in.elsize(); return NdArrayRef(stealBuffer(std::move(res_buf)), in.eltype(), in.shape(), makeCompactStrides(in.shape()), kOffset); } +std::vector Communicator::gather(const NdArrayRef& in, size_t root, + std::string_view tag) { + const auto array = getOrCreateCompactArray(in); + yacl::ByteContainerView bv(reinterpret_cast(array.data()), + array.numel() * array.elsize()); + auto bufs = yacl::link::Gather(lctx_, bv, root, tag); + + stats_.latency += 1; + stats_.comm += array.numel() * array.elsize(); + + auto res = std::vector(getWorldSize()); + if (root == getRank()) { + SPU_ENFORCE_EQ(bufs.size(), getWorldSize()); + for (size_t idx = 0; idx < bufs.size(); idx++) { + res[idx] = + NdArrayRef(stealBuffer(std::move(bufs[idx])), in.eltype(), in.shape(), + makeCompactStrides(in.shape()), kOffset); + } + } + return res; +} + +NdArrayRef Communicator::broadcast(const NdArrayRef& in, size_t root, + std::string_view tag) { + const auto array = getOrCreateCompactArray(in); + yacl::ByteContainerView bv(reinterpret_cast(array.data()), + array.elsize() * array.numel()); + auto buf = yacl::link::Broadcast(lctx_, bv, root, tag); + + stats_.latency += 1; + stats_.comm += in.elsize() * in.numel(); + + return NdArrayRef(stealBuffer(std::move(buf)), in.eltype(), in.shape(), + makeCompactStrides(in.shape()), kOffset); +} + void Communicator::sendAsync(size_t dst_rank, const NdArrayRef& in, std::string_view tag) { - const auto buf = getOrCreateCompactBuf(in); - lctx_->SendAsync(dst_rank, *buf, tag); + const auto array = getOrCreateCompactArray(in); + yacl::ByteContainerView bv(reinterpret_cast(array.data()), + in.numel() * in.elsize()); + lctx_->SendAsync(dst_rank, bv, tag); } NdArrayRef Communicator::recv(size_t src_rank, const Type& eltype, @@ -123,4 +176,4 @@ NdArrayRef Communicator::recv(size_t src_rank, const Type& eltype, return NdArrayRef(stealBuffer(std::move(buf)), eltype, {numel}, {1}, kOffset); } -} // namespace spu::mpc +} // namespace spu::mpc \ No newline at end of file diff --git a/libspu/mpc/common/communicator.h b/libspu/mpc/common/communicator.h index d566450e..f6103937 100644 --- a/libspu/mpc/common/communicator.h +++ b/libspu/mpc/common/communicator.h @@ -50,7 +50,7 @@ enum class ReduceOp { // gap. class Communicator : public State { public: - static constexpr char kBindName[] = "Communicator"; + static constexpr const char* kBindName() { return "Communicator"; } struct Stats { // @@ -103,6 +103,11 @@ class Communicator : public State { NdArrayRef allReduce(ReduceOp op, const NdArrayRef& in, std::string_view tag); + std::vector gather(const NdArrayRef& in, size_t root, + std::string_view tag); + + NdArrayRef broadcast(const NdArrayRef& in, size_t root, std::string_view tag); + NdArrayRef reduce(ReduceOp op, const NdArrayRef& in, size_t root, std::string_view tag); diff --git a/libspu/mpc/common/prg_state.cc b/libspu/mpc/common/prg_state.cc index 0f68fcff..374fba19 100644 --- a/libspu/mpc/common/prg_state.cc +++ b/libspu/mpc/common/prg_state.cc @@ -19,6 +19,8 @@ #include "yacl/link/algorithm/allgather.h" #include "yacl/utils/serialize.h" +#include "libspu/mpc/utils/permute.h" + namespace spu::mpc { PrgState::PrgState() { @@ -108,4 +110,15 @@ NdArrayRef PrgState::genPubl(FieldType field, const Shape& shape) { return res; } +Index PrgState::genPrivPerm(size_t numel) { + return genRandomPerm(numel, priv_seed_, &priv_counter_); +} + +std::pair PrgState::genPrssPermPair(size_t numel) { + std::pair res; + res.first = genRandomPerm(numel, self_seed_, &r0_counter_); + res.second = genRandomPerm(numel, next_seed_, &r1_counter_); + return res; +} + } // namespace spu::mpc diff --git a/libspu/mpc/common/prg_state.h b/libspu/mpc/common/prg_state.h index 7e8938e6..93c0461b 100644 --- a/libspu/mpc/common/prg_state.h +++ b/libspu/mpc/common/prg_state.h @@ -15,6 +15,7 @@ #pragma once #include "absl/types/span.h" +#include "yacl/crypto/rand/rand.h" #include "yacl/crypto/tools/prg.h" #include "yacl/link/context.h" @@ -39,12 +40,13 @@ class PrgState : public State { uint64_t priv_counter_ = 0; // Pseudorandom Secret Sharing seeds. - uint128_t next_seed_ = 0; uint128_t self_seed_ = 0; - uint64_t prss_counter_ = 0; + uint128_t next_seed_ = 0; + uint64_t r0_counter_ = 0; // cnt for self_seed + uint64_t r1_counter_ = 0; // cnt for next_seed public: - static constexpr char kBindName[] = "PrgState"; + static constexpr const char* kBindName() { return "PrgState"; } static constexpr auto kAesType = yacl::crypto::SymmetricCrypto::CryptoType::AES128_CTR; @@ -59,13 +61,18 @@ class PrgState : public State { NdArrayRef genPubl(FieldType field, const Shape& shape); + Index genPrivPerm(size_t numel); + + // Generate a random permutation pair (p0, p1). + std::pair genPrssPermPair(size_t numel); + // Generate a random pair (r0, r1), where // r1 = next_party.r0 // // This correlation could be used to construct zero shares. // // Note: ignore_first, ignore_second is for perf improvement. - enum class GenPrssCtrl { Both, First, Second, None }; + enum class GenPrssCtrl { Both, First, Second }; std::pair genPrssPair(FieldType field, const Shape& shape, GenPrssCtrl ctrl); @@ -73,29 +80,21 @@ class PrgState : public State { template void fillPrssPair(T* r0, T* r1, size_t numel, GenPrssCtrl ctrl) { switch (ctrl) { - case GenPrssCtrl::None: { - // Nothing to generate, pure dummy - prss_counter_ = yacl::crypto::DummyUpdateRandomCount(prss_counter_, - numel * sizeof(T)); - return; - } case GenPrssCtrl::First: { - prss_counter_ = yacl::crypto::FillPRand( - kAesType, self_seed_, 0, prss_counter_, absl::MakeSpan(r0, numel)); + r0_counter_ = yacl::crypto::FillPRand( + kAesType, self_seed_, 0, r0_counter_, absl::MakeSpan(r0, numel)); return; } case GenPrssCtrl::Second: { - prss_counter_ = yacl::crypto::FillPRand( - kAesType, next_seed_, 0, prss_counter_, absl::MakeSpan(r1, numel)); + r1_counter_ = yacl::crypto::FillPRand( + kAesType, next_seed_, 0, r1_counter_, absl::MakeSpan(r1, numel)); return; } case GenPrssCtrl::Both: { - auto counter0 = yacl::crypto::FillPRand( - kAesType, self_seed_, 0, prss_counter_, absl::MakeSpan(r0, numel)); - auto counter1 = yacl::crypto::FillPRand( - kAesType, next_seed_, 0, prss_counter_, absl::MakeSpan(r1, numel)); - SPU_ENFORCE(counter0 == counter1); - prss_counter_ = counter0; + r0_counter_ = yacl::crypto::FillPRand( + kAesType, self_seed_, 0, r0_counter_, absl::MakeSpan(r0, numel)); + r1_counter_ = yacl::crypto::FillPRand( + kAesType, next_seed_, 0, r1_counter_, absl::MakeSpan(r1, numel)); return; } } diff --git a/libspu/mpc/common/prg_tensor.h b/libspu/mpc/common/prg_tensor.h index ced75981..d3b3e704 100644 --- a/libspu/mpc/common/prg_tensor.h +++ b/libspu/mpc/common/prg_tensor.h @@ -15,6 +15,7 @@ #pragma once #include "libspu/core/ndarray_ref.h" +#include "libspu/mpc/utils/gfmp_ops.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc { @@ -22,24 +23,46 @@ namespace spu::mpc { using PrgSeed = uint128_t; using PrgCounter = uint64_t; +// Gfmp is regarded as word +// standing for Galois Field with Mersenne Prime. +enum class ElementType { kRing, kGfmp }; + struct PrgArrayDesc { Shape shape; FieldType field; PrgCounter prg_counter; + ElementType eltype; }; inline NdArrayRef prgCreateArray(FieldType field, const Shape& shape, PrgSeed seed, PrgCounter* counter, - PrgArrayDesc* desc) { + PrgArrayDesc* desc, + ElementType eltype = ElementType::kRing) { if (desc != nullptr) { - *desc = {Shape(shape.begin(), shape.end()), field, *counter}; + *desc = {Shape(shape.begin(), shape.end()), field, *counter, eltype}; + } + if (eltype == ElementType::kGfmp) { + return gfmp_rand(field, shape, seed, counter); + } else { + return ring_rand(field, shape, seed, counter); } - return ring_rand(field, shape, seed, counter); } inline NdArrayRef prgReplayArray(PrgSeed seed, const PrgArrayDesc& desc) { PrgCounter counter = desc.prg_counter; - return ring_rand(desc.field, desc.shape, seed, &counter); + if (desc.eltype == ElementType::kGfmp) { + return gfmp_rand(desc.field, desc.shape, seed, &counter); + } else { + return ring_rand(desc.field, desc.shape, seed, &counter); + } +} + +inline NdArrayRef prgReplayArrayMutable(PrgSeed seed, PrgArrayDesc& desc) { + if (desc.eltype == ElementType::kGfmp) { + return gfmp_rand(desc.field, desc.shape, seed, &desc.prg_counter); + } else { + return ring_rand(desc.field, desc.shape, seed, &desc.prg_counter); + } } } // namespace spu::mpc diff --git a/libspu/mpc/common/pv2k.cc b/libspu/mpc/common/pv2k.cc index 6c777779..ea38bf9d 100644 --- a/libspu/mpc/common/pv2k.cc +++ b/libspu/mpc/common/pv2k.cc @@ -37,7 +37,7 @@ inline int64_t getOwner(const NdArrayRef& x) { class P2V : public RevealToKernel { public: - static constexpr char kBindName[] = "p2v"; + static constexpr const char* kBindName() { return "p2v"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -58,7 +58,7 @@ class P2V : public RevealToKernel { class V2P : public UnaryKernel { public: - static constexpr char kBindName[] = "v2p"; + static constexpr const char* kBindName() { return "v2p"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -73,7 +73,7 @@ class V2P : public UnaryKernel { auto numel = in.numel(); - DISPATCH_ALL_FIELDS(field, "v2p", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { std::vector priv(numel); NdArrayView _in(in); @@ -90,7 +90,7 @@ class V2P : public UnaryKernel { class MakeP : public Kernel { public: - static constexpr char kBindName[] = "make_p"; + static constexpr const char* kBindName() { return "make_p"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -113,7 +113,7 @@ class MakeP : public Kernel { Strides(shape.size(), 0), // strides 0); - DISPATCH_ALL_FIELDS(field, "pub2k.make_p", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { arr.at(Index(shape.size(), 0)) = static_cast(init); }); return Value(arr, DT_INVALID); @@ -122,7 +122,7 @@ class MakeP : public Kernel { class RandP : public RandKernel { public: - static constexpr char kBindName[] = "rand_p"; + static constexpr const char* kBindName() { return "rand_p"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -136,9 +136,9 @@ class RandP : public RandKernel { } }; -class NotP : public UnaryKernel { +class NegateP : public UnaryKernel { public: - static constexpr char kBindName[] = "not_p"; + static constexpr const char* kBindName() { return "negate_p"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -146,13 +146,13 @@ class NotP : public UnaryKernel { NdArrayRef proc(KernelEvalContext*, const NdArrayRef& in) const override { const auto field = in.eltype().as()->field(); - return ring_not(in).as(makeType(field)); + return ring_neg(in).as(makeType(field)); } }; -class NotV : public UnaryKernel { +class NegateV : public UnaryKernel { public: - static constexpr char kBindName[] = "not_v"; + static constexpr const char* kBindName() { return "negate_v"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -160,7 +160,7 @@ class NotV : public UnaryKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override { if (isOwner(ctx, in.eltype())) { - return ring_not(in).as(in.eltype()); + return ring_neg(in).as(in.eltype()); } else { return in; } @@ -169,20 +169,21 @@ class NotV : public UnaryKernel { class MsbP : public UnaryKernel { public: - static constexpr char kBindName[] = "msb_p"; + static constexpr const char* kBindName() { return "msb_p"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext*, const NdArrayRef& in) const override { - return ring_rshift(in, in.elsize() * 8 - 1).as(in.eltype()); + return ring_rshift(in, {static_cast(in.elsize() * 8 - 1)}) + .as(in.eltype()); } }; class MsbV : public UnaryKernel { public: - static constexpr char kBindName[] = "msb_v"; + static constexpr const char* kBindName() { return "msb_v"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -190,7 +191,8 @@ class MsbV : public UnaryKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override { if (isOwner(ctx, in.eltype())) { - return ring_rshift(in, in.elsize() * 8 - 1).as(in.eltype()); + return ring_rshift(in, {static_cast(in.elsize() * 8 - 1)}) + .as(in.eltype()); } else { return in; } @@ -199,7 +201,7 @@ class MsbV : public UnaryKernel { class EqualPP : public BinaryKernel { public: - static constexpr char kBindName[] = "equal_pp"; + static constexpr const char* kBindName() { return "equal_pp"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } @@ -215,7 +217,7 @@ class EqualPP : public BinaryKernel { class EqualVVV : public BinaryKernel { public: - static constexpr char kBindName[] = "equal_vvv"; + static constexpr const char* kBindName() { return "equal_vvv"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } @@ -234,15 +236,13 @@ class EqualVVV : public BinaryKernel { class EqualVP : public BinaryKernel { public: - static constexpr char kBindName[] = "equal_vp"; + static constexpr const char* kBindName() { return "equal_vp"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x, const NdArrayRef& y) const override { - SPU_ENFORCE_EQ(x.eltype(), y.eltype()); - if (isOwner(ctx, x.eltype())) { return ring_equal(x, y).as(x.eltype()); } else { @@ -253,7 +253,7 @@ class EqualVP : public BinaryKernel { class AddPP : public BinaryKernel { public: - static constexpr char kBindName[] = "add_pp"; + static constexpr const char* kBindName() { return "add_pp"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } @@ -267,7 +267,7 @@ class AddPP : public BinaryKernel { class AddVVV : public BinaryKernel { public: - static constexpr char kBindName[] = "add_vvv"; + static constexpr const char* kBindName() { return "add_vvv"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } @@ -286,7 +286,7 @@ class AddVVV : public BinaryKernel { class AddVP : public BinaryKernel { public: - static constexpr char kBindName[] = "add_vp"; + static constexpr const char* kBindName() { return "add_vp"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } @@ -303,7 +303,7 @@ class AddVP : public BinaryKernel { class MulPP : public BinaryKernel { public: - static constexpr char kBindName[] = "mul_pp"; + static constexpr const char* kBindName() { return "mul_pp"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -318,7 +318,7 @@ class MulPP : public BinaryKernel { class MulVP : public BinaryKernel { public: - static constexpr char kBindName[] = "mul_vp"; + static constexpr const char* kBindName() { return "mul_vp"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -336,7 +336,7 @@ class MulVP : public BinaryKernel { class MulVVV : public BinaryKernel { public: - static constexpr char kBindName[] = "mul_vvv"; + static constexpr const char* kBindName() { return "mul_vvv"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -354,7 +354,7 @@ class MulVVV : public BinaryKernel { class MatMulPP : public MatmulKernel { public: - static constexpr char kBindName[] = "mmul_pp"; + static constexpr const char* kBindName() { return "mmul_pp"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -369,7 +369,7 @@ class MatMulPP : public MatmulKernel { class MatMulVVV : public MatmulKernel { public: - static constexpr char kBindName[] = "mmul_vvv"; + static constexpr const char* kBindName() { return "mmul_vvv"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -386,7 +386,7 @@ class MatMulVVV : public MatmulKernel { class MatMulVP : public MatmulKernel { public: - static constexpr char kBindName[] = "mmul_vp"; + static constexpr const char* kBindName() { return "mmul_vp"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -402,7 +402,7 @@ class MatMulVP : public MatmulKernel { class AndPP : public BinaryKernel { public: - static constexpr char kBindName[] = "and_pp"; + static constexpr const char* kBindName() { return "and_pp"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -417,7 +417,7 @@ class AndPP : public BinaryKernel { class AndVVV : public BinaryKernel { public: - static constexpr char kBindName[] = "and_vvv"; + static constexpr const char* kBindName() { return "and_vvv"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -436,7 +436,7 @@ class AndVVV : public BinaryKernel { class AndVP : public BinaryKernel { public: - static constexpr char kBindName[] = "and_vp"; + static constexpr const char* kBindName() { return "and_vp"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -454,7 +454,7 @@ class AndVP : public BinaryKernel { class XorPP : public BinaryKernel { public: - static constexpr char kBindName[] = "xor_pp"; + static constexpr const char* kBindName() { return "xor_pp"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -469,7 +469,7 @@ class XorPP : public BinaryKernel { class XorVVV : public BinaryKernel { public: - static constexpr char kBindName[] = "xor_vvv"; + static constexpr const char* kBindName() { return "xor_vvv"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -487,7 +487,7 @@ class XorVVV : public BinaryKernel { class XorVP : public BinaryKernel { public: - static constexpr char kBindName[] = "xor_vp"; + static constexpr const char* kBindName() { return "xor_vp"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -505,28 +505,28 @@ class XorVP : public BinaryKernel { class LShiftP : public ShiftKernel { public: - static constexpr char kBindName[] = "lshift_p"; + static constexpr const char* kBindName() { return "lshift_p"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext*, const NdArrayRef& in, - size_t bits) const override { + const Sizes& bits) const override { return ring_lshift(in, bits).as(in.eltype()); } }; class LShiftV : public ShiftKernel { public: - static constexpr char kBindName[] = "lshift_v"; + static constexpr const char* kBindName() { return "lshift_v"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override { + const Sizes& bits) const override { if (isOwner(ctx, in.eltype())) { return ring_lshift(in, bits).as(in.eltype()); } else { @@ -537,28 +537,28 @@ class LShiftV : public ShiftKernel { class RShiftP : public ShiftKernel { public: - static constexpr char kBindName[] = "rshift_p"; + static constexpr const char* kBindName() { return "rshift_p"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext*, const NdArrayRef& in, - size_t bits) const override { + const Sizes& bits) const override { return ring_rshift(in, bits).as(in.eltype()); } }; class RShiftV : public ShiftKernel { public: - static constexpr char kBindName[] = "rshift_v"; + static constexpr const char* kBindName() { return "rshift_v"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override { + const Sizes& bits) const override { if (isOwner(ctx, in.eltype())) { return ring_rshift(in, bits).as(in.eltype()); } else { @@ -569,28 +569,28 @@ class RShiftV : public ShiftKernel { class ARShiftP : public ShiftKernel { public: - static constexpr char kBindName[] = "arshift_p"; + static constexpr const char* kBindName() { return "arshift_p"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext*, const NdArrayRef& in, - size_t bits) const override { + const Sizes& bits) const override { return ring_arshift(in, bits).as(in.eltype()); } }; class ARShiftV : public ShiftKernel { public: - static constexpr char kBindName[] = "arshift_v"; + static constexpr const char* kBindName() { return "arshift_v"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override { + const Sizes& bits) const override { if (isOwner(ctx, in.eltype())) { return ring_arshift(in, bits).as(in.eltype()); } else { @@ -607,8 +607,8 @@ NdArrayRef rounded_arshift(const NdArrayRef& in, size_t bits) { // https://stackoverflow.com/questions/14008330/how-do-you-multiply-two-fixed-point-numbers // Under certain pattern, like sum(mul(A, B)), error can accumulate in a // fairly significant way - auto v1 = ring_arshift(in, bits); - auto v2 = ring_arshift(in, bits - 1); + auto v1 = ring_arshift(in, {static_cast(bits)}); + auto v2 = ring_arshift(in, {static_cast(bits - 1)}); ring_and_(v2, ring_ones(in.eltype().as()->field(), in.shape())); ring_add_(v1, v2); return v1; @@ -616,30 +616,32 @@ NdArrayRef rounded_arshift(const NdArrayRef& in, size_t bits) { class TruncP : public ShiftKernel { public: - static constexpr char kBindName[] = "trunc_p"; + static constexpr const char* kBindName() { return "trunc_p"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext*, const NdArrayRef& in, - size_t bits) const override { - return rounded_arshift(in, bits).as(in.eltype()); + const Sizes& bits) const override { + SPU_ENFORCE(bits.size() == 1, "truncation bits should be splat"); + return rounded_arshift(in, bits[0]).as(in.eltype()); } }; class TruncV : public ShiftKernel { public: - static constexpr char kBindName[] = "trunc_v"; + static constexpr const char* kBindName() { return "trunc_v"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override { + const Sizes& bits) const override { if (isOwner(ctx, in.eltype())) { - return rounded_arshift(in, bits).as(in.eltype()); + SPU_ENFORCE(bits.size() == 1, "truncation bits should be splat"); + return rounded_arshift(in, bits[0]).as(in.eltype()); } else { return in; } @@ -648,7 +650,7 @@ class TruncV : public ShiftKernel { class BitrevP : public BitrevKernel { public: - static constexpr char kBindName[] = "bitrev_p"; + static constexpr const char* kBindName() { return "bitrev_p"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -666,7 +668,7 @@ class BitrevP : public BitrevKernel { class BitrevV : public BitrevKernel { public: - static constexpr char kBindName[] = "bitrev_v"; + static constexpr const char* kBindName() { return "bitrev_v"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -688,7 +690,7 @@ class BitrevV : public BitrevKernel { class GenInvPermP : public GenInvPermKernel { public: - static constexpr char kBindName[] = "gen_inv_perm_p"; + static constexpr const char* kBindName() { return "gen_inv_perm_p"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -701,7 +703,7 @@ class GenInvPermP : public GenInvPermKernel { auto numel = in.numel(); - DISPATCH_ALL_FIELDS(field, "gen_inv_perm_p", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = std::make_signed_t; std::vector perm(numel); std::iota(perm.begin(), perm.end(), 0); @@ -722,7 +724,7 @@ class GenInvPermP : public GenInvPermKernel { class GenInvPermV : public GenInvPermKernel { public: - static constexpr char kBindName[] = "gen_inv_perm_v"; + static constexpr const char* kBindName() { return "gen_inv_perm_v"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -735,7 +737,7 @@ class GenInvPermV : public GenInvPermKernel { auto numel = in.numel(); const auto field = in.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, "gen_inv_perm_v", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = std::make_signed_t; std::vector perm(numel); std::iota(perm.begin(), perm.end(), 0); @@ -759,7 +761,7 @@ class GenInvPermV : public GenInvPermKernel { class InvPermPP : public PermKernel { public: - static constexpr char kBindName[] = "inv_perm_pp"; + static constexpr const char* kBindName() { return "inv_perm_pp"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } @@ -769,7 +771,7 @@ class InvPermPP : public PermKernel { SPU_ENFORCE_EQ(x.eltype(), y.eltype()); NdArrayRef z(x.eltype(), x.shape()); const auto field = x.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = std::make_signed_t; NdArrayView _x(x); NdArrayView _y(y); @@ -784,7 +786,7 @@ class InvPermPP : public PermKernel { class InvPermVV : public PermKernel { public: - static constexpr char kBindName[] = "inv_perm_vv"; + static constexpr const char* kBindName() { return "inv_perm_vv"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } @@ -795,7 +797,7 @@ class InvPermVV : public PermKernel { if (isOwner(ctx, x.eltype())) { NdArrayRef z(x.eltype(), x.shape()); const auto field = x.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = std::make_signed_t; NdArrayView _x(x); NdArrayView _y(y); @@ -813,7 +815,7 @@ class InvPermVV : public PermKernel { class PermPP : public PermKernel { public: - static constexpr char kBindName[] = "perm_pp"; + static constexpr const char* kBindName() { return "perm_pp"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } @@ -823,7 +825,7 @@ class PermPP : public PermKernel { SPU_ENFORCE_EQ(x.eltype(), y.eltype()); NdArrayRef z(x.eltype(), x.shape()); const auto field = x.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = std::make_signed_t; NdArrayView _x(x); NdArrayView _y(y); @@ -838,7 +840,7 @@ class PermPP : public PermKernel { class PermVV : public PermKernel { public: - static constexpr char kBindName[] = "perm_vv"; + static constexpr const char* kBindName() { return "perm_vv"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } @@ -849,7 +851,7 @@ class PermVV : public PermKernel { if (isOwner(ctx, x.eltype())) { NdArrayRef z(x.eltype(), x.shape()); const auto field = x.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = std::make_signed_t; NdArrayView _x(x); NdArrayView _y(y); @@ -867,7 +869,7 @@ class PermVV : public PermKernel { class MergeKeysP : public MergeKeysKernel { public: - static constexpr char kBindName[] = "merge_keys_p"; + static constexpr const char* kBindName() { return "merge_keys_p"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } @@ -878,7 +880,7 @@ class MergeKeysP : public MergeKeysKernel { NdArrayRef out(inputs[0].eltype(), inputs[0].shape()); const auto field = inputs[0].eltype().as()->field(); const auto numel = inputs[0].numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = std::make_signed_t; NdArrayView _out(out); _out[0] = 0; @@ -899,7 +901,7 @@ class MergeKeysP : public MergeKeysKernel { class MergeKeysV : public MergeKeysKernel { public: - static constexpr char kBindName[] = "merge_keys_v"; + static constexpr const char* kBindName() { return "merge_keys_v"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } @@ -917,7 +919,7 @@ class MergeKeysV : public MergeKeysKernel { NdArrayRef out(inputs[0].eltype(), inputs[0].shape()); const auto field = inputs[0].eltype().as()->field(); const auto numel = inputs[0].numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = std::make_signed_t; NdArrayView _out(out); _out[0] = 0; @@ -952,7 +954,7 @@ void regPV2kTypes() { void regPV2kKernels(Object* obj) { obj->regKernelgetParam(0); - size_t bits = ctx->getParam(1); + const auto& bits = ctx->getParam(1); + + SPU_ENFORCE( + bits.size() == 1 || in.numel() == static_cast(bits.size()), + "numel mismatch {} {}", in.numel(), bits.size()); auto res = proc(ctx, UnwrapValue(in), bits); diff --git a/libspu/mpc/kernel.h b/libspu/mpc/kernel.h index 12383391..d75b3426 100644 --- a/libspu/mpc/kernel.h +++ b/libspu/mpc/kernel.h @@ -42,7 +42,7 @@ class ShiftKernel : public Kernel { public: void evaluate(KernelEvalContext* ctx) const override; virtual NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const = 0; + const Sizes& bits) const = 0; }; class BinaryKernel : public Kernel { diff --git a/libspu/mpc/ref2k/ref2k.cc b/libspu/mpc/ref2k/ref2k.cc index 1d76ea95..28b8c176 100644 --- a/libspu/mpc/ref2k/ref2k.cc +++ b/libspu/mpc/ref2k/ref2k.cc @@ -47,7 +47,7 @@ void registerTypes() { class Ref2kCommonTypeS : public Kernel { public: - static constexpr char kBindName[] = "common_type_s"; + static constexpr const char* kBindName() { return "common_type_s"; } Kind kind() const override { return Kind::Dynamic; } @@ -64,7 +64,7 @@ class Ref2kCommonTypeS : public Kernel { class Ref2kCommonTypeV : public Kernel { public: - static constexpr char kBindName[] = "common_type_v"; + static constexpr const char* kBindName() { return "common_type_v"; } Kind kind() const override { return Kind::Dynamic; } @@ -86,7 +86,7 @@ class Ref2kCommonTypeV : public Kernel { class Ref2kCastTypeS : public CastTypeKernel { public: - static constexpr char kBindName[] = "cast_type_s"; + static constexpr const char* kBindName() { return "cast_type_s"; } Kind kind() const override { return Kind::Dynamic; } @@ -102,7 +102,7 @@ class Ref2kCastTypeS : public CastTypeKernel { class Ref2kP2S : public UnaryKernel { public: - static constexpr char kBindName[] = "p2s"; + static constexpr const char* kBindName() { return "p2s"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -115,7 +115,7 @@ class Ref2kP2S : public UnaryKernel { class Ref2kS2P : public UnaryKernel { public: - static constexpr char kBindName[] = "s2p"; + static constexpr const char* kBindName() { return "s2p"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -128,7 +128,7 @@ class Ref2kS2P : public UnaryKernel { class Ref2kS2V : public RevealToKernel { public: - static constexpr char kBindName[] = "s2v"; + static constexpr const char* kBindName() { return "s2v"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -149,7 +149,8 @@ class Ref2kS2V : public RevealToKernel { class Ref2kV2S : public UnaryKernel { public: - static constexpr char kBindName[] = "v2s"; + static constexpr const char* kBindName() { return "v2s"; } + Kind kind() const override { return Kind::Dynamic; } ce::CExpr latency() const override { return ce::Const(0); } @@ -165,7 +166,7 @@ class Ref2kV2S : public UnaryKernel { int64_t numel = in.numel(); - DISPATCH_ALL_FIELDS(field, "v2s", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { std::vector _send(numel); NdArrayView _in(in); @@ -185,7 +186,7 @@ class Ref2kV2S : public UnaryKernel { class Ref2kRandS : public RandKernel { public: - static constexpr char kBindName[] = "rand_s"; + static constexpr const char* kBindName() { return "rand_s"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -196,13 +197,13 @@ class Ref2kRandS : public RandKernel { const auto field = ctx->getState()->getDefaultField(); return ring_rshift( - state->genPubl(field, shape).as(makeType(field)), 2); + state->genPubl(field, shape).as(makeType(field)), {2}); } }; -class Ref2kNotS : public UnaryKernel { +class Ref2kNegateS : public UnaryKernel { public: - static constexpr char kBindName[] = "not_s"; + static constexpr const char* kBindName() { return "negate_s"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -210,13 +211,13 @@ class Ref2kNotS : public UnaryKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override { const auto field = in.eltype().as()->field(); - return ring_not(in).as(makeType(field)); + return ring_neg(in).as(makeType(field)); } }; class Ref2kAddSS : public BinaryKernel { public: - static constexpr char kBindName[] = "add_ss"; + static constexpr const char* kBindName() { return "add_ss"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -231,7 +232,7 @@ class Ref2kAddSS : public BinaryKernel { class Ref2kAddSP : public BinaryKernel { public: - static constexpr char kBindName[] = "add_sp"; + static constexpr const char* kBindName() { return "add_sp"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -245,7 +246,7 @@ class Ref2kAddSP : public BinaryKernel { class Ref2kMulSS : public BinaryKernel { public: - static constexpr char kBindName[] = "mul_ss"; + static constexpr const char* kBindName() { return "mul_ss"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -260,7 +261,7 @@ class Ref2kMulSS : public BinaryKernel { class Ref2kMulSP : public BinaryKernel { public: - static constexpr char kBindName[] = "mul_sp"; + static constexpr const char* kBindName() { return "mul_sp"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -274,7 +275,7 @@ class Ref2kMulSP : public BinaryKernel { class Ref2kMatMulSS : public MatmulKernel { public: - static constexpr char kBindName[] = "mmul_ss"; + static constexpr const char* kBindName() { return "mmul_ss"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -289,7 +290,7 @@ class Ref2kMatMulSS : public MatmulKernel { class Ref2kMatMulSP : public MatmulKernel { public: - static constexpr char kBindName[] = "mmul_sp"; + static constexpr const char* kBindName() { return "mmul_sp"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -303,7 +304,7 @@ class Ref2kMatMulSP : public MatmulKernel { class Ref2kAndSS : public BinaryKernel { public: - static constexpr char kBindName[] = "and_ss"; + static constexpr const char* kBindName() { return "and_ss"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -318,7 +319,7 @@ class Ref2kAndSS : public BinaryKernel { class Ref2kAndSP : public BinaryKernel { public: - static constexpr char kBindName[] = "and_sp"; + static constexpr const char* kBindName() { return "and_sp"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -332,7 +333,7 @@ class Ref2kAndSP : public BinaryKernel { class Ref2kXorSS : public BinaryKernel { public: - static constexpr char kBindName[] = "xor_ss"; + static constexpr const char* kBindName() { return "xor_ss"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -347,7 +348,7 @@ class Ref2kXorSS : public BinaryKernel { class Ref2kXorSP : public BinaryKernel { public: - static constexpr char kBindName[] = "xor_sp"; + static constexpr const char* kBindName() { return "xor_sp"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -361,35 +362,35 @@ class Ref2kXorSP : public BinaryKernel { class Ref2kLShiftS : public ShiftKernel { public: - static constexpr char kBindName[] = "lshift_s"; + static constexpr const char* kBindName() { return "lshift_s"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override { + const Sizes& bits) const override { return ring_lshift(in, bits).as(in.eltype()); } }; class Ref2kRShiftS : public ShiftKernel { public: - static constexpr char kBindName[] = "rshift_s"; + static constexpr const char* kBindName() { return "rshift_s"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override { + const Sizes& bits) const override { return ring_rshift(in, bits).as(in.eltype()); } }; class Ref2kBitrevS : public BitrevKernel { public: - static constexpr char kBindName[] = "bitrev_s"; + static constexpr const char* kBindName() { return "bitrev_s"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -407,21 +408,21 @@ class Ref2kBitrevS : public BitrevKernel { class Ref2kARShiftS : public ShiftKernel { public: - static constexpr char kBindName[] = "arshift_s"; + static constexpr const char* kBindName() { return "arshift_s"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override { + const Sizes& bits) const override { return ring_arshift(in, bits).as(in.eltype()); } }; class Ref2kTruncS : public TruncAKernel { public: - static constexpr char kBindName[] = "trunc_s"; + static constexpr const char* kBindName() { return "trunc_s"; } bool hasMsbError() const override { return false; } @@ -441,8 +442,8 @@ class Ref2kTruncS : public TruncAKernel { // https://stackoverflow.com/questions/14008330/how-do-you-multiply-two-fixed-point-numbers // Under certain pattern, like sum(mul(A, B)), error can accumulate in a // fairly significant way - auto v1 = ring_arshift(in, bits); - auto v2 = ring_arshift(in, bits - 1); + auto v1 = ring_arshift(in, {static_cast(bits)}); + auto v2 = ring_arshift(in, {static_cast(bits - 1)}); ring_and_(v2, ring_ones(in.eltype().as()->field(), in.shape())); ring_add_(v1, v2); return v1; @@ -451,14 +452,15 @@ class Ref2kTruncS : public TruncAKernel { class Ref2kMsbS : public UnaryKernel { public: - static constexpr char kBindName[] = "msb_s"; + static constexpr const char* kBindName() { return "msb_s"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override { - return ring_rshift(in, in.elsize() * 8 - 1).as(in.eltype()); + return ring_rshift(in, {static_cast(in.elsize() * 8 - 1)}) + .as(in.eltype()); } }; @@ -487,7 +489,7 @@ void regRef2kProtocol(SPUContext* ctx, ctx->prot() ->regKernel #include +#include "yacl/crypto/rand/rand.h" + #include "libspu/core/type_util.h" -#include "libspu/core/vectorize.h" -#include "libspu/mpc/ab_api.h" #include "libspu/mpc/common/communicator.h" #include "libspu/mpc/common/prg_state.h" #include "libspu/mpc/common/pv2k.h" @@ -37,7 +37,7 @@ NdArrayRef A2V::proc(KernelEvalContext* ctx, const NdArrayRef& in, auto numel = in.numel(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { std::vector share(numel); NdArrayView _in(in); pforeach(0, numel, [&](int64_t idx) { share[idx] = _in[idx]; }); @@ -109,7 +109,7 @@ NdArrayRef RandA::proc(KernelEvalContext* ctx, const Shape& shape) const { // - https://eprint.iacr.org/2019/599.pdf // It's safer to keep the number within [-2**(k-2), 2**(k-2)) for comparison // operations. - return ring_rshift(prg_state->genPriv(field, shape), 2) + return ring_rshift(prg_state->genPriv(field, shape), {2}) .as(makeType(field)); } @@ -130,34 +130,12 @@ NdArrayRef P2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { NdArrayRef A2P::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { const auto field = in.eltype().as()->field(); auto* comm = ctx->getState(); - auto out = comm->allReduce(ReduceOp::ADD, in, kBindName); + auto out = comm->allReduce(ReduceOp::ADD, in, kBindName()); return out.as(makeType(field)); } -NdArrayRef NotA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { - auto* comm = ctx->getState(); - - // First, let's show negate could be locally processed. - // let X = sum(Xi) % M - // let Yi = neg(Xi) = M-Xi - // - // we get - // Y = sum(Yi) % M - // = n*M - sum(Xi) % M - // = -sum(Xi) % M - // = -X % M - // - // 'not' could be processed accordingly. - // not(X) - // = M-1-X # by definition, not is the complement of 2^k - // = neg(X) + M-1 - // +NdArrayRef NegateA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { auto res = ring_neg(in); - if (comm->getRank() == 0) { - const auto field = in.eltype().as()->field(); - ring_add_(res, ring_not(ring_zeros(field, in.shape()))); - } - return res.as(in.eltype()); } @@ -200,10 +178,7 @@ NdArrayRef MatMulAP::proc(KernelEvalContext* ctx, const NdArrayRef& x, } NdArrayRef LShiftA::proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const { - const auto field = in.eltype().as()->field(); - bits %= SizeOf(field) * 8; - + const Sizes& bits) const { return ring_lshift(in, bits).as(in.eltype()); } @@ -216,10 +191,10 @@ NdArrayRef TruncAPr::proc(KernelEvalContext* ctx, const NdArrayRef& in, auto rank = comm->getRank(); const auto numel = in.numel(); const auto field = in.eltype().as()->field(); - const size_t k = SizeOf(field) * 8; + const int k = SizeOf(field) * 8; NdArrayRef out(in.eltype(), in.shape()); - DISPATCH_ALL_FIELDS(field, "securenn.truncpr", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; auto r = prg_state->genPriv(field, in.shape()); @@ -232,9 +207,11 @@ NdArrayRef TruncAPr::proc(KernelEvalContext* ctx, const NdArrayRef& in, auto rb_recon = comm->reduce(ReduceOp::ADD, rb, 2, "rb"); if (rank == 2) { - auto adjust1 = - ring_sub(ring_rshift(ring_lshift(r_recon, 1), bits + 1), rc_recon); - auto adjust2 = ring_sub(ring_rshift(r_recon, k - 1), rb_recon); + auto adjust1 = ring_sub(ring_rshift(ring_lshift(r_recon, {1}), + {static_cast(bits + 1)}), + rc_recon); + auto adjust2 = ring_sub( + ring_rshift(r_recon, {static_cast(k - 1)}), rb_recon); comm->sendAsync(0, adjust1, "adjust1"); comm->sendAsync(0, adjust2, "adjust2"); } @@ -272,7 +249,7 @@ NdArrayRef TruncAPr::proc(KernelEvalContext* ctx, const NdArrayRef& in, x_plus_r[idx] = x + _r[idx]; }); // open + = c - c = comm->allReduce(x_plus_r, kBindName); + c = comm->allReduce(x_plus_r, kBindName()); } pforeach(0, numel, [&](int64_t idx) { @@ -362,7 +339,6 @@ NdArrayRef MulAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, b = prg_state ->genPrssPair(field, x.shape(), PrgState::GenPrssCtrl::Second) .second; - prg_state->genPrssPair(field, x.shape(), PrgState::GenPrssCtrl::None); c = comm->recv(2, ty, "c"); c = c.reshape(x.shape()); } @@ -527,7 +503,6 @@ NdArrayRef MatMulAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, .second; b = prg_state->genPrssPair(field, shape2, PrgState::GenPrssCtrl::Second) .second; - prg_state->genPrssPair(field, shape3, PrgState::GenPrssCtrl::None); c = comm->recv(2, ty, "c"); c = c.reshape(shape3); @@ -594,7 +569,7 @@ NdArrayRef ShareConvert::proc(KernelEvalContext* ctx, const auto kComm = a.elsize() * size; comm->addCommStatsManually(4, 4 * log_p * kComm + 6 * kComm); - DISPATCH_ALL_FIELDS(field, "securenn.sc", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; const U L_1 = (U)(~0); // 2^k - 1 // P0 and P1 add the share of zero @@ -785,10 +760,6 @@ NdArrayRef ShareConvert::proc(KernelEvalContext* ctx, } // P0 and P1 end execute if (rank == 2) { - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution dis(0, L_1 - 1); - auto a_0 = comm->recv(0, ty, "a_"); auto a_1 = comm->recv(1, ty, "a_"); a_0 = a_0.reshape(a.shape()); @@ -811,7 +782,8 @@ NdArrayRef ShareConvert::proc(KernelEvalContext* ctx, NdArrayView _dp_x_p0(dp_x_p0); NdArrayView _dp_x_p1(dp_x_p1); - NdArrayRef delta_p0(ty, a.shape()); + NdArrayRef delta_p0 = + ring_rand_range(field, a.shape(), 0, L_1 - 1); // (ty, a.shape()); NdArrayRef delta_p1(ty, a.shape()); NdArrayView _delta_p0(delta_p0); NdArrayView _delta_p1(delta_p1); @@ -830,7 +802,6 @@ NdArrayRef ShareConvert::proc(KernelEvalContext* ctx, } // split delta in Z_(L_1) - _delta_p0[idx] = dis(gen); _delta_p1[idx] = _delta[idx] - _delta_p0[idx]; if (_delta[idx] < _delta_p0[idx]) _delta_p1[idx] -= (U)1; // when overflow @@ -843,7 +814,7 @@ NdArrayRef ShareConvert::proc(KernelEvalContext* ctx, comm->sendAsync(1, delta_p1, "delta"); // split eta_ in Z_(L_1) - NdArrayRef eta_p0(ty, a.shape()); + NdArrayRef eta_p0 = ring_rand_range(field, a.shape(), 0, L_1 - 1); NdArrayRef eta_p1(ty, a.shape()); NdArrayView _eta_p0(eta_p0); NdArrayView _eta_p1(eta_p1); @@ -870,7 +841,6 @@ NdArrayRef ShareConvert::proc(KernelEvalContext* ctx, } // split eta_ in Z_(L_1) - _eta_p0[idx] = dis(gen); _eta_p1[idx] = _eta_[idx] - _eta_p0[idx]; if (_eta_[idx] < _eta_p0[idx]) _eta_p1[idx] -= (U)1; // when overflow }); // end pforeach @@ -902,7 +872,7 @@ NdArrayRef Msb::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { const auto kComm = in.elsize() * size; comm->addCommStatsManually(5, 13 * kComm + 4 * kComm * log_p); - DISPATCH_ALL_FIELDS(field, "securenn.msb", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; const U L_1 = (U)(~0); @@ -916,10 +886,6 @@ NdArrayRef Msb::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { auto [u_r0, u_r1] = prg_state->genPrssPair(field, {size * k}, PrgState::GenPrssCtrl::Both); if (rank == 2) { - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution dis(0, L_1 - 1); - // random for beaver // P2 generate a0, a1, b0, b1, c0 by PRF // and calculate c1 @@ -935,12 +901,12 @@ NdArrayRef Msb::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { auto c1 = ring_sub(ring_mul(ring_add(a0, a1), ring_add(b0, b1)), c0); // end beaver (c1 will be sent with x to reduce one round latency) - NdArrayRef x(ty, in.shape()); + NdArrayRef x = ring_rand_range(field, in.shape(), 0, L_1 - 1); NdArrayView _x(x); // split x into x_p0 and x_p1 in Z_(L-1), (L=2^k) - NdArrayRef x_p0(ty, in.shape()); + NdArrayRef x_p0 = ring_rand_range(field, in.shape(), 0, L_1 - 1); NdArrayRef x_p1(ty, in.shape()); NdArrayView _x_p0(x_p0); NdArrayView _x_p1(x_p1); @@ -959,11 +925,9 @@ NdArrayRef Msb::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { NdArrayRef lsb_x(ty, in.shape()); NdArrayView _lsb_x(lsb_x); pforeach(0, size, [&](int64_t idx) { - _x[idx] = dis(gen); auto dp_x = bitDecompose(_x[idx], k); // vector // split x - _x_p0[idx] = dis(gen); _x_p1[idx] = _x[idx] - _x_p0[idx]; if (_x[idx] < _x_p0[idx]) _x_p1[idx] -= (U)1; // when overflow @@ -1050,7 +1014,6 @@ NdArrayRef Msb::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { prg_state ->genPrssPair(field, in.shape(), PrgState::GenPrssCtrl::Second) .second; - prg_state->genPrssPair(field, in.shape(), PrgState::GenPrssCtrl::None); beaver_c = comm->recv(2, ty, "beaver_c"); beaver_c = beaver_c.reshape(in.shape()); } @@ -1241,7 +1204,7 @@ NdArrayRef Msb_opt::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { const auto kComm = in.elsize() * size; comm->addCommStatsManually(5, 9 * kComm + 3 * kComm * log_p); - DISPATCH_ALL_FIELDS(field, "securenn.msb", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; const U L_1 = (U)(~0); @@ -1265,10 +1228,6 @@ NdArrayRef Msb_opt::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { auto [beta_0, beta_1] = prg_state->genPrssPair(field, in.shape(), PrgState::GenPrssCtrl::Both); if (rank == 2) { - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution dis(0, L_1 - 1); - // random for beaver // P2 generate a0, a1, b0, b1, c0 by PRF // and calculate c1 @@ -1395,7 +1354,6 @@ NdArrayRef Msb_opt::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { prg_state ->genPrssPair(field, in.shape(), PrgState::GenPrssCtrl::Second) .second; - prg_state->genPrssPair(field, in.shape(), PrgState::GenPrssCtrl::None); beaver_c = comm->recv(2, ty, "beaver_c"); beaver_c = beaver_c.reshape(in.shape()); } diff --git a/libspu/mpc/securenn/arithmetic.h b/libspu/mpc/securenn/arithmetic.h index 599bf14a..947a5917 100644 --- a/libspu/mpc/securenn/arithmetic.h +++ b/libspu/mpc/securenn/arithmetic.h @@ -20,7 +20,7 @@ namespace spu::mpc::securenn { class A2V : public RevealToKernel { public: - static constexpr char kBindName[] = "a2v"; + static constexpr const char* kBindName() { return "a2v"; } // TODO: communication is unbalanced Kind kind() const override { return Kind::Dynamic; } @@ -35,7 +35,7 @@ class A2V : public RevealToKernel { class V2A : public UnaryKernel { public: - static constexpr char kBindName[] = "v2a"; + static constexpr const char* kBindName() { return "v2a"; } // TODO: communication is unbalanced Kind kind() const override { return Kind::Dynamic; } @@ -49,7 +49,7 @@ class V2A : public UnaryKernel { class RandA : public RandKernel { public: - static constexpr char kBindName[] = "rand_a"; + static constexpr const char* kBindName() { return "rand_a"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -60,7 +60,7 @@ class RandA : public RandKernel { class P2A : public UnaryKernel { public: - static constexpr char kBindName[] = "p2a"; + static constexpr const char* kBindName() { return "p2a"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -71,7 +71,7 @@ class P2A : public UnaryKernel { class A2P : public UnaryKernel { public: - static constexpr char kBindName[] = "a2p"; + static constexpr const char* kBindName() { return "a2p"; } ce::CExpr latency() const override { return ce::Const(1); } @@ -80,9 +80,9 @@ class A2P : public UnaryKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override; }; -class NotA : public UnaryKernel { +class NegateA : public UnaryKernel { public: - static constexpr char kBindName[] = "not_a"; + static constexpr const char* kBindName() { return "negate_a"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -96,7 +96,7 @@ class NotA : public UnaryKernel { //////////////////////////////////////////////////////////////////// class AddAP : public BinaryKernel { public: - static constexpr char kBindName[] = "add_ap"; + static constexpr const char* kBindName() { return "add_ap"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -108,7 +108,7 @@ class AddAP : public BinaryKernel { class AddAA : public BinaryKernel { public: - static constexpr char kBindName[] = "add_aa"; + static constexpr const char* kBindName() { return "add_aa"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -123,7 +123,7 @@ class AddAA : public BinaryKernel { //////////////////////////////////////////////////////////////////// class MulAP : public BinaryKernel { public: - static constexpr char kBindName[] = "mul_ap"; + static constexpr const char* kBindName() { return "mul_ap"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -135,7 +135,7 @@ class MulAP : public BinaryKernel { class MulAA : public BinaryKernel { public: - static constexpr char kBindName[] = "mul_aa"; + static constexpr const char* kBindName() { return "mul_aa"; } ce::CExpr latency() const override { // online @@ -153,7 +153,7 @@ class MulAA : public BinaryKernel { //////////////////////////////////////////////////////////////////// class MatMulAP : public MatmulKernel { public: - static constexpr char kBindName[] = "mmul_ap"; + static constexpr const char* kBindName() { return "mmul_ap"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -165,14 +165,14 @@ class MatMulAP : public MatmulKernel { class LShiftA : public ShiftKernel { public: - static constexpr char kBindName[] = "lshift_a"; + static constexpr const char* kBindName() { return "lshift_a"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; // Refer to: @@ -181,7 +181,7 @@ class LShiftA : public ShiftKernel { // https://eprint.iacr.org/2020/338.pdf class TruncAPr : public TruncAKernel { public: - static constexpr char kBindName[] = "trunc_a"; + static constexpr const char* kBindName() { return "trunc_a"; } Kind kind() const override { return Kind::Static; } // offline + online @@ -201,8 +201,7 @@ class TruncAPr : public TruncAKernel { class MatMulAA : public MatmulKernel { public: - // static constexpr char kBindName[] = "mmul_aa_2pc"; - static constexpr char kBindName[] = "mmul_aa"; + static constexpr const char* kBindName() { return "mmul_aa"; } ce::CExpr latency() const override { // beaver + online @@ -223,7 +222,7 @@ class MatMulAA : public MatmulKernel { class MatMulAA_simple : public MatmulKernel { public: - static constexpr char kBindName[] = "mmul_aa_simple"; + static constexpr const char* kBindName() { return "mmul_aa_simple"; } ce::CExpr latency() const override { // beaver + online @@ -244,7 +243,7 @@ class MatMulAA_simple : public MatmulKernel { class Msb : public UnaryKernel { public: - static constexpr char kBindName[] = "msb_a2a"; + static constexpr const char* kBindName() { return "msb_a2a"; } ce::CExpr latency() const override { return ce::Const(5); } ce::CExpr comm() const override { @@ -258,7 +257,7 @@ class Msb : public UnaryKernel { class Msb_opt : public UnaryKernel { public: - static constexpr char kBindName[] = "msb_opt_a2a"; + static constexpr const char* kBindName() { return "msb_opt_a2a"; } ce::CExpr latency() const override { return ce::Const(5); } ce::CExpr comm() const override { @@ -272,7 +271,7 @@ class Msb_opt : public UnaryKernel { class ShareConvert : public UnaryKernel { public: - static constexpr char kBindName[] = "sc"; + static constexpr const char* kBindName() { return "sc"; } ce::CExpr latency() const override { return ce::Const(4); } ce::CExpr comm() const override { const auto log_p = 9; diff --git a/libspu/mpc/securenn/boolean.cc b/libspu/mpc/securenn/boolean.cc index 7d08673e..c122b9e1 100644 --- a/libspu/mpc/securenn/boolean.cc +++ b/libspu/mpc/securenn/boolean.cc @@ -31,7 +31,7 @@ namespace { size_t getNumBits(const NdArrayRef& in) { if (in.eltype().isa()) { const auto field = in.eltype().as()->field(); - return DISPATCH_ALL_FIELDS(field, "_", + return DISPATCH_ALL_FIELDS(field, [&]() { return maxBitWidth(in); }); } else if (in.eltype().isa()) { return in.eltype().as()->nbits(); @@ -91,7 +91,7 @@ NdArrayRef CastTypeB::proc(KernelEvalContext* ctx, const NdArrayRef& in, NdArrayRef B2P::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { const auto field = in.eltype().as()->field(); auto* comm = ctx->getState(); - auto out = comm->allReduce(ReduceOp::XOR, in, kBindName); + auto out = comm->allReduce(ReduceOp::XOR, in, kBindName()); return out.as(makeType(field)); } @@ -120,7 +120,7 @@ NdArrayRef AndBP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const size_t out_nbits = std::min(getNumBits(lhs), getNumBits(rhs)); NdArrayRef out(makeType(field, out_nbits), lhs.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _lhs(lhs); NdArrayView _rhs(rhs); NdArrayView _out(out); @@ -146,12 +146,12 @@ NdArrayRef AndBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const int64_t numel = lhs.numel(); NdArrayRef out(makeType(field, out_nbits), lhs.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = ring2k_t; NdArrayView _lhs(lhs); NdArrayView _rhs(rhs); - DISPATCH_UINT_PT_TYPES(backtype, "_", [&]() { + DISPATCH_UINT_PT_TYPES(backtype, [&]() { using V = ScalarT; int64_t numBytes = numel * SizeOf(backtype); @@ -197,7 +197,6 @@ NdArrayRef AndBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, b = prg_state ->genPrssPair(field, {numField}, PrgState::GenPrssCtrl::Second) .second; - prg_state->genPrssPair(field, {numField}, PrgState::GenPrssCtrl::None); c = comm->recv(2, ty, "c"); c = c.reshape({numField}); } @@ -257,32 +256,32 @@ NdArrayRef XorBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, } NdArrayRef LShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const { + const Sizes& shift) const { const auto field = in.eltype().as()->field(); - shift %= SizeOf(field) * 8; - size_t out_nbits = in.eltype().as()->nbits() + shift; - out_nbits = std::clamp(out_nbits, static_cast(0), SizeOf(field) * 8); + int64_t out_nbits = in.eltype().as()->nbits() + + *std::max_element(shift.begin(), shift.end()); + out_nbits = std::clamp(out_nbits, 0L, + static_cast(SizeOf(field) * 8)); return makeBShare(ring_lshift(in, shift), field, out_nbits); } NdArrayRef RShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const { + const Sizes& shift) const { const auto field = in.eltype().as()->field(); - shift %= SizeOf(field) * 8; - size_t nbits = in.eltype().as()->nbits(); - size_t out_nbits = nbits - std::min(nbits, shift); - SPU_ENFORCE(nbits <= SizeOf(field) * 8); + int64_t nbits = in.eltype().as()->nbits(); + int64_t out_nbits = + nbits - std::min(nbits, *std::min_element(shift.begin(), shift.end())); + SPU_ENFORCE(nbits <= static_cast(SizeOf(field) * 8)); return makeBShare(ring_rshift(in, shift), field, out_nbits); } NdArrayRef ARShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const { + const Sizes& shift) const { const auto field = in.eltype().as()->field(); - shift %= SizeOf(field) * 8; // arithmetic right shift expects to work on ring, or the behaviour is // undefined. @@ -310,7 +309,7 @@ NdArrayRef BitIntlB::proc(KernelEvalContext* ctx, const NdArrayRef& in, NdArrayRef out(in.eltype(), in.shape()); auto numel = in.numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _in(in); NdArrayView _out(out); @@ -331,7 +330,7 @@ NdArrayRef BitDeintlB::proc(KernelEvalContext* ctx, const NdArrayRef& in, NdArrayRef out(in.eltype(), in.shape()); auto numel = in.numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _in(in); NdArrayView _out(out); diff --git a/libspu/mpc/securenn/boolean.h b/libspu/mpc/securenn/boolean.h index 35433c0b..a19f1092 100644 --- a/libspu/mpc/securenn/boolean.h +++ b/libspu/mpc/securenn/boolean.h @@ -20,7 +20,7 @@ namespace spu::mpc::securenn { class CommonTypeB : public Kernel { public: - static constexpr char kBindName[] = "common_type_b"; + static constexpr const char* kBindName() { return "common_type_b"; } Kind kind() const override { return Kind::Dynamic; } @@ -29,7 +29,7 @@ class CommonTypeB : public Kernel { class CastTypeB : public CastTypeKernel { public: - static constexpr char kBindName[] = "cast_type_b"; + static constexpr const char* kBindName() { return "cast_type_b"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -41,7 +41,7 @@ class CastTypeB : public CastTypeKernel { class B2P : public UnaryKernel { public: - static constexpr char kBindName[] = "b2p"; + static constexpr const char* kBindName() { return "b2p"; } ce::CExpr latency() const override { return ce::Const(1); } @@ -52,7 +52,7 @@ class B2P : public UnaryKernel { class P2B : public UnaryKernel { public: - static constexpr char kBindName[] = "p2b"; + static constexpr const char* kBindName() { return "p2b"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -63,7 +63,7 @@ class P2B : public UnaryKernel { class AndBP : public BinaryKernel { public: - static constexpr char kBindName[] = "and_bp"; + static constexpr const char* kBindName() { return "and_bp"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -75,7 +75,7 @@ class AndBP : public BinaryKernel { class AndBB : public BinaryKernel { public: - static constexpr char kBindName[] = "and_bb"; + static constexpr const char* kBindName() { return "and_bb"; } ce::CExpr latency() const override { return ce::Const(1); } @@ -87,7 +87,7 @@ class AndBB : public BinaryKernel { class XorBP : public BinaryKernel { public: - static constexpr char kBindName[] = "xor_bp"; + static constexpr const char* kBindName() { return "xor_bp"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -99,7 +99,7 @@ class XorBP : public BinaryKernel { class XorBB : public BinaryKernel { public: - static constexpr char kBindName[] = "xor_bb"; + static constexpr const char* kBindName() { return "xor_bb"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -111,43 +111,43 @@ class XorBB : public BinaryKernel { class LShiftB : public ShiftKernel { public: - static constexpr char kBindName[] = "lshift_b"; + static constexpr const char* kBindName() { return "lshift_b"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const override; + const Sizes& shift) const override; }; class RShiftB : public ShiftKernel { public: - static constexpr char kBindName[] = "rshift_b"; + static constexpr const char* kBindName() { return "rshift_b"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const override; + const Sizes& shift) const override; }; class ARShiftB : public ShiftKernel { public: - static constexpr char kBindName[] = "arshift_b"; + static constexpr const char* kBindName() { return "arshift_b"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const override; + const Sizes& shift) const override; }; class BitrevB : public BitrevKernel { public: - static constexpr char kBindName[] = "bitrev_b"; + static constexpr const char* kBindName() { return "bitrev_b"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -159,7 +159,7 @@ class BitrevB : public BitrevKernel { class BitIntlB : public BitSplitKernel { public: - static constexpr char kBindName[] = "bitintl_b"; + static constexpr const char* kBindName() { return "bitintl_b"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -171,7 +171,7 @@ class BitIntlB : public BitSplitKernel { class BitDeintlB : public BitSplitKernel { public: - static constexpr char kBindName[] = "bitdeintl_b"; + static constexpr const char* kBindName() { return "bitdeintl_b"; } ce::CExpr latency() const override { return ce::Const(0); } diff --git a/libspu/mpc/securenn/conversion.cc b/libspu/mpc/securenn/conversion.cc index 1fe99fe4..cc78358c 100644 --- a/libspu/mpc/securenn/conversion.cc +++ b/libspu/mpc/securenn/conversion.cc @@ -73,8 +73,8 @@ NdArrayRef B2A::proc(KernelEvalContext* ctx, const NdArrayRef& x) const { auto r_b = wrap_a2b(ctx->sctx(), r_a); // evaluate adder circuit on x & r, and reveal x+r - auto x_plus_r = comm->allReduce(ReduceOp::XOR, - wrap_add_bb(ctx->sctx(), x, r_b), kBindName); + auto x_plus_r = comm->allReduce( + ReduceOp::XOR, wrap_add_bb(ctx->sctx(), x, r_b), kBindName()); // compute -r + (x+r) ring_neg_(r_a); @@ -112,8 +112,12 @@ NdArrayRef B2A_Randbit::proc(KernelEvalContext* ctx, auto randbits = prg_state->genPriv(field, {numel * static_cast(nbits)}); // reconstruct ranbits - if (rank == 0) comm->sendAsync(2, randbits, "randbits0"); - if (rank == 1) comm->sendAsync(2, randbits, "randbits1"); + if (rank == 0) { + comm->sendAsync(2, randbits, "randbits0"); + } + if (rank == 1) { + comm->sendAsync(2, randbits, "randbits1"); + } if (rank == 2) { auto randbits0 = comm->recv(0, makeType(field), "randbits0"); randbits0 = randbits0.reshape(randbits.shape()); @@ -132,7 +136,7 @@ NdArrayRef B2A_Randbit::proc(KernelEvalContext* ctx, auto res = NdArrayRef(makeType(field), x.shape()); - DISPATCH_ALL_FIELDS(field, kBindName, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; NdArrayView _randbits(randbits); diff --git a/libspu/mpc/securenn/conversion.h b/libspu/mpc/securenn/conversion.h index 06544825..b769fd45 100644 --- a/libspu/mpc/securenn/conversion.h +++ b/libspu/mpc/securenn/conversion.h @@ -19,7 +19,7 @@ namespace spu::mpc::securenn { class A2B : public UnaryKernel { public: - static constexpr char kBindName[] = "a2b"; + static constexpr const char* kBindName() { return "a2b"; } ce::CExpr latency() const override { return (Log(ce::K()) + 1) // adder-circuit; @@ -39,7 +39,7 @@ class A2B : public UnaryKernel { class B2A : public UnaryKernel { public: - static constexpr char kBindName[] = "b2a"; + static constexpr const char* kBindName() { return "b2a"; } ce::CExpr latency() const override { return (Log(ce::K()) + 1) * Log(ce::N()) // A2B @@ -60,7 +60,7 @@ class B2A : public UnaryKernel { class B2A_Randbit : public UnaryKernel { public: - static constexpr char kBindName[] = "b2a"; + static constexpr const char* kBindName() { return "b2a"; } ce::CExpr latency() const override { return ce::Const(1); } @@ -74,7 +74,7 @@ class B2A_Randbit : public UnaryKernel { class Msb_a2b : public UnaryKernel { public: - static constexpr char kBindName[] = "msb_a2b"; + static constexpr const char* kBindName() { return "msb_a2b"; } ce::CExpr latency() const override { #ifndef OPT_SECURENN_MSB @@ -110,7 +110,7 @@ class Msb_a2b : public UnaryKernel { class CommonTypeV : public Kernel { public: - static constexpr char kBindName[] = "common_type_v"; + static constexpr const char* kBindName() { return "common_type_v"; } Kind kind() const override { return Kind::Dynamic; } diff --git a/libspu/mpc/securenn/protocol.cc b/libspu/mpc/securenn/protocol.cc index 0fe800b0..3140b06c 100644 --- a/libspu/mpc/securenn/protocol.cc +++ b/libspu/mpc/securenn/protocol.cc @@ -50,7 +50,7 @@ void regSecurennProtocol(SPUContext* ctx, ctx->prot() ->regKernel< securenn::P2A, securenn::A2P, securenn::A2V, securenn::V2A, // - securenn::NotA, // + securenn::NegateA, // securenn::AddAP, securenn::AddAA, // securenn::MulAP, securenn::MulAA, // securenn::MatMulAP, securenn::MatMulAA, securenn::MatMulAA_simple, // diff --git a/libspu/mpc/securenn/state.h b/libspu/mpc/securenn/state.h index 2e673c8b..905da8ea 100644 --- a/libspu/mpc/securenn/state.h +++ b/libspu/mpc/securenn/state.h @@ -24,7 +24,7 @@ class SecurennState : public State { public: SecurennState() = default; - static constexpr char kBindName[] = "SecurennState"; + static constexpr const char* kBindName() { return "SecurennState"; } std::unique_ptr fork() override { auto ret = std::unique_ptr(new SecurennState); diff --git a/libspu/mpc/semi2k/BUILD.bazel b/libspu/mpc/semi2k/BUILD.bazel index e845ee73..dfef1ec7 100644 --- a/libspu/mpc/semi2k/BUILD.bazel +++ b/libspu/mpc/semi2k/BUILD.bazel @@ -46,6 +46,34 @@ spu_cc_library( ], ) +spu_cc_library( + name = "prime_utils", + srcs = ["prime_utils.cc"], + hdrs = ["prime_utils.h"], + deps = [ + ":state", + ":type", + "//libspu/mpc:kernel", + "//libspu/mpc/common:communicator", + "//libspu/mpc/utils:gfmp", + "//libspu/mpc/utils:ring_ops", + ], +) + +spu_cc_library( + name = "exp", + srcs = ["exp.cc"], + hdrs = ["exp.h"], + deps = [ + ":prime_utils", + ":state", + ":type", + "//libspu/mpc:kernel", + "//libspu/mpc/utils:gfmp", + "//libspu/mpc/utils:ring_ops", + ], +) + spu_cc_library( name = "conversion", srcs = ["conversion.cc"], @@ -83,6 +111,7 @@ spu_cc_library( ":arithmetic", ":boolean", ":conversion", + ":exp", ":permute", ":state", "//libspu/mpc/common:prg_state", @@ -94,7 +123,10 @@ spu_cc_test( name = "protocol_test", srcs = ["protocol_test.cc"], deps = [ + ":exp", + ":prime_utils", ":protocol", + ":type", "//libspu/mpc:ab_api_test", "//libspu/mpc:api_test", "//libspu/mpc/semi2k/beaver/beaver_impl/ttp_server:beaver_server", diff --git a/libspu/mpc/semi2k/arithmetic.cc b/libspu/mpc/semi2k/arithmetic.cc index 90591da6..25c31934 100644 --- a/libspu/mpc/semi2k/arithmetic.cc +++ b/libspu/mpc/semi2k/arithmetic.cc @@ -38,7 +38,7 @@ NdArrayRef RandA::proc(KernelEvalContext* ctx, const Shape& shape) const { // - https://eprint.iacr.org/2019/599.pdf // It's safer to keep the number within [-2**(k-2), 2**(k-2)) for comparison // operations. - return ring_rshift(prg_state->genPriv(field, shape), 2) + return ring_rshift(prg_state->genPriv(field, shape), {2}) .as(makeType(field)); } @@ -61,7 +61,7 @@ NdArrayRef P2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { NdArrayRef A2P::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { const auto field = in.eltype().as()->field(); auto* comm = ctx->getState(); - auto out = comm->allReduce(ReduceOp::ADD, in, kBindName); + auto out = comm->allReduce(ReduceOp::ADD, in, kBindName()); return out.as(makeType(field)); } @@ -73,7 +73,7 @@ NdArrayRef A2V::proc(KernelEvalContext* ctx, const NdArrayRef& in, auto numel = in.numel(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { std::vector share(numel); NdArrayView _in(in); pforeach(0, numel, [&](int64_t idx) { share[idx] = _in[idx]; }); @@ -116,30 +116,8 @@ NdArrayRef V2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { return x.as(makeType(field)); } -NdArrayRef NotA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { - auto* comm = ctx->getState(); - - // First, let's show negate could be locally processed. - // let X = sum(Xi) % M - // let Yi = neg(Xi) = M-Xi - // - // we get - // Y = sum(Yi) % M - // = n*M - sum(Xi) % M - // = -sum(Xi) % M - // = -X % M - // - // 'not' could be processed accordingly. - // not(X) - // = M-1-X # by definition, not is the complement of 2^k - // = neg(X) + M-1 - // +NdArrayRef NegateA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { auto res = ring_neg(in); - if (comm->getRank() == 0) { - const auto field = in.eltype().as()->field(); - ring_add_(res, ring_not(ring_zeros(field, in.shape()))); - } - return res.as(in.eltype()); } @@ -289,13 +267,17 @@ NdArrayRef MulAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, auto [a, b, c, x_a, y_b] = MulOpen(ctx, x, y, false); // Zi = Ci + (X - A) * Bi + (Y - B) * Ai + <(X - A) * (Y - B)> - auto z = ring_add( - ring_add(ring_mul(std::move(b), x_a), ring_mul(std::move(a), y_b)), c); + ring_mul_(b, x_a); + ring_mul_(a, y_b); + ring_add_(b, a); + ring_add_(b, c); + if (comm->getRank() == 0) { // z += (X-A) * (Y-B); - ring_add_(z, ring_mul(std::move(x_a), y_b)); + ring_mul_(x_a, y_b); + ring_add_(b, x_a); } - return z.as(x.eltype()); + return b.as(x.eltype()); } NdArrayRef SquareA::proc(KernelEvalContext* ctx, const NdArrayRef& x) const { @@ -340,6 +322,51 @@ NdArrayRef SquareA::proc(KernelEvalContext* ctx, const NdArrayRef& x) const { return z.as(x.eltype()); } +// Let x be AShrTy, y be BShrTy, nbits(y) == 1 +// (x0+x1) * (y0^y1) = (x0+x1) * (y0+y1-2y0y1) +// we define xx0 = (1-2y0)x0, xx1 = (1-2y1)x1 +// yy0 = y0, yy1 = y1 +// if we can compute z0+z1 = xx0*yy1 + xx1*yy0 (which can be easily got from Mul +// Beaver), then (x0+x1) * (y0^y1) = (z0 + z1) + (x0y0 + x1y1) +NdArrayRef MulA1B::proc(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) const { + SPU_ENFORCE(x.eltype().as()->field() == + y.eltype().as()->field()); + + const auto field = x.eltype().as()->field(); + auto* comm = ctx->getState(); + + // IMPORTANT: the underlying value of y is not exactly 0 or 1, so we must mask + // it explicitly. + auto yy = ring_bitmask(y, 0, 1).as(makeType(field)); + // To optimize memory usage, re-use xx buffer + auto xx = ring_ones(field, x.shape()); + ring_sub_(xx, ring_lshift(yy, {1})); + ring_mul_(xx, x); + + auto [a, b, c, xx_a, yy_b] = MulOpen(ctx, xx, yy, false); + + // Zi = Ci + (XX - A) * Bi + (YY - B) * Ai + <(XX - A) * (YY - B)> - XXi * YYi + // We re-use b to compute z + ring_mul_(b, xx_a); + ring_mul_(a, yy_b); + ring_add_(b, a); + ring_add_(b, c); + + ring_mul_(xx, yy); + ring_sub_(b, xx); + if (comm->getRank() == 0) { + // z += (XX-A) * (YY-B); + ring_mul_(xx_a, yy_b); + ring_add_(b, xx_a); + } + + // zi += xi * yi + ring_add_(b, ring_mul(x, yy)); + + return b.as(x.eltype()); +} + //////////////////////////////////////////////////////////////////// // matmul family //////////////////////////////////////////////////////////////////// @@ -364,10 +391,7 @@ NdArrayRef MatMulAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, } NdArrayRef LShiftA::proc(KernelEvalContext*, const NdArrayRef& in, - size_t bits) const { - const auto field = in.eltype().as()->field(); - bits %= SizeOf(field) * 8; - + const Sizes& bits) const { return ring_lshift(in, bits).as(in.eltype()); } @@ -381,7 +405,7 @@ NdArrayRef TruncA::proc(KernelEvalContext* ctx, const NdArrayRef& x, if (comm->getWorldSize() == 2) { // SecureML, local truncation. // Ref: Theorem 1. https://eprint.iacr.org/2017/396.pdf - return ring_arshift(x, bits).as(x.eltype()); + return ring_arshift(x, {static_cast(bits)}).as(x.eltype()); } else { // ABY3, truncation pair method. // Ref: Section 5.1.2 https://eprint.iacr.org/2018/403.pdf @@ -396,10 +420,10 @@ NdArrayRef TruncA::proc(KernelEvalContext* ctx, const NdArrayRef& x, x.shape()); // open x - r - auto x_r = comm->allReduce(ReduceOp::ADD, ring_sub(x, r), kBindName); + auto x_r = comm->allReduce(ReduceOp::ADD, ring_sub(x, r), kBindName()); auto res = rb; if (comm->getRank() == 0) { - ring_add_(res, ring_arshift(x_r, bits)); + ring_add_(res, ring_arshift(x_r, {static_cast(bits)})); } // res = [x-r] + [r], x which [*] is truncation operation. @@ -418,7 +442,7 @@ NdArrayRef TruncAPr::proc(KernelEvalContext* ctx, const NdArrayRef& in, NdArrayRef out(in.eltype(), in.shape()); - DISPATCH_ALL_FIELDS(field, "semi2k.truncpr", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; auto [r, rc, rb] = beaver->TruncPr(field, numel, bits); SPU_ENFORCE(static_cast(r.size()) == numel * SizeOf(field)); @@ -447,7 +471,7 @@ NdArrayRef TruncAPr::proc(KernelEvalContext* ctx, const NdArrayRef& in, x_plus_r[idx] = x + _r[idx]; }); // open + = c - c = comm->allReduce(x_plus_r, kBindName); + c = comm->allReduce(x_plus_r, kBindName()); } pforeach(0, numel, [&](int64_t idx) { diff --git a/libspu/mpc/semi2k/arithmetic.h b/libspu/mpc/semi2k/arithmetic.h index 221a9f44..8b887ed5 100644 --- a/libspu/mpc/semi2k/arithmetic.h +++ b/libspu/mpc/semi2k/arithmetic.h @@ -20,7 +20,7 @@ namespace spu::mpc::semi2k { class RandA : public RandKernel { public: - static constexpr char kBindName[] = "rand_a"; + static constexpr const char* kBindName() { return "rand_a"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -31,7 +31,7 @@ class RandA : public RandKernel { class P2A : public UnaryKernel { public: - static constexpr char kBindName[] = "p2a"; + static constexpr const char* kBindName() { return "p2a"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -42,7 +42,7 @@ class P2A : public UnaryKernel { class A2P : public UnaryKernel { public: - static constexpr char kBindName[] = "a2p"; + static constexpr const char* kBindName() { return "a2p"; } ce::CExpr latency() const override { return ce::Const(1); } @@ -53,7 +53,7 @@ class A2P : public UnaryKernel { class A2V : public RevealToKernel { public: - static constexpr char kBindName[] = "a2v"; + static constexpr const char* kBindName() { return "a2v"; } // TODO: communication is unbalanced Kind kind() const override { return Kind::Dynamic; } @@ -68,7 +68,7 @@ class A2V : public RevealToKernel { class V2A : public UnaryKernel { public: - static constexpr char kBindName[] = "v2a"; + static constexpr const char* kBindName() { return "v2a"; } // TODO: communication is unbalanced Kind kind() const override { return Kind::Dynamic; } @@ -80,9 +80,9 @@ class V2A : public UnaryKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override; }; -class NotA : public UnaryKernel { +class NegateA : public UnaryKernel { public: - static constexpr char kBindName[] = "not_a"; + static constexpr const char* kBindName() { return "negate_a"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -96,7 +96,7 @@ class NotA : public UnaryKernel { //////////////////////////////////////////////////////////////////// class AddAP : public BinaryKernel { public: - static constexpr char kBindName[] = "add_ap"; + static constexpr const char* kBindName() { return "add_ap"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -108,7 +108,7 @@ class AddAP : public BinaryKernel { class AddAA : public BinaryKernel { public: - static constexpr char kBindName[] = "add_aa"; + static constexpr const char* kBindName() { return "add_aa"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -123,7 +123,7 @@ class AddAA : public BinaryKernel { //////////////////////////////////////////////////////////////////// class MulAP : public BinaryKernel { public: - static constexpr char kBindName[] = "mul_ap"; + static constexpr const char* kBindName() { return "mul_ap"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -135,7 +135,7 @@ class MulAP : public BinaryKernel { class MulAA : public BinaryKernel { public: - static constexpr char kBindName[] = "mul_aa"; + static constexpr const char* kBindName() { return "mul_aa"; } ce::CExpr latency() const override { // TODO: consider beaver @@ -150,7 +150,7 @@ class MulAA : public BinaryKernel { class SquareA : public UnaryKernel { public: - static constexpr char kBindName[] = "square_a"; + static constexpr const char* kBindName() { return "square_a"; } ce::CExpr latency() const override { // TODO: consider beaver @@ -162,12 +162,28 @@ class SquareA : public UnaryKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x) const override; }; +// Note: only for 2PC. +class MulA1B : public BinaryKernel { + public: + static constexpr const char* kBindName() { return "mul_a1b"; } + + ce::CExpr latency() const override { + // TODO: consider beaver + return ce::Const(1); + } + + ce::CExpr comm() const override { return ce::K() * 2; } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) const override; +}; + //////////////////////////////////////////////////////////////////// // matmul family //////////////////////////////////////////////////////////////////// class MatMulAP : public MatmulKernel { public: - static constexpr char kBindName[] = "mmul_ap"; + static constexpr const char* kBindName() { return "mmul_ap"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -179,7 +195,7 @@ class MatMulAP : public MatmulKernel { class MatMulAA : public MatmulKernel { public: - static constexpr char kBindName[] = "mmul_aa"; + static constexpr const char* kBindName() { return "mmul_aa"; } ce::CExpr latency() const override { // only count online for now. @@ -199,19 +215,19 @@ class MatMulAA : public MatmulKernel { class LShiftA : public ShiftKernel { public: - static constexpr char kBindName[] = "lshift_a"; + static constexpr const char* kBindName() { return "lshift_a"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; class TruncA : public TruncAKernel { public: - static constexpr char kBindName[] = "trunc_a"; + static constexpr const char* kBindName() { return "trunc_a"; } // TODO: handle case > 3PC Kind kind() const override { return Kind::Dynamic; } @@ -237,7 +253,7 @@ class TruncA : public TruncAKernel { // https://eprint.iacr.org/2020/338.pdf class TruncAPr : public TruncAKernel { public: - static constexpr char kBindName[] = "trunc_a"; + static constexpr const char* kBindName() { return "trunc_a"; } Kind kind() const override { return Kind::Static; } @@ -257,7 +273,7 @@ class TruncAPr : public TruncAKernel { class BeaverCacheKernel : public Kernel { public: - static constexpr char kBindName[] = "beaver_cache"; + static constexpr const char* kBindName() { return "beaver_cache"; } void evaluate(KernelEvalContext* ctx) const override; }; diff --git a/libspu/mpc/semi2k/beaver/beaver_cache.h b/libspu/mpc/semi2k/beaver/beaver_cache.h index 3c9d9ade..5de345ff 100644 --- a/libspu/mpc/semi2k/beaver/beaver_cache.h +++ b/libspu/mpc/semi2k/beaver/beaver_cache.h @@ -32,9 +32,11 @@ namespace spu::mpc::semi2k { class BeaverCache { public: + // clang-format off BeaverCache() : cache_db_(fmt::format("BeaverCache.{}.{}.{}", getpid(), fmt::ptr(this), std::random_device()())) {}; + // clang-format on ~BeaverCache() { db_.reset(); try { diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/BUILD.bazel b/libspu/mpc/semi2k/beaver/beaver_impl/BUILD.bazel index 423b53b0..5f0bd1e1 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/BUILD.bazel +++ b/libspu/mpc/semi2k/beaver/beaver_impl/BUILD.bazel @@ -24,6 +24,8 @@ spu_cc_library( "//libspu/mpc/common:prg_tensor", "//libspu/mpc/semi2k/beaver:beaver_interface", "//libspu/mpc/semi2k/beaver/beaver_impl/trusted_party", + "//libspu/mpc/semi2k/beaver/beaver_impl/ttp_server:beaver_stream", + "//libspu/mpc/utils:gfmp_ops", "//libspu/mpc/utils:ring_ops", "@com_github_microsoft_seal//:seal", "@yacl//yacl/link", @@ -39,6 +41,7 @@ spu_cc_test( ":beaver_ttp", "//libspu/core:xt_helper", "//libspu/mpc/semi2k/beaver/beaver_impl/ttp_server:beaver_server", + "//libspu/mpc/utils:gfmp", "//libspu/mpc/utils:permute", "//libspu/mpc/utils:simulate", "@com_google_googletest//:gtest", @@ -52,9 +55,11 @@ spu_cc_library( deps = [ "//libspu/mpc/common:prg_tensor", "//libspu/mpc/semi2k/beaver:beaver_interface", + "//libspu/mpc/semi2k/beaver/beaver_impl/ttp_server:beaver_stream", "//libspu/mpc/semi2k/beaver/beaver_impl/ttp_server:service_cc_proto", + "//libspu/mpc/utils:gfmp_ops", "//libspu/mpc/utils:ring_ops", - "@yacl//yacl/crypto/pke:asymmetric_sm2_crypto", + "@yacl//yacl/crypto/pke:sm2_enc", "@yacl//yacl/link", "@yacl//yacl/link/algorithm:barrier", "@yacl//yacl/utils:parallel", diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_test.cc b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_test.cc index 636766ec..300a2f6e 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_test.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_test.cc @@ -14,7 +14,6 @@ #include -#include "fmt/format.h" #include "gtest/gtest.h" #include "yacl/crypto/key_utils.h" #include "yacl/link/algorithm/barrier.h" @@ -25,6 +24,7 @@ #include "libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.h" #include "libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.h" #include "libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server.h" +#include "libspu/mpc/utils/gfmp.h" #include "libspu/mpc/utils/permute.h" #include "libspu/mpc/utils/ring_ops.h" #include "libspu/mpc/utils/simulate.h" @@ -159,6 +159,60 @@ std::vector open_buffer(std::vector& in_buffers, } return ret; } + +template +std::vector open_buffer_gfmp(std::vector& in_buffers, + FieldType k_field, + const std::vector& shapes, + size_t k_world_size, bool add_open) { + std::vector ret; + + auto reduce = [&](NdArrayRef& r, yacl::Buffer& b) { + if (b.size() == 0) { + return; + } + EXPECT_EQ(b.size(), r.shape().numel() * SizeOf(k_field)); + NdArrayRef a(std::make_shared(std::move(b)), ret[0].eltype(), + r.shape()); + auto Ta = r.eltype(); + gfmp_add_mod_(r, a.as(Ta)); + }; + if constexpr (std::is_same_v) { + ret.resize(3); + SPU_ENFORCE(shapes.size() == 3); + for (size_t i = 0; i < shapes.size(); i++) { + ret[i] = gfmp_zeros(k_field, shapes[i]); + } + for (Rank r = 0; r < k_world_size; r++) { + auto& [a_buf, b_buf, c_buf] = in_buffers[r]; + reduce(ret[0], a_buf); + reduce(ret[1], b_buf); + reduce(ret[2], c_buf); + } + } else if constexpr (std::is_same_v) { + ret.resize(2); + SPU_ENFORCE(shapes.size() == 2); + for (size_t i = 0; i < shapes.size(); i++) { + ret[i] = gfmp_zeros(k_field, shapes[i]); + } + for (Rank r = 0; r < k_world_size; r++) { + auto& [a_buf, b_buf] = in_buffers[r]; + reduce(ret[0], a_buf); + reduce(ret[1], b_buf); + } + } else if constexpr (std::is_same_v) { + ret.resize(1); + SPU_ENFORCE(shapes.size() == 1); + for (size_t i = 0; i < shapes.size(); i++) { + ret[i] = gfmp_zeros(k_field, shapes[i]); + } + for (Rank r = 0; r < k_world_size; r++) { + auto& a_buf = in_buffers[r]; + reduce(ret[0], a_buf); + } + } + return ret; +} } // namespace TEST_P(BeaverTest, Mul_large) { @@ -187,7 +241,7 @@ TEST_P(BeaverTest, Mul_large) { auto open = open_buffer(triples, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); NdArrayView _c(open[2]); @@ -214,13 +268,13 @@ TEST_P(BeaverTest, Mul_large) { auto open = open_buffer(triples, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); - NdArrayView _cache_a(x_cache); + NdArrayView _a_cache(x_cache); NdArrayView _b(open[1]); NdArrayView _c(open[2]); for (auto idx = 0; idx < _a.numel(); idx++) { - EXPECT_EQ(_cache_a[idx], _a[idx]); + EXPECT_EQ(_a_cache[idx], _a[idx]); auto t = _a[idx] * _b[idx]; auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; EXPECT_LE(err, kMaxDiff); @@ -240,13 +294,13 @@ TEST_P(BeaverTest, Mul_large) { auto open = open_buffer(triples, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); - NdArrayView _cache_b(y_cache); + NdArrayView _b_cache(y_cache); NdArrayView _c(open[2]); for (auto idx = 0; idx < _a.numel(); idx++) { - EXPECT_EQ(_cache_b[idx], _b[idx]); + EXPECT_EQ(_b_cache[idx], _b[idx]); auto t = _a[idx] * _b[idx]; auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; EXPECT_LE(err, kMaxDiff); @@ -267,15 +321,15 @@ TEST_P(BeaverTest, Mul_large) { auto open = open_buffer(triples, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); - NdArrayView _cache_a(x_cache); - NdArrayView _cache_b(y_cache); + NdArrayView _a_cache(x_cache); + NdArrayView _b_cache(y_cache); NdArrayView _c(open[2]); for (auto idx = 0; idx < _a.numel(); idx++) { - EXPECT_EQ(_cache_a[idx], _a[idx]); - EXPECT_EQ(_cache_b[idx], _b[idx]); + EXPECT_EQ(_a_cache[idx], _a[idx]); + EXPECT_EQ(_b_cache[idx], _b[idx]); auto t = _a[idx] * _b[idx]; auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; EXPECT_LE(err, kMaxDiff); @@ -297,17 +351,17 @@ TEST_P(BeaverTest, Mul_large) { auto open = open_buffer(triples, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); - NdArrayView _cache_a(x_cache); - NdArrayView _cache_b(y_cache); + NdArrayView _a_cache(x_cache); + NdArrayView _b_cache(y_cache); NdArrayView _c(open[2]); for (auto idx = 0; idx < _a.numel(); idx++) { // mul not support transpose. // enforce ne - EXPECT_NE(_cache_a[idx], _a[idx]); - EXPECT_NE(_cache_b[idx], _b[idx]); + EXPECT_NE(_a_cache[idx], _a[idx]); + EXPECT_NE(_b_cache[idx], _b[idx]); auto t = _a[idx] * _b[idx]; auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; EXPECT_LE(err, kMaxDiff); @@ -342,7 +396,7 @@ TEST_P(BeaverTest, Mul) { auto open = open_buffer(triples, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); NdArrayView _c(open[2]); @@ -369,13 +423,13 @@ TEST_P(BeaverTest, Mul) { auto open = open_buffer(triples, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); - NdArrayView _cache_a(x_cache); + NdArrayView _a_cache(x_cache); NdArrayView _b(open[1]); NdArrayView _c(open[2]); for (auto idx = 0; idx < _a.numel(); idx++) { - EXPECT_EQ(_cache_a[idx], _a[idx]); + EXPECT_EQ(_a_cache[idx], _a[idx]); auto t = _a[idx] * _b[idx]; auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; EXPECT_LE(err, kMaxDiff); @@ -395,13 +449,13 @@ TEST_P(BeaverTest, Mul) { auto open = open_buffer(triples, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); - NdArrayView _cache_b(y_cache); + NdArrayView _b_cache(y_cache); NdArrayView _c(open[2]); for (auto idx = 0; idx < _a.numel(); idx++) { - EXPECT_EQ(_cache_b[idx], _b[idx]); + EXPECT_EQ(_b_cache[idx], _b[idx]); auto t = _a[idx] * _b[idx]; auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; EXPECT_LE(err, kMaxDiff); @@ -422,15 +476,15 @@ TEST_P(BeaverTest, Mul) { auto open = open_buffer(triples, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); - NdArrayView _cache_a(x_cache); - NdArrayView _cache_b(y_cache); + NdArrayView _a_cache(x_cache); + NdArrayView _b_cache(y_cache); NdArrayView _c(open[2]); for (auto idx = 0; idx < _a.numel(); idx++) { - EXPECT_EQ(_cache_a[idx], _a[idx]); - EXPECT_EQ(_cache_b[idx], _b[idx]); + EXPECT_EQ(_a_cache[idx], _a[idx]); + EXPECT_EQ(_b_cache[idx], _b[idx]); auto t = _a[idx] * _b[idx]; auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; EXPECT_LE(err, kMaxDiff); @@ -452,17 +506,17 @@ TEST_P(BeaverTest, Mul) { auto open = open_buffer(triples, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); - NdArrayView _cache_a(x_cache); - NdArrayView _cache_b(y_cache); + NdArrayView _a_cache(x_cache); + NdArrayView _b_cache(y_cache); NdArrayView _c(open[2]); for (auto idx = 0; idx < _a.numel(); idx++) { // mul not support transpose. // enforce ne - EXPECT_NE(_cache_a[idx], _a[idx]); - EXPECT_NE(_cache_b[idx], _b[idx]); + EXPECT_NE(_a_cache[idx], _a[idx]); + EXPECT_NE(_b_cache[idx], _b[idx]); auto t = _a[idx] * _b[idx]; auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; EXPECT_LE(err, kMaxDiff); @@ -471,6 +525,176 @@ TEST_P(BeaverTest, Mul) { } } +TEST_P(BeaverTest, MulGfmp) { + const auto factory = std::get<0>(GetParam()).first; + const size_t kWorldSize = std::get<1>(GetParam()); + const FieldType kField = std::get<2>(GetParam()); + const int64_t kMaxDiff = std::get<3>(GetParam()); + const size_t adjust_rank = std::get<4>(GetParam()); + const int64_t kNumel = 7; + + std::vector triples(kWorldSize); + + std::vector x_desc(kWorldSize); + std::vector y_desc(kWorldSize); + NdArrayRef x_cache; + NdArrayRef y_cache; + { + utils::simulate( + kWorldSize, [&](const std::shared_ptr& lctx) { + auto beaver = factory(lctx, ttp_options_, adjust_rank); + triples[lctx->Rank()] = + beaver->Mul(kField, kNumel, &x_desc[lctx->Rank()], + &y_desc[lctx->Rank()], ElementType::kGfmp); + yacl::link::Barrier(lctx, "BeaverUT"); + }); + + auto open = open_buffer_gfmp( + triples, kField, std::vector(3, {kNumel}), kWorldSize, true); + + DISPATCH_ALL_FIELDS(kField, [&]() { + NdArrayView _a(open[0]); + NdArrayView _b(open[1]); + NdArrayView _c(open[2]); + for (auto idx = 0; idx < _a.numel(); idx++) { + auto prime = ScalarTypeToPrime::prime; + auto t = mul_mod(_a[idx], _b[idx]); + auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; + auto error_mod_p = static_cast(err) % prime; + EXPECT_LE(error_mod_p, kMaxDiff); + } + }); + + x_cache = open[0]; + y_cache = open[1]; + } + { + utils::simulate(kWorldSize, + [&](const std::shared_ptr& lctx) { + auto beaver = factory(lctx, ttp_options_, adjust_rank); + x_desc[lctx->Rank()].status = Beaver::Replay; + triples[lctx->Rank()] = + beaver->Mul(kField, kNumel, &x_desc[lctx->Rank()], + nullptr, ElementType::kGfmp); + yacl::link::Barrier(lctx, "BeaverUT"); + }); + + auto open = open_buffer_gfmp( + triples, kField, std::vector(3, {kNumel}), kWorldSize, true); + + DISPATCH_ALL_FIELDS(kField, [&]() { + NdArrayView _a(open[0]); + NdArrayView _a_cache(x_cache); + NdArrayView _b(open[1]); + NdArrayView _c(open[2]); + for (auto idx = 0; idx < _a.numel(); idx++) { + auto prime = ScalarTypeToPrime::prime; + EXPECT_EQ(_a_cache[idx], _a[idx]); + auto t = mul_mod(_a[idx], _b[idx]) % prime; + auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; + auto error_mod_p = static_cast(err) % prime; + EXPECT_LE(error_mod_p, kMaxDiff); + } + }); + } + { + utils::simulate( + kWorldSize, [&](const std::shared_ptr& lctx) { + auto beaver = factory(lctx, ttp_options_, adjust_rank); + y_desc[lctx->Rank()].status = Beaver::Replay; + triples[lctx->Rank()] = + beaver->Mul(kField, kNumel, nullptr, &y_desc[lctx->Rank()], + ElementType::kGfmp); + yacl::link::Barrier(lctx, "BeaverUT"); + }); + + auto open = open_buffer_gfmp( + triples, kField, std::vector(3, {kNumel}), kWorldSize, true); + + DISPATCH_ALL_FIELDS(kField, [&]() { + NdArrayView _a(open[0]); + NdArrayView _b(open[1]); + NdArrayView _b_cache(y_cache); + NdArrayView _c(open[2]); + for (auto idx = 0; idx < _a.numel(); idx++) { + EXPECT_EQ(_b_cache[idx], _b[idx]); + auto t = mul_mod(_a[idx], _b[idx]); + auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; + auto prime = ScalarTypeToPrime::prime; + auto error_mod_p = static_cast(err) % prime; + EXPECT_LE(error_mod_p, kMaxDiff); + } + }); + } + { + utils::simulate( + kWorldSize, [&](const std::shared_ptr& lctx) { + auto beaver = factory(lctx, ttp_options_, adjust_rank); + x_desc[lctx->Rank()].status = Beaver::Replay; + y_desc[lctx->Rank()].status = Beaver::Replay; + triples[lctx->Rank()] = + beaver->Mul(kField, kNumel, &x_desc[lctx->Rank()], + &y_desc[lctx->Rank()], ElementType::kGfmp); + yacl::link::Barrier(lctx, "BeaverUT"); + }); + + auto open = open_buffer_gfmp( + triples, kField, std::vector(3, {kNumel}), kWorldSize, true); + + DISPATCH_ALL_FIELDS(kField, [&]() { + NdArrayView _a(open[0]); + NdArrayView _b(open[1]); + NdArrayView _a_cache(x_cache); + NdArrayView _b_cache(y_cache); + NdArrayView _c(open[2]); + for (auto idx = 0; idx < _a.numel(); idx++) { + EXPECT_EQ(_a_cache[idx], _a[idx]); + EXPECT_EQ(_b_cache[idx], _b[idx]); + auto t = mul_mod(_a[idx], _b[idx]); + auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; + auto prime = ScalarTypeToPrime::prime; + auto error_mod_p = static_cast(err) % prime; + EXPECT_LE(error_mod_p, kMaxDiff); + } + }); + } + { + utils::simulate( + kWorldSize, [&](const std::shared_ptr& lctx) { + auto beaver = factory(lctx, ttp_options_, adjust_rank); + x_desc[lctx->Rank()].status = Beaver::TransposeReplay; + y_desc[lctx->Rank()].status = Beaver::TransposeReplay; + // mul not support transpose. + triples[lctx->Rank()] = + beaver->Mul(kField, kNumel, &x_desc[lctx->Rank()], + &y_desc[lctx->Rank()], ElementType::kGfmp); + yacl::link::Barrier(lctx, "BeaverUT"); + }); + + auto open = open_buffer_gfmp( + triples, kField, std::vector(3, {kNumel}), kWorldSize, true); + + DISPATCH_ALL_FIELDS(kField, [&]() { + NdArrayView _a(open[0]); + NdArrayView _b(open[1]); + NdArrayView _a_cache(x_cache); + NdArrayView _b_cache(y_cache); + NdArrayView _c(open[2]); + for (auto idx = 0; idx < _a.numel(); idx++) { + // mul not support transpose. + // enforce ne + EXPECT_NE(_a_cache[idx], _a[idx]); + EXPECT_NE(_b_cache[idx], _b[idx]); + auto t = mul_mod(_a[idx], _b[idx]); + auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; + auto prime = ScalarTypeToPrime::prime; + auto error_mod_p = static_cast(err) % prime; + EXPECT_LE(error_mod_p, kMaxDiff); + } + }); + } +} + TEST_P(BeaverTest, And) { const auto factory = std::get<0>(GetParam()).first; const size_t kWorldSize = std::get<1>(GetParam()); @@ -539,7 +763,7 @@ TEST_P(BeaverTest, Dot) { kWorldSize, true); auto res = ring_mmul(open[0], open[1]); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _r(res); NdArrayView _c(open[2]); for (auto idx = 0; idx < res.numel(); idx++) { @@ -565,13 +789,13 @@ TEST_P(BeaverTest, Dot) { kWorldSize, true); auto res = ring_mmul(x_cache, open[1]); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); - NdArrayView _cache_a(x_cache); + NdArrayView _a_cache(x_cache); NdArrayView _r(res); NdArrayView _c(open[2]); for (auto idx = 0; idx < res.numel(); idx++) { - EXPECT_EQ(_cache_a[idx], _a[idx]); + EXPECT_EQ(_a_cache[idx], _a[idx]); auto err = _r[idx] > _c[idx] ? _r[idx] - _c[idx] : _c[idx] - _r[idx]; EXPECT_LE(err, kMaxDiff); } @@ -592,13 +816,13 @@ TEST_P(BeaverTest, Dot) { kWorldSize, true); auto res = ring_mmul(open[0], y_cache); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _b(open[1]); - NdArrayView _cache_b(y_cache); + NdArrayView _b_cache(y_cache); NdArrayView _r(res); NdArrayView _c(open[2]); for (auto idx = 0; idx < res.numel(); idx++) { - EXPECT_EQ(_cache_b[idx], _b[idx]); + EXPECT_EQ(_b_cache[idx], _b[idx]); auto err = _r[idx] > _c[idx] ? _r[idx] - _c[idx] : _c[idx] - _r[idx]; EXPECT_LE(err, kMaxDiff); } @@ -620,16 +844,16 @@ TEST_P(BeaverTest, Dot) { kWorldSize, true); auto res = ring_mmul(x_cache, y_cache); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); - NdArrayView _cache_a(x_cache); + NdArrayView _a_cache(x_cache); NdArrayView _b(open[1]); - NdArrayView _cache_b(y_cache); + NdArrayView _b_cache(y_cache); NdArrayView _r(res); NdArrayView _c(open[2]); for (auto idx = 0; idx < res.numel(); idx++) { - EXPECT_EQ(_cache_a[idx], _a[idx]); - EXPECT_EQ(_cache_b[idx], _b[idx]); + EXPECT_EQ(_a_cache[idx], _a[idx]); + EXPECT_EQ(_b_cache[idx], _b[idx]); auto err = _r[idx] > _c[idx] ? _r[idx] - _c[idx] : _c[idx] - _r[idx]; EXPECT_LE(err, kMaxDiff); } @@ -651,19 +875,19 @@ TEST_P(BeaverTest, Dot) { kWorldSize, true); auto res = ring_mmul(x_cache, y_cache); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { auto transpose_a = open[0].transpose(); NdArrayView _a(transpose_a); - NdArrayView _cache_a(y_cache); + NdArrayView _a_cache(y_cache); auto transpose_b = open[1].transpose(); NdArrayView _b(transpose_b); - NdArrayView _cache_b(x_cache); + NdArrayView _b_cache(x_cache); auto transpose_r = res.transpose(); NdArrayView _r(transpose_r); NdArrayView _c(open[2]); for (auto idx = 0; idx < res.numel(); idx++) { - EXPECT_EQ(_cache_a[idx], _a[idx]); - EXPECT_EQ(_cache_b[idx], _b[idx]); + EXPECT_EQ(_a_cache[idx], _a[idx]); + EXPECT_EQ(_b_cache[idx], _b[idx]); auto err = _r[idx] > _c[idx] ? _r[idx] - _c[idx] : _c[idx] - _r[idx]; EXPECT_LE(err, kMaxDiff); } @@ -684,13 +908,13 @@ TEST_P(BeaverTest, Dot) { auto open = open_buffer(triples, kField, std::vector(3, {M * K}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); - NdArrayView _cache_a(x_cache); + NdArrayView _a_cache(x_cache); NdArrayView _b(open[1]); NdArrayView _c(open[2]); for (auto idx = 0; idx < _a.numel(); idx++) { - EXPECT_EQ(_cache_a[idx], _a[idx]); + EXPECT_EQ(_a_cache[idx], _a[idx]); auto t = _a[idx] * _b[idx]; auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; EXPECT_LE(err, kMaxDiff); @@ -725,7 +949,7 @@ TEST_P(BeaverTest, Dot_large) { open_buffer(triples, kField, {{M, K}, {K, N}, {M, N}}, kWorldSize, true); auto res = ring_mmul(open[0], open[1]); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _r(res); NdArrayView _c(open[2]); for (auto idx = 0; idx < res.numel(); idx++) { @@ -741,7 +965,7 @@ TEST_P(BeaverTest, Trunc) { const FieldType kField = std::get<2>(GetParam()); const size_t adjust_rank = std::get<4>(GetParam()); const int64_t kNumel = 7; - const size_t kBits = 5; + const int64_t kBits = 5; std::vector pairs; pairs.resize(kWorldSize); @@ -756,7 +980,7 @@ TEST_P(BeaverTest, Trunc) { EXPECT_EQ(pairs.size(), kWorldSize); auto open = open_buffer(pairs, kField, {{kNumel}, {kNumel}}, kWorldSize, true); - EXPECT_TRUE(ring_all_equal(ring_arshift(open[0], kBits), open[1], 0)); + EXPECT_TRUE(ring_all_equal(ring_arshift(open[0], {kBits}), open[1], 0)); } TEST_P(BeaverTest, TruncPr) { @@ -782,7 +1006,7 @@ TEST_P(BeaverTest, TruncPr) { auto open = open_buffer(rets, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "semi2k.truncpr.ut", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { using T = ring2k_t; auto sum_r_iter = open[0].begin(); auto sum_rc_iter = open[1].begin(); @@ -821,7 +1045,7 @@ TEST_P(BeaverTest, Randbit) { EXPECT_EQ(shares.size(), kWorldSize); auto open = open_buffer(shares, kField, {{kNumel}}, kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { using scalar_t = typename Ring2kTrait<_kField>::scalar_t; auto x = xt_adapt(open[0]); EXPECT_TRUE(xt::all(x <= xt::ones_like(x))); @@ -871,7 +1095,9 @@ TEST_P(BeaverTest, PermPair) { const size_t adjust_rank = std::get<4>(GetParam()); const int64_t kNumel = 10; std::random_device rd; - const auto r_perm = genRandomPerm(kNumel, rd()); + uint128_t seed = rd(); + uint64_t ctr = rd(); + const auto r_perm = genRandomPerm(kNumel, seed, &ctr); for (size_t r = 0; r < kWorldSize; ++r) { std::vector pairs(kWorldSize); diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.cc b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.cc index 574d892a..f876209d 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.cc @@ -23,6 +23,7 @@ #include "libspu/mpc/common/prg_tensor.h" #include "libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.h" +#include "libspu/mpc/utils/gfmp_ops.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::semi2k { @@ -32,9 +33,9 @@ namespace { inline size_t CeilDiv(size_t a, size_t b) { return (a + b - 1) / b; } void FillReplayDesc(Beaver::ReplayDesc* desc, FieldType field, int64_t size, - const std::vector& encrypted_seeds, - PrgCounter counter, PrgSeed self_seed) { + PrgCounter counter, PrgSeed self_seed, + ElementType eltype = ElementType::kRing) { if (desc == nullptr || desc->status != Beaver::Init) { return; } @@ -43,12 +44,13 @@ void FillReplayDesc(Beaver::ReplayDesc* desc, FieldType field, int64_t size, desc->prg_counter = counter; desc->encrypted_seeds = encrypted_seeds; desc->seed = self_seed; + desc->eltype = eltype; } } // namespace BeaverTfpUnsafe::BeaverTfpUnsafe(std::shared_ptr lctx) - : lctx_(std::move(std::move(lctx))), + : lctx_(std::move(lctx)), seed_(yacl::crypto::SecureRandSeed()), counter_(0) { auto buf = yacl::SerializeUint128(seed_); @@ -67,7 +69,8 @@ BeaverTfpUnsafe::BeaverTfpUnsafe(std::shared_ptr lctx) BeaverTfpUnsafe::Triple BeaverTfpUnsafe::Mul(FieldType field, int64_t size, ReplayDesc* x_desc, - ReplayDesc* y_desc) { + ReplayDesc* y_desc, + ElementType eltype) { std::vector ops(3); Shape shape({size, 1}); std::vector> replay_seeds(3); @@ -75,9 +78,13 @@ BeaverTfpUnsafe::Triple BeaverTfpUnsafe::Mul(FieldType field, int64_t size, auto if_replay = [&](const ReplayDesc* replay_desc, size_t idx) { if (replay_desc == nullptr || replay_desc->status != Beaver::Replay) { ops[idx].seeds = seeds_; - return prgCreateArray(field, shape, seed_, &counter_, &ops[idx].desc); + // enforce the eltypes in ops + ops[idx].desc.eltype = eltype; + return prgCreateArray(field, shape, seed_, &counter_, &ops[idx].desc, + eltype); } else { SPU_ENFORCE(replay_desc->field == field); + SPU_ENFORCE(replay_desc->eltype == eltype); SPU_ENFORCE(replay_desc->size == size); if (lctx_->Rank() == 0) { SPU_ENFORCE(replay_desc->encrypted_seeds.size() == lctx_->WorldSize()); @@ -90,25 +97,31 @@ BeaverTfpUnsafe::Triple BeaverTfpUnsafe::Mul(FieldType field, int64_t size, } ops[idx].seeds = replay_seeds[idx]; ops[idx].desc.field = field; + ops[idx].desc.eltype = eltype; ops[idx].desc.shape = shape; ops[idx].desc.prg_counter = replay_desc->prg_counter; } PrgCounter tmp_counter = replay_desc->prg_counter; return prgCreateArray(field, shape, replay_desc->seed, &tmp_counter, - nullptr); + nullptr, eltype); } }; - FillReplayDesc(x_desc, field, size, seeds_buff_, counter_, seed_); + FillReplayDesc(x_desc, field, size, seeds_buff_, counter_, seed_, eltype); auto a = if_replay(x_desc, 0); - FillReplayDesc(y_desc, field, size, seeds_buff_, counter_, seed_); + FillReplayDesc(y_desc, field, size, seeds_buff_, counter_, seed_, eltype); auto b = if_replay(y_desc, 1); - auto c = prgCreateArray(field, shape, seed_, &counter_, &ops[2].desc); + auto c = prgCreateArray(field, shape, seed_, &counter_, &ops[2].desc, eltype); if (lctx_->Rank() == 0) { ops[2].seeds = seeds_; - auto adjust = TrustedParty::adjustMul(ops); - ring_add_(c, adjust); + auto adjust = TrustedParty::adjustMul(absl::MakeSpan(ops)); + if (eltype == ElementType::kGfmp) { + auto T = c.eltype(); + gfmp_add_mod_(c, adjust.as(T)); + } else { + ring_add_(c, adjust); + } } Triple ret; @@ -119,6 +132,37 @@ BeaverTfpUnsafe::Triple BeaverTfpUnsafe::Mul(FieldType field, int64_t size, return ret; } +BeaverTfpUnsafe::Pair BeaverTfpUnsafe::MulPriv(FieldType field, int64_t size, + ElementType eltype) { + std::vector ops(2); + Shape shape({size, 1}); + + ops[0].seeds = seeds_; + // enforce the eltypes in ops + ops[0].desc.eltype = eltype; + ops[1].desc.eltype = eltype; + auto a_or_b = + prgCreateArray(field, shape, seed_, &counter_, &ops[0].desc, eltype); + auto c = prgCreateArray(field, shape, seed_, &counter_, &ops[1].desc, eltype); + + if (lctx_->Rank() == 0) { + ops[1].seeds = seeds_; + auto adjust = TrustedParty::adjustMulPriv(absl::MakeSpan(ops)); + if (eltype == ElementType::kGfmp) { + auto T = c.eltype(); + gfmp_add_mod_(c, adjust.as(T)); + } else { + ring_add_(c, adjust); + } + } + + Pair ret; + std::get<0>(ret) = std::move(*a_or_b.buf()); + std::get<1>(ret) = std::move(*c.buf()); + + return ret; +} + BeaverTfpUnsafe::Pair BeaverTfpUnsafe::Square(FieldType field, int64_t size, ReplayDesc* x_desc) { std::vector ops(2); @@ -158,7 +202,7 @@ BeaverTfpUnsafe::Pair BeaverTfpUnsafe::Square(FieldType field, int64_t size, if (lctx_->Rank() == 0) { ops[1].seeds = seeds_; - auto adjust = TrustedParty::adjustSquare(ops); + auto adjust = TrustedParty::adjustSquare(absl::MakeSpan(ops)); ring_add_(b, adjust); } @@ -223,7 +267,7 @@ BeaverTfpUnsafe::Triple BeaverTfpUnsafe::Dot(FieldType field, int64_t m, if (lctx_->Rank() == 0) { ops[2].seeds = seeds_; - auto adjust = TrustedParty::adjustDot(ops); + auto adjust = TrustedParty::adjustDot(absl::MakeSpan(ops)); ring_add_(c, adjust); } @@ -250,7 +294,7 @@ BeaverTfpUnsafe::Triple BeaverTfpUnsafe::And(int64_t size) { for (auto& op : ops) { op.seeds = seeds_; } - auto adjust = TrustedParty::adjustAnd(ops); + auto adjust = TrustedParty::adjustAnd(absl::MakeSpan(ops)); ring_xor_(c, adjust); } @@ -276,7 +320,7 @@ BeaverTfpUnsafe::Pair BeaverTfpUnsafe::Trunc(FieldType field, int64_t size, for (auto& op : ops) { op.seeds = seeds_; } - auto adjust = TrustedParty::adjustTrunc(ops, bits); + auto adjust = TrustedParty::adjustTrunc(absl::MakeSpan(ops), bits); ring_add_(b, adjust); } @@ -300,7 +344,7 @@ BeaverTfpUnsafe::Triple BeaverTfpUnsafe::TruncPr(FieldType field, int64_t size, for (auto& op : ops) { op.seeds = seeds_; } - auto adjusts = TrustedParty::adjustTruncPr(ops, bits); + auto adjusts = TrustedParty::adjustTruncPr(absl::MakeSpan(ops), bits); ring_add_(rc, std::get<0>(adjusts)); ring_add_(rb, std::get<1>(adjusts)); } @@ -322,7 +366,7 @@ BeaverTfpUnsafe::Array BeaverTfpUnsafe::RandBit(FieldType field, int64_t size) { for (auto& op : ops) { op.seeds = seeds_; } - auto adjust = TrustedParty::adjustRandBit(ops); + auto adjust = TrustedParty::adjustRandBit(absl::MakeSpan(ops)); ring_add_(a, adjust); } @@ -348,10 +392,11 @@ BeaverTfpUnsafe::Pair BeaverTfpUnsafe::PermPair( auto pv_buf = lctx_->Recv(perm_rank, kTag); ring_add_(b, TrustedParty::adjustPerm( - ops, absl::MakeSpan(pv_buf.data(), - pv_buf.size() / sizeof(int64_t)))); + absl::MakeSpan(ops), + absl::MakeSpan(pv_buf.data(), + pv_buf.size() / sizeof(int64_t)))); } else { - ring_add_(b, TrustedParty::adjustPerm(ops, perm_vec)); + ring_add_(b, TrustedParty::adjustPerm(absl::MakeSpan(ops), perm_vec)); } } else if (perm_rank == lctx_->Rank()) { lctx_->SendAsync( @@ -380,7 +425,7 @@ BeaverTfpUnsafe::Pair BeaverTfpUnsafe::Eqz(FieldType field, int64_t size) { for (auto& op : ops) { op.seeds = seeds_; } - auto adjust = TrustedParty::adjustEqz(ops); + auto adjust = TrustedParty::adjustEqz(absl::MakeSpan(ops)); ring_xor_(b, adjust); } diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.h b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.h index 9ca11bca..2f26a716 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.h +++ b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.h @@ -45,7 +45,11 @@ class BeaverTfpUnsafe final : public Beaver { explicit BeaverTfpUnsafe(std::shared_ptr lctx); Triple Mul(FieldType field, int64_t size, ReplayDesc* x_desc = nullptr, - ReplayDesc* y_desc = nullptr) override; + ReplayDesc* y_desc = nullptr, + ElementType eltype = ElementType::kRing) override; + + Pair MulPriv(FieldType field, int64_t size, + ElementType eltype = ElementType::kRing) override; Pair Square(FieldType field, int64_t size, ReplayDesc* x_desc = nullptr) override; diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.cc b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.cc index 4f3094c3..be2e9e86 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.cc @@ -14,14 +14,17 @@ #include "libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.h" +#include #include #include -#include "yacl/crypto/pke/asymmetric_sm2_crypto.h" +#include "yacl/crypto/pke/sm2_enc.h" #include "yacl/crypto/rand/rand.h" #include "yacl/link/algorithm/allgather.h" #include "libspu/mpc/common/prg_tensor.h" +#include "libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_stream.h" +#include "libspu/mpc/utils/gfmp_ops.h" #include "libspu/mpc/utils/ring_ops.h" namespace brpc { @@ -39,7 +42,8 @@ inline size_t CeilDiv(size_t a, size_t b) { return (a + b - 1) / b; } void FillReplayDesc(Beaver::ReplayDesc* desc, FieldType field, int64_t size, const std::vector& encrypted_seeds, - PrgCounter counter, PrgSeed self_seed) { + PrgCounter counter, PrgSeed self_seed, + ElementType eltype = ElementType::kRing) { if (desc == nullptr || desc->status != Beaver::Init) { return; } @@ -48,6 +52,7 @@ void FillReplayDesc(Beaver::ReplayDesc* desc, FieldType field, int64_t size, desc->prg_counter = counter; desc->encrypted_seeds = encrypted_seeds; desc->seed = self_seed; + desc->eltype = eltype; } template @@ -59,11 +64,15 @@ AdjustRequest BuildAdjustRequest( SPU_ENFORCE(!descs.empty()); uint32_t field_size; + ElementType eltype = ElementType::kRing; + for (size_t i = 0; i < descs.size(); i++) { const auto& desc = descs[i]; auto* input = ret.add_prg_inputs(); input->set_prg_count(desc.prg_counter); field_size = SizeOf(desc.field); + eltype = desc.eltype; + input->set_buffer_len(desc.shape.numel() * SizeOf(desc.field)); absl::Span seeds; @@ -81,22 +90,158 @@ AdjustRequest BuildAdjustRequest( beaver::ttp_server::AdjustAndRequest>) { ret.set_field_size(field_size); } + if constexpr (std::is_same_v || + std::is_same_v) { + if (eltype == ElementType::kGfmp) + ret.set_element_type(beaver::ttp_server::ElType::GFMP); + } + return ret; } template struct dependent_false : std::false_type {}; +class StreamReader : public brpc::StreamInputHandler { + public: + enum class Status : int8_t { + kNotFinished, + kNormalFinished, + kAbnormalFinished, + kStreamFailed, + }; + + StreamReader(int32_t num_buf, size_t buf_len) { + SPU_ENFORCE(num_buf > 0); + SPU_ENFORCE(buf_len > 0); + buf_vec_.resize(num_buf); + buf_len_ = buf_len; + future_finished_ = promise_finished_.get_future(); + future_closed_ = promise_closed_.get_future(); + } + + int on_received_messages(brpc::StreamId id, butil::IOBuf* const messages[], + size_t size) override { + SPDLOG_DEBUG("on_received_messages, stream id: {}", id); + for (size_t i = 0; i < size; ++i) { + if (status_ != Status::kNotFinished) { + SPDLOG_ERROR("unexpected messages received"); + return -1; + } + + SPDLOG_DEBUG("receive buf size: {}", messages[i]->size()); + const auto& message = messages[i]; + beaver::ttp_server::BeaverDownStreamMeta meta; + message->copy_to(&meta, sizeof(meta)); + message->pop_front(sizeof(meta)); + if (meta.err_code != 0) { + SPDLOG_ERROR("response error from server, err_code: {}, err_text: {}", + meta.err_code, message->to_string()); + status_ = Status::kAbnormalFinished; + promise_finished_.set_value(status_); + return -2; + } + + SPU_ENFORCE(message->length() % buf_vec_.size() == 0); + size_t msg_len = message->length() / buf_vec_.size(); + for (size_t buf_idx = 0; buf_idx < buf_vec_.size(); ++buf_idx) { + message->append_to(&buf_vec_[buf_idx], msg_len, buf_idx * msg_len); + } + + SPU_ENFORCE(buf_vec_[0].length() <= buf_len_, + "unexpected bytes received"); + if (buf_vec_[0].length() == buf_len_) { + status_ = Status::kNormalFinished; + promise_finished_.set_value(status_); + } + } + return 0; + } + + void on_idle_timeout(brpc::StreamId id) override { + SPDLOG_WARN("Stream {} idle timeout", id); + } + + void on_closed(brpc::StreamId id) override { + SPDLOG_DEBUG("Stream {} closed", id); + promise_closed_.set_value(); + } + + void on_failed(brpc::StreamId id, int error_code, + const std::string& error_text) override { + SPDLOG_ERROR("Stream {} failed, error_code: {}, error_text: {}", id, + error_code, error_text); + status_ = Status::kStreamFailed; + promise_finished_.set_value(status_); + } + + const auto& GetBufVecRef() const { + SPU_ENFORCE(status_ == Status::kNormalFinished); + return buf_vec_; + } + + Status WaitFinished() { return future_finished_.get(); }; + + void WaitClosed() { future_closed_.wait(); } + + private: + std::vector buf_vec_; + size_t buf_len_; + Status status_ = Status::kNotFinished; + std::promise promise_finished_; + std::promise promise_closed_; + std::future future_finished_; + std::future future_closed_; +}; + +// Obtain a tuple containing num_buf and buf_len template -std::vector RpcCall(brpc::Channel& channel, AdjustRequest req, - FieldType ret_field) { +std::tuple GetBufferLength(const AdjustRequest& req) { + if constexpr (std::is_same_v) { + SPU_ENFORCE_EQ(req.prg_inputs().size(), 3); + return {1, req.prg_inputs()[2].buffer_len()}; + } else if constexpr (std::is_same_v< + AdjustRequest, + beaver::ttp_server::AdjustTruncPrRequest>) { + SPU_ENFORCE_GE(req.prg_inputs().size(), 1); + return {2, req.prg_inputs()[0].buffer_len()}; + } else { + SPU_ENFORCE_GE(req.prg_inputs().size(), 1); + return {1, req.prg_inputs()[0].buffer_len()}; + } +} + +template +std::vector RpcCall( + brpc::Channel& channel, AdjustRequest req, FieldType ret_field, + const std::vector* upstream_messages = nullptr) { brpc::Controller cntl; beaver::ttp_server::BeaverService::Stub stub(&channel); beaver::ttp_server::AdjustResponse rsp; + auto [num_buf, buf_len] = GetBufferLength(req); + StreamReader reader(num_buf, buf_len); + brpc::StreamOptions stream_options; + stream_options.max_buf_size = 2 * beaver::ttp_server::kUpStreamChunkSize; + stream_options.handler = &reader; + brpc::StreamId stream_id; + SPU_ENFORCE_EQ(brpc::StreamCreate(&stream_id, cntl, &stream_options), 0, + "Failed to create stream"); + auto cleanup = absl::MakeCleanup([&stream_id, &reader]() { + SPU_ENFORCE(brpc::StreamClose(stream_id) == 0); + reader.WaitClosed(); + }); + if constexpr (std::is_same_v) { stub.AdjustMul(&cntl, &req, &rsp, nullptr); + } else if constexpr (std::is_same_v< + AdjustRequest, + beaver::ttp_server::AdjustMulPrivRequest>) { + stub.AdjustMulPriv(&cntl, &req, &rsp, nullptr); } else if constexpr (std::is_same_v< AdjustRequest, beaver::ttp_server::AdjustSquareRequest>) { @@ -135,14 +280,31 @@ std::vector RpcCall(brpc::Channel& channel, AdjustRequest req, "Adjust server failed code={}, error={}", ErrorCode_Name(rsp.code()), rsp.message()); + if (upstream_messages != nullptr) { + for (const auto& message : *upstream_messages) { + int ret = brpc::StreamWrite(stream_id, message); + if (ret == EAGAIN) { + SPU_ENFORCE_EQ(brpc::StreamWait(stream_id, nullptr), 0); + ret = brpc::StreamWrite(stream_id, message); + } + SPU_ENFORCE_EQ(ret, 0, "Write stream failed"); + SPDLOG_DEBUG("write buf size {} to stream id {}", message.length(), + stream_id); + } + } + + auto status = reader.WaitFinished(); + SPU_ENFORCE(status == StreamReader::Status::kNormalFinished, + "Stream reader finished abnormally, status: {}", + static_cast(status)); std::vector ret; - for (const auto& output : rsp.adjust_outputs()) { - SPU_ENFORCE(output.size() % SizeOf(ret_field) == 0); - int64_t size = output.size() / SizeOf(ret_field); + for (const auto& buf : reader.GetBufVecRef()) { + SPU_ENFORCE(buf.length() % SizeOf(ret_field) == 0); + int64_t size = buf.length() / SizeOf(ret_field); // FIXME: change beaver interface: change return type to buffer. NdArrayRef array(makeType(ret_field), {size}); // FIXME: TTP adjuster server and client MUST have same endianness. - std::memcpy(array.data(), output.data(), output.size()); + buf.copy_to(array.data()); ret.push_back(std::move(array)); } @@ -152,7 +314,7 @@ std::vector RpcCall(brpc::Channel& channel, AdjustRequest req, } // namespace BeaverTtp::BeaverTtp(std::shared_ptr lctx, Options ops) - : lctx_(std::move(std::move(lctx))), + : lctx_(std::move(lctx)), seed_(yacl::crypto::SecureRandSeed()), counter_(0), options_(std::move(ops)) { @@ -179,7 +341,7 @@ BeaverTtp::BeaverTtp(std::shared_ptr lctx, Options ops) yacl::Buffer encrypted_seed; { - std::unique_ptr encryptor; + std::unique_ptr encryptor; auto lower_schema = absl::AsciiStrToLower(options_.asym_crypto_schema); if (lower_schema == "sm2") { encryptor = std::make_unique( @@ -197,15 +359,18 @@ BeaverTtp::BeaverTtp(std::shared_ptr lctx, Options ops) "BEAVER_TTP:SYNC_ENCRYPTED_SEEDS"); } +// TODO: kGfmp supports more operations BeaverTtp::Triple BeaverTtp::Mul(FieldType field, int64_t size, - ReplayDesc* x_desc, ReplayDesc* y_desc) { + ReplayDesc* x_desc, ReplayDesc* y_desc, + ElementType eltype) { std::vector descs(3); std::vector> descs_seed(3, encrypted_seeds_); Shape shape({size, 1}); auto if_replay = [&](const ReplayDesc* replay_desc, size_t idx) { if (replay_desc == nullptr || replay_desc->status != Beaver::Replay) { - return prgCreateArray(field, shape, seed_, &counter_, &descs[idx]); + return prgCreateArray(field, shape, seed_, &counter_, &descs[idx], + eltype); } else { SPU_ENFORCE(replay_desc->field == field); SPU_ENFORCE(replay_desc->size == size); @@ -213,27 +378,35 @@ BeaverTtp::Triple BeaverTtp::Mul(FieldType field, int64_t size, if (lctx_->Rank() == options_.adjust_rank) { descs_seed[idx] = replay_desc->encrypted_seeds; descs[idx].field = field; + descs[idx].eltype = eltype; descs[idx].shape = shape; descs[idx].prg_counter = replay_desc->prg_counter; } PrgCounter tmp_counter = replay_desc->prg_counter; return prgCreateArray(field, shape, replay_desc->seed, &tmp_counter, - &descs[idx]); + &descs[idx], eltype); } }; - FillReplayDesc(x_desc, field, size, encrypted_seeds_, counter_, seed_); + FillReplayDesc(x_desc, field, size, encrypted_seeds_, counter_, seed_, + eltype); auto a = if_replay(x_desc, 0); - FillReplayDesc(y_desc, field, size, encrypted_seeds_, counter_, seed_); + FillReplayDesc(y_desc, field, size, encrypted_seeds_, counter_, seed_, + eltype); auto b = if_replay(y_desc, 1); - auto c = prgCreateArray(field, shape, seed_, &counter_, &descs[2]); + auto c = prgCreateArray(field, shape, seed_, &counter_, &descs[2], eltype); if (lctx_->Rank() == options_.adjust_rank) { auto req = BuildAdjustRequest( descs, descs_seed); auto adjusts = RpcCall(channel_, req, field); SPU_ENFORCE_EQ(adjusts.size(), 1U); - ring_add_(c, adjusts[0].reshape(shape)); + if (eltype == ElementType::kGfmp) { + auto T = c.eltype(); + gfmp_add_mod_(c, adjusts[0].reshape(shape).as(T)); + } else { + ring_add_(c, adjusts[0].reshape(shape)); + } } Triple ret; @@ -244,6 +417,34 @@ BeaverTtp::Triple BeaverTtp::Mul(FieldType field, int64_t size, return ret; } +BeaverTtp::Pair BeaverTtp::MulPriv(FieldType field, int64_t size, + ElementType eltype) { + std::vector descs(2); + std::vector> descs_seed(2, encrypted_seeds_); + Shape shape({size, 1}); + auto a_or_b = + prgCreateArray(field, shape, seed_, &counter_, &descs[0], eltype); + auto c = prgCreateArray(field, shape, seed_, &counter_, &descs[1], eltype); + if (lctx_->Rank() == options_.adjust_rank) { + auto req = BuildAdjustRequest( + descs, descs_seed); + auto adjusts = RpcCall(channel_, req, field); + SPU_ENFORCE_EQ(adjusts.size(), 1U); + if (eltype == ElementType::kGfmp) { + auto T = c.eltype(); + gfmp_add_mod_(c, adjusts[0].reshape(shape).as(T)); + } else { + ring_add_(c, adjusts[0].reshape(shape)); + } + } + + Pair ret; + std::get<0>(ret) = std::move(*a_or_b.buf()); + std::get<1>(ret) = std::move(*c.buf()); + + return ret; +} + BeaverTtp::Pair BeaverTtp::Square(FieldType field, int64_t size, ReplayDesc* x_desc) { std::vector descs(2); @@ -466,10 +667,20 @@ BeaverTtp::Pair BeaverTtp::PermPair(FieldType field, int64_t size, if (lctx_->Rank() == perm_rank) { auto req = BuildAdjustRequest( descs, descs_seed); - for (auto p : perm_vec) { - req.add_perm_vec(p); + std::vector stream_data; + size_t left_buf_size = perm_vec.size() * sizeof(int64_t); + size_t chunk_idx = 0; + while (left_buf_size > 0) { + using beaver::ttp_server::kUpStreamChunkSize; + size_t cur_chunk_size = std::min(left_buf_size, kUpStreamChunkSize); + stream_data.emplace_back(); + stream_data.back().append(reinterpret_cast(perm_vec.data()) + + (chunk_idx * kUpStreamChunkSize), + cur_chunk_size); + ++chunk_idx; + left_buf_size -= cur_chunk_size; } - auto adjusts = RpcCall(channel_, req, field); + auto adjusts = RpcCall(channel_, req, field, &stream_data); SPU_ENFORCE_EQ(adjusts.size(), 1U); ring_add_(b, adjusts[0].reshape(b.shape())); } diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.h b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.h index ecb39237..501d5eac 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.h +++ b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.h @@ -66,7 +66,11 @@ class BeaverTtp final : public Beaver { ~BeaverTtp() override = default; Triple Mul(FieldType field, int64_t size, ReplayDesc* x_desc = nullptr, - ReplayDesc* y_desc = nullptr) override; + ReplayDesc* y_desc = nullptr, + ElementType eltype = ElementType::kRing) override; + + Pair MulPriv(FieldType field, int64_t size, + ElementType eltype = ElementType::kRing) override; Pair Square(FieldType field, int64_t size, ReplayDesc* x_desc = nullptr) override; diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/BUILD.bazel b/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/BUILD.bazel index 5a503213..2613516a 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/BUILD.bazel +++ b/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/BUILD.bazel @@ -21,7 +21,9 @@ spu_cc_library( srcs = ["trusted_party.cc"], hdrs = ["trusted_party.h"], deps = [ + "//libspu/core:type_util", "//libspu/mpc/common:prg_tensor", + "//libspu/mpc/utils:gfmp_ops", "//libspu/mpc/utils:permute", "//libspu/mpc/utils:ring_ops", ], diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.cc b/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.cc index c16c6af3..75c528d6 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.cc @@ -14,36 +14,64 @@ #include "libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.h" +#include "libspu/core/type_util.h" +#include "libspu/mpc/common/prg_tensor.h" +#include "libspu/mpc/utils/gfmp_ops.h" #include "libspu/mpc/utils/permute.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::semi2k { namespace { +enum class ReduceOp : uint8_t { + ADD = 0, + XOR = 1, + MUL = 2, +}; + enum class RecOp : uint8_t { ADD = 0, XOR = 1, }; -std::vector reconstruct( - RecOp op, absl::Span ops) { +std::vector reduce(ReduceOp op, + absl::Span ops) { std::vector rs(ops.size()); const auto world_size = ops[0].seeds.size(); + for (size_t rank = 0; rank < world_size; rank++) { for (size_t idx = 0; idx < ops.size(); idx++) { // FIXME: TTP adjuster server and client MUST have same endianness. - auto t = prgReplayArray(ops[idx].seeds[rank], ops[idx].desc); + NdArrayRef t; + if (rank < world_size - 1) { + t = prgReplayArray(ops[idx].seeds[rank], ops[idx].desc); + } else { + t = prgReplayArrayMutable(ops[idx].seeds[rank], ops[idx].desc); + } if (rank == 0) { rs[idx] = t; } else { - if (op == RecOp::ADD) { - ring_add_(rs[idx], t); - } else if (op == RecOp::XOR) { + if (op == ReduceOp::ADD) { + if (ops[idx].desc.eltype == ElementType::kGfmp) { + // TODO: generalize the reduction + gfmp_add_mod_(rs[idx], t); + } else { + ring_add_(rs[idx], t); + } + } else if (op == ReduceOp::XOR) { + // gfmp has no xor implementation ring_xor_(rs[idx], t); + } else if (op == ReduceOp::MUL) { + if (ops[idx].desc.eltype == ElementType::kGfmp) { + // TODO: generalize the reduction + gfmp_mul_mod_(rs[idx], t); + } else { + ring_mul_(rs[idx], t); + } } else { - SPU_ENFORCE("not supported reconstruct op"); + SPU_THROW("not supported reduction op"); } } } @@ -52,11 +80,17 @@ std::vector reconstruct( return rs; } +std::vector reconstruct(RecOp op, + absl::Span ops) { + return reduce(ReduceOp(op), ops); +} + void checkOperands(absl::Span ops, bool skip_shape = false, bool allow_transpose = false) { for (size_t idx = 1; idx < ops.size(); idx++) { SPU_ENFORCE(skip_shape || ops[0].desc.shape == ops[idx].desc.shape); SPU_ENFORCE(allow_transpose || ops[0].transpose == false); + SPU_ENFORCE(ops[0].desc.eltype == ops[idx].desc.eltype); SPU_ENFORCE(ops[0].desc.field == ops[idx].desc.field); SPU_ENFORCE(ops[0].seeds.size() == ops[idx].seeds.size(), "{} <> {}", ops[0].seeds.size(), ops[idx].seeds.size()); @@ -65,24 +99,56 @@ void checkOperands(absl::Span ops, } // namespace -NdArrayRef TrustedParty::adjustMul(absl::Span ops) { +// TODO: gfmp support more operations +NdArrayRef TrustedParty::adjustMul(absl::Span ops) { SPU_ENFORCE_EQ(ops.size(), 3U); checkOperands(ops); auto rs = reconstruct(RecOp::ADD, ops); // adjust = rs[0] * rs[1] - rs[2]; - return ring_sub(ring_mul(rs[0], rs[1]), rs[2]); + if (ops[0].desc.eltype == ElementType::kGfmp) { + return gfmp_sub_mod(gfmp_mul_mod(rs[0], rs[1]), rs[2]); + } else { + ring_mul_(rs[0], rs[1]); + ring_sub_(rs[0], rs[2]); + return rs[0]; + } +} + +// ops are [a_or_b, c] +// P0 generate a, c0 +// P1 generate b, c1 +// The adjustment is ab - (c0 + c1), +// which only needs to be sent to adjust party, e.g. P0. +// P0 with adjust is ab - c1 = ab - (c0 + c1) + c0 +// Therefore, +// P0 holds: a, ab - c1 +// P1 holds: b, c1 +NdArrayRef TrustedParty::adjustMulPriv(absl::Span ops) { + SPU_ENFORCE_EQ(ops.size(), 2U); + checkOperands(ops); + + auto ab = reduce(ReduceOp::MUL, ops.subspan(0, 1))[0]; + auto c = reconstruct(RecOp::ADD, ops.subspan(1, 1))[0]; + // adjust = ab - c; + if (ops[0].desc.eltype == ElementType::kGfmp) { + return gfmp_sub_mod(ab, c); + } else { + return ring_sub(ab, c); + } } -NdArrayRef TrustedParty::adjustSquare(absl::Span ops) { +NdArrayRef TrustedParty::adjustSquare(absl::Span ops) { SPU_ENFORCE_EQ(ops.size(), 2U); auto rs = reconstruct(RecOp::ADD, ops); // adjust = rs[0] * rs[0] - rs[1]; - return ring_sub(ring_mul(rs[0], rs[0]), rs[1]); + ring_mul_(rs[0], rs[0]); + ring_sub_(rs[0], rs[1]); + return rs[0]; } -NdArrayRef TrustedParty::adjustDot(absl::Span ops) { +NdArrayRef TrustedParty::adjustDot(absl::Span ops) { SPU_ENFORCE_EQ(ops.size(), 3U); checkOperands(ops, true, true); SPU_ENFORCE(ops[2].transpose == false); @@ -96,30 +162,35 @@ NdArrayRef TrustedParty::adjustDot(absl::Span ops) { } // adjust = rs[0] dot rs[1] - rs[2]; - return ring_sub(ring_mmul(rs[0], rs[1]), rs[2]); + auto dot = ring_mmul(rs[0], rs[1]); + ring_sub_(dot, rs[2]); + return dot; } -NdArrayRef TrustedParty::adjustAnd(absl::Span ops) { +NdArrayRef TrustedParty::adjustAnd(absl::Span ops) { SPU_ENFORCE_EQ(ops.size(), 3U); checkOperands(ops); auto rs = reconstruct(RecOp::XOR, ops); // adjust = (rs[0] & rs[1]) ^ rs[2]; - return ring_xor(ring_and(rs[0], rs[1]), rs[2]); + ring_and_(rs[0], rs[1]); + ring_xor_(rs[0], rs[2]); + return rs[0]; } -NdArrayRef TrustedParty::adjustTrunc(absl::Span ops, - size_t bits) { +NdArrayRef TrustedParty::adjustTrunc(absl::Span ops, size_t bits) { SPU_ENFORCE_EQ(ops.size(), 2U); checkOperands(ops); auto rs = reconstruct(RecOp::ADD, ops); // adjust = (rs[0] >> bits) - rs[1]; - return ring_sub(ring_arshift(rs[0], bits), rs[1]); + ring_arshift_(rs[0], {static_cast(bits)}); + ring_sub_(rs[0], rs[1]); + return rs[0]; } std::pair TrustedParty::adjustTruncPr( - absl::Span ops, size_t bits) { + absl::Span ops, size_t bits) { // descs[0] is r, descs[1] adjust to r[k-2, bits], descs[2] adjust to r[k-1] SPU_ENFORCE_EQ(ops.size(), 3U); checkOperands(ops); @@ -127,24 +198,28 @@ std::pair TrustedParty::adjustTruncPr( auto rs = reconstruct(RecOp::ADD, ops); // adjust1 = ((rs[0] << 1) >> (bits + 1)) - rs[1]; - auto adjust1 = ring_sub(ring_rshift(ring_lshift(rs[0], 1), bits + 1), rs[1]); + auto adjust1 = ring_lshift(rs[0], {1}); + ring_rshift_(adjust1, {static_cast(bits + 1)}); + ring_sub_(adjust1, rs[1]); // adjust2 = (rs[0] >> (k - 1)) - rs[2]; const size_t k = SizeOf(ops[0].desc.field) * 8; - auto adjust2 = ring_sub(ring_rshift(rs[0], k - 1), rs[2]); - + auto adjust2 = ring_rshift(rs[0], {static_cast(k - 1)}); + ring_sub_(adjust2, rs[2]); return {adjust1, adjust2}; } -NdArrayRef TrustedParty::adjustRandBit(absl::Span ops) { +NdArrayRef TrustedParty::adjustRandBit(absl::Span ops) { SPU_ENFORCE_EQ(ops.size(), 1U); auto rs = reconstruct(RecOp::ADD, ops); // adjust = bitrev - rs[0]; - return ring_sub(ring_randbit(ops[0].desc.field, ops[0].desc.shape), rs[0]); + auto randbits = ring_randbit(ops[0].desc.field, ops[0].desc.shape); + ring_sub_(randbits, rs[0]); + return randbits; } -NdArrayRef TrustedParty::adjustEqz(absl::Span ops) { +NdArrayRef TrustedParty::adjustEqz(absl::Span ops) { SPU_ENFORCE_EQ(ops.size(), 2U); checkOperands(ops); auto rs_a = reconstruct(RecOp::ADD, ops.subspan(0, 1)); @@ -153,7 +228,7 @@ NdArrayRef TrustedParty::adjustEqz(absl::Span ops) { return ring_xor(rs_a[0], rs_b[0]); } -NdArrayRef TrustedParty::adjustPerm(absl::Span ops, +NdArrayRef TrustedParty::adjustPerm(absl::Span ops, absl::Span perm_vec) { SPU_ENFORCE_EQ(ops.size(), 2U); auto rs = reconstruct(RecOp::ADD, ops); diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.h b/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.h index de58b591..60098256 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.h +++ b/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.h @@ -31,24 +31,26 @@ class TrustedParty { bool transpose{false}; }; - static NdArrayRef adjustMul(absl::Span); + static NdArrayRef adjustMul(absl::Span); - static NdArrayRef adjustSquare(absl::Span); + static NdArrayRef adjustMulPriv(absl::Span); - static NdArrayRef adjustDot(absl::Span); + static NdArrayRef adjustSquare(absl::Span); - static NdArrayRef adjustAnd(absl::Span); + static NdArrayRef adjustDot(absl::Span); - static NdArrayRef adjustTrunc(absl::Span, size_t bits); + static NdArrayRef adjustAnd(absl::Span); - static std::pair adjustTruncPr( - absl::Span, size_t bits); + static NdArrayRef adjustTrunc(absl::Span, size_t bits); - static NdArrayRef adjustRandBit(absl::Span); + static std::pair adjustTruncPr(absl::Span, + size_t bits); - static NdArrayRef adjustEqz(absl::Span); + static NdArrayRef adjustRandBit(absl::Span); - static NdArrayRef adjustPerm(absl::Span, + static NdArrayRef adjustEqz(absl::Span); + + static NdArrayRef adjustPerm(absl::Span, absl::Span perm_vec); }; diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/BUILD.bazel b/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/BUILD.bazel index 816841b3..8ccd5d27 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/BUILD.bazel +++ b/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/BUILD.bazel @@ -28,15 +28,21 @@ proto_library( srcs = ["service.proto"], ) +spu_cc_library( + name = "beaver_stream", + hdrs = ["beaver_stream.h"], +) + spu_cc_library( name = "beaver_server", srcs = ["beaver_server.cc"], hdrs = ["beaver_server.h"], deps = [ + ":beaver_stream", ":service_cc_proto", "//libspu/mpc/semi2k/beaver/beaver_impl/trusted_party", "@com_github_brpc_brpc//:brpc", - "@yacl//yacl/crypto/pke:asymmetric_sm2_crypto", + "@yacl//yacl/crypto/pke:sm2_enc", ], ) diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server.cc b/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server.cc index 8b7a9903..2a5136ec 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server.cc @@ -15,17 +15,19 @@ #include "libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server.h" #include +#include #include #include "absl/strings/ascii.h" #include "spdlog/spdlog.h" #include "yacl/base/byte_container_view.h" #include "yacl/base/exception.h" -#include "yacl/crypto/pke/asymmetric_sm2_crypto.h" +#include "yacl/crypto/pke/sm2_enc.h" #include "libspu/core/ndarray_ref.h" #include "libspu/mpc/common/prg_tensor.h" #include "libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.h" +#include "libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_stream.h" #include "libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/service.pb.h" @@ -49,9 +51,9 @@ class DecryptError : public yacl::Exception { template std::tuple, std::vector>, size_t> -BuildOperand( - const AdjustRequest& req, uint32_t field_size, - const std::unique_ptr& decryptor) { +BuildOperand(const AdjustRequest& req, uint32_t field_size, + const std::unique_ptr& decryptor, + ElementType eltype) { std::vector ops; std::vector> seeds; size_t pad_length = 0; @@ -139,7 +141,7 @@ BuildOperand( } seeds.emplace_back(std::move(seed)); ops.push_back( - TrustedParty::Operand{{shape, type, prg_count}, seeds.back()}); + TrustedParty::Operand{{shape, type, prg_count, eltype}, seeds.back()}); } if constexpr (std::is_same_v) { @@ -174,22 +176,139 @@ std::vector StripNdArray(std::vector& nds, template struct dependent_false : std::false_type {}; +class StreamReader : public brpc::StreamInputHandler { + public: + enum class Status : int8_t { + kNotFinished, + kNormalFinished, + kAbnormalFinished, + kStreamFailed, + }; + + explicit StreamReader(size_t total_buf_len) { + total_buf_len_ = total_buf_len; + future_finished_ = promise_finished_.get_future(); + future_closed_ = promise_closed_.get_future(); + } + + int on_received_messages(brpc::StreamId id, butil::IOBuf* const messages[], + size_t size) override { + SPDLOG_DEBUG("on_received_messages, stream id: {}", id); + for (size_t i = 0; i < size; ++i) { + if (status_ != Status::kNotFinished) { + SPDLOG_WARN("unexpected messages received"); + return -1; + } + const auto& message = messages[i]; + SPDLOG_DEBUG("receive buf size: {}", message->size()); + buf_.append(message->movable()); + if (buf_.length() == total_buf_len_) { + status_ = Status::kNormalFinished; + promise_finished_.set_value(status_); + } else if (buf_.length() > total_buf_len_) { + SPDLOG_ERROR("buf length ({}) greater than expected buf size ({})", + buf_.length(), total_buf_len_); + status_ = Status::kAbnormalFinished; + promise_finished_.set_value(status_); + } + } + return 0; + } + + void on_idle_timeout(brpc::StreamId id) override { + SPDLOG_INFO("Stream {} idle timeout", id); + } + + void on_closed(brpc::StreamId id) override { + SPDLOG_DEBUG("Stream {} closed", id); + promise_closed_.set_value(); + } + + void on_failed(brpc::StreamId id, int error_code, + const std::string& error_text) override { + SPDLOG_ERROR("Stream {} failed, error_code: {}, error_text: {}", id, + error_code, error_text); + status_ = Status::kStreamFailed; + promise_finished_.set_value(status_); + } + + const auto& GetBufRef() const { + SPU_ENFORCE(status_ == Status::kNormalFinished); + return buf_; + } + + Status WaitFinished() { return future_finished_.get(); }; + + void WaitClosed() { future_closed_.wait(); } + + private: + butil::IOBuf buf_; + size_t total_buf_len_; + Status status_ = Status::kNotFinished; + std::promise promise_finished_; + std::promise promise_closed_; + std::future future_finished_; + std::future future_closed_; +}; + template -std::vector AdjustImpl( - const AdjustRequest& req, - const std::unique_ptr& decryptor) { - std::vector ret; - size_t field_size; - if constexpr (std::is_same_v) { - field_size = 128 / 8; - } else { - field_size = req.field_size(); +size_t GetBufferLength(const AdjustRequest& req) { + if constexpr (std::is_same_v) { + if (req.prg_inputs().size() > 0 && req.field_size() > 0) { + return req.prg_inputs()[0].buffer_len() / req.field_size() * + sizeof(int64_t); + } else { + SPDLOG_ERROR("Invalid request, prg_inputs size: {}, field_size: {}", + req.prg_inputs().size(), req.field_size()); + } } - auto [ops, seeds, pad_length] = BuildOperand(req, field_size, decryptor); + return 0; +} + +void SendStreamData(brpc::StreamId stream_id, + absl::Span buf_vec) { + SPU_ENFORCE(!buf_vec.empty()); + for (size_t idx = 1; idx < buf_vec.size(); ++idx) { + SPU_ENFORCE_EQ(buf_vec[0].size(), buf_vec[idx].size()); + } + + size_t chunk_size = kDownStreamChunkSize / buf_vec.size(); + // FIXME: TTP adjuster server and client MUST have same endianness. + size_t left_buf_size = buf_vec[0].size(); + int64_t chunk_idx = 0; + while (left_buf_size > 0) { + butil::IOBuf io_buf; + BeaverDownStreamMeta meta; + io_buf.append(&meta, sizeof(meta)); + + size_t cur_chunk_size = std::min(left_buf_size, chunk_size); + for (const auto& buf : buf_vec) { + int ret = io_buf.append(buf.data() + (chunk_idx * chunk_size), + cur_chunk_size); + SPU_ENFORCE_EQ(ret, 0, "Append data to IO buffer failed"); + } + + // StreamWrite result cannot be EAGAIN, given that we have not set + // max_buf_size + SPU_ENFORCE_EQ(brpc::StreamWrite(stream_id, io_buf), 0); + left_buf_size -= cur_chunk_size; + ++chunk_idx; + } +} + +template +std::vector AdjustImpl(const AdjustRequest& req, + absl::Span ops, + StreamReader& stream_reader) { + std::vector ret; if constexpr (std::is_same_v) { auto adjust = TrustedParty::adjustMul(ops); ret.push_back(std::move(adjust)); + } else if constexpr (std::is_same_v) { + auto adjust = TrustedParty::adjustMulPriv(ops); + ret.push_back(std::move(adjust)); } else if constexpr (std::is_same_v) { auto adjust = TrustedParty::adjustSquare(ops); ret.push_back(std::move(adjust)); @@ -213,7 +332,14 @@ std::vector AdjustImpl( auto adjust = TrustedParty::adjustEqz(ops); ret.push_back(std::move(adjust)); } else if constexpr (std::is_same_v) { - std::vector pv(req.perm_vec().begin(), req.perm_vec().end()); + auto status = stream_reader.WaitFinished(); + SPU_ENFORCE(status == StreamReader::Status::kNormalFinished, + "Stream reader finished abnormally, status: {}", + static_cast(status)); + const auto& buf = stream_reader.GetBufRef(); + SPU_ENFORCE(buf.length() % sizeof(int64_t) == 0); + std::vector pv(buf.length() / sizeof(int64_t)); + buf.copy_to(pv.data()); auto adjust = TrustedParty::adjustPerm(ops, pv); ret.push_back(std::move(adjust)); } else { @@ -221,14 +347,70 @@ std::vector AdjustImpl( "not support AdjustRequest type"); } - return StripNdArray(ret, pad_length); + return ret; +} + +template +void AdjustAndSend( + const AdjustRequest& req, brpc::StreamId stream_id, + StreamReader& stream_reader, + const std::unique_ptr& decryptor) { + size_t field_size; + if constexpr (std::is_same_v) { + field_size = 128 / 8; + } else { + field_size = req.field_size(); + } + ElementType eltype = ElementType::kRing; + // enable eltype for selected requests here + // later all requests may support gfmp + if constexpr (std::is_same_v || + std::is_same_v) { + if (req.element_type() == ElType::GFMP) { + eltype = ElementType::kGfmp; + } + } + auto [ops, seeds, pad_length] = + BuildOperand(req, field_size, decryptor, eltype); + + if constexpr (std::is_same_v || + std::is_same_v) { + auto adjusts = AdjustImpl(req, absl::MakeSpan(ops), stream_reader); + auto buf_vec = StripNdArray(adjusts, pad_length); + SendStreamData(stream_id, buf_vec); + return; + } + + SPU_ENFORCE_EQ(beaver::ttp_server::kReplayChunkSize % 128, 0U); + SPU_ENFORCE(!ops.empty()); + for (size_t idx = 1; idx < ops.size(); idx++) { + SPU_ENFORCE(ops[0].desc.shape == ops[idx].desc.shape); + } + int64_t left_elements = ops[0].desc.shape.at(0); + int64_t chunk_elements = + beaver::ttp_server::kReplayChunkSize / SizeOf(ops[0].desc.field); + while (left_elements > 0) { + int64_t cur_elements = std::min(left_elements, chunk_elements); + left_elements -= cur_elements; + for (auto& op : ops) { + op.desc.shape[0] = cur_elements; + } + auto adjusts = AdjustImpl(req, absl::MakeSpan(ops), stream_reader); + if (left_elements > 0) { + auto buf_vec = StripNdArray(adjusts, 0); + SendStreamData(stream_id, buf_vec); + } else { + auto buf_vec = StripNdArray(adjusts, pad_length); + SendStreamData(stream_id, buf_vec); + } + } } } // namespace class ServiceImpl final : public BeaverService { private: - std::unique_ptr decryptor_; + std::unique_ptr decryptor_; public: ServiceImpl(const std::string& asym_crypto_schema, @@ -246,34 +428,59 @@ class ServiceImpl final : public BeaverService { void Adjust(::google::protobuf::RpcController* controller, const AdjustRequest* req, AdjustResponse* rsp, ::google::protobuf::Closure* done) const { - brpc::ClosureGuard done_guard(done); auto* cntl = static_cast(controller); std::string client_side(butil::endpoint2str(cntl->remote_side()).c_str()); + brpc::StreamId stream_id = brpc::INVALID_STREAM_ID; + auto request = *req; + StreamReader reader(GetBufferLength(*req)); + + // To address the scenario where clients transmit data after an RPC + // response, give precedence to setting up absl::MakeCleanup before invoking + // brpc::ClosureGuard to ensure proper resource management + auto cleanup = absl::MakeCleanup([&]() { + auto cleanup = absl::MakeCleanup([&]() { + if (stream_id != brpc::INVALID_STREAM_ID) { + // To avoid encountering a core dump, it is essential to close the + // process stream prior to the destruction of the StreamReader object + reader.WaitClosed(); + } + }); + try { + AdjustAndSend(request, stream_id, reader, decryptor_); + } catch (const DecryptError& e) { + auto err = fmt::format("Seed Decrypt error {}", e.what()); + SPDLOG_ERROR("{}, client {}", err, + client_side); // TODO: catch the function name + BeaverDownStreamMeta meta; + meta.err_code = ErrorCode::SeedDecryptError; + butil::IOBuf buf; + SPU_ENFORCE_EQ(buf.append(&meta, sizeof(meta)), 0); + SPU_ENFORCE_EQ(buf.append(err.c_str()), 0); + brpc::StreamWrite(stream_id, buf); + return; + } catch (const std::exception& e) { + auto err = fmt::format("adjust error {}", e.what()); + SPDLOG_ERROR("{}, client {}", err, client_side); + BeaverDownStreamMeta meta; + meta.err_code = ErrorCode::OpAdjustError; + butil::IOBuf buf; + SPU_ENFORCE_EQ(buf.append(&meta, sizeof(meta)), 0); + SPU_ENFORCE_EQ(buf.append(err.c_str()), 0); + brpc::StreamWrite(stream_id, buf); + return; + } + }); - std::vector adjusts; - try { - adjusts = AdjustImpl(*req, decryptor_); - } catch (const DecryptError& e) { - auto err = fmt::format("Seed Decrypt error {}", e.what()); - SPDLOG_ERROR("{}, client {}", err, client_side); - rsp->set_code(ErrorCode::SeedDecryptError); - rsp->set_message(err); - return; - } catch (const std::exception& e) { - auto err = fmt::format("adjust error {}", e.what()); - SPDLOG_ERROR("{}, client {}", err, client_side); - rsp->set_code(ErrorCode::OpAdjustError); - rsp->set_message(err); + brpc::ClosureGuard done_guard(done); + brpc::StreamOptions stream_options; + stream_options.max_buf_size = 0; // there is no flow control for downstream + stream_options.handler = &reader; + if (brpc::StreamAccept(&stream_id, *cntl, &stream_options) != 0) { + SPDLOG_ERROR("Failed to accept stream"); + rsp->set_code(ErrorCode::StreamAcceptError); return; } - rsp->set_code(ErrorCode::OK); - for (auto& a : adjusts) { - // FIXME: TTP adjuster server and client MUST have same endianness. - rsp->add_adjust_outputs(a.data(), a.size()); - // how to move this buffer to pb ? - a.reset(); - } } void AdjustMul(::google::protobuf::RpcController* controller, @@ -282,6 +489,12 @@ class ServiceImpl final : public BeaverService { Adjust(controller, req, rsp, done); } + void AdjustMulPriv(::google::protobuf::RpcController* controller, + const AdjustMulPrivRequest* req, AdjustResponse* rsp, + ::google::protobuf::Closure* done) override { + Adjust(controller, req, rsp, done); + } + void AdjustSquare(::google::protobuf::RpcController* controller, const AdjustSquareRequest* req, AdjustResponse* rsp, ::google::protobuf::Closure* done) override { diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server_main.cc b/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server_main.cc index 1680325e..0a701cf3 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server_main.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server_main.cc @@ -91,8 +91,7 @@ int main(int argc, char* argv[]) { std::string key; SPU_ENFORCE( butil::Base64Decode(ttp_server_config::FLAGS_server_private_key, &key)); - decode_private_key = - yacl::Buffer(decode_private_key.data(), decode_private_key.size()); + decode_private_key = yacl::Buffer(key.data(), key.size()); } spu::mpc::semi2k::beaver::ttp_server::ServerOptions ops{ diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_stream.h b/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_stream.h new file mode 100644 index 00000000..04dcc88e --- /dev/null +++ b/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_stream.h @@ -0,0 +1,31 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace spu::mpc::semi2k::beaver::ttp_server { + +constexpr size_t kReplayChunkSize = 50 * 1024 * 1024; // bytes + +constexpr size_t kUpStreamChunkSize = 50 * 1024 * 1024; // bytes +constexpr size_t kDownStreamChunkSize = 50 * 1024 * 1024; // bytes + +// A list of buffer streams +struct BeaverDownStreamMeta { + int32_t err_code = 0; +}; + +} // namespace spu::mpc::semi2k::beaver::ttp_server \ No newline at end of file diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/service.proto b/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/service.proto index 435bf9d6..23fd3025 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/service.proto +++ b/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/service.proto @@ -22,8 +22,16 @@ enum ErrorCode { OK = 0; OpAdjustError = 1; SeedDecryptError = 2; + StreamAcceptError = 3; } +// The type of element in the field. +// Match the enum in libspu/mpc/common/prg_tensor.h +enum ElType { + UNSPECIFIED = 0; + RING = 1; + GFMP = 2; +} // PRG generated buffer metainfo. // BeaverService replay PRG to generate same buffer using each party's prg_seed // encrypted by server's public key. PrgBufferMeta represent {world_size} @@ -41,6 +49,8 @@ service BeaverService { // V1 adjust ops rpc AdjustMul(AdjustMulRequest) returns (AdjustResponse); + rpc AdjustMulPriv(AdjustMulPrivRequest) returns (AdjustResponse); + rpc AdjustSquare(AdjustSquareRequest) returns (AdjustResponse); rpc AdjustDot(AdjustDotRequest) returns (AdjustResponse); @@ -68,6 +78,27 @@ message AdjustMulRequest { // adjust_c = ra * rb - rc // make // ra * rb = (adjust_c + rc) + + // element type supported: "GFMP", "RING" + ElType element_type = 3; + // if element type is "GFMP" then all ring ops will be changed to gfmp +} + +message AdjustMulPrivRequest { + // input 2 prg buffer + // first is a or b [one party holds a slice, another b slice] + // second is c + repeated PrgBufferMeta prg_inputs = 1; + // What field size should be used to interpret buffer content + uint32 field_size = 2; + // output + // adjust_c = a * b - rc + // make + // a * b = (adjust_c + rc) + + // element type supported: "GFMP", "RING" + ElType element_type = 3; + // if element type is "GFMP" then all ring ops will be changed to gfmp } message AdjustSquareRequest { @@ -170,8 +201,6 @@ message AdjustPermRequest { repeated PrgBufferMeta prg_inputs = 1; // What field size should be used to interpret buffer content uint32 field_size = 2; - // permutation vector - repeated int64 perm_vec = 3; // output // adjust_b = (apply inverse permutation perm_vec to ra) - rb // make @@ -181,6 +210,4 @@ message AdjustPermRequest { message AdjustResponse { ErrorCode code = 1; string message = 2; - // Adjust output array buffer - repeated bytes adjust_outputs = 3; } diff --git a/libspu/mpc/semi2k/beaver/beaver_interface.h b/libspu/mpc/semi2k/beaver/beaver_interface.h index d610f380..89c58267 100644 --- a/libspu/mpc/semi2k/beaver/beaver_interface.h +++ b/libspu/mpc/semi2k/beaver/beaver_interface.h @@ -41,6 +41,7 @@ class Beaver { std::vector encrypted_seeds; int64_t size; FieldType field; + ElementType eltype; }; using Array = yacl::Buffer; @@ -50,8 +51,11 @@ class Beaver { virtual ~Beaver() = default; virtual Triple Mul(FieldType field, int64_t size, - ReplayDesc* x_desc = nullptr, - ReplayDesc* y_desc = nullptr) = 0; + ReplayDesc* x_desc = nullptr, ReplayDesc* y_desc = nullptr, + ElementType eltype = ElementType::kRing) = 0; + + virtual Pair MulPriv(FieldType field, int64_t size, + ElementType eltype = ElementType::kRing) = 0; virtual Pair Square(FieldType field, int64_t size, ReplayDesc* x_desc = nullptr) = 0; diff --git a/libspu/mpc/semi2k/boolean.cc b/libspu/mpc/semi2k/boolean.cc index 4dae73a1..707eb91a 100644 --- a/libspu/mpc/semi2k/boolean.cc +++ b/libspu/mpc/semi2k/boolean.cc @@ -30,7 +30,7 @@ namespace { size_t getNumBits(const NdArrayRef& in) { if (in.eltype().isa()) { const auto field = in.eltype().as()->field(); - return DISPATCH_ALL_FIELDS(field, "_", + return DISPATCH_ALL_FIELDS(field, [&]() { return maxBitWidth(in); }); } else if (in.eltype().isa()) { return in.eltype().as()->nbits(); @@ -90,7 +90,7 @@ NdArrayRef CastTypeB::proc(KernelEvalContext*, const NdArrayRef& in, NdArrayRef B2P::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { const auto field = in.eltype().as()->field(); auto* comm = ctx->getState(); - auto out = comm->allReduce(ReduceOp::XOR, in, kBindName); + auto out = comm->allReduce(ReduceOp::XOR, in, kBindName()); return out.as(makeType(field)); } @@ -119,7 +119,7 @@ NdArrayRef AndBP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const size_t out_nbits = std::min(getNumBits(lhs), getNumBits(rhs)); NdArrayRef out(makeType(field, out_nbits), lhs.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _lhs(lhs); NdArrayView _rhs(rhs); NdArrayView _out(out); @@ -144,12 +144,12 @@ NdArrayRef AndBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, // semi2k always use the same storage type. NdArrayRef out(makeType(field, out_nbits), lhs.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = ring2k_t; NdArrayView _lhs(lhs); NdArrayView _rhs(rhs); - DISPATCH_UINT_PT_TYPES(backtype, "_", [&]() { + DISPATCH_UINT_PT_TYPES(backtype, [&]() { using V = ScalarT; // TODO: redefine beaver interface, generate variadic beaver and bits. @@ -215,32 +215,31 @@ NdArrayRef XorBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, } NdArrayRef LShiftB::proc(KernelEvalContext*, const NdArrayRef& in, - size_t shift) const { + const Sizes& shift) const { const auto field = in.eltype().as()->field(); - shift %= SizeOf(field) * 8; - size_t out_nbits = in.eltype().as()->nbits() + shift; + size_t out_nbits = in.eltype().as()->nbits() + + *std::max_element(shift.begin(), shift.end()); out_nbits = std::clamp(out_nbits, static_cast(0), SizeOf(field) * 8); return makeBShare(ring_lshift(in, shift), field, out_nbits); } NdArrayRef RShiftB::proc(KernelEvalContext*, const NdArrayRef& in, - size_t shift) const { + const Sizes& shift) const { const auto field = in.eltype().as()->field(); - shift %= SizeOf(field) * 8; - size_t nbits = in.eltype().as()->nbits(); - size_t out_nbits = nbits - std::min(nbits, shift); - SPU_ENFORCE(nbits <= SizeOf(field) * 8); + int64_t nbits = in.eltype().as()->nbits(); + int64_t out_nbits = + nbits - std::min(nbits, *std::min_element(shift.begin(), shift.end())); + SPU_ENFORCE(nbits <= static_cast(SizeOf(field) * 8)); return makeBShare(ring_rshift(in, shift), field, out_nbits); } NdArrayRef ARShiftB::proc(KernelEvalContext*, const NdArrayRef& in, - size_t shift) const { + const Sizes& shift) const { const auto field = in.eltype().as()->field(); - shift %= SizeOf(field) * 8; // arithmetic right shift expects to work on ring, or the behaviour is // undefined. @@ -268,7 +267,7 @@ NdArrayRef BitIntlB::proc(KernelEvalContext*, const NdArrayRef& in, NdArrayRef out(in.eltype(), in.shape()); auto numel = in.numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _in(in); NdArrayView _out(out); @@ -289,7 +288,7 @@ NdArrayRef BitDeintlB::proc(KernelEvalContext*, const NdArrayRef& in, NdArrayRef out(in.eltype(), in.shape()); auto numel = in.numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _in(in); NdArrayView _out(out); diff --git a/libspu/mpc/semi2k/boolean.h b/libspu/mpc/semi2k/boolean.h index 8c7e6835..766e39c8 100644 --- a/libspu/mpc/semi2k/boolean.h +++ b/libspu/mpc/semi2k/boolean.h @@ -20,7 +20,7 @@ namespace spu::mpc::semi2k { class CommonTypeB : public Kernel { public: - static constexpr char kBindName[] = "common_type_b"; + static constexpr const char* kBindName() { return "common_type_b"; } Kind kind() const override { return Kind::Dynamic; } @@ -29,7 +29,7 @@ class CommonTypeB : public Kernel { class CastTypeB : public CastTypeKernel { public: - static constexpr char kBindName[] = "cast_type_b"; + static constexpr const char* kBindName() { return "cast_type_b"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -41,7 +41,7 @@ class CastTypeB : public CastTypeKernel { class B2P : public UnaryKernel { public: - static constexpr char kBindName[] = "b2p"; + static constexpr const char* kBindName() { return "b2p"; } ce::CExpr latency() const override { return ce::Const(1); } @@ -52,7 +52,7 @@ class B2P : public UnaryKernel { class P2B : public UnaryKernel { public: - static constexpr char kBindName[] = "p2b"; + static constexpr const char* kBindName() { return "p2b"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -63,7 +63,7 @@ class P2B : public UnaryKernel { class AndBP : public BinaryKernel { public: - static constexpr char kBindName[] = "and_bp"; + static constexpr const char* kBindName() { return "and_bp"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -75,7 +75,7 @@ class AndBP : public BinaryKernel { class AndBB : public BinaryKernel { public: - static constexpr char kBindName[] = "and_bb"; + static constexpr const char* kBindName() { return "and_bb"; } ce::CExpr latency() const override { return ce::Const(1); } @@ -87,7 +87,7 @@ class AndBB : public BinaryKernel { class XorBP : public BinaryKernel { public: - static constexpr char kBindName[] = "xor_bp"; + static constexpr const char* kBindName() { return "xor_bp"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -99,7 +99,7 @@ class XorBP : public BinaryKernel { class XorBB : public BinaryKernel { public: - static constexpr char kBindName[] = "xor_bb"; + static constexpr const char* kBindName() { return "xor_bb"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -111,43 +111,43 @@ class XorBB : public BinaryKernel { class LShiftB : public ShiftKernel { public: - static constexpr char kBindName[] = "lshift_b"; + static constexpr const char* kBindName() { return "lshift_b"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const override; + const Sizes& shift) const override; }; class RShiftB : public ShiftKernel { public: - static constexpr char kBindName[] = "rshift_b"; + static constexpr const char* kBindName() { return "rshift_b"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const override; + const Sizes& shift) const override; }; class ARShiftB : public ShiftKernel { public: - static constexpr char kBindName[] = "arshift_b"; + static constexpr const char* kBindName() { return "arshift_b"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const override; + const Sizes& shift) const override; }; class BitrevB : public BitrevKernel { public: - static constexpr char kBindName[] = "bitrev_b"; + static constexpr const char* kBindName() { return "bitrev_b"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -159,7 +159,7 @@ class BitrevB : public BitrevKernel { class BitIntlB : public BitSplitKernel { public: - static constexpr char kBindName[] = "bitintl_b"; + static constexpr const char* kBindName() { return "bitintl_b"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -171,7 +171,7 @@ class BitIntlB : public BitSplitKernel { class BitDeintlB : public BitSplitKernel { public: - static constexpr char kBindName[] = "bitdeintl_b"; + static constexpr const char* kBindName() { return "bitdeintl_b"; } ce::CExpr latency() const override { return ce::Const(0); } diff --git a/libspu/mpc/semi2k/conversion.cc b/libspu/mpc/semi2k/conversion.cc index dce1857f..018d7ac9 100644 --- a/libspu/mpc/semi2k/conversion.cc +++ b/libspu/mpc/semi2k/conversion.cc @@ -99,8 +99,8 @@ NdArrayRef B2A::proc(KernelEvalContext* ctx, const NdArrayRef& x) const { auto r_b = wrap_a2b(ctx->sctx(), r_a); // evaluate adder circuit on x & r, and reveal x+r - auto x_plus_r = comm->allReduce(ReduceOp::XOR, - wrap_add_bb(ctx->sctx(), x, r_b), kBindName); + auto x_plus_r = comm->allReduce( + ReduceOp::XOR, wrap_add_bb(ctx->sctx(), x, r_b), kBindName()); // compute -r + (x+r) ring_neg_(r_a); @@ -135,7 +135,7 @@ NdArrayRef B2A_Randbit::proc(KernelEvalContext* ctx, rand_numel * SizeOf(field)); auto res = NdArrayRef(makeType(field), x.shape()); - DISPATCH_ALL_FIELDS(field, kBindName, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; absl::Span _randbits(randbits.data(), rand_numel); @@ -143,7 +143,7 @@ NdArrayRef B2A_Randbit::proc(KernelEvalContext* ctx, // algorithm begins. // Ref: III.D @ https://eprint.iacr.org/2019/599.pdf (SPDZ-2K primitives) - DISPATCH_UINT_PT_TYPES(backtype, "_", [&]() { + DISPATCH_UINT_PT_TYPES(backtype, [&]() { using V = ScalarT; std::vector x_xor_r(numel); @@ -229,13 +229,13 @@ std::vector B2A_Disassemble::proc(KernelEvalContext* ctx, for (int64_t idx = 0; idx < nbits; ++idx) { res.emplace_back(makeType(field), x.shape()); } - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; absl::Span _randbits(randbits.data(), rand_numel); NdArrayView _x(x); - DISPATCH_UINT_PT_TYPES(backtype, "_", [&]() { + DISPATCH_UINT_PT_TYPES(backtype, [&]() { using V = ScalarT; std::vector x_xor_r(numel); @@ -305,7 +305,9 @@ NdArrayRef MsbA2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { // Compute the k'th bit. // (m^n)[k] ^ carry - auto msb = xor_bb(sctx, rshift_b(sctx, xor_bb(sctx, m, n), k), carry); + auto msb = xor_bb( + sctx, rshift_b(sctx, xor_bb(sctx, m, n), {static_cast(k)}), + carry); return UnwrapValue(msb); } @@ -327,7 +329,7 @@ NdArrayRef eqz(KernelEvalContext* ctx, const NdArrayRef& in) { // beaver samples r and deals [r]a and [r]b // receal c = a+r // check a == 0 <=> c == r - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; auto [ra_buf, rb_buf] = beaver->Eqz(field, numel); @@ -355,11 +357,11 @@ NdArrayRef eqz(KernelEvalContext* ctx, const NdArrayRef& in) { // TODO: fix AND triple // in beaver->AND(field, shape), min FM32, need min 1byte to reduce comm NdArrayRef round_out = rb.as(makeType(field)); - size_t cur_bits = round_out.eltype().as()->nbits(); + int64_t cur_bits = round_out.eltype().as()->nbits(); while (cur_bits != 1) { cur_bits /= 2; - round_out = - wrap_and_bb(ctx->sctx(), round_out, ring_rshift(round_out, cur_bits)); + round_out = wrap_and_bb(ctx->sctx(), round_out, + ring_rshift(round_out, {cur_bits})); } // 1 bit info in lsb diff --git a/libspu/mpc/semi2k/conversion.h b/libspu/mpc/semi2k/conversion.h index 891a23cd..d55d044d 100644 --- a/libspu/mpc/semi2k/conversion.h +++ b/libspu/mpc/semi2k/conversion.h @@ -20,7 +20,7 @@ namespace spu::mpc::semi2k { class A2B : public UnaryKernel { public: - static constexpr char kBindName[] = "a2b"; + static constexpr const char* kBindName() { return "a2b"; } ce::CExpr latency() const override { return (Log(ce::K()) + 1) // adder-circuit; @@ -40,7 +40,7 @@ class A2B : public UnaryKernel { class B2A : public UnaryKernel { public: - static constexpr char kBindName[] = "b2a"; + static constexpr const char* kBindName() { return "b2a"; } ce::CExpr latency() const override { return (Log(ce::K()) + 1) * Log(ce::N()) // A2B @@ -61,7 +61,7 @@ class B2A : public UnaryKernel { class B2A_Randbit : public UnaryKernel { public: - static constexpr char kBindName[] = "b2a"; + static constexpr const char* kBindName() { return "b2a"; } ce::CExpr latency() const override { return ce::Const(1); } @@ -75,7 +75,7 @@ class B2A_Randbit : public UnaryKernel { class B2A_Disassemble : public DisassembleKernel { public: - static constexpr char kBindName[] = "b2a_disassemble"; + static constexpr const char* kBindName() { return "b2a_disassemble"; } ce::CExpr latency() const override { return ce::Const(1); } @@ -91,7 +91,7 @@ class B2A_Disassemble : public DisassembleKernel { // Note: current only for 2PC. class MsbA2B : public UnaryKernel { public: - static constexpr char kBindName[] = "msb_a2b"; + static constexpr const char* kBindName() { return "msb_a2b"; } ce::CExpr latency() const override { // 1 * carry: log(k) + 1 @@ -109,7 +109,7 @@ class MsbA2B : public UnaryKernel { class EqualAA : public BinaryKernel { public: - static constexpr char kBindName[] = "equal_aa"; + static constexpr const char* kBindName() { return "equal_aa"; } ce::CExpr latency() const override { // 1 * edabits + logk * andbb @@ -126,7 +126,7 @@ class EqualAA : public BinaryKernel { class EqualAP : public BinaryKernel { public: - static constexpr char kBindName[] = "equal_ap"; + static constexpr const char* kBindName() { return "equal_ap"; } ce::CExpr latency() const override { // 1 * edabits + logk * andbb @@ -143,7 +143,7 @@ class EqualAP : public BinaryKernel { class CommonTypeV : public Kernel { public: - static constexpr char kBindName[] = "common_type_v"; + static constexpr const char* kBindName() { return "common_type_v"; } Kind kind() const override { return Kind::Dynamic; } diff --git a/libspu/mpc/semi2k/exp.cc b/libspu/mpc/semi2k/exp.cc new file mode 100644 index 00000000..34dba15f --- /dev/null +++ b/libspu/mpc/semi2k/exp.cc @@ -0,0 +1,97 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "exp.h" + +#include "prime_utils.h" +#include "type.h" + +#include "libspu/mpc/utils/gfmp.h" +#include "libspu/mpc/utils/ring_ops.h" + +namespace spu::mpc::semi2k { + +// Given [x*2^fxp] mod 2k for x +// compute [exp(x) * 2^fxp] mod 2^k + +// Assume x is in valid range, otherwise the error may be too large to +// use this method. + +NdArrayRef ExpA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { + const size_t fxp = ctx->sctx()->getFxpBits(); + SPU_ENFORCE( + fxp < 64, + "fxp must be less than 64 for this method, or shift bit overflow ", + "may occur"); + auto field = in.eltype().as()->field(); + NdArrayRef x = in.clone(); + NdArrayRef out; + + // TODO: set different values for FM64 FM32 + const size_t kExpFxp = (field == FieldType::FM128) ? 24 : 13; + + const int rank = ctx->sctx()->lctx()->Rank(); + DISPATCH_ALL_FIELDS(field, [&]() { + auto total_fxp = kExpFxp + fxp; + // note that x is already encoded with fxp + // this conv scale further converts x int fixed point numbers with + // total_fxp + const ring2k_t exp_conv_scale = std::roundf(M_LOG2E * (1L << kExpFxp)); + + // offset scale should directly encoded to a fixed point with total_fxp + const ring2k_t offset = + ctx->sctx()->config().experimental_exp_prime_offset(); + const ring2k_t offset_scaled = offset << total_fxp; + + NdArrayView _x(x); + if (rank == 0) { + pforeach(0, x.numel(), [&](ring2k_t i) { + _x[i] *= exp_conv_scale; + _x[i] += offset_scaled; + }); + } else { + pforeach(0, x.numel(), [&](ring2k_t i) { _x[i] *= exp_conv_scale; }); + } + size_t shr_width = SizeOf(field) * 8 - fxp; + + const ring2k_t kBit = 1; + auto shifted_bit = kBit << total_fxp; + const ring2k_t frac_mask = shifted_bit - 1; + + auto int_part = ring_arshift(x, {static_cast(total_fxp)}); + + // convert from ring-share (int-part) to a prime share over p - 1 + int_part = ProbConvRing2k(int_part, rank, shr_width); + NdArrayView int_part_view(int_part); + + pforeach(0, x.numel(), [&](int64_t i) { + // y = 2^int_part mod p + ring2k_t y = exp_mod(2, int_part_view[i]); + // z = 2^fract_part in RR + double frac_part = static_cast(_x[i] & frac_mask) / shifted_bit; + frac_part = std::pow(2., frac_part); + + // Multiply the 2^{int_part} * 2^{frac_part} mod p + // note that mul_mod uses mersenne prime as modulus according to field + int_part_view[i] = mul_mod( + y, static_cast(std::roundf(frac_part * (kBit << fxp)))); + }); + + NdArrayRef muled = MulPrivModMP(ctx, int_part.as(makeType(field))); + + out = ConvMP(ctx, muled, offset + fxp); + }); + return out.as(in.eltype()); +} + +} // namespace spu::mpc::semi2k \ No newline at end of file diff --git a/libspu/mpc/semi2k/exp.h b/libspu/mpc/semi2k/exp.h new file mode 100644 index 00000000..fcc4711e --- /dev/null +++ b/libspu/mpc/semi2k/exp.h @@ -0,0 +1,37 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "libspu/mpc/kernel.h" + +namespace spu::mpc::semi2k { + +// Given [x*2^fxp] mod 2k for x +// compute [exp(x) * 2^fxp] mod 2^k +// Example: +// spu::mpc::semi2k::ExpA exp; +// outp = exp.proc(&kcontext, ring2k_shr); +class ExpA : public UnaryKernel { + public: + static constexpr const char* kBindName() { return "exp_a"; } + + ce::CExpr latency() const override { return ce::Const(2); } + + ce::CExpr comm() const override { return 2 * ce::K(); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override; +}; + +} // namespace spu::mpc::semi2k diff --git a/libspu/mpc/semi2k/permute.cc b/libspu/mpc/semi2k/permute.cc index 28278f8c..71f68ef4 100644 --- a/libspu/mpc/semi2k/permute.cc +++ b/libspu/mpc/semi2k/permute.cc @@ -40,18 +40,35 @@ inline int64_t getOwner(const NdArrayRef& x) { return x.eltype().as()->owner(); } +Index ring2pv(const NdArrayRef& x) { + SPU_ENFORCE(x.eltype().isa(), "must be ring2k_type, got={}", + x.eltype()); + const auto field = x.eltype().as()->field(); + Index pv(x.numel()); + DISPATCH_ALL_FIELDS(field, [&]() { + NdArrayView _x(x); + pforeach(0, x.numel(), [&](int64_t idx) { pv[idx] = int64_t(_x[idx]); }); + }); + return pv; +} + // Secure inverse permutation of x by perm_rank's permutation pv // The idea here is: // Input permutation pv, beaver generates perm pair {, } that // InversePermute(A, pv) = B. So we can get = InversePermute(open( - // ), pv) + that y = InversePermute(x, pv). NdArrayRef SecureInvPerm(KernelEvalContext* ctx, const NdArrayRef& x, - size_t perm_rank, absl::Span pv) { + const NdArrayRef& perm, size_t perm_rank) { const auto lctx = ctx->lctx(); const auto field = x.eltype().as()->field(); auto* beaver = ctx->getState()->beaver(); auto numel = x.numel(); + Index pv; + if (perm.eltype().isa() || + (perm.eltype().isa() && isOwner(ctx, perm.eltype()))) { + pv = ring2pv(perm); + } auto [a_buf, b_buf] = beaver->PermPair(field, numel, perm_rank, pv); NdArrayRef a(std::make_shared(std::move(a_buf)), x.eltype(), @@ -75,14 +92,11 @@ NdArrayRef SecureInvPerm(KernelEvalContext* ctx, const NdArrayRef& x, NdArrayRef RandPermM::proc(KernelEvalContext* ctx, const Shape& shape) const { NdArrayRef out(makeType(), shape); - // generate a RandU64 as permutation seed auto* prg_state = ctx->getState(); - const auto seed = prg_state->genPriv(FieldType::FM64, {1}); - NdArrayView _seed(seed); - const auto perm_vector = genRandomPerm(out.numel(), _seed[0]); + const auto perm_vector = prg_state->genPrivPerm(out.numel()); const auto field = out.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _out(out); pforeach(0, out.numel(), [&](int64_t idx) { _out[idx] = ring2k_t(perm_vector[idx]); }); @@ -95,51 +109,37 @@ NdArrayRef PermAM::proc(KernelEvalContext* ctx, const NdArrayRef& in, const NdArrayRef& perm) const { auto* comm = ctx->getState(); - PermVector pv = ring2pv(perm); NdArrayRef out(in); for (size_t i = 0; i < comm->getWorldSize(); ++i) { - out = SecureInvPerm(ctx, out, i, pv); + out = SecureInvPerm(ctx, out, perm, i); } - return out; } NdArrayRef PermAP::proc(KernelEvalContext* ctx, const NdArrayRef& in, const NdArrayRef& perm) const { - PermVector pv = ring2pv(perm); - auto out = applyPerm(in, pv); - return out; + return applyPerm(in, perm); } NdArrayRef InvPermAM::proc(KernelEvalContext* ctx, const NdArrayRef& in, const NdArrayRef& perm) const { auto* comm = ctx->getState(); - PermVector pv = ring2pv(perm); NdArrayRef out(in); - auto inv_pv = genInversePerm(pv); + auto inv_perm = genInversePerm(perm); for (int i = comm->getWorldSize() - 1; i >= 0; --i) { - out = SecureInvPerm(ctx, out, i, inv_pv); + out = SecureInvPerm(ctx, out, inv_perm, i); } - return out; } NdArrayRef InvPermAP::proc(KernelEvalContext* ctx, const NdArrayRef& in, const NdArrayRef& perm) const { - PermVector pv = ring2pv(perm); - auto out = applyInvPerm(in, pv); - return out; + return applyInvPerm(in, perm); } NdArrayRef InvPermAV::proc(KernelEvalContext* ctx, const NdArrayRef& in, const NdArrayRef& perm) const { - PermVector pv; - const auto lctx = ctx->lctx(); - if (isOwner(ctx, perm.eltype())) { - pv = ring2pv(perm); - } - auto out = SecureInvPerm(ctx, in, getOwner(perm), pv); - return out; + return SecureInvPerm(ctx, in, perm, getOwner(perm)); } } // namespace spu::mpc::semi2k \ No newline at end of file diff --git a/libspu/mpc/semi2k/permute.h b/libspu/mpc/semi2k/permute.h index 20f35996..d581f73c 100644 --- a/libspu/mpc/semi2k/permute.h +++ b/libspu/mpc/semi2k/permute.h @@ -20,7 +20,7 @@ namespace spu::mpc::semi2k { class RandPermM : public RandKernel { public: - static constexpr char kBindName[] = "rand_perm_m"; + static constexpr const char* kBindName() { return "rand_perm_m"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -31,7 +31,7 @@ class RandPermM : public RandKernel { class PermAM : public PermKernel { public: - static constexpr char kBindName[] = "perm_am"; + static constexpr const char* kBindName() { return "perm_am"; } ce::CExpr latency() const override { return ce::N(); } @@ -43,7 +43,7 @@ class PermAM : public PermKernel { class PermAP : public PermKernel { public: - static constexpr char kBindName[] = "perm_ap"; + static constexpr const char* kBindName() { return "perm_ap"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -55,7 +55,7 @@ class PermAP : public PermKernel { class InvPermAM : public PermKernel { public: - static constexpr char kBindName[] = "inv_perm_am"; + static constexpr const char* kBindName() { return "inv_perm_am"; } ce::CExpr latency() const override { return ce::N(); } @@ -67,7 +67,7 @@ class InvPermAM : public PermKernel { class InvPermAP : public PermKernel { public: - static constexpr char kBindName[] = "inv_perm_ap"; + static constexpr const char* kBindName() { return "inv_perm_ap"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -79,7 +79,7 @@ class InvPermAP : public PermKernel { class InvPermAV : public PermKernel { public: - static constexpr char kBindName[] = "inv_perm_av"; + static constexpr const char* kBindName() { return "inv_perm_av"; } // communication is unbalanced Kind kind() const override { return Kind::Dynamic; } diff --git a/libspu/mpc/semi2k/prime_utils.cc b/libspu/mpc/semi2k/prime_utils.cc new file mode 100644 index 00000000..3d406f24 --- /dev/null +++ b/libspu/mpc/semi2k/prime_utils.cc @@ -0,0 +1,202 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "prime_utils.h" + +#include "type.h" + +#include "libspu/mpc/common/communicator.h" +#include "libspu/mpc/semi2k/state.h" +#include "libspu/mpc/utils/gfmp.h" +#include "libspu/mpc/utils/ring_ops.h" + +namespace spu::mpc::semi2k { + +NdArrayRef ProbConvRing2k(const NdArrayRef& inp_share, int rank, + size_t shr_width) { + SPU_ENFORCE(inp_share.eltype().isa()); + SPU_ENFORCE(rank >= 0 && rank <= 1); + + auto eltype = inp_share.eltype(); + NdArrayRef output_share(eltype, inp_share.shape()); + + auto ring_ty = eltype.as()->field(); + uint128_t shifted_bit = 1; + shifted_bit <<= shr_width; + auto mask = shifted_bit - 1; + // x mod p - 1 + // in our case p > 2^shr_width + + DISPATCH_ALL_FIELDS(ring_ty, [&]() { + const auto prime = ScalarTypeToPrime::prime; + ring2k_t prime_minus_one = (prime - 1); + NdArrayView inp(inp_share); + NdArrayView output_share_view(output_share); + pforeach(0, output_share.numel(), [&](int64_t i) { + output_share_view[i] = + rank == 0 ? ((inp[i] & mask) % prime_minus_one) + // numerical considerations here + // we wanted to work on ring 2k or field p - 1 + // however, if we do not add p -1 + // then the computation will resort to int128 + // due to the way computer works + : ((inp[i] & mask) + prime_minus_one - shifted_bit) % + prime_minus_one; + }); + }); + return output_share; +} + +NdArrayRef UnflattenBuffer(yacl::Buffer&& buf, const NdArrayRef& x) { + return NdArrayRef(std::make_shared(std::move(buf)), x.eltype(), + x.shape()); +} + +// P0 holds x,P1 holds y +// Beaver generates ab = c_0 + c_1 +// Give (a, c_0) to P0 +// Give (b, c_1) to P1 +std::tuple MulPrivPrep(KernelEvalContext* ctx, + const NdArrayRef& x) { + const auto field = x.eltype().as()->field(); + auto* beaver = ctx->getState()->beaver(); + + // generate beaver multiple triple. + NdArrayRef a_or_b; + NdArrayRef c; + + const size_t numel = x.shape().numel(); + auto [a_or_b_buf, c_buf] = beaver->MulPriv( + field, numel, // + x.eltype().isa() ? ElementType::kGfmp : ElementType::kRing); + SPU_ENFORCE(static_cast(a_or_b_buf.size()) == numel * SizeOf(field)); + SPU_ENFORCE(static_cast(c_buf.size()) == numel * SizeOf(field)); + + a_or_b = UnflattenBuffer(std::move(a_or_b_buf), x); + c = UnflattenBuffer(std::move(c_buf), x); + + return {std::move(a_or_b), std::move(c)}; +} + +// P0 holds x,P1 holds y +// Beaver generates ab = c_0 + c_1 +// Give (a, c_0) to P0 +// Give (b, c_1) to P1 +// +// - P0 sends (x+a) to P1 ; P1 sends (y+b) to P0 +// - P0 calculates z0 = x(y+b) + c0 ; P1 calculates z1 = -b(x+a) + c1 +NdArrayRef MulPriv(KernelEvalContext* ctx, const NdArrayRef& x) { + SPU_ENFORCE(x.eltype().isa()); + auto* comm = ctx->getState(); + + NdArrayRef a_or_b, c, xa_or_yb; + + std::tie(a_or_b, c) = MulPrivPrep(ctx, x); + + // P0 sends (x+a) to P1 ; P1 sends (y+b) to P0 + comm->sendAsync(comm->nextRank(), ring_add(a_or_b, x), "(x + a) or (y + b)"); + xa_or_yb = comm->recv(comm->prevRank(), x.eltype(), "(x + a) or (y + b)") + .reshape(x.shape()); + // note that our rings are commutative. + if (comm->getRank() == 0) { + ring_add_(c, ring_mul(std::move(xa_or_yb), x)); + } + if (comm->getRank() == 1) { + ring_sub_(c, ring_mul(std::move(xa_or_yb), a_or_b)); + } + return c; +} + +NdArrayRef MulPrivModMP(KernelEvalContext* ctx, const NdArrayRef& x) { + SPU_ENFORCE(x.eltype().isa()); + auto* comm = ctx->getState(); + + NdArrayRef a_or_b, c, xa_or_yb; + std::tie(a_or_b, c) = MulPrivPrep(ctx, x); + + comm->sendAsync(comm->nextRank(), gfmp_add_mod(a_or_b, x), "xa_or_yb"); + xa_or_yb = + comm->recv(comm->prevRank(), x.eltype(), "xa_or_yb").reshape(x.shape()); + + // note that our rings are commutative. + if (comm->getRank() == 0) { + gfmp_add_mod_(c, gfmp_mul_mod(std::move(xa_or_yb), x)); + } + if (comm->getRank() == 1) { + gfmp_sub_mod_(c, gfmp_mul_mod(std::move(xa_or_yb), a_or_b)); + } + return c; +} + +// We assume the input is ``positive'' +// Given h0 + h1 = h mod p and h < p / 2 +// Define b0 = 1{h0 >= p/2} +// b1 = 1{h1 >= p/2} +// Compute w = 1{h0 + h1 >= p} +// It can be proved that w = (b0 or b1) = not (not b0 and not b1) +NdArrayRef WrapBitModMP(KernelEvalContext* ctx, const NdArrayRef& x) { + // create a wrap bit NdArrayRef of the same shape as in + NdArrayRef b(x.eltype(), x.shape()); + + // for each element, we compute b = 1{h < p/2} for each private share piece + const auto numel = x.numel(); + const auto field = x.eltype().as()->field(); + + DISPATCH_ALL_FIELDS(field, [&]() { + ring2k_t prime = ScalarTypeToPrime::prime; + ring2k_t phalf = prime >> 1; + NdArrayView _x(x); + NdArrayView _b(b); + pforeach(0, numel, [&](int64_t idx) { + _b[idx] = static_cast(_x[idx] < phalf); + }); + + // do private mul + b = MulPriv(ctx, b.as(makeType(field))); + + // map 1 to 0 and 0 to 1, use 1 - x + if (ctx->getState()->getRank() == 0) { + pforeach(0, numel, [&](int64_t idx) { _b[idx] = 1 - _b[idx]; }); + } else { + pforeach(0, numel, [&](int64_t idx) { _b[idx] = -_b[idx]; }); + } + }); + + return b; +} +// Mersenne Prime share -> Ring2k share + +NdArrayRef ConvMP(KernelEvalContext* ctx, const NdArrayRef& h, + uint truncate_nbits) { + // calculate wrap bit + NdArrayRef w = WrapBitModMP(ctx, h); + const auto field = h.eltype().as()->field(); + const auto numel = h.numel(); + + // x = (h - p * w) mod 2^k + + NdArrayRef x(makeType(field), h.shape()); + DISPATCH_ALL_FIELDS(field, [&]() { + auto prime = ScalarTypeToPrime::prime; + NdArrayView h_view(h); + NdArrayView _x(x); + NdArrayView w_view(w); + pforeach(0, numel, [&](int64_t idx) { + _x[idx] = static_cast(h_view[idx] >> truncate_nbits) - + static_cast(prime >> truncate_nbits) * w_view[idx]; + }); + }); + return x; +} + +} // namespace spu::mpc::semi2k diff --git a/libspu/mpc/semi2k/prime_utils.h b/libspu/mpc/semi2k/prime_utils.h new file mode 100644 index 00000000..a04acf3a --- /dev/null +++ b/libspu/mpc/semi2k/prime_utils.h @@ -0,0 +1,46 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "libspu/core/context.h" +#include "libspu/mpc/kernel.h" + +namespace spu::mpc::semi2k { +// Ring2k share -> Mersenne Prime - 1 share +// Given x0 + x1 = x mod 2^k +// Compute h0 + h1 = x mod p with probability > 1 - |x|/2^k +NdArrayRef ProbConvRing2k(const NdArrayRef& inp_share, int rank, + size_t shr_width); + +// Mul open private share +std::tuple MulPrivPrep(KernelEvalContext* ctx, + const NdArrayRef& x); + +// Note that [x] = (x_alice, x_bob) and x_alice + x_bob = x +// Note that we actually want to find the muliplication of x_alice and x_bob +// this function is currently achieved by doing (x_alice, 0) * (0, x_bob) +// optimization is possible. +NdArrayRef MulPrivModMP(KernelEvalContext* ctx, const NdArrayRef& x); +// We assume the input is ``positive'' +// Given h0 + h1 = h mod p and h < p / 2 +// Define b0 = 1{h0 >= p/2} +// b1 = 1{h1 >= p/2} +// Compute w = 1{h0 + h1 >= p} +// It can be proved that w = (b0 or b1) +NdArrayRef WrapBitModMP(KernelEvalContext* ctx, const NdArrayRef& x); + +// Mersenne Prime share -> Ring2k share +NdArrayRef ConvMP(KernelEvalContext* ctx, const NdArrayRef& h, + uint truncate_nbits); +} // namespace spu::mpc::semi2k \ No newline at end of file diff --git a/libspu/mpc/semi2k/protocol.cc b/libspu/mpc/semi2k/protocol.cc index 3acd8344..33d6226b 100644 --- a/libspu/mpc/semi2k/protocol.cc +++ b/libspu/mpc/semi2k/protocol.cc @@ -20,6 +20,7 @@ #include "libspu/mpc/semi2k/arithmetic.h" #include "libspu/mpc/semi2k/boolean.h" #include "libspu/mpc/semi2k/conversion.h" +#include "libspu/mpc/semi2k/exp.h" #include "libspu/mpc/semi2k/permute.h" #include "libspu/mpc/semi2k/state.h" #include "libspu/mpc/semi2k/type.h" @@ -51,7 +52,7 @@ void regSemi2kProtocol(SPUContext* ctx, ctx->prot() ->regKernel< semi2k::P2A, semi2k::A2P, semi2k::A2V, semi2k::V2A, // - semi2k::NotA, // + semi2k::NegateA, // semi2k::AddAP, semi2k::AddAA, // semi2k::MulAP, semi2k::MulAA, semi2k::SquareA, // semi2k::MatMulAP, semi2k::MatMulAA, // @@ -76,6 +77,13 @@ void regSemi2kProtocol(SPUContext* ctx, if (lctx->WorldSize() == 2) { ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + + // only supports 2pc fm128 for now + if (ctx->getField() == FieldType::FM128 && + ctx->config().experimental_enable_exp_prime()) { + ctx->prot()->regKernel(); + } } // ctx->prot()->regKernel(); } diff --git a/libspu/mpc/semi2k/protocol_test.cc b/libspu/mpc/semi2k/protocol_test.cc index 66911344..eb1a6c60 100644 --- a/libspu/mpc/semi2k/protocol_test.cc +++ b/libspu/mpc/semi2k/protocol_test.cc @@ -25,7 +25,11 @@ #include "libspu/mpc/api_test.h" #include "libspu/mpc/common/communicator.h" #include "libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server.h" +#include "libspu/mpc/semi2k/exp.h" +#include "libspu/mpc/semi2k/prime_utils.h" #include "libspu/mpc/semi2k/state.h" +#include "libspu/mpc/semi2k/type.h" +#include "libspu/mpc/utils/gfmp.h" #include "libspu/mpc/utils/ring_ops.h" #include "libspu/mpc/utils/simulate.h" @@ -36,6 +40,12 @@ RuntimeConfig makeConfig(FieldType field) { RuntimeConfig conf; conf.set_protocol(ProtocolKind::SEMI2K); conf.set_field(field); + if (field == FieldType::FM64) { + conf.set_fxp_fraction_bits(17); + } else if (field == FieldType::FM128) { + conf.set_fxp_fraction_bits(40); + } + conf.set_experimental_enable_exp_prime(true); return conf; } @@ -404,4 +414,173 @@ TEST_P(BeaverCacheTest, SquareA) { }); } +TEST_P(BeaverCacheTest, priv_mul_test) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + // only supports 2 party (not counting beaver) + if (npc != 2) { + return; + } + NdArrayRef ring2k_shr[2]; + + int64_t numel = 1; + FieldType field = conf.field(); + + std::vector real_vec(numel); + for (int64_t i = 0; i < numel; ++i) { + real_vec[i] = 2; + } + + auto rnd_msg = gfmp_zeros(field, {numel}); + + DISPATCH_ALL_FIELDS(field, [&]() { + using sT = std::make_signed::type; + NdArrayView xmsg(rnd_msg); + pforeach(0, numel, [&](int64_t i) { xmsg[i] = std::round(real_vec[i]); }); + }); + + ring2k_shr[0] = rnd_msg; + ring2k_shr[1] = rnd_msg; + + NdArrayRef input, outp_pub; + NdArrayRef outp[2]; + + utils::simulate(npc, [&](const std::shared_ptr& lctx) { + auto obj = factory(conf, lctx); + + KernelEvalContext kcontext(obj.get()); + + int rank = lctx->Rank(); + + outp[rank] = spu::mpc::semi2k::MulPrivModMP(&kcontext, ring2k_shr[rank]); + }); + auto got = gfmp_add_mod(outp[0], outp[1]); + DISPATCH_ALL_FIELDS(field, [&]() { + using sT = std::make_signed::type; + NdArrayView got_view(got); + + double max_err = 0.0; + double min_err = 99.0; + for (int64_t i = 0; i < numel; ++i) { + double expected = real_vec[i] * real_vec[i]; + double got = static_cast(got_view[i]); + max_err = std::max(max_err, std::abs(expected - got)); + min_err = std::min(min_err, std::abs(expected - got)); + } + ASSERT_LE(min_err, 1e-3); + ASSERT_LE(max_err, 1e-3); + }); +} + +TEST_P(BeaverCacheTest, exp_mod_test) { + const RuntimeConfig& conf = std::get<1>(GetParam()); + FieldType field = conf.field(); + + DISPATCH_ALL_FIELDS(field, [&]() { + // exponents < 32 + ring2k_t exponents[5] = {10, 21, 27}; + + for (ring2k_t exponent : exponents) { + ring2k_t y = exp_mod(2, exponent); + ring2k_t prime = ScalarTypeToPrime::prime; + ring2k_t prime_minus_one = (prime - 1); + ring2k_t shifted_bit = 1; + shifted_bit <<= exponent; + EXPECT_EQ(y, shifted_bit % prime_minus_one); + } + }); +} + +TEST_P(BeaverCacheTest, ExpA) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + // exp only supports 2 party (not counting beaver) + // only supports FM128 for now + // note not using ctx->hasKernel("exp_a") because we are testing kernel + // registration as well. + if (npc != 2 || conf.field() != FieldType::FM128) { + return; + } + auto fxp = conf.fxp_fraction_bits(); + + NdArrayRef ring2k_shr[2]; + + int64_t numel = 100; + FieldType field = conf.field(); + + // how to define and achieve high pricision for e^20 + std::uniform_real_distribution dist(-18.0, 15.0); + std::default_random_engine rd; + std::vector real_vec(numel); + for (int64_t i = 0; i < numel; ++i) { + // make the input a fixed point number, eliminate the fixed point encoding + // error + real_vec[i] = + static_cast(std::round((dist(rd) * (1L << fxp)))) / (1L << fxp); + } + + auto rnd_msg = ring_zeros(field, {numel}); + + DISPATCH_ALL_FIELDS(field, [&]() { + using sT = std::make_signed::type; + NdArrayView xmsg(rnd_msg); + pforeach(0, numel, [&](int64_t i) { + xmsg[i] = std::round(real_vec[i] * (1L << fxp)); + }); + }); + + ring2k_shr[0] = ring_rand(field, rnd_msg.shape()) + .as(makeType(field)); + ring2k_shr[1] = ring_sub(rnd_msg, ring2k_shr[0]) + .as(makeType(field)); + + NdArrayRef outp_pub; + NdArrayRef outp[2]; + + utils::simulate(npc, [&](const std::shared_ptr& lctx) { + auto obj = factory(conf, lctx); + + KernelEvalContext kcontext(obj.get()); + + int rank = lctx->Rank(); + + size_t bytes = lctx->GetStats()->sent_bytes; + size_t action = lctx->GetStats()->sent_actions; + + spu::mpc::semi2k::ExpA exp; + outp[rank] = exp.proc(&kcontext, ring2k_shr[rank]); + + bytes = lctx->GetStats()->sent_bytes - bytes; + action = lctx->GetStats()->sent_actions - action; + SPDLOG_INFO("ExpA ({}) for n = {}, sent {} MiB ({} B per), actions {}", + field, numel, bytes * 1. / 1024. / 1024., bytes * 1. / numel, + action); + }); + assert(outp[0].eltype() == ring2k_shr[0].eltype()); + auto got = ring_add(outp[0], outp[1]); + ring_print(got, "exp result"); + DISPATCH_ALL_FIELDS(field, [&]() { + using sT = std::make_signed::type; + NdArrayView got_view(got); + + double max_err = 0.0; + for (int64_t i = 0; i < numel; ++i) { + double expected = std::exp(real_vec[i]); + expected = static_cast(std::round((expected * (1L << fxp)))) / + (1L << fxp); + double got = static_cast(got_view[i]) / (1L << fxp); + // cout left here for future improvement + std::cout << "expected: " << fmt::format("{0:f}", expected) + << ", got: " << fmt::format("{0:f}", got) << std::endl; + std::cout << "expected: " + << fmt::format("{0:b}", + static_cast(expected * (1L << fxp))) + << ", got: " << fmt::format("{0:b}", got_view[i]) << std::endl; + max_err = std::max(max_err, std::abs(expected - got)); + } + ASSERT_LE(max_err, 1e-0); + }); +} } // namespace spu::mpc::test diff --git a/libspu/mpc/semi2k/state.h b/libspu/mpc/semi2k/state.h index 7811d7f3..f9b394ca 100644 --- a/libspu/mpc/semi2k/state.h +++ b/libspu/mpc/semi2k/state.h @@ -33,7 +33,7 @@ class Semi2kState : public State { Semi2kState() = default; public: - static constexpr char kBindName[] = "Semi2kState"; + static constexpr const char* kBindName() { return "Semi2kState"; } explicit Semi2kState(const RuntimeConfig& conf, const std::shared_ptr& lctx) { diff --git a/libspu/mpc/spdz2k/abprotocol_spdz2k_test.cc b/libspu/mpc/spdz2k/abprotocol_spdz2k_test.cc index b49e34c2..496a7059 100644 --- a/libspu/mpc/spdz2k/abprotocol_spdz2k_test.cc +++ b/libspu/mpc/spdz2k/abprotocol_spdz2k_test.cc @@ -69,6 +69,10 @@ bool verifyCost(Kernel* kernel, std::string_view name, FieldType field, return succeed; } +Value hack_make_p(SPUContext* ctx, uint128_t init, const Shape& shape) { + return dynDispatch(ctx, "make_p", init, shape); +} + } // namespace TEST_P(BooleanTest, NotB) { @@ -88,7 +92,9 @@ TEST_P(BooleanTest, NotB) { auto r_b = dynDispatch(obj.get(), "not_b", b0); auto cost = obj->prot()->getState()->getStats() - prev; auto r_p = b2p(obj.get(), r_b); - auto r_pp = not_p(obj.get(), p0); + + auto ones = hack_make_p(obj.get(), -1, kShape); + auto r_pp = xor_pp(obj.get(), p0, ones); /* THEN */ EXPECT_VALUE_EQ(r_p, r_pp); @@ -242,7 +248,7 @@ TEST_P(ConversionTest, BitLT) { auto re = b2p(obj.get(), tmp); const auto field = p0.storage_type().as()->field(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = std::make_unsigned::type; size_t numel = kShape.numel(); @@ -291,7 +297,7 @@ TEST_P(ConversionTest, BitLE) { auto re = b2p(obj.get(), tmp); const auto field = p0.storage_type().as()->field(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = std::make_unsigned::type; size_t numel = kShape.numel(); NdArrayView p0_data(p0.data()); diff --git a/libspu/mpc/spdz2k/arithmetic.cc b/libspu/mpc/spdz2k/arithmetic.cc index 6e902a61..ceb7e75a 100644 --- a/libspu/mpc/spdz2k/arithmetic.cc +++ b/libspu/mpc/spdz2k/arithmetic.cc @@ -47,9 +47,9 @@ NdArrayRef CastRing(const NdArrayRef& in, FieldType out_field) { const auto in_field = in_ty->field(); auto out = ring_zeros(out_field, in.shape()); - return DISPATCH_ALL_FIELDS(in_field, "_", [&]() { + return DISPATCH_ALL_FIELDS(in_field, [&]() { NdArrayView _in(in); - return DISPATCH_ALL_FIELDS(out_field, "_", [&]() { + return DISPATCH_ALL_FIELDS(out_field, [&]() { NdArrayView _out(out); pforeach(0, in.numel(), [&](int64_t idx) { _out[idx] = static_cast(_in[idx]); @@ -106,7 +106,7 @@ NdArrayRef RandA::proc(KernelEvalContext* ctx, const Shape& shape) const { // - https://eprint.iacr.org/2019/599.pdf // It's safer to keep the number within [-2**(k-2), 2**(k-2)) for comparison // operations. - auto x = ring_rshift(prg_state->genPriv(field, shape), 2) + auto x = ring_rshift(prg_state->genPriv(field, shape), {2}) .as(makeType(field)); auto x_mac = beaver->AuthArrayRef(x, field, k, s); return makeAShare(x, x_mac, field); @@ -212,10 +212,8 @@ NdArrayRef V2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { return res; } -NdArrayRef NotA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { +NdArrayRef NegateA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { const auto field = in.eltype().as()->field(); - const auto key = ctx->getState()->key(); - auto* comm = ctx->getState(); // in const auto& x = getValueShare(in); @@ -225,14 +223,6 @@ NdArrayRef NotA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { auto neg_x = ring_neg(x); auto neg_x_mac = ring_neg(x_mac); - // add public M-1 - const auto& neg_ones = ring_not(ring_zeros(field, in.shape())); - if (comm->getRank() == 0) { - ring_add_(neg_x, neg_ones); - } - const auto& ones = ring_ones(field, in.shape()); - ring_sub_(neg_x_mac, ring_mul(ones, key)); - return makeAShare(neg_x, neg_x_mac, field); } @@ -296,7 +286,7 @@ bool SingleCheck(KernelEvalContext* ctx, const NdArrayRef& in) { auto* comm = ctx->getState(); auto* beaver = ctx->getState()->beaver(); const auto key = ctx->getState()->key(); - const size_t k = ctx->getState()->k(); + const int64_t k = ctx->getState()->k(); const size_t s = ctx->getState()->s(); // 1. Generate a random, shared value [r] @@ -305,8 +295,8 @@ bool SingleCheck(KernelEvalContext* ctx, const NdArrayRef& in) { // 2. Locally construct [y] const auto& x = getValueShare(in); const auto& x_mac = getMacShare(in); - auto y = ring_add(x, ring_lshift(r, k)); - auto y_mac = ring_add(x_mac, ring_lshift(r_mac, k)); + auto y = ring_add(x, ring_lshift(r, {k})); + auto y_mac = ring_add(x_mac, ring_lshift(r_mac, {k})); // 3. Open the value auto plain_y = comm->allReduce(ReduceOp::ADD, y, kBindName); @@ -334,7 +324,7 @@ bool SingleCheck(KernelEvalContext* ctx, const NdArrayRef& in) { static NdArrayRef wrap_lshift_a(SPUContext* ctx, const NdArrayRef& x, size_t k) { - return UnwrapValue(lshift_a(ctx, WrapValue(x), k)); + return UnwrapValue(lshift_a(ctx, WrapValue(x), {static_cast(k)})); } static NdArrayRef wrap_add_aa(SPUContext* ctx, const NdArrayRef& x, @@ -490,7 +480,7 @@ NdArrayRef MulAA::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, // open e, f auto res = vmap({e, f}, [&](const NdArrayRef& s) { - return comm->allReduce(ReduceOp::ADD, s, kBindName); + return comm->allReduce(ReduceOp::ADD, s, kBindName()); }); auto p_e = std::move(res[0]); auto p_f = std::move(res[1]); @@ -557,7 +547,7 @@ NdArrayRef MatMulAA::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, // open x-a & y-b auto res = vmap({ring_sub(x, a), ring_sub(y, b)}, [&](const NdArrayRef& s) { - return comm->allReduce(ReduceOp::ADD, s, kBindName); + return comm->allReduce(ReduceOp::ADD, s, kBindName()); }); auto p_e = std::move(res[0]); auto p_f = std::move(res[1]); @@ -579,9 +569,8 @@ NdArrayRef MatMulAA::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, } NdArrayRef LShiftA::proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const { + const Sizes& bits) const { const auto field = in.eltype().as()->field(); - bits %= SizeOf(field) * 8; // in const auto& x = getValueShare(in); @@ -617,7 +606,9 @@ NdArrayRef TruncA::proc(KernelEvalContext* ctx, const NdArrayRef& in, beaver->BatchOpen(ring_sub(x, r), ring_sub(x_mac, r_mac), k, s); SPU_ENFORCE(beaver->BatchMacCheck(x_r, check_mac, k, s)); size_t bit_len = SizeOf(field) * 8; - auto tr_x_r = ring_arshift(ring_lshift(x_r, bit_len - k), bit_len - k + bits); + auto tr_x_r = + ring_arshift(ring_lshift(x_r, {static_cast(bit_len - k)}), + {static_cast(bit_len - k + bits)}); ring_bitmask_(tr_x_r, 0, k); // res = [x-r] + [r], which [*] is truncation operation. diff --git a/libspu/mpc/spdz2k/arithmetic.h b/libspu/mpc/spdz2k/arithmetic.h index deb782d5..46765787 100644 --- a/libspu/mpc/spdz2k/arithmetic.h +++ b/libspu/mpc/spdz2k/arithmetic.h @@ -22,7 +22,7 @@ NdArrayRef GetMacShare(KernelEvalContext* ctx, const NdArrayRef& in); class RandA : public RandKernel { public: - static constexpr char kBindName[] = "rand_a"; + static constexpr const char* kBindName() { return "rand_a"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -33,7 +33,7 @@ class RandA : public RandKernel { class P2A : public UnaryKernel { public: - static constexpr char kBindName[] = "p2a"; + static constexpr const char* kBindName() { return "p2a"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -44,7 +44,7 @@ class P2A : public UnaryKernel { class A2P : public UnaryKernel { public: - static constexpr char kBindName[] = "a2p"; + static constexpr const char* kBindName() { return "a2p"; } Kind kind() const override { return Kind::Dynamic; } @@ -53,7 +53,7 @@ class A2P : public UnaryKernel { class A2V : public RevealToKernel { public: - static constexpr char kBindName[] = "a2v"; + static constexpr const char* kBindName() { return "a2v"; } Kind kind() const override { return Kind::Dynamic; } @@ -67,7 +67,7 @@ class A2V : public RevealToKernel { class V2A : public UnaryKernel { public: - static constexpr char kBindName[] = "v2a"; + static constexpr const char* kBindName() { return "v2a"; } Kind kind() const override { return Kind::Dynamic; } @@ -78,9 +78,9 @@ class V2A : public UnaryKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override; }; -class NotA : public UnaryKernel { +class NegateA : public UnaryKernel { public: - static constexpr char kBindName[] = "not_a"; + static constexpr const char* kBindName() { return "negate_a"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -94,7 +94,7 @@ class NotA : public UnaryKernel { //////////////////////////////////////////////////////////////////// class AddAP : public BinaryKernel { public: - static constexpr char kBindName[] = "add_ap"; + static constexpr const char* kBindName() { return "add_ap"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -106,7 +106,7 @@ class AddAP : public BinaryKernel { class AddAA : public BinaryKernel { public: - static constexpr char kBindName[] = "add_aa"; + static constexpr const char* kBindName() { return "add_aa"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -121,7 +121,7 @@ class AddAA : public BinaryKernel { //////////////////////////////////////////////////////////////////// class MulAP : public BinaryKernel { public: - static constexpr char kBindName[] = "mul_ap"; + static constexpr const char* kBindName() { return "mul_ap"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -133,7 +133,7 @@ class MulAP : public BinaryKernel { class MulAA : public BinaryKernel { public: - static constexpr char kBindName[] = "mul_aa"; + static constexpr const char* kBindName() { return "mul_aa"; } Kind kind() const override { return Kind::Dynamic; } @@ -146,7 +146,7 @@ class MulAA : public BinaryKernel { //////////////////////////////////////////////////////////////////// class MatMulAP : public MatmulKernel { public: - static constexpr char kBindName[] = "mmul_ap"; + static constexpr const char* kBindName() { return "mmul_ap"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -158,7 +158,7 @@ class MatMulAP : public MatmulKernel { class MatMulAA : public MatmulKernel { public: - static constexpr char kBindName[] = "mmul_aa"; + static constexpr const char* kBindName() { return "mmul_aa"; } // TODO(jint) express M, N, K Kind kind() const override { return Kind::Dynamic; } @@ -173,14 +173,14 @@ class MatMulAA : public MatmulKernel { class LShiftA : public ShiftKernel { public: - static constexpr char kBindName[] = "lshift_a"; + static constexpr const char* kBindName() { return "lshift_a"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; // Refer to: @@ -189,7 +189,7 @@ class LShiftA : public ShiftKernel { // - https://eprint.iacr.org/2019/599.pdf class TruncA : public TruncAKernel { public: - static constexpr char kBindName[] = "trunc_a"; + static constexpr const char* kBindName() { return "trunc_a"; } Kind kind() const override { return Kind::Dynamic; } diff --git a/libspu/mpc/spdz2k/beaver/BUILD.bazel b/libspu/mpc/spdz2k/beaver/BUILD.bazel index 5ea2a3f6..b7f4d80e 100644 --- a/libspu/mpc/spdz2k/beaver/BUILD.bazel +++ b/libspu/mpc/spdz2k/beaver/BUILD.bazel @@ -80,7 +80,7 @@ spu_cc_library( "//libspu/mpc/spdz2k/ot:tiny_ot", "//libspu/mpc/utils:ring_ops", "@yacl//yacl/crypto/tools:prg", - "@yacl//yacl/kernel/algorithms:ot_store", + "@yacl//yacl/kernel/type:ot_store", "@yacl//yacl/link", "@yacl//yacl/utils:matrix_utils", "@yacl//yacl/utils:serialize", diff --git a/libspu/mpc/spdz2k/beaver/beaver_test.cc b/libspu/mpc/spdz2k/beaver/beaver_test.cc index 0395ee21..c1508522 100644 --- a/libspu/mpc/spdz2k/beaver/beaver_test.cc +++ b/libspu/mpc/spdz2k/beaver/beaver_test.cc @@ -126,7 +126,7 @@ TEST_P(BeaverTest, AuthAnd) { EXPECT_TRUE(ring_all_equal(ring_mul(sum_c, sum_key), sum_c_mac)) << sum_c << sum_key << sum_c_mac; - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _valid_a(valid_a); NdArrayView _valid_b(valid_b); NdArrayView _valid_c(valid_c); @@ -314,7 +314,8 @@ TEST_P(BeaverTest, AuthTrunc) { const size_t bit_len = SizeOf(kField) * 8; auto trunc_sum_a = - ring_arshift(ring_lshift(sum_a, bit_len - k), bit_len - k + kBits); + ring_arshift(ring_lshift(sum_a, {static_cast(bit_len - k)}), + {static_cast(bit_len - k + kBits)}); ring_bitmask_(trunc_sum_a, 0, k); EXPECT_TRUE(ring_all_equal(trunc_sum_a, ring_bitmask(sum_b, 0, k))) @@ -386,7 +387,7 @@ TEST_P(BeaverTest, AuthDot) { << sum_c << sum_key << sum_c_mac; auto res = ring_mmul(sum_a, sum_b); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _sum_a(sum_a); NdArrayView _sum_c(sum_c); NdArrayView _res(res); diff --git a/libspu/mpc/spdz2k/beaver/beaver_tfp.cc b/libspu/mpc/spdz2k/beaver/beaver_tfp.cc index 3c7a5ca7..3ab9fd0e 100644 --- a/libspu/mpc/spdz2k/beaver/beaver_tfp.cc +++ b/libspu/mpc/spdz2k/beaver/beaver_tfp.cc @@ -49,7 +49,7 @@ uint128_t BeaverTfpUnsafe::InitSpdzKey(FieldType field, size_t s) { const int64_t size = 1; auto a = prgCreateArray(field, {size}, seed_, &counter_, &desc); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _a(a); if (comm_->getRank() == 0) { auto t = tp_.adjustSpdzKey(desc); @@ -260,8 +260,10 @@ std::pair BeaverTfpUnsafe::BatchOpen( // Open the low k_bits only // value = value + r_val * 2^k // mac = mac + r_mac * 2^k - auto masked_val = ring_add(value, ring_lshift(r_val, k)); - auto masked_mac = ring_add(mac, ring_lshift(r_mac, k)); + auto masked_val = + ring_add(value, ring_lshift(r_val, {static_cast(k)})); + auto masked_mac = + ring_add(mac, ring_lshift(r_mac, {static_cast(k)})); auto open_val = comm_->allReduce(ReduceOp::ADD, masked_val, kBindName); return {open_val, masked_mac}; diff --git a/libspu/mpc/spdz2k/beaver/beaver_tinyot.cc b/libspu/mpc/spdz2k/beaver/beaver_tinyot.cc index f0481944..f8d090fc 100644 --- a/libspu/mpc/spdz2k/beaver/beaver_tinyot.cc +++ b/libspu/mpc/spdz2k/beaver/beaver_tinyot.cc @@ -23,7 +23,7 @@ #include "yacl/crypto/rand/rand.h" #include "yacl/crypto/tools/prg.h" #include "yacl/kernel/algorithms/base_ot.h" -#include "yacl/kernel/algorithms/ot_store.h" +#include "yacl/kernel/type/ot_store.h" #include "yacl/utils/serialize.h" #include "libspu/mpc/common/prg_tensor.h" @@ -68,7 +68,7 @@ NdArrayRef ring_sqrt2k(const NdArrayRef& x, size_t bits = 0) { } auto ret = ring_zeros(field, x.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = std::make_unsigned::type; NdArrayView _ret(ret); NdArrayView _x(x); @@ -101,7 +101,7 @@ NdArrayRef ring_inv2k(const NdArrayRef& x, size_t bits = 0) { } auto ret = ring_zeros(field, x.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = std::make_unsigned::type; NdArrayView _ret(ret); NdArrayView _x(x); @@ -118,7 +118,7 @@ std::vector ring_cast_vector_boolean(const NdArrayRef& x) { const auto field = x.eltype().as()->field(); std::vector res(x.numel()); - DISPATCH_ALL_FIELDS(field, "RingOps", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _x(x); yacl::parallel_for(0, x.numel(), 4096, [&](size_t start, size_t end) { for (size_t i = start; i < end; i++) { @@ -192,7 +192,7 @@ uint128_t BeaverTinyOt::InitSpdzKey(FieldType, size_t s) { // - https://eprint.iacr.org/2018/482.pdf NdArrayRef BeaverTinyOt::AuthArrayRef(const NdArrayRef& x, FieldType field, size_t k, size_t s) { - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using T = ring2k_t; // 1. l_ = max(l, r + s, 2s) @@ -346,7 +346,7 @@ BeaverTinyOt::Triple_Pair BeaverTinyOt::AuthAnd(FieldType field, // Generate authorize bits in the form of B-Share NdArrayRef spdz_choices(makeType(field), {tinyot_num * 3 + sigma}); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = std::make_unsigned::type; auto _size = auth_abcr.choices.size(); NdArrayView _spdz_choices(spdz_choices); @@ -392,7 +392,7 @@ BeaverTinyOt::Triple_Pair BeaverTinyOt::AuthAnd(FieldType field, auto seed = GenSharedSeed(comm_); auto prg = yacl::crypto::Prg(seed); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = std::make_unsigned::type; NdArrayView _check_spdz_bit(check_spdz_bit); NdArrayView _check_spdz_mac(check_spdz_mac); @@ -546,7 +546,7 @@ BeaverTinyOt::Pair_Pair BeaverTinyOt::AuthTrunc(FieldType field, NdArrayRef tr_val(b_val.eltype(), shape); NdArrayRef tr_mac(b_val.eltype(), shape); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using PShrT = ring2k_t; NdArrayView _val(b_val); NdArrayView _mac(b_mac); @@ -642,7 +642,7 @@ BeaverTinyOt::Pair BeaverTinyOt::AuthRandBit(FieldType field, SPU_ENFORCE(ring_all_equal(ring_bitmask(square, 0, 1), ones)); auto root = ring_sqrt2k(square, k + 2); auto root_inv = ring_inv2k(root, k + 2); - auto root_inv_div2 = ring_rshift(root_inv, 1); + auto root_inv_div2 = ring_rshift(root_inv, {1}); auto d = ring_mul(root_inv_div2, y); auto d_mac = ring_mul(root_inv_div2, y_mac); @@ -751,17 +751,19 @@ std::pair BeaverTinyOt::BatchOpen( // Open the low k_bits only // value = value + r * 2^k // mac = mac + r_mac * 2^k - auto masked_val = ring_add(value, ring_lshift(r_val, k)); - auto masked_mac = ring_add(mac, ring_lshift(r_mac, k)); + auto masked_val = + ring_add(value, ring_lshift(r_val, {static_cast(k)})); + auto masked_mac = + ring_add(mac, ring_lshift(r_mac, {static_cast(k)})); - // Because we would use Maccheck to comfirm the open value. + // Because we would use Maccheck to confirm the open value. // Thus, we don't need commit them. auto open_val = comm_->allReduce(ReduceOp::ADD, masked_val, kBindName); return {open_val, masked_mac}; } void BeaverTinyOt::rotSend(FieldType field, NdArrayRef* q0, NdArrayRef* q1) { - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = ring2k_t; SPDLOG_DEBUG("rotSend start with numel {}", q0->numel()); @@ -784,7 +786,7 @@ void BeaverTinyOt::rotSend(FieldType field, NdArrayRef* q0, NdArrayRef* q1) { // todo: use dynamic_bitset instead of ArrayRef for `a` to improve performance void BeaverTinyOt::rotRecv(FieldType field, const NdArrayRef& a, NdArrayRef* s) { - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = ring2k_t; SPDLOG_DEBUG("rotRecv start with numel {}", a.numel()); @@ -814,7 +816,7 @@ void BeaverTinyOt::rotRecv(FieldType field, const NdArrayRef& a, // SPDZ2k: Efficient MPC mod 2k for Dishonest Majority // - https://eprint.iacr.org/2018/482.pdf NdArrayRef BeaverTinyOt::voleSend(FieldType field, const NdArrayRef& x) { - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using T = ring2k_t; SPU_ENFORCE(spdz2k_ot_primitives_ != nullptr); @@ -831,7 +833,7 @@ NdArrayRef BeaverTinyOt::voleSend(FieldType field, const NdArrayRef& x) { } NdArrayRef BeaverTinyOt::voleRecv(FieldType field, const NdArrayRef& alpha) { - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using T = ring2k_t; SPU_ENFORCE(spdz2k_ot_primitives_ != nullptr); @@ -915,7 +917,7 @@ BeaverTinyOt::Triple_Pair BeaverTinyOt::AuthMul(FieldType field, size_t s) { auto _size = shape.numel(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using T = ring2k_t; SPDLOG_DEBUG("AuthMul start..."); diff --git a/libspu/mpc/spdz2k/beaver/beaver_tinyot.h b/libspu/mpc/spdz2k/beaver/beaver_tinyot.h index b0d35372..6f8ec9ad 100644 --- a/libspu/mpc/spdz2k/beaver/beaver_tinyot.h +++ b/libspu/mpc/spdz2k/beaver/beaver_tinyot.h @@ -14,7 +14,7 @@ #pragma once -#include "yacl/kernel/algorithms/ot_store.h" +#include "yacl/kernel/type/ot_store.h" #include "yacl/link/context.h" #include "libspu/mpc/common/prg_state.h" diff --git a/libspu/mpc/spdz2k/beaver/trusted_party.cc b/libspu/mpc/spdz2k/beaver/trusted_party.cc index 978f1eee..287d1f76 100644 --- a/libspu/mpc/spdz2k/beaver/trusted_party.cc +++ b/libspu/mpc/spdz2k/beaver/trusted_party.cc @@ -248,9 +248,10 @@ std::vector TrustedParty::adjustAuthTrunc( ring_add_(r0[0], ring_sub(rs[0], t_rs)); // r0[1] += (rs[0] >> bits) - rs[1]; - const size_t bit_len = SizeOf(field) * 8; + const int64_t bit_len = SizeOf(field) * 8; auto tr_rs0 = - ring_arshift(ring_lshift(rs[0], bit_len - k), bit_len - k + bits); + ring_arshift(ring_lshift(rs[0], {static_cast(bit_len - k)}), + {static_cast(bit_len - k + bits)}); ring_bitmask_(tr_rs0, 0, k); ring_add_(r0[1], ring_sub(tr_rs0, rs[1])); diff --git a/libspu/mpc/spdz2k/boolean.cc b/libspu/mpc/spdz2k/boolean.cc index 226de5ad..7f230a28 100644 --- a/libspu/mpc/spdz2k/boolean.cc +++ b/libspu/mpc/spdz2k/boolean.cc @@ -38,7 +38,7 @@ NdArrayRef P2Value(FieldType out_field, const NdArrayRef& in, size_t k, size_t new_nbits = 0) { const auto* in_ty = in.eltype().as(); const auto in_field = in_ty->field(); - return DISPATCH_ALL_FIELDS(in_field, "_", [&]() { + return DISPATCH_ALL_FIELDS(in_field, [&]() { using PShrT = ring2k_t; size_t valid_nbits = k; @@ -52,7 +52,7 @@ NdArrayRef P2Value(FieldType out_field, const NdArrayRef& in, size_t k, out_shape.back() *= valid_nbits; auto out = ring_zeros(out_field, out_shape); - return DISPATCH_ALL_FIELDS(out_field, "_", [&]() { + return DISPATCH_ALL_FIELDS(out_field, [&]() { using BShrT = ring2k_t; NdArrayView _in(in); NdArrayView _out(out); @@ -279,10 +279,10 @@ NdArrayRef B2P::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { // 2. Maccheck SPU_ENFORCE(beaver_ptr->BatchMacCheck(pub, mac, 1, s)); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using BShrT = ring2k_t; auto& value = pub; - return DISPATCH_ALL_FIELDS(out_field, "_", [&]() { + return DISPATCH_ALL_FIELDS(out_field, [&]() { using PShrT = ring2k_t; NdArrayRef out(makeType(out_field), in.shape()); @@ -374,7 +374,7 @@ NdArrayRef BitrevB::proc(KernelEvalContext* ctx, const NdArrayRef& in, auto ret = x.clone(); auto ret_mac = x_mac.clone(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _ret(ret); NdArrayView _ret_mac(ret_mac); NdArrayView _x(x); @@ -515,37 +515,48 @@ NdArrayRef AndBP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, } NdArrayRef LShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const { + const Sizes& bits) const { + SPU_ENFORCE(bits.size() == 1, "only support splat shift bits, but got {}", + bits); + const auto field = in.eltype().as()->field(); const auto k = ctx->getState()->k(); const size_t nbits = in.eltype().as()->nbits(); - size_t res_nbits = nbits + bits; + size_t bit = bits[0]; + size_t res_nbits = nbits + bit; - if (bits >= k) { + if (bit >= k) { res_nbits = 1; } else if (res_nbits > k) { res_nbits = k; } - auto [ret, ret_mac] = LShiftBImpl(in, bits, k); + auto [ret, ret_mac] = LShiftBImpl(in, bit, k); return makeBShare(ret, ret_mac, field, res_nbits); } NdArrayRef RShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const { + const Sizes& bits) const { + SPU_ENFORCE(bits.size() == 1, "only support splat shift bits, but got {}", + bits); + + size_t bit = bits[0]; const auto field = in.eltype().as()->field(); const auto nbits = in.eltype().as()->nbits(); - size_t new_nbis = nbits > bits ? nbits - bits : 1; - auto [ret, ret_mac] = RShiftBImpl(in, bits); + size_t new_nbis = nbits > bit ? nbits - bit : 1; + auto [ret, ret_mac] = RShiftBImpl(in, bit); return makeBShare(ret, ret_mac, field, new_nbis); } static NdArrayRef wrap_rshift_b(SPUContext* ctx, const NdArrayRef& x, - size_t bits) { + const Sizes& bits) { return UnwrapValue(rshift_b(ctx, WrapValue(x), bits)); } NdArrayRef ARShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const { + const Sizes& bits) const { + SPU_ENFORCE(bits.size() == 1, "only support splat shift bits, but got {}", + bits); + const auto field = in.eltype().as()->field(); const auto k = ctx->getState()->k(); const auto nbits = in.eltype().as()->nbits(); @@ -553,7 +564,7 @@ NdArrayRef ARShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, if (nbits != k) { return wrap_rshift_b(ctx->sctx(), in, bits); } else { - auto [ret, ret_mac] = ARShiftBImpl(in, bits, k); + auto [ret, ret_mac] = ARShiftBImpl(in, bits[0], k); return makeBShare(ret, ret_mac, field, k); } } @@ -565,7 +576,7 @@ NdArrayRef BitIntlB::proc(KernelEvalContext* ctx, const NdArrayRef& in, const auto k = ctx->getState()->k(); NdArrayRef out(in.eltype(), in.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = ring2k_t; if (in.eltype().isa()) { @@ -612,7 +623,7 @@ NdArrayRef BitDeintlB::proc(KernelEvalContext* ctx, const NdArrayRef& in, const auto k = ctx->getState()->k(); NdArrayRef out(in.eltype(), in.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = ring2k_t; if (in.eltype().isa()) { diff --git a/libspu/mpc/spdz2k/boolean.h b/libspu/mpc/spdz2k/boolean.h index 6d2d9c9f..bc2fd92a 100644 --- a/libspu/mpc/spdz2k/boolean.h +++ b/libspu/mpc/spdz2k/boolean.h @@ -21,7 +21,7 @@ namespace spu::mpc::spdz2k { class CommonTypeB : public Kernel { public: - static constexpr char kBindName[] = "common_type_b"; + static constexpr const char* kBindName() { return "common_type_b"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -32,7 +32,7 @@ class CommonTypeB : public Kernel { class CastTypeB : public CastTypeKernel { public: - static constexpr char kBindName[] = "cast_type_b"; + static constexpr const char* kBindName() { return "cast_type_b"; } Kind kind() const override { return Kind::Dynamic; } @@ -42,7 +42,7 @@ class CastTypeB : public CastTypeKernel { class B2P : public UnaryKernel { public: - static constexpr char kBindName[] = "b2p"; + static constexpr const char* kBindName() { return "b2p"; } Kind kind() const override { return Kind::Dynamic; } @@ -51,7 +51,7 @@ class B2P : public UnaryKernel { class P2B : public UnaryKernel { public: - static constexpr char kBindName[] = "p2b"; + static constexpr const char* kBindName() { return "p2b"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -62,7 +62,7 @@ class P2B : public UnaryKernel { class NotB : public UnaryKernel { public: - static constexpr char kBindName[] = "not_b"; + static constexpr const char* kBindName() { return "not_b"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -73,7 +73,7 @@ class NotB : public UnaryKernel { class BitrevB : public BitrevKernel { public: - static constexpr char kBindName[] = "bitrev_b"; + static constexpr const char* kBindName() { return "bitrev_b"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -85,7 +85,7 @@ class BitrevB : public BitrevKernel { class AndBP : public BinaryKernel { public: - static constexpr char kBindName[] = "and_bp"; + static constexpr const char* kBindName() { return "and_bp"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -97,7 +97,7 @@ class AndBP : public BinaryKernel { class AndBB : public BinaryKernel { public: - static constexpr char kBindName[] = "and_bb"; + static constexpr const char* kBindName() { return "and_bb"; } ce::CExpr latency() const override { // rotate : 1 @@ -115,7 +115,7 @@ class AndBB : public BinaryKernel { class XorBP : public BinaryKernel { public: - static constexpr char kBindName[] = "xor_bp"; + static constexpr const char* kBindName() { return "xor_bp"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -127,7 +127,7 @@ class XorBP : public BinaryKernel { class XorBB : public BinaryKernel { public: - static constexpr char kBindName[] = "xor_bb"; + static constexpr const char* kBindName() { return "xor_bb"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -139,43 +139,43 @@ class XorBB : public BinaryKernel { class LShiftB : public ShiftKernel { public: - static constexpr char kBindName[] = "lshift_b"; + static constexpr const char* kBindName() { return "lshift_b"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; class RShiftB : public ShiftKernel { public: - static constexpr char kBindName[] = "rshift_b"; + static constexpr const char* kBindName() { return "rshift_b"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; class ARShiftB : public ShiftKernel { public: - static constexpr char kBindName[] = "arshift_b"; + static constexpr const char* kBindName() { return "arshift_b"; } ce::CExpr latency() const override { return ce::Const(0); } ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; class BitIntlB : public BitSplitKernel { public: - static constexpr char kBindName[] = "bitintl_b"; + static constexpr const char* kBindName() { return "bitintl_b"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -187,7 +187,7 @@ class BitIntlB : public BitSplitKernel { class BitDeintlB : public BitSplitKernel { public: - static constexpr char kBindName[] = "bitdeintl_b"; + static constexpr const char* kBindName() { return "bitdeintl_b"; } ce::CExpr latency() const override { return ce::Const(0); } diff --git a/libspu/mpc/spdz2k/conversion.cc b/libspu/mpc/spdz2k/conversion.cc index 8f8ded08..122d1c26 100644 --- a/libspu/mpc/spdz2k/conversion.cc +++ b/libspu/mpc/spdz2k/conversion.cc @@ -113,7 +113,7 @@ CircuitBasicBlock MakeSPDZBasicBlock(SPUContext* ctx) { cbb._and = [=](T const& x, T const& y) -> T { COMMUTATIVE_DISPATCH(and_pp, and_bp, and_bb); }; - cbb.lshift = [=](T const& x, size_t bits) -> T { + cbb.lshift = [=](T const& x, const Sizes& bits) -> T { if (_IsP(x)) { return lshift_p(ctx, x, bits); } else if (_IsB(x)) { @@ -121,7 +121,7 @@ CircuitBasicBlock MakeSPDZBasicBlock(SPUContext* ctx) { } SPU_THROW("unsupported op x={}", x); }; - cbb.rshift = [=](T const& x, size_t bits) -> T { + cbb.rshift = [=](T const& x, const Sizes& bits) -> T { if (_IsP(x)) { return rshift_p(ctx, x, bits); } else if (_IsB(x)) { @@ -191,7 +191,7 @@ NdArrayRef Bit2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { ring_bitmask_(c, 0, 1); // 5. [x] = c + [r] - 2 * c * [r] - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _c(c); NdArrayView _r(r); NdArrayView _r_mac(r_mac); @@ -292,7 +292,7 @@ NdArrayRef B2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { ring_bitmask_(c, 0, 1); // 4. [x] = c + [r] - 2 * c * [r] - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { NdArrayRef out(makeType(field, true), out_shape); NdArrayRef expand_out(makeType(field, true), _in.shape()); @@ -354,8 +354,8 @@ NdArrayRef MSB::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { // then set r = \sum r_i 2^{i} for (int64_t i = 0; i < k; ++i) { auto [_r_i, _r_i_mac] = beaver->AuthRandBit(field, in.shape(), k, s); - ring_add_(_r_val, ring_lshift(_r_i, i)); - ring_add_(_r_mac, ring_lshift(_r_i_mac, i)); + ring_add_(_r_val, ring_lshift(_r_i, {i})); + ring_add_(_r_mac, ring_lshift(_r_i_mac, {i})); // record r_i & r_i_mac _r_vec.emplace_back(std::move(_r_i)); _r_mac_vec.emplace_back(std::move(_r_i_mac)); @@ -381,8 +381,8 @@ NdArrayRef MSB::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { const auto ty = makeType(field); for (int64_t i = 0; i < k - 1; ++i) { - ring_add_(_ar, ring_lshift(_r_vec[i], i)); - ring_add_(_ar_mac, ring_lshift(_r_mac_vec[i], i)); + ring_add_(_ar, ring_lshift(_r_vec[i], {i})); + ring_add_(_ar_mac, ring_lshift(_r_mac_vec[i], {i})); auto at_r_i = makeAShare(_r_vec[i], _r_mac_vec[i], field); auto bt_r_i = wrap_a2bit(ctx->sctx(), at_r_i); @@ -415,8 +415,8 @@ NdArrayRef MSB::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { // d = a - a' auto _au = getValueShare(au); auto _au_mac = getMacShare(au); - auto _aa = ring_sub(ring_lshift(_au, k - 1), _ar); - auto _aa_mac = ring_sub(ring_lshift(_au_mac, k - 1), _ar_mac); + auto _aa = ring_sub(ring_lshift(_au, {k - 1}), _ar); + auto _aa_mac = ring_sub(ring_lshift(_au_mac, {k - 1}), _ar_mac); if (comm->getRank() == 0) { ring_add_(_aa, _c_open); } @@ -426,18 +426,18 @@ NdArrayRef MSB::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { // 7. let e = d + 2^{k-1} b, then open e auto [_b, _b_mac] = beaver->AuthRandBit(field, in.shape(), k, s); - auto _e = ring_add(_d, ring_lshift(_b, k - 1)); - auto _e_mac = ring_add(_d_mac, ring_lshift(_b_mac, k - 1)); + auto _e = ring_add(_d, ring_lshift(_b, {k - 1})); + auto _e_mac = ring_add(_d_mac, ring_lshift(_b_mac, {k - 1})); auto [e_open, e_zero_mac] = beaver->BatchOpen(_e, _e_mac, k, s); SPU_ENFORCE(beaver->BatchMacCheck(e_open, e_zero_mac, k, s)); // 8. e' be the most significant bit of e - auto _ee = ring_bitmask(ring_rshift(e_open, k - 1), 0, 1); + auto _ee = ring_bitmask(ring_rshift(e_open, {k - 1}), 0, 1); // 9. output e_{k-1} + b - 2 e_{k-1} b - auto _ret = ring_sub(_b, ring_lshift(ring_mul(_b, _ee), 1)); - auto _ret_mac = ring_sub(_b_mac, ring_lshift(ring_mul(_b_mac, _ee), 1)); + auto _ret = ring_sub(_b, ring_lshift(ring_mul(_b, _ee), {1})); + auto _ret_mac = ring_sub(_b_mac, ring_lshift(ring_mul(_b_mac, _ee), {1})); if (comm->getRank() == 0) { ring_add_(_ret, _ee); } diff --git a/libspu/mpc/spdz2k/conversion.h b/libspu/mpc/spdz2k/conversion.h index 29209cdf..8ffd4276 100644 --- a/libspu/mpc/spdz2k/conversion.h +++ b/libspu/mpc/spdz2k/conversion.h @@ -27,7 +27,7 @@ using ce::N; class A2B : public UnaryKernel { public: - static constexpr char kBindName[] = "a2b"; + static constexpr const char* kBindName() { return "a2b"; } CExpr latency() const override { // 1 * AddBB : log(k) + 1 @@ -48,7 +48,7 @@ class A2B : public UnaryKernel { class A2Bit : public UnaryKernel { public: - static constexpr char kBindName[] = "a2bit"; + static constexpr const char* kBindName() { return "a2bit"; } CExpr latency() const override { // 1 * AddBB : log(k) + 1 @@ -69,7 +69,7 @@ class A2Bit : public UnaryKernel { class Bit2A : public UnaryKernel { public: - static constexpr char kBindName[] = "bit2a"; + static constexpr const char* kBindName() { return "bit2a"; } CExpr latency() const override { // 1 * AddBB : log(k) + 1 @@ -90,7 +90,7 @@ class Bit2A : public UnaryKernel { class BitDec : public UnaryKernel { public: - static constexpr char kBindName[] = "bit_dec"; + static constexpr const char* kBindName() { return "bit_dec"; } CExpr latency() const override { // 1 * AddBB : log(k) + 1 @@ -112,7 +112,7 @@ class BitDec : public UnaryKernel { // https://encrypto.de/papers/DSZ15.pdf class B2A : public UnaryKernel { public: - static constexpr char kBindName[] = "b2a"; + static constexpr const char* kBindName() { return "b2a"; } CExpr latency() const override { // 2 * rotate : 2 @@ -133,7 +133,7 @@ class B2A : public UnaryKernel { class MSB : public UnaryKernel { public: - static constexpr char kBindName[] = "msb_a2b"; + static constexpr const char* kBindName() { return "msb_a2b"; } Kind kind() const override { return Kind::Dynamic; } @@ -142,7 +142,7 @@ class MSB : public UnaryKernel { class AddBB : public BinaryKernel { public: - static constexpr char kBindName[] = "add_bb"; + static constexpr const char* kBindName() { return "add_bb"; } CExpr latency() const override { // Cost from other gates (from KoggeStoneAdder): @@ -166,7 +166,7 @@ class AddBB : public BinaryKernel { class AddBP : public BinaryKernel { public: - static constexpr char kBindName[] = "add_bp"; + static constexpr const char* kBindName() { return "add_bp"; } CExpr latency() const override { return Const(0); } @@ -178,7 +178,7 @@ class AddBP : public BinaryKernel { class BitLTBB : public BinaryKernel { public: - static constexpr char kBindName[] = "bitlt_bb"; + static constexpr const char* kBindName() { return "bitlt_bb"; } CExpr latency() const override { return Const(0); } @@ -190,7 +190,7 @@ class BitLTBB : public BinaryKernel { class BitLEBB : public BinaryKernel { public: - static constexpr char kBindName[] = "bitle_bb"; + static constexpr const char* kBindName() { return "bitle_bb"; } CExpr latency() const override { return Const(0); } diff --git a/libspu/mpc/spdz2k/io.cc b/libspu/mpc/spdz2k/io.cc index 136ba2d7..dff59022 100644 --- a/libspu/mpc/spdz2k/io.cc +++ b/libspu/mpc/spdz2k/io.cc @@ -60,9 +60,9 @@ std::vector Spdz2kIo::toShares(const NdArrayRef& raw, const auto runtime_field = getRuntimeField(field); NdArrayRef x(makeType(runtime_field), raw.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _raw(raw); - DISPATCH_ALL_FIELDS(runtime_field, "_", [&]() { + DISPATCH_ALL_FIELDS(runtime_field, [&]() { NdArrayView _x(x); pforeach(0, raw.numel(), [&](int64_t idx) { _x[idx] = static_cast(_raw[idx]); @@ -113,9 +113,9 @@ NdArrayRef Spdz2kIo::fromShares(const std::vector& shares) const { { NdArrayRef x(makeType(field_), res.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _res(res); - DISPATCH_ALL_FIELDS(field_, "_", [&]() { + DISPATCH_ALL_FIELDS(field_, [&]() { NdArrayView _x(x); pforeach(0, x.numel(), [&](int64_t idx) { _x[idx] = static_cast(_res[idx]); diff --git a/libspu/mpc/spdz2k/ot/BUILD.bazel b/libspu/mpc/spdz2k/ot/BUILD.bazel index 45143795..124adaf1 100644 --- a/libspu/mpc/spdz2k/ot/BUILD.bazel +++ b/libspu/mpc/spdz2k/ot/BUILD.bazel @@ -68,7 +68,7 @@ spu_cc_library( "//libspu/mpc/utils:ring_ops", "@com_github_emptoolkit_emp_tool//:emp-tool", "@yacl//yacl/crypto/tools:prg", - "@yacl//yacl/kernel/algorithms:ot_store", + "@yacl//yacl/kernel/type:ot_store", "@yacl//yacl/link", ], ) diff --git a/libspu/mpc/spdz2k/ot/kos_ote.h b/libspu/mpc/spdz2k/ot/kos_ote.h index 58b9fc07..4ba1f1b9 100644 --- a/libspu/mpc/spdz2k/ot/kos_ote.h +++ b/libspu/mpc/spdz2k/ot/kos_ote.h @@ -15,7 +15,7 @@ #pragma once #include "absl/types/span.h" #include "yacl/base/dynamic_bitset.h" -#include "yacl/kernel/algorithms/ot_store.h" +#include "yacl/kernel/type/ot_store.h" #include "yacl/link/link.h" namespace spu::mpc::spdz2k { diff --git a/libspu/mpc/spdz2k/ot/tiny_ot.h b/libspu/mpc/spdz2k/ot/tiny_ot.h index cb099095..ebe38563 100644 --- a/libspu/mpc/spdz2k/ot/tiny_ot.h +++ b/libspu/mpc/spdz2k/ot/tiny_ot.h @@ -13,7 +13,7 @@ // limitations under the License. #include -#include "yacl/kernel/algorithms/ot_store.h" +#include "yacl/kernel/type/ot_store.h" #include "libspu/mpc/common/communicator.h" diff --git a/libspu/mpc/spdz2k/protocol.cc b/libspu/mpc/spdz2k/protocol.cc index 46db690a..ffbefd55 100644 --- a/libspu/mpc/spdz2k/protocol.cc +++ b/libspu/mpc/spdz2k/protocol.cc @@ -50,7 +50,7 @@ void regSpdz2kProtocol(SPUContext* ctx, ctx->prot()->addState(ctx->config(), lctx); ctx->prot() ->regKernel(); diff --git a/libspu/mpc/spdz2k/state.h b/libspu/mpc/spdz2k/state.h index 03981e02..c5d93838 100644 --- a/libspu/mpc/spdz2k/state.h +++ b/libspu/mpc/spdz2k/state.h @@ -78,7 +78,7 @@ class Spdz2kState : public State { } public: - static constexpr char kBindName[] = "Spdz2kState"; + static constexpr const char* kBindName() { return "Spdz2kState"; } static constexpr auto kAesType = yacl::crypto::SymmetricCrypto::CryptoType::AES128_CTR; diff --git a/libspu/mpc/spdz2k/value.cc b/libspu/mpc/spdz2k/value.cc index a4e49fb8..029f0651 100644 --- a/libspu/mpc/spdz2k/value.cc +++ b/libspu/mpc/spdz2k/value.cc @@ -56,7 +56,7 @@ NdArrayRef makeBShare(const NdArrayRef& s1, const NdArrayRef& s2, NdArrayRef res(ty, new_shape); int64_t res_numel = res.numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView> _res(res); pforeach(0, res_numel * k, [&](int64_t i) { @@ -108,7 +108,7 @@ NdArrayRef getShare(const NdArrayRef& in, int64_t share_idx) { } else { NdArrayRef ret(ty, new_shape); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { size_t numel = in.numel(); NdArrayView _ret(ret); NdArrayView> _in(in); @@ -141,7 +141,7 @@ size_t maxNumBits(const NdArrayRef& lhs, const NdArrayRef& rhs) { } const auto* rhs_ty = rhs.eltype().as(); const auto rhs_field = rhs_ty->field(); - return DISPATCH_ALL_FIELDS(rhs_field, "_", [&]() { + return DISPATCH_ALL_FIELDS(rhs_field, [&]() { using PShrT = ring2k_t; return std::max(lhs.eltype().as()->nbits(), maxBitWidth(rhs)); @@ -158,7 +158,7 @@ size_t minNumBits(const NdArrayRef& lhs, const NdArrayRef& rhs) { } const auto* rhs_ty = rhs.eltype().as(); const auto rhs_field = rhs_ty->field(); - return DISPATCH_ALL_FIELDS(rhs_field, "_", [&]() { + return DISPATCH_ALL_FIELDS(rhs_field, [&]() { using PShrT = ring2k_t; return std::min(lhs.eltype().as()->nbits(), maxBitWidth(rhs)); diff --git a/libspu/mpc/standard_shape/kernels.h b/libspu/mpc/standard_shape/kernels.h index c4826504..246dffe5 100644 --- a/libspu/mpc/standard_shape/kernels.h +++ b/libspu/mpc/standard_shape/kernels.h @@ -20,7 +20,7 @@ namespace spu::mpc::standard_shape { class Broadcast : public BroadcastKernel { public: - static constexpr char kBindName[] = "broadcast"; + static constexpr const char* kBindName() { return "broadcast"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -32,7 +32,7 @@ class Broadcast : public BroadcastKernel { class Reshape : public ShapeBasedKernel { public: - static constexpr char kBindName[] = "reshape"; + static constexpr const char* kBindName() { return "reshape"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -44,7 +44,7 @@ class Reshape : public ShapeBasedKernel { class ExtractSlice : public ExtractSliceKernel { public: - static constexpr char kBindName[] = "extract_slice"; + static constexpr const char* kBindName() { return "extract_slice"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -57,7 +57,7 @@ class ExtractSlice : public ExtractSliceKernel { class UpdateSlice : public UpdateSliceKernel { public: - static constexpr char kBindName[] = "update_slice"; + static constexpr const char* kBindName() { return "update_slice"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -69,7 +69,7 @@ class UpdateSlice : public UpdateSliceKernel { class Transpose : public DimsBasedKernel { public: - static constexpr char kBindName[] = "transpose"; + static constexpr const char* kBindName() { return "transpose"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -81,7 +81,7 @@ class Transpose : public DimsBasedKernel { class Reverse : public DimsBasedKernel { public: - static constexpr char kBindName[] = "reverse"; + static constexpr const char* kBindName() { return "reverse"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -93,7 +93,7 @@ class Reverse : public DimsBasedKernel { class Fill : public ShapeBasedKernel { public: - static constexpr char kBindName[] = "fill"; + static constexpr const char* kBindName() { return "fill"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -105,7 +105,7 @@ class Fill : public ShapeBasedKernel { class Pad : public PadKernel { public: - static constexpr char kBindName[] = "pad"; + static constexpr const char* kBindName() { return "pad"; } ce::CExpr latency() const override { return ce::Const(0); } @@ -119,7 +119,7 @@ class Pad : public PadKernel { class Concate : public ConcateKernel { public: - static constexpr char kBindName[] = "concatenate"; + static constexpr const char* kBindName() { return "concatenate"; } ce::CExpr latency() const override { return ce::Const(0); } diff --git a/libspu/mpc/tools/benchmark.cc b/libspu/mpc/tools/benchmark.cc index 615fd918..77d6e631 100644 --- a/libspu/mpc/tools/benchmark.cc +++ b/libspu/mpc/tools/benchmark.cc @@ -158,8 +158,8 @@ DEFINE_BENCHMARK(BenchAndSP, NumelArgs); DEFINE_BENCHMARK(BenchXorSP, NumelArgs); DEFINE_BENCHMARK(BenchS2P, NumelArgs); DEFINE_BENCHMARK(BenchP2S, NumelArgs); -DEFINE_BENCHMARK(BenchNotS, NumelArgs); -DEFINE_BENCHMARK(BenchNotP, NumelArgs); +DEFINE_BENCHMARK(BenchNegateS, NumelArgs); +DEFINE_BENCHMARK(BenchNegateP, NumelArgs); DEFINE_BENCHMARK(BenchLShiftS, NumelShiftArgs); DEFINE_BENCHMARK(BenchLShiftP, NumelShiftArgs); @@ -177,7 +177,7 @@ DEFINE_BENCHMARK(BenchRandB, NumelArgs); DEFINE_BENCHMARK(BenchP2A, NumelArgs); DEFINE_BENCHMARK(BenchA2P, NumelArgs); DEFINE_BENCHMARK(BenchMsbA2b, NumelArgs); -DEFINE_BENCHMARK(BenchNotA, NumelArgs); +DEFINE_BENCHMARK(BenchNegateA, NumelArgs); DEFINE_BENCHMARK(BenchAddAP, NumelArgs); DEFINE_BENCHMARK(BenchMulAP, NumelArgs); DEFINE_BENCHMARK(BenchAddAA, NumelArgs); diff --git a/libspu/mpc/tools/benchmark.h b/libspu/mpc/tools/benchmark.h index 811dc9aa..ce64c17c 100644 --- a/libspu/mpc/tools/benchmark.h +++ b/libspu/mpc/tools/benchmark.h @@ -172,10 +172,14 @@ class OpData { } for (auto& b1 : b1s) { b1 = p2b(obj_, rand_p(obj_, Shape{state.range(1)})); - b1 = lshift_b(obj_, b1, - SizeOf(static_cast(state.range(0))) * 8 - 1); - b1 = rshift_b(obj_, b1, - SizeOf(static_cast(state.range(0))) * 8 - 1); + b1 = lshift_b( + obj_, b1, + {static_cast( + SizeOf(static_cast(state.range(0))) * 8 - 1)}); + b1 = rshift_b( + obj_, b1, + {static_cast( + SizeOf(static_cast(state.range(0))) * 8 - 1)}); } } virtual ~OpData() = default; @@ -216,14 +220,14 @@ MPC_BENCH_DEFINE(BenchAddSP, OpData1S1P, add_sp, ss[0], ps[0]) MPC_BENCH_DEFINE(BenchMulSP, OpData1S1P, mul_sp, ss[0], ps[0]) MPC_BENCH_DEFINE(BenchAndSP, OpData1S1P, and_sp, ss[0], ps[0]) MPC_BENCH_DEFINE(BenchXorSP, OpData1S1P, xor_sp, ss[0], ps[0]) -MPC_BENCH_DEFINE(BenchNotS, OpData1S, not_s, ss[0]) -MPC_BENCH_DEFINE(BenchNotP, OpData1P, not_p, ps[0]) -MPC_BENCH_DEFINE(BenchLShiftS, OpData1S, lshift_s, ss[0], state.range(2)) -MPC_BENCH_DEFINE(BenchLShiftP, OpData1P, lshift_p, ps[0], state.range(2)) -MPC_BENCH_DEFINE(BenchRShiftS, OpData1S, rshift_s, ss[0], state.range(2)) -MPC_BENCH_DEFINE(BenchRShiftP, OpData1P, rshift_p, ps[0], state.range(2)) -MPC_BENCH_DEFINE(BenchARShiftS, OpData1S, arshift_s, ss[0], state.range(2)) -MPC_BENCH_DEFINE(BenchARShiftP, OpData1P, arshift_p, ps[0], state.range(2)) +MPC_BENCH_DEFINE(BenchNegateS, OpData1S, negate_s, ss[0]) +MPC_BENCH_DEFINE(BenchNegateP, OpData1P, negate_p, ps[0]) +MPC_BENCH_DEFINE(BenchLShiftS, OpData1S, lshift_s, ss[0], {state.range(2)}) +MPC_BENCH_DEFINE(BenchLShiftP, OpData1P, lshift_p, ps[0], {state.range(2)}) +MPC_BENCH_DEFINE(BenchRShiftS, OpData1S, rshift_s, ss[0], {state.range(2)}) +MPC_BENCH_DEFINE(BenchRShiftP, OpData1P, rshift_p, ps[0], {state.range(2)}) +MPC_BENCH_DEFINE(BenchARShiftS, OpData1S, arshift_s, ss[0], {state.range(2)}) +MPC_BENCH_DEFINE(BenchARShiftP, OpData1P, arshift_p, ps[0], {state.range(2)}) MPC_BENCH_DEFINE(BenchTruncS, OpData1S, trunc_s_wrapper, ss[0], state.range(2)) MPC_BENCH_DEFINE(BenchS2P, OpData1S, s2p, ss[0]) MPC_BENCH_DEFINE(BenchP2S, OpData1P, p2s, ps[0]) @@ -237,13 +241,13 @@ MPC_BENCH_DEFINE(BenchRandB, OpDataBasic, rand_b, Shape{state.range(1)}) MPC_BENCH_DEFINE(BenchP2A, OpData1P, p2a, ps[0]) MPC_BENCH_DEFINE(BenchA2P, OpData1A, a2p, as[0]) MPC_BENCH_DEFINE(BenchMsbA2b, OpData1A, msb_a2b, as[0]) -MPC_BENCH_DEFINE(BenchNotA, OpData1A, not_a, as[0]) +MPC_BENCH_DEFINE(BenchNegateA, OpData1A, negate_a, as[0]) MPC_BENCH_DEFINE(BenchAddAP, OpData1A1P, add_ap, as[0], ps[0]) MPC_BENCH_DEFINE(BenchMulAP, OpData1A1P, mul_ap, as[0], ps[0]) MPC_BENCH_DEFINE(BenchAddAA, OpData2A, add_aa, as[0], as[1]) MPC_BENCH_DEFINE(BenchMulAA, OpData2A, mul_aa, as[0], as[1]) MPC_BENCH_DEFINE(BenchMulA1B, OpData1A1B1, mul_a1b, as[0], b1s[0]) -MPC_BENCH_DEFINE(BenchLShiftA, OpData1A, lshift_a, as[0], state.range(2)) +MPC_BENCH_DEFINE(BenchLShiftA, OpData1A, lshift_a, as[0], {state.range(2)}) MPC_BENCH_DEFINE(BenchTruncA, OpData1A, trunc_a_wrapper, as[0], state.range(2)) // MPC_BENCH_DEFINE(BenchMMulAP, OpData1MA1MP, mmul_ap, mas[0], mps[0], // state.range(1), state.range(1), state.range(1)) @@ -257,9 +261,9 @@ MPC_BENCH_DEFINE(BenchAndBP, OpData1B1P, and_bp, bs[0], ps[0]) MPC_BENCH_DEFINE(BenchAndBB, OpData2B, and_bb, bs[0], bs[1]) MPC_BENCH_DEFINE(BenchXorBP, OpData1B1P, xor_bp, bs[0], ps[0]) MPC_BENCH_DEFINE(BenchXorBB, OpData2B, xor_bb, bs[0], bs[1]) -MPC_BENCH_DEFINE(BenchLShiftB, OpData1B, lshift_b, bs[0], state.range(2)) -MPC_BENCH_DEFINE(BenchRShiftB, OpData1B, rshift_b, bs[0], state.range(2)) -MPC_BENCH_DEFINE(BenchARShiftB, OpData1B, arshift_b, bs[0], state.range(2)) +MPC_BENCH_DEFINE(BenchLShiftB, OpData1B, lshift_b, bs[0], {state.range(2)}) +MPC_BENCH_DEFINE(BenchRShiftB, OpData1B, rshift_b, bs[0], {state.range(2)}) +MPC_BENCH_DEFINE(BenchARShiftB, OpData1B, arshift_b, bs[0], {state.range(2)}) MPC_BENCH_DEFINE(BenchBitRevB, OpData1B, bitrev_b, bs[0], 0, SizeOf(static_cast(state.range(0)))) MPC_BENCH_DEFINE(BenchBitIntlB, OpData1B, bitintl_b, bs[0], 0) diff --git a/libspu/mpc/utils/BUILD.bazel b/libspu/mpc/utils/BUILD.bazel index dea5932c..2b494ff1 100644 --- a/libspu/mpc/utils/BUILD.bazel +++ b/libspu/mpc/utils/BUILD.bazel @@ -22,6 +22,7 @@ spu_cc_library( hdrs = ["circuits.h"], deps = [ "//libspu/core:bit_utils", + "//libspu/core:shape", "//libspu/core:vectorize", ], ) @@ -49,6 +50,7 @@ spu_cc_library( hdrs = ["permute.h"], deps = [ "//libspu/core:ndarray_ref", + "@yacl//yacl/crypto/rand", ], ) @@ -74,6 +76,7 @@ spu_cc_library( deps = [ ":linalg", "//libspu/core:ndarray_ref", + "//libspu/core:type_util", "@yacl//yacl/crypto/rand", "@yacl//yacl/crypto/tools:prg", "@yacl//yacl/utils:parallel", @@ -88,6 +91,36 @@ spu_cc_test( ], ) +spu_cc_library( + name = "gfmp_ops", + srcs = ["gfmp_ops.cc"], + hdrs = ["gfmp_ops.h"], + copts = select({ + "@platforms//cpu:x86_64": [ + "-mavx", + ], + "//conditions:default": [], + }), + deps = [ + ":gfmp", + ":linalg", + ":ring_ops", + "//libspu/core:ndarray_ref", + "@yacl//yacl/crypto/rand", + "@yacl//yacl/crypto/tools:prg", + "@yacl//yacl/utils:parallel", + ], +) + +spu_cc_library( + name = "gfmp", + hdrs = ["gfmp.h"], + deps = [ + "//libspu/core:type_util", + "@yacl//yacl/base:int128", + ], +) + spu_cc_binary( name = "ring_ops_bench", srcs = ["ring_ops_bench.cc"], @@ -105,7 +138,7 @@ spu_cc_library( linkopts = OMP_LINKFLAGS, deps = [ "//libspu/core:parallel_utils", - "@com_github_eigenteam_eigen//:eigen3", + "@eigen_archive//:eigen3", ] + OMP_DEPS, ) diff --git a/libspu/mpc/utils/circuits.h b/libspu/mpc/utils/circuits.h index 3d032d6b..06cad9a0 100644 --- a/libspu/mpc/utils/circuits.h +++ b/libspu/mpc/utils/circuits.h @@ -22,6 +22,7 @@ #include "yacl/base/int128.h" #include "libspu/core/bit_utils.h" +#include "libspu/core/shape.h" #include "libspu/core/vectorize.h" namespace spu::mpc { @@ -35,10 +36,10 @@ struct CircuitBasicBlock { using And = std::function; // (logical) left shift - using LShift = std::function; + using LShift = std::function; // (logical) right shift - using RShift = std::function; + using RShift = std::function; // Init a constant. using InitLike = std::function; @@ -72,9 +73,9 @@ T kogge_stone(const CircuitBasicBlock& ctx, T const& lhs, T const& rhs, auto G = ctx._and(lhs, rhs); for (int idx = 0; idx < Log2Ceil(nbits); ++idx) { - const size_t offset = 1UL << idx; - auto G1 = ctx.lshift(G, offset); - auto P1 = ctx.lshift(P, offset); + const int64_t offset = 1L << idx; + auto G1 = ctx.lshift(G, {offset}); + auto P1 = ctx.lshift(P, {offset}); // P1 = P & P1 // G1 = G ^ (P & G1) @@ -90,7 +91,7 @@ T kogge_stone(const CircuitBasicBlock& ctx, T const& lhs, T const& rhs, } // out = (G << 1) ^ p0 - auto C = ctx.lshift(G, 1); + auto C = ctx.lshift(G, {1}); return ctx._xor(ctx._xor(lhs, rhs), C); } @@ -122,12 +123,12 @@ T sklansky(const CircuitBasicBlock& ctx, T const& lhs, T const& rhs, auto G = ctx._and(lhs, rhs); for (int idx = 0; idx < Log2Ceil(nbits); ++idx) { const auto s_mask = ctx.init_like(G, kSelMask[idx]); - auto G1 = ctx.lshift(ctx._and(G, s_mask), 1); - auto P1 = ctx.lshift(ctx._and(P, s_mask), 1); + auto G1 = ctx.lshift(ctx._and(G, s_mask), {1}); + auto P1 = ctx.lshift(ctx._and(P, s_mask), {1}); for (int j = 0; j < idx; j++) { - G1 = ctx._xor(G1, ctx.lshift(G1, 1 << j)); - P1 = ctx._xor(P1, ctx.lshift(P1, 1 << j)); + G1 = ctx._xor(G1, ctx.lshift(G1, {1 << j})); + P1 = ctx._xor(P1, ctx.lshift(P1, {1 << j})); } const auto k_mask = ctx.init_like(G, kKeepMasks[idx]); @@ -147,7 +148,7 @@ T sklansky(const CircuitBasicBlock& ctx, T const& lhs, T const& rhs, } // out = (G0 << 1) ^ p0 - auto C = ctx.lshift(G, 1); + auto C = ctx.lshift(G, {1}); return ctx._xor(ctx._xor(lhs, rhs), C); } @@ -181,21 +182,22 @@ T odd_even_split(const CircuitBasicBlock& ctx, const T& v, size_t nbits) { }}; // let r = v - T r = ctx.lshift(v, 0); + T r = ctx.lshift(v, {0}); for (int idx = 0; idx + 1 < Log2Ceil(nbits); ++idx) { // r = (r & keep) ^ ((r >> i) & move) ^ ((r & move) << i) const auto keep = ctx.init_like(r, kKeepMasks[idx]); const auto move = ctx.init_like(r, kSwapMasks[idx]); r = ctx._xor(ctx._and(r, keep), - ctx._xor(ctx._and(ctx.rshift(r, 1 << idx), move), - ctx.lshift(ctx._and(r, move), 1 << idx))); + ctx._xor(ctx._and(ctx.rshift(r, {1 << idx}), move), + ctx.lshift(ctx._and(r, move), {1 << idx}))); } if (!absl::has_single_bit(nbits)) { // handle non 2^k bits case. T mask = ctx.init_like(r, (1ULL << (nbits / 2)) - 1); - r = ctx._xor(ctx.lshift(ctx.rshift(r, 1 << Log2Floor(nbits)), nbits / 2), + r = ctx._xor(ctx.lshift(ctx.rshift(r, {1 << Log2Floor(nbits)}), + {static_cast(nbits / 2)}), ctx._and(r, mask)); } @@ -228,7 +230,7 @@ T carry_out(const CircuitBasicBlock& ctx, const T& x, const T& y, auto perm = odd_even_split(ctx, in, kk); T mask = ctx.init_like(perm, (static_cast(1) << hk) - 1); T t0 = ctx._and(perm, mask); - T t1 = ctx._and(ctx.rshift(perm, hk), mask); + T t1 = ctx._and(ctx.rshift(perm, {static_cast(hk)}), mask); ctx.set_nbits(t0, hk); ctx.set_nbits(t1, hk); return std::make_tuple(t0, t1); @@ -247,8 +249,8 @@ T carry_out(const CircuitBasicBlock& ctx, const T& x, const T& y, while (k > 1) { if (k % 2 != 0) { k += 1; - P = ctx.lshift(P, 1); - G = ctx.lshift(G, 1); + P = ctx.lshift(P, {1}); + G = ctx.lshift(G, {1}); } auto [P0, P1] = bit_split(P, k); auto [G0, G1] = bit_split(G, k); diff --git a/libspu/mpc/utils/circuits_test.cc b/libspu/mpc/utils/circuits_test.cc index 1e8e034e..c0e82433 100644 --- a/libspu/mpc/utils/circuits_test.cc +++ b/libspu/mpc/utils/circuits_test.cc @@ -26,8 +26,8 @@ CircuitBasicBlock makeScalarCBB() { CircuitBasicBlock cbb; cbb._xor = [](T const& lhs, T const& rhs) -> T { return lhs ^ rhs; }; cbb._and = [](T const& lhs, T const& rhs) -> T { return lhs & rhs; }; - cbb.lshift = [](T const& x, size_t bits) -> T { return x << bits; }; - cbb.rshift = [](T const& x, size_t bits) -> T { return x >> bits; }; + cbb.lshift = [](T const& x, const Sizes& bits) -> T { return x << bits[0]; }; + cbb.rshift = [](T const& x, const Sizes& bits) -> T { return x >> bits[0]; }; cbb.init_like = [](T const&, uint128_t init) -> T { return static_cast(init); }; @@ -53,16 +53,16 @@ CircuitBasicBlock makeVectorCBB() { std::bit_and<>()); return res; }; - cbb.lshift = [](C const& x, size_t bits) -> C { + cbb.lshift = [](C const& x, const Sizes& bits) -> C { C res; std::transform(x.begin(), x.end(), std::back_inserter(res), - [&](const auto& e) { return e << bits; }); + [&](const auto& e) { return e << bits[0]; }); return res; }; - cbb.rshift = [](C const& x, size_t bits) -> C { + cbb.rshift = [](C const& x, const Sizes& bits) -> C { C res; std::transform(x.begin(), x.end(), std::back_inserter(res), - [&](const auto& e) { return e >> bits; }); + [&](const auto& e) { return e >> bits[0]; }); return res; }; cbb.set_nbits = [](C&, size_t) -> void {}; diff --git a/libspu/mpc/utils/gfmp.h b/libspu/mpc/utils/gfmp.h new file mode 100644 index 00000000..8dfc4f5f --- /dev/null +++ b/libspu/mpc/utils/gfmp.h @@ -0,0 +1,168 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yacl/base/int128.h" + +#include "libspu/core/type_util.h" + +#define EIGEN_HAS_OPENMP + +#include "Eigen/Core" + +namespace spu::mpc { + +inline uint8_t mul(uint8_t x, uint8_t y, uint8_t* z) { + uint16_t hi = static_cast(x) * static_cast(y); + auto lo = static_cast(hi); + if (z != nullptr) { + *z = static_cast(hi >> 8); + } + return lo; +} + +inline uint32_t mul(uint32_t x, uint32_t y, uint32_t* z) { + uint64_t hi = static_cast(x) * static_cast(y); + auto lo = static_cast(hi); + if (z != nullptr) { + *z = static_cast(hi >> 32); + } + return lo; +} + +inline uint64_t mul(uint64_t x, uint64_t y, uint64_t* z) { + uint128_t hi = static_cast(x) * static_cast(y); + auto lo = static_cast(hi); + if (z != nullptr) { + *z = static_cast(hi >> 64); + } + return lo; +} + +inline uint128_t mul(uint128_t x, uint128_t y, uint128_t* z) { + uint64_t x_lo = x & 0xFFFFFFFFFFFFFFFF; + uint64_t x_hi = x >> 64; + uint64_t y_lo = y & 0xFFFFFFFFFFFFFFFF; + uint64_t y_hi = y >> 64; + + uint128_t lo = static_cast(x_lo) * y_lo; + + uint128_t xl_yh = static_cast(x_lo) * y_hi; + uint128_t xh_yl = static_cast(x_hi) * y_lo; + + lo += xl_yh << 64; + uint128_t hi = static_cast(lo < (xl_yh << 64)); + + lo += xh_yl << 64; + hi += static_cast(lo < (xh_yl << 64)); + hi += static_cast(x_hi) * y_hi; + + hi += xl_yh >> 64; + hi += xh_yl >> 64; + if (z != nullptr) { + *z = hi; + } + return lo; +} + +template , bool> = true> +inline T mul_mod(T x, T y) { + T c = 0; + T e = mul(x, y, &c); + T p = ScalarTypeToPrime::prime; + size_t mp_exp = ScalarTypeToPrime::exp; + T ret = (e & p) + ((e >> mp_exp) ^ (c << (sizeof(T) * 8 - mp_exp))); + return (ret >= p) ? ret - p : ret; +} + +template , bool> = true> +inline T add_mod(T x, T y) { + T ret = x + y; + T p = ScalarTypeToPrime::prime; + return (ret >= p) ? ret - p : ret; +} + +template , bool> = true> +inline T add_inv(T x) { + T p = ScalarTypeToPrime::prime; + return x ^ p; +} + +// Extended Euclidean Algorithm +// ax + by = gcd(a, b) +template , bool> = true> +void extend_gcd(T a, T b, T& x, T& y) { + if (b == 0) { + x = 1; + y = 0; + return; + } + extend_gcd(b, static_cast(a % b), y, x); + T tmp = mul_mod(static_cast(a / b), x); + y = add_mod(y, add_inv(tmp)); +} + +template , bool> = true> +inline T mul_inv(T in) { + T x; + T y; + T p = ScalarTypeToPrime::prime; + extend_gcd(p, in, x, y); + return y; +} + +template , bool> = true> +inline T mod_p(T in) { + T p = ScalarTypeToPrime::prime; + size_t mp_exp = ScalarTypeToPrime::exp; + T i = (in & p) + (in >> mp_exp); + return i >= p ? i - p : i; +} + +// the following code references SEAL library +// https://github.com/microsoft/SEAL/blob/main/src/seal/util/uintarithsmallmod.cpp +template , bool> = true> +inline T exp_mod(T operand, T exponent) { + // Fast cases + if (exponent == 0) { + // Result is supposed to be only one digit + return 1; + } + + if (exponent == 1) { + return operand; + } + + // Perform binary exponentiation. + T power = operand; + T product = 0; + T intermediate = 1; + + // Initially: power = operand and intermediate = 1, product is irrelevant. + while (true) { + if (exponent & 1) { + product = mul_mod(power, intermediate); + std::swap(product, intermediate); + } + exponent >>= 1; + if (exponent == 0) { + break; + } + product = mul_mod(power, power); + std::swap(product, power); + } + return intermediate; +} +} // namespace spu::mpc \ No newline at end of file diff --git a/libspu/mpc/utils/gfmp_ops.cc b/libspu/mpc/utils/gfmp_ops.cc new file mode 100644 index 00000000..c336a8ba --- /dev/null +++ b/libspu/mpc/utils/gfmp_ops.cc @@ -0,0 +1,251 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#define PFOR_GRAIN_SIZE 4096 + +#include "libspu/mpc/utils/gfmp_ops.h" + +#include + +#include "absl/types/span.h" +#include "yacl/crypto/rand/rand.h" +#include "yacl/crypto/tools/prg.h" + +#include "libspu/mpc/utils/gfmp.h" +#include "libspu/mpc/utils/linalg.h" +#include "libspu/mpc/utils/ring_ops.h" + +namespace spu::mpc { +namespace { + +#define SPU_ENFORCE_RING(x) \ + SPU_ENFORCE((x).eltype().isa(), "expect ring type, got={}", \ + (x).eltype()); + +#define SPU_ENFORCE_GFMP(x) \ + SPU_ENFORCE((x).eltype().isa(), "expect gfmp type, got={}", \ + (x).eltype()); + +#define ENFORCE_EQ_ELSIZE_AND_SHAPE(lhs, rhs) \ + SPU_ENFORCE((lhs).elsize() == (rhs).elsize(), \ + "type size mismatch lhs={}, rhs={}", (lhs).eltype(), \ + (rhs).eltype()); \ + SPU_ENFORCE((lhs).shape() == (rhs).shape(), \ + "numel mismatch, lhs={}, rhs={}", lhs, rhs); + +// Fast mod operation for Mersenne prime +void gfmp_mod_impl(NdArrayRef& ret, const NdArrayRef& x) { + ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); + const auto* ty = ret.eltype().as(); + const auto field = ty->field(); + const auto numel = x.numel(); + DISPATCH_ALL_FIELDS(field, [&]() { + NdArrayView _ret(ret); + NdArrayView _x(x); + pforeach(0, numel, [&](int64_t idx) { _ret[idx] = mod_p(_x[idx]); }); + }); +} + +void gfmp_mul_mod_impl(NdArrayRef& ret, const NdArrayRef& x, + const NdArrayRef& y) { + ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); + ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, y); + const auto* ty = x.eltype().as(); + const auto field = ty->field(); + const auto numel = x.numel(); + DISPATCH_ALL_FIELDS(field, [&]() { + NdArrayView _ret(ret); + NdArrayView _x(x); + NdArrayView _y(y); + pforeach(0, numel, + [&](int64_t idx) { _ret[idx] = mul_mod(_x[idx], _y[idx]); }); + }); +} + +void gfmp_add_mod_impl(NdArrayRef& ret, const NdArrayRef& x, + const NdArrayRef& y) { + ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); + ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, y); + const auto* ty = x.eltype().as(); + const auto field = ty->field(); + const auto numel = x.numel(); + DISPATCH_ALL_FIELDS(field, [&]() { + NdArrayView _ret(ret); + NdArrayView _x(x); + NdArrayView _y(y); + pforeach(0, numel, + [&](int64_t idx) { _ret[idx] = add_mod(_x[idx], _y[idx]); }); + }); +} + +void gfmp_sub_mod_impl(NdArrayRef& ret, const NdArrayRef& x, + const NdArrayRef& y) { + ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); + ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, y); + const auto* ty = x.eltype().as(); + const auto field = ty->field(); + const auto numel = x.numel(); + DISPATCH_ALL_FIELDS(field, [&]() { + NdArrayView _ret(ret); + NdArrayView _x(x); + NdArrayView _y(y); + pforeach(0, numel, [&](int64_t idx) { + _ret[idx] = add_mod(_x[idx], add_inv(_y[idx])); + }); + }); +} + +void gfmp_div_mod_impl(NdArrayRef& ret, const NdArrayRef& x, + const NdArrayRef& y) { + ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); + ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, y); + const auto* ty = x.eltype().as(); + const auto field = ty->field(); + const auto numel = x.numel(); + DISPATCH_ALL_FIELDS(field, [&]() { + NdArrayView _ret(ret); + NdArrayView _x(x); + NdArrayView _y(y); + pforeach(0, numel, [&](int64_t idx) { + _ret[idx] = mul_mod(_x[idx], mul_inv(_y[idx])); + }); + }); +} + +} // namespace +NdArrayRef gfmp_zeros(FieldType field, const Shape& shape) { + NdArrayRef ret(makeType(field), shape); + auto numel = ret.numel(); + + return DISPATCH_ALL_FIELDS(field, [&]() { + NdArrayView _ret(ret); + pforeach(0, numel, [&](int64_t idx) { _ret[idx] = 0; }); + return ret; + }); +} +NdArrayRef gfmp_rand(FieldType field, const Shape& shape) { + uint64_t cnt = 0; + return gfmp_rand(field, shape, yacl::crypto::SecureRandSeed(), &cnt); +} + +NdArrayRef gfmp_rand(FieldType field, const Shape& shape, uint128_t prg_seed, + uint64_t* prg_counter) { + constexpr yacl::crypto::SymmetricCrypto::CryptoType kCryptoType = + yacl::crypto::SymmetricCrypto::CryptoType::AES128_CTR; + constexpr uint128_t kAesInitialVector = 0U; + NdArrayRef res(makeType(field), shape); + DISPATCH_ALL_FIELDS(field, [&]() { + *prg_counter = yacl::crypto::FillPRandWithMersennePrime( + kCryptoType, prg_seed, kAesInitialVector, *prg_counter, + absl::MakeSpan(&res.at(0), res.numel())); + }); + return res; +} + +NdArrayRef gfmp_mod(const NdArrayRef& x) { + SPU_ENFORCE_GFMP(x); + NdArrayRef ret(x.eltype(), x.shape()); + gfmp_mod_impl(ret, x); + return ret; +} + +void gfmp_mod_(NdArrayRef& x) { + SPU_ENFORCE_GFMP(x); + gfmp_mod_impl(x, x); +} + +NdArrayRef gfmp_mul_mod(const NdArrayRef& x, const NdArrayRef& y) { + SPU_ENFORCE_GFMP(x); + SPU_ENFORCE_GFMP(y); + NdArrayRef ret(x.eltype(), x.shape()); + gfmp_mul_mod_impl(ret, x, y); + return ret; +} + +void gfmp_mul_mod_(NdArrayRef& x, const NdArrayRef& y) { + SPU_ENFORCE_GFMP(x); + SPU_ENFORCE_GFMP(y); + gfmp_mul_mod_impl(x, x, y); +} + +NdArrayRef gfmp_div_mod(const NdArrayRef& x, const NdArrayRef& y) { + SPU_ENFORCE_GFMP(x); + SPU_ENFORCE_GFMP(y); + NdArrayRef ret(x.eltype(), x.shape()); + gfmp_div_mod_impl(ret, x, y); + return ret; +} + +void gfmp_div_mod_(NdArrayRef& x, const NdArrayRef& y) { + SPU_ENFORCE_GFMP(x); + SPU_ENFORCE_GFMP(y); + gfmp_div_mod_impl(x, x, y); +} + +NdArrayRef gfmp_add_mod(const NdArrayRef& x, const NdArrayRef& y) { + SPU_ENFORCE_GFMP(x); + SPU_ENFORCE_GFMP(y); + NdArrayRef ret(x.eltype(), x.shape()); + gfmp_add_mod_impl(ret, x, y); + return ret; +} + +void gfmp_add_mod_(NdArrayRef& x, const NdArrayRef& y) { + SPU_ENFORCE_GFMP(x); + SPU_ENFORCE_GFMP(y); + gfmp_add_mod_impl(x, x, y); +} + +NdArrayRef gfmp_sub_mod(const NdArrayRef& x, const NdArrayRef& y) { + SPU_ENFORCE_GFMP(x); + SPU_ENFORCE_GFMP(y); + NdArrayRef ret(x.eltype(), x.shape()); + gfmp_sub_mod_impl(ret, x, y); + return ret; +} + +void gfmp_sub_mod_(NdArrayRef& x, const NdArrayRef& y) { + SPU_ENFORCE_GFMP(x); + SPU_ENFORCE_GFMP(y); + gfmp_sub_mod_impl(x, x, y); +} + +// not requiring and not casting field. +void gfmp_exp_mod_impl(NdArrayRef& ret, const NdArrayRef& x, + const NdArrayRef& y) { + ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); + ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, y); + const auto* ty = x.eltype().as(); + const auto field = ty->field(); + const auto numel = x.numel(); + DISPATCH_ALL_FIELDS(field, [&]() { + NdArrayView _ret(ret); + NdArrayView _x(x); + NdArrayView _y(y); + pforeach(0, numel, + [&](int64_t idx) { _ret[idx] = exp_mod(_x[idx], _y[idx]); }); + }); +} + +NdArrayRef gfmp_exp_mod(const NdArrayRef& x, const NdArrayRef& y) { + NdArrayRef ret(x.eltype(), x.shape()); + gfmp_exp_mod_impl(ret, x, y); + return ret; +} + +void gfmp_exp_mod_(NdArrayRef& x, const NdArrayRef& y) { + gfmp_exp_mod_impl(x, x, y); +} + +} // namespace spu::mpc diff --git a/libspu/mpc/utils/gfmp_ops.h b/libspu/mpc/utils/gfmp_ops.h new file mode 100644 index 00000000..31d275e8 --- /dev/null +++ b/libspu/mpc/utils/gfmp_ops.h @@ -0,0 +1,45 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "libspu/core/ndarray_ref.h" + +namespace spu::mpc { + +NdArrayRef gfmp_rand(FieldType field, const Shape& shape); +NdArrayRef gfmp_rand(FieldType field, const Shape& shape, uint128_t prg_seed, + uint64_t* prg_counter); + +NdArrayRef gfmp_zeros(FieldType field, const Shape& shape); + +NdArrayRef gfmp_mod(const NdArrayRef& x); +void gfmp_mod_(NdArrayRef& x); + +NdArrayRef gfmp_mul_mod(const NdArrayRef& x, const NdArrayRef& y); +void gfmp_mul_mod_(NdArrayRef& x, const NdArrayRef& y); + +NdArrayRef gfmp_div_mod(const NdArrayRef& x, const NdArrayRef& y); +void gfmp_div_mod_(NdArrayRef& x, const NdArrayRef& y); + +NdArrayRef gfmp_add_mod(const NdArrayRef& x, const NdArrayRef& y); +void gfmp_add_mod_(NdArrayRef& x, const NdArrayRef& y); + +NdArrayRef gfmp_sub_mod(const NdArrayRef& x, const NdArrayRef& y); +void gfmp_sub_mod_(NdArrayRef& x, const NdArrayRef& y); + +NdArrayRef gfmp_exp_mod(const NdArrayRef& x, const NdArrayRef& y); +void gfmp_exp_mod_(NdArrayRef& x, const NdArrayRef& y); + +} // namespace spu::mpc diff --git a/libspu/mpc/utils/permute.cc b/libspu/mpc/utils/permute.cc index 1c5d2920..62a8e85a 100644 --- a/libspu/mpc/utils/permute.cc +++ b/libspu/mpc/utils/permute.cc @@ -17,17 +17,19 @@ #include #include +#include "yacl/crypto/rand/rand.h" + #include "libspu/core/ndarray_ref.h" #include "libspu/core/type_util.h" namespace spu::mpc { -PermVector ring2pv(const NdArrayRef& x) { +Index ring2pv(const NdArrayRef& x) { SPU_ENFORCE(x.eltype().isa(), "must be ring2k_type, got={}", x.eltype()); const auto field = x.eltype().as()->field(); - PermVector pv(x.numel()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + Index pv(x.numel()); + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _x(x); pforeach(0, x.numel(), [&](int64_t idx) { pv[idx] = int64_t(_x[idx]); }); }); @@ -39,7 +41,7 @@ NdArrayRef applyInvPerm(const NdArrayRef& x, absl::Span pv) { NdArrayRef y(x.eltype(), x.shape()); const auto field = x.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, kPermModule, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _x(x); NdArrayView _y(y); for (int64_t i = 0; i < y.numel(); i++) { @@ -49,12 +51,32 @@ NdArrayRef applyInvPerm(const NdArrayRef& x, absl::Span pv) { return y; } +NdArrayRef applyInvPerm(const NdArrayRef& x, const NdArrayRef& pv) { + SPU_ENFORCE_EQ(x.shape().ndim(), 1U, "x should be 1-d tensor"); + SPU_ENFORCE_EQ(x.shape(), pv.shape(), "x and pv should have same shape"); + + NdArrayRef y(x.eltype(), x.shape()); + const auto field = x.eltype().as()->field(); + DISPATCH_ALL_FIELDS(field, [&]() { + NdArrayView _x(x); + NdArrayView _y(y); + const auto pv_field = pv.eltype().as()->field(); + DISPATCH_ALL_FIELDS(pv_field, [&]() { + NdArrayView _pv(pv); + for (int64_t i = 0; i < y.numel(); i++) { + _y[_pv[i]] = _x[i]; + } + }); + }); + return y; +} + NdArrayRef applyPerm(const NdArrayRef& x, absl::Span pv) { SPU_ENFORCE_EQ(x.shape().ndim(), 1U, "x should be 1-d tensor"); NdArrayRef y(x.eltype(), x.shape()); const auto field = x.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, kPermModule, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _x(x); NdArrayView _y(y); for (int64_t i = 0; i < y.numel(); i++) { @@ -64,29 +86,45 @@ NdArrayRef applyPerm(const NdArrayRef& x, absl::Span pv) { return y; } -PermVector genRandomPerm(size_t size, uint64_t seed) { - PermVector perm(size); - std::iota(perm.begin(), perm.end(), 0); - // TODO: change PRNG to CSPRNG - std::mt19937 rng(seed); - std::shuffle(perm.begin(), perm.end(), rng); - return perm; +NdArrayRef applyPerm(const NdArrayRef& x, const NdArrayRef& pv) { + SPU_ENFORCE_EQ(x.shape().ndim(), 1U, "x should be 1-d tensor"); + SPU_ENFORCE_EQ(x.shape(), pv.shape(), "x and pv should have same shape"); + + NdArrayRef y(x.eltype(), x.shape()); + const auto field = x.eltype().as()->field(); + DISPATCH_ALL_FIELDS(field, [&]() { + NdArrayView _x(x); + NdArrayView _y(y); + const auto pv_field = pv.eltype().as()->field(); + DISPATCH_ALL_FIELDS(pv_field, [&]() { + NdArrayView _pv(pv); + for (int64_t i = 0; i < y.numel(); i++) { + _y[i] = _x[_pv[i]]; + } + }); + }); + return y; } -PermVector genInversePerm(absl::Span pv) { - PermVector ret(pv.size()); - for (size_t i = 0; i < pv.size(); ++i) { - ret[pv[i]] = i; - } +NdArrayRef genInversePerm(const NdArrayRef& perm) { + NdArrayRef ret(perm.eltype(), perm.shape()); + auto field = perm.eltype().as()->field(); + DISPATCH_ALL_FIELDS(field, [&]() { + NdArrayView _ret(ret); + NdArrayView _perm(perm); + for (int64_t i = 0; i < perm.numel(); ++i) { + _ret[_perm[i]] = ring2k_t(i); + } + }); return ret; } -PermVector genPermBySort(const NdArrayRef& x) { +Index genPermBySort(const NdArrayRef& x) { SPU_ENFORCE_EQ(x.shape().ndim(), 1U, "x should be 1-d tensor"); - PermVector perm(x.shape()[0]); + Index perm(x.shape()[0]); std::iota(perm.begin(), perm.end(), 0); const auto field = x.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, kPermModule, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = std::make_signed_t; NdArrayView _x(x); @@ -96,4 +134,11 @@ PermVector genPermBySort(const NdArrayRef& x) { return perm; } +Index genRandomPerm(size_t numel, uint128_t seed, uint64_t* ctr) { + Index perm(numel); + std::iota(perm.begin(), perm.end(), 0); + yacl::crypto::ReplayShuffle(perm.begin(), perm.end(), seed, ctr); + return perm; +} + } // namespace spu::mpc \ No newline at end of file diff --git a/libspu/mpc/utils/permute.h b/libspu/mpc/utils/permute.h index 9034d79b..5c4eb1d2 100644 --- a/libspu/mpc/utils/permute.h +++ b/libspu/mpc/utils/permute.h @@ -20,24 +20,23 @@ namespace spu::mpc { constexpr char kPermModule[] = "Permute"; -using PermVector = std::vector; - -PermVector genRandomPerm(size_t size, uint64_t seed); - -PermVector genInversePerm(absl::Span pv); +NdArrayRef genInversePerm(const NdArrayRef& perm); // generate permutation vector that can make x ordered -PermVector genPermBySort(const NdArrayRef& x); +Index genPermBySort(const NdArrayRef& x); // reorder 1-d tensor element by applying inverse permutation. // ret = ApplyInvPerm(x, pv) -> ret[pv[i]] = x[i] NdArrayRef applyInvPerm(const NdArrayRef& x, absl::Span pv); +NdArrayRef applyInvPerm(const NdArrayRef& x, const NdArrayRef& pv); // reorder 1-d tensor element by applying permutation. // ret = ApplyPerm(x, pv) -> ret[i] = x[pv[i]] NdArrayRef applyPerm(const NdArrayRef& x, absl::Span pv); +NdArrayRef applyPerm(const NdArrayRef& x, const NdArrayRef& pv); // get a permutation vector from a ring -PermVector ring2pv(const NdArrayRef& x); +Index ring2pv(const NdArrayRef& x); +Index genRandomPerm(size_t numel, uint128_t seed, uint64_t* ctr); } // namespace spu::mpc \ No newline at end of file diff --git a/libspu/mpc/utils/ring_ops.cc b/libspu/mpc/utils/ring_ops.cc index 0d49f743..e0e67b2b 100644 --- a/libspu/mpc/utils/ring_ops.cc +++ b/libspu/mpc/utils/ring_ops.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#define PFOR_GRAIN_SIZE 4096 - #include "libspu/mpc/utils/ring_ops.h" #include @@ -29,8 +27,6 @@ namespace spu::mpc { namespace { -constexpr char kModule[] = "RingOps"; - #define SPU_ENFORCE_RING(x) \ SPU_ENFORCE((x).eltype().isa(), "expect ring type, got={}", \ (x).eltype()); @@ -47,7 +43,7 @@ constexpr char kModule[] = "RingOps"; ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); \ const auto field = x.eltype().as()->field(); \ const int64_t numel = ret.numel(); \ - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { \ + return DISPATCH_ALL_FIELDS(field, [&]() { \ using T = std::make_signed_t; \ NdArrayView _x(x); \ NdArrayView _ret(ret); \ @@ -67,7 +63,7 @@ DEF_UNARY_RING_OP(ring_neg, -); ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, y); \ const auto field = x.eltype().as()->field(); \ const int64_t numel = ret.numel(); \ - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { \ + return DISPATCH_ALL_FIELDS(field, [&]() { \ NdArrayView _x(x); \ NdArrayView _y(y); \ NdArrayView _ret(ret); \ @@ -86,40 +82,56 @@ DEF_BINARY_RING_OP(ring_xor, ^); #undef DEF_BINARY_RING_OP -void ring_arshift_impl(NdArrayRef& ret, const NdArrayRef& x, size_t bits) { +void ring_arshift_impl(NdArrayRef& ret, const NdArrayRef& x, + const Sizes& bits) { ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); + bool is_splat = bits.size() == 1; + SPU_ENFORCE(static_cast(bits.size()) == x.numel() || is_splat, + "mismatched numel {} vs {}", bits.size(), x.numel()); const auto numel = ret.numel(); const auto field = x.eltype().as()->field(); - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { // According to K&R 2nd edition the results are implementation-dependent for // right shifts of signed values, but "usually" its arithmetic right shift. using S = std::make_signed::type; NdArrayView _ret(ret); NdArrayView _x(x); - pforeach(0, numel, [&](int64_t idx) { _ret[idx] = _x[idx] >> bits; }); + pforeach(0, numel, [&](int64_t idx) { + _ret[idx] = _x[idx] >> (is_splat ? bits[0] : bits[idx]); + }); }); } -void ring_rshift_impl(NdArrayRef& ret, const NdArrayRef& x, size_t bits) { +void ring_rshift_impl(NdArrayRef& ret, const NdArrayRef& x, const Sizes& bits) { ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); + bool is_splat = bits.size() == 1; + SPU_ENFORCE(static_cast(bits.size()) == x.numel() || is_splat, + "mismatched numel {} vs {}", bits.size(), x.numel()); const auto numel = ret.numel(); const auto field = x.eltype().as()->field(); - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; NdArrayView _ret(ret); NdArrayView _x(x); - pforeach(0, numel, [&](int64_t idx) { _ret[idx] = _x[idx] >> bits; }); + pforeach(0, numel, [&](int64_t idx) { + _ret[idx] = _x[idx] >> (is_splat ? bits[0] : bits[idx]); + }); }); } -void ring_lshift_impl(NdArrayRef& ret, const NdArrayRef& x, size_t bits) { +void ring_lshift_impl(NdArrayRef& ret, const NdArrayRef& x, const Sizes& bits) { ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); + bool is_splat = bits.size() == 1; + SPU_ENFORCE(static_cast(bits.size()) == x.numel() || is_splat, + "mismatched numel {} vs {}", bits.size(), x.numel()); const auto numel = ret.numel(); const auto field = x.eltype().as()->field(); - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _ret(ret); NdArrayView _x(x); - pforeach(0, numel, [&](int64_t idx) { _ret[idx] = _x[idx] << bits; }); + pforeach(0, numel, [&](int64_t idx) { + _ret[idx] = _x[idx] << (is_splat ? bits[0] : bits[idx]); + }); }); } @@ -130,7 +142,7 @@ void ring_bitrev_impl(NdArrayRef& ret, const NdArrayRef& x, size_t start, const auto field = x.eltype().as()->field(); const auto numel = ret.numel(); - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; // optimize: use faster reverse method. @@ -161,7 +173,7 @@ void ring_bitmask_impl(NdArrayRef& ret, const NdArrayRef& x, size_t low, SPU_ENFORCE(low < high && high <= SizeOf(field) * 8); - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; U mask = 0; if (high - low < SizeOf(field) * 8) { @@ -184,7 +196,7 @@ void ring_print(const NdArrayRef& x, std::string_view name) { SPU_ENFORCE_RING(x); const auto field = x.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, kModule, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; std::string out; @@ -211,7 +223,7 @@ NdArrayRef ring_rand(FieldType field, const Shape& shape) { NdArrayRef ring_rand(FieldType field, const Shape& shape, uint128_t prg_seed, uint64_t* prg_counter) { constexpr yacl::crypto::SymmetricCrypto::CryptoType kCryptoType = - yacl::crypto::SymmetricCrypto::CryptoType::AES128_CTR; + yacl::crypto::SymmetricCrypto::CryptoType::AES128_ECB; constexpr uint128_t kAesInitialVector = 0U; NdArrayRef res(makeType(field), shape); @@ -222,21 +234,27 @@ NdArrayRef ring_rand(FieldType field, const Shape& shape, uint128_t prg_seed, return res; } -NdArrayRef ring_rand_range(FieldType field, const Shape& shape, int32_t min, - int32_t max) { - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution dis(min, max); +NdArrayRef ring_rand_range(FieldType field, const Shape& shape, uint128_t min, + uint128_t max) { + constexpr yacl::crypto::SymmetricCrypto::CryptoType kCryptoType = + yacl::crypto::SymmetricCrypto::CryptoType::AES128_ECB; + constexpr uint64_t kAesInitialVector = 0U; + uint64_t cnt = 0; NdArrayRef x(makeType(field), shape); auto numel = x.numel(); - DISPATCH_ALL_FIELDS(field, kModule, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { + std::vector rand_range(numel); + yacl::crypto::FillPRandWithLtN( + kCryptoType, yacl::crypto::SecureRandSeed(), kAesInitialVector, cnt, + absl::MakeSpan(rand_range), static_cast(max - min + 1)); SPU_ENFORCE(sizeof(ring2k_t) >= sizeof(int32_t)); auto iter = x.begin(); for (auto idx = 0; idx < numel; ++idx, ++iter) { - iter.getScalarValue() = static_cast(dis(gen)); + iter.getScalarValue() = + rand_range[idx] + static_cast(min); } }); @@ -250,7 +268,7 @@ void ring_assign(NdArrayRef& x, const NdArrayRef& y) { const auto numel = x.numel(); const auto field = x.eltype().as()->field(); - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _y(y); NdArrayView _x(x); pforeach(0, numel, [&](int64_t idx) { _x[idx] = _y[idx]; }); @@ -261,7 +279,7 @@ NdArrayRef ring_zeros(FieldType field, const Shape& shape) { NdArrayRef ret(makeType(field), shape); auto numel = ret.numel(); - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _ret(ret); pforeach(0, numel, [&](int64_t idx) { _ret[idx] = ring2k_t(0); }); return ret; @@ -272,7 +290,7 @@ NdArrayRef ring_ones(FieldType field, const Shape& shape) { NdArrayRef ret(makeType(field), shape); auto numel = ret.numel(); - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _ret(ret); pforeach(0, numel, [&](int64_t idx) { _ret[idx] = ring2k_t(1); }); return ret; @@ -280,17 +298,15 @@ NdArrayRef ring_ones(FieldType field, const Shape& shape) { } NdArrayRef ring_randbit(FieldType field, const Shape& shape) { - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution<> distrib(0, RAND_MAX); - NdArrayRef ret(makeType(field), shape); auto numel = ret.numel(); - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + auto rand_bytes = yacl::crypto::RandBytes(numel, false); + + return DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _ret(ret); for (auto idx = 0; idx < numel; ++idx) { - _ret[idx] = distrib(gen) & 0x1; + _ret[idx] = static_cast(rand_bytes[idx]) & 0x1; } return ret; }); @@ -341,7 +357,7 @@ void ring_mul_impl(NdArrayRef& ret, const NdArrayRef& x, uint128_t y) { const auto numel = x.numel(); const auto field = x.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, kModule, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = std::make_unsigned::type; NdArrayView _x(x); NdArrayView _ret(ret); @@ -368,7 +384,7 @@ void ring_mmul_impl(NdArrayRef& z, const NdArrayRef& lhs, SPU_ENFORCE(rhs.eltype().isa(), "rhs not ring, got={}", rhs.eltype()); const auto field = lhs.eltype().as()->field(); - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { const auto lhs_stride_scale = lhs.elsize() / sizeof(ring2k_t); const auto rhs_stride_scale = rhs.elsize() / sizeof(ring2k_t); const auto ret_stride_scale = z.elsize() / sizeof(ring2k_t); @@ -436,31 +452,35 @@ void ring_equal_(NdArrayRef& x, const NdArrayRef& y) { ring_equal_impl(x, x, y); } -NdArrayRef ring_arshift(const NdArrayRef& x, size_t bits) { +NdArrayRef ring_arshift(const NdArrayRef& x, const Sizes& bits) { NdArrayRef res(x.eltype(), x.shape()); ring_arshift_impl(res, x, bits); return res; } -void ring_arshift_(NdArrayRef& x, size_t bits) { +void ring_arshift_(NdArrayRef& x, const Sizes& bits) { ring_arshift_impl(x, x, bits); } -NdArrayRef ring_rshift(const NdArrayRef& x, size_t bits) { +NdArrayRef ring_rshift(const NdArrayRef& x, const Sizes& bits) { NdArrayRef res(x.eltype(), x.shape()); ring_rshift_impl(res, x, bits); return res; } -void ring_rshift_(NdArrayRef& x, size_t bits) { ring_rshift_impl(x, x, bits); } +void ring_rshift_(NdArrayRef& x, const Sizes& bits) { + ring_rshift_impl(x, x, bits); +} -NdArrayRef ring_lshift(const NdArrayRef& x, size_t bits) { +NdArrayRef ring_lshift(const NdArrayRef& x, const Sizes& bits) { NdArrayRef res(x.eltype(), x.shape()); ring_lshift_impl(res, x, bits); return res; } -void ring_lshift_(NdArrayRef& x, size_t bits) { ring_lshift_impl(x, x, bits); } +void ring_lshift_(NdArrayRef& x, const Sizes& bits) { + ring_lshift_impl(x, x, bits); +} NdArrayRef ring_bitrev(const NdArrayRef& x, size_t start, size_t end) { NdArrayRef res(x.eltype(), x.shape()); @@ -503,7 +523,7 @@ bool ring_all_equal(const NdArrayRef& x, const NdArrayRef& y, size_t abs_err) { auto numel = x.numel(); const auto field = x.eltype().as()->field(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using T = std::make_signed_t; NdArrayView _x(x); @@ -529,7 +549,7 @@ std::vector ring_cast_boolean(const NdArrayRef& x) { auto numel = x.numel(); std::vector res(numel); - DISPATCH_ALL_FIELDS(field, kModule, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _x(x); pforeach(0, numel, [&](int64_t idx) { res[idx] = static_cast(_x[idx] & 0x1); @@ -549,7 +569,7 @@ NdArrayRef ring_select(const std::vector& c, const NdArrayRef& x, NdArrayRef z(x.eltype(), x.shape()); const int64_t numel = c.size(); - DISPATCH_ALL_FIELDS(field, kModule, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _x(x); NdArrayView _y(y); NdArrayView _z(z); diff --git a/libspu/mpc/utils/ring_ops.h b/libspu/mpc/utils/ring_ops.h index cea10a06..035cd955 100644 --- a/libspu/mpc/utils/ring_ops.h +++ b/libspu/mpc/utils/ring_ops.h @@ -18,38 +18,13 @@ namespace spu::mpc { -#define DEF_RVALUE_BINARY_RING_OP(op_name, commutative) \ - template \ - typename std::enable_if< \ - std::is_same_v>> && \ - std::is_same_v>>, \ - NdArrayRef>::type \ - op_name(X&& x, Y&& y) { \ - if constexpr (std::is_rvalue_reference_v) { \ - op_name##_(x, y); \ - if constexpr (std::is_rvalue_reference_v) { \ - NdArrayRef dummy = std::move(y); \ - } \ - return std::move(x); \ - } else if constexpr (std::is_rvalue_reference_v && \ - COMMUTATIVE) { \ - op_name##_(y, x); \ - return std::move(y); \ - } else { \ - return op_name(static_cast(x), \ - static_cast(y)); \ - } \ - } - void ring_print(const NdArrayRef& x, std::string_view name = "_"); NdArrayRef ring_rand(FieldType field, const Shape& shape); NdArrayRef ring_rand(FieldType field, const Shape& shape, uint128_t prg_seed, uint64_t* prg_counter); -NdArrayRef ring_rand_range(FieldType field, const Shape& shape, int32_t min, - int32_t max); +NdArrayRef ring_rand_range(FieldType field, const Shape& shape, uint128_t min, + uint128_t max); NdArrayRef ring_zeros(FieldType field, const Shape& shape); @@ -65,15 +40,12 @@ void ring_neg_(NdArrayRef& x); NdArrayRef ring_add(const NdArrayRef& x, const NdArrayRef& y); void ring_add_(NdArrayRef& x, const NdArrayRef& y); -DEF_RVALUE_BINARY_RING_OP(ring_add, true); NdArrayRef ring_sub(const NdArrayRef& x, const NdArrayRef& y); void ring_sub_(NdArrayRef& x, const NdArrayRef& y); -DEF_RVALUE_BINARY_RING_OP(ring_sub, false); NdArrayRef ring_mul(const NdArrayRef& x, const NdArrayRef& y); void ring_mul_(NdArrayRef& x, const NdArrayRef& y); -DEF_RVALUE_BINARY_RING_OP(ring_mul, true); NdArrayRef ring_mul(const NdArrayRef& x, uint128_t y); void ring_mul_(NdArrayRef& x, uint128_t y); @@ -87,24 +59,21 @@ void ring_not_(NdArrayRef& x); NdArrayRef ring_and(const NdArrayRef& x, const NdArrayRef& y); void ring_and_(NdArrayRef& x, const NdArrayRef& y); -DEF_RVALUE_BINARY_RING_OP(ring_and, true); NdArrayRef ring_xor(const NdArrayRef& x, const NdArrayRef& y); void ring_xor_(NdArrayRef& x, const NdArrayRef& y); -DEF_RVALUE_BINARY_RING_OP(ring_xor, true); NdArrayRef ring_equal(const NdArrayRef& x, const NdArrayRef& y); void ring_equal_(NdArrayRef& x, const NdArrayRef& y); -DEF_RVALUE_BINARY_RING_OP(ring_equal, true); -NdArrayRef ring_arshift(const NdArrayRef& x, size_t bits); -void ring_arshift_(NdArrayRef& x, size_t bits); +NdArrayRef ring_arshift(const NdArrayRef& x, const Sizes& bits); +void ring_arshift_(NdArrayRef& x, const Sizes& bits); -NdArrayRef ring_rshift(const NdArrayRef& x, size_t bits); -void ring_rshift_(NdArrayRef& x, size_t bits); +NdArrayRef ring_rshift(const NdArrayRef& x, const Sizes& bits); +void ring_rshift_(NdArrayRef& x, const Sizes& bits); -NdArrayRef ring_lshift(const NdArrayRef& x, size_t bits); -void ring_lshift_(NdArrayRef& x, size_t bits); +NdArrayRef ring_lshift(const NdArrayRef& x, const Sizes& bits); +void ring_lshift_(NdArrayRef& x, const Sizes& bits); NdArrayRef ring_bitrev(const NdArrayRef& x, size_t start, size_t end); void ring_bitrev_(NdArrayRef& x, size_t start, size_t end); @@ -138,6 +107,4 @@ void ring_set_value(NdArrayRef& in, const T& value) { pforeach(0, in.numel(), [&](int64_t idx) { _in[idx] = value; }); }; -#undef DEF_RVALUE_BINARY_RING_OP - } // namespace spu::mpc diff --git a/libspu/spu.proto b/libspu/spu.proto index 93793d83..a9050c5d 100644 --- a/libspu/spu.proto +++ b/libspu/spu.proto @@ -240,6 +240,7 @@ message RuntimeConfig { EXP_DEFAULT = 0; // Implementation defined. EXP_PADE = 1; // The pade approximation. EXP_TAYLOR = 2; // Taylor series approximation. + EXP_PRIME = 3; // exp prime only available for some implementations } // The exponent approximation method. @@ -331,6 +332,25 @@ message RuntimeConfig { uint64 experimental_inter_op_concurrency = 104; // Enable use of private type bool experimental_enable_colocated_optimization = 105; + + // enable experimental exp prime method + bool experimental_enable_exp_prime = 106; + + // The offset parameter for exp prime methods. + // control the valid range of exp prime method. + // valid range is: + // ((47 - offset - 2fxp)/log_2(e), (125 - 2fxp - offset)/log_2(e)) + // clamp to value would be + // lower bound: (48 - offset - 2fxp)/log_2(e) + // higher bound: (124 - 2fxp - offset)/log_2(e) + // default offset is 13, 0 offset is not supported. + uint32 experimental_exp_prime_offset = 107; + // whether to apply the clamping lower bound + // default to enable it + bool experimental_exp_prime_disable_lower_bound = 108; + // whether to apply the clamping upper bound + // default to disable it + bool experimental_exp_prime_enable_upper_bound = 109; } message TTPBeaverConfig { diff --git a/libspu/version.h b/libspu/version.h index 632933d0..6bfefca5 100644 --- a/libspu/version.h +++ b/libspu/version.h @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#define SPU_VERSION "0.9.2b0" +#define SPU_VERSION "0.9.3b0" #include diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 00000000..fbcb1153 --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,5 @@ +{ + "executionEnvironments": [ + {"root": "."} + ] +} diff --git a/requirements-dev.txt b/requirements-dev.txt index 6ee814f8..33bcd4da 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,6 @@ pandas>=1.4.2 -flax -scikit-learn +flax<0.10.0 +scikit-learn<1.6.0 # for tests absl-py>=1.1.0 tensorflow-cpu>=2.12.0; sys_platform == "linux" and platform_machine == 'x86_64' diff --git a/requirements.txt b/requirements.txt index ac8fd6cc..2c52e827 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ grpcio>=1.42.0,!=1.48.0 -numpy>=1.22.0 +numpy>=1.22.0, <2 # FIXME: for SF compatibility protobuf>=4, <5 cloudpickle>=2.0.0 multiprocess>=0.70.12.2 cachetools>=5.0.0 -jax[cpu]>=0.4.16, <=0.4.26 # FIXME: Jax 0.4.26+ select perf issue +jax[cpu]>=0.4.16, <=0.4.34 # FIXME: Jax 0.4.26+ select perf issue termcolor>=2.0.0 diff --git a/setup.py b/setup.py index f1585582..cb1cac99 100644 --- a/setup.py +++ b/setup.py @@ -256,9 +256,9 @@ def has_ext_modules(self): if sys.platform == "darwin": # Due to a bug in conda x64 python, platform tag has to be 10_16 for X64 wheel if platform.machine() == "x86_64": - plat_name = "macosx_12_0_x86_64" + plat_name = "macosx_13_0_x86_64" else: - plat_name = "macosx_12_0_arm64" + plat_name = "macosx_13_0_arm64" elif platform.machine() == "aarch64": # Linux aarch64 plat_name = "manylinux_2_28_aarch64" diff --git a/sml/ensemble/BUILD.bazel b/sml/ensemble/BUILD.bazel index 7832e732..2572dc68 100644 --- a/sml/ensemble/BUILD.bazel +++ b/sml/ensemble/BUILD.bazel @@ -1,4 +1,4 @@ -# Copyright 2023 Ant Group Co., Ltd. +# Copyright 2024 Ant Group Co., Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,3 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +load("@rules_python//python:defs.bzl", "py_library") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "adaboost", + srcs = ["adaboost.py"], + deps = [ + "//sml/tree", + ], +) + +py_library( + name = "forest", + srcs = ["forest.py"], + deps = [ + "//sml/tree", + ], +) diff --git a/sml/ensemble/adaboost.py b/sml/ensemble/adaboost.py new file mode 100644 index 00000000..c37743dd --- /dev/null +++ b/sml/ensemble/adaboost.py @@ -0,0 +1,283 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# 不支持early_stop + +import copy +import warnings + +import jax.numpy as jnp +from jax import lax + +from sml.tree.tree import DecisionTreeClassifier as sml_dtc + + +class AdaBoostClassifier: + """A adaboost classifier based on DecisionTreeClassifier. + + Parameters + ---------- + estimator : {"dtc"}, default="dtc" + Specifies the type of model or algorithm to be used for training. + Supported estimators are "dtc". + + n_estimators : int + The number of estimators. Must specify an integer > 0. + + learning_rate : float + The step size used to update the model weights during training. + It's an float, must learning_rate > 0. + + algorithm : str (default='discrete') + The boosting algorithm to use. Only the SAMME discrete algorithm is used in this implementation. + In scikit-learn, the Real Boosting Algorithm (SAMME.R) will be deprecated. + + epsilon : float (default=1e-5) + A small positive value used in calculations to avoid division by zero and other numerical issues. + Must be greater than 0 and less than 0.1. + + """ + + def __init__( + self, + estimator, + n_estimators, + learning_rate, + algorithm, + epsilon=1e-5, + ): + assert isinstance( + estimator, sml_dtc + ), "Estimator other than sml_dtc is not supported." + assert ( + n_estimators is not None and n_estimators > 0 + ), "n_estimators should not be None and must > 0." + assert algorithm == "discrete", ( + "Only support SAMME discrete algorithm. " + "In scikit-learn, the Real Boosting Algorithm (SAMME.R) will be deprecated. " + "You can refer to the official documentation for more details: " + "https://github.com/scikit-learn/scikit-learn/issues/26784" + ) + assert epsilon > 0 and epsilon < 0.1, "epsilon must be > 0 and < 0.1." + + self.estimator = estimator + self.n_estimators = n_estimators + self.learning_rate = learning_rate + self.algorithm = algorithm + self.epsilon = epsilon + + self.n_classes = estimator.n_labels + + self.estimators_ = [] + self.estimator_weight_ = jnp.zeros(self.n_estimators, dtype=jnp.float32) + self.estimator_errors_ = jnp.ones(self.n_estimators, dtype=jnp.float32) + self.estimator_flags_ = jnp.zeros(self.n_estimators, dtype=jnp.bool_) + self.early_stop = False # 添加 early_stop 标志 + + def _num_samples(self, x): + """返回x中的样本数量.""" + if hasattr(x, 'fit'): + # 检查是否是一个estimator + raise TypeError('Expected sequence or array-like, got estimator') + if ( + not hasattr(x, '__len__') + and not hasattr(x, 'shape') + and not hasattr(x, '__array__') + ): + raise TypeError("Expected sequence or array-like, got %s" % type(x)) + + if hasattr(x, 'shape'): + if len(x.shape) == 0: # scalar + raise TypeError( + "Singleton array %r cannot be considered a valid collection." % x + ) + return x.shape[0] + else: + return len(x) + + def _check_sample_weight(self, sample_weight, X): + ''' + Description: Validate and process sample weights. + + Parameters: + - sample_weight: Can be None, a scalar (int or float), or a 1D array-like. + - X: Input data from which to determine the number of samples. + + Returns: + - sample_weight: A 1D array of sample weights, one for each sample in X. + + Sample weight scenarios: + 1. None: + - If sample_weight is None, it will be initialized to an array of ones, + meaning all samples are equally weighted. + 2. Scalar (int or float): + - If sample_weight is a scalar, it will be converted to an array where + each sample's weight is equal to the scalar value. + 3. Array-like: + - If sample_weight is an array or array-like, it will be converted to a JAX array. + - The array must be 1D and its length must match the number of samples. + - If these conditions are not met, an error will be raised. + ''' + n_samples = self._num_samples(X) + + if sample_weight is None: + sample_weight = jnp.ones(n_samples, dtype=jnp.float32) + elif isinstance(sample_weight, (jnp.int32, jnp.float32)): + sample_weight = jnp.full(n_samples, sample_weight, dtype=jnp.float32) + else: + sample_weight = jnp.asarray(sample_weight, dtype=jnp.float32) + if sample_weight.ndim != 1: + raise ValueError("Sample weight must be 1D array or scalar") + + if sample_weight.shape[0] != n_samples: + raise ValueError( + "sample_weight.shape == {}, expected {}!".format( + sample_weight.shape, (n_samples,) + ) + ) + + return sample_weight + + def fit(self, X, y, sample_weight=None): + sample_weight = self._check_sample_weight( + sample_weight, + X, + ) + sample_weight /= sample_weight.sum() + + self.classes = y + + epsilon = self.epsilon + + for iboost in range(self.n_estimators): + sample_weight = jnp.clip(sample_weight, a_min=epsilon, a_max=None) + + estimator = copy.deepcopy(self.estimator) + sample_weight, estimator_weight, estimator_error, flag = ( + self._boost_discrete( + iboost, + X, + y, + sample_weight, + estimator, + ) + ) + + self.estimator_weight_ = self.estimator_weight_.at[iboost].set( + estimator_weight + ) + self.estimator_errors_ = self.estimator_errors_.at[iboost].set( + estimator_error + ) + self.estimator_flags_ = self.estimator_flags_.at[iboost].set(flag) + + sample_weight_sum = jnp.sum(sample_weight) + if iboost < self.n_estimators - 1: + sample_weight /= sample_weight_sum + + return self + + def _boost_discrete(self, iboost, X, y, sample_weight, estimator): + """Implement a single boost using the SAMME discrete algorithm.""" + self.estimators_.append(estimator) + + n_classes = self.n_classes + epsilon = self.epsilon + + estimator.fit(X, y, sample_weight=sample_weight) + + y_predict = estimator.predict(X) + + incorrect = y_predict != y + estimator_error = jnp.mean( + jnp.average(incorrect, weights=sample_weight, axis=0) + ) + is_small_error = estimator_error <= epsilon + + self.early_stop = jnp.logical_or(self.early_stop, is_small_error) + + def true_0_fun(sample_weight): + return sample_weight, 1.0, 0.0, jnp.array(False, dtype=jnp.bool_) + + def false_0_fun(sample_weight, estimator_error, incorrect, n_classes): + flag = estimator_error < 1.0 - (1.0 / n_classes) + flag = jnp.where(self.early_stop, jnp.array(False, dtype=jnp.bool_), flag) + + estimator_weight = self.learning_rate * ( + jnp.log((1.0 - estimator_error) / estimator_error) + + jnp.log(n_classes - 1.0) + ) + sample_weight_updated = sample_weight * jnp.exp( + estimator_weight * incorrect + ) + + sample_weight = jnp.where(flag, sample_weight_updated, sample_weight) + estimator_weight = jnp.where(flag, estimator_weight, 0.0) + + return sample_weight, estimator_weight, estimator_error, flag + + sample_weight_true, estimator_weight_true, estimator_error_true, flag_true = ( + true_0_fun(sample_weight) + ) + ( + sample_weight_false, + estimator_weight_false, + estimator_error_false, + flag_false, + ) = false_0_fun(sample_weight, estimator_error, incorrect, n_classes) + + sample_weight = jnp.where( + is_small_error, sample_weight_true, sample_weight_false + ) + estimator_weight = jnp.where( + is_small_error, estimator_weight_true, estimator_weight_false + ) + estimator_error = jnp.where( + is_small_error, estimator_error_true, estimator_error_false + ) + flag = jnp.where(is_small_error, flag_true, flag_false) + + return sample_weight, estimator_weight, estimator_error, flag + + def predict(self, X): + pred = self.decision_function(X) + + if self.n_classes == 2: + return self.classes.take(pred > 0, axis=0) + + return self.classes.take(jnp.argmax(pred, axis=1), axis=0) + + def decision_function(self, X): + n_classes = self.n_classes + classes = self.classes[:, jnp.newaxis] + + pred = sum( + jnp.where( + (estimator.predict(X) == classes).T, + w, + -1 / (n_classes - 1) * w, + ) + * flag + for estimator, w, flag in zip( + self.estimators_, self.estimator_weight_, self.estimator_flags_ + ) + ) + + weights_flags = self.estimator_weight_ * self.estimator_flags_ + pred /= jnp.sum(weights_flags) + + if n_classes == 2: + pred[:, 0] *= -1 + return pred.sum(axis=1) + return pred diff --git a/sml/ensemble/emulations/BUILD.bazel b/sml/ensemble/emulations/BUILD.bazel index 7832e732..2a3d0bd4 100644 --- a/sml/ensemble/emulations/BUILD.bazel +++ b/sml/ensemble/emulations/BUILD.bazel @@ -1,4 +1,4 @@ -# Copyright 2023 Ant Group Co., Ltd. +# Copyright 2024 Ant Group Co., Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,3 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +load("@rules_python//python:defs.bzl", "py_binary") + +package(default_visibility = ["//visibility:public"]) + +py_binary( + name = "adaboost_emul", + srcs = ["adaboost_emul.py"], + deps = [ + "//sml/ensemble:adaboost", + "//sml/utils:emulation", + ], +) + +py_binary( + name = "forest_emul", + srcs = ["forest_emul.py"], + deps = [ + "//sml/ensemble:forest", + "//sml/utils:emulation", + ], +) diff --git a/sml/ensemble/emulations/adaboost_emul.py b/sml/ensemble/emulations/adaboost_emul.py new file mode 100644 index 00000000..73cc6178 --- /dev/null +++ b/sml/ensemble/emulations/adaboost_emul.py @@ -0,0 +1,120 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time + +import jax.numpy as jnp +from sklearn.datasets import load_iris +from sklearn.ensemble import AdaBoostClassifier +from sklearn.tree import DecisionTreeClassifier + +import sml.utils.emulation as emulation +from sml.ensemble.adaboost import AdaBoostClassifier as sml_Adaboost +from sml.tree.tree import DecisionTreeClassifier as sml_dtc + +MAX_DEPTH = 3 +CONFIG_FILE = emulation.CLUSTER_ABY3_3PC + + +def emul_ada(mode=emulation.Mode.MULTIPROCESS): + def proc_wrapper( + estimator, + n_estimators, + learning_rate, + algorithm, + epsilon, + ): + ada_custom = sml_Adaboost( + estimator=estimator, + n_estimators=n_estimators, + learning_rate=learning_rate, + algorithm=algorithm, + epsilon=epsilon, + ) + + def proc(X, y): + ada_custom_fit = ada_custom.fit(X, y, sample_weight=None) + result = ada_custom_fit.predict(X) + return result + + return proc + + def load_data(): + iris = load_iris() + iris_data, iris_label = jnp.array(iris.data), jnp.array(iris.target) + # sorted_features: n_samples * n_features_in + n_samples, n_features_in = iris_data.shape + n_labels = len(jnp.unique(iris_label)) + sorted_features = jnp.sort(iris_data, axis=0) + new_threshold = (sorted_features[:-1, :] + sorted_features[1:, :]) / 2 + new_features = jnp.greater_equal( + iris_data[:, :], new_threshold[:, jnp.newaxis, :] + ) + new_features = new_features.transpose([1, 0, 2]).reshape(n_samples, -1) + + X, y = new_features[:, ::3], iris_label[:] + return X, y + + try: + # bandwidth and latency only work for docker mode + emulator = emulation.Emulator(CONFIG_FILE, mode, bandwidth=300, latency=20) + emulator.up() + + # load mock data + X, y = load_data() + n_labels = jnp.unique(y).shape[0] + + # compare with sklearn + base_estimator = DecisionTreeClassifier(max_depth=3) # 基分类器 + ada = AdaBoostClassifier( + estimator=base_estimator, + n_estimators=3, + learning_rate=1.0, + algorithm="SAMME", + ) + + start = time.time() + ada = ada.fit(X, y) + score_plain = ada.score(X, y) + end = time.time() + print(f"Running time in SKlearn: {end - start:.2f}s") + + # mark these data to be protected in SPU + X_spu, y_spu = emulator.seal(X, y) + + # run + dtc = sml_dtc("gini", "best", 3, 3) + proc = proc_wrapper( + estimator=dtc, + n_estimators=3, + learning_rate=1.0, + algorithm="discrete", + epsilon=1e-5, + ) + start = time.time() + + result = emulator.run(proc)(X_spu, y_spu) + end = time.time() + score_encrpted = jnp.mean((result == y)) + print(f"Running time in SPU: {end - start:.2f}s") + + # print acc + print(f"Accuracy in SKlearn: {score_plain:.2f}") + print(f"Accuracy in SPU: {score_encrpted:.2f}") + + finally: + emulator.down() + + +if __name__ == "__main__": + emul_ada(emulation.Mode.MULTIPROCESS) diff --git a/sml/ensemble/emulations/forest_emul.py b/sml/ensemble/emulations/forest_emul.py new file mode 100644 index 00000000..90437386 --- /dev/null +++ b/sml/ensemble/emulations/forest_emul.py @@ -0,0 +1,125 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time + +import jax.numpy as jnp +from sklearn.datasets import load_iris +from sklearn.ensemble import RandomForestClassifier + +import sml.utils.emulation as emulation +from sml.ensemble.forest import RandomForestClassifier as sml_rfc + +MAX_DEPTH = 3 +CONFIG_FILE = emulation.CLUSTER_ABY3_3PC + + +def emul_forest(mode=emulation.Mode.MULTIPROCESS): + def proc_wrapper( + n_estimators, + max_features, + criterion, + splitter, + max_depth, + bootstrap, + max_samples, + n_labels, + ): + rf_custom = sml_rfc( + n_estimators, + max_features, + criterion, + splitter, + max_depth, + bootstrap, + max_samples, + n_labels, + ) + + def proc(X, y): + rf_custom_fit = rf_custom.fit(X, y) + result = rf_custom_fit.predict(X) + return result + + return proc + + def load_data(): + iris = load_iris() + iris_data, iris_label = jnp.array(iris.data), jnp.array(iris.target) + # sorted_features: n_samples * n_features_in + n_samples, n_features_in = iris_data.shape + n_labels = len(jnp.unique(iris_label)) + sorted_features = jnp.sort(iris_data, axis=0) + new_threshold = (sorted_features[:-1, :] + sorted_features[1:, :]) / 2 + new_features = jnp.greater_equal( + iris_data[:, :], new_threshold[:, jnp.newaxis, :] + ) + new_features = new_features.transpose([1, 0, 2]).reshape(n_samples, -1) + + X, y = new_features[:, ::3], iris_label[:] + return X, y + + try: + # bandwidth and latency only work for docker mode + emulator = emulation.Emulator(CONFIG_FILE, mode, bandwidth=300, latency=20) + emulator.up() + + # load mock data + X, y = load_data() + n_labels = jnp.unique(y).shape[0] + + # compare with sklearn + rf = RandomForestClassifier( + n_estimators=3, + max_features=None, + criterion='gini', + max_depth=MAX_DEPTH, + bootstrap=False, + max_samples=None, + ) + start = time.time() + rf = rf.fit(X, y) + score_plain = rf.score(X, y) + end = time.time() + print(f"Running time in SKlearn: {end - start:.2f}s") + + # mark these data to be protected in SPU + X_spu, y_spu = emulator.seal(X, y) + + # run + proc = proc_wrapper( + n_estimators=3, + max_features=0.7, + criterion='gini', + splitter='best', + max_depth=3, + bootstrap=False, + max_samples=None, + n_labels=n_labels, + ) + start = time.time() + result = emulator.run(proc)(X_spu, y_spu) + end = time.time() + score_encrpted = jnp.mean((result == y)) + print(f"Running time in SPU: {end - start:.2f}s") + + # print acc + print(f"Accuracy in SKlearn: {score_plain:.2f}") + print(f"Accuracy in SPU: {score_encrpted:.2f}") + + finally: + emulator.down() + + +if __name__ == "__main__": + emul_forest(emulation.Mode.MULTIPROCESS) diff --git a/sml/ensemble/forest.py b/sml/ensemble/forest.py new file mode 100644 index 00000000..c2dd2649 --- /dev/null +++ b/sml/ensemble/forest.py @@ -0,0 +1,201 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import random + +import jax +import jax.numpy as jnp +from jax import lax + +from sml.tree.tree import DecisionTreeClassifier as sml_dtc + + +class RandomForestClassifier: + """A random forest classifier based on DecisionTreeClassifier. + + Parameters + ---------- + n_estimators : int + The number of trees in the forest. Must specify an integer > 0. + + max_features : int, float, "auto", "sqrt", "log2", or None. + The number of features to consider when looking for the best split. + If it's an integer, must 0 < integer < n_features. + If it's an float, must 0 < float <= 1. + + criterion : {"gini"}, default="gini" + The function to measure the quality of a split. Supported criteria are + "gini" for the Gini impurity. + + splitter : {"best"}, default="best" + The strategy used to choose the split at each node. Supported + strategies are "best" to choose the best split. + + max_depth : int + The maximum depth of the tree. Must specify an integer > 0. + + bootstrap : bool + Whether bootstrap samples are used when building trees. + + max_samples : int, float ,None, default=None + The number of samples to draw from X to train each base estimator. + This parameter is only valid if bootstrap is ture. + If it's an integer, must 0 < integer < n_samples. + If it's an float, must 0 < float <= 1. + + n_labels: int + The max number of labels. + + """ + + def __init__( + self, + n_estimators, + max_features, + criterion, + splitter, + max_depth, + bootstrap, + max_samples, + n_labels, + ): + assert criterion == "gini", "criteria other than gini is not supported." + assert splitter == "best", "splitter other than best is not supported." + assert ( + n_estimators is not None and n_estimators > 0 + ), "n_estimators should not be None and must > 0." + assert ( + max_depth is not None and max_depth > 0 + ), "max_depth should not be None and must > 0." + assert isinstance( + bootstrap, bool + ), "bootstrap should be a boolean value (True or False)" + + self.n_estimators = n_estimators + self.max_features = max_features + self.criterion = criterion + self.splitter = splitter + self.max_depth = max_depth + self.bootstrap = bootstrap + self.max_samples = max_samples + self.n_labels = n_labels + + self.trees = [] + self.features_indices = [] + + def _calculate_max_samples(self, max_samples, n_samples): + if isinstance(max_samples, int): + assert ( + max_samples <= n_samples + ), "max_samples should not exceed n_samples when it's an integer" + return max_samples + elif isinstance(max_samples, float): + assert ( + 0 < max_samples <= 1 + ), "max_samples should be in the range (0, 1] when it's a float" + return int(max_samples * n_samples) + else: + return n_samples + + def _bootstrap_sample(self, X, y): + n_samples = X.shape[0] + max_samples = self._calculate_max_samples(self.max_samples, n_samples) + + if not self.bootstrap: + return X, y + + # 实现bootstrap + population = range(n_samples) + indices = random.sample(population, max_samples) + + indices = jnp.array(indices) + return X[indices], y[indices] + + def _select_features(self, n, k): + indices = range(n) + selected_elements = random.sample(indices, k) + return selected_elements + + def _calculate_max_features(self, max_features, n_features): + if isinstance(max_features, int): + assert ( + 0 < max_features <= n_features + ), "0 < max_features <= n_features when it's an integer" + return max_features + + elif isinstance(max_features, float): + assert ( + 0 < max_features <= 1 + ), "max_features should be in the range (0, 1] when it's a float" + return int(max_features * n_features) + + elif isinstance(max_features, str): + if max_features == 'sqrt': + return int(math.sqrt(n_features)) + elif max_features == 'log2': + return int(math.log2(n_features)) + else: + return n_features + else: + return n_features + + def fit(self, X, y): + n_samples, n_features = X.shape + self.n_features = n_features + self.max_features = self._calculate_max_features( + self.max_features, self.n_features + ) + self.label_list = jnp.arange(self.n_labels) + + self.trees = [] + self.features_indices = [] + + for _ in range(self.n_estimators): + X_sample, y_sample = self._bootstrap_sample(X, y) + features = self._select_features(self.n_features, self.max_features) + + tree = sml_dtc(self.criterion, self.splitter, self.max_depth, self.n_labels) + tree.fit(X_sample[:, features], y_sample) + self.trees.append(tree) + self.features_indices.append(features) + + return self + + def jax_mode_row_vectorized(self, data): + label_list = jnp.array(self.label_list) + + data_expanded = jnp.expand_dims(data, axis=-1) + label_expanded = jnp.expand_dims(label_list, axis=0) + + mask = (data_expanded == label_expanded).astype(jnp.int32) + + counts = jnp.sum(mask, axis=1) + mode_indices = jnp.argmax(counts, axis=1) + + modes = label_list[mode_indices] + return modes + + def predict(self, X): + predictions_list = [] + for i, tree in enumerate(self.trees): + features = self.features_indices[i] + predictions = tree.predict(X[:, features]) + predictions_list.append(predictions) + + tree_predictions = jnp.array(predictions_list).T + + y_pred = self.jax_mode_row_vectorized(tree_predictions) + + return y_pred.ravel() diff --git a/sml/ensemble/tests/BUILD.bazel b/sml/ensemble/tests/BUILD.bazel index 7832e732..6815cf85 100644 --- a/sml/ensemble/tests/BUILD.bazel +++ b/sml/ensemble/tests/BUILD.bazel @@ -1,4 +1,4 @@ -# Copyright 2023 Ant Group Co., Ltd. +# Copyright 2024 Ant Group Co., Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,3 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +load("@rules_python//python:defs.bzl", "py_test") + +package(default_visibility = ["//visibility:public"]) + +py_test( + name = "adaboost_test", + srcs = ["adaboost_test.py"], + deps = [ + "//sml/ensemble:adaboost", + "//spu:init", + "//spu/utils:simulation", + ], +) + +py_test( + name = "forest_test", + srcs = ["forest_test.py"], + deps = [ + "//sml/ensemble:forest", + "//spu:init", + "//spu/utils:simulation", + ], +) diff --git a/sml/ensemble/tests/adaboost_test.py b/sml/ensemble/tests/adaboost_test.py new file mode 100644 index 00000000..71f6b92d --- /dev/null +++ b/sml/ensemble/tests/adaboost_test.py @@ -0,0 +1,108 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import jax.numpy as jnp +from sklearn.datasets import load_iris +from sklearn.ensemble import AdaBoostClassifier +from sklearn.tree import DecisionTreeClassifier + +import spu.spu_pb2 as spu_pb2 # type: ignore +import spu.utils.simulation as spsim +from sml.ensemble.adaboost import AdaBoostClassifier as sml_Adaboost +from sml.tree.tree import DecisionTreeClassifier as sml_dtc + +MAX_DEPTH = 3 + + +class UnitTests(unittest.TestCase): + def test_Ada(self): + def proc_wrapper( + estimator, + n_estimators, + learning_rate, + algorithm, + epsilon, + ): + ada_custom = sml_Adaboost( + estimator=estimator, + n_estimators=n_estimators, + learning_rate=learning_rate, + algorithm=algorithm, + epsilon=epsilon, + ) + + def proc(X, y): + ada_custom_fit = ada_custom.fit(X, y, sample_weight=None) + result = ada_custom_fit.predict(X) + return result + + return proc + + def load_data(): + iris = load_iris() + iris_data, iris_label = jnp.array(iris.data), jnp.array(iris.target) + # sorted_features: n_samples * n_features_in + n_samples, n_features_in = iris_data.shape + n_labels = len(jnp.unique(iris_label)) + sorted_features = jnp.sort(iris_data, axis=0) + new_threshold = (sorted_features[:-1, :] + sorted_features[1:, :]) / 2 + new_features = jnp.greater_equal( + iris_data[:, :], new_threshold[:, jnp.newaxis, :] + ) + new_features = new_features.transpose([1, 0, 2]).reshape(n_samples, -1) + + X, y = new_features[:, ::3], iris_label[:] + return X, y + + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + ) + + X, y = load_data() + n_samples, n_features = X.shape + n_labels = jnp.unique(y).shape[0] + + # compare with sklearn + base_estimator = DecisionTreeClassifier(max_depth=3) # 基分类器 + ada = AdaBoostClassifier( + estimator=base_estimator, + n_estimators=3, + learning_rate=1.0, + algorithm="SAMME", + ) + ada = ada.fit(X, y) + score_plain = ada.score(X, y) + + # run + dtc = sml_dtc("gini", "best", 3, 3) + proc = proc_wrapper( + estimator=dtc, + n_estimators=3, + learning_rate=1.0, + algorithm="discrete", + epsilon=1e-5, + ) + + result = spsim.sim_jax(sim, proc)(X, y) + print(result) + score_encrypted = jnp.mean(result == y) + + # print acc + print(f"Accuracy in SKlearn: {score_plain}") + print(f"Accuracy in SPU: {score_encrypted}") + + +if __name__ == '__main__': + unittest.main() diff --git a/sml/ensemble/tests/forest_test.py b/sml/ensemble/tests/forest_test.py new file mode 100644 index 00000000..6c9d3280 --- /dev/null +++ b/sml/ensemble/tests/forest_test.py @@ -0,0 +1,117 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import jax.numpy as jnp +from sklearn.datasets import load_iris +from sklearn.ensemble import RandomForestClassifier + +import spu.spu_pb2 as spu_pb2 # type: ignore +import spu.utils.simulation as spsim +from sml.ensemble.forest import RandomForestClassifier as sml_rfc + +MAX_DEPTH = 3 + + +class UnitTests(unittest.TestCase): + def test_forest(self): + def proc_wrapper( + n_estimators, + max_features, + criterion, + splitter, + max_depth, + bootstrap, + max_samples, + n_labels, + ): + rf_custom = sml_rfc( + n_estimators, + max_features, + criterion, + splitter, + max_depth, + bootstrap, + max_samples, + n_labels, + ) + + def proc(X, y): + rf_custom_fit = rf_custom.fit(X, y) + + result = rf_custom_fit.predict(X) + return result + + return proc + + def load_data(): + iris = load_iris() + iris_data, iris_label = jnp.array(iris.data), jnp.array(iris.target) + n_samples, n_features_in = iris_data.shape + n_labels = len(jnp.unique(iris_label)) + sorted_features = jnp.sort(iris_data, axis=0) + new_threshold = (sorted_features[:-1, :] + sorted_features[1:, :]) / 2 + new_features = jnp.greater_equal( + iris_data[:, :], new_threshold[:, jnp.newaxis, :] + ) + new_features = new_features.transpose([1, 0, 2]).reshape(n_samples, -1) + + X, y = new_features[:, ::3], iris_label[:] + return X, y + + # bandwidth and latency only work for docker mode + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + ) + + # load mock data + X, y = load_data() + n_labels = jnp.unique(y).shape[0] + + # compare with sklearn + rf = RandomForestClassifier( + n_estimators=3, + max_features="log2", + criterion='gini', + max_depth=MAX_DEPTH, + bootstrap=True, + max_samples=0.7, + ) + rf = rf.fit(X, y) + score_plain = rf.score(X, y) + tree_predictions = jnp.array([tree.predict(X) for tree in rf.estimators_]) + + # run + proc = proc_wrapper( + n_estimators=3, + max_features="log2", + criterion='gini', + splitter='best', + max_depth=3, + bootstrap=True, + max_samples=0.7, + n_labels=n_labels, + ) + + result = spsim.sim_jax(sim, proc)(X, y) + + score_encrpted = jnp.mean((result == y)) + + # print acc + print(f"Accuracy in SKlearn: {score_plain}") + print(f"Accuracy in SPU: {score_encrpted}") + + +if __name__ == "__main__": + unittest.main() diff --git a/sml/linear_model/BUILD.bazel b/sml/linear_model/BUILD.bazel index 6fa078a8..fa4fdd15 100644 --- a/sml/linear_model/BUILD.bazel +++ b/sml/linear_model/BUILD.bazel @@ -54,3 +54,11 @@ py_binary( "//sml/linear_model/utils:solver", ], ) + +py_library( + name = "quantile", + srcs = ["quantile.py"], + deps = [ + "//sml/linear_model/utils:_linprog_simplex", + ], +) diff --git a/sml/linear_model/emulations/BUILD.bazel b/sml/linear_model/emulations/BUILD.bazel index 6778cd0c..46df0f92 100644 --- a/sml/linear_model/emulations/BUILD.bazel +++ b/sml/linear_model/emulations/BUILD.bazel @@ -62,3 +62,12 @@ py_binary( "//sml/utils:emulation", ], ) + +py_binary( + name = "quantile_emul", + srcs = ["quantile_emul.py"], + deps = [ + "//sml/linear_model:quantile", + "//sml/utils:emulation", + ], +) diff --git a/sml/linear_model/emulations/quantile_emul.py b/sml/linear_model/emulations/quantile_emul.py new file mode 100644 index 00000000..ed81b4c3 --- /dev/null +++ b/sml/linear_model/emulations/quantile_emul.py @@ -0,0 +1,104 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import jax.numpy as jnp +from sklearn.linear_model import QuantileRegressor as SklearnQuantileRegressor + +import sml.utils.emulation as emulation +from sml.linear_model.quantile import QuantileRegressor as SmlQuantileRegressor + +CONFIG_FILE = emulation.CLUSTER_ABY3_3PC + + +def emul_quantile(mode=emulation.Mode.MULTIPROCESS): + def proc_wrapper( + quantile, + alpha, + fit_intercept, + lr, + max_iter, + ): + quantile_custom = SmlQuantileRegressor( + quantile=quantile, + alpha=alpha, + fit_intercept=fit_intercept, + lr=lr, + max_iter=max_iter, + ) + + def proc(X, y): + quantile_custom_fit = quantile_custom.fit(X, y) + result = quantile_custom_fit.predict(X) + return result, quantile_custom_fit.coef_, quantile_custom_fit.intercept_ + + return proc + + def generate_data(): + from jax import random + + key = random.PRNGKey(42) + key, subkey = random.split(key) + X = random.normal(subkey, (100, 2)) + y = 5 * X[:, 0] + 2 * X[:, 1] + random.normal(key, (100,)) * 0.1 + return X, y + + try: + # bandwidth and latency only work for docker mode + emulator = emulation.Emulator(CONFIG_FILE, mode, bandwidth=300, latency=20) + emulator.up() + + # load mock data + X, y = generate_data() + + # compare with sklearn + quantile_sklearn = SklearnQuantileRegressor( + quantile=0.2, alpha=0.1, fit_intercept=True, solver='highs' + ) + start = time.time() + quantile_sklearn_fit = quantile_sklearn.fit(X, y) + y_pred_plain = quantile_sklearn_fit.predict(X) + rmse_plain = jnp.sqrt(jnp.mean((y - y_pred_plain) ** 2)) + end = time.time() + print(f"Running time in SKlearn: {end - start:.2f}s") + print(quantile_sklearn_fit.coef_) + print(quantile_sklearn_fit.intercept_) + + # mark these data to be protected in SPU + X_spu, y_spu = emulator.seal(X, y) + + # run + # Larger max_iter can give higher accuracy, but it will take more time to run + proc = proc_wrapper( + quantile=0.2, alpha=0.1, fit_intercept=True, lr=0.01, max_iter=200 + ) + start = time.time() + result, coef, intercept = emulator.run(proc)(X_spu, y_spu) + end = time.time() + rmse_encrpted = jnp.sqrt(jnp.mean((y - result) ** 2)) + print(f"Running time in SPU: {end - start:.2f}s") + print(coef) + print(intercept) + + # print RMSE + print(f"RMSE in SKlearn: {rmse_plain:.2f}") + print(f"RMSE in SPU: {rmse_encrpted:.2f}") + + finally: + emulator.down() + + +if __name__ == "__main__": + emul_quantile(emulation.Mode.MULTIPROCESS) diff --git a/sml/linear_model/quantile.py b/sml/linear_model/quantile.py new file mode 100644 index 00000000..549e67ae --- /dev/null +++ b/sml/linear_model/quantile.py @@ -0,0 +1,196 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +import jax.numpy as jnp +import pandas as pd +from jax import grad + +from sml.linear_model.utils._linprog_simplex import _linprog_simplex + + +class QuantileRegressor: + """ + Initialize the quantile regression model. + Parameters + ---------- + quantile : float, default=0.5 + The quantile to be predicted. Must be between 0 and 1. + A quantile of 0.5 corresponds to the median (50th percentile). + alpha : float, default=1.0 + Regularization strength; must be a positive float. + Larger values specify stronger regularization, reducing model complexity. + fit_intercept : bool, default=True + Whether to calculate the intercept for the model. + If False, no intercept will be used in calculations, meaning the model will + assume that the data is already centered. + lr : float, default=0.01 + Learning rate for the optimization process. This controls the size of + the steps taken in each iteration towards minimizing the objective function. + max_iter : int, default=1000 + The maximum number of iterations for the optimization algorithm. + This controls how long the model will continue to update the weights + before stopping. + max_val : float, default=1e10 + The maximum value allowed for the model parameters. + Attributes + ---------- + coef_ : array-like of shape (n_features,) + The coefficients (weights) assigned to the input features. These will be + learned during model fitting. + intercept_ : float + The intercept (bias) term. If `fit_intercept=True`, this will be + learned during model fitting. + """ + + def __init__( + self, + quantile=0.5, + alpha=1.0, + fit_intercept=True, + lr=0.01, + max_iter=1000, + max_val=1e10, + ): + self.quantile = quantile + self.alpha = alpha + self.fit_intercept = fit_intercept + self.lr = lr + self.max_iter = max_iter + self.max_val = max_val + + self.coef_ = None + self.intercept_ = None + + def fit(self, X, y, sample_weight=None): + """ + Fit the quantile regression model using linear programming. + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Training data. + y : array-like of shape (n_samples,) + Target values. + sample_weight : array-like of shape (n_samples,), optional + Individual weights for each sample. If not provided, all samples + are assumed to have equal weight. + Returns + ------- + self : object + Returns an instance of self. + Steps: + 1. Determine the number of parameters (`n_params`), accounting for the intercept if needed. + 2. Define the objective function `c`, incorporating both the L1 regularization and the pinball loss. + 3. Set up the equality constraint matrix `A_eq` and vector `b_eq` based on the input data `X` and `y`. + 4. Solve the linear programming problem using `_linprog_simplex`. + 5. Extract the model parameters (intercept and coefficients) from the solution. + """ + n_samples, n_features = X.shape + n_params = n_features + + if sample_weight is None: + sample_weight = jnp.ones((n_samples,)) + + if self.fit_intercept: + n_params += 1 + + alpha = jnp.sum(sample_weight) * self.alpha + + # After rescaling alpha, the minimization problem is + # min sum(pinball loss) + alpha * L1 + # Use linear programming formulation of quantile regression + # min_x c x + # A_eq x = b_eq + # 0 <= x + # x = (s0, s, t0, t, u, v) = slack variables >= 0 + # intercept = s0 - t0 + # coef = s - t + # c = (0, alpha * 1_p, 0, alpha * 1_p, quantile * 1_n, (1-quantile) * 1_n) + # residual = y - X@coef - intercept = u - v + # A_eq = (1_n, X, -1_n, -X, diag(1_n), -diag(1_n)) + # b_eq = y + # p = n_features + # n = n_samples + # 1_n = vector of length n with entries equal one + # see https://stats.stackexchange.com/questions/384909/ + c = jnp.concatenate( + [ + jnp.full(2 * n_params, fill_value=alpha), + sample_weight * self.quantile, + sample_weight * (1 - self.quantile), + ] + ) + + if self.fit_intercept: + c = c.at[0].set(0) + c = c.at[n_params].set(0) + + eye = jnp.eye(n_samples) + if self.fit_intercept: + ones = jnp.ones((n_samples, 1)) + A = jnp.concatenate([ones, X, -ones, -X, eye, -eye], axis=1) + else: + A = jnp.concatenate([X, -X, eye, -eye], axis=1) + + b = y + + result = _linprog_simplex( + c, A, b, maxiter=self.max_iter, tol=1e-3, max_val=self.max_val + ) + + solution = result + + params = solution[:n_params] - solution[n_params : 2 * n_params] + + if self.fit_intercept: + self.coef_ = params[1:] + self.intercept_ = params[0] + else: + self.coef_ = params + self.intercept_ = 0.0 + return self + + def predict(self, X): + """ + Predict target values using the fitted quantile regression model. + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Input data for which predictions are to be made. + Returns + ------- + y_pred : array-like of shape (n_samples,) + Predicted target values. + Notes + ----- + The predict method computes the predicted target values using the model's + learned coefficients and intercept (if fit_intercept=True). + - If the model includes an intercept, a column of ones is added to the input data `X` to account + for the intercept in the linear combination. + - The method then computes the dot product between the modified `X` and the stacked vector of + intercept and coefficients. + - If there is no intercept, the method simply computes the dot product between `X` and the coefficients. + """ + + assert ( + self.coef_ is not None and self.intercept_ is not None + ), "Model has not been fitted yet. Please fit the model before predicting." + + n_features = len(self.coef_) + assert X.shape[1] == n_features, ( + f"Input X must have {n_features} features, " + f"but got {X.shape[1]} features instead." + ) + + return jnp.dot(X, self.coef_) + self.intercept_ diff --git a/sml/linear_model/tests/BUILD.bazel b/sml/linear_model/tests/BUILD.bazel index 1fa04f86..f729c206 100644 --- a/sml/linear_model/tests/BUILD.bazel +++ b/sml/linear_model/tests/BUILD.bazel @@ -70,3 +70,13 @@ py_test( "//spu/utils:simulation", ], ) + +py_test( + name = "quantile_test", + srcs = ["quantile_test.py"], + deps = [ + "//sml/linear_model:quantile", + "//spu:init", + "//spu/utils:simulation", + ], +) diff --git a/sml/linear_model/tests/glm_test.py b/sml/linear_model/tests/glm_test.py index a9e7c734..f42b09b6 100644 --- a/sml/linear_model/tests/glm_test.py +++ b/sml/linear_model/tests/glm_test.py @@ -106,7 +106,7 @@ def accuracy_test(model, std_model, y, coef, num=5): assert norm_diff < 1e-2 -def proc_test(proc): +def proc_test(proc, x, y): """ Test if the results of the specified fitting algorithm are correct. @@ -121,8 +121,8 @@ def proc_test(proc): """ # Run the simulation and get the results - sim_res = spsim.sim_jax(sim, proc)() - res = proc() + sim_res = spsim.sim_jax(sim, proc)(x, y) + res = proc(x, y) # Calculate the difference between simulation and actual results norm_diff = jnp.linalg.norm(sim_res - res) @@ -130,10 +130,10 @@ def proc_test(proc): print(proc.__name__, "-norm_diff:", "%.5f" % norm_diff) # Assert that the difference is within the tolerance - assert norm_diff < 1e-4 + assert norm_diff < 5e-1 -def proc_ncSolver(): +def proc_ncSolver(X, y): """ Fit Generalized Linear Regression model using Newton-Cholesky algorithm and return the model coefficients. @@ -163,7 +163,7 @@ def proc_lbfgsSolver(): return model.coef_ -def proc_Poisson(): +def proc_Poisson(X, round_exp_y): """ Fit Generalized Linear Regression model using PoissonRegressor and return the model coefficients. @@ -178,7 +178,7 @@ def proc_Poisson(): return model.coef_ -def proc_Gamma(): +def proc_Gamma(X, exp_y): """ Fit Generalized Linear Regression model using GammaRegressor and return the model coefficients. @@ -193,7 +193,7 @@ def proc_Gamma(): return model.coef_ -def proc_Tweedie(): +def proc_Tweedie(X, exp_y): """ Fit Generalized Linear Regression model using TweedieRegressor and return the model coefficients. @@ -239,22 +239,22 @@ def test_Tweedie_accuracy(self, power=1.5): def test_ncSolver_encrypted(self): # Test if the results of the Newton-Cholesky solver are correct after encryption - proc_test(proc_ncSolver) + proc_test(proc_ncSolver, X, y) print('test_ncSolver_encrypted: OK') def test_Poisson_encrypted(self): # Test if the results of the PoissonRegressor model are correct after encryption - proc_test(proc_Poisson) + proc_test(proc_Poisson, X, round_exp_y) print('test_Poisson_encrypted: OK') def test_gamma_encrypted(self): # Test if the results of the GammaRegressor model are correct after encryption - proc_test(proc_Gamma) + proc_test(proc_Gamma, X, exp_y) print('test_gamma_encrypted: OK') def test_Tweedie_encrypted(self): # Test if the results of the TweedieRegressor model are correct after encryption - proc_test(proc_Tweedie) + proc_test(proc_Tweedie, X, exp_y) print('test_Tweedie_encrypted: OK') diff --git a/sml/linear_model/tests/quantile_test.py b/sml/linear_model/tests/quantile_test.py new file mode 100644 index 00000000..d9d12f1f --- /dev/null +++ b/sml/linear_model/tests/quantile_test.py @@ -0,0 +1,93 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import jax.numpy as jnp +from sklearn.linear_model import QuantileRegressor as SklearnQuantileRegressor + +import spu.spu_pb2 as spu_pb2 # type: ignore +import spu.utils.simulation as spsim +from sml.linear_model.quantile import QuantileRegressor as SmlQuantileRegressor + + +class UnitTests(unittest.TestCase): + def test_quantile(self): + def proc_wrapper( + quantile, + alpha, + fit_intercept, + lr, + max_iter, + ): + quantile_custom = SmlQuantileRegressor( + quantile=quantile, + alpha=alpha, + fit_intercept=fit_intercept, + lr=lr, + max_iter=max_iter, + ) + + def proc(X, y): + quantile_custom_fit = quantile_custom.fit(X, y) + result = quantile_custom_fit.predict(X) + return result, quantile_custom_fit.coef_, quantile_custom_fit.intercept_ + + return proc + + n_samples, n_features = 100, 2 + + def generate_data(): + from jax import random + + key = random.PRNGKey(42) + key, subkey = random.split(key) + X = random.normal(subkey, (100, 2)) + y = 5 * X[:, 0] + 2 * X[:, 1] + random.normal(key, (100,)) * 0.1 + return X, y + + # bandwidth and latency only work for docker mode + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + ) + + X, y = generate_data() + + # compare with sklearn + quantile_sklearn = SklearnQuantileRegressor( + quantile=0.2, alpha=0.1, fit_intercept=True, solver='revised simplex' + ) + quantile_sklearn_fit = quantile_sklearn.fit(X, y) + y_pred_plain = quantile_sklearn_fit.predict(X) + rmse_plain = jnp.sqrt(jnp.mean((y - y_pred_plain) ** 2)) + print(f"RMSE in SKlearn: {rmse_plain:.2f}") + print(quantile_sklearn_fit.coef_) + print(quantile_sklearn_fit.intercept_) + + # run + # Larger max_iter can give higher accuracy, but it will take more time to run + proc = proc_wrapper( + quantile=0.2, alpha=0.1, fit_intercept=True, lr=0.01, max_iter=20 + ) + result, coef, intercept = spsim.sim_jax(sim, proc)(X, y) + rmse_encrpted = jnp.sqrt(jnp.mean((y - result) ** 2)) + + # print RMSE + print(f"RMSE in SPU: {rmse_encrpted:.2f}") + print(coef) + print(intercept) + + +if __name__ == "__main__": + unittest.main() diff --git a/sml/linear_model/utils/BUILD.bazel b/sml/linear_model/utils/BUILD.bazel index 7c13def5..27329073 100644 --- a/sml/linear_model/utils/BUILD.bazel +++ b/sml/linear_model/utils/BUILD.bazel @@ -31,3 +31,8 @@ py_library( name = "solver", srcs = ["solver.py"], ) + +py_library( + name = "_linprog_simplex", + srcs = ["_linprog_simplex.py"], +) diff --git a/sml/linear_model/utils/_linprog_simplex.py b/sml/linear_model/utils/_linprog_simplex.py new file mode 100644 index 00000000..0ae02578 --- /dev/null +++ b/sml/linear_model/utils/_linprog_simplex.py @@ -0,0 +1,156 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +import jax.numpy as jnp +from jax import jit, lax + + +def _pivot_col(T, tol=1e-5): + mask = T[-1, :-1] >= -tol + + all_masked = jnp.all(mask) + + ma = jnp.where(mask, jnp.inf, T[-1, :-1]) + min_col = jnp.argmin(ma) + + valid = ~all_masked + result = jnp.where(all_masked, 0, min_col) + + return valid, result + + +def _pivot_row(T, pivcol, phase, tol=1e-5, max_val=1e10): + if phase == 1: + k = 2 + else: + k = 1 + + mask = T[:-k, pivcol] <= tol + ma = jnp.where(mask, jnp.inf, T[:-k, pivcol]) + mb = jnp.where(mask, jnp.inf, T[:-k, -1]) + + q = jnp.where(ma >= max_val, jnp.inf, mb / ma) + + min_rows = jnp.nanargmin(q) + all_masked = jnp.all(mask) + + row = min_rows + row = jnp.where(all_masked, 0, row) + + return ~all_masked, row + + +def _apply_pivot(T, basis, pivrow, pivcol): + pivrow = jnp.int32(pivrow) + pivcol = jnp.int32(pivcol) + + basis = basis.at[pivrow].set(pivcol) + + pivrow_one_hot = jax.nn.one_hot(pivrow, T.shape[0]) + pivcol_one_hot = jax.nn.one_hot(pivcol, T.shape[1]) + + pivval = jnp.dot(pivrow_one_hot, jnp.dot(T, pivcol_one_hot)) + + updated_row = T[pivrow] / pivval + T = pivrow_one_hot[:, None] * updated_row + T * (1 - pivrow_one_hot[:, None]) + + scalar = jnp.dot(T, pivcol_one_hot).reshape(-1, 1) + + updated_T = T - scalar * T[pivrow] + + row_restore_matrix = pivrow_one_hot[:, None] * T[pivrow] + updated_T = row_restore_matrix + updated_T * (1 - pivrow_one_hot[:, None]) + + return updated_T, basis + + +def _solve_simplex( + T, + n, + basis, + maxiter=100, + tol=1e-5, + max_val=1e10, + phase=2, +): + complete = False + + num = 0 + pivcol = 0 + pivrow = 0 + while num < maxiter: + pivcol_found, pivcol = _pivot_col(T, tol) + + def cal_pivcol_found_True(T, pivcol, phase, tol, complete): + pivrow_found, pivrow = _pivot_row(T, pivcol, phase, tol, max_val) + + pivrow_isnot_found = pivrow_found == False + complete = jnp.where(pivrow_isnot_found, True, complete) + + return pivrow, complete + + pivcol_is_found = pivcol_found == True + pivrow_True, complete_True = cal_pivcol_found_True( + T, pivcol, phase, tol, complete + ) + + pivrow = jnp.where(pivcol_is_found, pivrow_True, 0) + + complete = jnp.where(pivcol_is_found, complete_True, complete) + + complete_is_False = complete == False + apply_T, apply_basis = _apply_pivot(T, basis, pivrow, pivcol) + T = jnp.where(complete_is_False, apply_T, T) + basis = jnp.where(complete_is_False, apply_basis, basis) + num = num + 1 + + return T, basis + + +def _linprog_simplex(c, A, b, c0=0, maxiter=300, tol=1e-5, max_val=1e10): + n, m = A.shape + + # All constraints must have b >= 0. + is_negative_constraint = jnp.less(b, 0) + A = jnp.where(is_negative_constraint[:, None], A * -1, A) + b = jnp.where(is_negative_constraint, b * -1, b) + + av = jnp.arange(n) + m + basis = av.copy() + + row_constraints = jnp.hstack((A, jnp.eye(n), b[:, jnp.newaxis])) + row_objective = jnp.hstack((c, jnp.zeros(n), c0)) + row_pseudo_objective = -row_constraints.sum(axis=0) + row_pseudo_objective = row_pseudo_objective.at[av].set(0) + T = jnp.vstack((row_constraints, row_objective, row_pseudo_objective)) + + # phase 1 + T, basis = _solve_simplex( + T, n, basis, maxiter=maxiter, tol=tol, max_val=max_val, phase=1 + ) + + T_new = T[:-1, :] + T = jnp.delete(T_new, av, 1, assume_unique_indices=True) + + # phase 2 + T, basis = _solve_simplex( + T, n, basis, maxiter=maxiter, tol=tol, max_val=max_val, phase=2 + ) + + solution = jnp.zeros(n + m) + solution = solution.at[basis[:n]].set(T[:n, -1]) + x = solution[:m] + + return x diff --git a/sml/metrics/classification/BUILD.bazel b/sml/metrics/classification/BUILD.bazel index 9dd7d487..0bcaebe5 100644 --- a/sml/metrics/classification/BUILD.bazel +++ b/sml/metrics/classification/BUILD.bazel @@ -21,6 +21,7 @@ py_library( srcs = ["classification.py"], deps = [ ":auc", + "//sml/preprocessing", "//spu/ops/groupby", ], ) diff --git a/sml/metrics/classification/auc.py b/sml/metrics/classification/auc.py index 1a21536d..be20526c 100644 --- a/sml/metrics/classification/auc.py +++ b/sml/metrics/classification/auc.py @@ -20,12 +20,14 @@ from spu.ops.groupby import groupby_sorted -def binary_clf_curve(sorted_pairs: jnp.array) -> Tuple[jnp.array, jnp.array, jnp.array]: +def binary_clf_curve( + sorted_pairs: jnp.ndarray, +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Calculate true and false positives per binary classification threshold (can be used for roc curve or precision/recall curve). Results may include trailing zeros. Args: - sorted_pairs: jnp.array + sorted_pairs: jnp.ndarray y_true y_score pairs sorted by y_score in decreasing order Returns: fps: 1d ndarray @@ -57,6 +59,7 @@ def binary_clf_curve(sorted_pairs: jnp.array) -> Tuple[jnp.array, jnp.array, jnp fps = seg_end_marks * fps thresholds = seg_end_marks * thresholds thresholds, fps, tps = jax.lax.sort([-thresholds] + [fps, tps], num_keys=1) + return fps, tps, -thresholds diff --git a/sml/metrics/classification/classification.py b/sml/metrics/classification/classification.py index 2d3882fc..1eec8a57 100644 --- a/sml/metrics/classification/classification.py +++ b/sml/metrics/classification/classification.py @@ -16,10 +16,12 @@ import jax import jax.numpy as jnp -from auc import binary_roc_auc +from sml.preprocessing.preprocessing import label_binarize from spu.ops.groupby import groupby, groupby_sum +from .auc import binary_clf_curve, binary_roc_auc + def roc_auc_score(y_true, y_pred): sorted_arr = create_sorted_label_score_pair(y_true, y_pred) @@ -222,3 +224,155 @@ def fun_score( else: raise ValueError("average should be None or 'binary'") return fun_result + + +def precision_recall_curve( + y_true: jnp.ndarray, y_score: jnp.ndarray, pos_label=1, score_eps=1e-5 +): + """Compute precision-recall pairs for different probability thresholds. + + Note: this implementation is restricted to the binary classification task. + + Parameters + ---------- + y_true : 1d array-like of shape (n,). True binary labels. + + y_score : 1d array-like of shape (n,). Target scores, non-negative. + + pos_label : int, default=1. The label of the positive class. + + score_eps : float, default=1e-5. The lower bound for y_score. + + Returns + ------- + precisions : ndarray of shape (n + 1,). + Precision values where element i is the precision s.t. + score >= thresholds[i] and the last element is 1. + + recalls : ndarray of shape (n + 1,). + Increasing recall values where element i is the recall s.t. + score >= thresholds[i] and the last element is 0. + + thresholds : ndarray of shape (n,). + Decreasing thresholds used to compute precision and recall. + Results might include trailing zeros. + """ + + # normalize the input + y_true = jnp.where(y_true == pos_label, 1, 0) + y_score = jnp.where( + y_score < score_eps, score_eps, y_score + ) # to avoid messing up trailing zero and score zero + + # compute TP and FP + sorted_pairs = create_sorted_label_score_pair(y_true, y_score) + fp, tp, thresholds = binary_clf_curve(sorted_pairs) + + # compute precision and recalls + mask = jnp.where(thresholds > 0, 1, 0) # tied value entries have mask=0 + precisions = jnp.where(mask, tp / (tp + fp + 1e-5), 0) + max_tp = jnp.max(tp) + recalls = jnp.where(max_tp == 0, jnp.ones_like(tp), tp / max_tp) + + return ( + jnp.hstack((1, precisions)), + jnp.hstack((0, recalls)), + thresholds, + ) + + +def average_precision_score( + y_true: jnp.ndarray, + y_score: jnp.ndarray, + classes=(0, 1), + average="macro", + pos_label=1, + score_eps=1e-5, +): + """Compute average precision (AP) from prediction scores. + + .. math:: + \\text{AP} = \\sum_n (R_n - R_{n-1}) P_n + + Parameters + ------- + y_true : array-like of shape (n_samples,) + True labels. + + y_score : array-like of shape (n_samples,) or (n_samples, n_classes) + Estimated target scores as returned by a classifier, non-negative. + + classes : 1d array-like, shape (n_classes,), default=(0,1) as for binary classification + Uniquely holds the label for each class. + SPU cannot support dynamic shape, so this parameter needs to be designated. + + average : {'macro', 'micro', None}, default='macro' + This parameter is required for multiclass/multilabel targets and + will be ignored when y_true is binary. + + 'macro': + Calculate metrics for each label, and find their unweighted mean. + 'micro': + Calculate metrics globally by considering each element of the label + indicator matrix as a label. + None: + Scores for each class are returned. + + pos_label : int, default=1 + The label of the positive class. Only applied to binary y_true. + + score_eps : float, default=1e-5. The lower bound for y_score. + + Returns + ------- + average_precision : float + Average precision score. + """ + + assert average in ( + 'macro', + 'micro', + None, + ), 'average must be either "macro", "micro" or None' + + def binary_average_precision(y_true, y_score, pos_label=1): + """Compute the average precision for binary classification.""" + precisions, recalls, _ = precision_recall_curve( + y_true, y_score, pos_label=pos_label, score_eps=score_eps + ) + + return jnp.sum(jnp.diff(recalls) * precisions[1:]) + + n_classes = len(classes) + if n_classes <= 2: + # binary classification + # given y_true all the same is a special case considered as binary classification + return binary_average_precision(y_true, y_score, pos_label=pos_label) + else: + # multi-class classification + # binarize labels using one-vs-all scheme into multilabel-indicator + y_true = label_binarize(y_true, classes=classes, n_classes=n_classes) + + if average == "micro": + y_true = y_true.ravel() + y_score = y_score.ravel() + elif average == "macro": + pass + + # extend the classes dimension if needed + if y_true.ndim == 1: + y_true = y_true[:, jnp.newaxis] + if y_score.ndim == 1: + y_score = y_score[:, jnp.newaxis] + + # compute score for each class + n_classes = y_score.shape[1] + score = jnp.zeros((n_classes,)) + for c in range(n_classes): + binary_ap = binary_average_precision( + y_true[:, c].ravel(), y_score[:, c].ravel(), pos_label=pos_label + ) + score = score.at[c].set(binary_ap) + + # average the scores + return jnp.average(score) if average else score diff --git a/sml/metrics/classification/classification_emul.py b/sml/metrics/classification/classification_emul.py index 0cb51686..afa2e65a 100644 --- a/sml/metrics/classification/classification_emul.py +++ b/sml/metrics/classification/classification_emul.py @@ -18,6 +18,7 @@ import jax.numpy as jnp import numpy as np from sklearn import metrics +from sklearn.metrics import average_precision_score as sk_average_precision_score # add ops dir to the path sys.path.append(os.path.join(os.path.dirname(__file__), '../../')) @@ -25,6 +26,7 @@ import sml.utils.emulation as emulation from sml.metrics.classification.classification import ( accuracy_score, + average_precision_score, f1_score, precision_score, recall_score, @@ -42,7 +44,7 @@ def emul_auc(mode: emulation.Mode.MULTIPROCESS): # Run result = emulator.run(roc_auc_score)( - y_true, y_pred + *emulator.seal(y_true, y_pred) ) # X, y should be two-dimension array print(result) @@ -97,7 +99,7 @@ def check(spu_result, sk_result): y_true = jnp.array([0, 1, 1, 0, 1, 1]) y_pred = jnp.array([0, 0, 1, 0, 1, 1]) spu_result = emulator.run(proc, static_argnums=(2, 5))( - y_true, y_pred, 'binary', None, 1, False + *emulator.seal(y_true, y_pred), 'binary', None, 1, False ) sk_result = sklearn_proc(y_true, y_pred) check(spu_result, sk_result) @@ -106,12 +108,83 @@ def check(spu_result, sk_result): y_true = jnp.array([0, 1, 1, 0, 2, 1]) y_pred = jnp.array([0, 0, 1, 0, 2, 1]) spu_result = emulator.run(proc, static_argnums=(2, 5))( - y_true, y_pred, None, [0, 1, 2], 1, True + *emulator.seal(y_true, y_pred), None, [0, 1, 2], 1, True ) sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0, 1, 2]) check(spu_result, sk_result) +def emul_average_precision_score(mode: emulation.Mode.MULTIPROCESS): + def procBinary(y_true, y_score, **kwargs): + sk_res = sk_average_precision_score(y_true, y_score, **kwargs) + spu_res = emulator.run(average_precision_score)( + *emulator.seal(y_true, y_score), **kwargs + ) + return sk_res, spu_res + + def check(res1, res2): + return np.testing.assert_allclose(res1, res2, rtol=1e-3, atol=1e-3) + + # --- Test binary classification --- + # 0-1 labels, no tied value + y_true = jnp.array([0, 0, 1, 1], dtype=jnp.int32) + y_score = jnp.array([0.1, 0.4, 0.35, 0.8], dtype=jnp.float32) + check(*procBinary(y_true, y_score)) + # 0-1 labels, with tied value, even length + y_true = jnp.array([0, 0, 1, 1], dtype=jnp.int32) + y_score = jnp.array([0.4, 0.4, 0.4, 0.25], dtype=jnp.float32) + check(*procBinary(y_true, y_score)) + # 0-1 labels, with tied value, odd length + y_true = jnp.array([0, 0, 1, 1, 1], dtype=jnp.int32) + y_score = jnp.array([0.4, 0.4, 0.4, 0.25, 0.25], dtype=jnp.float32) + check(*procBinary(y_true, y_score)) + # customized labels + y_true = jnp.array([2, 2, 3, 3], dtype=jnp.int32) + y_score = jnp.array([0.1, 0.2, 0.3, 0.4], dtype=jnp.float32) + check(*procBinary(y_true, y_score, pos_label=3)) + # larger random dataset + y_true = jnp.array(np.random.randint(0, 2, 100), dtype=jnp.int32) + y_score = jnp.array(np.hstack((0, 1, np.random.random(98))), dtype=jnp.float32) + check(*procBinary(y_true, y_score)) + # single label edge case + y_true = jnp.array([0, 0, 0, 0], dtype=jnp.int32) + y_score = jnp.array([0.4, 0.25, 0.4, 0.25], dtype=jnp.float32) + check(*procBinary(y_true, y_score)) + y_true = jnp.array([1, 1, 1, 1], dtype=jnp.int32) + y_score = jnp.array([0.4, 0.25, 0.4, 0.25], dtype=jnp.float32) + check(*procBinary(y_true, y_score)) + # zero score edge case + y_true = jnp.array([0, 0, 1, 1, 1], dtype=jnp.int32) + y_score = jnp.array([0, 0, 0, 0.25, 0.25], dtype=jnp.float32) + check(*procBinary(y_true, y_score)) + # score > 1 edge case + y_true = jnp.array([0, 0, 1, 1, 1], dtype=jnp.int32) + y_score = jnp.array([1.5, 1.5, 1.5, 0.25, 0.25], dtype=jnp.float32) + check(*procBinary(y_true, y_score)) + + # --- Test multiclass classification --- + y_true = np.array([0, 0, 1, 1, 2, 2], dtype=jnp.int32) + y_score = np.array( + [ + [0.7, 0.2, 0.1], + [0.4, 0.3, 0.3], + [0.1, 0.8, 0.1], + [0.2, 0.3, 0.5], + [0.4, 0.4, 0.2], + [0.1, 0.2, 0.7], + ], + dtype=jnp.float32, + ) + classes = jnp.unique(y_true) + # test over three supported average options + for average in ["macro", "micro", None]: + sk_res = sk_average_precision_score(y_true, y_score, average=average) + spu_res = emulator.run(average_precision_score, static_argnums=(3,))( + *emulator.seal(y_true, y_score), classes, average + ) + check(sk_res, spu_res) + + if __name__ == "__main__": try: # bandwidth and latency only work for docker mode @@ -124,5 +197,6 @@ def check(spu_result, sk_result): emulator.up() emul_auc(emulation.Mode.MULTIPROCESS) emul_Classification(emulation.Mode.MULTIPROCESS) + emul_average_precision_score(emulation.Mode.MULTIPROCESS) finally: emulator.down() diff --git a/sml/metrics/classification/classification_test.py b/sml/metrics/classification/classification_test.py index 2ff4a3f4..aeda6721 100644 --- a/sml/metrics/classification/classification_test.py +++ b/sml/metrics/classification/classification_test.py @@ -27,10 +27,12 @@ # add ops dir to the path sys.path.append(os.path.join(os.path.dirname(__file__), '../../')) +from sklearn.metrics import average_precision_score as sk_average_precision_score from sklearn.metrics import roc_auc_score as sk_roc_auc_score from sml.metrics.classification.classification import ( accuracy_score, + average_precision_score, bin_counts, equal_obs, f1_score, @@ -155,6 +157,80 @@ def check(spu_result, sk_result): sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0, 1, 2]) check(spu_result, sk_result) + def test_average_precision_score(self): + sim = spsim.Simulator.simple( + 2, spu_pb2.ProtocolKind.SEMI2K, spu_pb2.FieldType.FM64 + ) + + def proc(y_true, y_score, **kwargs): + sk_res = sk_average_precision_score(y_true, y_score, **kwargs) + spu_res = spsim.sim_jax(sim, average_precision_score)( + y_true, y_score, **kwargs + ) + return sk_res, spu_res + + def check(res1, res2): + return np.testing.assert_allclose(res1, res2, rtol=1e-3, atol=1e-3) + + # --- Test binary classification --- + # 0-1 labels, no tied value + y_true = jnp.array([0, 0, 1, 1], dtype=jnp.int32) + y_score = jnp.array([0.1, 0.4, 0.35, 0.8], dtype=jnp.float32) + check(*proc(y_true, y_score)) + # 0-1 labels, with tied value, even length + y_true = jnp.array([0, 0, 1, 1], dtype=jnp.int32) + y_score = jnp.array([0.4, 0.4, 0.4, 0.25], dtype=jnp.float32) + check(*proc(y_true, y_score)) + # 0-1 labels, with tied value, odd length + y_true = jnp.array([0, 0, 1, 1, 1], dtype=jnp.int32) + y_score = jnp.array([0.4, 0.4, 0.4, 0.25, 0.25], dtype=jnp.float32) + check(*proc(y_true, y_score)) + # customized labels + y_true = jnp.array([2, 2, 3, 3], dtype=jnp.int32) + y_score = jnp.array([0.1, 0.2, 0.3, 0.4], dtype=jnp.float32) + check(*proc(y_true, y_score, pos_label=3)) + # larger random dataset + y_true = jnp.array(np.random.randint(0, 2, 100), dtype=jnp.int32) + y_score = jnp.array(np.hstack((0, 1, np.random.random(98))), dtype=jnp.float32) + check(*proc(y_true, y_score)) + # single label edge case + y_true = jnp.array([0, 0, 0, 0], dtype=jnp.int32) + y_score = jnp.array([0.4, 0.25, 0.4, 0.25], dtype=jnp.float32) + check(*proc(y_true, y_score)) + y_true = jnp.array([1, 1, 1, 1], dtype=jnp.int32) + y_score = jnp.array([0.4, 0.25, 0.4, 0.25], dtype=jnp.float32) + check(*proc(y_true, y_score)) + # zero score edge case + y_true = jnp.array([0, 0, 1, 1, 1], dtype=jnp.int32) + y_score = jnp.array([0, 0, 0, 0.25, 0.25], dtype=jnp.float32) + check(*proc(y_true, y_score)) + # score > 1 edge case + y_true = jnp.array([0, 0, 1, 1, 1], dtype=jnp.int32) + y_score = jnp.array([1.5, 1.5, 1.5, 0.25, 0.25], dtype=jnp.float32) + check(*proc(y_true, y_score)) + + # --- Test multiclass classification --- + y_true = np.array([0, 0, 1, 1, 2, 2], dtype=jnp.int32) + y_score = np.array( + [ + [0.7, 0.2, 0.1], + [0.4, 0.3, 0.3], + [0.1, 0.8, 0.1], + [0.2, 0.3, 0.5], + [0.4, 0.4, 0.2], + [0.1, 0.2, 0.7], + ], + dtype=jnp.float32, + ) + classes = jnp.unique(y_true) + # test over three supported average options + for average in ["macro", "micro", None]: + sk_res = sk_average_precision_score(y_true, y_score, average=average) + spu_res = spsim.sim_jax(sim, average_precision_score, static_argnums=(3,))( + y_true, y_score, classes, average + ) + check(sk_res, spu_res) + if __name__ == "__main__": unittest.main() diff --git a/sml/preprocessing/preprocessing.py b/sml/preprocessing/preprocessing.py index 4232aebb..ca4e18d7 100644 --- a/sml/preprocessing/preprocessing.py +++ b/sml/preprocessing/preprocessing.py @@ -340,7 +340,7 @@ def fit(self, X, *, zero_variance=False, contain_nan=False): contain_nan : bool, default=False Set to True to handle the nan value. - This option desiced whether to use nanmin and nanmax to compute the minimum + This option decides whether to use nanmin and nanmax to compute the minimum and maximum. """ self._reset() @@ -364,7 +364,7 @@ def partial_fit(self, X, *, zero_variance=False, contain_nan=False): contain_nan : bool, default=False Set to True to handle the nan value. - This option desiced whether to use nanmin and nanmax to compute the minimum + This option decides whether to use nanmin and nanmax to compute the minimum and maximum. """ feature_range = self.feature_range @@ -434,7 +434,7 @@ def fit_transform(self, X, *, zero_variance=False, contain_nan=False): contain_nan : bool, default=False Set to True to handle the nan value. - This option desiced whether to use nanmin and nanmax to compute the minimum + This option decides whether to use nanmin and nanmax to compute the minimum and maximum. Returns @@ -495,7 +495,7 @@ def fit(self, X, zero_maxabs=False, contain_nan=False): contain_nan : bool, default=False Set to True to handle the nan value. - This option desiced whether to use nanmin and nanmax to compute the minimum + This option decides whether to use nanmin and nanmax to compute the minimum and maximum. """ self._reset() @@ -519,7 +519,7 @@ def partial_fit(self, X, *, zero_maxabs=False, contain_nan=False): contain_nan : bool, default=False Set to True to handle the nan value. - This option desiced whether to use nanmin and nanmax to compute the minimum + This option decides whether to use nanmin and nanmax to compute the minimum and maximum. """ first_pass = not hasattr(self, "n_samples_seen_") @@ -569,7 +569,7 @@ def fit_transform(self, X, *, zero_maxabs=False, contain_nan=False): contain_nan : bool, default=False Set to True to handle the nan value. - This option desiced whether to use nanmin and nanmax to compute the minimum + This option decides whether to use nanmin and nanmax to compute the minimum and maximum. Returns @@ -687,9 +687,9 @@ def loop_body(i, st): class KBinsDiscretizer: """Bin continuous data into intervals. - Attribute encode is not implemented, since there is currently no onehotencoder + Attribute encode is not implemented, since there is currently no OneHotEncoder in sml. - Attribute subsample is not implemented, since random choise in SPU runtime does + Attribute subsample is not implemented, since random choice in SPU runtime does not work as expected. Parameters @@ -744,7 +744,7 @@ def fit( Note that there is currently no support for handling constant values in a feature , since it introduces much redundant boolean computation because dynamic shape is not supported. - In sklarn, feature with constant value will be replaced with 0 after transformation. + In sklearn, feature with constant value will be replaced with 0 after transformation. (see https://github.com/scikit-learn/scikit-learn/blob/d139ff234b0f8ec30287e26e0bc801bdafdfbb1a/sklearn/preprocessing/tests/test_discretization.py#L192) Parameters diff --git a/sml/preprocessing/tests/preprocessing_test.py b/sml/preprocessing/tests/preprocessing_test.py index 5be80047..19fffec4 100644 --- a/sml/preprocessing/tests/preprocessing_test.py +++ b/sml/preprocessing/tests/preprocessing_test.py @@ -349,7 +349,7 @@ def test_kbinsdiscretizer_uniform(self): ) def kbinsdiscretize(X): - transformer = KBinsDiscretizer(n_bins=3, strategy='uniform') + transformer = KBinsDiscretizer(n_bins=5, strategy='uniform') transformed = transformer.fit_transform(X) inv_transformed = transformer.inverse_transform(transformed) return transformed, inv_transformed @@ -359,7 +359,7 @@ def kbinsdiscretize(X): ) transformer = preprocessing.KBinsDiscretizer( - n_bins=3, encode='ordinal', strategy='uniform', subsample=None + n_bins=5, encode='ordinal', strategy='uniform', subsample=None ) sk_transformed = transformer.fit_transform(X) sk_inv_transformed = transformer.inverse_transform(sk_transformed) @@ -382,14 +382,15 @@ def test_kbinsdiscretizer_uniform_diverse_n_bins(self): def kbinsdiscretize(X, n_bins): transformer = KBinsDiscretizer( - n_bins=3, diverse_n_bins=n_bins, strategy='uniform' + n_bins=max_bins, diverse_n_bins=n_bins, strategy='uniform' ) transformed = transformer.fit_transform(X) inv_transformed = transformer.inverse_transform(transformed) return transformed, inv_transformed X = jnp.array([[0, 0, 0, 0], [0, 1, 1, 0], [1, 2, 2, 1], [1, 2, 2, 2]]) - n_bins = jnp.array([2, 3, 3, 3]) + n_bins = jnp.array([3, 5, 5, 5]) + max_bins = int(jnp.max(n_bins)) transformer = preprocessing.KBinsDiscretizer( n_bins=n_bins, encode='ordinal', strategy='uniform', subsample=None @@ -415,16 +416,20 @@ def test_kbinsdiscretizer_uniform_diverse_n_bins_no_vectorize(self): 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 ) + # When you set vectorize to False, diverse_n_bins should be public. def kbinsdiscretize(X): transformer = KBinsDiscretizer( - n_bins=3, diverse_n_bins=np.array([2, 3, 3, 3]), strategy='uniform' + n_bins=max_bins, + diverse_n_bins=np.array([3, 5, 5, 5]), + strategy='uniform', ) transformed = transformer.fit_transform(X, vectorize=False) inv_transformed = transformer.inverse_transform(transformed) return transformed, inv_transformed X = jnp.array([[0, 0, 0, 0], [0, 1, 1, 0], [1, 2, 2, 1], [1, 2, 2, 2]]) - n_bins = jnp.array([2, 3, 3, 3]) + n_bins = jnp.array([3, 5, 5, 5]) + max_bins = int(jnp.max(n_bins)) transformer = preprocessing.KBinsDiscretizer( n_bins=n_bins, encode='ordinal', strategy='uniform', subsample=None @@ -449,7 +454,7 @@ def test_kbinsdiscretizer_quantile(self): ) def kbinsdiscretize(X): - transformer = KBinsDiscretizer(n_bins=3, strategy='quantile') + transformer = KBinsDiscretizer(n_bins=5, strategy='quantile') transformed = transformer.fit_transform(X) inv_transformed = transformer.inverse_transform(transformed) return transformed, inv_transformed @@ -459,7 +464,7 @@ def kbinsdiscretize(X): ) transformer = preprocessing.KBinsDiscretizer( - 3, encode='ordinal', strategy='quantile', subsample=None + 5, encode='ordinal', strategy='quantile', subsample=None ) sk_transformed = transformer.fit_transform(X) sk_inv_transformed = transformer.inverse_transform(sk_transformed) @@ -483,14 +488,15 @@ def test_kbinsdiscretizer_quantile_diverse_n_bins(self): def kbinsdiscretize(X, n_bins): transformer = KBinsDiscretizer( - n_bins=3, diverse_n_bins=n_bins, strategy='quantile' + n_bins=max_bins, diverse_n_bins=n_bins, strategy='quantile' ) transformed = transformer.fit_transform(X, remove_bin=True) inv_transformed = transformer.inverse_transform(transformed) return transformed, inv_transformed X = jnp.array([[0, 0, 0, 0], [0, 1, 1, 0], [1, 2, 2, 1], [1, 2, 2, 2]]) - n_bins = jnp.array([2, 3, 3, 3]) + n_bins = jnp.array([3, 5, 5, 5]) + max_bins = int(jnp.max(n_bins)) transformer = preprocessing.KBinsDiscretizer( n_bins=n_bins, encode='ordinal', strategy='quantile', subsample=None @@ -519,14 +525,15 @@ def test_kbinsdiscretizer_quantile_diverse_n_bins2(self): def kbinsdiscretize(X, n_bins): transformer = KBinsDiscretizer( - n_bins=4, diverse_n_bins=n_bins, strategy='quantile' + n_bins=max_bins, diverse_n_bins=n_bins, strategy='quantile' ) transformed = transformer.fit_transform(X, remove_bin=True) inv_transformed = transformer.inverse_transform(transformed) return transformed, inv_transformed X = jnp.array([[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]]) - n_bins = jnp.array([2, 4, 4, 4]) + n_bins = jnp.array([4, 5, 5, 5]) + max_bins = int(jnp.max(n_bins)) transformer = preprocessing.KBinsDiscretizer( n_bins=n_bins, encode='ordinal', strategy='quantile', subsample=None @@ -555,14 +562,17 @@ def test_kbinsdiscretizer_quantile_diverse_n_bins_no_vectorize(self): def kbinsdiscretize(X): transformer = KBinsDiscretizer( - n_bins=3, diverse_n_bins=np.array([2, 3, 3, 3]), strategy='quantile' + n_bins=max_bins, + diverse_n_bins=np.array([3, 5, 5, 5]), + strategy='quantile', ) transformed = transformer.fit_transform(X, vectorize=False, remove_bin=True) inv_transformed = transformer.inverse_transform(transformed) return transformed, inv_transformed X = jnp.array([[0, 0, 0, 0], [0, 1, 1, 0], [1, 2, 2, 1], [1, 2, 2, 2]]) - n_bins = jnp.array([2, 3, 3, 3]) + n_bins = jnp.array([3, 5, 5, 5]) + max_bins = int(jnp.max(n_bins)) transformer = preprocessing.KBinsDiscretizer( n_bins=n_bins, encode='ordinal', strategy='quantile', subsample=None @@ -588,7 +598,7 @@ def test_kbinsdiscretizer_quantile_eliminate(self): ) def kbinsdiscretize(X): - transformer = KBinsDiscretizer(n_bins=3, strategy='quantile') + transformer = KBinsDiscretizer(n_bins=2, strategy='quantile') transformed = transformer.fit_transform(X, remove_bin=True) inv_transformed = transformer.inverse_transform(transformed) return transformed, inv_transformed @@ -603,7 +613,7 @@ def kbinsdiscretize(X): ) transformer = preprocessing.KBinsDiscretizer( - 3, encode='ordinal', strategy='quantile', subsample=None + 2, encode='ordinal', strategy='quantile', subsample=None ) sk_transformed = transformer.fit_transform(X) sk_inv_transformed = transformer.inverse_transform(sk_transformed) @@ -625,18 +635,25 @@ def test_kbinsdiscretizer_quantile_sample_weight(self): ) def kbinsdiscretize(X, sample_weight): - transformer = KBinsDiscretizer(n_bins=3, strategy='quantile') + transformer = KBinsDiscretizer(n_bins=2, strategy='quantile') transformed = transformer.fit_transform( X, sample_weight=sample_weight, remove_bin=True ) inv_transformed = transformer.inverse_transform(transformed) return transformed, inv_transformed - X = jnp.array([[0, 0, 0, 0], [0, 1, 1, 1], [1, 2, 2, 2], [1, 2, 2, 2]]) + X = jnp.array( + [ + [0.2, 0.2, 0.3, 0.4], + [0.5, 1.1, 1.2, 1], + [0.7, 2.12, 2.3, 2.1], + [1, 2.51, 2.9, 2.6], + ] + ) sample_weight = jnp.array([1, 1, 3, 1]) transformer = preprocessing.KBinsDiscretizer( - 3, encode='ordinal', strategy='quantile', subsample=None + 2, encode='ordinal', strategy='quantile', subsample=None ) transformer.fit(X, sample_weight=sample_weight) sk_transformed = transformer.transform(X) @@ -670,9 +687,18 @@ def kbinsdiscretize(X, n_bins, sample_weight): inv_transformed = transformer.inverse_transform(transformed) return transformed, inv_transformed - X = jnp.array([[0, 0, 0, 0], [0, 1, 1, 1], [1, 2, 2, 2], [1, 2, 2, 2]]) - n_bins = jnp.array([2, 3, 3, 3]) - sample_weight = jnp.array([1, 1, 3, 1]) + X = jnp.array( + [ + [0.2, 0.2, 0.3, 0.4], + [0.5, 1.1, 1.2, 1], + [0.7, 2.12, 2.3, 2.1], + [1, 2.51, 2.9, 2.6], + [1.3, 2.8, 3.1, 2.12], + [1.9, 2.91, 3.4, 2.99], + ] + ) + n_bins = jnp.array([2, 2, 3, 3]) + sample_weight = jnp.array([1, 1, 3, 1, 1, 1]) transformer = preprocessing.KBinsDiscretizer( n_bins=n_bins, encode='ordinal', strategy='quantile', subsample=None @@ -701,7 +727,7 @@ def test_kbinsdiscretizer_quantile_sample_weight_diverse_n_bins2(self): def kbinsdiscretize(X, n_bins, sample_weight): transformer = KBinsDiscretizer( - n_bins=4, diverse_n_bins=n_bins, strategy='quantile' + n_bins=5, diverse_n_bins=n_bins, strategy='quantile' ) transformed = transformer.fit_transform( X, sample_weight=sample_weight, remove_bin=True @@ -709,9 +735,19 @@ def kbinsdiscretize(X, n_bins, sample_weight): inv_transformed = transformer.inverse_transform(transformed) return transformed, inv_transformed - X = jnp.array([[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]]) - n_bins = jnp.array([2, 4, 4, 4]) - sample_weight = jnp.array([1, 1, 3, 1]) + X = jnp.array( + [ + [1.0, 1.2, 1, 1], + [2, 2, 2.6, 2.1], + [3.1, 3.11, 3.48, 3.09], + [4, 4.1, 4.4, 4.6], + [5, 5.2, 5.88, 5.11], + [6, 6.4, 6.2, 6.4], + [7, 7, 7.2, 7], + ] + ) + n_bins = jnp.array([2, 3, 4, 5]) + sample_weight = jnp.array([1, 1, 3, 1, 2, 1, 1]) transformer = preprocessing.KBinsDiscretizer( n_bins=n_bins, encode='ordinal', strategy='quantile', subsample=None @@ -740,7 +776,7 @@ def test_kbinsdiscretizer_quantile_sample_weight_diverse_n_bins_no_vectorize(sel def kbinsdiscretize(X, sample_weight): transformer = KBinsDiscretizer( - n_bins=3, diverse_n_bins=np.array([2, 3, 3, 3]), strategy='quantile' + n_bins=5, diverse_n_bins=n_bins, strategy='quantile' ) transformed = transformer.fit_transform( X, vectorize=False, sample_weight=sample_weight, remove_bin=True @@ -748,9 +784,19 @@ def kbinsdiscretize(X, sample_weight): inv_transformed = transformer.inverse_transform(transformed) return transformed, inv_transformed - X = jnp.array([[0, 0, 0, 0], [0, 1, 1, 0], [1, 2, 2, 1], [1, 2, 2, 2]]) - n_bins = jnp.array([2, 3, 3, 3]) - sample_weight = jnp.array([1, 1, 3, 1]) + X = jnp.array( + [ + [1.0, 1.2, 1, 1], + [2, 2, 2.6, 2.1], + [3.1, 3.11, 3.48, 3.09], + [4, 4.1, 4.4, 4.6], + [5, 5.2, 5.88, 5.11], + [6, 6.4, 6.2, 6.4], + [7, 7, 7.2, 7], + ] + ) + n_bins = np.array([2, 3, 4, 5]) + sample_weight = jnp.array([1, 1, 3, 1, 2, 1, 1]) transformer = preprocessing.KBinsDiscretizer( n_bins=n_bins, encode='ordinal', strategy='quantile', subsample=None @@ -838,63 +884,6 @@ def kbinsdiscretize(X): sk_inv_transformed, spu_inv_transformed, rtol=0, atol=1e-4 ) - # def test_kbinsdiscretizer_kmeans_diverse_n_bins(self): - # sim = spsim.Simulator.simple( - # 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - # ) - - # def kbinsdiscretize(X, n_bins): - # transformer = KBinsDiscretizer(n_bins=3, diverse_n_bins=n_bins, strategy='kmeans') - # transformer.fit(X, remove_bin=True) - # transformed = transformer.transform(X) - # inv_transformed = transformer.inverse_transform(transformed) - # return transformed, inv_transformed - - # X = jnp.array([[0, 0, 0, 0], [0, 1, 1, 0], [1, 2, 2, 1], [1, 2, 2, 2]]) - # n_bins = jnp.array([2, 3, 3, 3]) - - # transformer = preprocessing.KBinsDiscretizer(n_bins=n_bins, encode='ordinal', strategy='kmeans', subsample=None) - # sk_transformed = transformer.fit_transform(X) - # sk_inv_transformed = transformer.inverse_transform(sk_transformed) - # # print("sklearn:\n", sk_transformed) - # # print("sklearn:\n", sk_inv_transformed) - - # spu_transformed, spu_inv_transformed = spsim.sim_jax(sim, kbinsdiscretize)(X, n_bins) - # # print("result\n", spu_transformed) - # # print("result\n", spu_inv_transformed) - - # np.testing.assert_allclose(sk_transformed, spu_transformed, rtol=0, atol=1e-4) - # np.testing.assert_allclose(sk_inv_transformed, spu_inv_transformed, rtol=0, atol=1e-4) - - # def test_kbinsdiscretizer_kmeans_diverse_n_bins2(self): - # sim = spsim.Simulator.simple( - # 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - # ) - - # transformer_spu = KBinsDiscretizer(n_bins=4, diverse_n_bins=np.array([2, 4, 4, 4]), strategy='kmeans') - # def kbinsdiscretize(X, n_bins): - # # transformer_spu = KBinsDiscretizer(n_bins=4, diverse_n_bins=np.array([2, 4, 4, 4]), strategy='kmeans') - # transformer_spu.fit(X) - # transformed = transformer_spu.transform(X) - # inv_transformed = transformer_spu.inverse_transform(transformed) - # return transformed, inv_transformed - - # X = jnp.array([[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]]) - # n_bins = jnp.array([2, 4, 4, 4]) - - # transformer = preprocessing.KBinsDiscretizer(n_bins=n_bins, encode='ordinal', strategy='kmeans', subsample=None) - # sk_transformed = transformer.fit_transform(X) - # sk_inv_transformed = transformer.inverse_transform(sk_transformed) - # # print("sklearn:\n", sk_transformed) - # # print("sklearn:\n", sk_inv_transformed) - - # spu_transformed, spu_inv_transformed = spsim.sim_jax(sim, kbinsdiscretize)(X, n_bins) - # # print("result\n", spu_transformed) - # # print("result\n", spu_inv_transformed) - - # np.testing.assert_allclose(sk_transformed, spu_transformed, rtol=0, atol=1e-4) - # np.testing.assert_allclose(sk_inv_transformed, spu_inv_transformed, rtol=0, atol=1e-4) - if __name__ == "__main__": unittest.main() diff --git a/sml/tree/tree.py b/sml/tree/tree.py index f946909a..b6b598a5 100644 --- a/sml/tree/tree.py +++ b/sml/tree/tree.py @@ -53,8 +53,8 @@ def __init__(self, criterion, splitter, max_depth, n_labels): self.max_depth = max_depth self.n_labels = n_labels - def fit(self, X, y): - self.T, self.F = odtt(X, y, self.max_depth, self.n_labels) + def fit(self, X, y, sample_weight=None): + self.T, self.F = odtt(X, y, self.max_depth, self.n_labels, sample_weight) return self def predict(self, X): @@ -115,7 +115,7 @@ def oaa_elementwise(array, index_array): # def oblivious_learning(X, y, T, F, M, Cn, h): -def oblivious_learning(X, y, T, F, M, h, Cn, n_labels): +def oblivious_learning(X, y, T, F, M, h, Cn, n_labels, sample_weight=None): '''partition the data and count the number of data samples. params: @@ -135,27 +135,65 @@ def oblivious_learning(X, y, T, F, M, h, Cn, n_labels): Dval = oaae(X, Tval) M = 2 * M + Dval + 1 - # (n_leaves) LCidx = jnp.arange(0, n_h) isLeaf = jnp.equal(F[n_h - 1 : 2 * n_h - 1], jnp.ones(n_h)) - # (n_samples, n_leaves) LCF = jnp.equal(M[:, jnp.newaxis] - n_h + 1, LCidx) LCF = LCF * isLeaf - # (n_samples, n_leaves, n_labels, 2 * n_features) + Cd = jnp.zeros((n_d, n_h, n_labels + 1, 2 * n_f)) - Cd = Cd.at[:, :, 0, 0::2].set(jnp.tile((1 - X)[:, jnp.newaxis, :], (1, n_h, 1))) - Cd = Cd.at[:, :, 0, 1::2].set(jnp.tile((X)[:, jnp.newaxis, :], (1, n_h, 1))) - for i in range(n_labels): - Cd = Cd.at[:, :, i + 1, 0::2].set( + if sample_weight is not None: + Cd = Cd.at[:, :, 0, 0::2].set( jnp.tile( - ((1 - X) * (i == y)[:, jnp.newaxis])[:, jnp.newaxis, :], (1, n_h, 1) + (1 - X)[:, jnp.newaxis, :] * sample_weight[:, jnp.newaxis, jnp.newaxis], + (1, n_h, 1), ) ) - Cd = Cd.at[:, :, i + 1, 1::2].set( - jnp.tile(((X) * (i == y)[:, jnp.newaxis])[:, jnp.newaxis, :], (1, n_h, 1)) + Cd = Cd.at[:, :, 0, 1::2].set( + jnp.tile( + (X)[:, jnp.newaxis, :] * sample_weight[:, jnp.newaxis, jnp.newaxis], + (1, n_h, 1), + ) ) + else: + Cd = Cd.at[:, :, 0, 0::2].set(jnp.tile((1 - X)[:, jnp.newaxis, :], (1, n_h, 1))) + Cd = Cd.at[:, :, 0, 1::2].set(jnp.tile((X)[:, jnp.newaxis, :], (1, n_h, 1))) + + for i in range(n_labels): + if sample_weight is not None: + Cd = Cd.at[:, :, i + 1, 0::2].set( + jnp.tile( + ( + (1 - X)[:, jnp.newaxis, :] + * (i == y)[:, jnp.newaxis, jnp.newaxis] + * sample_weight[:, jnp.newaxis, jnp.newaxis] + ), + (1, n_h, 1), + ) + ) + Cd = Cd.at[:, :, i + 1, 1::2].set( + jnp.tile( + ( + (X)[:, jnp.newaxis, :] + * (i == y)[:, jnp.newaxis, jnp.newaxis] + * sample_weight[:, jnp.newaxis, jnp.newaxis] + ), + (1, n_h, 1), + ) + ) + else: + Cd = Cd.at[:, :, i + 1, 0::2].set( + jnp.tile( + ((1 - X) * (i == y)[:, jnp.newaxis])[:, jnp.newaxis, :], (1, n_h, 1) + ) + ) + Cd = Cd.at[:, :, i + 1, 1::2].set( + jnp.tile( + ((X) * (i == y)[:, jnp.newaxis])[:, jnp.newaxis, :], (1, n_h, 1) + ) + ) + Cd = Cd * LCF[:, :, jnp.newaxis, jnp.newaxis] - # (n_leaves, n_labels+1, 2*n_features) + new_Cn = jnp.sum(Cd, axis=0) if h != 0: @@ -221,7 +259,7 @@ def oblivious_node_split(SD, T, F, Cn, h, max_depth): return T, Cn -def oblivious_DT_training(X, y, max_depth, n_labels): +def oblivious_DT_training(X, y, max_depth, n_labels, sample_weight=None): n_samples, n_features = X.shape T = jnp.zeros((2 ** (max_depth + 1) - 1)) F = jnp.ones((2**max_depth - 1)) @@ -231,7 +269,10 @@ def oblivious_DT_training(X, y, max_depth, n_labels): h = 0 while h < max_depth: - Cn, M = ol(X, y, T, F, M, h, Cn, n_labels) + if sample_weight is not None: + Cn, M = ol(X, y, T, F, M, h, Cn, n_labels, sample_weight) + else: + Cn, M = ol(X, y, T, F, M, h, Cn, n_labels) SD, gamma, F = ohc(Cn, gamma, F, h, n_labels) diff --git a/spu/experimental/drop_cached_var_impl.py b/spu/experimental/drop_cached_var_impl.py index 0222a27c..3831ce85 100644 --- a/spu/experimental/drop_cached_var_impl.py +++ b/spu/experimental/drop_cached_var_impl.py @@ -26,9 +26,9 @@ # Public facing interface -def drop_cached_var(input, *dependences): +def drop_cached_var(input, *dependencies): # Add necessary preprocessing code - return _drop_cached_var_prim.bind(input, *dependences) + return _drop_cached_var_prim.bind(input, *dependencies) # ********************************* @@ -38,12 +38,12 @@ def drop_cached_var(input, *dependences): # For JIT compilation we need a function to evaluate the shape and dtype of the # outputs of our op for some given inputs -def _drop_cached_var_abstract(input, *dependences): +def _drop_cached_var_abstract(input, *dependencies): return core.ShapedArray(input.shape, input.dtype) # We also need a lowering rule to provide an MLIR "lowering" of out primitive. -def _drop_cached_var_lowering(ctx, input, *dependences): +def _drop_cached_var_lowering(ctx, input, *dependencies): # The inputs and outputs all have the same shape and memory layout # so let's predefine this specification dtype = mlir.ir.RankedTensorType(input.type) @@ -53,7 +53,7 @@ def _drop_cached_var_lowering(ctx, input, *dependences): # Output types result_types=[dtype], # The inputs: - operands=[input, *dependences], + operands=[input, *dependencies], has_side_effect=True, ).results @@ -70,8 +70,8 @@ def _drop_cached_var_lowering(ctx, input, *dependences): mlir.register_lowering(_drop_cached_var_prim, _drop_cached_var_lowering) -def _drop_cached_var_transpose(ct, input, *dependences): - return [ct] * (len(dependences) + 1) +def _drop_cached_var_transpose(ct, input, *dependencies): + return [ct] * (len(dependencies) + 1) # Connect the JVP and batching rules diff --git a/spu/libpsi.cc b/spu/libpsi.cc index fec0f233..0beb06eb 100644 --- a/spu/libpsi.cc +++ b/spu/libpsi.cc @@ -104,16 +104,30 @@ void BindLibs(py::module& m) { "Run UB PSI with v2 API.", NO_GIL); m.def( - "pir", + "apsi_send", [](const std::string& config_pb, const std::shared_ptr& lctx) -> py::bytes { - psi::PirConfig config; + psi::ApsiSenderConfig config; YACL_ENFORCE(config.ParseFromString(config_pb)); auto r = psi::RunPir(config, lctx); return r.SerializeAsString(); }, - py::arg("pir_config"), py::arg("link_context") = nullptr, "Run PIR."); + py::arg("pir_config"), py::arg("link_context") = nullptr, + "Run APSI sender operations."); + + m.def( + "apsi_receive", + [](const std::string& config_pb, + const std::shared_ptr& lctx) -> py::bytes { + psi::ApsiReceiverConfig config; + YACL_ENFORCE(config.ParseFromString(config_pb)); + + auto r = psi::RunPir(config, lctx); + return r.SerializeAsString(); + }, + py::arg("pir_config"), py::arg("link_context") = nullptr, + "Run APSI receiver operations."); } PYBIND11_MODULE(libpsi, m) { diff --git a/spu/psi.py b/spu/psi.py index 2796ab52..581bd3d3 100644 --- a/spu/psi.py +++ b/spu/psi.py @@ -19,7 +19,11 @@ from . import libpsi # type: ignore from .libpsi.libs import ProgressData from .libspu.link import Context # type: ignore -from .pir_pb2 import PirConfig, PirProtocol, PirResultReport # type: ignore +from .pir_pb2 import ( # type: ignore + ApsiReceiverConfig, + ApsiSenderConfig, + PirResultReport, +) from .psi_pb2 import ( # type: ignore BucketPsiConfig, CurveType, @@ -144,8 +148,16 @@ def ub_psi( return report -def pir(config: PirProtocol, link: Context = None) -> PirResultReport: - report_str = libpsi.libs.pir(config.SerializeToString(), link) +def apsi_send(config: ApsiSenderConfig, link: Context = None) -> PirResultReport: + report_str = libpsi.libs.apsi_send(config.SerializeToString(), link) + + report = PirResultReport() + report.ParseFromString(report_str) + return report + + +def apsi_receive(config: ApsiReceiverConfig, link: Context = None) -> PirResultReport: + report_str = libpsi.libs.apsi_receive(config.SerializeToString(), link) report = PirResultReport() report.ParseFromString(report_str) diff --git a/spu/tests/BUILD.bazel b/spu/tests/BUILD.bazel index fd037279..d9e348cb 100644 --- a/spu/tests/BUILD.bazel +++ b/spu/tests/BUILD.bazel @@ -178,7 +178,6 @@ py_binary( srcs = ["jnp_debug.py"], deps = [ "//spu:api", - "//spu/intrinsic:all_intrinsics", "//spu/utils:simulation", ], ) diff --git a/spu/tests/data/BUILD.bazel b/spu/tests/data/BUILD.bazel index 1eadb581..647c6e0a 100644 --- a/spu/tests/data/BUILD.bazel +++ b/spu/tests/data/BUILD.bazel @@ -20,5 +20,12 @@ filegroup( "alice.csv", "bob.csv", "carol.csv", + "db.csv", + "ground_truth.csv", + "pir/100K-1-16.json", + "pir/db.csv", + "pir/ground_truth.csv", + "pir/query.csv", + "query.csv", ], ) diff --git a/spu/tests/data/db.csv b/spu/tests/data/db.csv new file mode 100644 index 00000000..207e86e7 --- /dev/null +++ b/spu/tests/data/db.csv @@ -0,0 +1,100 @@ +aPYaKgvvcESwAtfghnRUIAYYIZeCsGeaWbAAEUCXQzNrxGRPVOcACqMBJmdfiveq,LdQNbKmBMhlpCctB +eXeLfYoSlEntpRaKuddYhtaImMOdEhNTIolxElSrlPYMhgZoWxccpUTOjFciaRcD,wcsaUIrVHmxDxnXr +gOUgSoWKKYaLpmWiuTyuVNostvTJUHxBZjYWJZukOzqICDlmKdavyERZimgOkaHn,FgdZUDbsfINsODTe +BrImQzkUpelejHVFeatNDJTqdTuxCJmHKunkODDPsUtsgEfCkXyLNEAZXniJZgvf,PkxxKRwEEEwrcLbR +RhoATcXotsJDbAICIkIbAYdDUXqreTDbZVPoSyziyQJvBVVaZhYqWmHVQsqNbdIP,OTFagQCudOpRtKqT +bibVQsoLIFjXrSFIjPFOTWclPDhakarsWICsYNvbMtlNwBIKqnaWmvCPcFPesBVy,YysQfuPwTiXayxlt +CQZcsLKmugDFfsYAOihJENLpBwCYwoLuATHrMiIChXkycvhejgjjUCoxKDpTDMuZ,OIGhXryCxSXCZKBn +aizidqjurwmpHHqioJclEQvccvnWQKmZTMBNSGgmfxxqHilZwFPdjHlZElQaayJN,SjCyivgzSTLpXmnz +eNGJWxNcevvcthSXdXLseTORLpDVdVCAdnbQtcovUVAislbievgctCGMidsncXyL,rFJEAlFBMUMZNPrc +OQArDeXFlhxHBPERmlYuwZHQAgCQtRorzCkDenmMdnbFdGKJggqmhApWMrbtqDZw,saeWaHguQALRalnZ +AZKVktpAYSLSZTpjylrifvdItAuhXvzvDhtWnbjMlshYFTooHXtwNfCsmiuWgXLa,sntRGrWUapWvBwAM +xNTfGmPFpdczqsPvlLFMSzXLdqGkyupfwXepkfJHECoAIDZaRwqoIullNzMfWHLK,jVerCzWaKQoYYBeP +hKFRJSHErdbPbPFRZpEkigOemyqRgRDzWHMhiVDlHDtTAjKHZcQCogTvFlbwVVbA,ftLFWLGNrSkavyEX +tGYkoEeaQinZmEsPtFyhyJxXGiJreOCTHTxPfSWFFQwqChLKADXjeWiTsVnLyFTb,vKUHgdGvKkJKSAcJ +KSyQGEXCvzziOdoaIwQYynSjQlyPOkJYroZAxblcGRHOqhjWuByOVIUYfuacruUQ,NnsQMUvXYZOhswVL +WsQkdjbaExqvNkQlbMWqwxtChiWRBsmylFanbONarkbmmWAnKPvCdPLlhQtvFaXP,phEafGnFQIcoYCKG +rESDWnSxXeTvwvTDxSsSmwPswujcWVlrvlgVAuGUYjvyXFjCSGenqEZhfFHgZXIB,sHxPvGwZWDbJgCYn +UKgGYBiPUdYdfuQZrpKEtprGxofPaLVqhdiuSUzKrSDzXuCgYuOQJlyhsNHyLrCx,sdNdnkjuRNLyPhaB +eEXLhpsjwQQRFVXxPgqQnFMCQDAMOMxwoAeeubrJXwxKaiWgnilOwjzoEZuRUJLV,WXVgFQQFBEwRomjA +QVGpdCqJIpZStCSUOEKyEoOJJurZSWmQZDCnIrANHGJYhpbfxAhsPvrVZnVrhQKn,bUIqpQPVvtiPhItz +UOmpkfibQXxQlYJQzZgdfoIHckIWHWiVBcoLaLSQnlpnIBQZcrnCEXCfHTDwsFDX,OdckqFMcdfnVnBoB +TDzDtfzzMugZZxhNGhmwYsMrOCFvCUWmmOUOLNGAYMRMnZVGuOMSXZZgaTufrqXK,DAToiYOldpgNGOqn +zmRKIEQFtIjCYUXaFgyAvZZEDIukHAwYlzUwxbttWndcAGFEoRzGyAUuLsKnbfZi,rfwdiePuXVtvKgat +vmtpIcBkPJyFRqKWIYHWcecKdgCoUShJwkhYvjHZPdhwmcdBGwQDDVynyOwSZcYj,PDNsnMKRZubVpMRT +EjQMHoawzxMREpZaJFKJBNsnKdzQTWeGmAMkhsuSfEzoDpQfdUWUeTFKvKClRNPz,rauheCdowFiOAMFk +mjjTWkjovIsCsMuZfdtXIKVZEcwuspLRUtCVPKpMdkkaGQtUUmFrXaZHaDuPKvsa,wCscOftxAHuBnsSW +MjvTMaePwIVFpEbspToomYGFAOmpGuKlmgJvIOhtVoHNgWaHReuMELUapHWAaZjL,MeizqMvAktGZLkCH +EwJZCBgPDuKRnTTZwuJRKfkznXpHGdbfMOZTnVjixKGciMLkdLSzWBXkBhMGzwSS,RhpjUzFsJtSSXund +HSaBXSGcBxYSUIXYnlFnYrdTclIehDdMhKqIRJuAYebfViJttknfMmCqbyYOJAXE,NIZPwgQebsKBehaN +fLOyQDLsIUaWZUjwzsrxlGHlGTYNWVZyTEWJZenWqZiMqHEpLWAvGojmOQvteOqS,XGTuKgtLshqQUtfr +wvKVPbYksmYXTsRqvJETrjXJethrvgmBLIwMQhJBCMTfLGOFKHwxrrBcGQqdMjZe,ICnEQAovJhrWaIiY +WpbdxcjMqKMkLdSlUBowCTWDGVtRJLiEDQytMenWEIkWFWLKByiEhrvIpCncUQDS,xEGgXHFKHYDlGdYF +AVSBpQelmdheyUZPdmRrhrEqHmKowFAIDNjxzphVCoLgSypBfHNtVuDgoIVqCoLF,ZLFAfNImdiEwcupl +XtndzLknyWTeElGFXjZfbrHGqYzqHcTzEtXquKkpckuwhkQPcCkmXIhfCnLYCrVG,UHbNwRmAJFMalnbt +lvLHBgacmSdZJqpzrezjTYTfWIBFUDIaMcGwErtmnAwgjDXwmIHxMDqYTrJvjUyq,NoWnhJQWDJVCtExB +SKenZwjdFvsAiARRmpBzTAXGWtByjJcIniiAhovlsAHLZXQJCmDyRJxKevZBttDa,mHAfCtCExIyoRuWG +vKgaRAkJVOSESTgmEVXpQIVXDADJiHmFFaAwxtjwUyFVrQouyJcZeDwhMUZPROkA,CiuvQKMCsVGLkoBM +cdxMDaSfvmlcpSfFqvfzgNIyUcmkDEyVswXcCKCJfYyAqrSCGWBGIEQlBxKTWSCj,LTIsYYovMEnpawzQ +bnfGLKRBqwOCQBNWSRAbEWVLqyUAlzrYNiJRWUZnmGnXtjMFBQhLHBhVJBygIrGU,vLYmsxCGzRXOMiqB +bxidiEYIOjbFDbHRqnaYXZuQcZJbNxynsmjNPCpzujEKzATaBeTrUchoylhvqLjx,AXlyMrMWiWGMqoIs +eSuMvTbZgMRDrEIxJwgFYdpNWkQzEzrsyyybeaJlUPEhEZZBWpPwQFqImIGnLFar,BTfGxXWpxvvPzcDS +uONlUmHMFQiPkCfdPrqqUDleaUBHKnxuQbFJAqfMqzoSpqPzawdOIvtVQMSWfCRv,uFDJTeJjZNFjIxKT +IEWyZnggNMulCyYklMdZaMYiIsqQNtbzbcpMHBUfPeOKaoSCMeezBSqcQwVXNJho,zpDlpaygXXZyslHj +ItlLEykbpzIeTErDcbaxxfzntBAYcPHVcLFneOzhNhxYYgwsbKZEmHHuHTnPnhSW,nKPAjZgLwcVTMNGi +uTQQIgVItVwPUSygrxoUwrLuFbAbahqbnixUuRIKRnJConAViRHYsRerKyEieFYI,fhPHVglAfXtHjNae +etMRJJVQSxIMgvzdSoCPsVpGJcKtjpXtMqtzXgaGQDryplTJifNvOFYGWChHLOUo,WJrEYWYGMOqveCgd +BYHllipDRYZVvMQYYhIRzLHabgftPTSnFbUCRmZejFUoeLLQoZtrZPJrTjVqWfNO,EztmadvCnpbgQtBl +ZEkjAbQOpxQrbtEDVlDhKgChCNsxTxSQtUXARrEeVJQrzPuVPkYHuuoMXjVCyeCk,kcgMJGMwDDDiFDSP +WdhxSOyjKNLItJZiXZtkKBdIcjGLuLbZHPmSfJCzlvqBxnrobjDTPxsFRXhEInhh,OPuAWqKpmpZHbvIY +wSGlgcmRasbLYVClIhhCppYezjZrWIhhaiASQcDrCDxdsGJIJjNmTWwbsFuKlbAt,VToAPASiTJHUGIEk +CFzAfFrDFyhndtOJalJiSiufNlCjWcwxwQjnjRgbqFlaAlIzXgrwVJmwISEAKHxx,jpylGyqZUinEIiVo +BwluSThaftzdTOrrzjBfqmHdbXqLUDEPPqmduYZESLtaSQAWLOKeuECRKPDEumJA,vgocckgnbQZCgpSq +OUguorluoGFpgVuXujAFBOIkBsIAaNCXgcywsWjutvEcrJrrDBRAHgcKwnfNLpXr,vwsRUvSJpupXvGta +TXQWeuibBvqDmmTaLAPZNsEHosjhcLBsixvomJaiAPmLmBDemETNOZMwrwREVRir,EIgmXjWGNYpcfdOs +BbQiYWirQnKUzBachIeJZgJWvmeZkpBUphzYEGrGcGpxUwWvSSnSYpQzBRsOFuzC,oUBiVRJNarkuUdRV +dugMhlZPXHEnWTZjaOFaxmmZgIHdVBmzUsIfUUPZkdVzDvCLRZyTBPrHhDVrYAVk,OdxERipabMlyoyEa +aKPvrmNwyclBMcPMEAgGItoshSJVrSonWVrcHMWiBXxqTpdHjuGKrRLTaHhQyunk,TmBLACtwsObAkgoz +LSndhzUuoIbprGCzRDfZryKqdcqwLpLWYfHoOHgBJDSkZRoYMQxmIoVCUdBSxHsZ,qvhvMIqFvXELhWqK +lInwjXRJTuZgrvcbDEvFHgPFGpqlnSuIJWtNRJYizWEfZJbZtfLexmEQMnGLxNlW,MURWsZhFTeLoqeAd +DOlyDBVEZFruhsBwHZgrnWTXckcVJcVzrwniSnJFEYUNiFgkIyukstlbdrluVhag,vUuhogLghMFjNjyK +zgDxGOAAYMGtAOoMviwtSQDLEpHluVuqFsqisVvoKCLfnMdPVTKgKCrchKrAmlmz,YaojDSYxfntenXUp +JpnLfyRVbNPLIfbvPGuakcXCvxtoElcbACKRUfMSiKUemqyOVmvLspaZEPUtqJxv,eweOmajnCQLZOgBp +nJJyhHkDkQRttGsYjkMBBuGeQuPPDHQEQQdnGQmMbOdRFifRDpZVUqdfqeskxngR,nyikwrUYIuWbTawJ +ZcwBhgoFFiWyDeZbeMpliSvhbfsAkccEQQhYreLTGdfVHuNLpmCsduhkIlRKMNkx,nWeXVNOUmsoakAMC +BCJYdYFkkLRWUxhnnwpJGbXJchPvVCtbcaMdkArTcNLdRmopwncgdgOLGhJkZOnC,JWAVSCWqKUcmfoqk +NkxNbLyBjixoCzClTdwshhuZcFRjJJdDdWgCfiQIttZWQWqouBkYyMGpampLdUAr,nkuaUoDVcCUBUoOv +WgwvZoDsjspDAWHEflQMWzlbqnssWiBElmABhLmhgDPqFbNmAHSnzQrbAqSVAmWS,IQIlCVVcjVcTdshJ +yXelVXMUEuAtfNgzPrhjvYOpiAVEMZuqPfsQEUQoshjSIekxxzkFxdftfqFzfzpa,AyTbbGVAXKCYUUln +sAtWqpSPxPSkDtmIJKfNvlKjgStnYMOmrLsQnzmIFAusETPPzLDTjKcBASKWNRAJ,bDQFllamogAjBEPU +RAOMDMMZkezCBWxQDWLjvHLkbpvFyrbUbDDEekWciXejYwKifSVumcsocUmMkmpa,MnwKAECsMVGVLIZM +LMizeIoxMxHCKwikjqOSSPbuiqXWDAmbTLMBOXpyorUmpunjWFTNLVqvHNcCNHrN,zDzyAkfeTYZzaxhG +IzLqciYsaKtWrsrjOQldeIvEqavoIZEYnupJZLizJVeOhoLqtLQFaoNRdvZMWSQH,klLegifptLAxnhha +VkyQsnlXcIGjGhcSJUcZKeQyiDUbIcgIWSbHaEsbfEydSTHqRlxImGdYGEurZczg,jyzyCGlvuBdKwyIX +dkSRcFpHPXWjNIHrCpWlOPaIkVqjtyPRhlJeYMksjieDxYhiUcGhbuvamVlrMDDx,FATEzwNXGerINvHD +KravfnzVbNOhLhstPcaLLVWpbqYzXckQGbuALlEiXqqhFfUyThFZFhzSLhjldPMB,sowcmsLQTFiKpNXy +ITZWxxjqnizlxuRWMlPQLUnBopyDtOxNfcaoFDbRKetIpVxKLSRoJOauSCcDwWUP,NmqaomxmRQGqiKCV +guhzFIXdUEACMMHObJptkrZqJgbclDcdRxCPYSvuAdITaKgHfJaKLNHFzdRmpHni,DiONbLbxHfUHhoTU +AnZjaIBmnsEBplkEmHBstdggPnYmhblyQQttVqYzxxNtOXwlNQetkvCOySSXRUpw,FCdUtLDyvAqszerb +wkCcCXGKcZJEDwTkzOoNRkMbxHdNciQlVruGSKcJrHpokspcZIVfupcTxapISupH,pzEvnKzLQbzNDSQN +ppieExXmNHqBXVgLFhjlHHHhHSAddipMCmPXhXDfHZVTtNhqcMMVauyjKOFGBHPe,tXpiHGkKGTzzMluO +lnVknQNrrYyqFbEKYPxsQWNPKLpsEVmUGtbbWMWDThMuScSByeZRwuusLYzKPbHE,HzxgCCtiIFYvgwWO +oYSnlwjpsWaNzYunBnhNLwiICrmAEFiZRczbdHYpQgwSrrMQCixgtjfCGOptTkmd,IYmbhaQueIKQvcBc +saWvQlIiiYqAPmcEDGsVXNAIJNNGTyZKhrMMKYHXJQnniGVuIClgwvAEXeIPGeFN,epReKmWNANFpINhn +PZIEJMwirPArOGCfJJAfdwGydRDBGGQojUzWFJtVoJZTFAFYwaDOuLFruvRjolHq,yCMGUjViZoTPMTtg +bMCOdGAYjXaDSPyZyGegyuRnnwYSySrRzbLbvtgBjfFXfCMPIVIGFTagRyBpiKLa,zRTXXYnUXmkIMDoP +JQWdUrNElPWswpQvnVqCmboMEjMhebKISRcmznakzemGxBjUughzOVbctPzmVTLW,CdcnGSQhRbdsoQrg +OdAvNbDdCQpTrAbJWrUrVprpgVIXwvvSStooIVwzUfDIThtvdBHldyUFFkvabfyj,ueAqmPurXOjNtvWr +zjIkKNEzyUFiubvlxYWNXdjoIIEwZavalnqwSCgDgcZUldjZOkzhKXuRciwSTNJg,MWwWigLZKqLgZkLp +WTaCVgYrnoyEoShtBDUmrRHeRSYIAjvUpZnVAUTxTyaIGzvQIdwcPafAnkIbplSq,mdIEdIajBbeAyPCk +tNdAhhdqVzLdGfPoctgRkehzEOIRvjEwDpmAQrMjbWtfRQGjeUiVJNafrhVKFieX,IxHvvUzKrlMpWhpR +vhgPWqsRnvDRMFIYHppovDbKlWPzEFwbBXSihpYbwCYpkeXFIXbIYdWSLfcHpnWX,qhwFSRGRKwcPlLJs +ARxEUJokZaGDgXHGxwPiSqqvNSmoowUxRDDkozqbvcUvQuPtdNaeaKOKykMIUkmR,XqmCMQPKPtAPzBZd +RwztiezZCzbSLKzIyYqfEMjDTcLpASCiGWoaseuxBWpvSVutmtdEgdZornGkHrQf,YcRFZNgodJFPNoop +YTcHYrADMhlKAnvdGBdQBXWBqcftxkNpFceODelYVRXwFOZTHdXkVGAfJTzZcyhD,tBGtrQaLFgACGOEE +fHFCvDLRGGhYZWSnxaIqKTgvNbCPLzyvOnpHyAhrKEAsApdPgkxAptCTtgYAnmEq,vxGOPFzvJOVBEblg +zckpuLjSVdhSFnhTqPfDoHdJdjpfZBDdlzGbYgzVbKgDMJQDBGCHZSJBdtzlvHro,TeeGbXAcEbwzglGf +muAQTPuNCQTZurKTDlYzTQgvlWNyRXOlKizgsnGSrKdYWCSBlQtOvIyEWVthaYhO,ZnYBDVQYoJOoTMlS +UQswwuiprHWAbguGNZgOAdFrgEIdsDRImrqXXTmbqppVgnJrjjiOdZaNUpIQGcTR,VwugWpNMzEKHAFqo +GDRPaAUIAymOEEksSqccGOqpUYvGUyvBKjfRqKSTAyNadpaMYnMYboPOrEEfXVWf,noDbJmsjYCgqHsBu +cVjSBnCUnKfKXwETABIPvavwLXMGSLSpoVylUSCRlRCzpDvDVjfNAIrSiRWNHJZS,OszhlCboIvNdCTYH diff --git a/spu/tests/data/ground_truth.csv b/spu/tests/data/ground_truth.csv new file mode 100644 index 00000000..fa97d0fa --- /dev/null +++ b/spu/tests/data/ground_truth.csv @@ -0,0 +1 @@ +JpnLfyRVbNPLIfbvPGuakcXCvxtoElcbACKRUfMSiKUemqyOVmvLspaZEPUtqJxv,eweOmajnCQLZOgBp diff --git a/spu/tests/data/pir/100K-1-16.json b/spu/tests/data/pir/100K-1-16.json new file mode 100644 index 00000000..6ebf1819 --- /dev/null +++ b/spu/tests/data/pir/100K-1-16.json @@ -0,0 +1,19 @@ +{ + "table_params": { + "hash_func_count": 1, + "table_size": 409, + "max_items_per_bin": 42 + }, + "item_params": { + "felts_per_item": 5 + }, + "query_params": { + "ps_low_degree": 0, + "query_powers": [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42 ] + }, + "seal_params": { + "plain_modulus": 65537, + "poly_modulus_degree": 2048, + "coeff_modulus_bits": [ 48 ] + } +} diff --git a/spu/tests/data/pir/db.csv b/spu/tests/data/pir/db.csv new file mode 100644 index 00000000..daf120eb --- /dev/null +++ b/spu/tests/data/pir/db.csv @@ -0,0 +1,101 @@ +key,value +aPYaKgvvcESwAtfghnRUIAYYIZeCsGeaWbAAEUCXQzNrxGRPVOcACqMBJmdfiveq,LdQNbKmBMhlpCctB +eXeLfYoSlEntpRaKuddYhtaImMOdEhNTIolxElSrlPYMhgZoWxccpUTOjFciaRcD,wcsaUIrVHmxDxnXr +gOUgSoWKKYaLpmWiuTyuVNostvTJUHxBZjYWJZukOzqICDlmKdavyERZimgOkaHn,FgdZUDbsfINsODTe +BrImQzkUpelejHVFeatNDJTqdTuxCJmHKunkODDPsUtsgEfCkXyLNEAZXniJZgvf,PkxxKRwEEEwrcLbR +RhoATcXotsJDbAICIkIbAYdDUXqreTDbZVPoSyziyQJvBVVaZhYqWmHVQsqNbdIP,OTFagQCudOpRtKqT +bibVQsoLIFjXrSFIjPFOTWclPDhakarsWICsYNvbMtlNwBIKqnaWmvCPcFPesBVy,YysQfuPwTiXayxlt +CQZcsLKmugDFfsYAOihJENLpBwCYwoLuATHrMiIChXkycvhejgjjUCoxKDpTDMuZ,OIGhXryCxSXCZKBn +aizidqjurwmpHHqioJclEQvccvnWQKmZTMBNSGgmfxxqHilZwFPdjHlZElQaayJN,SjCyivgzSTLpXmnz +eNGJWxNcevvcthSXdXLseTORLpDVdVCAdnbQtcovUVAislbievgctCGMidsncXyL,rFJEAlFBMUMZNPrc +OQArDeXFlhxHBPERmlYuwZHQAgCQtRorzCkDenmMdnbFdGKJggqmhApWMrbtqDZw,saeWaHguQALRalnZ +AZKVktpAYSLSZTpjylrifvdItAuhXvzvDhtWnbjMlshYFTooHXtwNfCsmiuWgXLa,sntRGrWUapWvBwAM +xNTfGmPFpdczqsPvlLFMSzXLdqGkyupfwXepkfJHECoAIDZaRwqoIullNzMfWHLK,jVerCzWaKQoYYBeP +hKFRJSHErdbPbPFRZpEkigOemyqRgRDzWHMhiVDlHDtTAjKHZcQCogTvFlbwVVbA,ftLFWLGNrSkavyEX +tGYkoEeaQinZmEsPtFyhyJxXGiJreOCTHTxPfSWFFQwqChLKADXjeWiTsVnLyFTb,vKUHgdGvKkJKSAcJ +KSyQGEXCvzziOdoaIwQYynSjQlyPOkJYroZAxblcGRHOqhjWuByOVIUYfuacruUQ,NnsQMUvXYZOhswVL +WsQkdjbaExqvNkQlbMWqwxtChiWRBsmylFanbONarkbmmWAnKPvCdPLlhQtvFaXP,phEafGnFQIcoYCKG +rESDWnSxXeTvwvTDxSsSmwPswujcWVlrvlgVAuGUYjvyXFjCSGenqEZhfFHgZXIB,sHxPvGwZWDbJgCYn +UKgGYBiPUdYdfuQZrpKEtprGxofPaLVqhdiuSUzKrSDzXuCgYuOQJlyhsNHyLrCx,sdNdnkjuRNLyPhaB +eEXLhpsjwQQRFVXxPgqQnFMCQDAMOMxwoAeeubrJXwxKaiWgnilOwjzoEZuRUJLV,WXVgFQQFBEwRomjA +QVGpdCqJIpZStCSUOEKyEoOJJurZSWmQZDCnIrANHGJYhpbfxAhsPvrVZnVrhQKn,bUIqpQPVvtiPhItz +UOmpkfibQXxQlYJQzZgdfoIHckIWHWiVBcoLaLSQnlpnIBQZcrnCEXCfHTDwsFDX,OdckqFMcdfnVnBoB +TDzDtfzzMugZZxhNGhmwYsMrOCFvCUWmmOUOLNGAYMRMnZVGuOMSXZZgaTufrqXK,DAToiYOldpgNGOqn +zmRKIEQFtIjCYUXaFgyAvZZEDIukHAwYlzUwxbttWndcAGFEoRzGyAUuLsKnbfZi,rfwdiePuXVtvKgat +vmtpIcBkPJyFRqKWIYHWcecKdgCoUShJwkhYvjHZPdhwmcdBGwQDDVynyOwSZcYj,PDNsnMKRZubVpMRT +EjQMHoawzxMREpZaJFKJBNsnKdzQTWeGmAMkhsuSfEzoDpQfdUWUeTFKvKClRNPz,rauheCdowFiOAMFk +mjjTWkjovIsCsMuZfdtXIKVZEcwuspLRUtCVPKpMdkkaGQtUUmFrXaZHaDuPKvsa,wCscOftxAHuBnsSW +MjvTMaePwIVFpEbspToomYGFAOmpGuKlmgJvIOhtVoHNgWaHReuMELUapHWAaZjL,MeizqMvAktGZLkCH +EwJZCBgPDuKRnTTZwuJRKfkznXpHGdbfMOZTnVjixKGciMLkdLSzWBXkBhMGzwSS,RhpjUzFsJtSSXund +HSaBXSGcBxYSUIXYnlFnYrdTclIehDdMhKqIRJuAYebfViJttknfMmCqbyYOJAXE,NIZPwgQebsKBehaN +fLOyQDLsIUaWZUjwzsrxlGHlGTYNWVZyTEWJZenWqZiMqHEpLWAvGojmOQvteOqS,XGTuKgtLshqQUtfr +wvKVPbYksmYXTsRqvJETrjXJethrvgmBLIwMQhJBCMTfLGOFKHwxrrBcGQqdMjZe,ICnEQAovJhrWaIiY +WpbdxcjMqKMkLdSlUBowCTWDGVtRJLiEDQytMenWEIkWFWLKByiEhrvIpCncUQDS,xEGgXHFKHYDlGdYF +AVSBpQelmdheyUZPdmRrhrEqHmKowFAIDNjxzphVCoLgSypBfHNtVuDgoIVqCoLF,ZLFAfNImdiEwcupl +XtndzLknyWTeElGFXjZfbrHGqYzqHcTzEtXquKkpckuwhkQPcCkmXIhfCnLYCrVG,UHbNwRmAJFMalnbt +lvLHBgacmSdZJqpzrezjTYTfWIBFUDIaMcGwErtmnAwgjDXwmIHxMDqYTrJvjUyq,NoWnhJQWDJVCtExB +SKenZwjdFvsAiARRmpBzTAXGWtByjJcIniiAhovlsAHLZXQJCmDyRJxKevZBttDa,mHAfCtCExIyoRuWG +vKgaRAkJVOSESTgmEVXpQIVXDADJiHmFFaAwxtjwUyFVrQouyJcZeDwhMUZPROkA,CiuvQKMCsVGLkoBM +cdxMDaSfvmlcpSfFqvfzgNIyUcmkDEyVswXcCKCJfYyAqrSCGWBGIEQlBxKTWSCj,LTIsYYovMEnpawzQ +bnfGLKRBqwOCQBNWSRAbEWVLqyUAlzrYNiJRWUZnmGnXtjMFBQhLHBhVJBygIrGU,vLYmsxCGzRXOMiqB +bxidiEYIOjbFDbHRqnaYXZuQcZJbNxynsmjNPCpzujEKzATaBeTrUchoylhvqLjx,AXlyMrMWiWGMqoIs +eSuMvTbZgMRDrEIxJwgFYdpNWkQzEzrsyyybeaJlUPEhEZZBWpPwQFqImIGnLFar,BTfGxXWpxvvPzcDS +uONlUmHMFQiPkCfdPrqqUDleaUBHKnxuQbFJAqfMqzoSpqPzawdOIvtVQMSWfCRv,uFDJTeJjZNFjIxKT +IEWyZnggNMulCyYklMdZaMYiIsqQNtbzbcpMHBUfPeOKaoSCMeezBSqcQwVXNJho,zpDlpaygXXZyslHj +ItlLEykbpzIeTErDcbaxxfzntBAYcPHVcLFneOzhNhxYYgwsbKZEmHHuHTnPnhSW,nKPAjZgLwcVTMNGi +uTQQIgVItVwPUSygrxoUwrLuFbAbahqbnixUuRIKRnJConAViRHYsRerKyEieFYI,fhPHVglAfXtHjNae +etMRJJVQSxIMgvzdSoCPsVpGJcKtjpXtMqtzXgaGQDryplTJifNvOFYGWChHLOUo,WJrEYWYGMOqveCgd +BYHllipDRYZVvMQYYhIRzLHabgftPTSnFbUCRmZejFUoeLLQoZtrZPJrTjVqWfNO,EztmadvCnpbgQtBl +ZEkjAbQOpxQrbtEDVlDhKgChCNsxTxSQtUXARrEeVJQrzPuVPkYHuuoMXjVCyeCk,kcgMJGMwDDDiFDSP +WdhxSOyjKNLItJZiXZtkKBdIcjGLuLbZHPmSfJCzlvqBxnrobjDTPxsFRXhEInhh,OPuAWqKpmpZHbvIY +wSGlgcmRasbLYVClIhhCppYezjZrWIhhaiASQcDrCDxdsGJIJjNmTWwbsFuKlbAt,VToAPASiTJHUGIEk +CFzAfFrDFyhndtOJalJiSiufNlCjWcwxwQjnjRgbqFlaAlIzXgrwVJmwISEAKHxx,jpylGyqZUinEIiVo +BwluSThaftzdTOrrzjBfqmHdbXqLUDEPPqmduYZESLtaSQAWLOKeuECRKPDEumJA,vgocckgnbQZCgpSq +OUguorluoGFpgVuXujAFBOIkBsIAaNCXgcywsWjutvEcrJrrDBRAHgcKwnfNLpXr,vwsRUvSJpupXvGta +TXQWeuibBvqDmmTaLAPZNsEHosjhcLBsixvomJaiAPmLmBDemETNOZMwrwREVRir,EIgmXjWGNYpcfdOs +BbQiYWirQnKUzBachIeJZgJWvmeZkpBUphzYEGrGcGpxUwWvSSnSYpQzBRsOFuzC,oUBiVRJNarkuUdRV +dugMhlZPXHEnWTZjaOFaxmmZgIHdVBmzUsIfUUPZkdVzDvCLRZyTBPrHhDVrYAVk,OdxERipabMlyoyEa +aKPvrmNwyclBMcPMEAgGItoshSJVrSonWVrcHMWiBXxqTpdHjuGKrRLTaHhQyunk,TmBLACtwsObAkgoz +LSndhzUuoIbprGCzRDfZryKqdcqwLpLWYfHoOHgBJDSkZRoYMQxmIoVCUdBSxHsZ,qvhvMIqFvXELhWqK +lInwjXRJTuZgrvcbDEvFHgPFGpqlnSuIJWtNRJYizWEfZJbZtfLexmEQMnGLxNlW,MURWsZhFTeLoqeAd +DOlyDBVEZFruhsBwHZgrnWTXckcVJcVzrwniSnJFEYUNiFgkIyukstlbdrluVhag,vUuhogLghMFjNjyK +zgDxGOAAYMGtAOoMviwtSQDLEpHluVuqFsqisVvoKCLfnMdPVTKgKCrchKrAmlmz,YaojDSYxfntenXUp +JpnLfyRVbNPLIfbvPGuakcXCvxtoElcbACKRUfMSiKUemqyOVmvLspaZEPUtqJxv,eweOmajnCQLZOgBp +nJJyhHkDkQRttGsYjkMBBuGeQuPPDHQEQQdnGQmMbOdRFifRDpZVUqdfqeskxngR,nyikwrUYIuWbTawJ +ZcwBhgoFFiWyDeZbeMpliSvhbfsAkccEQQhYreLTGdfVHuNLpmCsduhkIlRKMNkx,nWeXVNOUmsoakAMC +BCJYdYFkkLRWUxhnnwpJGbXJchPvVCtbcaMdkArTcNLdRmopwncgdgOLGhJkZOnC,JWAVSCWqKUcmfoqk +NkxNbLyBjixoCzClTdwshhuZcFRjJJdDdWgCfiQIttZWQWqouBkYyMGpampLdUAr,nkuaUoDVcCUBUoOv +WgwvZoDsjspDAWHEflQMWzlbqnssWiBElmABhLmhgDPqFbNmAHSnzQrbAqSVAmWS,IQIlCVVcjVcTdshJ +yXelVXMUEuAtfNgzPrhjvYOpiAVEMZuqPfsQEUQoshjSIekxxzkFxdftfqFzfzpa,AyTbbGVAXKCYUUln +sAtWqpSPxPSkDtmIJKfNvlKjgStnYMOmrLsQnzmIFAusETPPzLDTjKcBASKWNRAJ,bDQFllamogAjBEPU +RAOMDMMZkezCBWxQDWLjvHLkbpvFyrbUbDDEekWciXejYwKifSVumcsocUmMkmpa,MnwKAECsMVGVLIZM +LMizeIoxMxHCKwikjqOSSPbuiqXWDAmbTLMBOXpyorUmpunjWFTNLVqvHNcCNHrN,zDzyAkfeTYZzaxhG +IzLqciYsaKtWrsrjOQldeIvEqavoIZEYnupJZLizJVeOhoLqtLQFaoNRdvZMWSQH,klLegifptLAxnhha +VkyQsnlXcIGjGhcSJUcZKeQyiDUbIcgIWSbHaEsbfEydSTHqRlxImGdYGEurZczg,jyzyCGlvuBdKwyIX +dkSRcFpHPXWjNIHrCpWlOPaIkVqjtyPRhlJeYMksjieDxYhiUcGhbuvamVlrMDDx,FATEzwNXGerINvHD +KravfnzVbNOhLhstPcaLLVWpbqYzXckQGbuALlEiXqqhFfUyThFZFhzSLhjldPMB,sowcmsLQTFiKpNXy +ITZWxxjqnizlxuRWMlPQLUnBopyDtOxNfcaoFDbRKetIpVxKLSRoJOauSCcDwWUP,NmqaomxmRQGqiKCV +guhzFIXdUEACMMHObJptkrZqJgbclDcdRxCPYSvuAdITaKgHfJaKLNHFzdRmpHni,DiONbLbxHfUHhoTU +AnZjaIBmnsEBplkEmHBstdggPnYmhblyQQttVqYzxxNtOXwlNQetkvCOySSXRUpw,FCdUtLDyvAqszerb +wkCcCXGKcZJEDwTkzOoNRkMbxHdNciQlVruGSKcJrHpokspcZIVfupcTxapISupH,pzEvnKzLQbzNDSQN +ppieExXmNHqBXVgLFhjlHHHhHSAddipMCmPXhXDfHZVTtNhqcMMVauyjKOFGBHPe,tXpiHGkKGTzzMluO +lnVknQNrrYyqFbEKYPxsQWNPKLpsEVmUGtbbWMWDThMuScSByeZRwuusLYzKPbHE,HzxgCCtiIFYvgwWO +oYSnlwjpsWaNzYunBnhNLwiICrmAEFiZRczbdHYpQgwSrrMQCixgtjfCGOptTkmd,IYmbhaQueIKQvcBc +saWvQlIiiYqAPmcEDGsVXNAIJNNGTyZKhrMMKYHXJQnniGVuIClgwvAEXeIPGeFN,epReKmWNANFpINhn +PZIEJMwirPArOGCfJJAfdwGydRDBGGQojUzWFJtVoJZTFAFYwaDOuLFruvRjolHq,yCMGUjViZoTPMTtg +bMCOdGAYjXaDSPyZyGegyuRnnwYSySrRzbLbvtgBjfFXfCMPIVIGFTagRyBpiKLa,zRTXXYnUXmkIMDoP +JQWdUrNElPWswpQvnVqCmboMEjMhebKISRcmznakzemGxBjUughzOVbctPzmVTLW,CdcnGSQhRbdsoQrg +OdAvNbDdCQpTrAbJWrUrVprpgVIXwvvSStooIVwzUfDIThtvdBHldyUFFkvabfyj,ueAqmPurXOjNtvWr +zjIkKNEzyUFiubvlxYWNXdjoIIEwZavalnqwSCgDgcZUldjZOkzhKXuRciwSTNJg,MWwWigLZKqLgZkLp +WTaCVgYrnoyEoShtBDUmrRHeRSYIAjvUpZnVAUTxTyaIGzvQIdwcPafAnkIbplSq,mdIEdIajBbeAyPCk +tNdAhhdqVzLdGfPoctgRkehzEOIRvjEwDpmAQrMjbWtfRQGjeUiVJNafrhVKFieX,IxHvvUzKrlMpWhpR +vhgPWqsRnvDRMFIYHppovDbKlWPzEFwbBXSihpYbwCYpkeXFIXbIYdWSLfcHpnWX,qhwFSRGRKwcPlLJs +ARxEUJokZaGDgXHGxwPiSqqvNSmoowUxRDDkozqbvcUvQuPtdNaeaKOKykMIUkmR,XqmCMQPKPtAPzBZd +RwztiezZCzbSLKzIyYqfEMjDTcLpASCiGWoaseuxBWpvSVutmtdEgdZornGkHrQf,YcRFZNgodJFPNoop +YTcHYrADMhlKAnvdGBdQBXWBqcftxkNpFceODelYVRXwFOZTHdXkVGAfJTzZcyhD,tBGtrQaLFgACGOEE +fHFCvDLRGGhYZWSnxaIqKTgvNbCPLzyvOnpHyAhrKEAsApdPgkxAptCTtgYAnmEq,vxGOPFzvJOVBEblg +zckpuLjSVdhSFnhTqPfDoHdJdjpfZBDdlzGbYgzVbKgDMJQDBGCHZSJBdtzlvHro,TeeGbXAcEbwzglGf +muAQTPuNCQTZurKTDlYzTQgvlWNyRXOlKizgsnGSrKdYWCSBlQtOvIyEWVthaYhO,ZnYBDVQYoJOoTMlS +UQswwuiprHWAbguGNZgOAdFrgEIdsDRImrqXXTmbqppVgnJrjjiOdZaNUpIQGcTR,VwugWpNMzEKHAFqo +GDRPaAUIAymOEEksSqccGOqpUYvGUyvBKjfRqKSTAyNadpaMYnMYboPOrEEfXVWf,noDbJmsjYCgqHsBu +cVjSBnCUnKfKXwETABIPvavwLXMGSLSpoVylUSCRlRCzpDvDVjfNAIrSiRWNHJZS,OszhlCboIvNdCTYH diff --git a/spu/tests/data/pir/ground_truth.csv b/spu/tests/data/pir/ground_truth.csv new file mode 100644 index 00000000..6160abfa --- /dev/null +++ b/spu/tests/data/pir/ground_truth.csv @@ -0,0 +1,2 @@ +key,value +JpnLfyRVbNPLIfbvPGuakcXCvxtoElcbACKRUfMSiKUemqyOVmvLspaZEPUtqJxv,eweOmajnCQLZOgBp diff --git a/spu/tests/data/pir/query.csv b/spu/tests/data/pir/query.csv new file mode 100644 index 00000000..5cb2f046 --- /dev/null +++ b/spu/tests/data/pir/query.csv @@ -0,0 +1,2 @@ +key +JpnLfyRVbNPLIfbvPGuakcXCvxtoElcbACKRUfMSiKUemqyOVmvLspaZEPUtqJxv diff --git a/spu/tests/data/query.csv b/spu/tests/data/query.csv new file mode 100644 index 00000000..79dbd224 --- /dev/null +++ b/spu/tests/data/query.csv @@ -0,0 +1 @@ +JpnLfyRVbNPLIfbvPGuakcXCvxtoElcbACKRUfMSiKUemqyOVmvLspaZEPUtqJxv diff --git a/spu/tests/jnp_debug.py b/spu/tests/jnp_debug.py index ff970dbf..4757c555 100644 --- a/spu/tests/jnp_debug.py +++ b/spu/tests/jnp_debug.py @@ -15,7 +15,6 @@ import jax.numpy as jnp import numpy as np -import spu.intrinsic as si import spu.spu_pb2 as spu_pb2 import spu.utils.simulation as ppsim @@ -31,9 +30,9 @@ copts.disable_div_sqrt_rewrite = True x = np.random.randn(3, 4) - y = np.random.randn(5, 6) - fn = lambda x, y: si.example_binary(x, y) - # fn = lambda x, y: jnp.matmul(x, y) + y = np.random.randn(4, 5) + fn = lambda x, y: jnp.matmul(x, y) + spu_fn = ppsim.sim_jax(sim, fn, copts=copts) z = spu_fn(x, y) diff --git a/spu/tests/jnp_semi2k_r128_test.py b/spu/tests/jnp_semi2k_r128_test.py index 10295fed..7ce5c825 100644 --- a/spu/tests/jnp_semi2k_r128_test.py +++ b/spu/tests/jnp_semi2k_r128_test.py @@ -30,5 +30,19 @@ def setUp(self): self._rng = np.random.RandomState() +class JnpTestSemi2kFM128TwoParty(JnpTests.JnpTestBase): + def setUp(self): + config = spu_pb2.RuntimeConfig( + protocol=spu_pb2.ProtocolKind.SEMI2K, field=spu_pb2.FieldType.FM128 + ) + config.experimental_enable_exp_prime = True + config.experimental_exp_prime_enable_upper_bound = True + config.experimental_exp_prime_offset = 13 + config.experimental_exp_prime_disable_lower_bound = False + config.fxp_exp_mode = spu_pb2.RuntimeConfig.ExpMode.EXP_PRIME + self._sim = ppsim.Simulator(2, config) + self._rng = np.random.RandomState() + + if __name__ == "__main__": unittest.main() diff --git a/spu/tests/legacy_psi_test.py b/spu/tests/legacy_psi_test.py index eae7e7c8..477e964c 100644 --- a/spu/tests/legacy_psi_test.py +++ b/spu/tests/legacy_psi_test.py @@ -194,252 +194,6 @@ def test_dppsi_2pc(self): 2, inputs, outputs, selected_fields, psi.PsiType.DP_PSI_2PC ) - def test_ecdh_oprf_unbalanced(self): - print("----------test_ecdh_oprf_unbalanced-------------") - - offline_path = ["", "spu/tests/data/bob.csv"] - online_path = ["spu/tests/data/alice.csv", "spu/tests/data/bob.csv"] - outputs = ["./alice-ecdh-unbalanced.csv", "./bob-ecdh-unbalanced.csv"] - preprocess_path = ["./alice-preprocess.csv", ""] - secret_key_path = ["", "./secret_key.bin"] - selected_fields = ["id", "idx"] - - with open(secret_key_path[1], 'wb') as f: - f.write( - bytes.fromhex( - "000102030405060708090a0b0c0d0e0ff0e0d0c0b0a090807060504030201000" - ) - ) - - time_stamp = time.time() - lctx_desc = link.Desc() - lctx_desc.id = str(round(time_stamp * 1000)) - - for rank in range(2): - port = get_free_port() - lctx_desc.add_party(f"id_{rank}", f"127.0.0.1:{port}") - - receiver_rank = 0 - server_rank = 1 - client_rank = 0 - # one-way PSI, just one party get result - broadcast_result = False - - precheck_input = False - server_cache_path = "server_cache.bin" - - def wrap( - rank, - offline_path, - online_path, - out_path, - preprocess_path, - ub_secret_key_path, - ): - link_ctx = link.create_brpc(lctx_desc, rank) - - if receiver_rank != link_ctx.rank: - print("===== gen cache phase =====") - print(f"{offline_path}, {server_cache_path}") - - gen_cache_config = psi.BucketPsiConfig( - psi_type=psi.PsiType.Value('ECDH_OPRF_UB_PSI_2PC_GEN_CACHE'), - input_params=psi.InputParams( - path=offline_path, - select_fields=selected_fields, - precheck=False, - ), - output_params=psi.OutputParams( - path=server_cache_path, need_sort=False - ), - bucket_size=1000000, - curve_type=psi.CurveType.CURVE_FOURQ, - ecdh_secret_key_path=ub_secret_key_path, - ) - - start = time.time() - - gen_cache_report = psi.gen_cache_for_2pc_ub_psi(gen_cache_config) - - server_source_count = wc_count(offline_path) - self.assertEqual( - gen_cache_report.original_count, server_source_count - 1 - ) - - print(f"offline cost time: {time.time() - start}") - print( - f"offline: rank: {rank} original_count: {gen_cache_report.original_count}" - ) - - print("===== transfer cache phase =====") - transfer_cache_config = psi.BucketPsiConfig( - psi_type=psi.PsiType.Value('ECDH_OPRF_UB_PSI_2PC_TRANSFER_CACHE'), - broadcast_result=broadcast_result, - receiver_rank=receiver_rank, - input_params=psi.InputParams( - path=offline_path, - select_fields=selected_fields, - precheck=precheck_input, - ), - bucket_size=1000000, - curve_type=psi.CurveType.CURVE_FOURQ, - ) - - if receiver_rank == link_ctx.rank: - transfer_cache_config.preprocess_path = preprocess_path - else: - transfer_cache_config.input_params.path = server_cache_path - - print( - f"rank:{link_ctx.rank} file:{transfer_cache_config.input_params.path}" - ) - - start = time.time() - transfer_cache_report = psi.bucket_psi(link_ctx, transfer_cache_config) - - if receiver_rank != link_ctx.rank: - server_source_count = wc_count(offline_path) - self.assertEqual( - transfer_cache_report.original_count, server_source_count - 1 - ) - - print(f"transfer cache cost time: {time.time() - start}") - print( - f"transfer cache: rank: {rank} original_count: {transfer_cache_report.original_count}" - ) - - print("===== shuffle online phase =====") - shuffle_online_config = psi.BucketPsiConfig( - psi_type=psi.PsiType.Value('ECDH_OPRF_UB_PSI_2PC_SHUFFLE_ONLINE'), - broadcast_result=False, - receiver_rank=server_rank, - input_params=psi.InputParams( - path=online_path, - select_fields=selected_fields, - precheck=precheck_input, - ), - output_params=psi.OutputParams(path=out_path, need_sort=False), - bucket_size=10000000, - curve_type=psi.CurveType.CURVE_FOURQ, - ) - - if client_rank == link_ctx.rank: - shuffle_online_config.preprocess_path = preprocess_path - else: - shuffle_online_config.preprocess_path = server_cache_path - shuffle_online_config.ecdh_secret_key_path = ub_secret_key_path - - print( - f"rank:{link_ctx.rank} file:{shuffle_online_config.input_params.path}" - ) - - start = time.time() - shuffle_online_report = psi.bucket_psi(link_ctx, shuffle_online_config) - - if server_rank == link_ctx.rank: - server_source_count = wc_count(offline_path) - self.assertEqual( - shuffle_online_report.original_count, server_source_count - 1 - ) - - print(f"shuffle online cost time: {time.time() - start}") - print( - f"shuffle online: rank: {rank} original_count: {shuffle_online_report.original_count}" - ) - print( - f"shuffle online: rank: {rank} intersection: {shuffle_online_report.intersection_count}" - ) - - print("===== offline phase =====") - offline_config = psi.BucketPsiConfig( - psi_type=psi.PsiType.Value('ECDH_OPRF_UB_PSI_2PC_OFFLINE'), - broadcast_result=broadcast_result, - receiver_rank=client_rank, - input_params=psi.InputParams( - path=offline_path, - select_fields=selected_fields, - precheck=precheck_input, - ), - output_params=psi.OutputParams(path="fake.out", need_sort=False), - bucket_size=1000000, - curve_type=psi.CurveType.CURVE_FOURQ, - ) - - if client_rank == link_ctx.rank: - offline_config.preprocess_path = preprocess_path - offline_config.input_params.path = "dummy.csv" - else: - offline_config.ecdh_secret_key_path = ub_secret_key_path - - start = time.time() - offline_report = psi.bucket_psi(link_ctx, offline_config) - - if receiver_rank != link_ctx.rank: - server_source_count = wc_count(offline_path) - self.assertEqual(offline_report.original_count, server_source_count - 1) - - print(f"offline cost time: {time.time() - start}") - print( - f"offline: rank: {rank} original_count: {offline_report.original_count}" - ) - print( - f"offline: rank: {rank} intersection_count: {offline_report.intersection_count}" - ) - - print("===== online phase =====") - online_config = psi.BucketPsiConfig( - psi_type=psi.PsiType.Value('ECDH_OPRF_UB_PSI_2PC_ONLINE'), - broadcast_result=broadcast_result, - receiver_rank=client_rank, - input_params=psi.InputParams( - path=online_path, - select_fields=selected_fields, - precheck=precheck_input, - ), - output_params=psi.OutputParams(path=out_path, need_sort=False), - bucket_size=300000, - curve_type=psi.CurveType.CURVE_FOURQ, - ) - - if receiver_rank == link_ctx.rank: - online_config.preprocess_path = preprocess_path - else: - online_config.ecdh_secret_key_path = ub_secret_key_path - online_config.input_params.path = "dummy.csv" - - start = time.time() - report_online = psi.bucket_psi(link_ctx, online_config) - - if receiver_rank == link_ctx.rank: - client_source_count = wc_count(online_path) - self.assertEqual(report_online.original_count, client_source_count - 1) - - print(f"online cost time: {time.time() - start}") - print(f"online: rank:{rank} original_count: {report_online.original_count}") - print(f"intersection_count: {report_online.intersection_count}") - - link_ctx.stop_link() - - # launch with multiprocess - jobs = [ - multiprocess.Process( - target=wrap, - args=( - rank, - offline_path[rank], - online_path[rank], - outputs[rank], - preprocess_path[rank], - secret_key_path[rank], - ), - ) - for rank in range(2) - ] - [job.start() for job in jobs] - for job in jobs: - job.join() - self.assertEqual(job.exitcode, 0) - if __name__ == '__main__': unittest.main() diff --git a/spu/tests/pir_test.py b/spu/tests/pir_test.py index 45574044..3bf1cd68 100644 --- a/spu/tests/pir_test.py +++ b/spu/tests/pir_test.py @@ -13,122 +13,88 @@ # limitations under the License. import json +import tempfile import unittest -from tempfile import TemporaryDirectory import multiprocess -from google.protobuf import json_format - import spu.libspu.link as link import spu.psi as psi -from spu.tests.utils import create_link_desc, wc_count +from google.protobuf import json_format +from spu.tests.utils import create_link_desc class UnitTests(unittest.TestCase): - def setUp(self) -> None: - self.tempdir_ = TemporaryDirectory() - return super().setUp() - - def tearDown(self) -> None: - self.tempdir_.cleanup() - return super().tearDown() def test_pir(self): - # setup stage - server_setup_config = f''' - {{ - "mode": "MODE_SERVER_SETUP", - "pir_protocol": "PIR_PROTOCOL_KEYWORD_PIR_APSI", - "pir_server_config": {{ - "input_path": "spu/tests/data/alice.csv", - "setup_path": "{self.tempdir_.name}/spu_test_pir_pir_server_setup", - "key_columns": [ - "id" - ], - "label_columns": [ - "y" - ], - "label_max_len": 288, - "bucket_size": 1000000, - "apsi_server_config": {{ - "oprf_key_path": "{self.tempdir_.name}/spu_test_pir_server_secret_key.bin", - "num_per_query": 1, - "compressed": false - }} + with tempfile.TemporaryDirectory() as temp_dir: + # setup stage + sender_setup_config_json = f''' + {{ + "source_file": "spu/tests/data/pir/db.csv", + "params_file": "spu/tests/data/pir/100K-1-16.json", + "sdb_out_file": "{temp_dir}/sdb", + "save_db_only": true }} - }} - ''' - - with open( - f"{self.tempdir_.name}/spu_test_pir_server_secret_key.bin", 'wb' - ) as f: - f.write( - bytes.fromhex( - "000102030405060708090a0b0c0d0e0ff0e0d0c0b0a090807060504030201000" + ''' + + psi.apsi_send( + json_format.ParseDict( + json.loads(sender_setup_config_json), psi.ApsiSenderConfig() ) ) - psi.pir(json_format.ParseDict(json.loads(server_setup_config), psi.PirConfig())) - - server_online_config = f''' - {{ - "mode": "MODE_SERVER_ONLINE", - "pir_protocol": "PIR_PROTOCOL_KEYWORD_PIR_APSI", - "pir_server_config": {{ - "setup_path": "{self.tempdir_.name}/spu_test_pir_pir_server_setup" + sender_online_config_json = f''' + {{ + "db_file": "{temp_dir}/sdb" }} - }} - ''' - - client_online_config = f''' - {{ - "mode": "MODE_CLIENT", - "pir_protocol": "PIR_PROTOCOL_KEYWORD_PIR_APSI", - "pir_client_config": {{ - "input_path": "{self.tempdir_.name}/spu_test_pir_pir_client.csv", - "key_columns": [ - "id" - ], - "output_path": "{self.tempdir_.name}/spu_test_pir_pir_output.csv" + ''' + + receiver_online_config_json = f''' + {{ + "query_file": "spu/tests/data/pir/query.csv", + "output_file": "{temp_dir}/result.csv", + "params_file": "spu/tests/data/pir/100K-1-16.json" }} - }} - ''' + ''' - pir_client_input_content = '''id -user808 -xxx -''' + sender_online_config = json_format.ParseDict( + json.loads(sender_online_config_json), psi.ApsiSenderConfig() + ) + + receiver_online_config = json_format.ParseDict( + json.loads(receiver_online_config_json), psi.ApsiReceiverConfig() + ) - with open(f"{self.tempdir_.name}/spu_test_pir_pir_client.csv", 'w') as f: - f.write(pir_client_input_content) + link_desc = create_link_desc(2) - configs = [ - json_format.ParseDict(json.loads(server_online_config), psi.PirConfig()), - json_format.ParseDict(json.loads(client_online_config), psi.PirConfig()), - ] + def sender_wrap(rank, link_desc, config): + link_ctx = link.create_brpc(link_desc, rank) + psi.apsi_send(config, link_ctx) - link_desc = create_link_desc(2) + def receiver_wrap(rank, link_desc, config): + link_ctx = link.create_brpc(link_desc, rank) + psi.apsi_receive(config, link_ctx) - def wrap(rank, link_desc, configs): - link_ctx = link.create_brpc(link_desc, rank) - psi.pir(configs[rank], link_ctx) + jobs = [ + multiprocess.Process( + target=sender_wrap, args=(0, link_desc, sender_online_config) + ), + multiprocess.Process( + target=receiver_wrap, args=(1, link_desc, receiver_online_config) + ), + ] - jobs = [ - multiprocess.Process( - target=wrap, - args=(rank, link_desc, configs), - ) - for rank in range(2) - ] - [job.start() for job in jobs] - for job in jobs: - job.join() - self.assertEqual(job.exitcode, 0) - - # including title, actual matched item cnt is 1. - self.assertEqual( - wc_count(f"{self.tempdir_.name}/spu_test_pir_pir_output.csv"), 2 - ) + [job.start() for job in jobs] + for job in jobs: + job.join() + self.assertEqual(job.exitcode, 0) + + import pandas as pd + + df1 = pd.read_csv(f'{temp_dir}/result.csv') + df2 = pd.read_csv('spu/tests/data/pir/ground_truth.csv') + + self.assertTrue(df1.equals(df2)) if __name__ == '__main__': diff --git a/spu/tests/ub_psi_test.py b/spu/tests/ub_psi_test.py index 2728e1b6..2a5323e1 100644 --- a/spu/tests/ub_psi_test.py +++ b/spu/tests/ub_psi_test.py @@ -43,12 +43,12 @@ def test_ub_psi(self): "role": "ROLE_SERVER", "cache_path": "{self.tempdir_.name}/spu_test_ub_psi_server_cache", "input_config": {{ + "type" : "IO_TYPE_FILE_CSV", "path": "spu/tests/data/alice.csv" }}, "keys": [ "id" - ], - "server_secret_key_path": "{self.tempdir_.name}/spu_test_ub_psi_server_secret_key.key" + ] }} ''' @@ -60,15 +60,6 @@ def test_ub_psi(self): }} ''' - with open( - f"{self.tempdir_.name}/spu_test_ub_psi_server_secret_key.key", 'wb' - ) as f: - f.write( - bytes.fromhex( - "000102030405060708090a0b0c0d0e0ff0e0d0c0b0a090807060504030201000" - ) - ) - configs = [ json_format.ParseDict(json.loads(server_offline_config), psi.UbPsiConfig()), json_format.ParseDict(json.loads(client_offline_config), psi.UbPsiConfig()), @@ -95,8 +86,10 @@ def wrap(rank, link_desc, configs): {{ "mode": "MODE_ONLINE", "role": "ROLE_SERVER", - "server_secret_key_path": "{self.tempdir_.name}/spu_test_ub_psi_server_secret_key.key", - "cache_path": "{self.tempdir_.name}/spu_test_ub_psi_server_cache" + "cache_path": "{self.tempdir_.name}/spu_test_ub_psi_server_cache", + "output_config": {{ + "type" : "IO_TYPE_FILE_CSV" + }} }} ''' @@ -105,9 +98,11 @@ def wrap(rank, link_desc, configs): "mode": "MODE_ONLINE", "role": "ROLE_CLIENT", "input_config": {{ + "type" : "IO_TYPE_FILE_CSV", "path": "spu/tests/data/bob.csv" }}, "output_config": {{ + "type" : "IO_TYPE_FILE_CSV", "path": "{self.tempdir_.name}/spu_test_ubpsi_bob_psi_ouput.csv" }}, "keys": [ diff --git a/spu/utils/simulation.py b/spu/utils/simulation.py index 592ccaff..6a6c220a 100644 --- a/spu/utils/simulation.py +++ b/spu/utils/simulation.py @@ -22,6 +22,7 @@ import jax.extend.linear_util as jax_lu except ImportError: import jax.linear_util as jax_lu # fallback + import jax.numpy as jnp import numpy as np from jax._src import api_util as japi_util @@ -69,6 +70,7 @@ def simple(cls, wsize: int, prot: spu_pb2.ProtocolKind, field: spu_pb2.FieldType A SPU Simulator """ config = spu_pb2.RuntimeConfig(protocol=prot, field=field) + if prot == spu_pb2.ProtocolKind.CHEETAH: # config.cheetah_2pc_config.enable_mul_lsb_error = True # config.cheetah_2pc_config.ot_kind = spu_pb2.CheetahOtKind.YACL_Softspoken