From c04a462e5c2a4105e0f10f0dab2cb621578e08c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexandre=20P=C3=A9r=C3=A9?= Date: Thu, 14 Nov 2024 11:24:26 +0100 Subject: [PATCH 1/2] feat(optimizer): add generic keyset info generation --- .../include/concretelang/Common/Keysets.h | 5 + .../lib/Bindings/Python/CompilerAPIModule.cpp | 87 +++++++++ .../compiler/lib/Common/Keysets.cpp | 94 ++++++++++ .../src/concrete-optimizer.rs | 30 ++++ .../src/cpp/concrete-optimizer.cpp | 72 ++++++++ .../src/cpp/concrete-optimizer.hpp | 17 ++ .../multi_parameters/generic_generation.rs | 169 ++++++++++++++++++ .../optimization/dag/multi_parameters/mod.rs | 1 + .../dag/multi_parameters/partitions.rs | 2 +- .../tests/compilation/test_restrictions.py | 32 +++- 10 files changed, 507 insertions(+), 2 deletions(-) create mode 100644 compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/generic_generation.rs diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Common/Keysets.h b/compilers/concrete-compiler/compiler/include/concretelang/Common/Keysets.h index 465ee76fd..b6d47a20b 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Common/Keysets.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Common/Keysets.h @@ -6,6 +6,7 @@ #ifndef CONCRETELANG_COMMON_KEYSETS_H #define CONCRETELANG_COMMON_KEYSETS_H +#include "concrete-optimizer.hpp" #include "concrete-protocol.capnp.h" #include "concretelang/Common/Csprng.h" #include "concretelang/Common/Error.h" @@ -92,6 +93,10 @@ class KeysetCache { KeysetCache() = default; }; +Message generate_generic_keyset_info( + std::vector partitions, + bool generate_fks); + } // namespace keysets } // namespace concretelang diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index c1183a812..e834d2fc3 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -18,6 +18,8 @@ #include "concretelang/Support/Error.h" #include "concretelang/Support/V0Parameters.h" #include "concretelang/Support/logging.h" +#include <_types/_uint8_t.h> +#include #include #include #include @@ -645,6 +647,18 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( output.append(")"); return output; } + + bool operator==(LweSecretKeyParam const &other) const { + capnp::AnyStruct::Reader left = this->info.asReader().getParams(); + capnp::AnyStruct::Reader right = other.info.asReader().getParams(); + return left == right; + } + + bool operator!=(LweSecretKeyParam const &other) const { + capnp::AnyStruct::Reader left = this->info.asReader().getParams(); + capnp::AnyStruct::Reader right = other.info.asReader().getParams(); + return left != right; + } }; pybind11::class_(m, "LweSecretKeyParam") .def( @@ -659,6 +673,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( [](pybind11::object key) { return pybind11::hash(pybind11::repr(key)); }) + .def(pybind11::self == pybind11::self) + .def(pybind11::self != pybind11::self) .doc() = "Parameters of an LWE Secret Key."; // ------------------------------------------------------------------------------// @@ -689,6 +705,18 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( output.append(")"); return output; } + + bool operator==(BootstrapKeyParam const &other) const { + capnp::AnyStruct::Reader left = this->info.asReader().getParams(); + capnp::AnyStruct::Reader right = other.info.asReader().getParams(); + return left == right; + } + + bool operator!=(BootstrapKeyParam const &other) const { + capnp::AnyStruct::Reader left = this->info.asReader().getParams(); + capnp::AnyStruct::Reader right = other.info.asReader().getParams(); + return left != right; + } }; pybind11::class_(m, "BootstrapKeyParam") .def( @@ -745,6 +773,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( [](pybind11::object key) { return pybind11::hash(pybind11::repr(key)); }) + .def(pybind11::self == pybind11::self) + .def(pybind11::self != pybind11::self) .doc() = "Parameters of a Bootstrap key."; // ------------------------------------------------------------------------------// @@ -766,6 +796,18 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( output.append(")"); return output; } + + bool operator==(KeyswitchKeyParam const &other) const { + capnp::AnyStruct::Reader left = this->info.asReader().getParams(); + capnp::AnyStruct::Reader right = other.info.asReader().getParams(); + return left == right; + } + + bool operator!=(KeyswitchKeyParam const &other) const { + capnp::AnyStruct::Reader left = this->info.asReader().getParams(); + capnp::AnyStruct::Reader right = other.info.asReader().getParams(); + return left != right; + } }; pybind11::class_(m, "KeyswitchKeyParam") .def( @@ -804,6 +846,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( [](pybind11::object key) { return pybind11::hash(pybind11::repr(key)); }) + .def(pybind11::self == pybind11::self) + .def(pybind11::self != pybind11::self) .doc() = "Parameters of a keyswitch key."; // ------------------------------------------------------------------------------// @@ -834,6 +878,18 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( output.append(")"); return output; } + + bool operator==(PackingKeyswitchKeyParam const &other) const { + capnp::AnyStruct::Reader left = this->info.asReader().getParams(); + capnp::AnyStruct::Reader right = other.info.asReader().getParams(); + return left == right; + } + + bool operator!=(PackingKeyswitchKeyParam const &other) const { + capnp::AnyStruct::Reader left = this->info.asReader().getParams(); + capnp::AnyStruct::Reader right = other.info.asReader().getParams(); + return left != right; + } }; pybind11::class_(m, "PackingKeyswitchKeyParam") .def( @@ -892,13 +948,44 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( [](pybind11::object key) { return pybind11::hash(pybind11::repr(key)); }) + .def(pybind11::self == pybind11::self) + .def(pybind11::self != pybind11::self) .doc() = "Parameters of a packing keyswitch key."; + // ------------------------------------------------------------------------------// + // PARTITION DEFINITION // + // ------------------------------------------------------------------------------// + // + pybind11::class_( + m, "PartitionDefinition") + .def(init([](uint8_t precision, double norm2) + -> concrete_optimizer::utils::PartitionDefinition { + return concrete_optimizer::utils::PartitionDefinition{precision, + norm2}; + }), + arg("precision"), arg("norm2")) + .doc() = "Definition of a partition (in terms of precision in bits and " + "norm2 in value)."; + // ------------------------------------------------------------------------------// // KEYSET INFO // // ------------------------------------------------------------------------------// typedef Message KeysetInfo; pybind11::class_(m, "KeysetInfo") + .def_static( + "generate_generic", + [](std::vector + partitions, + bool generateFks) -> KeysetInfo { + if (partitions.size() < 2) { + throw std::runtime_error("Need at least two partition defs to " + "generate a generic keyset info."); + } + return ::concretelang::keysets::generate_generic_keyset_info( + partitions, generateFks); + }, + arg("partition_defs"), arg("generate_fks"), + "Generate a generic keyset info for a set of partition definitions") .def( "secret_keys", [](KeysetInfo &keysetInfo) { diff --git a/compilers/concrete-compiler/compiler/lib/Common/Keysets.cpp b/compilers/concrete-compiler/compiler/lib/Common/Keysets.cpp index 90d7f1d43..dabdcd052 100644 --- a/compilers/concrete-compiler/compiler/lib/Common/Keysets.cpp +++ b/compilers/concrete-compiler/compiler/lib/Common/Keysets.cpp @@ -6,6 +6,7 @@ #include "concretelang/Common/Keysets.h" #include "capnp/message.h" #include "concrete-cpu.h" +#include "concrete-optimizer.hpp" #include "concrete-protocol.capnp.h" #include "concretelang/Common/Csprng.h" #include "concretelang/Common/Error.h" @@ -417,5 +418,98 @@ KeysetCache::getKeyset(const Message &keysetInfo, return std::move(keyset); } +Message generate_generic_keyset_info( + std::vector partitionDefs, + bool generateFks) { + auto output = Message{}; + rust::Vec rustPartitionDefs{}; + for (auto def : partitionDefs) { + rustPartitionDefs.push_back(def); + } + auto parameters = concrete_optimizer::utils::generate_generic_keyset_info( + rustPartitionDefs, generateFks); + + auto skLen = (int)parameters.secret_keys.size(); + auto skBuilder = output.asBuilder().initLweSecretKeys(skLen); + for (int i = 0; i < skLen; i++) { + auto output = Message(); + auto sk = parameters.secret_keys[i]; + output.asBuilder().setId(sk.identifier); + output.asBuilder().getParams().setIntegerPrecision(64); + output.asBuilder().getParams().setLweDimension(sk.polynomial_size * + sk.glwe_dimension); + output.asBuilder().getParams().setKeyType( + ::concreteprotocol::KeyType::BINARY); + skBuilder.setWithCaveats(i, output.asReader()); + } + + auto bskLen = (int)parameters.bootstrap_keys.size(); + auto bskBuilder = output.asBuilder().initLweBootstrapKeys(bskLen); + for (int i = 0; i < bskLen; i++) { + auto output = Message(); + auto bsk = parameters.bootstrap_keys[i]; + output.asBuilder().setId(bsk.identifier); + output.asBuilder().setInputId(bsk.input_key.identifier); + output.asBuilder().setOutputId(bsk.output_key.identifier); + output.asBuilder().getParams().setLevelCount( + bsk.br_decomposition_parameter.level); + output.asBuilder().getParams().setBaseLog( + bsk.br_decomposition_parameter.log2_base); + output.asBuilder().getParams().setGlweDimension( + bsk.output_key.glwe_dimension); + output.asBuilder().getParams().setPolynomialSize( + bsk.output_key.polynomial_size); + output.asBuilder().getParams().setInputLweDimension( + bsk.input_key.polynomial_size); + output.asBuilder().getParams().setIntegerPrecision(64); + output.asBuilder().getParams().setKeyType( + concreteprotocol::KeyType::BINARY); + bskBuilder.setWithCaveats(i, output.asReader()); + } + + auto kskLen = (int)parameters.keyswitch_keys.size(); + auto ckskLen = (int)parameters.conversion_keyswitch_keys.size(); + auto kskBuilder = output.asBuilder().initLweKeyswitchKeys(kskLen + ckskLen); + for (int i = 0; i < kskLen; i++) { + auto output = Message(); + auto ksk = parameters.keyswitch_keys[i]; + output.asBuilder().setId(ksk.identifier); + output.asBuilder().setInputId(ksk.input_key.identifier); + output.asBuilder().setOutputId(ksk.output_key.identifier); + output.asBuilder().getParams().setLevelCount( + ksk.ks_decomposition_parameter.level); + output.asBuilder().getParams().setBaseLog( + ksk.ks_decomposition_parameter.log2_base); + output.asBuilder().getParams().setIntegerPrecision(64); + output.asBuilder().getParams().setInputLweDimension( + ksk.input_key.glwe_dimension * ksk.input_key.polynomial_size); + output.asBuilder().getParams().setOutputLweDimension( + ksk.output_key.glwe_dimension * ksk.output_key.polynomial_size); + output.asBuilder().getParams().setKeyType( + concreteprotocol::KeyType::BINARY); + kskBuilder.setWithCaveats(i, output.asReader()); + } + for (int i = 0; i < ckskLen; i++) { + auto output = Message(); + auto ksk = parameters.conversion_keyswitch_keys[i]; + output.asBuilder().setId(ksk.identifier); + output.asBuilder().setInputId(ksk.input_key.identifier); + output.asBuilder().setOutputId(ksk.output_key.identifier); + output.asBuilder().getParams().setLevelCount( + ksk.ks_decomposition_parameter.level); + output.asBuilder().getParams().setBaseLog( + ksk.ks_decomposition_parameter.log2_base); + output.asBuilder().getParams().setIntegerPrecision(64); + output.asBuilder().getParams().setInputLweDimension( + ksk.input_key.glwe_dimension * ksk.input_key.polynomial_size); + output.asBuilder().getParams().setOutputLweDimension( + ksk.output_key.glwe_dimension * ksk.output_key.polynomial_size); + output.asBuilder().getParams().setKeyType( + concreteprotocol::KeyType::BINARY); + kskBuilder.setWithCaveats(i + kskLen, output.asReader()); + } + return output; +} + } // namespace keysets } // namespace concretelang diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs index b195d1242..c76d31eab 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs @@ -11,6 +11,7 @@ use concrete_optimizer::dag::operator::{ }; use concrete_optimizer::dag::unparametrized; use concrete_optimizer::optimization::config::{Config, SearchSpace}; +use concrete_optimizer::optimization::dag::multi_parameters::generic_generation::generate_generic_parameters; use concrete_optimizer::optimization::dag::multi_parameters::keys_spec::CircuitSolution; use concrete_optimizer::optimization::dag::multi_parameters::optimize::{ KeysetRestriction, MacroParameters, NoSearchSpaceRestriction, RangeRestriction, @@ -913,6 +914,22 @@ fn location_from_string(string: &str) -> Box { } } +fn generate_generic_keyset_info( + inputs: Vec, + generate_fks: bool, +) -> ffi::CircuitKeys { + generate_generic_parameters( + inputs + .into_iter() + .map( + |ffi::PartitionDefinition { precision, norm2 }| concrete_optimizer::optimization::dag::multi_parameters::generic_generation::PartitionDefinition { precision, norm2 }, + ) + .collect(), + generate_fks, + ) + .into() +} + pub struct Weights(operator::Weights); fn vector(weights: &[i64]) -> Box { @@ -981,6 +998,12 @@ mod ffi { #[namespace = "concrete_optimizer::utils"] fn location_from_string(string: &str) -> Box; + #[namespace = "concrete_optimizer::utils"] + fn generate_generic_keyset_info( + partitions: Vec, + generate_fks: bool, + ) -> CircuitKeys; + #[namespace = "concrete_optimizer::utils"] fn get_external_partition( name: String, @@ -1359,6 +1382,13 @@ mod ffi { pub struct KeysetRestriction { pub info: KeysetInfo, } + + #[namespace = "concrete_optimizer::utils"] + #[derive(Debug, Clone)] + pub struct PartitionDefinition { + pub precision: u8, + pub norm2: f64, + } } fn processing_unit(options: &ffi::Options) -> ProcessingUnit { diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp index 6c6c8075c..fe34c2444 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp @@ -929,6 +929,13 @@ struct operator_new { }; } // namespace detail +template +union ManuallyDrop { + T value; + ManuallyDrop(T &&value) : value(::std::move(value)) {} + ~ManuallyDrop() {} +}; + template union MaybeUninit { T value; @@ -974,6 +981,9 @@ namespace concrete_optimizer { struct KeysetInfo; struct KeysetRestriction; } + namespace utils { + struct PartitionDefinition; + } } namespace concrete_optimizer { @@ -1387,6 +1397,18 @@ struct KeysetRestriction final { #endif // CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$KeysetRestriction } // namespace restriction +namespace utils { +#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$utils$PartitionDefinition +#define CXXBRIDGE1_STRUCT_concrete_optimizer$utils$PartitionDefinition +struct PartitionDefinition final { + ::std::uint8_t precision; + double norm2; + + using IsRelocatable = ::std::true_type; +}; +#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$utils$PartitionDefinition +} // namespace utils + namespace v0 { extern "C" { ::concrete_optimizer::v0::Solution concrete_optimizer$v0$cxxbridge1$optimize_bootstrap(::std::uint64_t precision, double noise_factor, ::concrete_optimizer::Options const &options) noexcept; @@ -1418,6 +1440,8 @@ ::concrete_optimizer::Location *concrete_optimizer$utils$cxxbridge1$location_unk ::concrete_optimizer::Location *concrete_optimizer$utils$cxxbridge1$location_from_string(::rust::Str string) noexcept; +void concrete_optimizer$utils$cxxbridge1$generate_generic_keyset_info(::rust::Vec<::concrete_optimizer::utils::PartitionDefinition> *partitions, bool generate_fks, ::CircuitKeys *return$) noexcept; + ::concrete_optimizer::ExternalPartition *concrete_optimizer$utils$cxxbridge1$get_external_partition(::rust::String *name, ::std::uint64_t log2_polynomial_size, ::std::uint64_t glwe_dimension, ::std::uint64_t internal_dim, double max_variance, double variance) noexcept; double concrete_optimizer$utils$cxxbridge1$get_noise_br(::concrete_optimizer::Options const &options, ::std::uint64_t log2_polynomial_size, ::std::uint64_t glwe_dimension, ::std::uint64_t lwe_dim, ::std::uint64_t pbs_level, ::std::uint64_t pbs_log2_base) noexcept; @@ -1560,6 +1584,13 @@ ::rust::Box<::concrete_optimizer::Location> location_from_string(::rust::Str str return ::rust::Box<::concrete_optimizer::Location>::from_raw(concrete_optimizer$utils$cxxbridge1$location_from_string(string)); } +::CircuitKeys generate_generic_keyset_info(::rust::Vec<::concrete_optimizer::utils::PartitionDefinition> partitions, bool generate_fks) noexcept { + ::rust::ManuallyDrop<::rust::Vec<::concrete_optimizer::utils::PartitionDefinition>> partitions$(::std::move(partitions)); + ::rust::MaybeUninit<::CircuitKeys> return$; + concrete_optimizer$utils$cxxbridge1$generate_generic_keyset_info(&partitions$.value, generate_fks, &return$.value); + return ::std::move(return$.value); +} + ::rust::Box<::concrete_optimizer::ExternalPartition> get_external_partition(::rust::String name, ::std::uint64_t log2_polynomial_size, ::std::uint64_t glwe_dimension, ::std::uint64_t internal_dim, double max_variance, double variance) noexcept { return ::rust::Box<::concrete_optimizer::ExternalPartition>::from_raw(concrete_optimizer$utils$cxxbridge1$get_external_partition(&name, log2_polynomial_size, glwe_dimension, internal_dim, max_variance, variance)); } @@ -1713,6 +1744,15 @@ ::concrete_optimizer::Location *cxxbridge1$box$concrete_optimizer$Location$alloc void cxxbridge1$box$concrete_optimizer$Location$dealloc(::concrete_optimizer::Location *) noexcept; void cxxbridge1$box$concrete_optimizer$Location$drop(::rust::Box<::concrete_optimizer::Location> *ptr) noexcept; +void cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$new(::rust::Vec<::concrete_optimizer::utils::PartitionDefinition> const *ptr) noexcept; +void cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$drop(::rust::Vec<::concrete_optimizer::utils::PartitionDefinition> *ptr) noexcept; +::std::size_t cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$len(::rust::Vec<::concrete_optimizer::utils::PartitionDefinition> const *ptr) noexcept; +::std::size_t cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$capacity(::rust::Vec<::concrete_optimizer::utils::PartitionDefinition> const *ptr) noexcept; +::concrete_optimizer::utils::PartitionDefinition const *cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$data(::rust::Vec<::concrete_optimizer::utils::PartitionDefinition> const *ptr) noexcept; +void cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$reserve_total(::rust::Vec<::concrete_optimizer::utils::PartitionDefinition> *ptr, ::std::size_t new_cap) noexcept; +void cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$set_len(::rust::Vec<::concrete_optimizer::utils::PartitionDefinition> *ptr, ::std::size_t len) noexcept; +void cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$truncate(::rust::Vec<::concrete_optimizer::utils::PartitionDefinition> *ptr, ::std::size_t len) noexcept; + ::concrete_optimizer::ExternalPartition *cxxbridge1$box$concrete_optimizer$ExternalPartition$alloc() noexcept; void cxxbridge1$box$concrete_optimizer$ExternalPartition$dealloc(::concrete_optimizer::ExternalPartition *) noexcept; void cxxbridge1$box$concrete_optimizer$ExternalPartition$drop(::rust::Box<::concrete_optimizer::ExternalPartition> *ptr) noexcept; @@ -1884,6 +1924,38 @@ void Box<::concrete_optimizer::Location>::drop() noexcept { cxxbridge1$box$concrete_optimizer$Location$drop(this); } template <> +Vec<::concrete_optimizer::utils::PartitionDefinition>::Vec() noexcept { + cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$new(this); +} +template <> +void Vec<::concrete_optimizer::utils::PartitionDefinition>::drop() noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$drop(this); +} +template <> +::std::size_t Vec<::concrete_optimizer::utils::PartitionDefinition>::size() const noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$len(this); +} +template <> +::std::size_t Vec<::concrete_optimizer::utils::PartitionDefinition>::capacity() const noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$capacity(this); +} +template <> +::concrete_optimizer::utils::PartitionDefinition const *Vec<::concrete_optimizer::utils::PartitionDefinition>::data() const noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$data(this); +} +template <> +void Vec<::concrete_optimizer::utils::PartitionDefinition>::reserve_total(::std::size_t new_cap) noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$reserve_total(this, new_cap); +} +template <> +void Vec<::concrete_optimizer::utils::PartitionDefinition>::set_len(::std::size_t len) noexcept { + return cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$set_len(this, len); +} +template <> +void Vec<::concrete_optimizer::utils::PartitionDefinition>::truncate(::std::size_t len) { + return cxxbridge1$rust_vec$concrete_optimizer$utils$PartitionDefinition$truncate(this, len); +} +template <> ::concrete_optimizer::ExternalPartition *Box<::concrete_optimizer::ExternalPartition>::allocation::alloc() noexcept { return cxxbridge1$box$concrete_optimizer$ExternalPartition$alloc(); } diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp index 336faa20a..662493a2c 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp @@ -955,6 +955,9 @@ namespace concrete_optimizer { struct KeysetInfo; struct KeysetRestriction; } + namespace utils { + struct PartitionDefinition; + } } namespace concrete_optimizer { @@ -1368,6 +1371,18 @@ struct KeysetRestriction final { #endif // CXXBRIDGE1_STRUCT_concrete_optimizer$restriction$KeysetRestriction } // namespace restriction +namespace utils { +#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$utils$PartitionDefinition +#define CXXBRIDGE1_STRUCT_concrete_optimizer$utils$PartitionDefinition +struct PartitionDefinition final { + ::std::uint8_t precision; + double norm2; + + using IsRelocatable = ::std::true_type; +}; +#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$utils$PartitionDefinition +} // namespace utils + namespace v0 { ::concrete_optimizer::v0::Solution optimize_bootstrap(::std::uint64_t precision, double noise_factor, ::concrete_optimizer::Options const &options) noexcept; } // namespace v0 @@ -1381,6 +1396,8 @@ ::rust::Box<::concrete_optimizer::Location> location_unknown() noexcept; ::rust::Box<::concrete_optimizer::Location> location_from_string(::rust::Str string) noexcept; +::CircuitKeys generate_generic_keyset_info(::rust::Vec<::concrete_optimizer::utils::PartitionDefinition> partitions, bool generate_fks) noexcept; + ::rust::Box<::concrete_optimizer::ExternalPartition> get_external_partition(::rust::String name, ::std::uint64_t log2_polynomial_size, ::std::uint64_t glwe_dimension, ::std::uint64_t internal_dim, double max_variance, double variance) noexcept; double get_noise_br(::concrete_optimizer::Options const &options, ::std::uint64_t log2_polynomial_size, ::std::uint64_t glwe_dimension, ::std::uint64_t lwe_dim, ::std::uint64_t pbs_level, ::std::uint64_t pbs_log2_base) noexcept; diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/generic_generation.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/generic_generation.rs new file mode 100644 index 000000000..0e0c4fc08 --- /dev/null +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/generic_generation.rs @@ -0,0 +1,169 @@ +use crate::{ + computing_cost::cpu::CpuComplexity, + config::ProcessingUnit, + dag::{ + operator::{FunctionTable, LevelledComplexity, Precision, Shape}, + unparametrized, + }, + optimization::{ + config::{Config, SearchSpace}, + decomposition::{self}, + }, +}; + +use super::{ + keys_spec::{CircuitKeys, ExpandedCircuitKeys}, + optimize::{optimize, NoSearchSpaceRestriction}, + partition_cut::PartitionCut, + PartitionIndex, +}; + +const _4_SIGMA: f64 = 0.000_063_342_483_999_973; + +#[derive(Debug, Clone, PartialEq)] +pub struct PartitionDefinition { + pub precision: Precision, + pub norm2: f64, +} + +impl PartialOrd for PartitionDefinition { + fn partial_cmp(&self, other: &Self) -> Option { + match self.precision.cmp(&other.precision) { + std::cmp::Ordering::Equal => self.norm2.partial_cmp(&other.norm2), + ordering => Some(ordering), + } + } +} + +pub fn generate_generic_parameters( + partitions: Vec, + generate_fks: bool, +) -> CircuitKeys { + let mut dag = unparametrized::Dag::new(); + + for def_a in partitions.iter() { + for def_b in partitions.iter() { + if def_a == def_b { + continue; + } + let inp_a = dag.add_input(def_a.precision, Shape::number()); + let lut_a = dag.add_lut(inp_a, FunctionTable::UNKWOWN, def_a.precision); + let _weighted_a = dag.add_linear_noise( + [lut_a], + LevelledComplexity::ZERO, + [def_a.norm2.sqrt()], + Shape::number(), + "", + ); + + let inp_b = dag.add_input(def_b.precision, Shape::number()); + let lut_b = dag.add_lut(inp_b, FunctionTable::UNKWOWN, def_b.precision); + let weighted_b = dag.add_linear_noise( + [lut_b], + LevelledComplexity::ZERO, + [def_b.norm2.sqrt()], + Shape::number(), + "", + ); + + dag.add_composition(weighted_b, inp_a); + + if generate_fks && def_a > def_b { + let inp_a = dag.add_input(def_a.precision, Shape::number()); + let lut_a = dag.add_lut(inp_a, FunctionTable::UNKWOWN, def_a.precision); + let _weighted_a = dag.add_linear_noise( + [lut_a], + LevelledComplexity::ZERO, + [def_a.norm2.sqrt()], + Shape::number(), + "", + ); + + let inp_b = dag.add_input(def_b.precision, Shape::number()); + let lut_b = dag.add_lut(inp_b, FunctionTable::UNKWOWN, def_b.precision); + let _weighted_b = dag.add_linear_noise( + [lut_b], + LevelledComplexity::ZERO, + [def_b.norm2.sqrt()], + Shape::number(), + "", + ); + + let _ = dag.add_linear_noise( + [lut_a, lut_b], + LevelledComplexity::ZERO, + [0., 0.], + Shape::number(), + "", + ); + } + } + } + + let precisions: Vec<_> = partitions.iter().map(|def| def.precision).collect(); + let n_partitions = precisions.len(); + let p_cut = PartitionCut::maximal_partitionning(&dag); + let config = Config { + security_level: 128, + maximum_acceptable_error_probability: _4_SIGMA, + key_sharing: true, + ciphertext_modulus_log: 64, + fft_precision: 53, + complexity_model: &CpuComplexity::default(), + }; + let search_space = SearchSpace::default_cpu(); + let cache = decomposition::cache(128, ProcessingUnit::Cpu, None, true, 64, 53); + let parameters = optimize( + &dag, + config, + &search_space, + &NoSearchSpaceRestriction, + &cache, + &Some(p_cut), + PartitionIndex(0), + ) + .map_or(None, |v| Some(v.1)) + .unwrap(); + + for i in 0..n_partitions { + for j in 0..n_partitions { + assert!( + parameters.micro_params.ks[i][j].is_some(), + "Ksk[{i},{j}] missing." + ); + if i > j { + assert!( + parameters.micro_params.fks[i][j].is_some(), + "Fksk[{i},{j}] missing." + ); + } + } + } + ExpandedCircuitKeys::of(¶meters).compacted() +} + +#[cfg(test)] +mod test { + use super::{generate_generic_parameters, PartitionDefinition}; + + #[test] + fn test_generate_generic_parameters() { + let _ = generate_generic_parameters( + vec![ + PartitionDefinition { + precision: 3, + norm2: 1., + }, + PartitionDefinition { + precision: 3, + norm2: 100., + }, + PartitionDefinition { + precision: 3, + norm2: 1000., + }, + ], + true, + ); + } +} diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs index e0b66a59e..7e88ea7ba 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs @@ -2,6 +2,7 @@ pub(crate) mod analyze; mod complexity; mod fast_keyswitch; mod feasible; +pub mod generic_generation; pub mod keys_spec; pub mod optimize; pub mod optimize_generic; diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitions.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitions.rs index 18dd8050d..72a7e14b8 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitions.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitions.rs @@ -9,7 +9,7 @@ use crate::dag::operator::OperatorIndex; use super::partition_cut::PartitionCut; #[derive(Clone, Debug, PartialEq, Eq, Default, PartialOrd, Ord, Hash, Copy)] -pub struct PartitionIndex(pub(crate) usize); +pub struct PartitionIndex(pub usize); impl PartitionIndex { pub const FIRST: Self = Self(0); diff --git a/frontends/concrete-python/tests/compilation/test_restrictions.py b/frontends/concrete-python/tests/compilation/test_restrictions.py index de96cb85b..ab42c5a52 100644 --- a/frontends/concrete-python/tests/compilation/test_restrictions.py +++ b/frontends/concrete-python/tests/compilation/test_restrictions.py @@ -4,7 +4,12 @@ import numpy as np import pytest -from mlir._mlir_libs._concretelang._compiler import KeysetRestriction, RangeRestriction +from mlir._mlir_libs._concretelang._compiler import ( + KeysetInfo, + KeysetRestriction, + PartitionDefinition, + RangeRestriction, +) from concrete import fhe @@ -96,3 +101,28 @@ def inc(x): restricted_keyset_info = restricted_module.keys.specs.program_info.get_keyset_info() assert big_keyset_info == restricted_keyset_info assert small_keyset_info != restricted_keyset_info + + +def test_generic_restriction(): + """ + Test that compiling a module works. + """ + + generic_keyset_info = KeysetInfo.generate_generic( + [PartitionDefinition(8, 10.0), PartitionDefinition(10, 10000.0)], True + ) + + @fhe.module() + class Module: + @fhe.function({"x": "encrypted"}) + def inc(x): + return (x + 1) % 200 + + inputset = [np.random.randint(1, 200, size=()) for _ in range(100)] + restricted_module = Module.compile( + {"inc": inputset}, + enable_unsafe_features=True, + keyset_restriction=generic_keyset_info.get_restriction(), + ) + compiled_keyset_info = restricted_module.keys.specs.program_info.get_keyset_info() + assert all([k in generic_keyset_info.secret_keys() for k in compiled_keyset_info.secret_keys()]) From ca3fb40e9d6594e49923a7eb3984e75b8de02c16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexandre=20P=C3=A9r=C3=A9?= Date: Mon, 18 Nov 2024 11:08:39 +0100 Subject: [PATCH 2/2] doc(frontend): document parameter restrictions --- docs/SUMMARY.md | 1 + ...rameter_compatibility_with_restrictions.md | 80 +++++++++++++++++++ docs/guides/configure.md | 6 ++ .../tests/compilation/test_restrictions.py | 8 +- 4 files changed, 90 insertions(+), 5 deletions(-) create mode 100644 docs/compilation/parameter_compatibility_with_restrictions.md diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 5ecb2fedf..17f6329e4 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -29,6 +29,7 @@ * [Multi parameters](compilation/multi_parameters.md) * [Compression](compilation/compression.md) * [Reusing arguments](compilation/reuse_arguments.md) +* [Parameter compatibility with restrictions](compilation/parameter_compatibility_with_restrictions.md) * [Common errors](compilation/common_errors.md) ## Execution / Analysis diff --git a/docs/compilation/parameter_compatibility_with_restrictions.md b/docs/compilation/parameter_compatibility_with_restrictions.md new file mode 100644 index 000000000..83bb293ad --- /dev/null +++ b/docs/compilation/parameter_compatibility_with_restrictions.md @@ -0,0 +1,80 @@ +# Parameters compatibility with restrictions + +When compiling a module, the optimizer analyzes the circuits and the expected probability of error, to find the fastest crypto-parameters suiting those constraints. Depending on the crypto-parameters found, the size of the keys (and the ciphertexts) will differ. This means that if an existing module is used in production (using a certain set of crypto-parameters), there is no guarantee that a compilation of a second (different) module will yield compatible crypto-parameters. + +Concrete provides a way to ensure that a compilation is going to yield compatible crypto-parameters, thanks to _restrictions_. Restrictions are going to restrict the search-space walked by the optimizer to ensure that only compatible parameters can be returnedyielded. As of now, we support two major restrictions: + ++ __Keyset restriction__ : Restricts the crypto-parameters to an existing keyset. This restriction is suited for users that already have a module in production, and want to compile a compatible module. ++ __Ranges restriction__ : Restricts the crypto-parameters ranges allowed in the optimizer. This restriction is suited to users targetting a specific backend which does not support the breadth of parameters available on CPU. + +## Keyset restriction + +The keyset restriction can be generated directly form an existing keyset: + +```python +@fhe.module() +class Big: + @fhe.function({"x": "encrypted"}) + def inc(x): + return (x + 1) % 200 + +big_inputset = [np.random.randint(1, 200, size=()) for _ in range(100)] +big_module = Big.compile( + {"inc": big_inputset}, +) +big_keyset_info = big_module.keys.specs.program_info.get_keyset_info() + +# We get the restriction from the existing keyset +restriction = big_keyset_info.get_restriction() + +@fhe.module() +class Small: + @fhe.function({"x": "encrypted"}) + def inc(x): + return (x + 1) % 20 + +small_inputset = [np.random.randint(1, 20, size=()) for _ in range(100)] +small_module = Small.compile( + {"inc": small_inputset}, + # We pass the keyset restriction as an extra compilation option + keyset_restriction=restriction +) +restricted_keyset_info = restricted_module.keys.specs.program_info.get_keyset_info() +assert big_keyset_info == restricted_keyset_info +``` + +## Ranges restriction + +A ranges restriction can be built by adding available values: +```python +@fhe.module() +class Module: + @fhe.function({"x": "encrypted"}) + def inc(x): + return (x + 1) % 20 + +inputset = [np.random.randint(1, 20, size=()) for _ in range(100)] + +## We generate a range restriction +range_restriction = RangeRestriction() + +## Make 999 and 200 available as internal lwe dimensions +range_restriction.add_available_internal_lwe_dimension(999) +range_restriction.add_available_internal_lwe_dimension(200) + +## Setting other restrictions +range_restriction.add_available_glwe_log_polynomial_size(12) +range_restriction.add_available_glwe_dimension(2) +range_restriction.add_available_pbs_level_count(3) +range_restriction.add_available_pbs_base_log(11) +range_restriction.add_available_ks_level_count(3) +range_restriction.add_available_ks_base_log(6) + +module = Module.compile( + {"inc": inputset}, + # We pass the range restriction as an extra compilation option. + range_restriction=range_restriction +) +``` + +Note that if no available parameters are set for one of the parameter ranges (say `ks_base_log`), it is assumed that the default range is available. diff --git a/docs/guides/configure.md b/docs/guides/configure.md index 88e0b5b3f..521764b9e 100644 --- a/docs/guides/configure.md +++ b/docs/guides/configure.md @@ -201,6 +201,12 @@ When options are specified both in the `configuration` and as kwargs in the `com #### single_precision: bool = False - Use single precision for the whole circuit. +#### range_restriction: Optional[RangeRestriction] = None +- A range restriction to pass to the optimizer to restrict the available crypto-parameters. + +#### keyset_restriction: Optional[KeysetRestriction] = None +- A keyset restriction to pass to the optimizer to restrict the available crypto-parameters. + #### use_gpu: bool = False - Enable generating code for GPU in the compiler. diff --git a/frontends/concrete-python/tests/compilation/test_restrictions.py b/frontends/concrete-python/tests/compilation/test_restrictions.py index ab42c5a52..b9836b24e 100644 --- a/frontends/concrete-python/tests/compilation/test_restrictions.py +++ b/frontends/concrete-python/tests/compilation/test_restrictions.py @@ -46,7 +46,7 @@ def inc(x): ks_base_log = 6 range_restriction.add_available_ks_base_log(ks_base_log) module = Module.compile( - {"inc": inputset}, enable_unsafe_features=True, range_restriction=range_restriction + {"inc": inputset}, range_restriction=range_restriction ) keyset_info = module.keys.specs.program_info.get_keyset_info() assert keyset_info.bootstrap_keys()[0].polynomial_size() == 2**glwe_log_polynomial_size @@ -83,20 +83,19 @@ def inc(x): big_module = Big.compile( {"inc": big_inputset}, - enable_unsafe_features=True, ) big_keyset_info = big_module.keys.specs.program_info.get_keyset_info() small_module = Small.compile( {"inc": small_inputset}, - enable_unsafe_features=True, ) small_keyset_info = small_module.keys.specs.program_info.get_keyset_info() assert big_keyset_info != small_keyset_info restriction = big_keyset_info.get_restriction() restricted_module = Small.compile( - {"inc": small_inputset}, enable_unsafe_features=True, keyset_restriction=restriction + {"inc": small_inputset}, + keyset_restriction=restriction ) restricted_keyset_info = restricted_module.keys.specs.program_info.get_keyset_info() assert big_keyset_info == restricted_keyset_info @@ -121,7 +120,6 @@ def inc(x): inputset = [np.random.randint(1, 200, size=()) for _ in range(100)] restricted_module = Module.compile( {"inc": inputset}, - enable_unsafe_features=True, keyset_restriction=generic_keyset_info.get_restriction(), ) compiled_keyset_info = restricted_module.keys.specs.program_info.get_keyset_info()