diff --git a/.github/workflows/compiler_build_and_test_cpu.yml b/.github/workflows/compiler_build_and_test_cpu.yml index a5f0a9c773..cac78b9d44 100644 --- a/.github/workflows/compiler_build_and_test_cpu.yml +++ b/.github/workflows/compiler_build_and_test_cpu.yml @@ -81,6 +81,7 @@ jobs: shell: bash run: | rustup toolchain install nightly-2024-09-30 + pip install mypy set -e cd /concrete/compilers/concrete-compiler/compiler rm -rf /build/* @@ -137,6 +138,7 @@ jobs: rustup toolchain install nightly-2024-09-30 cd /concrete/compilers/concrete-compiler/compiler pip install pytest + pip install mypy dnf install -y libzstd libzstd-devel sed "s/pytest/python -m pytest/g" -i Makefile mkdir -p /tmp/concrete_compiler/gpu_tests/ diff --git a/.github/workflows/compiler_build_and_test_gpu.yml b/.github/workflows/compiler_build_and_test_gpu.yml index 92c7e89444..688b0a3889 100644 --- a/.github/workflows/compiler_build_and_test_gpu.yml +++ b/.github/workflows/compiler_build_and_test_gpu.yml @@ -81,6 +81,7 @@ jobs: shell: bash run: | rustup toolchain install nightly-2024-09-30 + pip install mypy set -e cd /concrete/compilers/concrete-compiler/compiler rm -rf /build/* diff --git a/.github/workflows/compiler_macos_build_and_test.yml b/.github/workflows/compiler_macos_build_and_test.yml index 874eba80c3..9bc6c2588c 100644 --- a/.github/workflows/compiler_macos_build_and_test.yml +++ b/.github/workflows/compiler_macos_build_and_test.yml @@ -37,6 +37,7 @@ jobs: brew install ninja ccache pip3.10 install numpy pybind11==2.8 wheel delocate pip3.10 install pytest + pip3.10 install mypy - name: Cache compilation (push) if: github.event_name == 'push' diff --git a/.github/workflows/concrete_python_benchmark.yml b/.github/workflows/concrete_python_benchmark.yml index 1f94004553..6fbfb3f0ca 100644 --- a/.github/workflows/concrete_python_benchmark.yml +++ b/.github/workflows/concrete_python_benchmark.yml @@ -57,6 +57,7 @@ jobs: set -e rustup toolchain install nightly-2024-09-30 + pip install mypy rm -rf /build/* export PYTHON=${{ format('python{0}', matrix.python-version) }} diff --git a/.github/workflows/concrete_python_release.yml b/.github/workflows/concrete_python_release.yml index 85f9029bba..433de8cd81 100644 --- a/.github/workflows/concrete_python_release.yml +++ b/.github/workflows/concrete_python_release.yml @@ -102,6 +102,7 @@ jobs: set -e rustup toolchain install nightly-2024-09-30 + pip install mypy rm -rf /build/* export PYTHON=${{ format('python{0}', matrix.python-version) }} diff --git a/.github/workflows/concrete_python_release_gpu.yml b/.github/workflows/concrete_python_release_gpu.yml index a97b7bfe12..19fa2f4054 100644 --- a/.github/workflows/concrete_python_release_gpu.yml +++ b/.github/workflows/concrete_python_release_gpu.yml @@ -91,6 +91,7 @@ jobs: set -e rustup toolchain install nightly-2024-09-30 + pip install mypy rm -rf /build/* export PYTHON=${{ format('python{0}', matrix.python-version) }} diff --git a/.github/workflows/concrete_python_tests_linux.yml b/.github/workflows/concrete_python_tests_linux.yml index 9d0a54a186..7090bc891f 100644 --- a/.github/workflows/concrete_python_tests_linux.yml +++ b/.github/workflows/concrete_python_tests_linux.yml @@ -77,6 +77,7 @@ jobs: shell: bash run: | rustup toolchain install nightly-2024-09-30 + pip install mypy set -e rm -rf /build/* diff --git a/compilers/concrete-compiler/compiler/Makefile b/compilers/concrete-compiler/compiler/Makefile index e10e88a498..c67d5dd2d8 100644 --- a/compilers/concrete-compiler/compiler/Makefile +++ b/compilers/concrete-compiler/compiler/Makefile @@ -21,6 +21,7 @@ INSTALL_PREFIX?=$(abspath $(BUILD_DIR))/install INSTALL_PATH=$(abspath $(INSTALL_PREFIX))/concretecompiler/ MAKEFILE_ROOT_DIR=$(shell pwd) MINIMAL_TESTS?=OFF +STUBGEN=$(shell $(Python3_EXECUTABLE) -c "import sysconfig; sp = sysconfig.get_paths()['scripts']; print(f\"{sp}/stubgen\")") KEYSETCACHEDEV=/tmp/KeySetCache KEYSETCACHECI ?= ../KeySetCache @@ -186,6 +187,7 @@ concretecompiler: build-initialized python-bindings: build-initialized cmake --build $(BUILD_DIR) --target ConcretelangMLIRPythonModules cmake --build $(BUILD_DIR) --target ConcretelangPythonModules + PYTHONPATH=${PYTHONPATH}:$(BUILD_DIR)/tools/concretelang/python_packages/concretelang_core LD_PRELOAD=$(BUILD_DIR)/lib/libConcretelangRuntime.so $(STUBGEN) -m mlir._mlir_libs._concretelang._compiler --include-docstrings -o $(BUILD_DIR)/tools/concretelang/python_packages/concretelang_core clientlib: build-initialized cmake --build $(BUILD_DIR) --target ConcretelangClientLib diff --git a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/ClientLib.h b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/ClientLib.h index ea622e2ad9..56dba93763 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/ClientLib.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/ClientLib.h @@ -47,31 +47,46 @@ class ClientCircuit { public: static Result - create(const Message &info, - const ClientKeyset &keyset, - std::shared_ptr csprng, - bool useSimulation = false); + createEncrypted(const Message &info, + const ClientKeyset &keyset, + std::shared_ptr csprng); + + static Result + createSimulated(const Message &info, + std::shared_ptr csprng); Result prepareInput(Value arg, size_t pos); Result processOutput(TransportValue result, size_t pos); + Result simulatePrepareInput(Value arg, size_t pos); + + Result simulateProcessOutput(TransportValue result, size_t pos); + std::string getName(); const Message &getCircuitInfo(); + bool isSimulated(); + private: ClientCircuit() = delete; ClientCircuit(const Message &circuitInfo, std::vector inputTransformers, - std::vector outputTransformers) + std::vector outputTransformers, + bool simulated) : circuitInfo(circuitInfo), inputTransformers(inputTransformers), - outputTransformers(outputTransformers){}; + outputTransformers(outputTransformers), simulated(simulated){}; + static Result + create(const Message &info, + const ClientKeyset &keyset, + std::shared_ptr csprng, bool useSimulation); private: Message circuitInfo; std::vector inputTransformers; std::vector outputTransformers; + bool simulated; }; /// Contains all the context to generate inputs for a server call by the @@ -80,10 +95,14 @@ class ClientProgram { public: /// Generates a fresh client program with fresh keyset on the first use. static Result - create(const Message &info, - const ClientKeyset &keyset, - std::shared_ptr csprng, - bool useSimulation = false); + createEncrypted(const Message &info, + const ClientKeyset &keyset, + std::shared_ptr csprng); + + /// Generates a fresh client program with empty keyset for simulation. + static Result + createSimulated(const Message &info, + std::shared_ptr csprng); /// Returns a reference to the named client circuit if it exists. Result getClientCircuit(std::string circuitName); diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Common/Compat.h b/compilers/concrete-compiler/compiler/include/concretelang/Common/Compat.h deleted file mode 100644 index cc2579931d..0000000000 --- a/compilers/concrete-compiler/compiler/include/concretelang/Common/Compat.h +++ /dev/null @@ -1,504 +0,0 @@ -// Part of the Concrete Compiler Project, under the BSD3 License with Zama -// Exceptions. See -// https://github.com/zama-ai/concrete/blob/main/LICENSE.txt -// for license information. -// -// NOTE: -// ----- -// To limit the size of the refactoring, we chose to not propagate the new -// client/server lib to the exterior APIs after it was finalized. This file only -// serves as a compatibility layer for exterior (python/rust/c) apis, for the -// time being. - -#ifndef CONCRETELANG_COMMON_COMPAT -#define CONCRETELANG_COMMON_COMPAT - -#include "capnp/serialize-packed.h" -#include "concrete-protocol.capnp.h" -#include "concretelang/ClientLib/ClientLib.h" -#include "concretelang/Common/Keys.h" -#include "concretelang/Common/Keysets.h" -#include "concretelang/Common/Protocol.h" -#include "concretelang/Common/Values.h" -#include "concretelang/ServerLib/ServerLib.h" -#include "concretelang/Support/CompilerEngine.h" -#include "concretelang/Support/Error.h" -#include "kj/io.h" -#include "kj/std/iostream.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/ExecutionEngine/ExecutionEngine.h" -#include "mlir/ExecutionEngine/OptUtils.h" -#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" -#include -#include - -using concretelang::clientlib::ClientCircuit; -using concretelang::clientlib::ClientProgram; -using concretelang::keysets::Keyset; -using concretelang::keysets::KeysetCache; -using concretelang::keysets::ServerKeyset; -using concretelang::serverlib::ServerCircuit; -using concretelang::serverlib::ServerProgram; -using concretelang::values::TransportValue; -using concretelang::values::Value; - -#define CONCAT(a, b) CONCAT_INNER(a, b) -#define CONCAT_INNER(a, b) a##b - -#define GET_OR_THROW_EXPECTED_(VARNAME, RESULT, MAYBE) \ - auto MAYBE = RESULT; \ - if (auto err = MAYBE.takeError()) { \ - throw std::runtime_error(llvm::toString(std::move(err))); \ - } \ - VARNAME = std::move(*MAYBE); - -#define GET_OR_THROW_EXPECTED(VARNAME, RESULT) \ - GET_OR_THROW_EXPECTED_(VARNAME, RESULT, CONCAT(maybe, __COUNTER__)) - -#define GET_OR_THROW_RESULT_(VARNAME, RESULT, MAYBE) \ - auto MAYBE = RESULT; \ - if (MAYBE.has_failure()) { \ - throw std::runtime_error(MAYBE.as_failure().error().mesg); \ - } \ - VARNAME = MAYBE.value(); - -#define GET_OR_THROW_RESULT(VARNAME, RESULT) \ - GET_OR_THROW_RESULT_(VARNAME, RESULT, CONCAT(maybe, __COUNTER__)) - -#define EXPECTED_TRY_(lhs, rhs, maybe) \ - auto maybe = rhs; \ - if (auto err = maybe.takeError()) { \ - return std::move(err); \ - } \ - lhs = *maybe; - -#define EXPECTED_TRY(lhs, rhs) \ - EXPECTED_TRY_(lhs, rhs, CONCAT(maybe, __COUNTER__)) - -template llvm::Expected outcomeToExpected(Result outcome) { - if (outcome.has_failure()) { - return mlir::concretelang::StreamStringError( - outcome.as_failure().error().mesg); - } else { - return outcome.value(); - } -} - -// Every number sent by python through the API has a type `int64` that must be -// turned into the proper type expected by the ArgTransformers. This allows to -// get an extra transformer executed right before the ArgTransformer gets -// called. -std::function -getPythonTypeTransformer(const Message &info) { - if (info.asReader().getTypeInfo().hasIndex()) { - return [=](Value input) { - Tensor tensorInput = input.getTensor().value(); - return Value{(Tensor)tensorInput}; - }; - } else if (info.asReader().getTypeInfo().hasPlaintext()) { - if (info.asReader().getTypeInfo().getPlaintext().getIntegerPrecision() <= - 8) { - return [=](Value input) { - Tensor tensorInput = input.getTensor().value(); - return Value{(Tensor)tensorInput}; - }; - } - if (info.asReader().getTypeInfo().getPlaintext().getIntegerPrecision() <= - 16) { - return [=](Value input) { - Tensor tensorInput = input.getTensor().value(); - return Value{(Tensor)tensorInput}; - }; - } - if (info.asReader().getTypeInfo().getPlaintext().getIntegerPrecision() <= - 32) { - return [=](Value input) { - Tensor tensorInput = input.getTensor().value(); - return Value{(Tensor)tensorInput}; - }; - } - if (info.asReader().getTypeInfo().getPlaintext().getIntegerPrecision() <= - 64) { - return [=](Value input) { - Tensor tensorInput = input.getTensor().value(); - return Value{(Tensor)tensorInput}; - }; - } - assert(false); - } else if (info.asReader().getTypeInfo().hasLweCiphertext()) { - if (info.asReader() - .getTypeInfo() - .getLweCiphertext() - .getEncoding() - .hasInteger() && - info.asReader() - .getTypeInfo() - .getLweCiphertext() - .getEncoding() - .getInteger() - .getIsSigned()) { - return [=](Value input) { return input; }; - } else { - return [=](Value input) { - Tensor tensorInput = input.getTensor().value(); - return Value{(Tensor)tensorInput}; - }; - } - } else { - assert(false); - } -}; - -namespace concretelang { -namespace serverlib { -/// A transition structure that preserver the current API of the library -/// support. -struct ServerLambda { - ServerCircuit circuit; - bool isSimulation; -}; -} // namespace serverlib - -namespace clientlib { - -/// A transition structure that preserver the current API of the library -/// support. -struct LweSecretKeyParam { - Message info; -}; - -/// A transition structure that preserver the current API of the library -/// support. -struct BootstrapKeyParam { - Message info; -}; - -/// A transition structure that preserver the current API of the library -/// support. -struct KeyswitchKeyParam { - Message info; -}; - -/// A transition structure that preserver the current API of the library -/// support. -struct PackingKeyswitchKeyParam { - Message info; -}; - -/// A transition structure that preserver the current API of the library -/// support. -struct Encoding { - Message circuit; -}; - -/// A transition structure that preserver the current API of the library -/// support. -struct EncryptionGate { - Message gateInfo; -}; - -/// A transition structure that preserver the current API of the library -/// support. -struct CircuitGate { - Message gateInfo; -}; - -/// A transition structure that preserver the current API of the library -/// support. -struct ValueExporter { - ClientCircuit circuit; -}; - -/// A transition structure that preserver the current API of the library -/// support. -struct SimulatedValueExporter { - ClientCircuit circuit; -}; - -/// A transition structure that preserver the current API of the library -/// support. -struct ValueDecrypter { - ClientCircuit circuit; -}; - -/// A transition structure that preserver the current API of the library -/// support. -struct SimulatedValueDecrypter { - ClientCircuit circuit; -}; - -/// A transition structure that preserver the current API of the library -/// support. -struct PublicArguments { - std::vector values; -}; - -/// A transition structure that preserver the current API of the library -/// support. -struct PublicResult { - std::vector values; -}; - -/// A transition structure that preserver the current API of the library -/// support. -struct SharedScalarOrTensorData { - TransportValue value; -}; - -/// A transition structure that preserver the current API of the library -/// support. -struct ClientParameters { - Message programInfo; - std::vector secretKeys; - std::vector bootstrapKeys; - std::vector keyswitchKeys; - std::vector packingKeyswitchKeys; -}; - -/// A transition structure that preserver the current API of the library -/// support. -struct EvaluationKeys { - ServerKeyset keyset; -}; - -/// A transition structure that preserver the current API of the library -/// support. -struct KeySetCache { - KeysetCache keysetCache; -}; - -/// A transition structure that preserver the current API of the library -/// support. -struct KeySet { - Keyset keyset; -}; - -} // namespace clientlib -} // namespace concretelang - -namespace mlir { -namespace concretelang { - -/// A transition structure that preserves the current API of the library -/// support. -struct LambdaArgument { - ::concretelang::values::Value value; -}; - -/// LibraryCompilationResult is the result of a compilation to a library. -struct LibraryCompilationResult { - /// The output directory path where the compilation artifacts have been - /// generated. - std::string outputDirPath; -}; - -class LibrarySupport { - -public: - LibrarySupport(std::string outputPath, std::string runtimeLibraryPath = "", - bool generateSharedLib = true, bool generateStaticLib = true, - bool generateClientParameters = true, - bool generateCompilationFeedback = true) - : outputPath(outputPath), runtimeLibraryPath(runtimeLibraryPath), - generateSharedLib(generateSharedLib), - generateStaticLib(generateStaticLib), - generateClientParameters(generateClientParameters), - generateCompilationFeedback(generateCompilationFeedback) {} - - llvm::Expected> - compile(llvm::SourceMgr &program, CompilationOptions options) { - // Setup the compiler engine - auto context = CompilationContext::createShared(); - concretelang::CompilerEngine engine(context); - engine.setCompilationOptions(options); - - // Compile to a library - auto library = - engine.compile(program, outputPath, runtimeLibraryPath, - generateSharedLib, generateStaticLib, - generateClientParameters, generateCompilationFeedback); - if (auto err = library.takeError()) { - return std::move(err); - } - - auto result = std::make_unique(); - result->outputDirPath = outputPath; - return std::move(result); - } - - llvm::Expected> - compile(llvm::StringRef s, CompilationOptions options) { - std::unique_ptr mb = - llvm::MemoryBuffer::getMemBuffer(s); - llvm::SourceMgr sm; - sm.AddNewSourceBuffer(std::move(mb), llvm::SMLoc()); - return this->compile(sm, options); - } - - llvm::Expected> - compile(mlir::ModuleOp &program, - std::shared_ptr &context, - CompilationOptions options) { - - // Setup the compiler engine - concretelang::CompilerEngine engine(context); - engine.setCompilationOptions(options); - - // Compile to a library - auto library = - engine.compile(program, outputPath, runtimeLibraryPath, - generateSharedLib, generateStaticLib, - generateClientParameters, generateCompilationFeedback); - if (auto err = library.takeError()) { - return std::move(err); - } - - auto result = std::make_unique(); - result->outputDirPath = outputPath; - return std::move(result); - } - - /// Load the server lambda from the compilation result. - llvm::Expected<::concretelang::serverlib::ServerLambda> - loadServerLambda(LibraryCompilationResult &result, std::string circuitName, - bool useSimulation) { - EXPECTED_TRY(auto programInfo, getProgramInfo()); - EXPECTED_TRY(ServerProgram serverProgram, - outcomeToExpected(ServerProgram::load(programInfo.asReader(), - getSharedLibPath(), - useSimulation))); - EXPECTED_TRY( - ServerCircuit serverCircuit, - outcomeToExpected(serverProgram.getServerCircuit(circuitName))); - return ::concretelang::serverlib::ServerLambda{serverCircuit, - useSimulation}; - } - - llvm::Expected - loadServerProgram(LibraryCompilationResult &result, bool useSimulation) { - EXPECTED_TRY(auto programInfo, getProgramInfo()); - return outcomeToExpected(ServerProgram::load( - programInfo.asReader(), getSharedLibPath(), useSimulation)); - } - - /// Load the client parameters from the compilation result. - llvm::Expected<::concretelang::clientlib::ClientParameters> - loadClientParameters(LibraryCompilationResult &result) { - EXPECTED_TRY(auto programInfo, getProgramInfo()); - auto secretKeys = - std::vector<::concretelang::clientlib::LweSecretKeyParam>(); - for (auto key : programInfo.asReader().getKeyset().getLweSecretKeys()) { - secretKeys.push_back(::concretelang::clientlib::LweSecretKeyParam{key}); - } - auto boostrapKeys = - std::vector<::concretelang::clientlib::BootstrapKeyParam>(); - for (auto key : programInfo.asReader().getKeyset().getLweBootstrapKeys()) { - boostrapKeys.push_back(::concretelang::clientlib::BootstrapKeyParam{key}); - } - auto keyswitchKeys = - std::vector<::concretelang::clientlib::KeyswitchKeyParam>(); - for (auto key : programInfo.asReader().getKeyset().getLweKeyswitchKeys()) { - keyswitchKeys.push_back( - ::concretelang::clientlib::KeyswitchKeyParam{key}); - } - auto packingKeyswitchKeys = - std::vector<::concretelang::clientlib::PackingKeyswitchKeyParam>(); - for (auto key : - programInfo.asReader().getKeyset().getPackingKeyswitchKeys()) { - packingKeyswitchKeys.push_back( - ::concretelang::clientlib::PackingKeyswitchKeyParam{key}); - } - return ::concretelang::clientlib::ClientParameters{ - programInfo, secretKeys, boostrapKeys, keyswitchKeys, - packingKeyswitchKeys}; - } - - llvm::Expected> getProgramInfo() { - auto path = CompilerEngine::Library::getProgramInfoPath(outputPath); - std::ifstream file(path); - std::string content((std::istreambuf_iterator(file)), - (std::istreambuf_iterator())); - if (file.fail()) { - return StreamStringError("Cannot read file: ") << path; - } - auto output = Message(); - if (output.readJsonFromString(content).has_failure()) { - return StreamStringError("Cannot read json string."); - } - return output; - } - - /// Load the the compilation result if circuit already compiled - llvm::Expected> - loadCompilationResult() { - auto result = std::make_unique(); - result->outputDirPath = outputPath; - return std::move(result); - } - - llvm::Expected - loadCompilationFeedback(LibraryCompilationResult &result) { - auto path = CompilerEngine::Library::getCompilationFeedbackPath( - result.outputDirPath); - auto feedback = ProgramCompilationFeedback::load(path); - if (feedback.has_error()) { - return StreamStringError(feedback.error().mesg); - } - return feedback.value(); - } - - /// Call the lambda with the public arguments. - llvm::Expected> - serverCall(::concretelang::serverlib::ServerLambda lambda, - ::concretelang::clientlib::PublicArguments &args, - ::concretelang::clientlib::EvaluationKeys &evaluationKeys) { - if (lambda.isSimulation) { - return mlir::concretelang::StreamStringError( - "Tried to perform server call on simulation lambda."); - } - EXPECTED_TRY(auto output, outcomeToExpected(lambda.circuit.call( - evaluationKeys.keyset, args.values))); - ::concretelang::clientlib::PublicResult res{output}; - return std::make_unique<::concretelang::clientlib::PublicResult>( - std::move(res)); - } - - /// Call the lambda with the public arguments. - llvm::Expected> - simulate(::concretelang::serverlib::ServerLambda lambda, - ::concretelang::clientlib::PublicArguments &args) { - if (!lambda.isSimulation) { - return mlir::concretelang::StreamStringError( - "Tried to perform simulation on execution lambda."); - } - EXPECTED_TRY(auto output, - outcomeToExpected(lambda.circuit.simulate(args.values))); - ::concretelang::clientlib::PublicResult res{output}; - return std::make_unique<::concretelang::clientlib::PublicResult>( - std::move(res)); - } - - /// Get path to shared library - std::string getSharedLibPath() { - return CompilerEngine::Library::getSharedLibraryPath(outputPath); - } - - /// Get path to client parameters file - std::string getProgramInfoPath() { - return CompilerEngine::Library::getProgramInfoPath(outputPath); - } - -private: - std::string outputPath; - std::string runtimeLibraryPath; - /// Flags to select generated artifacts - bool generateSharedLib; - bool generateStaticLib; - bool generateClientParameters; - bool generateCompilationFeedback; -}; - -} // namespace concretelang -} // namespace mlir - -#endif diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilerEngine.h b/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilerEngine.h index a9fbfcc97b..bc02e7bdb3 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilerEngine.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilerEngine.h @@ -187,7 +187,7 @@ class CompilerEngine { std::string runtimeLibraryPath; bool cleanUp; mlir::concretelang::ProgramCompilationFeedback compilationFeedback; - Message programInfo; + std::optional> programInfo; public: /// Create a library instance on which you can add compilation results. @@ -209,22 +209,22 @@ class CompilerEngine { std::string staticLibraryPath; /// Returns the program info of the library. - Message getProgramInfo() const; + Result> getProgramInfo(); /// Returns the path to the output dir. const std::string &getOutputDirPath() const; /// Returns the path of the shared library - static std::string getSharedLibraryPath(std::string outputDirPath); + std::string getSharedLibraryPath() const; /// Returns the path of the static library - static std::string getStaticLibraryPath(std::string outputDirPath); + std::string getStaticLibraryPath() const; /// Returns the path of the program info - static std::string getProgramInfoPath(std::string outputDirPath); + std::string getProgramInfoPath() const; /// Returns the path of the compilation feedback - static std::string getCompilationFeedbackPath(std::string outputDirPath); + std::string getCompilationFeedbackPath() const; // For advanced use const static std::string OBJECT_EXT, LINKER, LINKER_SHARED_OPT, AR, diff --git a/compilers/concrete-compiler/compiler/include/concretelang/TestLib/TestProgram.h b/compilers/concrete-compiler/compiler/include/concretelang/TestLib/TestProgram.h index 5e3be16373..26945db0e4 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/TestLib/TestProgram.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/TestLib/TestProgram.h @@ -80,15 +80,16 @@ class TestProgram { return outcome::success(); } OUTCOME_TRY(auto lib, getLibrary()); + OUTCOME_TRY(auto programInfo, lib.getProgramInfo()); if (tryCache) { OUTCOME_TRY(keyset, getTestKeySetCachePtr()->getKeyset( - lib.getProgramInfo().asReader().getKeyset(), - secretSeed, encryptionSeed)); + programInfo.asReader().getKeyset(), secretSeed, + encryptionSeed)); } else { auto encryptionCsprng = csprng::EncryptionCSPRNG(encryptionSeed); auto secretCsprng = csprng::SecretCSPRNG(secretSeed); Message keysetInfo = - lib.getProgramInfo().asReader().getKeyset(); + programInfo.asReader().getKeyset(); keyset = Keyset(keysetInfo, secretCsprng, encryptionCsprng); } return outcome::success(); @@ -114,6 +115,27 @@ class TestProgram { return processedOutputs; } + Result> simulate(std::vector inputs, + std::string name = "main") { + // preprocess arguments + auto preparedArgs = std::vector(); + OUTCOME_TRY(auto clientCircuit, getClientCircuit(name)); + for (size_t i = 0; i < inputs.size(); i++) { + OUTCOME_TRY(auto preparedInput, + clientCircuit.simulatePrepareInput(inputs[i], i)); + preparedArgs.push_back(preparedInput); + } + // Call server + OUTCOME_TRY(auto returns, callServer(preparedArgs, name)); + // postprocess arguments + std::vector processedOutputs(returns.size()); + for (size_t i = 0; i < processedOutputs.size(); i++) { + OUTCOME_TRY(processedOutputs[i], + clientCircuit.simulateProcessOutput(returns[i], i)); + } + return processedOutputs; + } + Result> compose_n_times(std::vector inputs, size_t n, std::string name = "main") { @@ -152,28 +174,35 @@ class TestProgram { Result getClientCircuit(std::string name = "main") { OUTCOME_TRY(auto lib, getLibrary()); Keyset ks{}; + OUTCOME_TRY(auto programInfo, lib.getProgramInfo()); + if (!isSimulation()) { OUTCOME_TRY(ks, getKeyset()); + OUTCOME_TRY(auto clientProgram, + ClientProgram::createEncrypted(programInfo, ks.client, + encryptionCsprng)); + OUTCOME_TRY(auto clientCircuit, clientProgram.getClientCircuit(name)); + return clientCircuit; + } else { + OUTCOME_TRY(auto clientProgram, ClientProgram::createSimulated( + programInfo, encryptionCsprng)); + OUTCOME_TRY(auto clientCircuit, clientProgram.getClientCircuit(name)); + return clientCircuit; } - auto programInfo = lib.getProgramInfo(); - OUTCOME_TRY(auto clientProgram, - ClientProgram::create(programInfo, ks.client, encryptionCsprng, - isSimulation())); - OUTCOME_TRY(auto clientCircuit, clientProgram.getClientCircuit(name)); - return clientCircuit; } Result getServerCircuit(std::string name = "main") { OUTCOME_TRY(auto lib, getLibrary()); - auto programInfo = lib.getProgramInfo(); + OUTCOME_TRY(auto programInfo, lib.getProgramInfo()); OUTCOME_TRY(auto serverProgram, - ServerProgram::load(programInfo, - lib.getSharedLibraryPath(artifactDirectory), + ServerProgram::load(programInfo, lib.getSharedLibraryPath(), isSimulation())); OUTCOME_TRY(auto serverCircuit, serverProgram.getServerCircuit(name)); return serverCircuit; } + bool isSimulation() { return compiler.getCompilationOptions().simulate; } + private: std::string getArtifactDirectory() { return artifactDirectory; } @@ -191,8 +220,6 @@ class TestProgram { return *keyset; } - bool isSimulation() { return compiler.getCompilationOptions().simulate; } - std::string artifactDirectory; mlir::concretelang::CompilerEngine compiler; std::optional library; diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CMakeLists.txt index 2766fa6883..b3e0639cea 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CMakeLists.txt @@ -35,31 +35,10 @@ declare_mlir_python_sources( "${CMAKE_CURRENT_SOURCE_DIR}" SOURCES concrete/compiler/__init__.py - concrete/compiler/client_parameters.py - concrete/compiler/client_support.py - concrete/compiler/compilation_context.py concrete/compiler/compilation_feedback.py - concrete/compiler/compilation_options.py - concrete/compiler/key_set_cache.py - concrete/compiler/key_set.py - concrete/compiler/lambda_argument.py - concrete/compiler/library_compilation_result.py - concrete/compiler/library_support.py - concrete/compiler/library_lambda.py - concrete/compiler/lwe_secret_key.py - concrete/compiler/parameter.py - concrete/compiler/public_arguments.py - concrete/compiler/public_result.py - concrete/compiler/server_circuit.py - concrete/compiler/server_program.py - concrete/compiler/evaluation_keys.py - concrete/compiler/simulated_value_decrypter.py - concrete/compiler/simulated_value_exporter.py + concrete/compiler/compilation_context.py concrete/compiler/tfhers_int.py concrete/compiler/utils.py - concrete/compiler/value.py - concrete/compiler/value_decrypter.py - concrete/compiler/value_exporter.py concrete/compiler/wrapper.py concrete/__init__.py concrete/lang/__init__.py diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index 4965f64b68..a05c2a1196 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -7,21 +7,25 @@ #include "concrete-optimizer.hpp" #include "concrete-protocol.capnp.h" #include "concretelang/ClientLib/ClientLib.h" -#include "concretelang/Common/Compat.h" #include "concretelang/Common/Csprng.h" +#include "concretelang/Common/Keys.h" #include "concretelang/Common/Keysets.h" -#include "concretelang/Dialect/FHE/IR/FHEOpsDialect.h.inc" +#include "concretelang/Common/Values.h" #include "concretelang/Runtime/DFRuntime.hpp" #include "concretelang/Runtime/GPUDFG.hpp" #include "concretelang/ServerLib/ServerLib.h" +#include "concretelang/Support/CompilerEngine.h" +#include "concretelang/Support/Error.h" +#include "concretelang/Support/V0Parameters.h" #include "concretelang/Support/logging.h" -#include +#include +#include #include #include -#include -#include -#include +#include +#include +#include #include #include #include @@ -29,9 +33,60 @@ #include #include +using concretelang::clientlib::ClientCircuit; +using concretelang::clientlib::ClientProgram; +using concretelang::keysets::Keyset; +using concretelang::keysets::KeysetCache; +using concretelang::keysets::ServerKeyset; +using concretelang::serverlib::ServerCircuit; +using concretelang::serverlib::ServerProgram; +using concretelang::values::TransportValue; +using concretelang::values::Value; using mlir::concretelang::CompilationOptions; -using mlir::concretelang::LambdaArgument; +#define CONCAT(a, b) CONCAT_INNER(a, b) +#define CONCAT_INNER(a, b) a##b + +#define GET_OR_THROW_EXPECTED_(VARNAME, RESULT, MAYBE) \ + auto MAYBE = RESULT; \ + if (auto err = MAYBE.takeError()) { \ + throw std::runtime_error(llvm::toString(std::move(err))); \ + } \ + VARNAME = std::move(*MAYBE); + +#define GET_OR_THROW_EXPECTED(VARNAME, RESULT) \ + GET_OR_THROW_EXPECTED_(VARNAME, RESULT, CONCAT(maybe, __COUNTER__)) + +#define GET_OR_THROW_RESULT_(VARNAME, RESULT, MAYBE) \ + auto MAYBE = RESULT; \ + if (MAYBE.has_failure()) { \ + throw std::runtime_error(MAYBE.as_failure().error().mesg); \ + } \ + VARNAME = MAYBE.value(); + +#define GET_OR_THROW_RESULT(VARNAME, RESULT) \ + GET_OR_THROW_RESULT_(VARNAME, RESULT, CONCAT(maybe, __COUNTER__)) + +#define EXPECTED_TRY_(lhs, rhs, maybe) \ + auto maybe = rhs; \ + if (auto err = maybe.takeError()) { \ + return std::move(err); \ + } \ + lhs = *maybe; + +#define EXPECTED_TRY(lhs, rhs) \ + EXPECTED_TRY_(lhs, rhs, CONCAT(maybe, __COUNTER__)) + +template llvm::Expected outcomeToExpected(Result outcome) { + if (outcome.has_failure()) { + return mlir::concretelang::StreamStringError( + outcome.as_failure().error().mesg); + } else { + return outcome.value(); + } +} + +namespace { class SignalGuard { public: SignalGuard() { previousHandler = signal(SIGINT, SignalGuard::handler); } @@ -46,417 +101,6 @@ class SignalGuard { } }; -/// Wrapper of the mlir::concretelang::LambdaArgument -struct lambdaArgument { - std::shared_ptr ptr; -}; -typedef struct lambdaArgument lambdaArgument; - -/// Hold a list of lambdaArgument to represent execution arguments -struct executionArguments { - lambdaArgument *data; - size_t size; -}; -typedef struct executionArguments executionArguments; - -// Library Support bindings /////////////////////////////////////////////////// - -struct LibrarySupport_Py { - mlir::concretelang::LibrarySupport support; -}; -typedef struct LibrarySupport_Py LibrarySupport_Py; - -LibrarySupport_Py -library_support(const char *outputPath, const char *runtimeLibraryPath, - bool generateSharedLib, bool generateStaticLib, - bool generateClientParameters, bool generateCompilationFeedback, - bool generateCppHeader) { - return LibrarySupport_Py{mlir::concretelang::LibrarySupport( - outputPath, runtimeLibraryPath, generateSharedLib, generateStaticLib, - generateClientParameters, generateCompilationFeedback)}; -} - -std::unique_ptr -library_compile(LibrarySupport_Py support, const char *module, - mlir::concretelang::CompilationOptions options) { - llvm::SourceMgr sm; - sm.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(module), - llvm::SMLoc()); - GET_OR_THROW_EXPECTED(auto compilationResult, - support.support.compile(sm, options)); - return compilationResult; -} - -std::unique_ptr -library_compile_module( - LibrarySupport_Py support, mlir::ModuleOp module, - mlir::concretelang::CompilationOptions options, - std::shared_ptr cctx) { - GET_OR_THROW_EXPECTED(auto compilationResult, - support.support.compile(module, cctx, options)); - return compilationResult; -} - -concretelang::clientlib::ClientParameters library_load_client_parameters( - LibrarySupport_Py support, - mlir::concretelang::LibraryCompilationResult &result) { - GET_OR_THROW_EXPECTED(auto clientParameters, - support.support.loadClientParameters(result)); - return clientParameters; -} - -mlir::concretelang::ProgramCompilationFeedback -library_load_compilation_feedback( - LibrarySupport_Py support, - mlir::concretelang::LibraryCompilationResult &result) { - GET_OR_THROW_EXPECTED(auto compilationFeedback, - support.support.loadCompilationFeedback(result)); - return compilationFeedback; -} - -concretelang::serverlib::ServerLambda -library_load_server_lambda(LibrarySupport_Py support, - mlir::concretelang::LibraryCompilationResult &result, - std::string circuitName, bool useSimulation) { - GET_OR_THROW_EXPECTED( - auto serverLambda, - support.support.loadServerLambda(result, circuitName, useSimulation)); - return serverLambda; -} - -std::unique_ptr -library_server_call(LibrarySupport_Py support, - concretelang::serverlib::ServerLambda lambda, - concretelang::clientlib::PublicArguments &args, - concretelang::clientlib::EvaluationKeys &evaluationKeys) { - GET_OR_THROW_EXPECTED(auto publicResult, support.support.serverCall( - lambda, args, evaluationKeys)); - return publicResult; -} - -std::unique_ptr -library_simulate(LibrarySupport_Py support, - concretelang::serverlib::ServerLambda lambda, - concretelang::clientlib::PublicArguments &args) { - GET_OR_THROW_EXPECTED(auto publicResult, - support.support.simulate(lambda, args)); - return publicResult; -} - -std::string library_get_shared_lib_path(LibrarySupport_Py support) { - return support.support.getSharedLibPath(); -} - -std::string library_get_program_info_path(LibrarySupport_Py support) { - return support.support.getProgramInfoPath(); -} - -// Client Support bindings /////////////////////////////////////////////////// - -std::unique_ptr -key_set(concretelang::clientlib::ClientParameters clientParameters, - std::optional cache, - std::map lweSecretKeys, uint64_t secretSeedMsb, - uint64_t secretSeedLsb, uint64_t encSeedMsb, uint64_t encSeedLsb) { - auto secretSeed = (((__uint128_t)secretSeedMsb) << 64) | secretSeedLsb; - auto encryptionSeed = (((__uint128_t)encSeedMsb) << 64) | encSeedLsb; - - if (cache.has_value()) { - GET_OR_THROW_RESULT(Keyset keyset, - (*cache).keysetCache.getKeyset( - clientParameters.programInfo.asReader().getKeyset(), - secretSeed, encryptionSeed, lweSecretKeys)); - concretelang::clientlib::KeySet output{keyset}; - return std::make_unique(std::move(output)); - } else { - concretelang::csprng::SecretCSPRNG secCsprng(secretSeed); - concretelang::csprng::EncryptionCSPRNG encCsprng(encryptionSeed); - auto keyset = Keyset(clientParameters.programInfo.asReader().getKeyset(), - secCsprng, encCsprng, lweSecretKeys); - concretelang::clientlib::KeySet output{keyset}; - return std::make_unique(std::move(output)); - } -} - -std::unique_ptr -encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters, - concretelang::clientlib::KeySet &keySet, - llvm::ArrayRef args, - const std::string &circuitName) { - auto maybeProgram = ::concretelang::clientlib::ClientProgram::create( - clientParameters.programInfo.asReader(), keySet.keyset.client, - std::make_shared<::concretelang::csprng::EncryptionCSPRNG>( - ::concretelang::csprng::EncryptionCSPRNG(0)), - false); - if (maybeProgram.has_failure()) { - throw std::runtime_error(maybeProgram.as_failure().error().mesg); - } - auto circuit = maybeProgram.value().getClientCircuit(circuitName).value(); - std::vector output; - for (size_t i = 0; i < args.size(); i++) { - auto info = circuit.getCircuitInfo().asReader().getInputs()[i]; - auto typeTransformer = getPythonTypeTransformer(info); - auto input = typeTransformer(args[i]->value); - auto maybePrepared = circuit.prepareInput(input, i); - - if (maybePrepared.has_failure()) { - throw std::runtime_error(maybePrepared.as_failure().error().mesg); - } - output.push_back(maybePrepared.value()); - } - concretelang::clientlib::PublicArguments publicArgs{output}; - return std::make_unique( - std::move(publicArgs)); -} - -std::vector -decrypt_result(concretelang::clientlib::ClientParameters clientParameters, - concretelang::clientlib::KeySet &keySet, - concretelang::clientlib::PublicResult &publicResult, - const std::string &circuitName) { - auto maybeProgram = ::concretelang::clientlib::ClientProgram::create( - clientParameters.programInfo.asReader(), keySet.keyset.client, - std::make_shared<::concretelang::csprng::EncryptionCSPRNG>( - ::concretelang::csprng::EncryptionCSPRNG(0)), - false); - if (maybeProgram.has_failure()) { - throw std::runtime_error(maybeProgram.as_failure().error().mesg); - } - auto circuit = maybeProgram.value().getClientCircuit(circuitName).value(); - std::vector results; - for (auto e : llvm::enumerate(publicResult.values)) { - auto maybeProcessed = circuit.processOutput(e.value(), e.index()); - if (maybeProcessed.has_failure()) { - throw std::runtime_error(maybeProcessed.as_failure().error().mesg); - } - - mlir::concretelang::LambdaArgument out{maybeProcessed.value()}; - lambdaArgument tensor_arg{ - std::make_shared(std::move(out))}; - results.push_back(tensor_arg); - } - return results; -} - -std::unique_ptr -publicArgumentsUnserialize( - concretelang::clientlib::ClientParameters &clientParameters, - const std::string &buffer) { - auto publicArgumentsProto = Message(); - if (publicArgumentsProto.readBinaryFromString(buffer).has_failure()) { - throw std::runtime_error("Failed to deserialize public arguments."); - } - std::vector values; - for (auto arg : publicArgumentsProto.asReader().getArgs()) { - values.push_back(arg); - } - concretelang::clientlib::PublicArguments output{values}; - return std::make_unique( - std::move(output)); -} - -std::string publicArgumentsSerialize( - concretelang::clientlib::PublicArguments &publicArguments) { - auto publicArgumentsProto = Message(); - auto argBuilder = - publicArgumentsProto.asBuilder().initArgs(publicArguments.values.size()); - for (size_t i = 0; i < publicArguments.values.size(); i++) { - argBuilder.setWithCaveats(i, publicArguments.values[i].asReader()); - } - auto maybeBuffer = publicArgumentsProto.writeBinaryToString(); - if (maybeBuffer.has_failure()) { - throw std::runtime_error("Failed to serialize public arguments."); - } - return maybeBuffer.value(); -} - -std::unique_ptr publicResultUnserialize( - concretelang::clientlib::ClientParameters &clientParameters, - const std::string &buffer) { - auto publicResultsProto = Message(); - if (publicResultsProto.readBinaryFromString(buffer).has_failure()) { - throw std::runtime_error("Failed to deserialize public results."); - } - std::vector values; - for (auto res : publicResultsProto.asReader().getResults()) { - values.push_back(res); - } - concretelang::clientlib::PublicResult output{values}; - return std::make_unique( - std::move(output)); -} - -std::string -publicResultSerialize(concretelang::clientlib::PublicResult &publicResult) { - std::string buffer; - auto publicResultsProto = Message(); - auto resBuilder = - publicResultsProto.asBuilder().initResults(publicResult.values.size()); - for (size_t i = 0; i < publicResult.values.size(); i++) { - resBuilder.setWithCaveats(i, publicResult.values[i].asReader()); - } - auto maybeBuffer = publicResultsProto.writeBinaryToString(); - if (maybeBuffer.has_failure()) { - throw std::runtime_error("Failed to serialize public results."); - } - return maybeBuffer.value(); -} - -concretelang::clientlib::EvaluationKeys -evaluationKeysUnserialize(const std::string &buffer) { - auto serverKeysetProto = Message(); - auto maybeError = serverKeysetProto.readBinaryFromString( - buffer, mlir::concretelang::python::DESER_OPTIONS); - if (maybeError.has_failure()) { - throw std::runtime_error("Failed to deserialize server keyset." + - maybeError.as_failure().error().mesg); - } - auto serverKeyset = - concretelang::keysets::ServerKeyset::fromProto(serverKeysetProto); - concretelang::clientlib::EvaluationKeys output{serverKeyset}; - return output; -} - -std::string evaluationKeysSerialize( - concretelang::clientlib::EvaluationKeys &evaluationKeys) { - auto serverKeysetProto = evaluationKeys.keyset.toProto(); - auto maybeBuffer = serverKeysetProto.writeBinaryToString(); - if (maybeBuffer.has_failure()) { - throw std::runtime_error("Failed to serialize evaluation keys."); - } - return maybeBuffer.value(); -} - -std::unique_ptr -keySetUnserialize(const std::string &buffer) { - auto keysetProto = Message(); - auto maybeError = keysetProto.readBinaryFromString( - buffer, mlir::concretelang::python::DESER_OPTIONS); - if (maybeError.has_failure()) { - throw std::runtime_error("Failed to deserialize keyset." + - maybeError.as_failure().error().mesg); - } - auto keyset = concretelang::keysets::Keyset::fromProto(keysetProto); - concretelang::clientlib::KeySet output{keyset}; - return std::make_unique(std::move(output)); -} - -std::string keySetSerialize(concretelang::clientlib::KeySet &keySet) { - auto keysetProto = keySet.keyset.toProto(); - auto maybeBuffer = keysetProto.writeBinaryToString(); - if (maybeBuffer.has_failure()) { - throw std::runtime_error("Failed to serialize keys."); - } - return maybeBuffer.value(); -} - -concretelang::clientlib::SharedScalarOrTensorData -valueUnserialize(const std::string &buffer) { - auto inner = TransportValue(); - if (inner - .readBinaryFromString(buffer, - mlir::concretelang::python::DESER_OPTIONS) - .has_failure()) { - throw std::runtime_error("Failed to deserialize Value"); - } - return {inner}; -} - -std::string -valueSerialize(const concretelang::clientlib::SharedScalarOrTensorData &value) { - auto maybeString = value.value.writeBinaryToString(); - if (maybeString.has_failure()) { - throw std::runtime_error("Failed to serialize Value"); - } - return maybeString.value(); -} - -concretelang::clientlib::ValueExporter -createValueExporter(concretelang::clientlib::KeySet &keySet, - concretelang::clientlib::ClientParameters &clientParameters, - const std::string &circuitName) { - auto maybeProgram = ::concretelang::clientlib::ClientProgram::create( - clientParameters.programInfo.asReader(), keySet.keyset.client, - std::make_shared<::concretelang::csprng::EncryptionCSPRNG>( - ::concretelang::csprng::EncryptionCSPRNG(0)), - false); - if (maybeProgram.has_failure()) { - throw std::runtime_error(maybeProgram.as_failure().error().mesg); - } - auto maybeCircuit = maybeProgram.value().getClientCircuit(circuitName); - return ::concretelang::clientlib::ValueExporter{maybeCircuit.value()}; -} - -concretelang::clientlib::SimulatedValueExporter createSimulatedValueExporter( - concretelang::clientlib::ClientParameters &clientParameters, - const std::string &circuitName) { - - auto maybeProgram = ::concretelang::clientlib::ClientProgram::create( - clientParameters.programInfo, ::concretelang::keysets::ClientKeyset(), - std::make_shared<::concretelang::csprng::EncryptionCSPRNG>( - ::concretelang::csprng::EncryptionCSPRNG(0)), - true); - if (maybeProgram.has_failure()) { - throw std::runtime_error(maybeProgram.as_failure().error().mesg); - } - auto maybeCircuit = maybeProgram.value().getClientCircuit(circuitName); - return ::concretelang::clientlib::SimulatedValueExporter{ - maybeCircuit.value()}; -} - -concretelang::clientlib::ValueDecrypter createValueDecrypter( - concretelang::clientlib::KeySet &keySet, - concretelang::clientlib::ClientParameters &clientParameters, - const std::string &circuitName) { - - auto maybeProgram = ::concretelang::clientlib::ClientProgram::create( - clientParameters.programInfo.asReader(), keySet.keyset.client, - std::make_shared<::concretelang::csprng::EncryptionCSPRNG>( - ::concretelang::csprng::EncryptionCSPRNG(0)), - false); - if (maybeProgram.has_failure()) { - throw std::runtime_error(maybeProgram.as_failure().error().mesg); - } - auto maybeCircuit = maybeProgram.value().getClientCircuit(circuitName); - return ::concretelang::clientlib::ValueDecrypter{maybeCircuit.value()}; -} - -concretelang::clientlib::SimulatedValueDecrypter createSimulatedValueDecrypter( - concretelang::clientlib::ClientParameters &clientParameters, - const std::string &circuitName) { - - auto maybeProgram = ::concretelang::clientlib::ClientProgram::create( - clientParameters.programInfo.asReader(), - ::concretelang::keysets::ClientKeyset(), - std::make_shared<::concretelang::csprng::EncryptionCSPRNG>( - ::concretelang::csprng::EncryptionCSPRNG(0)), - true); - if (maybeProgram.has_failure()) { - throw std::runtime_error(maybeProgram.as_failure().error().mesg); - } - auto maybeCircuit = maybeProgram.value().getClientCircuit(circuitName); - return ::concretelang::clientlib::SimulatedValueDecrypter{ - maybeCircuit.value()}; -} - -concretelang::clientlib::ClientParameters -clientParametersUnserialize(const std::string &json) { - auto programInfo = Message(); - if (programInfo.readJsonFromString(json).has_failure()) { - throw std::runtime_error("Failed to deserialize client parameters"); - } - return concretelang::clientlib::ClientParameters{programInfo, {}, {}, {}, {}}; -} - -std::string -clientParametersSerialize(concretelang::clientlib::ClientParameters ¶ms) { - auto maybeJson = params.programInfo.writeJsonToString(); - if (maybeJson.has_failure()) { - throw std::runtime_error("Failed to serialize client parameters"); - } - return maybeJson.value(); -} - void terminateDataflowParallelization() { _dfr_terminate(); } void initDataflowParallelization() { @@ -491,226 +135,98 @@ std::string roundTrip(const char *module) { return os.str(); } -bool lambdaArgumentIsTensor(lambdaArgument &lambda_arg) { - return !lambda_arg.ptr->value.isScalar(); -} - -std::vector lambdaArgumentGetTensorData(lambdaArgument &lambda_arg) { - if (auto tensor = lambda_arg.ptr->value.getTensor(); tensor) { - Tensor out = (Tensor)tensor.value(); - return out.values; - } else if (auto tensor = lambda_arg.ptr->value.getTensor(); - tensor) { - Tensor out = (Tensor)tensor.value(); - return out.values; - } else if (auto tensor = lambda_arg.ptr->value.getTensor(); - tensor) { - Tensor out = (Tensor)tensor.value(); - return out.values; - } else if (auto tensor = lambda_arg.ptr->value.getTensor(); - tensor) { - return tensor.value().values; - } else { - throw std::invalid_argument( - "LambdaArgument isn't a tensor or has an unsupported bitwidth"); - } -} - -std::vector -lambdaArgumentGetSignedTensorData(lambdaArgument &lambda_arg) { - if (auto tensor = lambda_arg.ptr->value.getTensor(); tensor) { - Tensor out = (Tensor)tensor.value(); - return out.values; - } else if (auto tensor = lambda_arg.ptr->value.getTensor(); tensor) { - Tensor out = (Tensor)tensor.value(); - return out.values; - } else if (auto tensor = lambda_arg.ptr->value.getTensor(); tensor) { - Tensor out = (Tensor)tensor.value(); - return out.values; - } else if (auto tensor = lambda_arg.ptr->value.getTensor(); tensor) { - return tensor.value().values; - } else { - throw std::invalid_argument( - "LambdaArgument isn't a tensor or has an unsupported bitwidth"); - } -} - -std::vector -lambdaArgumentGetTensorDimensions(lambdaArgument &lambda_arg) { - std::vector dims = lambda_arg.ptr->value.getDimensions(); - return {dims.begin(), dims.end()}; -} - -bool lambdaArgumentIsScalar(lambdaArgument &lambda_arg) { - return lambda_arg.ptr->value.isScalar(); -} - -bool lambdaArgumentIsSigned(lambdaArgument &lambda_arg) { - return lambda_arg.ptr->value.isSigned(); -} - -uint64_t lambdaArgumentGetScalar(lambdaArgument &lambda_arg) { - if (lambda_arg.ptr->value.isScalar() && - lambda_arg.ptr->value.hasElementType()) { - return lambda_arg.ptr->value.getTensor()->values[0]; +// Every number sent by python through the API has a type `int64` that must be +// turned into the proper type expected by the ArgTransformers. This allows to +// get an extra transformer executed right before the ArgTransformer gets +// called. +std::function +getPythonTypeTransformer(const Message &info) { + if (info.asReader().getTypeInfo().hasIndex()) { + return [=](Value input) { + Tensor tensorInput = input.getTensor().value(); + return Value{(Tensor)tensorInput}; + }; + } else if (info.asReader().getTypeInfo().hasPlaintext()) { + if (info.asReader().getTypeInfo().getPlaintext().getIntegerPrecision() <= + 8) { + return [=](Value input) { + Tensor tensorInput = input.getTensor().value(); + return Value{(Tensor)tensorInput}; + }; + } + if (info.asReader().getTypeInfo().getPlaintext().getIntegerPrecision() <= + 16) { + return [=](Value input) { + Tensor tensorInput = input.getTensor().value(); + return Value{(Tensor)tensorInput}; + }; + } + if (info.asReader().getTypeInfo().getPlaintext().getIntegerPrecision() <= + 32) { + return [=](Value input) { + Tensor tensorInput = input.getTensor().value(); + return Value{(Tensor)tensorInput}; + }; + } + if (info.asReader().getTypeInfo().getPlaintext().getIntegerPrecision() <= + 64) { + return [=](Value input) { + Tensor tensorInput = input.getTensor().value(); + return Value{(Tensor)tensorInput}; + }; + } + assert(false); + } else if (info.asReader().getTypeInfo().hasLweCiphertext()) { + if (info.asReader() + .getTypeInfo() + .getLweCiphertext() + .getEncoding() + .hasInteger() && + info.asReader() + .getTypeInfo() + .getLweCiphertext() + .getEncoding() + .getInteger() + .getIsSigned()) { + return [=](Value input) { return input; }; + } else { + return [=](Value input) { + Tensor tensorInput = input.getTensor().value(); + return Value{(Tensor)tensorInput}; + }; + } } else { - throw std::invalid_argument("LambdaArgument isn't a scalar, should " - "be an IntLambdaArgument"); + assert(false); } -} +}; -int64_t lambdaArgumentGetSignedScalar(lambdaArgument &lambda_arg) { - if (lambda_arg.ptr->value.isScalar() && - lambda_arg.ptr->value.hasElementType()) { - return lambda_arg.ptr->value.getTensor()->values[0]; - } else { - throw std::invalid_argument("LambdaArgument isn't a scalar, should " - "be an IntLambdaArgument"); +template Tensor arrayToTensor(pybind11::array &input) { + auto data_ptr = (const T *)input.data(); + std::vector data = std::vector(data_ptr, data_ptr + input.size()); + auto dims = std::vector(input.ndim(), 0); + for (ssize_t i = 0; i < input.ndim(); i++) { + dims[i] = input.shape(i); } + return std::move(Tensor(std::move(data), std::move(dims))); } -lambdaArgument lambdaArgumentFromTensorU8(std::vector data, - std::vector dimensions) { - std::vector dims(dimensions.begin(), dimensions.end()); - - auto val = Value{((Tensor)Tensor(data, dims))}; - mlir::concretelang::LambdaArgument out{val}; - lambdaArgument tensor_arg{ - std::make_shared(std::move(out))}; - return tensor_arg; -} - -lambdaArgument lambdaArgumentFromTensorI8(std::vector data, - std::vector dimensions) { - std::vector dims(dimensions.begin(), dimensions.end()); - auto val = Value{((Tensor)Tensor(data, dims))}; - mlir::concretelang::LambdaArgument out{val}; - lambdaArgument tensor_arg{ - std::make_shared(std::move(out))}; - return tensor_arg; -} - -lambdaArgument lambdaArgumentFromTensorU16(std::vector data, - std::vector dimensions) { - std::vector dims(dimensions.begin(), dimensions.end()); - auto val = Value{((Tensor)Tensor(data, dims))}; - mlir::concretelang::LambdaArgument out{val}; - lambdaArgument tensor_arg{ - std::make_shared(std::move(out))}; - return tensor_arg; -} - -lambdaArgument lambdaArgumentFromTensorI16(std::vector data, - std::vector dimensions) { - std::vector dims(dimensions.begin(), dimensions.end()); - auto val = Value{((Tensor)Tensor(data, dims))}; - mlir::concretelang::LambdaArgument out{val}; - lambdaArgument tensor_arg{ - std::make_shared(std::move(out))}; - return tensor_arg; -} - -lambdaArgument lambdaArgumentFromTensorU32(std::vector data, - std::vector dimensions) { - std::vector dims(dimensions.begin(), dimensions.end()); - auto val = Value{((Tensor)Tensor(data, dims))}; - mlir::concretelang::LambdaArgument out{val}; - lambdaArgument tensor_arg{ - std::make_shared(std::move(out))}; - return tensor_arg; -} - -lambdaArgument lambdaArgumentFromTensorI32(std::vector data, - std::vector dimensions) { - std::vector dims(dimensions.begin(), dimensions.end()); - auto val = Value{((Tensor)Tensor(data, dims))}; - mlir::concretelang::LambdaArgument out{val}; - lambdaArgument tensor_arg{ - std::make_shared(std::move(out))}; - return tensor_arg; -} - -lambdaArgument lambdaArgumentFromTensorU64(std::vector data, - std::vector dimensions) { - std::vector dims(dimensions.begin(), dimensions.end()); - auto val = Value{((Tensor)Tensor(data, dims))}; - mlir::concretelang::LambdaArgument out{val}; - lambdaArgument tensor_arg{ - std::make_shared(std::move(out))}; - return tensor_arg; -} - -lambdaArgument lambdaArgumentFromTensorI64(std::vector data, - std::vector dimensions) { - std::vector dims(dimensions.begin(), dimensions.end()); - auto val = Value{((Tensor)Tensor(data, dims))}; - mlir::concretelang::LambdaArgument out{val}; - lambdaArgument tensor_arg{ - std::make_shared(std::move(out))}; - return tensor_arg; -} - -lambdaArgument lambdaArgumentFromScalar(uint64_t scalar) { - auto val = Value{((Tensor)Tensor(scalar))}; - mlir::concretelang::LambdaArgument out{val}; - lambdaArgument scalar_arg{ - std::make_shared(std::move(out))}; - return scalar_arg; -} - -lambdaArgument lambdaArgumentFromSignedScalar(int64_t scalar) { - auto val = Value{Tensor(scalar)}; - mlir::concretelang::LambdaArgument out{val}; - lambdaArgument scalar_arg{ - std::make_shared(std::move(out))}; - return scalar_arg; -} - -concreteprotocol::LweCiphertextEncryptionInfo::Reader -clientParametersInputEncryptionAt( - ::concretelang::clientlib::ClientParameters &clientParameters, - size_t inputIdx, std::string circuitName) { - auto reader = clientParameters.programInfo.asReader(); - if (!reader.hasCircuits()) { - throw std::runtime_error("can't get keyid: no circuit info"); - } - auto circuits = reader.getCircuits(); - for (auto circuit : circuits) { - if (circuit.hasName() && - circuitName.compare(circuit.getName().cStr()) == 0) { - if (!circuit.hasInputs()) { - throw std::runtime_error("can't get keyid: no input"); - } - auto inputs = circuit.getInputs(); - if (inputIdx >= inputs.size()) { - throw std::runtime_error( - "can't get keyid: inputIdx bigger than number of inputs"); - } - auto input = inputs[inputIdx]; - if (!input.hasTypeInfo()) { - throw std::runtime_error("can't get keyid: input don't have typeInfo"); - } - auto typeInfo = input.getTypeInfo(); - if (!typeInfo.hasLweCiphertext()) { - throw std::runtime_error("can't get keyid: typeInfo don't " - "have lwe ciphertext info"); - } - auto lweCt = typeInfo.getLweCiphertext(); - if (!lweCt.hasEncryption()) { - throw std::runtime_error("can't get keyid: lwe ciphertext " - "don't have encryption info"); - } - return lweCt.getEncryption(); - } - } - - throw std::runtime_error("can't get keyid: no circuit with name " + - circuitName); +template pybind11::array tensorToArray(Tensor input) { + return pybind11::array(pybind11::array::ShapeContainer(input.dimensions), + input.values.data()); } +} // namespace /// Populate the compiler API python module. void mlir::concretelang::python::populateCompilerAPISubmodule( pybind11::module &m) { + + using ::concretelang::csprng::EncryptionCSPRNG; + using ::concretelang::values::Value; + using pybind11::arg; + using pybind11::array; + using pybind11::init; + using Library = CompilerEngine::Library; + m.doc() = "Concretelang compiler python API"; m.def("round_trip", @@ -727,46 +243,6 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( m.def("check_gpu_runtime_enabled", &checkGPURuntimeEnabled); m.def("check_cuda_device_available", &checkCudaDeviceAvailable); - m.def("import_tfhers_fheuint8", - [](const pybind11::bytes &serialized_fheuint, - TfhersFheIntDescription info, uint32_t encryptionKeyId, - double encryptionVariance) { - const std::string &buffer_str = serialized_fheuint; - std::vector buffer(buffer_str.begin(), buffer_str.end()); - auto arrayRef = llvm::ArrayRef(buffer); - auto valueOrError = ::concretelang::clientlib::importTfhersFheUint8( - arrayRef, info, encryptionKeyId, encryptionVariance); - if (valueOrError.has_error()) { - throw std::runtime_error(valueOrError.error().mesg); - } - return ::concretelang::clientlib::SharedScalarOrTensorData{ - valueOrError.value()}; - }); - - m.def("export_tfhers_fheuint8", - [](::concretelang::clientlib::SharedScalarOrTensorData fheuint, - TfhersFheIntDescription info) { - auto result = ::concretelang::clientlib::exportTfhersFheUint8( - fheuint.value, info); - if (result.has_error()) { - throw std::runtime_error(result.error().mesg); - } - return result.value(); - }); - - m.def("get_tfhers_fheuint8_description", - [](const pybind11::bytes &serialized_fheuint) { - const std::string &buffer_str = serialized_fheuint; - std::vector buffer(buffer_str.begin(), buffer_str.end()); - auto arrayRef = llvm::ArrayRef(buffer); - auto info = - ::concretelang::clientlib::getTfhersFheUint8Description(arrayRef); - if (info.has_error()) { - throw std::runtime_error(info.error().mesg); - } - return info.value(); - }); - pybind11::class_(m, "TfhersFheIntDescription") .def(pybind11::init([](size_t width, bool is_signed, size_t message_modulus, size_t carry_modulus, @@ -839,8 +315,10 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( }); pybind11::enum_(m, "Backend") - .value("CPU", mlir::concretelang::Backend::CPU) - .value("GPU", mlir::concretelang::Backend::GPU) + .value("CPU", mlir::concretelang::Backend::CPU, + "Circuit codegen targets cpu.") + .value("GPU", mlir::concretelang::Backend::GPU, + "Circuit codegen tartgets gpu.") .export_values(); pybind11::enum_(m, "OptimizerStrategy") @@ -864,122 +342,194 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( pybind11::class_(m, "CompilationOptions") .def(pybind11::init([](mlir::concretelang::Backend backend) { - return CompilationOptions(backend); - })) - .def("set_verify_diagnostics", - [](CompilationOptions &options, bool b) { - options.verifyDiagnostics = b; - }) - .def("set_auto_parallelize", [](CompilationOptions &options, - bool b) { options.autoParallelize = b; }) - .def("set_loop_parallelize", [](CompilationOptions &options, - bool b) { options.loopParallelize = b; }) - .def("set_dataflow_parallelize", - [](CompilationOptions &options, bool b) { - options.dataflowParallelize = b; - }) - .def("set_compress_evaluation_keys", - [](CompilationOptions &options, bool b) { - options.compressEvaluationKeys = b; - }) - .def("set_compress_input_ciphertexts", - [](CompilationOptions &options, bool b) { - options.compressInputCiphertexts = b; - }) - .def("set_optimize_concrete", [](CompilationOptions &options, - bool b) { options.optimizeTFHE = b; }) - .def("set_p_error", - [](CompilationOptions &options, double p_error) { - options.optimizerConfig.p_error = p_error; - }) - .def("set_display_optimizer_choice", - [](CompilationOptions &options, bool display) { - options.optimizerConfig.display = display; - }) - .def("set_optimizer_strategy", - [](CompilationOptions &options, optimizer::Strategy strategy) { - options.optimizerConfig.strategy = strategy; - }) - .def("set_optimizer_multi_parameter_strategy", - [](CompilationOptions &options, - concrete_optimizer::MultiParamStrategy strategy) { - options.optimizerConfig.multi_param_strategy = strategy; - }) - .def("set_global_p_error", - [](CompilationOptions &options, double global_p_error) { - options.optimizerConfig.global_p_error = global_p_error; - }) - .def("add_composition", - [](CompilationOptions &options, std::string from_func, - size_t from_pos, std::string to_func, size_t to_pos) { - options.optimizerConfig.composition_rules.push_back( - {from_func, from_pos, to_func, to_pos}); - }) - .def("set_composable", - [](CompilationOptions &options, bool composable) { - options.optimizerConfig.composable = composable; - }) - .def("set_security_level", - [](CompilationOptions &options, int security_level) { - options.optimizerConfig.security = security_level; - }) - .def("set_v0_parameter", - [](CompilationOptions &options, size_t glweDimension, - size_t logPolynomialSize, size_t nSmall, size_t brLevel, - size_t brLogBase, size_t ksLevel, size_t ksLogBase) { - options.v0Parameter = {glweDimension, logPolynomialSize, nSmall, - brLevel, brLogBase, ksLevel, - ksLogBase, std::nullopt}; - }) - .def("set_v0_parameter", - [](CompilationOptions &options, size_t glweDimension, - size_t logPolynomialSize, size_t nSmall, size_t brLevel, - size_t brLogBase, size_t ksLevel, size_t ksLogBase, - mlir::concretelang::CRTDecomposition crtDecomposition, - size_t cbsLevel, size_t cbsLogBase, size_t pksLevel, - size_t pksLogBase, size_t pksInputLweDimension, - size_t pksOutputPolynomialSize) { - mlir::concretelang::PackingKeySwitchParameter pksParam = { - pksInputLweDimension, pksOutputPolynomialSize, pksLevel, - pksLogBase}; - mlir::concretelang::CitcuitBoostrapParameter crbParam = { - cbsLevel, cbsLogBase}; - mlir::concretelang::WopPBSParameter wopPBSParam = {pksParam, - crbParam}; - mlir::concretelang::LargeIntegerParameter largeIntegerParam = { - crtDecomposition, wopPBSParam}; - options.v0Parameter = {glweDimension, logPolynomialSize, nSmall, - brLevel, brLogBase, ksLevel, - ksLogBase, largeIntegerParam}; - }) - .def("force_encoding", - [](CompilationOptions &options, - concrete_optimizer::Encoding encoding) { - options.optimizerConfig.encoding = encoding; - }) - .def("simulation", [](CompilationOptions &options, - bool simulate) { options.simulate = simulate; }) - .def("set_emit_gpu_ops", - [](CompilationOptions &options, bool emit_gpu_ops) { - options.emitGPUOps = emit_gpu_ops; - }) - .def("set_batch_tfhe_ops", - [](CompilationOptions &options, bool batch_tfhe_ops) { - options.batchTFHEOps = batch_tfhe_ops; - }) - .def("set_enable_tlu_fusing", - [](CompilationOptions &options, bool enableTluFusing) { - options.enableTluFusing = enableTluFusing; - }) - .def("set_print_tlu_fusing", - [](CompilationOptions &options, bool printTluFusing) { - options.printTluFusing = printTluFusing; - }) - .def("set_enable_overflow_detection_in_simulation", - [](CompilationOptions &options, bool enableOverflowDetection) { - options.enableOverflowDetectionInSimulation = - enableOverflowDetection; - }); + return CompilationOptions(backend); + }), + arg("backend")) + .def( + "set_verify_diagnostics", + [](CompilationOptions &options, bool b) { + options.verifyDiagnostics = b; + }, + "Set option for diagnostics verification.", arg("verify_diagnostics")) + .def( + "set_auto_parallelize", + [](CompilationOptions &options, bool b) { + options.autoParallelize = b; + }, + "Set option for auto parallelization.", arg("auto_parallelize")) + .def( + "set_loop_parallelize", + [](CompilationOptions &options, bool b) { + options.loopParallelize = b; + }, + "Set option for loop parallelization.", arg("loop_parallelize")) + .def( + "set_dataflow_parallelize", + [](CompilationOptions &options, bool b) { + options.dataflowParallelize = b; + }, + "Set option for dataflow parallelization.", + arg("dataflow_parallelize")) + .def( + "set_compress_evaluation_keys", + [](CompilationOptions &options, bool b) { + options.compressEvaluationKeys = b; + }, + "Set option for compression of evaluation keys.", + arg("compress_evaluation_keys")) + .def( + "set_compress_input_ciphertexts", + [](CompilationOptions &options, bool b) { + options.compressInputCiphertexts = b; + }, + "Set option for compression of input ciphertexts.", + arg("compress_input_ciphertexts")) + .def( + "set_optimize_concrete", + [](CompilationOptions &options, bool b) { options.optimizeTFHE = b; }, + "Set flag to enable/disable optimization of concrete intermediate " + "representation.", + arg("optimize")) + .def( + "set_p_error", + [](CompilationOptions &options, double p_error) { + options.optimizerConfig.p_error = p_error; + }, + "Set error probability for shared by each pbs.", arg("p_error")) + .def( + "set_display_optimizer_choice", + [](CompilationOptions &options, bool display) { + options.optimizerConfig.display = display; + }, + "Set display flag of optimizer choices.", arg("display")) + .def( + "set_optimizer_strategy", + [](CompilationOptions &options, optimizer::Strategy strategy) { + options.optimizerConfig.strategy = strategy; + }, + "Set the strategy of the optimizer.", arg("strategy")) + .def( + "set_optimizer_multi_parameter_strategy", + [](CompilationOptions &options, + concrete_optimizer::MultiParamStrategy strategy) { + options.optimizerConfig.multi_param_strategy = strategy; + }, + "Set the strategy of the optimizer for multi-parameter.", + arg("strategy")) + .def( + "set_global_p_error", + [](CompilationOptions &options, double global_p_error) { + options.optimizerConfig.global_p_error = global_p_error; + }, + "Set global error probability for the full circuit.", + arg("global_p_error")) + .def( + "add_composition", + [](CompilationOptions &options, std::string from_func, + size_t from_pos, std::string to_func, size_t to_pos) { + options.optimizerConfig.composition_rules.push_back( + {from_func, from_pos, to_func, to_pos}); + }, + "Add a composition rule.", arg("from_func"), arg("from_pos"), + arg("to_func"), arg("to_pos")) + .def( + "set_composable", + [](CompilationOptions &options, bool composable) { + options.optimizerConfig.composable = composable; + }, + "Set composable flag.", arg("composable")) + .def( + "set_security_level", + [](CompilationOptions &options, int security_level) { + options.optimizerConfig.security = security_level; + }, + "Set security level.", arg("security_level")) + .def( + "set_v0_parameter", + [](CompilationOptions &options, size_t glweDimension, + size_t logPolynomialSize, size_t nSmall, size_t brLevel, + size_t brLogBase, size_t ksLevel, size_t ksLogBase) { + options.v0Parameter = {glweDimension, logPolynomialSize, nSmall, + brLevel, brLogBase, ksLevel, + ksLogBase, std::nullopt}; + }, + "Set the basic V0 parameters.", arg("glwe_dimension"), + arg("log_poly_size"), arg("n_small"), arg("br_level"), + arg("br_log_base"), arg("ks_level"), arg("ks_log_base")) + .def( + "set_all_v0_parameter", + [](CompilationOptions &options, size_t glweDimension, + size_t logPolynomialSize, size_t nSmall, size_t brLevel, + size_t brLogBase, size_t ksLevel, size_t ksLogBase, + std::vector crtDecomposition, size_t cbsLevel, + size_t cbsLogBase, size_t pksLevel, size_t pksLogBase, + size_t pksInputLweDimension, size_t pksOutputPolynomialSize) { + mlir::concretelang::PackingKeySwitchParameter pksParam = { + pksInputLweDimension, pksOutputPolynomialSize, pksLevel, + pksLogBase}; + mlir::concretelang::CitcuitBoostrapParameter crbParam = { + cbsLevel, cbsLogBase}; + mlir::concretelang::WopPBSParameter wopPBSParam = {pksParam, + crbParam}; + mlir::concretelang::LargeIntegerParameter largeIntegerParam = { + crtDecomposition, wopPBSParam}; + options.v0Parameter = {glweDimension, logPolynomialSize, nSmall, + brLevel, brLogBase, ksLevel, + ksLogBase, largeIntegerParam}; + }, + "Set all the V0 parameters.", arg("glwe_dimension"), + arg("log_poly_size"), arg("n_small"), arg("br_level"), + arg("br_log_base"), arg("ks_level"), arg("ks_log_base"), + arg("crt_decomp"), arg("cbs_level"), arg("cbs_log_base"), + arg("pks_level"), arg("pks_log_base"), arg("pks_input_lwe_dim"), + arg("pks_output_poly_size")) + .def( + "force_encoding", + [](CompilationOptions &options, + concrete_optimizer::Encoding encoding) { + options.optimizerConfig.encoding = encoding; + }, + "Force the compiler to use a specific encoding.", arg("encoding")) + .def( + "simulation", + [](CompilationOptions &options, bool simulate) { + options.simulate = simulate; + }, + "Enable or disable simulation.", arg("simulate")) + .def( + "set_emit_gpu_ops", + [](CompilationOptions &options, bool emit_gpu_ops) { + options.emitGPUOps = emit_gpu_ops; + }, + "Set flag that allows gpu ops to be emitted.", arg("emit_gpu_ops")) + .def( + "set_batch_tfhe_ops", + [](CompilationOptions &options, bool batch_tfhe_ops) { + options.batchTFHEOps = batch_tfhe_ops; + }, + "Set flag that triggers the batching of scalar TFHE operations.", + arg("batch_tfhe_ops")) + .def( + "set_enable_tlu_fusing", + [](CompilationOptions &options, bool enableTluFusing) { + options.enableTluFusing = enableTluFusing; + }, + "Enable or disable tlu fusing.", arg("enable_tlu_fusing")) + .def( + "set_print_tlu_fusing", + [](CompilationOptions &options, bool printTluFusing) { + options.printTluFusing = printTluFusing; + }, + "Enable or disable printing tlu fusing.", arg("print_tlu_fusing")) + .def( + "set_enable_overflow_detection_in_simulation", + [](CompilationOptions &options, bool enableOverflowDetection) { + options.enableOverflowDetectionInSimulation = + enableOverflowDetection; + }, + "Enable or disable overflow detection during simulation.", + arg("enable_overflow_detection")) + .doc() = "Holds different flags and options of the compilation process."; pybind11::enum_(m, "PrimitiveOperation") @@ -1010,7 +560,27 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( .def_readonly("keys", &mlir::concretelang::Statistic::keys) .def_readonly("count", &mlir::concretelang::Statistic::count); - pybind11::class_( + pybind11::class_( + m, "CircuitCompilationFeedback") + .def_readonly("name", + &mlir::concretelang::CircuitCompilationFeedback::name) + .def_readonly( + "total_inputs_size", + &mlir::concretelang::CircuitCompilationFeedback::totalInputsSize) + .def_readonly( + "total_output_size", + &mlir::concretelang::CircuitCompilationFeedback::totalOutputsSize) + .def_readonly("crt_decompositions_of_outputs", + &mlir::concretelang::CircuitCompilationFeedback:: + crtDecompositionsOfOutputs) + .def_readonly("statistics", + &mlir::concretelang::CircuitCompilationFeedback::statistics) + .def_readonly( + "memory_usage_per_location", + &mlir::concretelang::CircuitCompilationFeedback::memoryUsagePerLoc) + .doc() = "Compilation feedback for a single circuit."; + + pybind11::class_( m, "ProgramCompilationFeedback") .def_readonly("complexity", &mlir::concretelang::ProgramCompilationFeedback::complexity) @@ -1030,26 +600,21 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( totalKeyswitchKeysSize) .def_readonly( "circuit_feedbacks", - &mlir::concretelang::ProgramCompilationFeedback::circuitFeedbacks); - - pybind11::class_( - m, "CircuitCompilationFeedback") - .def_readonly("name", - &mlir::concretelang::CircuitCompilationFeedback::name) - .def_readonly( - "total_inputs_size", - &mlir::concretelang::CircuitCompilationFeedback::totalInputsSize) - .def_readonly( - "total_output_size", - &mlir::concretelang::CircuitCompilationFeedback::totalOutputsSize) - .def_readonly("crt_decompositions_of_outputs", - &mlir::concretelang::CircuitCompilationFeedback:: - crtDecompositionsOfOutputs) - .def_readonly("statistics", - &mlir::concretelang::CircuitCompilationFeedback::statistics) - .def_readonly( - "memory_usage_per_location", - &mlir::concretelang::CircuitCompilationFeedback::memoryUsagePerLoc); + &mlir::concretelang::ProgramCompilationFeedback::circuitFeedbacks) + .def( + "get_circuit_feedback", + [](mlir::concretelang::ProgramCompilationFeedback &feedback, + std::string function) { + for (auto circuit : feedback.circuitFeedbacks) { + if (circuit.name == function) { + return circuit; + } + } + throw std::runtime_error( + "Circuit feedback not found fo passed function."); + }, + "Return the circuit feedback for `function`.", arg("function")) + .doc() = "Compilation feedback for a whole program."; pybind11::class_>( @@ -1064,332 +629,586 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( mlirPythonContextToCapsule(wrap(mlirCtx))); }); - pybind11::class_( - m, "LibraryCompilationResult") - .def(pybind11::init([](std::string outputDirPath) { - return mlir::concretelang::LibraryCompilationResult{outputDirPath}; - })); - pybind11::class_<::concretelang::serverlib::ServerLambda>(m, "LibraryLambda"); - pybind11::class_(m, "LibrarySupport") - .def(pybind11::init( - [](std::string outputPath, std::string runtimeLibraryPath, - bool generateSharedLib, bool generateStaticLib, - bool generateClientParameters, bool generateCompilationFeedback, - bool generateCppHeader) { - return library_support( - outputPath.c_str(), runtimeLibraryPath.c_str(), - generateSharedLib, generateStaticLib, generateClientParameters, - generateCompilationFeedback, generateCppHeader); - })) - .def("compile", - [](LibrarySupport_Py &support, std::string mlir_program, - mlir::concretelang::CompilationOptions options) { - SignalGuard signalGuard; - return library_compile(support, mlir_program.c_str(), options); - }) - .def("compile", - [](LibrarySupport_Py &support, pybind11::object mlir_module, - mlir::concretelang::CompilationOptions options, - std::shared_ptr cctx) { - SignalGuard signalGuard; - return library_compile_module( - support, - unwrap(mlirPythonCapsuleToModule(mlir_module.ptr())).clone(), - options, cctx); - }) - .def("load_client_parameters", - [](LibrarySupport_Py &support, - mlir::concretelang::LibraryCompilationResult &result) { - return library_load_client_parameters(support, result); - }) - .def("load_compilation_feedback", - [](LibrarySupport_Py &support, - mlir::concretelang::LibraryCompilationResult &result) { - return library_load_compilation_feedback(support, result); - }) - .def( - "load_server_lambda", - [](LibrarySupport_Py &support, - mlir::concretelang::LibraryCompilationResult &result, - std::string circuitName, bool useSimulation) { - return library_load_server_lambda(support, result, circuitName, - useSimulation); - }, - pybind11::return_value_policy::reference) - .def("server_call", - [](LibrarySupport_Py &support, - ::concretelang::serverlib::ServerLambda lambda, - ::concretelang::clientlib::PublicArguments &publicArguments, - ::concretelang::clientlib::EvaluationKeys &evaluationKeys) { - pybind11::gil_scoped_release release; - SignalGuard signalGuard; - return library_server_call(support, lambda, publicArguments, - evaluationKeys); - }) - .def("simulate", - [](LibrarySupport_Py &support, - ::concretelang::serverlib::ServerLambda lambda, - ::concretelang::clientlib::PublicArguments &publicArguments) { - pybind11::gil_scoped_release release; - return library_simulate(support, lambda, publicArguments); - }) - .def("get_shared_lib_path", - [](LibrarySupport_Py &support) { - return library_get_shared_lib_path(support); - }) - .def("get_program_info_path", [](LibrarySupport_Py &support) { - return library_get_program_info_path(support); - }); + // ------------------------------------------------------------------------------// + // LWE SECRET KEY PARAM // + // ------------------------------------------------------------------------------// - class ClientSupport {}; - pybind11::class_(m, "ClientSupport") - .def(pybind11::init()) - .def_static( - "key_set", - [](::concretelang::clientlib::ClientParameters clientParameters, - ::concretelang::clientlib::KeySetCache *cache, - uint64_t secretSeedMsb, uint64_t secretSeedLsb, - uint64_t encSeedMsb, uint64_t encSeedLsb, - std::map initialLweSecretKeys) { - SignalGuard signalGuard; - auto optCache = - cache == nullptr - ? std::nullopt - : std::optional<::concretelang::clientlib::KeySetCache>( - *cache); - return key_set(clientParameters, optCache, initialLweSecretKeys, - secretSeedMsb, secretSeedLsb, encSeedMsb, - encSeedLsb); - }, - pybind11::arg().none(false), pybind11::arg().none(true), - pybind11::arg("secretSeedMsb") = 0, - pybind11::arg("secretSeedLsb") = 0, pybind11::arg("encSeedMsb") = 0, - pybind11::arg("encSeedLsb") = 0, - pybind11::arg("initialLweSecretKeys") = - std::map()) - .def_static( - "encrypt_arguments", - [](::concretelang::clientlib::ClientParameters clientParameters, - ::concretelang::clientlib::KeySet &keySet, - std::vector args, const std::string &circuitName) { - std::vector argsRef; - for (auto i = 0u; i < args.size(); i++) { - argsRef.push_back(args[i].ptr.get()); + struct LweSecretKeyParam { + Message info; + + std::string toString() { + std::string output = "LweSecretKeyParam(dimension="; + output.append( + std::to_string(info.asReader().getParams().getLweDimension())); + output.append(")"); + return output; + } + }; + pybind11::class_(m, "LweSecretKeyParam") + .def( + "dimension", + [](LweSecretKeyParam &key) { + return key.info.asReader().getParams().getLweDimension(); + }, + "Return the associated LWE dimension.") + .def("__str__", [](LweSecretKeyParam &key) { return key.toString(); }) + .def("__repr__", [](LweSecretKeyParam &key) { return key.toString(); }) + .def("__hash__", + [](pybind11::object key) { + return pybind11::hash(pybind11::repr(key)); + }) + .doc() = "Parameters of an LWE Secret Key."; + + // ------------------------------------------------------------------------------// + // BOOTSTRAP KEY PARAM // + // ------------------------------------------------------------------------------// + + struct BootstrapKeyParam { + Message info; + + std::string toString() { + std::string output = "BootstrapKeyParam("; + output.append("polynomial_size="); + output.append( + std::to_string(info.asReader().getParams().getPolynomialSize())); + output.append(", glwe_dimension="); + output.append( + std::to_string(info.asReader().getParams().getGlweDimension())); + output.append(", input_lwe_dimension="); + output.append( + std::to_string(info.asReader().getParams().getInputLweDimension())); + output.append(", level="); + output.append( + std::to_string(info.asReader().getParams().getLevelCount())); + output.append(", base_log="); + output.append(std::to_string(info.asReader().getParams().getBaseLog())); + output.append(", variance="); + output.append(std::to_string(info.asReader().getParams().getVariance())); + output.append(")"); + return output; + } + }; + pybind11::class_(m, "BootstrapKeyParam") + .def( + "input_secret_key_id", + [](BootstrapKeyParam &key) { + return key.info.asReader().getInputId(); + }, + "Return the key id of the associated input key.") + .def( + "output_secret_key_id", + [](BootstrapKeyParam &key) { + return key.info.asReader().getOutputId(); + }, + "Return the key id of the associated output key.") + .def( + "level", + [](BootstrapKeyParam &key) { + return key.info.asReader().getParams().getLevelCount(); + }, + "Return the associated number of levels.") + .def( + "base_log", + [](BootstrapKeyParam &key) { + return key.info.asReader().getParams().getBaseLog(); + }, + "Return the associated base log.") + .def( + "glwe_dimension", + [](BootstrapKeyParam &key) { + return key.info.asReader().getParams().getGlweDimension(); + }, + "Return the associated GLWE dimension.") + .def( + "variance", + [](BootstrapKeyParam &key) { + return key.info.asReader().getParams().getVariance(); + }, + "Return the associated noise variance.") + .def( + "polynomial_size", + [](BootstrapKeyParam &key) { + return key.info.asReader().getParams().getPolynomialSize(); + }, + "Return the associated polynomial size.") + .def( + "input_lwe_dimension", + [](BootstrapKeyParam &key) { + return key.info.asReader().getParams().getInputLweDimension(); + }, + "Return the associated input lwe dimension.") + .def("__str__", [](BootstrapKeyParam &key) { return key.toString(); }) + .def("__repr__", [](BootstrapKeyParam &key) { return key.toString(); }) + .def("__hash__", + [](pybind11::object key) { + return pybind11::hash(pybind11::repr(key)); + }) + .doc() = "Parameters of a Bootstrap key."; + + // ------------------------------------------------------------------------------// + // KEYSWITCH KEY PARAM // + // ------------------------------------------------------------------------------// + + struct KeyswitchKeyParam { + Message info; + + std::string toString() { + std::string output = "KeyswitchKeyParam("; + output.append("level="); + output.append( + std::to_string(info.asReader().getParams().getLevelCount())); + output.append(", base_log="); + output.append(std::to_string(info.asReader().getParams().getBaseLog())); + output.append(", variance="); + output.append(std::to_string(info.asReader().getParams().getVariance())); + output.append(")"); + return output; + } + }; + pybind11::class_(m, "KeyswitchKeyParam") + .def( + "input_secret_key_id", + [](KeyswitchKeyParam &key) { + return key.info.asReader().getInputId(); + }, + "Return the key id of the associated input key.") + .def( + "output_secret_key_id", + [](KeyswitchKeyParam &key) { + return key.info.asReader().getOutputId(); + }, + "Return the key id of the associated output key.") + .def( + "level", + [](KeyswitchKeyParam &key) { + return key.info.asReader().getParams().getLevelCount(); + }, + "Return the associated number of levels.") + .def( + "base_log", + [](KeyswitchKeyParam &key) { + return key.info.asReader().getParams().getBaseLog(); + }, + "Return the associated base log.") + .def( + "variance", + [](KeyswitchKeyParam &key) { + return key.info.asReader().getParams().getVariance(); + }, + "Return the associated noise variance.") + .def("__str__", [](KeyswitchKeyParam &key) { return key.toString(); }) + .def("__repr__", [](KeyswitchKeyParam &key) { return key.toString(); }) + .def("__hash__", + [](pybind11::object key) { + return pybind11::hash(pybind11::repr(key)); + }) + .doc() = "Parameters of a keyswitch key."; + + // ------------------------------------------------------------------------------// + // PACKING KEYSWITCH KEY PARAM // + // ------------------------------------------------------------------------------// + + struct PackingKeyswitchKeyParam { + Message info; + + std::string toString() { + std::string output = "PackingKeyswitchKeyParam("; + output.append("polynomial_size="); + output.append( + std::to_string(info.asReader().getParams().getPolynomialSize())); + output.append(", glwe_dimension="); + output.append( + std::to_string(info.asReader().getParams().getGlweDimension())); + output.append(", input_lwe_dimension="); + output.append( + std::to_string(info.asReader().getParams().getInputLweDimension())); + output.append(", level="); + output.append( + std::to_string(info.asReader().getParams().getLevelCount())); + output.append(", base_log="); + output.append(std::to_string(info.asReader().getParams().getBaseLog())); + output.append(", variance="); + output.append(std::to_string(info.asReader().getParams().getVariance())); + output.append(")"); + return output; + } + }; + pybind11::class_(m, "PackingKeyswitchKeyParam") + .def( + "input_secret_key_id", + [](PackingKeyswitchKeyParam &key) { + return key.info.asReader().getInputId(); + }, + "Return the key id of the associated input key.") + .def( + "output_secret_key_id", + [](PackingKeyswitchKeyParam &key) { + return key.info.asReader().getOutputId(); + }, + "Return the key id of the associated output key.") + .def( + "level", + [](PackingKeyswitchKeyParam &key) { + return key.info.asReader().getParams().getLevelCount(); + }, + "Return the associated number of levels.") + .def( + "base_log", + [](PackingKeyswitchKeyParam &key) { + return key.info.asReader().getParams().getBaseLog(); + }, + "Return the associated base log.") + .def( + "glwe_dimension", + [](PackingKeyswitchKeyParam &key) { + return key.info.asReader().getParams().getGlweDimension(); + }, + "Return the associated GLWE dimension.") + .def( + "polynomial_size", + [](PackingKeyswitchKeyParam &key) { + return key.info.asReader().getParams().getPolynomialSize(); + }, + "Return the associated polynomial size.") + .def( + "input_lwe_dimension", + [](PackingKeyswitchKeyParam &key) { + return key.info.asReader().getParams().getInputLweDimension(); + }, + "Return the associated input LWE dimension.") + .def( + "variance", + [](PackingKeyswitchKeyParam &key) { + return key.info.asReader().getParams().getVariance(); + }, + "Return the associated noise variance.") + .def("__str__", + [](PackingKeyswitchKeyParam &key) { return key.toString(); }) + .def("__repr__", + [](PackingKeyswitchKeyParam &key) { return key.toString(); }) + .def("__hash__", + [](pybind11::object key) { + return pybind11::hash(pybind11::repr(key)); + }) + .doc() = "Parameters of a packing keyswitch key."; + + // ------------------------------------------------------------------------------// + // TYPE INFO // + // ------------------------------------------------------------------------------// + typedef Message TypeInfo; + pybind11::class_(m, "TypeInfo") + .def( + "is_plaintext", + [](TypeInfo &type) { return type.asReader().hasPlaintext(); }, + "Return true if the type is plaintext") + .doc() = "Informations describing the type of a gate."; + + // ------------------------------------------------------------------------------// + // RAW INFO // + // ------------------------------------------------------------------------------// + typedef Message RawInfo; + pybind11::class_(m, "RawInfo") + .def( + "get_shape", + [](RawInfo &raw) { + auto output = std::vector(); + for (auto dim : raw.asReader().getShape().getDimensions()) { + output.push_back({dim}); } - return encrypt_arguments(clientParameters, keySet, argsRef, - circuitName); - }) - .def_static( - "decrypt_result", - [](::concretelang::clientlib::ClientParameters clientParameters, - ::concretelang::clientlib::KeySet &keySet, - ::concretelang::clientlib::PublicResult &publicResult, - const std::string &circuitName) { - return decrypt_result(clientParameters, keySet, publicResult, - circuitName); - }); - pybind11::class_<::concretelang::clientlib::KeySetCache>(m, "KeySetCache") - .def(pybind11::init()); - - pybind11::class_<::concretelang::clientlib::LweSecretKeyParam>( - m, "LweSecretKeyParam") - .def("dimension", [](::concretelang::clientlib::LweSecretKeyParam &key) { - return key.info.asReader().getParams().getLweDimension(); - }); - - pybind11::class_<::concretelang::clientlib::BootstrapKeyParam>( - m, "BootstrapKeyParam") - .def("input_secret_key_id", - [](::concretelang::clientlib::BootstrapKeyParam &key) { - return key.info.asReader().getInputId(); - }) - .def("output_secret_key_id", - [](::concretelang::clientlib::BootstrapKeyParam &key) { - return key.info.asReader().getOutputId(); - }) - .def("level", - [](::concretelang::clientlib::BootstrapKeyParam &key) { - return key.info.asReader().getParams().getLevelCount(); - }) - .def("base_log", - [](::concretelang::clientlib::BootstrapKeyParam &key) { - return key.info.asReader().getParams().getBaseLog(); - }) - .def("glwe_dimension", - [](::concretelang::clientlib::BootstrapKeyParam &key) { - return key.info.asReader().getParams().getGlweDimension(); - }) - .def("variance", - [](::concretelang::clientlib::BootstrapKeyParam &key) { - return key.info.asReader().getParams().getVariance(); - }) - .def("polynomial_size", - [](::concretelang::clientlib::BootstrapKeyParam &key) { - return key.info.asReader().getParams().getPolynomialSize(); - }) - .def("input_lwe_dimension", - [](::concretelang::clientlib::BootstrapKeyParam &key) { - return key.info.asReader().getParams().getInputLweDimension(); - }); + return output; + }, + "Return the shape associated to the raw info.") + .def( + "get_integer_precision", + [](RawInfo &raw) { return raw.asReader().getIntegerPrecision(); }, + "Return the integer precision associated to the raw info.") + .def( + "get_signedness", + [](RawInfo &raw) { return raw.asReader().getIsSigned(); }, + "Return the signedness associated to the raw info.") + .doc() = "Informations describing a raw type of gate."; + + // ------------------------------------------------------------------------------// + // GATE INFO // + // ------------------------------------------------------------------------------// + typedef Message GateInfo; + pybind11::class_(m, "GateInfo") + .def( + "get_type_info", + [](GateInfo &gate) -> TypeInfo { + return {gate.asReader().getTypeInfo()}; + }, + "Return the type associated to the gate.") + .def( + "get_raw_info", + [](GateInfo &gate) -> RawInfo { + return {gate.asReader().getRawInfo()}; + }, + "Return the raw type associated to the gate.") + .doc() = "Informations describing a circuit gate (input or output)."; + + // ------------------------------------------------------------------------------// + // CIRCUIT INFO // + // ------------------------------------------------------------------------------// + typedef Message CircuitInfo; + pybind11::class_(m, "CircuitInfo") + .def( + "get_name", + [](CircuitInfo &circuit) { + return circuit.asReader().getName().cStr(); + }, + "Return the name of the circuit") + .def( + "get_inputs", + [](CircuitInfo &circuit) -> std::vector { + auto output = std::vector(); + for (auto gate : circuit.asReader().getInputs()) { + output.push_back({gate}); + } + return output; + }, + "Return the input gates") + .def( + "get_outputs", + [](CircuitInfo &circuit) -> std::vector { + auto output = std::vector(); + for (auto gate : circuit.asReader().getOutputs()) { + output.push_back({gate}); + } + return output; + }, + "Return the output gates") + .doc() = "Informations describing a compiled circuit."; - pybind11::class_<::concretelang::clientlib::KeyswitchKeyParam>( - m, "KeyswitchKeyParam") - .def("input_secret_key_id", - [](::concretelang::clientlib::KeyswitchKeyParam &key) { - return key.info.asReader().getInputId(); - }) - .def("output_secret_key_id", - [](::concretelang::clientlib::KeyswitchKeyParam &key) { - return key.info.asReader().getOutputId(); - }) - .def("level", - [](::concretelang::clientlib::KeyswitchKeyParam &key) { - return key.info.asReader().getParams().getLevelCount(); - }) - .def("base_log", - [](::concretelang::clientlib::KeyswitchKeyParam &key) { - return key.info.asReader().getParams().getBaseLog(); - }) - .def("variance", [](::concretelang::clientlib::KeyswitchKeyParam &key) { - return key.info.asReader().getParams().getVariance(); - }); - - pybind11::class_<::concretelang::clientlib::PackingKeyswitchKeyParam>( - m, "PackingKeyswitchKeyParam") - .def("input_secret_key_id", - [](::concretelang::clientlib::PackingKeyswitchKeyParam &key) { - return key.info.asReader().getInputId(); - }) - .def("output_secret_key_id", - [](::concretelang::clientlib::PackingKeyswitchKeyParam &key) { - return key.info.asReader().getOutputId(); - }) - .def("level", - [](::concretelang::clientlib::PackingKeyswitchKeyParam &key) { - return key.info.asReader().getParams().getLevelCount(); - }) - .def("base_log", - [](::concretelang::clientlib::PackingKeyswitchKeyParam &key) { - return key.info.asReader().getParams().getBaseLog(); - }) - .def("glwe_dimension", - [](::concretelang::clientlib::PackingKeyswitchKeyParam &key) { - return key.info.asReader().getParams().getGlweDimension(); - }) - .def("polynomial_size", - [](::concretelang::clientlib::PackingKeyswitchKeyParam &key) { - return key.info.asReader().getParams().getPolynomialSize(); - }) - .def("input_lwe_dimension", - [](::concretelang::clientlib::PackingKeyswitchKeyParam &key) { - return key.info.asReader().getParams().getInputLweDimension(); - }) - .def("variance", - [](::concretelang::clientlib::PackingKeyswitchKeyParam &key) { - return key.info.asReader().getParams().getVariance(); - }); + // ------------------------------------------------------------------------------// + // PROGRAM INFO // + // ------------------------------------------------------------------------------// - pybind11::class_<::concretelang::clientlib::ClientParameters>( - m, "ClientParameters") - .def_static("deserialize", - [](const pybind11::bytes &buffer) { - return clientParametersUnserialize(buffer); - }) - .def("serialize", - [](::concretelang::clientlib::ClientParameters &clientParameters) { - return pybind11::bytes( - clientParametersSerialize(clientParameters)); - }) - .def("lwe_secret_key_param", - [](::concretelang::clientlib::ClientParameters &clientParameters, - size_t keyId) { - if (keyId >= clientParameters.secretKeys.size()) { - throw std::runtime_error("keyId bigger than the number of keys"); - } - return clientParameters.secretKeys[keyId]; - }) - .def("input_keyid_at", - [](::concretelang::clientlib::ClientParameters &clientParameters, - size_t inputIdx, std::string circuitName) { - auto encryption = clientParametersInputEncryptionAt( - clientParameters, inputIdx, circuitName); - return encryption.getKeyId(); - }) - .def("input_variance_at", - [](::concretelang::clientlib::ClientParameters &clientParameters, - size_t inputIdx, std::string circuitName) { - auto encryption = clientParametersInputEncryptionAt( - clientParameters, inputIdx, circuitName); - return encryption.getVariance(); - }) + struct ProgramInfo { + Message programInfo; + + concreteprotocol::LweCiphertextEncryptionInfo::Reader + inputEncryptionAt(size_t inputId, std::string circuitName) { + auto reader = programInfo.asReader(); + if (!reader.hasCircuits()) { + throw std::runtime_error("can't get keyid: no circuit info"); + } + auto circuits = reader.getCircuits(); + for (auto circuit : circuits) { + if (circuit.hasName() && + circuitName.compare(circuit.getName().cStr()) == 0) { + if (!circuit.hasInputs()) { + throw std::runtime_error("can't get keyid: no input"); + } + auto inputs = circuit.getInputs(); + if (inputId >= inputs.size()) { + throw std::runtime_error( + "can't get keyid: inputId bigger than number of inputs"); + } + auto input = inputs[inputId]; + if (!input.hasTypeInfo()) { + throw std::runtime_error( + "can't get keyid: input don't have typeInfo"); + } + auto typeInfo = input.getTypeInfo(); + if (!typeInfo.hasLweCiphertext()) { + throw std::runtime_error("can't get keyid: typeInfo don't " + "have lwe ciphertext info"); + } + auto lweCt = typeInfo.getLweCiphertext(); + if (!lweCt.hasEncryption()) { + throw std::runtime_error("can't get keyid: lwe ciphertext " + "don't have encryption info"); + } + return lweCt.getEncryption(); + } + } + + throw std::runtime_error("can't get keyid: no circuit with name " + + circuitName); + } + }; + pybind11::class_(m, "ProgramInfo") + .def_static( + "deserialize", + [](const pybind11::bytes &buffer) { + auto programInfo = Message(); + if (programInfo.readJsonFromString(buffer).has_failure()) { + throw std::runtime_error("Failed to deserialize program info"); + } + return ProgramInfo{programInfo}; + }, + "Deserialize a ProgramInfo from bytes.", arg("bytes")) + .def( + "serialize", + [](ProgramInfo &programInfo) { + auto programInfoSerialize = [](ProgramInfo ¶ms) { + auto maybeJson = params.programInfo.writeJsonToString(); + if (maybeJson.has_failure()) { + throw std::runtime_error("Failed to serialize program info"); + } + return maybeJson.value(); + }; + return pybind11::bytes(programInfoSerialize(programInfo)); + }, + "Serialize a ProgramInfo to bytes.") + .def( + "input_keyid_at", + [](ProgramInfo &programInfo, size_t pos, std::string circuitName) { + auto encryption = programInfo.inputEncryptionAt(pos, circuitName); + return encryption.getKeyId(); + }, + "Return the key id associated to the argument `pos` of circuit " + "`circuit_name`.", + arg("pos"), arg("circuit_name")) + .def( + "input_variance_at", + [](ProgramInfo &programInfo, size_t pos, std::string circuitName) { + auto encryption = programInfo.inputEncryptionAt(pos, circuitName); + return encryption.getVariance(); + }, + "Return the noise variance associated to the argument `pos` of " + "circuit `circuit_name`.", + arg("pos"), arg("circuit_name")) .def("function_list", - [](::concretelang::clientlib::ClientParameters &clientParameters) { + [](ProgramInfo &programInfo) { std::vector result; for (auto circuit : - clientParameters.programInfo.asReader().getCircuits()) { + programInfo.programInfo.asReader().getCircuits()) { result.push_back(circuit.getName()); } return result; }) - .def("output_signs", - [](::concretelang::clientlib::ClientParameters &clientParameters) { - std::vector result; - for (auto output : clientParameters.programInfo.asReader() - .getCircuits()[0] - .getOutputs()) { - if (output.getTypeInfo().hasLweCiphertext() && - output.getTypeInfo() - .getLweCiphertext() - .getEncoding() - .hasInteger()) { - result.push_back(output.getTypeInfo() - .getLweCiphertext() - .getEncoding() - .getInteger() - .getIsSigned()); - } else { - result.push_back(true); - } - } - return result; - }) - .def("input_signs", - [](::concretelang::clientlib::ClientParameters &clientParameters) { - std::vector result; - for (auto input : clientParameters.programInfo.asReader() + .def( + "output_signs", + [](ProgramInfo &programInfo) { + std::vector result; + for (auto output : programInfo.programInfo.asReader() .getCircuits()[0] - .getInputs()) { - if (input.getTypeInfo().hasLweCiphertext() && - input.getTypeInfo() - .getLweCiphertext() - .getEncoding() - .hasInteger()) { - result.push_back(input.getTypeInfo() - .getLweCiphertext() - .getEncoding() - .getInteger() - .getIsSigned()); - } else { - result.push_back(true); - } - } - return result; - }) - .def_readonly("secret_keys", - &::concretelang::clientlib::ClientParameters::secretKeys) - .def_readonly("bootstrap_keys", - &::concretelang::clientlib::ClientParameters::bootstrapKeys) - .def_readonly("keyswitch_keys", - &::concretelang::clientlib::ClientParameters::keyswitchKeys) - .def_readonly( + .getOutputs()) { + if (output.getTypeInfo().hasLweCiphertext() && + output.getTypeInfo() + .getLweCiphertext() + .getEncoding() + .hasInteger()) { + result.push_back(output.getTypeInfo() + .getLweCiphertext() + .getEncoding() + .getInteger() + .getIsSigned()); + } else { + result.push_back(true); + } + } + return result; + }, + "Return the signedness of the output of the first circuit.") + .def( + "input_signs", + [](ProgramInfo &programInfo) { + std::vector result; + for (auto input : programInfo.programInfo.asReader() + .getCircuits()[0] + .getInputs()) { + if (input.getTypeInfo().hasLweCiphertext() && + input.getTypeInfo() + .getLweCiphertext() + .getEncoding() + .hasInteger()) { + result.push_back(input.getTypeInfo() + .getLweCiphertext() + .getEncoding() + .getInteger() + .getIsSigned()); + } else { + result.push_back(true); + } + } + return result; + }, + "Return the signedness of the input of the first circuit.") + .def( + "secret_keys", + [](ProgramInfo &programInfo) { + auto secretKeys = std::vector(); + for (auto key : programInfo.programInfo.asReader() + .getKeyset() + .getLweSecretKeys()) { + secretKeys.push_back(LweSecretKeyParam{key}); + } + return secretKeys; + }, + "Return the parameters of the secret keys for this program.") + .def( + "bootstrap_keys", + [](ProgramInfo &programInfo) { + auto bootstrapKeys = std::vector(); + for (auto key : programInfo.programInfo.asReader() + .getKeyset() + .getLweBootstrapKeys()) { + bootstrapKeys.push_back(BootstrapKeyParam{key}); + } + return bootstrapKeys; + }, + "Return the parameters of the bootstrap keys for this program.") + .def( + "keyswitch_keys", + [](ProgramInfo &programInfo) { + auto keyswitchKeys = std::vector(); + for (auto key : programInfo.programInfo.asReader() + .getKeyset() + .getLweKeyswitchKeys()) { + keyswitchKeys.push_back(KeyswitchKeyParam{key}); + } + return keyswitchKeys; + }, + "Return the parameters of the keyswitch keys for this program.") + .def( "packing_keyswitch_keys", - &::concretelang::clientlib::ClientParameters::packingKeyswitchKeys); + [](ProgramInfo &programInfo) { + auto packingKeyswitchKeys = std::vector(); + for (auto key : programInfo.programInfo.asReader() + .getKeyset() + .getPackingKeyswitchKeys()) { + packingKeyswitchKeys.push_back(PackingKeyswitchKeyParam{key}); + } + return packingKeyswitchKeys; + }, + "Return the parameters of the packing keyswitch keys for this " + "program.") + .def( + "get_circuits", + [](ProgramInfo &programInfo) { + auto output = std::vector(); + for (auto circuit : + programInfo.programInfo.asReader().getCircuits()) { + output.push_back(circuit); + } + return output; + }, + "Return the circuits associated to the program.") + .def( + "get_circuit", + [](ProgramInfo &programInfo, std::string &name) -> CircuitInfo { + for (auto circuit : + programInfo.programInfo.asReader().getCircuits()) { + if (circuit.getName() == name) { + return circuit; + } + } + throw std::runtime_error("couldn't find circuit."); + }, + "Return the circuit associated to the program with given name.") + .doc() = "Informations describing a compiled program."; + + // ------------------------------------------------------------------------------// + // LWE SECRET KEY // + // ------------------------------------------------------------------------------// pybind11::class_(m, "LweSecretKey") .def_static( "deserialize", - [](pybind11::bytes buffer, - ::concretelang::clientlib::LweSecretKeyParam ¶ms) { + [](pybind11::bytes buffer, LweSecretKeyParam ¶ms) { std::string buffer_str = buffer; auto lwe_dim = params.info.asReader().getParams().getLweDimension(); auto lwe_size = concrete_cpu_lwe_secret_key_size_u64(lwe_dim); @@ -1401,27 +1220,30 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( return LweSecretKey( std::make_shared>(std::move(lwe_sk)), params.info); - }) - .def("serialize", - [](LweSecretKey &lweSk) { - auto skBuffer = lweSk.getBuffer(); - auto lwe_dimension = - lweSk.getInfo().asReader().getParams().getLweDimension(); - auto buffer_size = - concrete_cpu_lwe_secret_key_buffer_size_u64(lwe_dimension); - std::vector buffer(buffer_size, 0); - buffer_size = concrete_cpu_serialize_lwe_secret_key_u64( - skBuffer.data(), lwe_dimension, buffer.data(), buffer_size); - if (buffer_size == 0) { - throw std::runtime_error("couldn't serialize the secret key"); - } - auto bytes = pybind11::bytes((char *)buffer.data(), buffer_size); - return bytes; - }) + }, + "Deserialize an LweSecretKet from bytes and associated parameters.", + arg("buffer"), arg("params")) + .def( + "serialize", + [](LweSecretKey &lweSk) { + auto skBuffer = lweSk.getBuffer(); + auto lwe_dimension = + lweSk.getInfo().asReader().getParams().getLweDimension(); + auto buffer_size = + concrete_cpu_lwe_secret_key_buffer_size_u64(lwe_dimension); + std::vector buffer(buffer_size, 0); + buffer_size = concrete_cpu_serialize_lwe_secret_key_u64( + skBuffer.data(), lwe_dimension, buffer.data(), buffer_size); + if (buffer_size == 0) { + throw std::runtime_error("couldn't serialize the secret key"); + } + auto bytes = pybind11::bytes((char *)buffer.data(), buffer_size); + return bytes; + }, + "Serialize an LweSecretKey to bytes.") .def_static( "deserialize_from_glwe", - [](pybind11::bytes buffer, - ::concretelang::clientlib::LweSecretKeyParam ¶ms) { + [](pybind11::bytes buffer, LweSecretKeyParam ¶ms) { std::string buffer_str = buffer; auto lwe_dim = params.info.asReader().getParams().getLweDimension(); auto glwe_sk_size = concrete_cpu_lwe_secret_key_size_u64(lwe_dim); @@ -1433,380 +1255,667 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( return LweSecretKey( std::make_shared>(std::move(glwe_sk)), params.info); - }) - .def("serialize_as_glwe", - [](LweSecretKey &lweSk, size_t glwe_dimension, - size_t polynomial_size) { - auto skBuffer = lweSk.getBuffer(); - auto buffer_size = concrete_cpu_glwe_secret_key_buffer_size_u64( - glwe_dimension, polynomial_size); - std::vector buffer(buffer_size, 0); - buffer_size = concrete_cpu_serialize_glwe_secret_key_u64( - skBuffer.data(), glwe_dimension, polynomial_size, - buffer.data(), buffer_size); - if (buffer_size == 0) { - throw std::runtime_error("couldn't serialize the secret key"); - } - auto bytes = pybind11::bytes((char *)buffer.data(), buffer_size); - return bytes; - }) - .def_property_readonly("param", [](LweSecretKey &lweSk) { - return ::concretelang::clientlib::LweSecretKeyParam{lweSk.getInfo()}; - }); - - pybind11::class_<::concretelang::clientlib::KeySet>(m, "KeySet") - .def_static("deserialize", - [](const pybind11::bytes &buffer) { - std::unique_ptr<::concretelang::clientlib::KeySet> result = - keySetUnserialize(buffer); - return result; - }) - .def("serialize", - [](::concretelang::clientlib::KeySet &keySet) { - return pybind11::bytes(keySetSerialize(keySet)); - }) - .def("get_lwe_secret_key", - [](::concretelang::clientlib::KeySet &keySet, size_t keyIndex) { - auto secretKeys = keySet.keyset.client.lweSecretKeys; - if (keyIndex >= secretKeys.size()) { - throw std::runtime_error( - "keyIndex is bigger than the number of keys"); - } - return secretKeys[keyIndex]; - }) - .def("get_evaluation_keys", - [](::concretelang::clientlib::KeySet &keySet) { - return ::concretelang::clientlib::EvaluationKeys{ - keySet.keyset.server}; - }); - - pybind11::class_<::concretelang::clientlib::SharedScalarOrTensorData>(m, - "Value") - .def_static("deserialize", - [](const pybind11::bytes &buffer) { - return valueUnserialize(buffer); - }) + }, + "Deserialize an LweSecretKey from glwe encoded (tfhe-rs compatible) " + "bytes and associated parameters.", + arg("buffer"), arg("params")) .def( - "serialize", - [](const ::concretelang::clientlib::SharedScalarOrTensorData &value) { - return pybind11::bytes(valueSerialize(value)); - }); + "serialize_as_glwe", + [](LweSecretKey &lweSk, size_t glwe_dimension, + size_t polynomial_size) { + auto skBuffer = lweSk.getBuffer(); + auto buffer_size = concrete_cpu_glwe_secret_key_buffer_size_u64( + glwe_dimension, polynomial_size); + std::vector buffer(buffer_size, 0); + buffer_size = concrete_cpu_serialize_glwe_secret_key_u64( + skBuffer.data(), glwe_dimension, polynomial_size, buffer.data(), + buffer_size); + if (buffer_size == 0) { + throw std::runtime_error("couldn't serialize the secret key"); + } + auto bytes = pybind11::bytes((char *)buffer.data(), buffer_size); + return bytes; + }, + "Serialize an LweSecretKey to glwe encoded (tfhe-rs compatible) " + "bytes and associated parameters.", + arg("glwe_dimension"), arg("polynomial_size")) + .def_property_readonly( + "param", + [](LweSecretKey &lweSk) { + return LweSecretKeyParam{lweSk.getInfo()}; + }, + "Parameters associated to the key.") + .doc() = "Lwe secret key."; - pybind11::class_(m, "ServerProgram") - .def_static("load", - [](LibrarySupport_Py &support, bool useSimulation) { - GET_OR_THROW_EXPECTED(auto programInfo, - support.support.getProgramInfo()); - auto sharedLibPath = support.support.getSharedLibPath(); - GET_OR_THROW_RESULT( - auto result, - ServerProgram::load(programInfo.asReader(), - sharedLibPath, useSimulation)); - return result; - }) - .def("get_server_circuit", [](ServerProgram &program, - const std::string &circuitName) { - GET_OR_THROW_RESULT(auto result, program.getServerCircuit(circuitName)); - return result; - }); + // ------------------------------------------------------------------------------// + // KEYSET CACHE // + // ------------------------------------------------------------------------------// - pybind11::class_(m, "ServerCircuit") - .def("call", - [](ServerCircuit &circuit, - ::concretelang::clientlib::PublicArguments &publicArguments, - ::concretelang::clientlib::EvaluationKeys &evaluationKeys) { - SignalGuard signalGuard; - pybind11::gil_scoped_release release; - auto keyset = evaluationKeys.keyset; - auto values = publicArguments.values; - GET_OR_THROW_RESULT(auto output, circuit.call(keyset, values)); - ::concretelang::clientlib::PublicResult res{output}; - return std::make_unique<::concretelang::clientlib::PublicResult>( - std::move(res)); - }) - .def("simulate", - [](ServerCircuit &circuit, - ::concretelang::clientlib::PublicArguments &publicArguments) { - pybind11::gil_scoped_release release; - auto values = publicArguments.values; - GET_OR_THROW_RESULT(auto output, circuit.simulate(values)); - ::concretelang::clientlib::PublicResult res{output}; - return std::make_unique<::concretelang::clientlib::PublicResult>( - std::move(res)); - }); + pybind11::class_(m, "KeysetCache") + .def(pybind11::init(), arg("backing_directory_path")) + .doc() = "Local keysets cache."; - pybind11::class_<::concretelang::clientlib::ValueExporter>(m, "ValueExporter") + // ------------------------------------------------------------------------------// + // SERVER KEYSET // + // ------------------------------------------------------------------------------// + pybind11::class_(m, "ServerKeyset") .def_static( - "create", - [](::concretelang::clientlib::KeySet &keySet, - ::concretelang::clientlib::ClientParameters &clientParameters, - const std::string &circuitName) { - return createValueExporter(keySet, clientParameters, circuitName); - }) - .def("export_scalar", - [](::concretelang::clientlib::ValueExporter &exporter, - size_t position, int64_t value) { - SignalGuard signalGuard; - pybind11::gil_scoped_release release; - - auto info = exporter.circuit.getCircuitInfo() - .asReader() - .getInputs()[position]; - auto typeTransformer = getPythonTypeTransformer(info); - auto result = exporter.circuit.prepareInput( - typeTransformer({Tensor(value)}), position); - - if (result.has_error()) { - throw std::runtime_error(result.error().mesg); + "deserialize", + [](const pybind11::bytes &buffer) { + auto serverKeysetProto = Message(); + auto maybeError = serverKeysetProto.readBinaryFromString( + buffer, mlir::concretelang::python::DESER_OPTIONS); + if (maybeError.has_failure()) { + throw std::runtime_error("Failed to deserialize server keyset." + + maybeError.as_failure().error().mesg); + } + return ServerKeyset::fromProto(serverKeysetProto); + }, + "Deserialize a ServerKeyset from bytes.", arg("bytes")) + .def( + "serialize", + [](ServerKeyset &serverKeyset) { + auto serverKeysetSerialize = [](ServerKeyset &serverKeyset) { + auto serverKeysetProto = serverKeyset.toProto(); + auto maybeBuffer = serverKeysetProto.writeBinaryToString(); + if (maybeBuffer.has_failure()) { + throw std::runtime_error("Failed to serialize server keyset."); + } + return maybeBuffer.value(); + }; + return pybind11::bytes(serverKeysetSerialize(serverKeyset)); + }, + "Serialize a ServerKeyset to bytes.") + .doc() = "Server-side / Evaluation keyset"; + + // ------------------------------------------------------------------------------// + // CLIENT KEYSET // + // ------------------------------------------------------------------------------// + pybind11::class_(m, "ClientKeyset") + .def("get_secret_keys", + [](ClientKeyset &keyset) { return keyset.lweSecretKeys; }) + .doc() = "Client-side / Encryption keyset"; + + // ------------------------------------------------------------------------------// + // KEYSET // + // ------------------------------------------------------------------------------// + + pybind11::class_(m, "Keyset") + .def(init([](ProgramInfo programInfo, std::optional cache, + uint64_t secretSeedMsb, uint64_t secretSeedLsb, + uint64_t encSeedMsb, uint64_t encSeedLsb, + std::optional> + initialLweSecretKeys) { + SignalGuard const signalGuard; + + auto secretSeed = + (((__uint128_t)secretSeedMsb) << 64) | secretSeedLsb; + auto encryptionSeed = + (((__uint128_t)encSeedMsb) << 64) | encSeedLsb; + + if (!initialLweSecretKeys.has_value()) { + initialLweSecretKeys = std::map(); } - return ::concretelang::clientlib::SharedScalarOrTensorData{ - result.value()}; - }) - .def("export_tensor", [](::concretelang::clientlib::ValueExporter - &exporter, - size_t position, std::vector values, - std::vector shape) { - SignalGuard signalGuard; - pybind11::gil_scoped_release release; - std::vector dimensions(shape.begin(), shape.end()); - auto info = - exporter.circuit.getCircuitInfo().asReader().getInputs()[position]; - auto typeTransformer = getPythonTypeTransformer(info); - auto result = exporter.circuit.prepareInput( - typeTransformer({Tensor(values, dimensions)}), position); - - if (result.has_error()) { - throw std::runtime_error(result.error().mesg); - } + if (cache) { + GET_OR_THROW_RESULT( + Keyset keyset, + (*cache).getKeyset( + programInfo.programInfo.asReader().getKeyset(), + secretSeed, encryptionSeed, + initialLweSecretKeys.value())); + return std::make_unique(std::move(keyset)); + } else { + ::concretelang::csprng::SecretCSPRNG secCsprng(secretSeed); + ::concretelang::csprng::EncryptionCSPRNG encCsprng( + encryptionSeed); + auto keyset = + Keyset(programInfo.programInfo.asReader().getKeyset(), + secCsprng, encCsprng, initialLweSecretKeys.value()); + return std::make_unique(std::move(keyset)); + } + }), + arg("program_info"), arg("keyset_cache"), arg("secret_seed_msb") = 0, + arg("secret_seed_lsb") = 0, arg("encryption_seed_msb") = 0, + arg("encryption_seed_lsb") = 0, + arg("initial_lwe_secret_keys") = std::nullopt) + .def_static( + "deserialize", + [](const pybind11::bytes &buffer) { + auto keysetProto = Message(); + auto maybeError = keysetProto.readBinaryFromString( + buffer, mlir::concretelang::python::DESER_OPTIONS); + if (maybeError.has_failure()) { + throw std::runtime_error("Failed to deserialize keyset." + + maybeError.as_failure().error().mesg); + } + auto keyset = Keyset::fromProto(keysetProto); + return std::make_unique(std::move(keyset)); + }, + "Deserialize a Keyset from bytes.", arg("bytes")) + .def( + "serialize", + [](Keyset &keySet) { + auto keySetSerialize = [](Keyset &keyset) { + auto keysetProto = keyset.toProto(); + auto maybeBuffer = keysetProto.writeBinaryToString(); + if (maybeBuffer.has_failure()) { + throw std::runtime_error("Failed to serialize keys."); + } + return maybeBuffer.value(); + }; + return pybind11::bytes(keySetSerialize(keySet)); + }, + "Serialize a Keyset to bytes.") + .def( + "serialize_lwe_secret_key_as_glwe", + [](Keyset &keyset, size_t keyIndex, size_t glwe_dimension, + size_t polynomial_size) { + auto secretKeys = keyset.client.lweSecretKeys; + if (keyIndex >= secretKeys.size()) { + throw std::runtime_error( + "keyIndex is bigger than the number of keys"); + } + auto secretKey = secretKeys[keyIndex]; + auto skBuffer = secretKey.getBuffer(); + auto buffer_size = concrete_cpu_glwe_secret_key_buffer_size_u64( + glwe_dimension, polynomial_size); + std::vector buffer(buffer_size, 0); + buffer_size = concrete_cpu_serialize_glwe_secret_key_u64( + skBuffer.data(), glwe_dimension, polynomial_size, buffer.data(), + buffer_size); + if (buffer_size == 0) { + throw std::runtime_error("couldn't serialize the secret key"); + } + auto bytes = pybind11::bytes((char *)buffer.data(), buffer_size); + return bytes; + }, + "Serialize the `key_id` secret key as a tfhe-rs GLWE key with " + "parameters `glwe_dim` and `poly_size`.", + arg("key_id"), arg("glwe_dim"), arg("poly_size")) + .def( + "get_server_keys", [](Keyset &keyset) { return keyset.server; }, + "Return the associated ServerKeyset.") + .def( + "get_client_keys", [](Keyset &keyset) { return keyset.client; }, + "Return the associated ClientKeyset.") + .doc() = + "Complete keyset containing both client-side and server-side keys."; + + // ------------------------------------------------------------------------------// + // LIBRARY // + // ------------------------------------------------------------------------------// + + pybind11::class_(m, "Library") + .def(init([](std::string output_dir_path) -> Library { + return Library(output_dir_path); + }), + arg("output_dir_path")) + .def( + "get_program_info", + [](Library &library) { + GET_OR_THROW_RESULT(auto pi, library.getProgramInfo()); + return ProgramInfo{pi}; + }, + "Return the program info associated to the library.") + .def( + "get_output_dir_path", + [](Library &library) { return library.getOutputDirPath(); }, + "Return the path to library output directory.") + .def( + "get_shared_lib_path", + [](Library &library) { return library.getSharedLibraryPath(); }, + "Return the path to the shared library.") + .def( + "get_program_info_path", + [](Library &library) { return library.getProgramInfoPath(); }, + "Return the path to the program info file.") + .def( + "get_program_compilation_feedback", + [](Library &library) { + auto path = library.getCompilationFeedbackPath(); + GET_OR_THROW_RESULT(auto feedback, + ProgramCompilationFeedback::load(path)); + return feedback; + }, + "Return the associated program compilation feedback.") + .doc() = "Library object representing the output of a compilation."; + + // ------------------------------------------------------------------------------// + // COMPILER // + // ------------------------------------------------------------------------------// + + struct Compiler { + std::string outputPath; + std::string runtimeLibraryPath; + bool generateSharedLib; + bool generateStaticLib; + bool generateClientParameters; + bool generateCompilationFeedback; + }; + pybind11::class_(m, "Compiler") + .def(init([](std::string outputPath, std::string runtimeLibraryPath, + bool generateSharedLib, bool generateStaticLib, + bool generateProgramInfo, bool generateCompilationFeedback) { + if (!std::filesystem::exists(outputPath)) { + std::filesystem::create_directory(outputPath); + } + return Compiler{outputPath.c_str(), runtimeLibraryPath.c_str(), + generateSharedLib, generateStaticLib, + generateProgramInfo, generateCompilationFeedback}; + }), + arg("output_path"), arg("runtime_lib_path"), + arg("generate_shared_lib") = true, arg("generate_static_lib") = true, + arg("generate_program_info") = true, + arg("generate_compilation_feedback") = true) + .def( + "compile", + [](Compiler &support, std::string mlir_program, + mlir::concretelang::CompilationOptions options) { + SignalGuard signalGuard; + llvm::SourceMgr sm; + sm.AddNewSourceBuffer( + llvm::MemoryBuffer::getMemBuffer(mlir_program.c_str()), + llvm::SMLoc()); + + // Setup the compiler engine + auto context = CompilationContext::createShared(); + concretelang::CompilerEngine engine(context); + engine.setCompilationOptions(options); + + // Compile to a library + GET_OR_THROW_EXPECTED( + auto library, + engine.compile( + sm, support.outputPath, support.runtimeLibraryPath, + support.generateSharedLib, support.generateStaticLib, + support.generateClientParameters, + support.generateCompilationFeedback)); + return library; + }, + "Compile `mlir_program` using the `options` compilation options.", + arg("mlir_program"), arg("options")) + .def( + "compile", + [](Compiler &support, pybind11::object mlir_module, + mlir::concretelang::CompilationOptions options, + std::shared_ptr context) { + SignalGuard signalGuard; + mlir::ModuleOp module = + unwrap(mlirPythonCapsuleToModule(mlir_module.ptr())).clone(); + + // Setup the compiler engine + concretelang::CompilerEngine engine(context); + engine.setCompilationOptions(options); + + // Compile to a library + GET_OR_THROW_EXPECTED( + auto library, + engine.compile( + module, support.outputPath, support.runtimeLibraryPath, + support.generateSharedLib, support.generateStaticLib, + support.generateClientParameters, + support.generateCompilationFeedback)); + return library; + }, + "Compile the `mlir_module` module with `options` compilation " + "options, under the `context` compilation context.", + arg("mlir_module"), arg("options"), arg("context").none(false)) + .doc() = "Provides compilation facility."; - return ::concretelang::clientlib::SharedScalarOrTensorData{ - result.value()}; - }); + // ------------------------------------------------------------------------------// + // TRANSPORT VALUE // + // ------------------------------------------------------------------------------// - pybind11::class_<::concretelang::clientlib::SimulatedValueExporter>( - m, "SimulatedValueExporter") + pybind11::class_(m, "TransportValue") .def_static( - "create", - [](::concretelang::clientlib::ClientParameters &clientParameters, - const std::string &circuitName) { - return createSimulatedValueExporter(clientParameters, circuitName); - }) - .def("export_scalar", - [](::concretelang::clientlib::SimulatedValueExporter &exporter, - size_t position, int64_t value) { - SignalGuard signalGuard; - auto info = exporter.circuit.getCircuitInfo() - .asReader() - .getInputs()[position]; - auto typeTransformer = getPythonTypeTransformer(info); - auto result = exporter.circuit.prepareInput( - typeTransformer({Tensor(value)}), position); - - if (result.has_error()) { - throw std::runtime_error(result.error().mesg); - } + "deserialize", + [](const pybind11::bytes &buffer) { + auto inner = TransportValue(); + if (inner + .readBinaryFromString( + buffer, mlir::concretelang::python::DESER_OPTIONS) + .has_failure()) { + throw std::runtime_error("Failed to deserialize TransportValue"); + } + return TransportValue{inner}; + }, + "Deserialize a TransportValue from bytes.", arg("bytes")) + .def( + "serialize", + [](const TransportValue &value) { + auto valueSerialize = [](const TransportValue &value) { + auto maybeString = value.writeBinaryToString(); + if (maybeString.has_failure()) { + throw std::runtime_error("Failed to serialize TransportValue"); + } + return maybeString.value(); + }; + return pybind11::bytes(valueSerialize(value)); + }, + "Serialize a TransportValue to bytes") + .doc() = "Public/Transportable value."; - return ::concretelang::clientlib::SharedScalarOrTensorData{ - result.value()}; - }) - .def("export_tensor", [](::concretelang::clientlib::SimulatedValueExporter - &exporter, - size_t position, std::vector values, - std::vector shape) { - SignalGuard signalGuard; - std::vector dimensions(shape.begin(), shape.end()); - auto info = - exporter.circuit.getCircuitInfo().asReader().getInputs()[position]; - auto typeTransformer = getPythonTypeTransformer(info); - auto result = exporter.circuit.prepareInput( - typeTransformer({Tensor(values, dimensions)}), position); - - if (result.has_error()) { - throw std::runtime_error(result.error().mesg); - } + // ------------------------------------------------------------------------------// + // VALUE // + // ------------------------------------------------------------------------------// - return ::concretelang::clientlib::SharedScalarOrTensorData{ - result.value()}; - }); + typedef std::variant + PyValType; - pybind11::class_<::concretelang::clientlib::ValueDecrypter>(m, - "ValueDecrypter") - .def_static( - "create", - [](::concretelang::clientlib::KeySet &keySet, - ::concretelang::clientlib::ClientParameters &clientParameters, - const std::string &circuitName) { - return createValueDecrypter(keySet, clientParameters, circuitName); - }) - .def("decrypt", - [](::concretelang::clientlib::ValueDecrypter &decrypter, - size_t position, - ::concretelang::clientlib::SharedScalarOrTensorData &value) { - SignalGuard signalGuard; - pybind11::gil_scoped_release release; - - auto result = - decrypter.circuit.processOutput(value.value, position); - if (result.has_error()) { - throw std::runtime_error(result.error().mesg); + pybind11::class_(m, "Value") + .def(init([](int64_t scalar) { return Value{Tensor(scalar)}; }), + arg("input")) + .def( + init([](uint64_t scalar) { return Value{Tensor(scalar)}; }), + arg("input")) + .def(init([](int32_t scalar) { return Value{Tensor(scalar)}; }), + arg("input")) + .def( + init([](uint32_t scalar) { return Value{Tensor(scalar)}; }), + arg("input")) + .def(init([](int16_t scalar) { return Value{Tensor(scalar)}; }), + arg("input")) + .def( + init([](uint16_t scalar) { return Value{Tensor(scalar)}; }), + arg("input")) + .def(init([](int8_t scalar) { return Value{Tensor(scalar)}; }), + arg("input")) + .def(init([](uint8_t scalar) { return Value{Tensor(scalar)}; }), + arg("input")) + .def(init([](array input) -> Value { + if (input.dtype().kind() == 'i') { + if (input.dtype().itemsize() == 1) { + return Value{::arrayToTensor(input)}; + } + if (input.dtype().itemsize() == 2) { + return Value{::arrayToTensor(input)}; + } + if (input.dtype().itemsize() == 4) { + return Value{::arrayToTensor(input)}; + } + if (input.dtype().itemsize() == 8) { + return Value{::arrayToTensor(input)}; + } } + if (input.dtype().kind() == 'u') { + if (input.dtype().itemsize() == 1) { + return Value{::arrayToTensor(input)}; + } + if (input.dtype().itemsize() == 2) { + return Value{::arrayToTensor(input)}; + } + if (input.dtype().itemsize() == 4) { + return Value{::arrayToTensor(input)}; + } + if (input.dtype().itemsize() == 8) { + return Value{::arrayToTensor(input)}; + } + } + throw std::runtime_error( + "Values can only be constructed from arrays " + "of signed and unsigned integers."); + }), + arg("input")) + .def(init([](pybind11::object object) -> Value { + std::string message = + "Failed to create value from input. Incompatible type: "; + message.append(std::string{pybind11::str(object.get_type())}); + message.append(" . Expected int or np.ndarray."); + throw std::runtime_error(message); + }), + arg("input")) + .def( + "to_py_val", + [](Value &value) -> PyValType { + if (value.isScalar()) { + if (value.hasElementType()) { + return {value.getTensor()->values[0]}; + } + if (value.hasElementType()) { + return {value.getTensor()->values[0]}; + } + if (value.hasElementType()) { + return {value.getTensor()->values[0]}; + } + if (value.hasElementType()) { + return {value.getTensor()->values[0]}; + } + if (value.hasElementType()) { + return {value.getTensor()->values[0]}; + } + if (value.hasElementType()) { + return {value.getTensor()->values[0]}; + } + if (value.hasElementType()) { + return {value.getTensor()->values[0]}; + } + if (value.hasElementType()) { + return {value.getTensor()->values[0]}; + } + } else { + if (value.hasElementType()) { + return {tensorToArray(value.getTensor().value())}; + } + if (value.hasElementType()) { + return {tensorToArray(value.getTensor().value())}; + } + if (value.hasElementType()) { + return {tensorToArray(value.getTensor().value())}; + } + if (value.hasElementType()) { + return {tensorToArray(value.getTensor().value())}; + } + if (value.hasElementType()) { + return {tensorToArray(value.getTensor().value())}; + } + if (value.hasElementType()) { + return {tensorToArray(value.getTensor().value())}; + } + if (value.hasElementType()) { + return {tensorToArray(value.getTensor().value())}; + } + if (value.hasElementType()) { + return {tensorToArray(value.getTensor().value())}; + } + } + throw std::invalid_argument("Value has insupported scalar type."); + }, + "Return the inner value as a python type.") + .def( + "is_scalar", [](Value &val) { return val.isScalar(); }, + "Return whether the value is a scalar.") + .def( + "is_tensor", [](Value &val) { return !val.isScalar(); }, + "Return whether the value is a tensor.") + .def( + "get_shape", [](Value &val) { return val.getDimensions(); }, + "Return the shape of the value.") + .doc() = "Private / Runtime value."; - return lambdaArgument{ - std::make_shared( - mlir::concretelang::LambdaArgument{result.value()})}; - }); + // ------------------------------------------------------------------------------// + // SERVER CIRCUIT // + // ------------------------------------------------------------------------------// - pybind11::class_<::concretelang::clientlib::SimulatedValueDecrypter>( - m, "SimulatedValueDecrypter") - .def_static( - "create", - [](::concretelang::clientlib::ClientParameters &clientParameters, - const std::string &circuitName) { - return createSimulatedValueDecrypter(clientParameters, circuitName); - }) - .def("decrypt", - [](::concretelang::clientlib::SimulatedValueDecrypter &decrypter, - size_t position, - ::concretelang::clientlib::SharedScalarOrTensorData &value) { - SignalGuard signalGuard; - - auto result = - decrypter.circuit.processOutput(value.value, position); - if (result.has_error()) { - throw std::runtime_error(result.error().mesg); - } + pybind11::class_(m, "ServerCircuit") + .def( + "call", + [](ServerCircuit &circuit, std::vector args, + ServerKeyset keyset) { + SignalGuard signalGuard; + pybind11::gil_scoped_release release; + GET_OR_THROW_RESULT(auto output, circuit.call(keyset, args)); + return output; + }, + "Perform circuit call with `args` arguments using the `keyset` " + "ServerKeyset.", + arg("args"), arg("keyset")) + .def( + "simulate", + [](ServerCircuit &circuit, std::vector &args) { + pybind11::gil_scoped_release release; + GET_OR_THROW_RESULT(auto output, circuit.simulate(args)); + return output; + }, + "Perform circuit simulation with `args` arguments.", arg("args")) + .doc() = "Server-side / Evaluation circuit."; - return lambdaArgument{ - std::make_shared( - mlir::concretelang::LambdaArgument{result.value()})}; - }); + // ------------------------------------------------------------------------------// + // SERVER PROGRAM // + // ------------------------------------------------------------------------------// - pybind11::class_<::concretelang::clientlib::PublicArguments, - std::unique_ptr<::concretelang::clientlib::PublicArguments>>( - m, "PublicArguments") - .def_static( - "create", - [](const ::concretelang::clientlib::ClientParameters - &clientParameters, - std::vector<::concretelang::clientlib::SharedScalarOrTensorData> - &buffers) { - std::vector vals; - for (auto buf : buffers) { - vals.push_back(buf.value); + pybind11::class_(m, "ServerProgram") + .def(init([](Library &library, bool useSimulation) { + auto sharedLibPath = library.getSharedLibraryPath(); + + GET_OR_THROW_RESULT(auto pi, library.getProgramInfo()); + GET_OR_THROW_RESULT( + auto result, ServerProgram::load(pi.asReader(), sharedLibPath, + useSimulation)); + return result; + }), + arg("library"), arg("use_simulation")) + .def( + "get_server_circuit", + [](ServerProgram &program, const std::string &circuitName) { + GET_OR_THROW_RESULT(auto result, + program.getServerCircuit(circuitName)); + return result; + }, + "Return the `circuit` ServerCircuit.", arg("circuit")) + .doc() = "Server-side / Evaluation program."; + + // ------------------------------------------------------------------------------// + // CLIENT CIRCUIT // + // ------------------------------------------------------------------------------// + + pybind11::class_(m, "ClientCircuit") + .def( + "prepare_input", + [](ClientCircuit &circuit, Value arg, size_t pos) { + if (pos > circuit.getCircuitInfo().asReader().getInputs().size()) { + throw std::runtime_error("Unknown position."); } - return ::concretelang::clientlib::PublicArguments{vals}; - }) + auto info = circuit.getCircuitInfo().asReader().getInputs()[pos]; + auto typeTransformer = getPythonTypeTransformer(info); + GET_OR_THROW_RESULT( + auto ok, circuit.prepareInput(typeTransformer(arg), pos)); + return ok; + }, + "Prepare a `pos` positional arguments `arg` to be sent to server. ", + arg("arg"), arg("pos")) + .def( + "process_output", + [](ClientCircuit &circuit, TransportValue result, size_t pos) { + GET_OR_THROW_RESULT(auto ok, circuit.processOutput(result, pos)); + return ok; + }, + "Process a `pos` positional result `result` retrieved from server. ", + arg("result"), arg("pos")) + .def( + "simulate_prepare_input", + [](ClientCircuit &circuit, Value arg, size_t pos) { + if (pos > circuit.getCircuitInfo().asReader().getInputs().size()) { + throw std::runtime_error("Unknown position."); + } + auto info = circuit.getCircuitInfo().asReader().getInputs()[pos]; + auto typeTransformer = getPythonTypeTransformer(info); + GET_OR_THROW_RESULT(auto ok, circuit.simulatePrepareInput( + typeTransformer(arg), pos)); + return ok; + }, + "SIMULATE preparation of `pos` positional argument `arg` to be sent " + "to server. DOES NOT NCRYPT.", + arg("arg"), arg("pos")) + .def( + "simulate_process_output", + [](ClientCircuit &circuit, TransportValue result, size_t pos) { + GET_OR_THROW_RESULT(auto ok, + circuit.simulateProcessOutput(result, pos)); + return ok; + }, + "SIMULATE processing of `pos` positional result `result` retrieved " + "from server.", + arg("result"), arg("pos")) + .doc() = "Client-side / Encryption circuit."; + + // ------------------------------------------------------------------------------// + // CLIENT PROGRAM // + // ------------------------------------------------------------------------------// + + pybind11::class_(m, "ClientProgram") .def_static( - "deserialize", - [](::concretelang::clientlib::ClientParameters &clientParameters, - const pybind11::bytes &buffer) { - return publicArgumentsUnserialize(clientParameters, buffer); - }) - .def("serialize", - [](::concretelang::clientlib::PublicArguments &publicArgument) { - return pybind11::bytes(publicArgumentsSerialize(publicArgument)); - }); - pybind11::class_<::concretelang::clientlib::PublicResult>(m, "PublicResult") + "create_encrypted", + [](ProgramInfo programInfo, Keyset keyset) { + GET_OR_THROW_RESULT( + auto clientProgram, + ClientProgram::createEncrypted( + programInfo.programInfo, keyset.client, + std::make_shared(EncryptionCSPRNG(0)))); + return clientProgram; + }, + "Create an encrypted (as opposed to simulated) ClientProgram.", + arg("program_info"), arg("keyset")) .def_static( - "deserialize", - [](::concretelang::clientlib::ClientParameters &clientParameters, - const pybind11::bytes &buffer) { - return publicResultUnserialize(clientParameters, buffer); - }) - .def("serialize", - [](::concretelang::clientlib::PublicResult &publicResult) { - return pybind11::bytes(publicResultSerialize(publicResult)); - }) - .def("n_values", - [](const ::concretelang::clientlib::PublicResult &publicResult) { - return publicResult.values.size(); - }) - .def("get_value", - [](::concretelang::clientlib::PublicResult &publicResult, - size_t position) { - if (position >= publicResult.values.size()) { - throw std::runtime_error("Failed to get public result value."); - } - return ::concretelang::clientlib::SharedScalarOrTensorData{ - publicResult.values[position]}; - }); + "create_simulated", + [](ProgramInfo &programInfo) { + GET_OR_THROW_RESULT( + auto clientProgram, + ClientProgram::createSimulated( + programInfo.programInfo, + std::make_shared(EncryptionCSPRNG(0)))); + return clientProgram; + }, + "Create a simulated (as opposed to encrypted) ClientProgram. DOES " + "NOT PERFORM ENCRYPTION OF VALUES.", + arg("program_info")) + .def( + "get_client_circuit", + [](ClientProgram &program, + const std::string &circuitName) -> ClientCircuit { + GET_OR_THROW_RESULT(auto result, + program.getClientCircuit(circuitName)); + return result; + }, + "Return the `circuit` ClientCircuit.", arg("circuit")) + .doc() = "Client-side / Encryption program"; - pybind11::class_<::concretelang::clientlib::EvaluationKeys>(m, - "EvaluationKeys") - .def_static("deserialize", - [](const pybind11::bytes &buffer) { - return evaluationKeysUnserialize(buffer); - }) - .def("serialize", - [](::concretelang::clientlib::EvaluationKeys &evaluationKeys) { - return pybind11::bytes(evaluationKeysSerialize(evaluationKeys)); - }); + m.def("import_tfhers_fheuint8", + [](const pybind11::bytes &serialized_fheuint, + TfhersFheIntDescription info, uint32_t encryptionKeyId, + double encryptionVariance) { + const std::string &buffer_str = serialized_fheuint; + std::vector buffer(buffer_str.begin(), buffer_str.end()); + auto arrayRef = llvm::ArrayRef(buffer); + auto valueOrError = ::concretelang::clientlib::importTfhersFheUint8( + arrayRef, info, encryptionKeyId, encryptionVariance); + if (valueOrError.has_error()) { + throw std::runtime_error(valueOrError.error().mesg); + } + return TransportValue{valueOrError.value()}; + }); - pybind11::class_(m, "LambdaArgument") - .def_static("from_tensor_u8", - [](std::vector tensor, std::vector dims) { - return lambdaArgumentFromTensorU8(tensor, dims); - }) - .def_static("from_tensor_u16", - [](std::vector tensor, std::vector dims) { - return lambdaArgumentFromTensorU16(tensor, dims); - }) - .def_static("from_tensor_u32", - [](std::vector tensor, std::vector dims) { - return lambdaArgumentFromTensorU32(tensor, dims); - }) - .def_static("from_tensor_u64", - [](std::vector tensor, std::vector dims) { - return lambdaArgumentFromTensorU64(tensor, dims); - }) - .def_static("from_tensor_i8", - [](std::vector tensor, std::vector dims) { - return lambdaArgumentFromTensorI8(tensor, dims); - }) - .def_static("from_tensor_i16", - [](std::vector tensor, std::vector dims) { - return lambdaArgumentFromTensorI16(tensor, dims); - }) - .def_static("from_tensor_i32", - [](std::vector tensor, std::vector dims) { - return lambdaArgumentFromTensorI32(tensor, dims); - }) - .def_static("from_tensor_i64", - [](std::vector tensor, std::vector dims) { - return lambdaArgumentFromTensorI64(tensor, dims); - }) - .def_static("from_scalar", lambdaArgumentFromScalar) - .def_static("from_signed_scalar", lambdaArgumentFromSignedScalar) - .def("is_tensor", - [](lambdaArgument &lambda_arg) { - return lambdaArgumentIsTensor(lambda_arg); - }) - .def("get_tensor_data", - [](lambdaArgument &lambda_arg) { - return lambdaArgumentGetTensorData(lambda_arg); - }) - .def("get_signed_tensor_data", - [](lambdaArgument &lambda_arg) { - return lambdaArgumentGetSignedTensorData(lambda_arg); - }) - .def("get_tensor_shape", - [](lambdaArgument &lambda_arg) { - return lambdaArgumentGetTensorDimensions(lambda_arg); - }) - .def("is_scalar", - [](lambdaArgument &lambda_arg) { - return lambdaArgumentIsScalar(lambda_arg); - }) - .def("is_signed", - [](lambdaArgument &lambda_arg) { - return lambdaArgumentIsSigned(lambda_arg); - }) - .def("get_scalar", - [](lambdaArgument &lambda_arg) { - return lambdaArgumentGetScalar(lambda_arg); - }) - .def("get_signed_scalar", [](lambdaArgument &lambda_arg) { - return lambdaArgumentGetSignedScalar(lambda_arg); - }); + m.def("export_tfhers_fheuint8", + [](TransportValue fheuint, TfhersFheIntDescription info) { + auto result = + ::concretelang::clientlib::exportTfhersFheUint8(fheuint, info); + if (result.has_error()) { + throw std::runtime_error(result.error().mesg); + } + return result.value(); + }); + + m.def("get_tfhers_fheuint8_description", + [](const pybind11::bytes &serialized_fheuint) { + const std::string &buffer_str = serialized_fheuint; + std::vector buffer(buffer_str.begin(), buffer_str.end()); + auto arrayRef = llvm::ArrayRef(buffer); + auto info = + ::concretelang::clientlib::getTfhersFheUint8Description(arrayRef); + if (info.has_error()) { + throw std::runtime_error(info.error().mesg); + } + return info.value(); + }); } diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/__init__.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/__init__.py index db3e51c05e..ca8dd267e4 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/__init__.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/__init__.py @@ -3,48 +3,58 @@ """Compiler submodule.""" import atexit +from typing import Union # pylint: disable=no-name-in-module,import-error from mlir._mlir_libs._concretelang._compiler import ( + LweSecretKeyParam, + BootstrapKeyParam, + KeyswitchKeyParam, + PackingKeyswitchKeyParam, + ProgramInfo, + CompilationOptions, + LweSecretKey, + KeysetCache, + ServerKeyset, + Keyset, + Compiler, + TransportValue, + Value, + ServerProgram, + ServerCircuit, + ClientProgram, + ClientCircuit, + Backend, + KeyType, + OptimizerMultiParameterStrategy, + OptimizerStrategy, + PrimitiveOperation, + Library, + ProgramCompilationFeedback, + CircuitCompilationFeedback, terminate_df_parallelization as _terminate_df_parallelization, init_df_parallelization as _init_df_parallelization, check_gpu_runtime_enabled as _check_gpu_runtime_enabled, check_cuda_device_available as _check_cuda_device_available, -) -from mlir._mlir_libs._concretelang._compiler import round_trip as _round_trip -from mlir._mlir_libs._concretelang._compiler import ( + round_trip as _round_trip, set_llvm_debug_flag, set_compiler_logging, ) # pylint: enable=no-name-in-module,import-error -from .compilation_options import CompilationOptions, Encoding +from .utils import lookup_runtime_lib +from .compilation_feedback import MoreCircuitCompilationFeedback from .compilation_context import CompilationContext -from .key_set_cache import KeySetCache -from .client_parameters import ClientParameters -from .compilation_feedback import ProgramCompilationFeedback, CircuitCompilationFeedback -from .key_set import KeySet -from .public_result import PublicResult -from .public_arguments import PublicArguments -from .lambda_argument import LambdaArgument -from .library_compilation_result import LibraryCompilationResult -from .library_lambda import LibraryLambda -from .client_support import ClientSupport -from .library_support import LibrarySupport -from .lwe_secret_key import LweSecretKey, LweSecretKeyParam -from .evaluation_keys import EvaluationKeys + from .tfhers_int import ( TfhersExporter, TfhersFheIntDescription, ) -from .value import Value -from .value_decrypter import ValueDecrypter -from .value_exporter import ValueExporter -from .simulated_value_decrypter import SimulatedValueDecrypter -from .simulated_value_exporter import SimulatedValueExporter -from .parameter import Parameter -from .server_program import ServerProgram + +Parameter = Union[ + LweSecretKeyParam, BootstrapKeyParam, KeyswitchKeyParam, PackingKeyswitchKeyParam +] def init_dfr(): diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/client_parameters.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/client_parameters.py deleted file mode 100644 index bf2fa1659a..0000000000 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/client_parameters.py +++ /dev/null @@ -1,148 +0,0 @@ -# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. -# See https://github.com/zama-ai/concrete/blob/main/LICENSE.txt for license information. - -"""Client parameters.""" - -from typing import List - -# pylint: disable=no-name-in-module,import-error -from mlir._mlir_libs._concretelang._compiler import ( - ClientParameters as _ClientParameters, -) - -# pylint: enable=no-name-in-module,import-error - -from .lwe_secret_key import LweSecretKeyParam -from .wrapper import WrapperCpp - - -class ClientParameters(WrapperCpp): - """ClientParameters are public parameters used for key generation. - - It's a compilation artifact that describes which and how public and private keys should be generated, - and used to encrypt arguments of the compiled function. - """ - - def __init__(self, client_parameters: _ClientParameters): - """Wrap the native Cpp object. - - Args: - client_parameters (_ClientParameters): object to wrap - - Raises: - TypeError: if client_parameters is not of type _ClientParameters - """ - if not isinstance(client_parameters, _ClientParameters): - raise TypeError( - f"client_parameters must be of type _ClientParameters, not {type(client_parameters)}" - ) - super().__init__(client_parameters) - - def lwe_secret_key_param(self, key_id: int) -> LweSecretKeyParam: - """Get the parameters of a selected LWE secret key. - - Args: - key_id (int): keyid to get parameters from - - Raises: - TypeError: if arguments aren't of expected types - - Returns: - LweSecretKeyParam: LWE secret key parameters - """ - if not isinstance(key_id, int): - raise TypeError(f"key_id must be of type int, not {type(key_id)}") - return LweSecretKeyParam.wrap(self.cpp().lwe_secret_key_param(key_id)) - - def input_keyid_at(self, input_idx: int, circuit_name: str = "") -> int: - """Get the keyid of a selected encrypted input in a given circuit. - - Args: - input_idx (int): index of the input in the circuit. - circuit_name (str): name of the circuit containing the desired input. - - Raises: - TypeError: if arguments aren't of expected types - - Returns: - int: keyid - """ - if not isinstance(input_idx, int): - raise TypeError(f"input_idx must be of type int, not {type(input_idx)}") - if not isinstance(circuit_name, str): - raise TypeError( - f"circuit_name must be of type str, not {type(circuit_name)}" - ) - return self.cpp().input_keyid_at(input_idx, circuit_name) - - def input_variance_at(self, input_idx: int, circuit_name: str) -> float: - """Get the variance of a selected encrypted input in a given circuit. - - Args: - input_idx (int): index of the input in the circuit. - circuit_name (str): name of the circuit containing the desired input. - - Raises: - TypeError: if arguments aren't of expected types - - Returns: - float: variance - """ - if not isinstance(input_idx, int): - raise TypeError(f"input_idx must be of type int, not {type(input_idx)}") - if not isinstance(circuit_name, str): - raise TypeError( - f"circuit_name must be of type str, not {type(circuit_name)}" - ) - return self.cpp().input_variance_at(input_idx, circuit_name) - - def input_signs(self) -> List[bool]: - """Return the sign information of inputs. - - Returns: - List[bool]: list of booleans to indicate whether the inputs are signed or not - """ - return self.cpp().input_signs() - - def output_signs(self) -> List[bool]: - """Return the sign information of outputs. - - Returns: - List[bool]: list of booleans to indicate whether the outputs are signed or not - """ - return self.cpp().output_signs() - - def function_list(self) -> List[str]: - """Return the list of function names. - - Returns: - List[str]: list of the names of the functions. - """ - return self.cpp().function_list() - - def serialize(self) -> bytes: - """Serialize the ClientParameters. - - Returns: - bytes: serialized object - """ - return self.cpp().serialize() - - @staticmethod - def deserialize(serialized_params: bytes) -> "ClientParameters": - """Unserialize ClientParameters from bytes of serialized_params. - - Args: - serialized_params (bytes): previously serialized ClientParameters - - Raises: - TypeError: if serialized_params is not of type bytes - - Returns: - ClientParameters: deserialized object - """ - if not isinstance(serialized_params, bytes): - raise TypeError( - f"serialized_params must be of type bytes, not {type(serialized_params)}" - ) - return ClientParameters.wrap(_ClientParameters.deserialize(serialized_params)) diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/client_support.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/client_support.py deleted file mode 100644 index a4fc7822aa..0000000000 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/client_support.py +++ /dev/null @@ -1,322 +0,0 @@ -# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. -# See https://github.com/zama-ai/concrete/blob/main/LICENSE.txt for license information. - -"""Client support.""" -from typing import List, Optional, Union, Dict -import numpy as np - -# pylint: disable=no-name-in-module,import-error -from mlir._mlir_libs._concretelang._compiler import ClientSupport as _ClientSupport - -# pylint: enable=no-name-in-module,import-error - -from .public_result import PublicResult -from .key_set import KeySet -from .key_set_cache import KeySetCache -from .client_parameters import ClientParameters -from .public_arguments import PublicArguments -from .lambda_argument import LambdaArgument -from .lwe_secret_key import LweSecretKey -from .wrapper import WrapperCpp -from .utils import ACCEPTED_INTS, ACCEPTED_NUMPY_UINTS, ACCEPTED_TYPES - - -class ClientSupport(WrapperCpp): - """Client interface for doing key generation and encryption. - - It provides features that are needed on the client side: - - Generation of public and private keys required for the encrypted computation - - Encryption and preparation of public arguments, used later as input to the computation - - Decryption of public result returned after execution - """ - - def __init__(self, client_support: _ClientSupport): - """Wrap the native Cpp object. - - Args: - client_support (_ClientSupport): object to wrap - - Raises: - TypeError: if client_support is not of type _ClientSupport - """ - if not isinstance(client_support, _ClientSupport): - raise TypeError( - f"client_support must be of type _ClientSupport, not {type(client_support)}" - ) - super().__init__(client_support) - - # pylint: disable=arguments-differ - @staticmethod - def new() -> "ClientSupport": - """Build a ClientSupport. - - Returns: - ClientSupport - """ - return ClientSupport.wrap(_ClientSupport()) - - # pylint: enable=arguments-differ - - @staticmethod - def key_set( - client_parameters: ClientParameters, - keyset_cache: Optional[KeySetCache] = None, - secret_seed: Optional[int] = None, - encryption_seed: Optional[int] = None, - initial_lwe_secret_keys: Optional[Dict[int, LweSecretKey]] = None, - ) -> KeySet: - """Generate a key set according to the client parameters. - - If the cache is set, and include equivalent keys as specified by the client parameters, - the keyset is loaded, otherwise, a new keyset is generated and saved in the cache. - If keygen is required, it will first initialize the secret keys provided, if any. - - Args: - client_parameters (ClientParameters): client parameters specification - keyset_cache (Optional[KeySetCache], optional): keyset cache. Defaults to None. - secret_seed (Optional[int]): secret seed, must be a positive 128 bits integer - encryption_seed (Optional[int]): encryption seed, must be a positive 128 bits integer - initial_lwe_secret_keys (Optional[Dict[int, LweSecretKey]]): keys to init the keyset - with before keygen. It maps keyid to secret key - - Raises: - TypeError: if client_parameters is not of type ClientParameters - TypeError: if keyset_cache is not of type KeySetCache - AssertionError: if seed components is not uint64 - - Returns: - KeySet: generated or loaded keyset - """ - secret_seed = 0 if secret_seed is None else secret_seed - encryption_seed = 0 if encryption_seed is None else encryption_seed - - if secret_seed < 0 or secret_seed >= 2**128: - raise ValueError("secret_seed must be a positive 128 bits integer") - if encryption_seed < 0 or encryption_seed >= 2**128: - raise ValueError("encryption_seed must be a positive 128 bits integer") - secret_seed_msb = (secret_seed >> 64) & 0xFFFFFFFFFFFFFFFF - secret_seed_lsb = (secret_seed) & 0xFFFFFFFFFFFFFFFF - encryption_seed_msb = (encryption_seed >> 64) & 0xFFFFFFFFFFFFFFFF - encryption_seed_lsb = (encryption_seed) & 0xFFFFFFFFFFFFFFFF - - if keyset_cache is not None and not isinstance(keyset_cache, KeySetCache): - raise TypeError( - f"keyset_cache must be None or of type KeySetCache, not {type(keyset_cache)}" - ) - - if initial_lwe_secret_keys is not None: - intial_sks = {i: sk.cpp() for i, sk in initial_lwe_secret_keys.items()} - else: - intial_sks = {} - - cpp_cache = None if keyset_cache is None else keyset_cache.cpp() - return KeySet.wrap( - _ClientSupport.key_set( - client_parameters.cpp(), - cpp_cache, - secret_seed_msb, - secret_seed_lsb, - encryption_seed_msb, - encryption_seed_lsb, - intial_sks, - ), - ) - - @staticmethod - def encrypt_arguments( - client_parameters: ClientParameters, - keyset: KeySet, - args: List[Union[int, np.ndarray]], - circuit_name: str, - ) -> PublicArguments: - """Prepare arguments for encrypted computation. - - Pack public arguments by encrypting the ones that requires encryption, and leaving the rest as plain. - It also pack public materials (public keys) that are required during the computation. - - Args: - client_parameters (ClientParameters): client parameters specification - keyset (KeySet): keyset used to encrypt arguments that require encryption - args (List[Union[int, np.ndarray]]): list of scalar or tensor arguments - circuit_name(str): the name of the circuit for which to encrypt - - Raises: - TypeError: if client_parameters is not of type ClientParameters - TypeError: if keyset is not of type KeySet - TypeError: if circuit_name is not of type str - - Returns: - PublicArguments: public arguments for execution - """ - if not isinstance(client_parameters, ClientParameters): - raise TypeError( - f"client_parameters must be of type ClientParameters, not {type(client_parameters)}" - ) - if not isinstance(keyset, KeySet): - raise TypeError(f"keyset must be of type KeySet, not {type(keyset)}") - if not isinstance(circuit_name, str): - raise TypeError( - f"circuit_name must be of type str, not {type(circuit_name)}" - ) - - signs = client_parameters.input_signs() - if len(signs) != len(args): - raise RuntimeError( - f"function has arity {len(signs)} but is applied to too many arguments" - ) - - lambda_arguments = [ - ClientSupport._create_lambda_argument(arg, signed) - for arg, signed in zip(args, signs) - ] - return PublicArguments.wrap( - _ClientSupport.encrypt_arguments( - client_parameters.cpp(), - keyset.cpp(), - [arg.cpp() for arg in lambda_arguments], - circuit_name, - ) - ) - - @staticmethod - def decrypt_result( - client_parameters: ClientParameters, - keyset: KeySet, - public_result: PublicResult, - circuit_name: str, - ) -> Union[int, np.ndarray]: - """Decrypt a public result using the keyset. - - Args: - client_parameters (ClientParameters): client parameters for decryption - keyset (KeySet): keyset used for decryption - public_result (PublicResult): public result to decrypt - circuit_name (str): name of the circuit for which to decrypt - - Raises: - TypeError: if keyset is not of type KeySet - TypeError: if public_result is not of type PublicResult - TypeError: if circuit_name is not of type str - RuntimeError: if the result is of an unknown type - - Returns: - Union[int, np.ndarray]: plain result - """ - if not isinstance(keyset, KeySet): - raise TypeError(f"keyset must be of type KeySet, not {type(keyset)}") - if not isinstance(public_result, PublicResult): - raise TypeError( - f"public_result must be of type PublicResult, not {type(public_result)}" - ) - if not isinstance(circuit_name, str): - raise TypeError( - f"circuit_name must be of type str, not {type(circuit_name)}" - ) - results = _ClientSupport.decrypt_result( - client_parameters.cpp(), keyset.cpp(), public_result.cpp(), circuit_name - ) - - def process_result(result): - lambda_arg = LambdaArgument.wrap(result) - is_signed = lambda_arg.is_signed() - if lambda_arg.is_scalar(): - return ( - lambda_arg.get_signed_scalar() - if is_signed - else lambda_arg.get_scalar() - ) - - if lambda_arg.is_tensor(): - return np.array( - ( - lambda_arg.get_signed_tensor_data() - if is_signed - else lambda_arg.get_tensor_data() - ), - dtype=(np.int64 if is_signed else np.uint64), - ).reshape(lambda_arg.get_tensor_shape()) - - raise RuntimeError("unknown return type") - - processed_results = tuple(map(process_result, results)) - if len(processed_results) == 1: - return processed_results[0] - return processed_results - - @staticmethod - def _create_lambda_argument( - value: Union[int, np.ndarray], signed: bool - ) -> LambdaArgument: - """Create a lambda argument holding either an int or tensor value. - - Args: - value (Union[int, numpy.array]): value of the argument, either an int, or a numpy array - signed (bool): whether the value is signed - - Raises: - TypeError: if the values aren't in the expected range, or using a wrong type - - Returns: - LambdaArgument: lambda argument holding the appropriate value - """ - - # pylint: disable=too-many-return-statements,too-many-branches - if not isinstance(value, ACCEPTED_TYPES): - raise TypeError( - "value of lambda argument must be either int, numpy.array or numpy.(u)int{8,16,32,64}" - ) - if isinstance(value, ACCEPTED_INTS): - if ( - isinstance(value, int) - and not np.iinfo(np.int64).min <= value < np.iinfo(np.uint64).max - ): - raise TypeError( - "single integer must be in the range [-2**63, 2**64 - 1]" - ) - if signed: - return LambdaArgument.from_signed_scalar(value) - return LambdaArgument.from_scalar(value) - assert isinstance(value, np.ndarray) - if value.dtype not in ACCEPTED_NUMPY_UINTS: - raise TypeError("numpy.array must be of dtype (u)int{8,16,32,64}") - if value.shape == (): - if isinstance(value, np.ndarray): - # extract the single element - value = value.max() - # should be a single uint here - if signed: - return LambdaArgument.from_signed_scalar(value) - return LambdaArgument.from_scalar(value) - if value.dtype == np.uint8: - return LambdaArgument.from_tensor_u8( - value.flatten().tolist(), list(value.shape) - ) - if value.dtype == np.uint16: - return LambdaArgument.from_tensor_u16( - value.flatten().tolist(), list(value.shape) - ) - if value.dtype == np.uint32: - return LambdaArgument.from_tensor_u32( - value.flatten().tolist(), list(value.shape) - ) - if value.dtype == np.uint64: - return LambdaArgument.from_tensor_u64( - value.flatten().tolist(), list(value.shape) - ) - if value.dtype == np.int8: - return LambdaArgument.from_tensor_i8( - value.flatten().tolist(), list(value.shape) - ) - if value.dtype == np.int16: - return LambdaArgument.from_tensor_i16( - value.flatten().tolist(), list(value.shape) - ) - if value.dtype == np.int32: - return LambdaArgument.from_tensor_i32( - value.flatten().tolist(), list(value.shape) - ) - if value.dtype == np.int64: - return LambdaArgument.from_tensor_i64( - value.flatten().tolist(), list(value.shape) - ) - raise TypeError("numpy.array must be of dtype (u)int{8,16,32,64}") diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_context.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_context.py index d3e8384033..d79d1b88d1 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_context.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_context.py @@ -6,60 +6,27 @@ CompilationContext holds the MLIR Context supposed to be used during IR generation. """ -# pylint: disable=no-name-in-module,import-error +# pylint: disable=no-name-in-module,import-error,too-many-instance-attributes,protected-access from mlir._mlir_libs._concretelang._compiler import ( CompilationContext as _CompilationContext, ) from mlir.ir import Context as MlirContext -# pylint: enable=no-name-in-module,import-error -from .wrapper import WrapperCpp - -class CompilationContext(WrapperCpp): - """Support class for compilation context. - - CompilationContext is meant to outlive mlir_context(). - Do not use the mlir_context after deleting the CompilationContext. +class CompilationContext(_CompilationContext): + """ + Compilation context. """ - def __init__(self, compilation_context: _CompilationContext): - """Wrap the native Cpp object. - - Args: - compilation_context (_CompilationContext): object to wrap - - Raises: - TypeError: if compilation_context is not of type _CompilationContext - """ - if not isinstance(compilation_context, _CompilationContext): - raise TypeError( - f"compilation_context must be of type _CompilationContext, not " - f"{type(compilation_context)}" - ) - super().__init__(compilation_context) - - # pylint: disable=arguments-differ @staticmethod def new() -> "CompilationContext": - """Build a CompilationContext. - - Returns: - CompilationContext """ - return CompilationContext.wrap(_CompilationContext()) - - def mlir_context( - self, - ) -> MlirContext: + Creates a new CompilationContext. """ - Get the MLIR context used by the compilation context. + return CompilationContext() - The Compilation Context should outlive the mlir_context. - - Returns: - MlirContext: MLIR context of the compilation context + def mlir_context(self) -> MlirContext: + """ + Returns the associated mlir context. """ - # pylint: disable=protected-access - return MlirContext._CAPICreate(self.cpp().mlir_context()) - # pylint: enable=protected-access + return MlirContext._CAPICreate(_CompilationContext.mlir_context(self)) diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py index f073150ccf..294c36e118 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py @@ -8,19 +8,14 @@ # pylint: disable=no-name-in-module,import-error,too-many-instance-attributes from mlir._mlir_libs._concretelang._compiler import ( - ProgramCompilationFeedback as _ProgramCompilationFeedback, - CircuitCompilationFeedback as _CircuitCompilationFeedback, + CircuitCompilationFeedback, KeyType, PrimitiveOperation, + ProgramInfo, ) # pylint: enable=no-name-in-module,import-error -from .client_parameters import ClientParameters -from .parameter import Parameter -from .wrapper import WrapperCpp - - # matches (@tag, separator( | ), filename) REGEX_LOCATION = re.compile(r"loc\(\"(@[\w\.]+)?( \| )?(.+)\"") @@ -40,38 +35,17 @@ def tag_from_location(location): return tag -class CircuitCompilationFeedback(WrapperCpp): - """CircuitCompilationFeedback is a set of hint computed by the compiler engine for a circuit.""" - - def __init__(self, circuit_compilation_feedback: _CircuitCompilationFeedback): - """Wrap the native Cpp object. - - Args: - circuit_compilation_feeback (_CircuitCompilationFeedback): object to wrap - - Raises: - TypeError: if circuit_compilation_feedback is not of type _CircuitCompilationFeedback - """ - if not isinstance(circuit_compilation_feedback, _CircuitCompilationFeedback): - raise TypeError( - "circuit_compilation_feedback must be of type " - f"_CircuitCompilationFeedback, not {type(circuit_compilation_feedback)}" - ) - - self.name = circuit_compilation_feedback.name - self.total_inputs_size = circuit_compilation_feedback.total_inputs_size - self.total_output_size = circuit_compilation_feedback.total_output_size - self.crt_decompositions_of_outputs = ( - circuit_compilation_feedback.crt_decompositions_of_outputs - ) - self.statistics = circuit_compilation_feedback.statistics - self.memory_usage_per_location = ( - circuit_compilation_feedback.memory_usage_per_location - ) - - super().__init__(circuit_compilation_feedback) +class MoreCircuitCompilationFeedback: + """ + Helper class for compilation feedback. + """ - def count(self, *, operations: Set[PrimitiveOperation]) -> int: + @staticmethod + def count( + circuit_feedback: CircuitCompilationFeedback, + *, + operations: Set[PrimitiveOperation], + ) -> int: """ Count the amount of specified operations in the program. @@ -86,17 +60,18 @@ def count(self, *, operations: Set[PrimitiveOperation]) -> int: return sum( statistic.count - for statistic in self.statistics + for statistic in circuit_feedback.statistics if statistic.operation in operations ) + @staticmethod def count_per_parameter( - self, + circuit_feedback: CircuitCompilationFeedback, *, operations: Set[PrimitiveOperation], key_types: Set[KeyType], - client_parameters: ClientParameters, - ) -> Dict[Parameter, int]: + program_info: ProgramInfo, + ) -> Dict["Parameter", int]: """ Count the amount of specified operations in the program and group by parameters. @@ -107,8 +82,8 @@ def count_per_parameter( key_types (Set[KeyType]): set of key types used to filter the statistics - client_parameters (ClientParameters): - client parameters required for grouping by parameters + program_info (ProgramInfo): + program info required for grouping by parameters Returns: Dict[Parameter, int]: @@ -116,7 +91,7 @@ def count_per_parameter( """ result = {} - for statistic in self.statistics: + for statistic in circuit_feedback.statistics: if statistic.operation not in operations: continue @@ -124,14 +99,28 @@ def count_per_parameter( if key_type not in key_types: continue - parameter = Parameter(client_parameters, key_type, key_index) + if key_type == KeyType.SECRET: + parameter = program_info.secret_keys()[key_index] + elif key_type == KeyType.BOOTSTRAP: + parameter = program_info.bootstrap_keys()[key_index] + elif key_type == KeyType.KEY_SWITCH: + parameter = program_info.keyswitch_keys()[key_index] + elif key_type == KeyType.PACKING_KEY_SWITCH: + parameter = program_info.packing_keyswitch_keys()[key_index] + else: + assert False if parameter not in result: result[parameter] = 0 result[parameter] += statistic.count return result - def count_per_tag(self, *, operations: Set[PrimitiveOperation]) -> Dict[str, int]: + @staticmethod + def count_per_tag( + circuit_feedback: CircuitCompilationFeedback, + *, + operations: Set[PrimitiveOperation], + ) -> Dict[str, int]: """ Count the amount of specified operations in the program and group by tags. @@ -145,7 +134,7 @@ def count_per_tag(self, *, operations: Set[PrimitiveOperation]) -> Dict[str, int """ result = {} - for statistic in self.statistics: + for statistic in circuit_feedback.statistics: if statistic.operation not in operations: continue @@ -164,13 +153,15 @@ def count_per_tag(self, *, operations: Set[PrimitiveOperation]) -> Dict[str, int return result + @staticmethod + # pylint: disable=too-many-branches def count_per_tag_per_parameter( - self, + circuit_feedback: CircuitCompilationFeedback, *, operations: Set[PrimitiveOperation], key_types: Set[KeyType], - client_parameters: ClientParameters, - ) -> Dict[str, Dict[Parameter, int]]: + program_info: ProgramInfo, + ) -> Dict[str, Dict["Parameter", int]]: """ Count the amount of specified operations in the program and group by tags and parameters. @@ -181,8 +172,8 @@ def count_per_tag_per_parameter( key_types (Set[KeyType]): set of key types used to filter the statistics - client_parameters (ClientParameters): - client parameters required for grouping by parameters + program_info (ProgramInfo): + program info required for grouping by parameters Returns: Dict[str, Dict[Parameter, int]]: @@ -190,7 +181,7 @@ def count_per_tag_per_parameter( """ result: Dict[str, Dict[int, int]] = {} - for statistic in self.statistics: + for statistic in circuit_feedback.statistics: if statistic.operation not in operations: continue @@ -209,74 +200,19 @@ def count_per_tag_per_parameter( if key_type not in key_types: continue - parameter = Parameter(client_parameters, key_type, key_index) + if key_type == KeyType.SECRET: + parameter = program_info.secret_keys()[key_index] + elif key_type == KeyType.BOOTSTRAP: + parameter = program_info.bootstrap_keys()[key_index] + elif key_type == KeyType.KEY_SWITCH: + parameter = program_info.keyswitch_keys()[key_index] + elif key_type == KeyType.PACKING_KEY_SWITCH: + parameter = program_info.packing_keyswitch_keys()[key_index] + else: + assert False + if parameter not in result[current_tag]: result[current_tag][parameter] = 0 result[current_tag][parameter] += statistic.count return result - - -class ProgramCompilationFeedback(WrapperCpp): - """CompilationFeedback is a set of hint computed by the compiler engine.""" - - def __init__(self, program_compilation_feedback: _ProgramCompilationFeedback): - """Wrap the native Cpp object. - - Args: - compilation_feeback (_CompilationFeedback): object to wrap - - Raises: - TypeError: if program_compilation_feedback is not of type _CompilationFeedback - """ - if not isinstance(program_compilation_feedback, _ProgramCompilationFeedback): - raise TypeError( - "program_compilation_feedback must be of type " - f"_CompilationFeedback, not {type(program_compilation_feedback)}" - ) - - self.complexity = program_compilation_feedback.complexity - self.p_error = program_compilation_feedback.p_error - self.global_p_error = program_compilation_feedback.global_p_error - self.total_secret_keys_size = ( - program_compilation_feedback.total_secret_keys_size - ) - self.total_bootstrap_keys_size = ( - program_compilation_feedback.total_bootstrap_keys_size - ) - self.total_keyswitch_keys_size = ( - program_compilation_feedback.total_keyswitch_keys_size - ) - self.circuit_feedbacks = [ - CircuitCompilationFeedback(c) - for c in program_compilation_feedback.circuit_feedbacks - ] - - super().__init__(program_compilation_feedback) - - def circuit(self, circuit_name: str) -> CircuitCompilationFeedback: - """ - Returns the feedback for the circuit circuit_name. - - Args: - circuit_name (str): - the name of the circuit. - - Returns: - CircuitCompilationFeedback: - the feedback for the circuit. - - Raises: - TypeError: if the circuit_name is not a string - ValueError: if there is no circuit with name circuit_name - """ - if not isinstance(circuit_name, str): - raise TypeError( - f"circuit_name must be of type str, not {type(circuit_name)}" - ) - - for circuit in self.circuit_feedbacks: - if circuit.name == circuit_name: - return circuit - - raise ValueError(f"no circuit with name {circuit_name} found") diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_options.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_options.py deleted file mode 100644 index 63e73c27ca..0000000000 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_options.py +++ /dev/null @@ -1,522 +0,0 @@ -# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. -# See https://github.com/zama-ai/concrete/blob/main/LICENSE.txt for license information. - -"""CompilationOptions.""" - -from typing import List - -# pylint: disable=no-name-in-module,import-error -from mlir._mlir_libs._concretelang._compiler import ( - CompilationOptions as _CompilationOptions, - OptimizerStrategy as _OptimizerStrategy, - OptimizerMultiParameterStrategy as _OptimizerMultiParameterStrategy, - Encoding, - Backend as _Backend, -) -from .wrapper import WrapperCpp - - -# pylint: enable=no-name-in-module,import-error - - -class CompilationOptions(WrapperCpp): - """CompilationOptions holds different flags and options of the compilation process. - - It controls different parallelization flags, diagnostic verification, and also the name of entrypoint - function. - """ - - def __init__(self, compilation_options: _CompilationOptions): - """Wrap the native Cpp object. - - Args: - compilation_options (_CompilationOptions): object to wrap - - Raises: - TypeError: if compilation_options is not of type _CompilationOptions - """ - if not isinstance(compilation_options, _CompilationOptions): - raise TypeError( - f"_compilation_options must be of type _CompilationOptions, not {type(compilation_options)}" - ) - super().__init__(compilation_options) - - @staticmethod - # pylint: disable=arguments-differ - def new(backend=_Backend.CPU) -> "CompilationOptions": - """Build a CompilationOptions. - - Args: - backend (_Backend): backend to use. - - Raises: - TypeError: if function_name is not an str - - Returns: - CompilationOptions - """ - if not isinstance(backend, _Backend): - raise TypeError(f"backend must be of type Backend not {type(backend)}") - return CompilationOptions.wrap(_CompilationOptions(backend)) - - # pylint: enable=arguments-differ - - def add_composition(self, from_func: str, from_pos: int, to_func: str, to_pos: int): - """Adds a composition rule. - - Args: - from_func(str): the name of the circuit the output comes from. - from_pos(int): the return position of the output. - to_func(str): the name of the circuit the input targets. - to_pos(int): the argument position of the input. - - Raises: - TypeError: if the inputs do not have the proper type. - """ - if not isinstance(from_func, str): - raise TypeError("expected `from_func` to be (str)") - if not isinstance(from_pos, int): - raise TypeError("expected `from_pos` to be (int)") - if not isinstance(to_func, str): - raise TypeError("expected `to_func` to be (str)") - if not isinstance(from_pos, int): - raise TypeError("expected `to_pos` to be (int)") - self.cpp().add_composition(from_func, from_pos, to_func, to_pos) - - def set_composable(self, composable: bool): - """Set composable flag. - - Args: - composable(bool): the composable flag. - - Raises: - TypeError: if the inputs do not have the proper type. - """ - if not isinstance(composable, bool): - raise TypeError("expected `composable` to be (bool)") - self.cpp().set_composable(composable) - - def set_auto_parallelize(self, auto_parallelize: bool): - """Set option for auto parallelization. - - Args: - auto_parallelize (bool): whether to turn it on or off - - Raises: - TypeError: if the value to set is not boolean - """ - if not isinstance(auto_parallelize, bool): - raise TypeError("can't set the option to a non-boolean value") - self.cpp().set_auto_parallelize(auto_parallelize) - - def set_loop_parallelize(self, loop_parallelize: bool): - """Set option for loop parallelization. - - Args: - loop_parallelize (bool): whether to turn it on or off - - Raises: - TypeError: if the value to set is not boolean - """ - if not isinstance(loop_parallelize, bool): - raise TypeError("can't set the option to a non-boolean value") - self.cpp().set_loop_parallelize(loop_parallelize) - - def set_compress_evaluation_keys(self, compress_evaluation_keys: bool): - """Set option for compression of evaluation keys. - - Args: - compress_evaluation_keys (bool): whether to turn it on or off - - Raises: - TypeError: if the value to set is not boolean - """ - if not isinstance(compress_evaluation_keys, bool): - raise TypeError("can't set the option to a non-boolean value") - self.cpp().set_compress_evaluation_keys(compress_evaluation_keys) - - def set_compress_input_ciphertexts(self, compress_input_ciphertexts: bool): - """Set option for compression of input ciphertexts. - - Args: - compress_input_ciphertexts (bool): whether to turn it on or off - - Raises: - TypeError: if the value to set is not boolean - """ - if not isinstance(compress_input_ciphertexts, bool): - raise TypeError("can't set the option to a non-boolean value") - self.cpp().set_compress_input_ciphertexts(compress_input_ciphertexts) - - def set_verify_diagnostics(self, verify_diagnostics: bool): - """Set option for diagnostics verification. - - Args: - verify_diagnostics (bool): whether to turn it on or off - - Raises: - TypeError: if the value to set is not boolean - """ - if not isinstance(verify_diagnostics, bool): - raise TypeError("can't set the option to a non-boolean value") - self.cpp().set_verify_diagnostics(verify_diagnostics) - - def set_dataflow_parallelize(self, dataflow_parallelize: bool): - """Set option for dataflow parallelization. - - Args: - dataflow_parallelize (bool): whether to turn it on or off - - Raises: - TypeError: if the value to set is not boolean - """ - if not isinstance(dataflow_parallelize, bool): - raise TypeError("can't set the option to a non-boolean value") - self.cpp().set_dataflow_parallelize(dataflow_parallelize) - - def set_optimize_concrete(self, optimize: bool): - """Set flag to enable/disable optimization of concrete intermediate representation. - - Args: - optimize (bool): whether to turn it on or off - - Raises: - TypeError: if the value to set is not boolean - """ - if not isinstance(optimize, bool): - raise TypeError("can't set the option to a non-boolean value") - self.cpp().set_optimize_concrete(optimize) - - def set_funcname(self, funcname: str): - """Set entrypoint function name. - - Args: - funcname (str): name of the entrypoint function - - Raises: - TypeError: if the value to set is not str - """ - if not isinstance(funcname, str): - raise TypeError("can't set the option to a non-str value") - self.cpp().set_funcname(funcname) - - def set_p_error(self, p_error: float): - """Set error probability for shared by each pbs. - - Args: - p_error (float): probability of error for each lut - - Raises: - TypeError: if the value to set is not float - ValueError: if the value to set is not in interval ]0; 1] - """ - if not isinstance(p_error, float): - raise TypeError("can't set p_error to a non-float value") - if p_error == 0.0: - raise ValueError("p_error cannot be 0") - if not 0.0 <= p_error <= 1.0: - raise ValueError("p_error should be a probability in ]0; 1]") - self.cpp().set_p_error(p_error) - - def set_display_optimizer_choice(self, display: bool): - """Set display flag of optimizer choices. - - Args: - display (bool): if true the compiler display optimizer choices - - Raises: - TypeError: if the value is not a bool - """ - if not isinstance(display, bool): - raise TypeError("display should be a bool") - self.cpp().set_display_optimizer_choice(display) - - def set_optimizer_strategy(self, strategy: _OptimizerStrategy): - """Set the strategy of the optimizer. - - Args: - strategy (OptimizerStrategy): Use the specified optmizer strategy. - - Raises: - TypeError: if the value is not an OptimizerStrategy - """ - if not isinstance(strategy, _OptimizerStrategy): - raise TypeError("enable should be a bool") - self.cpp().set_optimizer_strategy(strategy) - - def set_optimizer_multi_parameter_strategy( - self, strategy: _OptimizerMultiParameterStrategy - ): - """Set the strategy of the optimizer for multi-parameter. - - Args: - strategy (OptimizerMultiParameterStrategy): Use the specified optmizer multi-parameter strategy. - - Raises: - TypeError: if the value is not a OptimizerMultiParameterStrategy - """ - if not isinstance(strategy, _OptimizerMultiParameterStrategy): - raise TypeError("enable should be a bool") - self.cpp().set_optimizer_multi_parameter_strategy(strategy) - - def set_global_p_error(self, global_p_error: float): - """Set global error probability for the full circuit. - - Args: - global_p_error (float): probability of error for the full circuit - - Raises: - TypeError: if the value to set is not float - ValueError: if the value to set is not in interval ]0; 1] - """ - if not isinstance(global_p_error, float): - raise TypeError("can't set global_p_error to a non-float value") - if global_p_error == 0.0: - raise ValueError("global_p_error cannot be 0") - if not 0.0 <= global_p_error <= 1.0: - raise ValueError("global_p_error be a probability in ]0; 1]") - self.cpp().set_global_p_error(global_p_error) - - def set_security_level(self, security_level: int): - """Set security level. - - Args: - security_level (int): the target number of bits of security to compile the circuit - - Raises: - TypeError: if the value to set is not int - ValueError: if the value to set is not in interval ]0; 1] - """ - if not isinstance(security_level, int): - raise TypeError("can't set security_level to a non-int value") - self.cpp().set_security_level(security_level) - - def set_v0_parameter( - self, - glwe_dim: int, - log_poly_size: int, - n_small: int, - br_level: int, - br_log_base: int, - ks_level: int, - ks_log_base: int, - ): - """Set the basic V0 parameters. - - Args: - glwe_dim (int): GLWE dimension - log_poly_size (int): log of polynomial size - n_small (int): n - br_level (int): bootstrap level - br_log_base (int): bootstrap base log - ks_level (int): keyswitch level - ks_log_base (int): keyswitch base log - - Raises: - TypeError: if parameters are not of type int - """ - if not isinstance(glwe_dim, int): - raise TypeError("glwe_dim need to be an integer") - if not isinstance(log_poly_size, int): - raise TypeError("log_poly_size need to be an integer") - if not isinstance(n_small, int): - raise TypeError("n_small need to be an integer") - if not isinstance(br_level, int): - raise TypeError("br_level need to be an integer") - if not isinstance(br_log_base, int): - raise TypeError("br_log_base need to be an integer") - if not isinstance(ks_level, int): - raise TypeError("ks_level need to be an integer") - if not isinstance(ks_log_base, int): - raise TypeError("ks_log_base need to be an integer") - self.cpp().set_v0_parameter( - glwe_dim, - log_poly_size, - n_small, - br_level, - br_log_base, - ks_level, - ks_log_base, - ) - - # pylint: disable=too-many-arguments,too-many-branches - - def set_all_v0_parameter( - self, - glwe_dim: int, - log_poly_size: int, - n_small: int, - br_level: int, - br_log_base: int, - ks_level: int, - ks_log_base: int, - crt_decomp: List[int], - cbs_level: int, - cbs_log_base: int, - pks_level: int, - pks_log_base: int, - pks_input_lwe_dim: int, - pks_output_poly_size: int, - ): - """Set all the V0 parameters. - - Args: - glwe_dim (int): GLWE dimension - log_poly_size (int): log of polynomial size - n_small (int): n - br_level (int): bootstrap level - br_log_base (int): bootstrap base log - ks_level (int): keyswitch level - ks_log_base (int): keyswitch base log - crt_decomp (List[int]): CRT decomposition vector - cbs_level (int): circuit bootstrap level - cbs_log_base (int): circuit bootstrap base log - pks_level (int): packing keyswitch level - pks_log_base (int): packing keyswitch base log - pks_input_lwe_dim (int): packing keyswitch input LWE dimension - pks_output_poly_size (int): packing keyswitch output polynomial size - - Raises: - TypeError: if parameters are not of type int - """ - if not isinstance(glwe_dim, int): - raise TypeError("glwe_dim need to be an integer") - if not isinstance(log_poly_size, int): - raise TypeError("log_poly_size need to be an integer") - if not isinstance(n_small, int): - raise TypeError("n_small need to be an integer") - if not isinstance(br_level, int): - raise TypeError("br_level need to be an integer") - if not isinstance(br_log_base, int): - raise TypeError("br_log_base need to be an integer") - if not isinstance(ks_level, int): - raise TypeError("ks_level need to be an integer") - if not isinstance(ks_log_base, int): - raise TypeError("ks_log_base need to be an integer") - if not isinstance(crt_decomp, list): - raise TypeError("crt_decomp need to be a list of integers") - if not isinstance(cbs_level, int): - raise TypeError("cbs_level need to be an integer") - if not isinstance(cbs_log_base, int): - raise TypeError("cbs_log_base need to be an integer") - if not isinstance(pks_level, int): - raise TypeError("pks_level need to be an integer") - if not isinstance(pks_log_base, int): - raise TypeError("pks_log_base need to be an integer") - if not isinstance(pks_input_lwe_dim, int): - raise TypeError("pks_input_lwe_dim need to be an integer") - if not isinstance(pks_output_poly_size, int): - raise TypeError("pks_output_poly_size need to be an integer") - self.cpp().set_v0_parameter( - glwe_dim, - log_poly_size, - n_small, - br_level, - br_log_base, - ks_level, - ks_log_base, - crt_decomp, - cbs_level, - cbs_log_base, - pks_level, - pks_log_base, - pks_input_lwe_dim, - pks_output_poly_size, - ) - - # pylint: enable=too-many-arguments,too-many-branches - - def force_encoding(self, encoding: Encoding): - """Force the compiler to use a specific encoding. - - Args: - encoding (Encoding): the encoding to force the compiler to use - - Raises: - TypeError: if encoding is not of type Encoding - """ - if not isinstance(encoding, Encoding): - raise TypeError("encoding need to be of type Encoding") - self.cpp().force_encoding(encoding) - - def simulation(self, simulate: bool): - """Enable or disable simulation. - - Args: - simulate (bool): flag to enable or disable simulation - - Raises: - TypeError: if the value to set is not bool - """ - if not isinstance(simulate, bool): - raise TypeError("need to pass a boolean value") - self.cpp().simulation(simulate) - - def set_emit_gpu_ops(self, emit_gpu_ops: bool): - """Set flag that allows gpu ops to be emitted. - - Args: - emit_gpu_ops (bool): whether to emit gpu ops. - - Raises: - TypeError: if the value to set is not bool - """ - if not isinstance(emit_gpu_ops, bool): - raise TypeError("emit_gpu_ops must be boolean") - self.cpp().set_emit_gpu_ops(emit_gpu_ops) - - def set_batch_tfhe_ops(self, batch_tfhe_ops: bool): - """Set flag that triggers the batching of scalar TFHE operations. - - Args: - batch_tfhe_ops (bool): whether to batch tfhe ops. - - Raises: - TypeError: if the value to set is not bool - """ - if not isinstance(batch_tfhe_ops, bool): - raise TypeError("batch_tfhe_ops must be boolean") - self.cpp().set_batch_tfhe_ops(batch_tfhe_ops) - - def set_enable_tlu_fusing(self, enable_tlu_fusing: bool): - """Enable or disable tlu fusing. - - Args: - enable_tlu_fusing (bool): flag to enable or disable tlu fusing - - Raises: - TypeError: if the value to set is not bool - """ - if not isinstance(enable_tlu_fusing, bool): - raise TypeError("need to pass a boolean value") - self.cpp().set_enable_tlu_fusing(enable_tlu_fusing) - - def set_print_tlu_fusing(self, print_tlu_fusing: bool): - """Enable or disable printing tlu fusing. - - Args: - print_tlu_fusing (bool): flag to enable or disable printing tlu fusing - - Raises: - TypeError: if the value to set is not bool - """ - if not isinstance(print_tlu_fusing, bool): - raise TypeError("need to pass a boolean value") - self.cpp().set_print_tlu_fusing(print_tlu_fusing) - - def set_enable_overflow_detection_in_simulation( - self, enable_overflow_detection: bool - ): - """Enable or disable overflow detection during simulation. - - Args: - enable_overflow_detection (bool): flag to enable or disable overflow detection - - Raises: - TypeError: if the value to set is not bool - """ - if not isinstance(enable_overflow_detection, bool): - raise TypeError("need to pass a boolean value") - self.cpp().set_enable_overflow_detection_in_simulation( - enable_overflow_detection - ) diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/evaluation_keys.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/evaluation_keys.py deleted file mode 100644 index 9aa375523d..0000000000 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/evaluation_keys.py +++ /dev/null @@ -1,63 +0,0 @@ -# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. -# See https://github.com/zama-ai/concrete/blob/main/LICENSE.txt for license information. - -"""EvaluationKeys.""" - -# pylint: disable=no-name-in-module,import-error -from mlir._mlir_libs._concretelang._compiler import ( - EvaluationKeys as _EvaluationKeys, -) - -# pylint: enable=no-name-in-module,import-error -from .wrapper import WrapperCpp - - -class EvaluationKeys(WrapperCpp): - """ - EvaluationKeys required for execution. - """ - - def __init__(self, evaluation_keys: _EvaluationKeys): - """Wrap the native Cpp object. - - Args: - evaluation_keys (_EvaluationKeys): object to wrap - - Raises: - TypeError: if evaluation_keys is not of type _EvaluationKeys - """ - if not isinstance(evaluation_keys, _EvaluationKeys): - raise TypeError( - f"evaluation_keys must be of type _EvaluationKeys, not {type(evaluation_keys)}" - ) - super().__init__(evaluation_keys) - - def serialize(self) -> bytes: - """Serialize the EvaluationKeys. - - Returns: - bytes: serialized object - """ - return self.cpp().serialize() - - @staticmethod - def deserialize(serialized_evaluation_keys: bytes) -> "EvaluationKeys": - """Unserialize EvaluationKeys from bytes. - - Args: - serialized_evaluation_keys (bytes): previously serialized EvaluationKeys - - Raises: - TypeError: if serialized_evaluation_keys is not of type bytes - - Returns: - EvaluationKeys: deserialized object - """ - if not isinstance(serialized_evaluation_keys, bytes): - raise TypeError( - f"serialized_evaluation_keys must be of type bytes, " - f"not {type(serialized_evaluation_keys)}" - ) - return EvaluationKeys.wrap( - _EvaluationKeys.deserialize(serialized_evaluation_keys) - ) diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/key_set.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/key_set.py deleted file mode 100644 index 0cc14335f9..0000000000 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/key_set.py +++ /dev/null @@ -1,91 +0,0 @@ -# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. -# See https://github.com/zama-ai/concrete/blob/main/LICENSE.txt for license information. - - -"""KeySet. - -Store for the different keys required for an encrypted computation. -""" - -# pylint: disable=no-name-in-module,import-error -from mlir._mlir_libs._concretelang._compiler import ( - KeySet as _KeySet, -) - -# pylint: enable=no-name-in-module,import-error -from .lwe_secret_key import LweSecretKey -from .wrapper import WrapperCpp -from .evaluation_keys import EvaluationKeys - - -class KeySet(WrapperCpp): - """KeySet stores the different keys required for an encrypted computation. - - Holds private keys (secret key) used for encryption/decryption, and public keys used for computation. - """ - - def __init__(self, keyset: _KeySet): - """Wrap the native Cpp object. - - Args: - keyset (_KeySet): object to wrap - - Raises: - TypeError: if keyset is not of type _KeySet - """ - if not isinstance(keyset, _KeySet): - raise TypeError(f"keyset must be of type _KeySet, not {type(keyset)}") - super().__init__(keyset) - - def serialize(self) -> bytes: - """Serialize the KeySet. - - Returns: - bytes: serialized object - """ - return self.cpp().serialize() - - @staticmethod - def deserialize(serialized_key_set: bytes) -> "KeySet": - """Deserialize KeySet from bytes. - - Args: - serialized_key_set (bytes): previously serialized KeySet - - Raises: - TypeError: if serialized_key_set is not of type bytes - - Returns: - KeySet: deserialized object - """ - if not isinstance(serialized_key_set, bytes): - raise TypeError( - f"serialized_key_set must be of type bytes, not {type(serialized_key_set)}" - ) - return KeySet.wrap(_KeySet.deserialize(serialized_key_set)) - - def get_lwe_secret_key(self, keyid: int) -> LweSecretKey: - """Get a specific LweSecretKey. - - Args: - keyid (int): id of the key to get - - Raises: - TypeError: if wrong types for input arguments - - Returns: - bytes: LweSecretKey - """ - if not isinstance(keyid, int): - raise TypeError(f"keyid must be of type int, not {type(keyid)}") - return LweSecretKey.wrap(self.cpp().get_lwe_secret_key(keyid)) - - def get_evaluation_keys(self) -> EvaluationKeys: - """ - Get evaluation keys for execution. - - Returns: - EvaluationKeys: - evaluation keys for execution - """ - return EvaluationKeys(self.cpp().get_evaluation_keys()) diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/key_set_cache.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/key_set_cache.py deleted file mode 100644 index d5cdd56681..0000000000 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/key_set_cache.py +++ /dev/null @@ -1,59 +0,0 @@ -# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. -# See https://github.com/zama-ai/concrete/blob/main/LICENSE.txt for license information. - -"""KeySetCache. - -Cache for keys to avoid generating similar keys multiple times. -""" - -# pylint: disable=no-name-in-module,import-error -from mlir._mlir_libs._concretelang._compiler import ( - KeySetCache as _KeySetCache, -) - -# pylint: enable=no-name-in-module,import-error -from .wrapper import WrapperCpp - - -class KeySetCache(WrapperCpp): - """KeySetCache is a cache for KeySet to avoid generating similar keys multiple times. - - Keys get cached and can be later used instead of generating a new keyset which can take a lot of time. - """ - - def __init__(self, keyset_cache: _KeySetCache): - """Wrap the native Cpp object. - - Args: - keyset_cache (_KeySetCache): object to wrap - - Raises: - TypeError: if keyset_cache is not of type _KeySetCache - """ - if not isinstance(keyset_cache, _KeySetCache): - raise TypeError( - f"key_set_cache must be of type _KeySetCache, not {type(keyset_cache)}" - ) - super().__init__(keyset_cache) - - @staticmethod - # pylint: disable=arguments-differ - def new(cache_path: str) -> "KeySetCache": - """Build a KeySetCache located at cache_path. - - Args: - cache_path (str): path to the cache - - Raises: - TypeError: if the path is not of type str. - - Returns: - KeySetCache - """ - if not isinstance(cache_path, str): - raise TypeError( - f"cache_path must to be of type str, not {type(cache_path)}" - ) - return KeySetCache.wrap(_KeySetCache(cache_path)) - - # pylint: enable=arguments-differ diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/lambda_argument.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/lambda_argument.py deleted file mode 100644 index ff8fc86bec..0000000000 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/lambda_argument.py +++ /dev/null @@ -1,250 +0,0 @@ -# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. -# See https://github.com/zama-ai/concrete/blob/main/LICENSE.txt for license information. - -"""LambdaArgument.""" -from typing import List - -# pylint: disable=no-name-in-module,import-error -from mlir._mlir_libs._concretelang._compiler import ( - LambdaArgument as _LambdaArgument, -) - -# pylint: enable=no-name-in-module,import-error -from .utils import ACCEPTED_INTS -from .wrapper import WrapperCpp - - -class LambdaArgument(WrapperCpp): - """LambdaArgument holds scalar or tensor values.""" - - def __init__(self, lambda_argument: _LambdaArgument): - """Wrap the native Cpp object. - - Args: - lambda_argument (_LambdaArgument): object to wrap - - Raises: - TypeError: if lambda_argument is not of type _LambdaArgument - """ - if not isinstance(lambda_argument, _LambdaArgument): - raise TypeError( - f"lambda_argument must be of type _LambdaArgument, not {type(lambda_argument)}" - ) - super().__init__(lambda_argument) - - @staticmethod - def new(*args, **kwargs): - """Use from_scalar or from_tensor instead. - - Raises: - RuntimeError - """ - raise RuntimeError( - "you should call from_scalar or from_tensor according to the argument type" - ) - - @staticmethod - def from_scalar(scalar: int) -> "LambdaArgument": - """Build a LambdaArgument containing the given scalar value. - - Args: - scalar (int or numpy.uint): scalar value to embed in LambdaArgument - - Raises: - TypeError: if scalar is not of type int or numpy.uint - - Returns: - LambdaArgument - """ - if not isinstance(scalar, ACCEPTED_INTS): - raise TypeError( - f"scalar must be of type int or numpy.int, not {type(scalar)}" - ) - return LambdaArgument.wrap(_LambdaArgument.from_scalar(scalar)) - - @staticmethod - def from_signed_scalar(scalar: int) -> "LambdaArgument": - """Build a LambdaArgument containing the given scalar value. - - Args: - scalar (int or numpy.int): scalar value to embed in LambdaArgument - - Raises: - TypeError: if scalar is not of type int or numpy.uint - - Returns: - LambdaArgument - """ - if not isinstance(scalar, ACCEPTED_INTS): - raise TypeError( - f"scalar must be of type int or numpy.uint, not {type(scalar)}" - ) - return LambdaArgument.wrap(_LambdaArgument.from_signed_scalar(scalar)) - - @staticmethod - def from_tensor_u8(data: List[int], shape: List[int]) -> "LambdaArgument": - """Build a LambdaArgument containing the given tensor. - - Args: - data (List[int]): flattened tensor data - shape (List[int]): shape of original tensor before flattening - - Returns: - LambdaArgument - """ - return LambdaArgument.wrap(_LambdaArgument.from_tensor_u8(data, shape)) - - @staticmethod - def from_tensor_u16(data: List[int], shape: List[int]) -> "LambdaArgument": - """Build a LambdaArgument containing the given tensor. - - Args: - data (List[int]): flattened tensor data - shape (List[int]): shape of original tensor before flattening - - Returns: - LambdaArgument - """ - return LambdaArgument.wrap(_LambdaArgument.from_tensor_u16(data, shape)) - - @staticmethod - def from_tensor_u32(data: List[int], shape: List[int]) -> "LambdaArgument": - """Build a LambdaArgument containing the given tensor. - - Args: - data (List[int]): flattened tensor data - shape (List[int]): shape of original tensor before flattening - - Returns: - LambdaArgument - """ - return LambdaArgument.wrap(_LambdaArgument.from_tensor_u32(data, shape)) - - @staticmethod - def from_tensor_u64(data: List[int], shape: List[int]) -> "LambdaArgument": - """Build a LambdaArgument containing the given tensor. - - Args: - data (List[int]): flattened tensor data - shape (List[int]): shape of original tensor before flattening - - Returns: - LambdaArgument - """ - return LambdaArgument.wrap(_LambdaArgument.from_tensor_u64(data, shape)) - - @staticmethod - def from_tensor_i8(data: List[int], shape: List[int]) -> "LambdaArgument": - """Build a LambdaArgument containing the given tensor. - - Args: - data (List[int]): flattened tensor data - shape (List[int]): shape of original tensor before flattening - - Returns: - LambdaArgument - """ - return LambdaArgument.wrap(_LambdaArgument.from_tensor_i8(data, shape)) - - @staticmethod - def from_tensor_i16(data: List[int], shape: List[int]) -> "LambdaArgument": - """Build a LambdaArgument containing the given tensor. - - Args: - data (List[int]): flattened tensor data - shape (List[int]): shape of original tensor before flattening - - Returns: - LambdaArgument - """ - return LambdaArgument.wrap(_LambdaArgument.from_tensor_i16(data, shape)) - - @staticmethod - def from_tensor_i32(data: List[int], shape: List[int]) -> "LambdaArgument": - """Build a LambdaArgument containing the given tensor. - - Args: - data (List[int]): flattened tensor data - shape (List[int]): shape of original tensor before flattening - - Returns: - LambdaArgument - """ - return LambdaArgument.wrap(_LambdaArgument.from_tensor_i32(data, shape)) - - @staticmethod - def from_tensor_i64(data: List[int], shape: List[int]) -> "LambdaArgument": - """Build a LambdaArgument containing the given tensor. - - Args: - data (List[int]): flattened tensor data - shape (List[int]): shape of original tensor before flattening - - Returns: - LambdaArgument - """ - return LambdaArgument.wrap(_LambdaArgument.from_tensor_i64(data, shape)) - - def is_signed(self) -> bool: - """Check if the contained argument is signed. - - Returns: - bool - """ - return self.cpp().is_signed() - - def is_scalar(self) -> bool: - """Check if the contained argument is a scalar. - - Returns: - bool - """ - return self.cpp().is_scalar() - - def get_scalar(self) -> int: - """Return the contained scalar value. - - Returns: - int - """ - return self.cpp().get_scalar() - - def get_signed_scalar(self) -> int: - """Return the contained scalar value. - - Returns: - int - """ - return self.cpp().get_signed_scalar() - - def is_tensor(self) -> bool: - """Check if the contained argument is a tensor. - - Returns: - bool - """ - return self.cpp().is_tensor() - - def get_tensor_shape(self) -> List[int]: - """Return the shape of the contained tensor. - - Returns: - List[int]: tensor shape - """ - return self.cpp().get_tensor_shape() - - def get_tensor_data(self) -> List[int]: - """Return the contained flattened tensor data. - - Returns: - List[int] - """ - return self.cpp().get_tensor_data() - - def get_signed_tensor_data(self) -> List[int]: - """Return the contained flattened tensor data. - - Returns: - List[int] - """ - return self.cpp().get_signed_tensor_data() diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/library_compilation_result.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/library_compilation_result.py deleted file mode 100644 index 1c76488be4..0000000000 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/library_compilation_result.py +++ /dev/null @@ -1,54 +0,0 @@ -# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. -# See https://github.com/zama-ai/concrete/blob/main/LICENSE.txt for license information. - -"""LibraryCompilationResult.""" - -# pylint: disable=no-name-in-module,import-error -from mlir._mlir_libs._concretelang._compiler import ( - LibraryCompilationResult as _LibraryCompilationResult, -) - -# pylint: enable=no-name-in-module,import-error -from .wrapper import WrapperCpp - - -class LibraryCompilationResult(WrapperCpp): - """LibraryCompilationResult holds the result of the library compilation.""" - - def __init__(self, library_compilation_result: _LibraryCompilationResult): - """Wrap the native Cpp object. - - Args: - library_compilation_result (_LibraryCompilationResult): object to wrap - - Raises: - TypeError: if library_compilation_result is not of type _LibraryCompilationResult - """ - if not isinstance(library_compilation_result, _LibraryCompilationResult): - raise TypeError( - f"library_compilation_result must be of type _LibraryCompilationResult, not " - f"{type(library_compilation_result)}" - ) - super().__init__(library_compilation_result) - - @staticmethod - # pylint: disable=arguments-differ - def new(output_dir_path: str) -> "LibraryCompilationResult": - """Build a LibraryCompilationResult at output_dir_path. - - Args: - output_dir_path (str): path to the compilation artifacts - - Raises: - TypeError: if output_dir_path is not of type str - - Returns: - LibraryCompilationResult - """ - if not isinstance(output_dir_path, str): - raise TypeError( - f"output_dir_path must be of type str, not {type(output_dir_path)}" - ) - return LibraryCompilationResult.wrap(_LibraryCompilationResult(output_dir_path)) - - # pylint: enable=arguments-differ diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/library_lambda.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/library_lambda.py deleted file mode 100644 index 446f37e3d2..0000000000 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/library_lambda.py +++ /dev/null @@ -1,31 +0,0 @@ -# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. -# See https://github.com/zama-ai/concrete/blob/main/LICENSE.txt for license information. - -"""LibraryLambda.""" - -# pylint: disable=no-name-in-module,import-error -from mlir._mlir_libs._concretelang._compiler import ( - LibraryLambda as _LibraryLambda, -) - -# pylint: enable=no-name-in-module,import-error -from .wrapper import WrapperCpp - - -class LibraryLambda(WrapperCpp): - """LibraryLambda reference a compiled library and can be ran using LibrarySupport.""" - - def __init__(self, library_lambda: _LibraryLambda): - """Wrap the native Cpp object. - - Args: - library_lambda (_LibraryLambda): object to wrap - - Raises: - TypeError: if library_lambda is not of type _LibraryLambda - """ - if not isinstance(library_lambda, _LibraryLambda): - raise TypeError( - f"library_lambda must be of type _LibraryLambda, not {type(library_lambda)}" - ) - super().__init__(library_lambda) diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/library_support.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/library_support.py deleted file mode 100644 index 4280296ec3..0000000000 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/library_support.py +++ /dev/null @@ -1,362 +0,0 @@ -# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. -# See https://github.com/zama-ai/concrete/blob/main/LICENSE.txt for license information. - -"""LibrarySupport. - -Library support provides a way to compile an MLIR program into a library that can be later loaded -to execute the compiled code. -""" -import os -from typing import Optional, Union - -# pylint: disable=no-name-in-module,import-error -from mlir._mlir_libs._concretelang._compiler import ( - LibrarySupport as _LibrarySupport, -) -from mlir.ir import Module as MlirModule - -# pylint: enable=no-name-in-module,import-error -from .compilation_options import CompilationOptions -from .compilation_context import CompilationContext -from .library_compilation_result import LibraryCompilationResult -from .public_arguments import PublicArguments -from .library_lambda import LibraryLambda -from .public_result import PublicResult -from .client_parameters import ClientParameters -from .compilation_feedback import ProgramCompilationFeedback -from .wrapper import WrapperCpp -from .utils import lookup_runtime_lib -from .evaluation_keys import EvaluationKeys - - -# Default output path for compilation artifacts -DEFAULT_OUTPUT_PATH = os.path.abspath( - os.path.join(os.path.curdir, "concrete-compiler_compilation_artifacts") -) - - -class LibrarySupport(WrapperCpp): - """Support class for library compilation and execution.""" - - def __init__(self, library_support: _LibrarySupport): - """Wrap the native Cpp object. - - Args: - library_support (_LibrarySupport): object to wrap - - Raises: - TypeError: if library_support is not of type _LibrarySupport - """ - if not isinstance(library_support, _LibrarySupport): - raise TypeError( - f"library_support must be of type _LibrarySupport, not " - f"{type(library_support)}" - ) - super().__init__(library_support) - self.output_dir_path = DEFAULT_OUTPUT_PATH - - @property - def output_dir_path(self) -> str: - """Path where to store compilation artifacts.""" - return self._output_dir_path - - @output_dir_path.setter - def output_dir_path(self, path: str): - if not isinstance(path, str): - raise TypeError(f"path must be of type str, not {type(path)}") - self._output_dir_path = path - - @staticmethod - # pylint: disable=arguments-differ - def new( - output_path: str = DEFAULT_OUTPUT_PATH, - runtime_library_path: Optional[str] = None, - generateSharedLib: bool = True, - generateStaticLib: bool = False, - generateClientParameters: bool = True, - generateCompilationFeedback: bool = True, - generateCppHeader: bool = False, - ) -> "LibrarySupport": - """Build a LibrarySupport. - - Args: - output_path (str, optional): path where to store compilation artifacts. - Defaults to DEFAULT_OUTPUT_PATH. - runtime_library_path (Optional[str], optional): path to the runtime library. Defaults to None. - generateSharedLib (bool): whether to emit shared library or not. Default to True. - generateStaticLib (bool): whether to emit static library or not. Default to False. - generateClientParameters (bool): whether to emit client parameters or not. Default to True. - generateCppHeader (bool): whether to emit cpp header or not. Default to False. - - Raises: - TypeError: if output_path is not of type str - TypeError: if runtime_library_path is not of type str - TypeError: if one of the generation flags is not of type bool - - Returns: - LibrarySupport - """ - if runtime_library_path is None: - runtime_library_path = lookup_runtime_lib() - if not isinstance(output_path, str): - raise TypeError(f"output_path must be of type str, not {type(output_path)}") - if not isinstance(runtime_library_path, str): - raise TypeError( - f"runtime_library_path must be of type str, not {type(runtime_library_path)}" - ) - for name, value in [ - ("generateSharedLib", generateSharedLib), - ("generateStaticLib", generateStaticLib), - ("generateClientParameters", generateClientParameters), - ("generateCompilationFeedback", generateCompilationFeedback), - ("generateCppHeader", generateCppHeader), - ]: - if not isinstance(value, bool): - raise TypeError(f"{name} must be of type bool, not {type(value)}") - library_support = LibrarySupport.wrap( - _LibrarySupport( - output_path, - runtime_library_path, - generateSharedLib, - generateStaticLib, - generateClientParameters, - generateCompilationFeedback, - generateCppHeader, - ) - ) - if not os.path.isdir(output_path): - os.makedirs(output_path) - library_support.output_dir_path = output_path - return library_support - - def compile( - self, - mlir_program: Union[str, MlirModule], - options: CompilationOptions = CompilationOptions.new(), - compilation_context: Optional[CompilationContext] = None, - ) -> LibraryCompilationResult: - """Compile an MLIR program using Concrete dialects into a library. - - Args: - mlir_program (Union[str, MlirModule]): mlir program to compile (textual or in-memory) - options (CompilationOptions): compilation options - - Raises: - TypeError: if mlir_program is not of type str or MlirModule - TypeError: if options is not of type CompilationOptions - - Returns: - LibraryCompilationResult: the result of the library compilation - """ - if not isinstance(mlir_program, (str, MlirModule)): - raise TypeError( - f"mlir_program must be of type str or MlirModule, not {type(mlir_program)}" - ) - if not isinstance(options, CompilationOptions): - raise TypeError( - f"options must be of type CompilationOptions, not {type(options)}" - ) - # get the PyCapsule of the module - if isinstance(mlir_program, MlirModule): - if compilation_context is None: - raise ValueError( - "compilation_context must be provided when compiling a module object" - ) - if not isinstance(compilation_context, CompilationContext): - raise TypeError( - f"compilation_context must be of type CompilationContext, not " - f"{type(compilation_context)}" - ) - # pylint: disable=protected-access - return LibraryCompilationResult.wrap( - self.cpp().compile( - mlir_program._CAPIPtr, options.cpp(), compilation_context.cpp() - ) - ) - # pylint: enable=protected-access - return LibraryCompilationResult.wrap( - self.cpp().compile(mlir_program, options.cpp()) - ) - - def reload(self) -> LibraryCompilationResult: - """Reload the library compilation result from the output_dir_path. - - Returns: - LibraryCompilationResult: loaded library - """ - return LibraryCompilationResult.new(self.output_dir_path) - - def load_client_parameters( - self, library_compilation_result: LibraryCompilationResult - ) -> ClientParameters: - """Load the client parameters from the library compilation result. - - Args: - library_compilation_result (LibraryCompilationResult): compilation result of the library - - Raises: - TypeError: if library_compilation_result is not of type LibraryCompilationResult - - Returns: - ClientParameters: appropriate client parameters for the compiled library - """ - if not isinstance(library_compilation_result, LibraryCompilationResult): - raise TypeError( - f"library_compilation_result must be of type LibraryCompilationResult, not " - f"{type(library_compilation_result)}" - ) - - return ClientParameters.wrap( - self.cpp().load_client_parameters(library_compilation_result.cpp()) - ) - - def load_compilation_feedback( - self, compilation_result: LibraryCompilationResult - ) -> ProgramCompilationFeedback: - """Load the compilation feedback from the compilation result. - - Args: - compilation_result (LibraryCompilationResult): result of the compilation - - Raises: - TypeError: if compilation_result is not of type LibraryCompilationResult - - Returns: - ProgramCompilationFeedback: the compilation feedback for the compiled program - """ - if not isinstance(compilation_result, LibraryCompilationResult): - raise TypeError( - f"compilation_result must be of type LibraryCompilationResult, not {type(compilation_result)}" - ) - return ProgramCompilationFeedback.wrap( - self.cpp().load_compilation_feedback(compilation_result.cpp()) - ) - - def load_server_lambda( - self, - library_compilation_result: LibraryCompilationResult, - simulation: bool, - circuit_name: str, - ) -> LibraryLambda: - """Load the server lambda for a given circuit from the library compilation result. - - Args: - library_compilation_result (LibraryCompilationResult): compilation result of the library - simulation (bool): use simulation for execution - circuit_name (str): name of the circuit to be loaded - - Raises: - TypeError: if library_compilation_result is not of type LibraryCompilationResult, if - circuit_name is not of type str or - - Returns: - LibraryLambda: executable reference to the library - """ - if not isinstance(library_compilation_result, LibraryCompilationResult): - raise TypeError( - f"library_compilation_result must be of type LibraryCompilationResult, not " - f"{type(library_compilation_result)}" - ) - if not isinstance(circuit_name, str): - raise TypeError( - f"circuit_name must be of type str, not " f"{type(circuit_name)}" - ) - if not isinstance(simulation, bool): - raise TypeError( - f"simulation must be of type bool, not " f"{type(simulation)}" - ) - return LibraryLambda.wrap( - self.cpp().load_server_lambda( - library_compilation_result.cpp(), circuit_name, simulation - ) - ) - - def server_call( - self, - library_lambda: LibraryLambda, - public_arguments: PublicArguments, - evaluation_keys: EvaluationKeys, - ) -> PublicResult: - """Call the library with public_arguments. - - Args: - library_lambda (LibraryLambda): reference to the compiled library - public_arguments (PublicArguments): arguments to use for execution - evaluation_keys (EvaluationKeys): evaluation keys to use for execution - - Raises: - TypeError: if library_lambda is not of type LibraryLambda - TypeError: if public_arguments is not of type PublicArguments - TypeError: if evaluation_keys is not of type EvaluationKeys - - Returns: - PublicResult: result of the execution - """ - if not isinstance(library_lambda, LibraryLambda): - raise TypeError( - f"library_lambda must be of type LibraryLambda, not {type(library_lambda)}" - ) - if not isinstance(public_arguments, PublicArguments): - raise TypeError( - f"public_arguments must be of type PublicArguments, not {type(public_arguments)}" - ) - if not isinstance(evaluation_keys, EvaluationKeys): - raise TypeError( - f"evaluation_keys must be of type EvaluationKeys, not {type(evaluation_keys)}" - ) - return PublicResult.wrap( - self.cpp().server_call( - library_lambda.cpp(), - public_arguments.cpp(), - evaluation_keys.cpp(), - ) - ) - - def simulate( - self, - library_lambda: LibraryLambda, - public_arguments: PublicArguments, - ) -> PublicResult: - """Call the library with public_arguments in simulation mode. - - Args: - library_lambda (LibraryLambda): reference to the compiled library - public_arguments (PublicArguments): arguments to use for execution - - Raises: - TypeError: if library_lambda is not of type LibraryLambda - TypeError: if public_arguments is not of type PublicArguments - - Returns: - PublicResult: result of the execution - """ - if not isinstance(library_lambda, LibraryLambda): - raise TypeError( - f"library_lambda must be of type LibraryLambda, not {type(library_lambda)}" - ) - if not isinstance(public_arguments, PublicArguments): - raise TypeError( - f"public_arguments must be of type PublicArguments, not {type(public_arguments)}" - ) - return PublicResult.wrap( - self.cpp().simulate( - library_lambda.cpp(), - public_arguments.cpp(), - ) - ) - - def get_shared_lib_path(self) -> str: - """Get the path where the shared library is expected to be. - - Returns: - str: path to the shared library - """ - return self.cpp().get_shared_lib_path() - - def get_program_info_path(self) -> str: - """Get the path where the program info file is expected to be. - - Returns: - str: path to the program info file - """ - return self.cpp().get_program_info_path() diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/lwe_secret_key.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/lwe_secret_key.py deleted file mode 100644 index 704b1fa027..0000000000 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/lwe_secret_key.py +++ /dev/null @@ -1,141 +0,0 @@ -# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. -# See https://github.com/zama-ai/concrete/blob/main/LICENSE.txt for license information. - - -"""LweSecretKey.""" - -# pylint: disable=no-name-in-module,import-error -from mlir._mlir_libs._concretelang._compiler import ( - LweSecretKey as _LweSecretKey, - LweSecretKeyParam as _LweSecretKeyParam, -) - -# pylint: enable=no-name-in-module,import-error -from .wrapper import WrapperCpp - - -class LweSecretKeyParam(WrapperCpp): - """LWE Secret Key Parameters""" - - def __init__(self, lwe_secret_key_param: _LweSecretKeyParam): - """Wrap the native Cpp object. - - Args: - lwe_secret_key_param (_LweSecretKeyParam): object to wrap - - Raises: - TypeError: if lwe_secret_key_param is not of type _LweSecretKeyParam - """ - if not isinstance(lwe_secret_key_param, _LweSecretKeyParam): - raise TypeError( - "lwe_secret_key_param must be of type _LweSecretKeyParam, " - f"not {type(lwe_secret_key_param)}" - ) - super().__init__(lwe_secret_key_param) - - @property - def dimension(self) -> int: - """LWE dimension""" - return self.cpp().dimension - - -class LweSecretKey(WrapperCpp): - """An LweSecretKey.""" - - def __init__(self, lwe_secret_key: _LweSecretKey): - """Wrap the native Cpp object. - - Args: - lwe_secret_key (_LweSecretKey): object to wrap - - Raises: - TypeError: if lwe_secret_key is not of type _LweSecretKey - """ - if not isinstance(lwe_secret_key, _LweSecretKey): - raise TypeError( - f"lwe_secret_key must be of type _LweSecretKey, not {type(lwe_secret_key)}" - ) - super().__init__(lwe_secret_key) - - def serialize(self) -> bytes: - """Serialize key. - - Returns: - bytes: serialized key - """ - - return self.cpp().serialize() - - @staticmethod - def deserialize(serialized_key: bytes, param: LweSecretKeyParam) -> "LweSecretKey": - """Deserialize LweSecretKey from bytes. - - Args: - serialized_key (bytes): previously serialized secret key - - Raises: - TypeError: if wrong types for input arguments - - Returns: - LweSecretKey: deserialized object - """ - if not isinstance(serialized_key, bytes): - raise TypeError( - f"serialized_key must be of type bytes, not {type(serialized_key)}" - ) - if not isinstance(param, LweSecretKeyParam): - raise TypeError( - f"param must be of type LweSecretKeyParam, not {type(param)}" - ) - return LweSecretKey.wrap(_LweSecretKey.deserialize(serialized_key, param.cpp())) - - def serialize_as_glwe(self, glwe_dim: int, poly_size: int) -> bytes: - """Serialize key as a glwe secret key. - - Args: - glwe_dim (int): glwe dimension of the key - poly_size (int): polynomial size of the key - - Raises: - TypeError: if wrong types for input arguments - - Returns: - bytes: serialized key - """ - if not isinstance(glwe_dim, int): - raise TypeError(f"glwe_dim must be of type int, not {type(glwe_dim)}") - if not isinstance(poly_size, int): - raise TypeError(f"poly_size must be of type int, not {type(poly_size)}") - return self.cpp().serialize_as_glwe(glwe_dim, poly_size) - - @staticmethod - def deserialize_from_glwe( - serialized_glwe_key: bytes, param: LweSecretKeyParam - ) -> "LweSecretKey": - """Deserialize LweSecretKey from glwe secret key bytes. - - Args: - serialized_glwe_key (bytes): previously serialized glwe secret key - - Raises: - TypeError: if wrong types for input arguments - - Returns: - LweSecretKey: deserialized object - """ - if not isinstance(serialized_glwe_key, bytes): - raise TypeError( - f"serialized_glwe_key must be of type bytes, not {type(serialized_glwe_key)}" - ) - if not isinstance(param, LweSecretKeyParam): - raise TypeError( - f"param must be of type LweSecretKeyParam, not {type(param)}" - ) - return LweSecretKey.wrap( - _LweSecretKey.deserialize_from_glwe(serialized_glwe_key, param.cpp()) - ) - - @property - def param(self) -> LweSecretKeyParam: - """LWE Secret Key Parameters""" - return LweSecretKeyParam.wrap(self.cpp().param) diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/parameter.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/parameter.py deleted file mode 100644 index 266bf2e36a..0000000000 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/parameter.py +++ /dev/null @@ -1,105 +0,0 @@ -""" -Parameter. -""" - -# pylint: disable=no-name-in-module,import-error - -from typing import Union - -from mlir._mlir_libs._concretelang._compiler import ( - LweSecretKeyParam, - BootstrapKeyParam, - KeyswitchKeyParam, - PackingKeyswitchKeyParam, - KeyType, -) - -from .client_parameters import ClientParameters - -# pylint: enable=no-name-in-module,import-error - - -class Parameter: - """ - An FHE parameter. - """ - - _inner: Union[ - LweSecretKeyParam, - BootstrapKeyParam, - KeyswitchKeyParam, - PackingKeyswitchKeyParam, - ] - - def __init__( - self, - client_parameters: ClientParameters, - key_type: KeyType, - key_index: int, - ): - if key_type == KeyType.SECRET: - self._inner = client_parameters.cpp().secret_keys[key_index] - elif key_type == KeyType.BOOTSTRAP: - self._inner = client_parameters.cpp().bootstrap_keys[key_index] - elif key_type == KeyType.KEY_SWITCH: - self._inner = client_parameters.cpp().keyswitch_keys[key_index] - elif key_type == KeyType.PACKING_KEY_SWITCH: - self._inner = client_parameters.cpp().packing_keyswitch_keys[key_index] - else: - raise ValueError("invalid key type") - - def __getattr__(self, item): - return getattr(self._inner, item) - - def __repr__(self): - param = self._inner - - if isinstance(param, LweSecretKeyParam): - result = f"LweSecretKeyParam(" f"dimension={param.dimension()}" f")" - - elif isinstance(param, BootstrapKeyParam): - result = ( - f"BootstrapKeyParam(" - f"polynomial_size={param.polynomial_size()}, " - f"glwe_dimension={param.glwe_dimension()}, " - f"input_lwe_dimension={param.input_lwe_dimension()}, " - f"level={param.level()}, " - f"base_log={param.base_log()}, " - f"variance={param.variance()}" - f")" - ) - - elif isinstance(param, KeyswitchKeyParam): - result = ( - f"KeyswitchKeyParam(" - f"level={param.level()}, " - f"base_log={param.base_log()}, " - f"variance={param.variance()}" - f")" - ) - - elif isinstance(param, PackingKeyswitchKeyParam): - result = ( - f"PackingKeyswitchKeyParam(" - f"polynomial_size={param.polynomial_size()}, " - f"glwe_dimension={param.glwe_dimension()}, " - f"input_lwe_dimension={param.input_lwe_dimension()}" - f"level={param.level()}, " - f"base_log={param.base_log()}, " - f"variance={param.variance()}" - f")" - ) - - else: - assert False - - return result - - def __str__(self): - return repr(self) - - def __hash__(self): - return hash(str(self)) - - def __eq__(self, other): - return str(self) == str(other) diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/public_arguments.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/public_arguments.py deleted file mode 100644 index 2c0799ef4d..0000000000 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/public_arguments.py +++ /dev/null @@ -1,91 +0,0 @@ -# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. -# See https://github.com/zama-ai/concrete/blob/main/LICENSE.txt for license information. - -"""PublicArguments.""" - -from typing import List - -# pylint: disable=no-name-in-module,import-error -from mlir._mlir_libs._concretelang._compiler import ( - PublicArguments as _PublicArguments, -) - -# pylint: enable=no-name-in-module,import-error -from .client_parameters import ClientParameters -from .value import Value -from .wrapper import WrapperCpp - - -class PublicArguments(WrapperCpp): - """PublicArguments holds encrypted and plain arguments, as well as public materials. - - An encrypted computation may require both encrypted and plain arguments, PublicArguments holds both - types, but also other public materials, such as public keys, which are required for private computation. - """ - - def __init__(self, public_arguments: _PublicArguments): - """Wrap the native Cpp object. - - Args: - public_arguments (_PublicArguments): object to wrap - - Raises: - TypeError: if public_arguments is not of type _PublicArguments - """ - if not isinstance(public_arguments, _PublicArguments): - raise TypeError( - f"public_arguments must be of type _PublicArguments, not {type(public_arguments)}" - ) - super().__init__(public_arguments) - - @staticmethod - # pylint: disable=arguments-differ - def new( - client_parameters: ClientParameters, values: List[Value] - ) -> "PublicArguments": - """ - Create public arguments from individual values. - """ - return PublicArguments( - _PublicArguments.create( - client_parameters.cpp(), - [value.cpp() for value in values], - ) - ) - - def serialize(self) -> bytes: - """Serialize the PublicArguments. - - Returns: - bytes: serialized object - """ - return self.cpp().serialize() - - @staticmethod - def deserialize( - client_parameters: ClientParameters, serialized_args: bytes - ) -> "PublicArguments": - """Unserialize PublicArguments from bytes of serialized_args. - - Args: - client_parameters (ClientParameters): client parameters of the compiled circuit - serialized_args (bytes): previously serialized PublicArguments - - Raises: - TypeError: if client_parameters is not of type ClientParameters - TypeError: if serialized_args is not of type bytes - - Returns: - PublicArguments: deserialized object - """ - if not isinstance(client_parameters, ClientParameters): - raise TypeError( - f"client_parameters must be of type ClientParameters, not {type(client_parameters)}" - ) - if not isinstance(serialized_args, bytes): - raise TypeError( - f"serialized_args must be of type bytes, not {type(serialized_args)}" - ) - return PublicArguments.wrap( - _PublicArguments.deserialize(client_parameters.cpp(), serialized_args) - ) diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/public_result.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/public_result.py deleted file mode 100644 index 655369e377..0000000000 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/public_result.py +++ /dev/null @@ -1,82 +0,0 @@ -# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. -# See https://github.com/zama-ai/concrete/blob/main/LICENSE.txt for license information. - -"""PublicResult.""" - -# pylint: disable=no-name-in-module,import-error -from mlir._mlir_libs._concretelang._compiler import ( - PublicResult as _PublicResult, -) -from .client_parameters import ClientParameters - -# pylint: enable=no-name-in-module,import-error -from .value import Value -from .wrapper import WrapperCpp - - -class PublicResult(WrapperCpp): - """PublicResult holds the result of an encrypted execution and can be decrypted using ClientSupport.""" - - def __init__(self, public_result: _PublicResult): - """Wrap the native Cpp object. - - Args: - public_result (_PublicResult): object to wrap - - Raises: - TypeError: if public_result is not of type _PublicResult - """ - if not isinstance(public_result, _PublicResult): - raise TypeError( - f"public_result must be of type _PublicResult, not {type(public_result)}" - ) - super().__init__(public_result) - - def n_values(self) -> int: - """ - Get number of values in the result. - """ - return self.cpp().n_values() - - def get_value(self, position: int) -> Value: - """ - Get a specific value in the result. - """ - return Value(self.cpp().get_value(position)) - - def serialize(self) -> bytes: - """Serialize the PublicResult. - - Returns: - bytes: serialized object - """ - return self.cpp().serialize() - - @staticmethod - def deserialize( - client_parameters: ClientParameters, serialized_result: bytes - ) -> "PublicResult": - """Unserialize PublicResult from bytes of serialized_result. - - Args: - client_parameters (ClientParameters): client parameters of the compiled circuit - serialized_result (bytes): previously serialized PublicResult - - Raises: - TypeError: if client_parameters is not of type ClientParameters - TypeError: if serialized_result is not of type bytes - - Returns: - PublicResult: deserialized object - """ - if not isinstance(client_parameters, ClientParameters): - raise TypeError( - f"client_parameters must be of type ClientParameters, not {type(client_parameters)}" - ) - if not isinstance(serialized_result, bytes): - raise TypeError( - f"serialized_result must be of type bytes, not {type(serialized_result)}" - ) - return PublicResult.wrap( - _PublicResult.deserialize(client_parameters.cpp(), serialized_result) - ) diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/server_circuit.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/server_circuit.py deleted file mode 100644 index cfd2be2686..0000000000 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/server_circuit.py +++ /dev/null @@ -1,88 +0,0 @@ -# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. -# See https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt for license information. - -"""ServerCircuit.""" - -# pylint: disable=no-name-in-module,import-error -from mlir._mlir_libs._concretelang._compiler import ( - ServerCircuit as _ServerCircuit, -) - -# pylint: enable=no-name-in-module,import-error -from .wrapper import WrapperCpp -from .public_arguments import PublicArguments -from .public_result import PublicResult -from .evaluation_keys import EvaluationKeys - - -class ServerCircuit(WrapperCpp): - """ServerCircuit references a circuit that can be called for execution and simulation.""" - - def __init__(self, server_circuit: _ServerCircuit): - """Wrap the native Cpp object. - - Args: - server_circuit (_ServerCircuit): object to wrap - - Raises: - TypeError: if server_circuit is not of type _ServerCircuit - """ - if not isinstance(server_circuit, _ServerCircuit): - raise TypeError( - f"server_circuit must be of type _ServerCircuit, not {type(server_circuit)}" - ) - super().__init__(server_circuit) - - def call( - self, - public_arguments: PublicArguments, - evaluation_keys: EvaluationKeys, - ) -> PublicResult: - """Executes the circuit on the public arguments. - - Args: - public_arguments (PublicArguments): public arguments to execute on - execution_keys (EvaluationKeys): evaluation keys to use for execution. - - Raises: - TypeError: if public_arguments is not of type PublicArguments, or if evaluation_keys is - not of type EvaluationKeys - - Returns: - PublicResult: A public result object containing the results. - """ - if not isinstance(public_arguments, PublicArguments): - raise TypeError( - f"public_arguments must be of type PublicArguments, not " - f"{type(public_arguments)}" - ) - if not isinstance(evaluation_keys, EvaluationKeys): - raise TypeError( - f"simulation must be of type EvaluationKeys, not " - f"{type(evaluation_keys)}" - ) - return PublicResult.wrap( - self.cpp().call(public_arguments.cpp(), evaluation_keys.cpp()) - ) - - def simulate( - self, - public_arguments: PublicArguments, - ) -> PublicResult: - """Simulates the circuit on the public arguments. - - Args: - public_arguments (PublicArguments): public arguments to execute on - - Raises: - TypeError: if public_arguments is not of type PublicArguments - - Returns: - PublicResult: A public result object containing the results. - """ - if not isinstance(public_arguments, PublicArguments): - raise TypeError( - f"public_arguments must be of type PublicArguments, not " - f"{type(public_arguments)}" - ) - return PublicResult.wrap(self.cpp().simulate(public_arguments.cpp())) diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/server_program.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/server_program.py deleted file mode 100644 index 16b6138761..0000000000 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/server_program.py +++ /dev/null @@ -1,80 +0,0 @@ -# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. -# See https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt for license information. - -"""ServerProgram.""" - -# pylint: disable=no-name-in-module,import-error -from mlir._mlir_libs._concretelang._compiler import ( - ServerProgram as _ServerProgram, -) - -# pylint: enable=no-name-in-module,import-error -from .wrapper import WrapperCpp -from .library_support import LibrarySupport -from .server_circuit import ServerCircuit - - -class ServerProgram(WrapperCpp): - """ServerProgram references compiled circuit objects.""" - - def __init__(self, server_program: _ServerProgram): - """Wrap the native Cpp object. - - Args: - server_program (_ServerProgram): object to wrap - - Raises: - TypeError: if server_program is not of type _ServerProgram - """ - if not isinstance(server_program, _ServerProgram): - raise TypeError( - f"server_program must be of type _ServerProgram, not {type(server_program)}" - ) - super().__init__(server_program) - - @staticmethod - def load( - library_support: LibrarySupport, - simulation: bool, - ) -> "ServerProgram": - """Loads the server program from a library support. - - Args: - library_support (LibrarySupport): library support - simulation (bool): use simulation for execution - - Raises: - TypeError: if library_support is not of type LibrarySupport, or if simulation is not of type bool - - Returns: - ServerProgram: A server program object containing references to circuits for calls. - """ - if not isinstance(library_support, LibrarySupport): - raise TypeError( - f"library_support must be of type LibrarySupport, not " - f"{type(library_support)}" - ) - if not isinstance(simulation, bool): - raise TypeError( - f"simulation must be of type bool, not " f"{type(simulation)}" - ) - return ServerProgram.wrap( - _ServerProgram.load(library_support.cpp(), simulation) - ) - - def get_server_circuit(self, circuit_name: str) -> ServerCircuit: - """Returns a given circuit if it is part of the program. - - Args: - circuit_name (str): name of the circuit to retrieve. - - Raises: - TypeError: if circuit_name is not of type str - RuntimeError: if the circuit is not part of the program - """ - if not isinstance(circuit_name, str): - raise TypeError( - f"circuit_name must be of type str, not {type(circuit_name)}" - ) - - return ServerCircuit.wrap(self.cpp().get_server_circuit(circuit_name)) diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/simulated_value_decrypter.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/simulated_value_decrypter.py deleted file mode 100644 index b03bc6a8af..0000000000 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/simulated_value_decrypter.py +++ /dev/null @@ -1,80 +0,0 @@ -"""SimulatedValueDecrypter.""" - -# pylint: disable=no-name-in-module,import-error - -from typing import Union - -import numpy as np -from mlir._mlir_libs._concretelang._compiler import ( - SimulatedValueDecrypter as _SimulatedValueDecrypter, -) - -from .client_parameters import ClientParameters -from .value import Value -from .wrapper import WrapperCpp - -# pylint: enable=no-name-in-module,import-error - - -class SimulatedValueDecrypter(WrapperCpp): - """A helper class to decrypt `Value`s.""" - - def __init__(self, value_decrypter: _SimulatedValueDecrypter): - """ - Wrap the native C++ object. - - Args: - value_decrypter (_SimulatedValueDecrypter): - object to wrap - - Raises: - TypeError: - if `value_decrypter` is not of type `_SimulatedValueDecrypter` - """ - - if not isinstance(value_decrypter, _SimulatedValueDecrypter): - raise TypeError( - f"value_decrypter must be of type _SimulatedValueDecrypter, not {type(value_decrypter)}" - ) - - super().__init__(value_decrypter) - - @staticmethod - # pylint: disable=arguments-differ - def new(client_parameters: ClientParameters, circuit_name: str): - """ - Create a value decrypter. - """ - return SimulatedValueDecrypter( - _SimulatedValueDecrypter.create(client_parameters.cpp(), circuit_name) - ) - - def decrypt(self, position: int, value: Value) -> Union[int, np.ndarray]: - """ - Decrypt value. - - Args: - position (int): - position of the argument within the circuit - - value (Value): - value to decrypt - - Returns: - Union[int, np.ndarray]: - decrypted value - """ - - lambda_arg = self.cpp().decrypt(position, value.cpp()) - is_signed = lambda_arg.is_signed() - if lambda_arg.is_scalar(): - return ( - lambda_arg.get_signed_scalar() if is_signed else lambda_arg.get_scalar() - ) - - shape = lambda_arg.get_tensor_shape() - return ( - np.array(lambda_arg.get_signed_tensor_data()).reshape(shape) - if is_signed - else np.array(lambda_arg.get_tensor_data()).reshape(shape) - ) diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/simulated_value_exporter.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/simulated_value_exporter.py deleted file mode 100644 index f6a1e9b651..0000000000 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/simulated_value_exporter.py +++ /dev/null @@ -1,92 +0,0 @@ -"""SimulatedValueExporter.""" - -# pylint: disable=no-name-in-module,import-error - -from typing import List - -from mlir._mlir_libs._concretelang._compiler import ( - SimulatedValueExporter as _SimulatedValueExporter, -) - -from .client_parameters import ClientParameters -from .value import Value -from .wrapper import WrapperCpp - -# pylint: enable=no-name-in-module,import-error - - -class SimulatedValueExporter(WrapperCpp): - """A helper class to create `Value`s.""" - - def __init__(self, value_exporter: _SimulatedValueExporter): - """ - Wrap the native C++ object. - - Args: - value_exporter (_SimulatedValueExporter): - object to wrap - - Raises: - TypeError: - if `value_exporter` is not of type `_SimulatedValueExporter` - """ - - if not isinstance(value_exporter, _SimulatedValueExporter): - raise TypeError( - f"value_exporter must be of type _SimulatedValueExporter, not {type(value_exporter)}" - ) - - super().__init__(value_exporter) - - @staticmethod - # pylint: disable=arguments-differ - def new( - client_parameters: ClientParameters, circuitName: str - ) -> "SimulatedValueExporter": - """ - Create a value exporter. - """ - return SimulatedValueExporter( - _SimulatedValueExporter.create(client_parameters.cpp(), circuitName) - ) - - def export_scalar(self, position: int, value: int) -> Value: - """ - Export scalar. - - Args: - position (int): - position of the argument within the circuit - - value (int): - scalar to export - - Returns: - Value: - exported scalar - """ - - return Value(self.cpp().export_scalar(position, value)) - - def export_tensor( - self, position: int, values: List[int], shape: List[int] - ) -> Value: - """ - Export tensor. - - Args: - position (int): - position of the argument within the circuit - - values (List[int]): - tensor elements to export - - shape (List[int]): - tensor shape to export - - Returns: - Value: - exported tensor - """ - - return Value(self.cpp().export_tensor(position, values, shape)) diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py index 68855c9aa3..3a16cb379b 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py @@ -1,14 +1,14 @@ """Import and export TFHErs integers into Concrete.""" -# pylint: disable=no-name-in-module,import-error +# pylint: disable=no-name-in-module,import-error, from mlir._mlir_libs._concretelang._compiler import ( import_tfhers_fheuint8 as _import_tfhers_fheuint8, export_tfhers_fheuint8 as _export_tfhers_fheuint8, get_tfhers_fheuint8_description as _get_tfhers_fheuint8_description, TfhersFheIntDescription as _TfhersFheIntDescription, + TransportValue, ) -from .value import Value from .wrapper import WrapperCpp # pylint: enable=no-name-in-module,import-error @@ -200,7 +200,7 @@ class TfhersExporter: """A helper class to import and export TFHErs big integers.""" @staticmethod - def export_fheuint8(value: Value, info: TfhersFheIntDescription) -> bytes: + def export_fheuint8(value: TransportValue, info: TfhersFheIntDescription) -> bytes: """Convert Concrete value to TFHErs and serialize it. Args: @@ -213,18 +213,18 @@ def export_fheuint8(value: Value, info: TfhersFheIntDescription) -> bytes: Returns: bytes: converted and serialized fheuint8 """ - if not isinstance(value, Value): - raise TypeError(f"value must be of type Value, not {type(value)}") + if not isinstance(value, TransportValue): + raise TypeError(f"value must be of type TransportValue, not {type(value)}") if not isinstance(info, TfhersFheIntDescription): raise TypeError( f"info must be of type TfhersFheIntDescription, not {type(info)}" ) - return bytes(_export_tfhers_fheuint8(value.cpp(), info.cpp())) + return bytes(_export_tfhers_fheuint8(value, info.cpp())) @staticmethod def import_fheuint8( buffer: bytes, info: TfhersFheIntDescription, keyid: int, variance: float - ) -> Value: + ) -> TransportValue: """Unserialize and convert from TFHErs to Concrete value. Args: @@ -237,7 +237,7 @@ def import_fheuint8( TypeError: if wrong input types Returns: - Value: unserialized and converted value + TransportValue: unserialized and converted value """ if not isinstance(buffer, bytes): raise TypeError(f"buffer must be of type bytes, not {type(buffer)}") @@ -249,4 +249,4 @@ def import_fheuint8( raise TypeError(f"keyid must be of type int, not {type(keyid)}") if not isinstance(variance, float): raise TypeError(f"variance must be of type float, not {type(variance)}") - return Value.wrap(_import_tfhers_fheuint8(buffer, info.cpp(), keyid, variance)) + return _import_tfhers_fheuint8(buffer, info.cpp(), keyid, variance) diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/value.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/value.py deleted file mode 100644 index 5a7e60867c..0000000000 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/value.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Value.""" - -# pylint: disable=no-name-in-module,import-error - -from mlir._mlir_libs._concretelang._compiler import ( - Value as _Value, -) - -from .wrapper import WrapperCpp - -# pylint: enable=no-name-in-module,import-error - - -class Value(WrapperCpp): - """An encrypted/clear value which can be scalar/tensor.""" - - def __init__(self, value: _Value): - """ - Wrap the native C++ object. - - Args: - value (_Value): - object to wrap - - Raises: - TypeError: - if `value` is not of type `_Value` - """ - - if not isinstance(value, _Value): - raise TypeError(f"value must be of type _Value, not {type(value)}") - - super().__init__(value) - - def serialize(self) -> bytes: - """ - Serialize value into bytes. - - Returns: - bytes: serialized value - """ - - return self.cpp().serialize() - - @staticmethod - def deserialize(serialized_value: bytes) -> "Value": - """ - Deserialize value from bytes. - - Args: - serialized_value (bytes): - previously serialized value - - Returns: - Value: - deserialized value - - Raises: - TypeError: - if `serialized_value` is not of type `bytes` - """ - - if not isinstance(serialized_value, bytes): - raise TypeError( - f"serialized_value must be of type bytes, not {type(serialized_value)}" - ) - - return Value.wrap(_Value.deserialize(serialized_value)) diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/value_decrypter.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/value_decrypter.py deleted file mode 100644 index d424136e30..0000000000 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/value_decrypter.py +++ /dev/null @@ -1,83 +0,0 @@ -"""ValueDecrypter.""" - -# pylint: disable=no-name-in-module,import-error - -from typing import Union - -import numpy as np -from mlir._mlir_libs._concretelang._compiler import ( - ValueDecrypter as _ValueDecrypter, -) - -from .client_parameters import ClientParameters -from .key_set import KeySet -from .value import Value -from .wrapper import WrapperCpp - -# pylint: enable=no-name-in-module,import-error - - -class ValueDecrypter(WrapperCpp): - """A helper class to decrypt `Value`s.""" - - def __init__(self, value_decrypter: _ValueDecrypter): - """ - Wrap the native C++ object. - - Args: - value_decrypter (_ValueDecrypter): - object to wrap - - Raises: - TypeError: - if `value_decrypter` is not of type `_ValueDecrypter` - """ - - if not isinstance(value_decrypter, _ValueDecrypter): - raise TypeError( - f"value_decrypter must be of type _ValueDecrypter, not {type(value_decrypter)}" - ) - - super().__init__(value_decrypter) - - @staticmethod - # pylint: disable=arguments-differ - def new( - keyset: KeySet, client_parameters: ClientParameters, circuit_name: str = "main" - ): - """ - Create a value decrypter. - """ - return ValueDecrypter( - _ValueDecrypter.create(keyset.cpp(), client_parameters.cpp(), circuit_name) - ) - - def decrypt(self, position: int, value: Value) -> Union[int, np.ndarray]: - """ - Decrypt value. - - Args: - position (int): - position of the argument within the circuit - - value (Value): - value to decrypt - - Returns: - Union[int, np.ndarray]: - decrypted value - """ - - lambda_arg = self.cpp().decrypt(position, value.cpp()) - is_signed = lambda_arg.is_signed() - if lambda_arg.is_scalar(): - return ( - lambda_arg.get_signed_scalar() if is_signed else lambda_arg.get_scalar() - ) - - shape = lambda_arg.get_tensor_shape() - return ( - np.array(lambda_arg.get_signed_tensor_data()).reshape(shape) - if is_signed - else np.array(lambda_arg.get_tensor_data()).reshape(shape) - ) diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/value_exporter.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/value_exporter.py deleted file mode 100644 index b23508e45d..0000000000 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/value_exporter.py +++ /dev/null @@ -1,93 +0,0 @@ -"""ValueExporter.""" - -# pylint: disable=no-name-in-module,import-error - -from typing import List - -from mlir._mlir_libs._concretelang._compiler import ( - ValueExporter as _ValueExporter, -) - -from .client_parameters import ClientParameters -from .key_set import KeySet -from .value import Value -from .wrapper import WrapperCpp - -# pylint: enable=no-name-in-module,import-error - - -class ValueExporter(WrapperCpp): - """A helper class to create `Value`s.""" - - def __init__(self, value_exporter: _ValueExporter): - """ - Wrap the native C++ object. - - Args: - value_exporter (_ValueExporter): - object to wrap - - Raises: - TypeError: - if `value_exporter` is not of type `_ValueExporter` - """ - - if not isinstance(value_exporter, _ValueExporter): - raise TypeError( - f"value_exporter must be of type _ValueExporter, not {type(value_exporter)}" - ) - - super().__init__(value_exporter) - - @staticmethod - # pylint: disable=arguments-differ - def new( - keyset: KeySet, client_parameters: ClientParameters, circuit_name: str - ) -> "ValueExporter": - """ - Create a value exporter. - """ - return ValueExporter( - _ValueExporter.create(keyset.cpp(), client_parameters.cpp(), circuit_name) - ) - - def export_scalar(self, position: int, value: int) -> Value: - """ - Export scalar. - - Args: - position (int): - position of the argument within the circuit - - value (int): - scalar to export - - Returns: - Value: - exported scalar - """ - - return Value(self.cpp().export_scalar(position, value)) - - def export_tensor( - self, position: int, values: List[int], shape: List[int] - ) -> Value: - """ - Export tensor. - - Args: - position (int): - position of the argument within the circuit - - values (List[int]): - tensor elements to export - - shape (List[int]): - tensor shape to export - - Returns: - Value: - exported tensor - """ - - return Value(self.cpp().export_tensor(position, values, shape)) diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/requirements_dev.txt b/compilers/concrete-compiler/compiler/lib/Bindings/Python/requirements_dev.txt index a5d6647ca4..021be4705c 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/requirements_dev.txt +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/requirements_dev.txt @@ -1,2 +1,3 @@ black==24.4.0 pylint==2.11.1 +mypy==1.11.2 diff --git a/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp b/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp index 6086b4722c..77e823cce2 100644 --- a/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp +++ b/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp @@ -33,6 +33,8 @@ using concretelang::values::Value; namespace concretelang { namespace clientlib { +bool ClientCircuit::isSimulated() { return simulated; } + Result ClientCircuit::create(const Message &info, const ClientKeyset &keyset, @@ -79,10 +81,27 @@ ClientCircuit::create(const Message &info, outputTransformers.push_back(transformer); } - return ClientCircuit(info, inputTransformers, outputTransformers); + return ClientCircuit(info, inputTransformers, outputTransformers, + useSimulation); +} + +Result ClientCircuit::createEncrypted( + const Message &info, + const ClientKeyset &keyset, + std::shared_ptr csprng) { + return ClientCircuit::create(info, keyset, csprng, false); +} + +Result ClientCircuit::createSimulated( + const Message &info, + std::shared_ptr csprng) { + return ClientCircuit::create(info, ClientKeyset(), csprng, true); } Result ClientCircuit::prepareInput(Value arg, size_t pos) { + if (simulated) { + return StringError("Called prepareInput on simulated client circuit."); + } if (pos >= inputTransformers.size()) { return StringError("Tried to prepare a Value for incorrect position."); } @@ -90,6 +109,34 @@ Result ClientCircuit::prepareInput(Value arg, size_t pos) { } Result ClientCircuit::processOutput(TransportValue result, size_t pos) { + if (simulated) { + return StringError("Called processOutput on simulated client circuit."); + } + if (pos >= outputTransformers.size()) { + return StringError( + "Tried to process a TransportValue for incorrect position."); + } + return outputTransformers[pos](result); +} + +Result ClientCircuit::simulatePrepareInput(Value arg, + size_t pos) { + if (!simulated) { + return StringError( + "Called simulatePrepareInput on encrypted client circuit."); + } + if (pos >= inputTransformers.size()) { + return StringError("Tried to prepare a Value for incorrect position."); + } + return inputTransformers[pos](arg); +} + +Result ClientCircuit::simulateProcessOutput(TransportValue result, + size_t pos) { + if (!simulated) { + return StringError( + "Called simulateProcessOutput on encrypted client circuit."); + } if (pos >= outputTransformers.size()) { return StringError( "Tried to process a TransportValue for incorrect position."); @@ -105,16 +152,26 @@ const Message &ClientCircuit::getCircuitInfo() { return circuitInfo; } -Result -ClientProgram::create(const Message &info, - const ClientKeyset &keyset, - std::shared_ptr csprng, - bool useSimulation) { +Result ClientProgram::createEncrypted( + const Message &info, + const ClientKeyset &keyset, + std::shared_ptr csprng) { + ClientProgram output; + for (auto circuitInfo : info.asReader().getCircuits()) { + OUTCOME_TRY(const ClientCircuit clientCircuit, + ClientCircuit::createEncrypted(circuitInfo, keyset, csprng)); + output.circuits.push_back(clientCircuit); + } + return output; +} + +Result ClientProgram::createSimulated( + const Message &info, + std::shared_ptr csprng) { ClientProgram output; for (auto circuitInfo : info.asReader().getCircuits()) { - OUTCOME_TRY( - ClientCircuit clientCircuit, - ClientCircuit::create(circuitInfo, keyset, csprng, useSimulation)); + OUTCOME_TRY(const ClientCircuit clientCircuit, + ClientCircuit::createSimulated(circuitInfo, csprng)); output.circuits.push_back(clientCircuit); } return output; @@ -195,12 +252,12 @@ exportTfhersFheUint8(TransportValue value, TfhersFheIntDescription desc) { if (!tensorOrError.has_value()) { return StringError("couldn't get tensor from value"); } - size_t buffer_size = + const size_t bufferSize = concrete_cpu_tfhers_fheint_buffer_size_u64(desc.lwe_size, desc.n_cts); - std::vector buffer(buffer_size, 0); - auto flat_data = tensorOrError.value().values; + std::vector buffer(bufferSize, 0); + auto flatData = tensorOrError.value().values; auto size = concrete_cpu_lwe_array_to_tfhers_uint8( - flat_data.data(), buffer.data(), buffer.size(), desc); + flatData.data(), buffer.data(), buffer.size(), desc); if (size == 0) { return StringError("couldn't convert lwe array to fheuint8"); } diff --git a/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp b/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp index b2eefd365e..9b9724f70b 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp @@ -65,6 +65,42 @@ #include "concretelang/Support/Utils.h" #include +namespace { +/// Returns the path of the shared library +std::string getSharedLibraryPath(std::string outputDirPath) { + llvm::SmallString<0> sharedLibraryPath(outputDirPath); + llvm::sys::path::append( + sharedLibraryPath, + "sharedlib" + + mlir::concretelang::CompilerEngine::Library::DOT_SHARED_LIB_EXT); + return sharedLibraryPath.str().str(); +} + +/// Returns the path of the static library +std::string getStaticLibraryPath(std::string outputDirPath) { + llvm::SmallString<0> staticLibraryPath(outputDirPath); + llvm::sys::path::append( + staticLibraryPath, + "staticlib" + + mlir::concretelang::CompilerEngine::Library::DOT_STATIC_LIB_EXT); + return staticLibraryPath.str().str(); +} + +/// Returns the path of the client parameter +std::string getProgramInfoPath(std::string outputDirPath) { + llvm::SmallString<0> programInfoPath(outputDirPath); + llvm::sys::path::append(programInfoPath, "program_info.concrete.params.json"); + return programInfoPath.str().str(); +} + +/// Returns the path of the compiler feedback +std::string getCompilationFeedbackPath(std::string outputDirPath) { + llvm::SmallString<0> compilationFeedbackPath(outputDirPath); + llvm::sys::path::append(compilationFeedbackPath, "compilation_feedback.json"); + return compilationFeedbackPath.str().str(); +} +} // namespace + namespace mlir { namespace concretelang { @@ -787,38 +823,6 @@ CompilerEngine::compile(mlir::ModuleOp module, std::string outputDirPath, generateStaticLib, generateClientParameters, generateCompilationFeedback); } -/// Returns the path of the shared library -std::string -CompilerEngine::Library::getSharedLibraryPath(std::string outputDirPath) { - llvm::SmallString<0> sharedLibraryPath(outputDirPath); - llvm::sys::path::append(sharedLibraryPath, "sharedlib" + DOT_SHARED_LIB_EXT); - return sharedLibraryPath.str().str(); -} - -/// Returns the path of the static library -std::string -CompilerEngine::Library::getStaticLibraryPath(std::string outputDirPath) { - llvm::SmallString<0> staticLibraryPath(outputDirPath); - llvm::sys::path::append(staticLibraryPath, "staticlib" + DOT_STATIC_LIB_EXT); - return staticLibraryPath.str().str(); -} - -/// Returns the path of the client parameter -std::string -CompilerEngine::Library::getProgramInfoPath(std::string outputDirPath) { - llvm::SmallString<0> programInfoPath(outputDirPath); - llvm::sys::path::append(programInfoPath, "program_info.concrete.params.json"); - return programInfoPath.str().str(); -} - -/// Returns the path of the compiler feedback -std::string -CompilerEngine::Library::getCompilationFeedbackPath(std::string outputDirPath) { - llvm::SmallString<0> compilationFeedbackPath(outputDirPath); - llvm::sys::path::append(compilationFeedbackPath, "compilation_feedback.json"); - return compilationFeedbackPath.str().str(); -} - const std::string CompilerEngine::Library::OBJECT_EXT = ".o"; const std::string CompilerEngine::Library::LINKER = "ld"; #ifdef __APPLE__ @@ -844,20 +848,53 @@ void CompilerEngine::Library::addExtraObjectFilePath(std::string path) { objectsPath.push_back(path); } -Message -CompilerEngine::Library::getProgramInfo() const { - return programInfo; +Result> +CompilerEngine::Library::getProgramInfo() { + if (!programInfo.has_value()) { + programInfo = Message(); + auto path = this->getProgramInfoPath(); + std::ifstream file(path); + std::string content((std::istreambuf_iterator(file)), + (std::istreambuf_iterator())); + if (file.fail()) { + return StringError("Cannot read program info file..."); + } + if (programInfo->readJsonFromString(content).has_failure()) { + return StringError("Program info file corrupted..."); + } + } + return programInfo.value(); } const std::string &CompilerEngine::Library::getOutputDirPath() const { return outputDirPath; } +/// Returns the path of the shared library +std::string CompilerEngine::Library::getSharedLibraryPath() const { + return ::getSharedLibraryPath(getOutputDirPath()); +} + +/// Returns the path of the static library +std::string CompilerEngine::Library::getStaticLibraryPath() const { + return ::getStaticLibraryPath(getOutputDirPath()); +}; + +/// Returns the path of the program info +std::string CompilerEngine::Library::getProgramInfoPath() const { + return ::getProgramInfoPath(getOutputDirPath()); +}; + +/// Returns the path of the compilation feedback +std::string CompilerEngine::Library::getCompilationFeedbackPath() const { + return ::getCompilationFeedbackPath(getOutputDirPath()); +}; + llvm::Expected CompilerEngine::Library::emitProgramInfoJSON() { - auto programInfoPath = getProgramInfoPath(outputDirPath); + auto programInfoPath = ::getProgramInfoPath(outputDirPath); std::error_code error; llvm::raw_fd_ostream out(programInfoPath, error); - auto maybeJson = programInfo.writeJsonToString(); + auto maybeJson = programInfo->writeJsonToString(); if (maybeJson.has_failure()) { return StreamStringError(maybeJson.as_failure().error().mesg); } @@ -870,7 +907,7 @@ llvm::Expected CompilerEngine::Library::emitProgramInfoJSON() { llvm::Expected CompilerEngine::Library::emitCompilationFeedbackJSON() { - auto path = getCompilationFeedbackPath(outputDirPath); + auto path = ::getCompilationFeedbackPath(outputDirPath); llvm::json::Value value(compilationFeedback); std::error_code error; llvm::raw_fd_ostream out(path, error); @@ -993,7 +1030,7 @@ llvm::Expected CompilerEngine::Library::emitShared() { } #endif } - auto path = emit(getSharedLibraryPath(outputDirPath), DOT_SHARED_LIB_EXT, + auto path = emit(::getSharedLibraryPath(outputDirPath), DOT_SHARED_LIB_EXT, LINKER + LINKER_SHARED_OPT, extraArgs); if (path) { sharedLibraryPath = path.get(); @@ -1024,7 +1061,7 @@ llvm::Expected CompilerEngine::Library::emitShared() { } llvm::Expected CompilerEngine::Library::emitStatic() { - auto path = emit(getStaticLibraryPath(outputDirPath), DOT_STATIC_LIB_EXT, + auto path = emit(::getStaticLibraryPath(outputDirPath), DOT_STATIC_LIB_EXT, AR + AR_STATIC_OPT); if (path) { staticLibraryPath = path.get(); diff --git a/compilers/concrete-compiler/compiler/tests/end_to_end_benchmarks/end_to_end_benchmark.cpp b/compilers/concrete-compiler/compiler/tests/end_to_end_benchmarks/end_to_end_benchmark.cpp index cadc37eff5..9db5532800 100644 --- a/compilers/concrete-compiler/compiler/tests/end_to_end_benchmarks/end_to_end_benchmark.cpp +++ b/compilers/concrete-compiler/compiler/tests/end_to_end_benchmarks/end_to_end_benchmark.cpp @@ -1,5 +1,4 @@ #include "../end_to_end_tests/end_to_end_test.h" -#include "concretelang/Common/Compat.h" #include "concretelang/TestLib/TestProgram.h" #include diff --git a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.cc b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.cc index dd0a895e27..2597448554 100644 --- a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.cc +++ b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.cc @@ -499,6 +499,6 @@ func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> { return %arg0: !FHE.eint<3> } )XXX")); - ASSERT_ASSIGN_OUTCOME_VALUE(result, circuit.call({Tensor(7)})); + ASSERT_ASSIGN_OUTCOME_VALUE(result, circuit.simulate({Tensor(7)})); ASSERT_EQ(result[0].getTensor().value()[0], (uint64_t)(7)); } diff --git a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_test.cc b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_test.cc index 65ecf0a7dc..3e8353d379 100644 --- a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_test.cc +++ b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_test.cc @@ -58,7 +58,9 @@ class EndToEndTest : public ::testing::Test { void testOnce() { for (auto tests_rep = 0; tests_rep <= options.numberOfRetry; tests_rep++) { // We execute the circuit. - auto maybeRes = testCircuit->call(args); + auto maybeRes = testCircuit->isSimulation() ? testCircuit->simulate(args) + : testCircuit->call(args); + ASSERT_OUTCOME_HAS_VALUE(maybeRes); auto result = maybeRes.value(); @@ -96,7 +98,8 @@ class EndToEndTest : public ::testing::Test { auto nbError = 0; for (size_t i = 0; i < errorRate->nb_repetition; i++) { // We execute the circuit. - auto maybeRes = (*testCircuit).call(args); + auto maybeRes = testCircuit->isSimulation() ? testCircuit->simulate(args) + : testCircuit->call(args); ASSERT_OUTCOME_HAS_VALUE(maybeRes); auto result = maybeRes.value(); diff --git a/compilers/concrete-compiler/compiler/tests/python/conftest.py b/compilers/concrete-compiler/compiler/tests/python/conftest.py index 9dae9d3e88..c021ca8cbf 100644 --- a/compilers/concrete-compiler/compiler/tests/python/conftest.py +++ b/compilers/concrete-compiler/compiler/tests/python/conftest.py @@ -1,7 +1,7 @@ import os import tempfile import pytest -from concrete.compiler import KeySetCache +from concrete.compiler import KeysetCache KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), "KeySetCache") @@ -12,7 +12,7 @@ def pytest_configure(config): @pytest.fixture(scope="session") def keyset_cache(): - return KeySetCache.new(KEY_SET_CACHE_PATH) + return KeysetCache(KEY_SET_CACHE_PATH) @pytest.fixture(scope="session") diff --git a/compilers/concrete-compiler/compiler/tests/python/overflow.py b/compilers/concrete-compiler/compiler/tests/python/overflow.py index b8ba3eaa64..044f59e596 100644 --- a/compilers/concrete-compiler/compiler/tests/python/overflow.py +++ b/compilers/concrete-compiler/compiler/tests/python/overflow.py @@ -1,7 +1,8 @@ import sys import shutil import numpy as np -from concrete.compiler import LibrarySupport +import json +from concrete.compiler import Compiler, lookup_runtime_lib from test_simulation import compile_run_assert @@ -14,14 +15,12 @@ mlir_input = f.read() artifact_dir = "./py_test_lib_compile_and_run" - engine = LibrarySupport.new(artifact_dir) - args = list(map(int, sys.argv[2:-1])) - expected_result = int(sys.argv[-1]) - args_and_shape = [] - for arg in args: - if isinstance(arg, int): - args_and_shape.append((arg, None)) - else: # np.array - args_and_shape.append((arg.flatten().tolist(), list(arg.shape))) - compile_run_assert(engine, mlir_input, args_and_shape, expected_result) + engine = Compiler( + artifact_dir, + lookup_runtime_lib(), + ) + args_and_res = json.loads(sys.argv[2]) + args = args_and_res[0] + expected_results = args_and_res[1] + compile_run_assert(engine, mlir_input, args, expected_results) shutil.rmtree(artifact_dir) diff --git a/compilers/concrete-compiler/compiler/tests/python/test_argument_support.py b/compilers/concrete-compiler/compiler/tests/python/test_argument_support.py index 7210b509d2..605cf2d7c9 100644 --- a/compilers/concrete-compiler/compiler/tests/python/test_argument_support.py +++ b/compilers/concrete-compiler/compiler/tests/python/test_argument_support.py @@ -1,7 +1,7 @@ import pytest import numpy as np from concrete.compiler.utils import ACCEPTED_NUMPY_UINTS -from concrete.compiler import ClientSupport +from concrete.compiler import Value @pytest.mark.parametrize( @@ -16,8 +16,8 @@ ], ) def test_invalid_arg_type(garbage): - with pytest.raises(TypeError): - ClientSupport._create_lambda_argument(garbage, signed=False) + with pytest.raises(RuntimeError): + Value(garbage) @pytest.mark.parametrize( @@ -32,11 +32,11 @@ def test_invalid_arg_type(garbage): ) def test_accepted_ints(value): try: - arg = ClientSupport._create_lambda_argument(value, signed=False) + arg = Value(value) except Exception: pytest.fail(f"value of type {type(value)} should be supported") assert arg.is_scalar(), "should have been a scalar" - assert arg.get_signed_scalar() == value + assert arg.to_py_val() == value # TODO: #495 @@ -52,16 +52,16 @@ def test_accepted_ints(value): def test_accepted_ndarray(dtype, maxvalue): value = np.array([0, 1, 2, maxvalue], dtype=dtype) try: - arg = ClientSupport._create_lambda_argument(value, signed=False) + arg = Value(value) except Exception: pytest.fail(f"value of type {type(value)} should be supported") assert arg.is_tensor(), "should have been a tensor" - assert np.all(np.equal(arg.get_tensor_shape(), value.shape)) + assert np.all(np.equal(arg.get_shape(), value.shape)) assert np.all( np.equal( - value.astype(np.int64), - np.array(arg.get_signed_tensor_data()).reshape(arg.get_tensor_shape()), + value, + np.array(arg.to_py_val()), ) ) @@ -69,8 +69,8 @@ def test_accepted_ndarray(dtype, maxvalue): def test_accepted_array_as_scalar(): value = np.array(7, dtype=np.uint16) try: - arg = ClientSupport._create_lambda_argument(value, signed=False) + arg = Value(value) except Exception: pytest.fail(f"value of type {type(value)} should be supported") assert arg.is_scalar(), "should have been a scalar" - assert arg.get_signed_scalar() == value + assert arg.to_py_val() == value diff --git a/compilers/concrete-compiler/compiler/tests/python/test_client_server.py b/compilers/concrete-compiler/compiler/tests/python/test_client_server.py index 84072b1ac5..3bf6378788 100644 --- a/compilers/concrete-compiler/compiler/tests/python/test_client_server.py +++ b/compilers/concrete-compiler/compiler/tests/python/test_client_server.py @@ -4,16 +4,23 @@ import tempfile from concrete.compiler import ( - ClientSupport, - EvaluationKeys, - LibrarySupport, - PublicArguments, - PublicResult, + Library, + Compiler, + lookup_runtime_lib, + CompilationOptions, + Backend, + Keyset, + ClientProgram, + ServerKeyset, + Value, + ServerProgram, + TransportValue, + CompilationContext, ) @pytest.mark.parametrize( - "mlir, args, expected_result", + "mlir, args, expected_results", [ pytest.param( """ @@ -25,9 +32,8 @@ """, (5, 7), - 12, + (12,), id="enc_plain_int_args", - marks=pytest.mark.xfail, ), pytest.param( """ @@ -39,12 +45,12 @@ """, (5, 7), - 12, + (12,), id="enc_enc_int_args", ), pytest.param( """ - + func.func @main(%arg0: tensor<4x!FHE.eint<5>>, %arg1: tensor<4xi6>) -> !FHE.eint<5> { %ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) : (tensor<4x!FHE.eint<5>>, tensor<4xi6>) -> !FHE.eint<5> return %ret : !FHE.eint<5> @@ -52,12 +58,11 @@ """, ( - np.array([1, 2, 3, 4], dtype=np.uint64), - np.array([4, 3, 2, 1], dtype=np.uint8), + np.array([1, 2, 3, 4], dtype=np.int64), + np.array([4, 3, 2, 1], dtype=np.int64), ), - 20, + (20,), id="enc_plain_ndarray_args", - marks=pytest.mark.xfail, ), pytest.param( """ @@ -69,10 +74,10 @@ """, ( - np.array([1, 2, 3, 4], dtype=np.uint64), - np.array([7, 0, 1, 5], dtype=np.uint64), + np.array([1, 2, 3, 4], dtype=np.int64), + np.array([7, 0, 1, 5], dtype=np.int64), ), - np.array([8, 2, 4, 9]), + (np.array([8, 2, 4, 9]),), id="enc_enc_ndarray_args", ), pytest.param( @@ -99,38 +104,48 @@ ), ], ) -def test_client_server_end_to_end(mlir, args, expected_result, keyset_cache): +def test_client_server_end_to_end(mlir, args, expected_results, keyset_cache): with tempfile.TemporaryDirectory() as tmpdirname: - support = LibrarySupport.new(str(tmpdirname)) - compilation_result = support.compile(mlir) - server_lambda = support.load_server_lambda(compilation_result, False) + support = Compiler( + str(tmpdirname), lookup_runtime_lib(), generate_shared_lib=True + ) + library = support.compile(mlir, CompilationOptions(Backend.CPU)) - client_parameters = support.load_client_parameters(compilation_result) - keyset = ClientSupport.key_set(client_parameters, keyset_cache) + program_info = library.get_program_info() + keyset = Keyset(program_info, keyset_cache) - evaluation_keys = keyset.get_evaluation_keys() + evaluation_keys = keyset.get_server_keys() evaluation_keys_serialized = evaluation_keys.serialize() - evaluation_keys_deserialized = EvaluationKeys.deserialize( + evaluation_keys_deserialized = ServerKeyset.deserialize( evaluation_keys_serialized ) - args = ClientSupport.encrypt_arguments(client_parameters, keyset, args) - args_serialized = args.serialize() - args_deserialized = PublicArguments.deserialize( - client_parameters, args_serialized - ) + client_program = ClientProgram.create_encrypted(program_info, keyset) + client_circuit = client_program.get_client_circuit("main") + args_serialized = [ + client_circuit.prepare_input(Value(arg), i).serialize() + for (i, arg) in enumerate(args) + ] + args_deserialized = [TransportValue.deserialize(arg) for arg in args_serialized] - result = support.server_call( - server_lambda, + server_program = ServerProgram(library, False) + server_circuit = server_program.get_server_circuit("main") + + results = server_circuit.call( args_deserialized, evaluation_keys_deserialized, ) - result_serialized = result.serialize() - result_deserialized = PublicResult.deserialize( - client_parameters, result_serialized - ) - - output = ClientSupport.decrypt_result( - client_parameters, keyset, result_deserialized + results_serialized = [result.serialize() for result in results] + results_deserialized = [ + client_circuit.process_output( + TransportValue.deserialize(result), i + ).to_py_val() + for (i, result) in enumerate(results_serialized) + ] + + assert all( + [ + np.all(result == expected) + for (result, expected) in zip(results_deserialized, expected_results) + ] ) - assert np.array_equal(output, expected_result) diff --git a/compilers/concrete-compiler/compiler/tests/python/test_compilation.py b/compilers/concrete-compiler/compiler/tests/python/test_compilation.py index 464c464723..5e42c702cc 100644 --- a/compilers/concrete-compiler/compiler/tests/python/test_compilation.py +++ b/compilers/concrete-compiler/compiler/tests/python/test_compilation.py @@ -4,32 +4,38 @@ import shutil import numpy as np from concrete.compiler import ( - LibrarySupport, - ClientSupport, + Compiler, CompilationOptions, ProgramCompilationFeedback, CircuitCompilationFeedback, + Backend, + lookup_runtime_lib, + Keyset, + Library, + ServerKeyset, + ServerProgram, + ClientProgram, + TransportValue, + Value, ) -def assert_result(result, expected_result): +def assert_result(results, expected_results): """Assert that result and expected result are equal. result and expected_result can be integers on numpy arrays. """ - assert type(expected_result) == type(result) - if isinstance(expected_result, int): - assert result == expected_result - else: + for result, expected_result in zip(results, expected_results): + assert type(expected_result) == type(result) assert np.all(result == expected_result) -def run(engine, args, compilation_result, keyset_cache, circuit_name): +def run(library: Library, args, keyset_cache, circuit_name): """Execute engine on the given arguments. Perform required loading, encryption, execution, and decryption.""" # Dev - compilation_feedback = engine.load_compilation_feedback(compilation_result) + compilation_feedback = library.get_program_compilation_feedback() assert isinstance(compilation_feedback, ProgramCompilationFeedback) assert isinstance(compilation_feedback.complexity, float) assert isinstance(compilation_feedback.p_error, float) @@ -37,42 +43,54 @@ def run(engine, args, compilation_result, keyset_cache, circuit_name): assert isinstance(compilation_feedback.total_secret_keys_size, int) assert isinstance(compilation_feedback.total_bootstrap_keys_size, int) assert isinstance(compilation_feedback.circuit_feedbacks, list) - circuit_feedback = next( - filter(lambda x: x.name == circuit_name, compilation_feedback.circuit_feedbacks) - ) + circuit_feedback = compilation_feedback.get_circuit_feedback(circuit_name) assert isinstance(circuit_feedback, CircuitCompilationFeedback) assert isinstance(circuit_feedback.total_inputs_size, int) assert isinstance(circuit_feedback.total_output_size, int) - # Client - client_parameters = engine.load_client_parameters(compilation_result) - key_set = ClientSupport.key_set(client_parameters, keyset_cache) - public_arguments = ClientSupport.encrypt_arguments( - client_parameters, key_set, args, circuit_name - ) - # Server - server_lambda = engine.load_server_lambda(compilation_result, False, circuit_name) - evaluation_keys = key_set.get_evaluation_keys() - public_result = engine.server_call(server_lambda, public_arguments, evaluation_keys) - # Client - result = ClientSupport.decrypt_result( - client_parameters, key_set, public_result, circuit_name + program_info = library.get_program_info() + keyset = Keyset(program_info, keyset_cache) + + evaluation_keys = keyset.get_server_keys() + evaluation_keys_serialized = evaluation_keys.serialize() + evaluation_keys_deserialized = ServerKeyset.deserialize(evaluation_keys_serialized) + + client_program = ClientProgram.create_encrypted(program_info, keyset) + client_circuit = client_program.get_client_circuit(circuit_name) + args_serialized = [ + client_circuit.prepare_input(Value(arg), i).serialize() + for (i, arg) in enumerate(args) + ] + args_deserialized = [TransportValue.deserialize(arg) for arg in args_serialized] + + server_program = ServerProgram(library, False) + server_circuit = server_program.get_server_circuit(circuit_name) + + results = server_circuit.call( + args_deserialized, + evaluation_keys_deserialized, ) - return result + results_serialized = [result.serialize() for result in results] + results_deserialized = [ + client_circuit.process_output(TransportValue.deserialize(result), i).to_py_val() + for (i, result) in enumerate(results_serialized) + ] + + return results_deserialized def compile_run_assert( - engine, + compiler, mlir_input, args, expected_result, keyset_cache, - options=CompilationOptions.new(), + options=CompilationOptions(Backend.CPU), circuit_name="main", ): """Compile run and assert result.""" - compilation_result = engine.compile(mlir_input, options) - result = run(engine, args, compilation_result, keyset_cache, circuit_name) + library = compiler.compile(mlir_input, options) + result = run(library, args, keyset_cache, circuit_name) assert_result(result, expected_result) @@ -85,7 +103,7 @@ def compile_run_assert( } """, (5, 7), - 12, + (12,), id="add_eint_int", ), pytest.param( @@ -95,8 +113,8 @@ def compile_run_assert( return %1: !FHE.eint<7> } """, - (np.array(4, dtype=np.int64), np.array(5, dtype=np.uint8)), - 9, + (np.array(4), np.array(5)), + (9,), id="add_eint_int_with_ndarray_as_scalar", ), pytest.param( @@ -108,7 +126,7 @@ def compile_run_assert( } """, (73,), - 73, + (73,), id="apply_lookup_table", ), pytest.param( @@ -121,10 +139,10 @@ def compile_run_assert( } """, ( - np.array([1, 2, 3, 4], dtype=np.uint8), - np.array([4, 3, 2, 1], dtype=np.uint8), + np.array([1, 2, 3, 4]), + np.array([4, 3, 2, 1]), ), - 20, + (20,), id="dot_eint_int_uint8", ), pytest.param( @@ -135,10 +153,10 @@ def compile_run_assert( } """, ( - np.array([31, 6, 12, 9], dtype=np.uint8), - np.array([32, 9, 2, 3], dtype=np.uint8), + np.array([31, 6, 12, 9]), + np.array([32, 9, 2, 3]), ), - np.array([63, 15, 14, 12]), + (np.array([63, 15, 14, 12]),), id="add_eint_int_1D", ), pytest.param( @@ -149,7 +167,7 @@ def compile_run_assert( } """, (5,), - -5, + (-5,), id="neg_eint_signed", ), pytest.param( @@ -159,8 +177,12 @@ def compile_run_assert( return %0: tensor<2x!FHE.esint<7>> } """, - (np.array([-5, 3]),), - np.array([5, -3]), + ( + np.array( + [-5, 3], + ), + ), + (np.array([5, -3]),), id="neg_eint_signed_2", ), ] @@ -177,10 +199,10 @@ def compile_run_assert( } """, ( - np.array([[1, 2, 3, 4], [4, 2, 1, 0], [2, 3, 1, 5]], dtype=np.uint8), - np.array([[1, 2, 3, 4], [4, 2, 1, 1], [2, 3, 1, 5]], dtype=np.uint8), + np.array([[1, 2, 3, 4], [4, 2, 1, 0], [2, 3, 1, 5]]), + np.array([[1, 2, 3, 4], [4, 2, 1, 1], [2, 3, 1, 5]]), ), - np.array([[52, 36], [31, 34], [42, 52]]), + (np.array([[52, 36], [31, 34], [42, 52]]),), id="matmul_eint_int_uint8", ), pytest.param( @@ -193,12 +215,12 @@ def compile_run_assert( } """, ( - np.array([1, 2, 3, 4], dtype=np.uint8), - np.array([9, 8, 6, 5], dtype=np.uint8), - np.array([3, 2, 7, 0], dtype=np.uint8), - np.array([1, 4, 2, 11], dtype=np.uint8), + np.array([1, 2, 3, 4]), + np.array([9, 8, 6, 5]), + np.array([3, 2, 7, 0]), + np.array([1, 4, 2, 11]), ), - np.array([14, 16, 18, 20]), + (np.array([14, 16, 18, 20]),), id="add_eint_int_1D", ), ] @@ -207,19 +229,21 @@ def compile_run_assert( @pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture) def test_lib_compile_and_run(mlir_input, args, expected_result, keyset_cache): artifact_dir = "./py_test_lib_compile_and_run" - engine = LibrarySupport.new(artifact_dir) - compile_run_assert(engine, mlir_input, args, expected_result, keyset_cache) + compiler = Compiler(artifact_dir, lookup_runtime_lib()) + compile_run_assert(compiler, mlir_input, args, expected_result, keyset_cache) shutil.rmtree(artifact_dir) @pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture) def test_lib_compile_reload_and_run(mlir_input, args, expected_result, keyset_cache): artifact_dir = "./test_lib_compile_reload_and_run" - engine = LibrarySupport.new(artifact_dir) + library = Library(artifact_dir) # Here don't save compilation result, reload - engine.compile(mlir_input) - compilation_result = engine.reload() - result = run(engine, args, compilation_result, keyset_cache, "main") + library = Compiler(artifact_dir, lookup_runtime_lib()).compile( + mlir_input, CompilationOptions(Backend.CPU) + ) + compilation_result = library.get_program_compilation_feedback() + result = run(library, args, keyset_cache, "main") # Check result assert_result(result, expected_result) shutil.rmtree(artifact_dir) @@ -233,13 +257,14 @@ def test_lib_compilation_artifacts(): } """ artifact_dir = "./test_artifacts" - engine = LibrarySupport.new(artifact_dir) - engine.compile(mlir_str) - assert os.path.exists(engine.get_program_info_path()) - assert os.path.exists(engine.get_shared_lib_path()) + library = Compiler(artifact_dir, lookup_runtime_lib()).compile( + mlir_str, CompilationOptions(Backend.CPU) + ) + assert os.path.exists(library.get_program_info_path()) + assert os.path.exists(library.get_shared_lib_path()) shutil.rmtree(artifact_dir) - assert not os.path.exists(engine.get_program_info_path()) - assert not os.path.exists(engine.get_shared_lib_path()) + assert not os.path.exists(library.get_program_info_path()) + assert not os.path.exists(library.get_shared_lib_path()) def test_multi_circuits(keyset_cache): @@ -256,16 +281,29 @@ def test_multi_circuits(keyset_cache): } """ args = (10, 3) - expected_add_result = 13 - expected_sub_result = 7 - engine = LibrarySupport.new("./py_test_multi_circuits") - options = CompilationOptions.new() + expected_add_result = (13,) + expected_sub_result = (7,) + artifact_dir = "./py_test_multi_circuits" + options = CompilationOptions(Backend.CPU) options.set_optimizer_strategy(OptimizerStrategy.V0) + compiler = Compiler(artifact_dir, lookup_runtime_lib()) compile_run_assert( - engine, mlir_str, args, expected_add_result, keyset_cache, options, "add" + compiler, + mlir_str, + args, + expected_add_result, + keyset_cache, + options, + circuit_name="add", ) compile_run_assert( - engine, mlir_str, args, expected_sub_result, keyset_cache, options, "sub" + compiler, + mlir_str, + args, + expected_sub_result, + keyset_cache, + options, + circuit_name="sub", ) @@ -278,21 +316,25 @@ def _test_lib_compile_and_run_with_options(keyset_cache, options): } """ args = (73,) - expected_result = 73 - engine = LibrarySupport.new("./py_test_lib_compile_and_run_custom_perror") + expected_result = (73,) + compiler = Compiler( + "./py_test_lib_compile_and_run_custom_perror", lookup_runtime_lib() + ) - compile_run_assert(engine, mlir_input, args, expected_result, keyset_cache, options) + compile_run_assert( + compiler, mlir_input, args, expected_result, keyset_cache, options + ) def test_lib_compile_and_run_p_error(keyset_cache): - options = CompilationOptions.new() + options = CompilationOptions(Backend.CPU) options.set_p_error(0.00001) options.set_display_optimizer_choice(True) _test_lib_compile_and_run_with_options(keyset_cache, options) def test_lib_compile_and_run_global_p_error(keyset_cache): - options = CompilationOptions.new() + options = CompilationOptions(Backend.CPU) options.set_global_p_error(0.00001) options.set_display_optimizer_choice(True) _test_lib_compile_and_run_with_options(keyset_cache, options) @@ -306,43 +348,14 @@ def test_compile_and_run_auto_parallelize( mlir_input, args, expected_result, keyset_cache ): artifact_dir = "./py_test_compile_and_run_auto_parallelize" - engine = LibrarySupport.new(artifact_dir) - options = CompilationOptions.new() + options = CompilationOptions(Backend.CPU) options.set_auto_parallelize(True) + engine = Compiler(artifact_dir, lookup_runtime_lib()) compile_run_assert( engine, mlir_input, args, expected_result, keyset_cache, options=options ) -# This test was running in JIT mode at first. Problem is now, it does not work with the library -# support. It is not clear to me why, but the dataflow runtime seems to have stuffs dedicated to -# the dropped JIT support... I am cancelling it until further explored. -# -# # FIXME #51 -# @pytest.mark.xfail( -# platform.system() == "Darwin", -# reason="MacOS have issues with translating Cpp exceptions", -# ) -# @pytest.mark.parametrize( -# "mlir_input, args, expected_result", end_to_end_parallel_fixture -# ) -# def test_compile_dataflow_and_fail_run( -# mlir_input, args, expected_result, keyset_cache, no_parallel -# ): -# if no_parallel: -# artifact_dir = "./py_test_compile_dataflow_and_fail_run" -# engine = LibrarySupport.new(artifact_dir) -# options = CompilationOptions.new() -# options.set_auto_parallelize(True) -# with pytest.raises( -# RuntimeError, -# match="call: current runtime doesn't support dataflow execution", -# ): -# compile_run_assert( -# engine, mlir_input, args, expected_result, keyset_cache, options=options -# ) - - @pytest.mark.parametrize( "mlir_input, args, expected_result", [ @@ -354,8 +367,8 @@ def test_compile_and_run_auto_parallelize( return %0 : tensor<3x2x!FHE.eint<7>> } """, - (np.array([[1, 2, 3, 4], [4, 2, 1, 0], [2, 3, 1, 5]], dtype=np.uint8),), - np.array([[26, 18], [15, 16], [21, 26]]), + (np.array([[1, 2, 3, 4], [4, 2, 1, 0], [2, 3, 1, 5]]),), + (np.array([[26, 18], [15, 16], [21, 26]]),), id="matmul_eint_int_uint8", ), ], @@ -364,11 +377,11 @@ def test_compile_and_run_loop_parallelize( mlir_input, args, expected_result, keyset_cache ): artifact_dir = "./py_test_compile_and_run_loop_parallelize" - engine = LibrarySupport.new(artifact_dir) - options = CompilationOptions.new() + compiler = Compiler(artifact_dir, lookup_runtime_lib()) + options = CompilationOptions(Backend.CPU) options.set_loop_parallelize(True) compile_run_assert( - engine, mlir_input, args, expected_result, keyset_cache, options=options + compiler, mlir_input, args, expected_result, keyset_cache, options=options ) @@ -377,7 +390,7 @@ def test_compile_and_run_loop_parallelize( [ pytest.param( """ - func.func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { + func.func @main%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) return %1: !FHE.eint<7> } @@ -389,11 +402,9 @@ def test_compile_and_run_loop_parallelize( ) def test_compile_and_run_invalid_arg_number(mlir_input, args, keyset_cache): artifact_dir = "./py_test_compile_and_run_invalid_arg_number" - engine = LibrarySupport.new(artifact_dir) - with pytest.raises( - RuntimeError, match=r"function has arity 2 but is applied to too many arguments" - ): - compile_run_assert(engine, mlir_input, args, None, keyset_cache) + compiler = Compiler(artifact_dir, lookup_runtime_lib()) + with pytest.raises(RuntimeError): + compile_run_assert(compiler, mlir_input, args, None, keyset_cache) def test_crt_decomposition_feedback(): @@ -408,9 +419,9 @@ def test_crt_decomposition_feedback(): """ artifact_dir = "./py_test_crt_decomposition_feedback" - engine = LibrarySupport.new(artifact_dir) - compilation_result = engine.compile(mlir, options=CompilationOptions.new()) - compilation_feedback = engine.load_compilation_feedback(compilation_result) + compiler = Compiler(artifact_dir, lookup_runtime_lib()) + library = compiler.compile(mlir, options=CompilationOptions(Backend.CPU)) + compilation_feedback = library.get_program_compilation_feedback() assert isinstance(compilation_feedback, ProgramCompilationFeedback) assert isinstance(compilation_feedback.complexity, float) @@ -445,7 +456,7 @@ def test_crt_decomposition_feedback(): } """, # 4*4*4097*8 (input1) + 4*2 (input2) + 4*2*4097*8 + 4097*3*8 + 4096*8 + 869*8 (temporary buffers) + 4*2*4097*8 (output buffer) + 64*8 (constant TLU) - {'loc("some/random/location.py":10:2)': 1187400}, + {'loc("some/random/location.py":10:2)': 1187584}, id="single location", ), pytest.param( @@ -461,7 +472,7 @@ def test_crt_decomposition_feedback(): # 4*4*4097*8 (input1) + 4*2 (input2) + 4*2*4097*8 (matmul result buffer) + 4097*2*8 (temporary buffers) 'loc("@matmul some/random/location.py":10:2)': 852184, # 4*2*4097*8 (matmul result buffer) + 4*2*4097*8 (result buffer) + 4097*8 + 4096*8 + 869*8 (temporary buffers) + 64*8 (constant TLU) - 'loc("@lut some/random/location.py":11:2)': 597424, + 'loc("@lut some/random/location.py":11:2)': 597608, # 4*2*4097*8 (result buffer) 'loc("@return some/random/location.py":12:2)': 262208, }, @@ -471,9 +482,9 @@ def test_crt_decomposition_feedback(): ) def test_memory_usage(mlir: str, expected_memory_usage_per_loc: dict): artifact_dir = "./test_memory_usage" - engine = LibrarySupport.new(artifact_dir) - compilation_result = engine.compile(mlir) - compilation_feedback = engine.load_compilation_feedback(compilation_result) + compiler = Compiler(artifact_dir, lookup_runtime_lib()) + library = compiler.compile(mlir, CompilationOptions(Backend.CPU)) + compilation_feedback = library.get_program_compilation_feedback() assert isinstance(compilation_feedback, ProgramCompilationFeedback) assert ( diff --git a/compilers/concrete-compiler/compiler/tests/python/test_keyset_serialization.py b/compilers/concrete-compiler/compiler/tests/python/test_keyset_serialization.py index ca4b4e2817..06db84cc04 100644 --- a/compilers/concrete-compiler/compiler/tests/python/test_keyset_serialization.py +++ b/compilers/concrete-compiler/compiler/tests/python/test_keyset_serialization.py @@ -4,12 +4,16 @@ import tempfile from concrete.compiler import ( - ClientSupport, - EvaluationKeys, - KeySet, - LibrarySupport, - PublicArguments, - PublicResult, + Compiler, + CompilationOptions, + lookup_runtime_lib, + Keyset, + ServerKeyset, + ClientProgram, + ServerProgram, + TransportValue, + Value, + Backend, ) @@ -27,26 +31,51 @@ def test_keyset_serialization(): """.strip() with tempfile.TemporaryDirectory() as tmpdirname: - support = LibrarySupport.new(str(tmpdirname)) - compilation_result = support.compile(mlir) - server_lambda = support.load_server_lambda(compilation_result, False) - client_parameters = support.load_client_parameters(compilation_result) + args = (5,) + expected_results = (25,) - keyset = ClientSupport.key_set(client_parameters) - evaluation_keys = keyset.get_evaluation_keys() + support = Compiler( + str(tmpdirname), lookup_runtime_lib(), generate_shared_lib=True + ) + library = support.compile(mlir, CompilationOptions(Backend.CPU)) + + program_info = library.get_program_info() + keyset = Keyset(program_info, None) + keyset = Keyset.deserialize(keyset.serialize()) - arg = 5 - encrypted_args = ClientSupport.encrypt_arguments( - client_parameters, keyset, [arg] + evaluation_keys = keyset.get_server_keys() + evaluation_keys_serialized = evaluation_keys.serialize() + evaluation_keys_deserialized = ServerKeyset.deserialize( + evaluation_keys_serialized ) - result = support.server_call(server_lambda, encrypted_args, evaluation_keys) + client_program = ClientProgram.create_encrypted(program_info, keyset) + client_circuit = client_program.get_client_circuit("main") + args_serialized = [ + client_circuit.prepare_input(Value(arg), i).serialize() + for (i, arg) in enumerate(args) + ] + args_deserialized = [TransportValue.deserialize(arg) for arg in args_serialized] - serialized_keyset = keyset.serialize() - deserialized_keyset = KeySet.deserialize(serialized_keyset) + server_program = ServerProgram(library, False) + server_circuit = server_program.get_server_circuit("main") + + results = server_circuit.call( + args_deserialized, + evaluation_keys_deserialized, + ) + results_serialized = [result.serialize() for result in results] + results_deserialized = [ + client_circuit.process_output( + TransportValue.deserialize(result), i + ).to_py_val() + for (i, result) in enumerate(results_serialized) + ] - output = ClientSupport.decrypt_result( - client_parameters, deserialized_keyset, result + assert all( + [ + np.all(result == expected) + for (result, expected) in zip(results_deserialized, expected_results) + ] ) - assert output == arg**2 diff --git a/compilers/concrete-compiler/compiler/tests/python/test_simulation.py b/compilers/concrete-compiler/compiler/tests/python/test_simulation.py index 6d36f8aa0c..ce76ba59e4 100644 --- a/compilers/concrete-compiler/compiler/tests/python/test_simulation.py +++ b/compilers/concrete-compiler/compiler/tests/python/test_simulation.py @@ -1,3 +1,4 @@ +import json import subprocess import sys import os @@ -6,60 +7,88 @@ import shutil import numpy as np from concrete.compiler import ( - LibrarySupport, - PublicArguments, - SimulatedValueExporter, - SimulatedValueDecrypter, + Compiler, CompilationOptions, + ProgramCompilationFeedback, + CircuitCompilationFeedback, + Backend, + lookup_runtime_lib, + Keyset, + Library, + ServerKeyset, + ServerProgram, + ClientProgram, + TransportValue, + Value, ) -def assert_result(result, expected_result): +def assert_result(results, expected_results): """Assert that result and expected result are equal. result and expected_result can be integers on numpy arrays. """ - assert type(expected_result) == type(result) - if isinstance(expected_result, int): - assert result == expected_result, f"{result} != {expected_result}" - else: + for result, expected_result in zip(results, expected_results): + assert type(expected_result) == type(result) assert np.all(result == expected_result) -def run_simulated(engine, args_and_shape, compilation_result): - client_parameters = engine.load_client_parameters(compilation_result) - sim_value_exporter = SimulatedValueExporter.new(client_parameters) - values = [] - pos = 0 - for arg, shape in args_and_shape: - if shape is None: - assert isinstance(arg, int) - values.append(sim_value_exporter.export_scalar(pos, arg)) - else: - assert isinstance(arg, list) - assert isinstance(shape, list) - values.append(sim_value_exporter.export_tensor(pos, arg, shape)) - pos += 1 - public_arguments = PublicArguments.new(client_parameters, values) - server_lambda = engine.load_server_lambda(compilation_result, True) - public_result = engine.simulate(server_lambda, public_arguments) - sim_value_decrypter = SimulatedValueDecrypter.new(client_parameters) - result = sim_value_decrypter.decrypt(0, public_result.get_value(0)) - return result +def run_simulated(library: Library, args, circuit_name): + """Execute engine on the given arguments. + + Perform required loading, encryption, execution, and decryption.""" + # Dev + compilation_feedback = library.get_program_compilation_feedback() + assert isinstance(compilation_feedback, ProgramCompilationFeedback) + assert isinstance(compilation_feedback.complexity, float) + assert isinstance(compilation_feedback.p_error, float) + assert isinstance(compilation_feedback.global_p_error, float) + assert isinstance(compilation_feedback.total_secret_keys_size, int) + assert isinstance(compilation_feedback.total_bootstrap_keys_size, int) + assert isinstance(compilation_feedback.circuit_feedbacks, list) + circuit_feedback = compilation_feedback.get_circuit_feedback(circuit_name) + assert isinstance(circuit_feedback, CircuitCompilationFeedback) + assert isinstance(circuit_feedback.total_inputs_size, int) + assert isinstance(circuit_feedback.total_output_size, int) + + program_info = library.get_program_info() + + client_program = ClientProgram.create_simulated(program_info) + client_circuit = client_program.get_client_circuit(circuit_name) + args_serialized = [ + client_circuit.simulate_prepare_input(Value(arg), i).serialize() + for (i, arg) in enumerate(args) + ] + args_deserialized = [TransportValue.deserialize(arg) for arg in args_serialized] + + server_program = ServerProgram(library, True) + server_circuit = server_program.get_server_circuit(circuit_name) + + results = server_circuit.simulate(args_deserialized) + results_serialized = [result.serialize() for result in results] + results_deserialized = [ + client_circuit.simulate_process_output( + TransportValue.deserialize(result), i + ).to_py_val() + for (i, result) in enumerate(results_serialized) + ] + + return results_deserialized def compile_run_assert( - engine, + compiler, mlir_input, - args_and_shape, + args, expected_result, - options=CompilationOptions.new(), + options=CompilationOptions(Backend.CPU), + circuit_name="main", ): - # compile with simulation + """Compile run and assert result.""" options.simulation(True) options.set_enable_overflow_detection_in_simulation(True) - compilation_result = engine.compile(mlir_input, options) - result = run_simulated(engine, args_and_shape, compilation_result) + library = compiler.compile(mlir_input, options) + result = run_simulated(library, args, circuit_name) assert_result(result, expected_result) @@ -72,7 +101,7 @@ def compile_run_assert( } """, (5, 7), - 12, + (12,), id="add_eint_int", ), pytest.param( @@ -84,7 +113,7 @@ def compile_run_assert( } """, (73,), - 73, + (73,), id="apply_lookup_table", ), pytest.param( @@ -97,10 +126,10 @@ def compile_run_assert( } """, ( - np.array([1, 2, 3, 4], dtype=np.uint8), - np.array([4, 3, 2, 1], dtype=np.uint8), + np.array([1, 2, 3, 4]), + np.array([4, 3, 2, 1]), ), - 20, + (20,), id="dot_eint_int_uint8", ), pytest.param( @@ -111,10 +140,10 @@ def compile_run_assert( } """, ( - np.array([31, 6, 12, 9], dtype=np.uint8), - np.array([32, 9, 2, 3], dtype=np.uint8), + np.array([31, 6, 12, 9]), + np.array([32, 9, 2, 3]), ), - np.array([63, 15, 14, 12]), + (np.array([63, 15, 14, 12]),), id="add_eint_int_1D", ), pytest.param( @@ -125,7 +154,7 @@ def compile_run_assert( } """, (5,), - -5, + (-5,), id="neg_eint_signed", ), pytest.param( @@ -136,7 +165,7 @@ def compile_run_assert( } """, (np.array([-5, 3]),), - np.array([5, -3]), + (np.array([5, -3]),), id="neg_eint_signed_2", ), pytest.param( @@ -153,26 +182,28 @@ def compile_run_assert( np.array([i // 4 for i in range(16)]).reshape((4, 4)), np.array([i // 4 for i in range(15, -1, -1)]).reshape((4, 4)), ), - np.array( - [ - 0, - 0, - 0, - 0, - 1296, - 1296, - 1296, - 1296, - 2592, - 2592, - 2592, - 2592, - 3888, - 3888, - 3888, - 3888, - ] - ).reshape((4, 4)), + ( + np.array( + [ + 0, + 0, + 0, + 0, + 1296, + 1296, + 1296, + 1296, + 2592, + 2592, + 2592, + 2592, + 3888, + 3888, + 3888, + 3888, + ] + ).reshape((4, 4)), + ), id="matul_chain_with_crt", ), pytest.param( @@ -186,9 +217,9 @@ def compile_run_assert( """, ( 81, - np.array(range(16384), dtype=np.uint64), + np.array(range(16384)), ), - 96, + (96,), id="add_lut_crt", ), ] @@ -205,10 +236,10 @@ def compile_run_assert( } """, ( - np.array([[1, 2, 3, 4], [4, 2, 1, 0], [2, 3, 1, 5]], dtype=np.uint8), - np.array([[1, 2, 3, 4], [4, 2, 1, 1], [2, 3, 1, 5]], dtype=np.uint8), + np.array([[1, 2, 3, 4], [4, 2, 1, 0], [2, 3, 1, 5]]), + np.array([[1, 2, 3, 4], [4, 2, 1, 1], [2, 3, 1, 5]]), ), - np.array([[52, 36], [31, 34], [42, 52]]), + (np.array([[52, 36], [31, 34], [42, 52]]),), id="matmul_eint_int_uint8", ), pytest.param( @@ -221,12 +252,12 @@ def compile_run_assert( } """, ( - np.array([1, 2, 3, 4], dtype=np.uint8), - np.array([9, 8, 6, 5], dtype=np.uint8), - np.array([3, 2, 7, 0], dtype=np.uint8), - np.array([1, 4, 2, 11], dtype=np.uint8), + np.array([1, 2, 3, 4]), + np.array([9, 8, 6, 5]), + np.array([3, 2, 7, 0]), + np.array([1, 4, 2, 11]), ), - np.array([14, 16, 18, 20]), + (np.array([14, 16, 18, 20]),), id="add_eint_int_1D", ), ] @@ -235,14 +266,8 @@ def compile_run_assert( @pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture) def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): artifact_dir = "./py_test_lib_compile_and_run" - engine = LibrarySupport.new(artifact_dir) - args_and_shape = [] - for arg in args: - if isinstance(arg, int): - args_and_shape.append((arg, None)) - else: # np.array - args_and_shape.append((arg.flatten().tolist(), list(arg.shape))) - compile_run_assert(engine, mlir_input, args_and_shape, expected_result) + compiler = Compiler(artifact_dir, lookup_runtime_lib()) + compile_run_assert(compiler, mlir_input, args, expected_result) shutil.rmtree(artifact_dir) @@ -255,7 +280,7 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): } """, (120, 30), - 150, + (150,), b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', id="add_eint_int", ), @@ -267,7 +292,7 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): } """, (-1, -2), - -3, + (-3,), b"", id="add_eint_int_signed", ), @@ -279,7 +304,7 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): } """, (-60, -20), - -80, + (-80,), b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', id="add_eint_int_signed_underflow", ), @@ -291,7 +316,7 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): } """, (60, 20), - -48, + (-48,), b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', id="add_eint_int_signed_overflow", ), @@ -303,7 +328,7 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): } """, (81, 73), - 154, + (154,), b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', id="add_eint", ), @@ -315,7 +340,7 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): } """, (-81, 73), - -8, + (-8,), b"", id="add_eint_signed", ), @@ -327,7 +352,7 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): } """, (-60, -20), - -80, # undefined behavior + (-80,), # undefined behavior b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', id="add_eint_signed_underflow", ), @@ -339,7 +364,7 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): } """, (81, 73), - -102, + (-102,), b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', id="add_eint_signed_overflow", ), @@ -351,7 +376,7 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): } """, (4, 7), - 256 - 3, + (256 - 3,), b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', id="sub_eint_int", ), @@ -363,7 +388,7 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): } """, (4, 7), - -3, + (-3,), b"", id="sub_eint_int_signed", ), @@ -375,7 +400,7 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): } """, (-37, 40), - -77, + (-77,), b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', id="sub_eint_int_signed_underflow", ), @@ -387,7 +412,7 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): } """, (33, -40), - -55, + (-55,), b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', id="sub_eint_int_signed_overflow", ), @@ -399,7 +424,7 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): } """, (11, 18), - 256 - 7, + (256 - 7,), b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', id="sub_eint", ), @@ -411,7 +436,7 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): } """, (11, 18), - -7, + (-7,), b"", id="sub_eint_signed", ), @@ -423,7 +448,7 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): } """, (-44, 32), - -76, # undefined behavior + (-76,), # undefined behavior b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', id="sub_eint_signed_underflow", ), @@ -435,7 +460,7 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): } """, (61, -25), - -42, # undefined behavior + (-42,), # undefined behavior b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', id="sub_eint_signed_overflow", ), @@ -447,7 +472,7 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): } """, (20, 10), - 200, + (200,), b'WARNING at loc("-":3:22): overflow happened during multiplication in simulation\n', id="mul_eint_int", ), @@ -460,7 +485,7 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): } """, (5, 10), - 256 - 50, + (256 - 50,), b'WARNING at loc("-":3:22): overflow happened during addition in simulation\nWARNING at loc("-":4:22): overflow happened during multiplication in simulation\n', id="sub_mul_eint_int", ), @@ -472,7 +497,7 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): } """, (5, -2), - -10, + (-10,), b"", id="mul_eint_int_signed", ), @@ -484,7 +509,7 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): } """, (-33, 5), - -37, # undefined behavior + (-37,), # undefined behavior b'WARNING at loc("-":3:22): overflow happened during multiplication in simulation\n', id="mul_eint_int_signed_underflow", ), @@ -496,7 +521,7 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): } """, (-33, -5), - -91, + (-91,), b'WARNING at loc("-":3:22): overflow happened during multiplication in simulation\n', id="mul_eint_int_signed_overflow", ), @@ -509,7 +534,7 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): } """, (1,), - 140, + (140,), b'WARNING at loc("-":4:22): overflow (padding bit) happened during LUT in simulation\nWARNING at loc("-":4:22): overflow (original value didn\'t fit, so a modulus was applied) happened during LUT in simulation\n', id="apply_lookup_table_big_value", ), @@ -522,7 +547,7 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): } """, (2,), - -2, + (-2,), b"", id="apply_lookup_table_signed", ), @@ -535,7 +560,7 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): } """, (1,), - -8, + (-8,), b'WARNING at loc("-":4:22): overflow (original value didn\'t fit, so a modulus was applied) happened during LUT in simulation\n', id="apply_lookup_table_signed_big_value", ), @@ -557,8 +582,7 @@ def test_lib_compile_and_run_simulation_with_overflow( # prepare cmd and run script_path = os.path.join(os.path.dirname(__file__), "overflow.py") cmd = [sys.executable, script_path, mlir_file.name] - cmd.extend(map(str, args)) - cmd.append(str(expected_result)) + cmd.append(json.dumps((args, expected_result))) out = subprocess.check_output(cmd, env=os.environ) # close/remove tmp file diff --git a/compilers/concrete-compiler/compiler/tests/python/test_statistics.py b/compilers/concrete-compiler/compiler/tests/python/test_statistics.py index 1c5be1e3e4..e84576739a 100644 --- a/compilers/concrete-compiler/compiler/tests/python/test_statistics.py +++ b/compilers/concrete-compiler/compiler/tests/python/test_statistics.py @@ -4,14 +4,14 @@ import tempfile from concrete.compiler import ( - ClientSupport, - EvaluationKeys, - KeySet, - LibrarySupport, - PublicArguments, - PublicResult, + Compiler, + KeyType, + PrimitiveOperation, + lookup_runtime_lib, + Backend, + CompilationOptions, + MoreCircuitCompilationFeedback, ) -from mlir._mlir_libs._concretelang._compiler import KeyType, PrimitiveOperation def test_statistics(): @@ -28,50 +28,52 @@ def test_statistics(): """.strip() with tempfile.TemporaryDirectory() as tmpdirname: - support = LibrarySupport.new(str(tmpdirname)) - compilation_result = support.compile(mlir) + support = Compiler(str(tmpdirname), lookup_runtime_lib()) + library = support.compile(mlir, CompilationOptions(Backend.CPU)) - client_parameters = support.load_client_parameters(compilation_result) - program_compilation_feedback = support.load_compilation_feedback( - compilation_result - ) - compilation_feedback = program_compilation_feedback.circuit("main") + program_info = library.get_program_info() + program_compilation_feedback = library.get_program_compilation_feedback() + compilation_feedback = program_compilation_feedback.get_circuit_feedback("main") - pbs_count = compilation_feedback.count( + pbs_count = MoreCircuitCompilationFeedback.count( + compilation_feedback, operations={ PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS, - } + }, ) assert pbs_count == 1 - pbs_counts_per_parameter = compilation_feedback.count_per_parameter( + pbs_counts_per_parameter = MoreCircuitCompilationFeedback.count_per_parameter( + compilation_feedback, operations={ PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS, }, key_types={KeyType.BOOTSTRAP}, - client_parameters=client_parameters, + program_info=program_info, ) assert len(pbs_counts_per_parameter) == 1 assert pbs_counts_per_parameter[list(pbs_counts_per_parameter.keys())[0]] == 1 - pbs_counts_per_tag = compilation_feedback.count_per_tag( + pbs_counts_per_tag = MoreCircuitCompilationFeedback.count_per_tag( + compilation_feedback, operations={ PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS, - } + }, ) assert pbs_counts_per_tag == {} pbs_counts_per_tag_per_parameter = ( - compilation_feedback.count_per_tag_per_parameter( + MoreCircuitCompilationFeedback.count_per_tag_per_parameter( + compilation_feedback, operations={ PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS, }, key_types={KeyType.BOOTSTRAP}, - client_parameters=client_parameters, + program_info=program_info, ) ) assert pbs_counts_per_tag_per_parameter == {} diff --git a/compilers/concrete-compiler/compiler/tests/python/test_wrappers.py b/compilers/concrete-compiler/compiler/tests/python/test_wrappers.py deleted file mode 100644 index 20f338374c..0000000000 --- a/compilers/concrete-compiler/compiler/tests/python/test_wrappers.py +++ /dev/null @@ -1,39 +0,0 @@ -import pytest -from concrete.compiler import ( - ClientParameters, - ClientSupport, - CompilationOptions, - KeySetCache, - KeySet, - LambdaArgument, - LibraryCompilationResult, - LibraryLambda, - LibrarySupport, - PublicArguments, - PublicResult, -) - - -@pytest.mark.parametrize("garbage", ["string here", 23, None]) -@pytest.mark.parametrize( - "WrapperClass", - [ - pytest.param(ClientParameters, id="ClientParameters"), - pytest.param(ClientSupport, id="ClientSupport"), - pytest.param(CompilationOptions, id="CompilationOptions"), - pytest.param(KeySetCache, id="KeySetCache"), - pytest.param(KeySet, id="KeySet"), - pytest.param(LambdaArgument, id="LambdaArgument"), - pytest.param(LibraryCompilationResult, id="LibraryCompilationResult"), - pytest.param(LibraryLambda, id="LibraryLambda"), - pytest.param(LibrarySupport, id="LibrarySupport"), - pytest.param(PublicArguments, id="PublicArguments"), - pytest.param(PublicResult, id="PublicResult"), - ], -) -def test_invalid_wrapping(WrapperClass, garbage): - with pytest.raises( - TypeError, - match=f"\.* must be of type _{WrapperClass.__name__}, not {type(garbage)}", - ): - WrapperClass(garbage) diff --git a/frontends/concrete-python/concrete/fhe/__init__.py b/frontends/concrete-python/concrete/fhe/__init__.py index 70c13ed30b..88a1bbf2a8 100644 --- a/frontends/concrete-python/concrete/fhe/__init__.py +++ b/frontends/concrete-python/concrete/fhe/__init__.py @@ -4,8 +4,6 @@ # pylint: disable=import-error,no-name-in-module -from concrete.compiler import EvaluationKeys, Parameter, PublicArguments, PublicResult - from .compilation import ( DEFAULT_GLOBAL_P_ERROR, DEFAULT_P_ERROR, @@ -23,6 +21,7 @@ Configuration, DebugArtifacts, EncryptionStatus, + EvaluationKeys, Exactness, ) from .compilation import FheFunction as Function diff --git a/frontends/concrete-python/concrete/fhe/compilation/__init__.py b/frontends/concrete-python/concrete/fhe/compilation/__init__.py index 88355d91af..78c8750abe 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/__init__.py +++ b/frontends/concrete-python/concrete/fhe/compilation/__init__.py @@ -20,12 +20,13 @@ MultivariateStrategy, ParameterSelectionStrategy, ) +from .evaluation_keys import EvaluationKeys from .keys import Keys from .module import FheFunction, FheModule from .module_compiler import FunctionDef, ModuleCompiler from .server import Server from .specs import ClientSpecs from .status import EncryptionStatus -from .utils import get_terminal_size, inputset +from .utils import inputset from .value import Value from .wiring import AllComposable, AllInputs, AllOutputs, Input, NotComposable, Output, Wire, Wired diff --git a/frontends/concrete-python/concrete/fhe/compilation/artifacts.py b/frontends/concrete-python/concrete/fhe/compilation/artifacts.py index e38f571cad..8fdeed185f 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/artifacts.py +++ b/frontends/concrete-python/concrete/fhe/compilation/artifacts.py @@ -339,7 +339,7 @@ def client_parameters(self) -> Optional[bytes]: """ return ( - self._execution_runtime.val.client.specs.client_parameters.serialize() + self._execution_runtime.val.client.specs.program_info.serialize() if self._execution_runtime is not None else None ) diff --git a/frontends/concrete-python/concrete/fhe/compilation/circuit.py b/frontends/concrete-python/concrete/fhe/compilation/circuit.py index 34fda45c3b..a1711be835 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/circuit.py +++ b/frontends/concrete-python/concrete/fhe/compilation/circuit.py @@ -41,6 +41,9 @@ def _function(self) -> FheFunction: @property def function_name(self) -> str: + """ + Return the name of the circuit. + """ return self._name def __str__(self): @@ -112,10 +115,11 @@ def simulate(self, *args: Any) -> Any: Any: result of the simulation """ + return self._function.simulate(*args) @property - def keys(self) -> Keys: + def keys(self) -> Optional[Keys]: """ Get the keys of the circuit. """ @@ -271,14 +275,14 @@ def size_of_outputs(self) -> int: return self._function.size_of_outputs # pragma: no cover @property - def p_error(self) -> int: + def p_error(self) -> float: """ Get probability of error for each simple TLU (on a scalar). """ return self._module.p_error # pragma: no cover @property - def global_p_error(self) -> int: + def global_p_error(self) -> float: """ Get the probability of having at least one simple TLU error during the entire execution. """ @@ -292,7 +296,7 @@ def complexity(self) -> float: return self._module.complexity # pragma: no cover @property - def memory_usage_per_location(self) -> Dict[str, int]: + def memory_usage_per_location(self) -> Dict[str, Optional[int]]: """ Get the memory usage of operations in the circuit per location. """ diff --git a/frontends/concrete-python/concrete/fhe/compilation/client.py b/frontends/concrete-python/concrete/fhe/compilation/client.py index e9dc45ff34..02ae08c288 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/client.py +++ b/frontends/concrete-python/concrete/fhe/compilation/client.py @@ -10,8 +10,10 @@ from typing import Dict, List, Optional, Tuple, Union import numpy as np -from concrete.compiler import EvaluationKeys, LweSecretKey, ValueDecrypter, ValueExporter +from concrete.compiler import ClientProgram, LweSecretKey +from concrete.compiler import Value as Value_ +from .evaluation_keys import EvaluationKeys from .keys import Keys from .specs import ClientSpecs from .utils import validate_input_args @@ -25,16 +27,19 @@ class Client: Client class, which can be used to manage keys, encrypt arguments and decrypt results. """ - specs: ClientSpecs - _keys: Keys + _client_specs: ClientSpecs + _keys: Optional[Keys] def __init__( self, client_specs: ClientSpecs, keyset_cache_directory: Optional[Union[str, Path]] = None, + is_simulated: bool = False, ): - self.specs = client_specs - self._keys = Keys(client_specs, keyset_cache_directory) + self._client_specs = client_specs + self._keys = None + if not is_simulated: + self._keys = Keys(client_specs, keyset_cache_directory) def save(self, path: Union[str, Path]): """ @@ -47,7 +52,7 @@ def save(self, path: Union[str, Path]): with tempfile.TemporaryDirectory() as tmp_dir: with open(Path(tmp_dir) / "client.specs.json", "wb") as f: - f.write(self.specs.serialize()) + f.write(self._client_specs.serialize()) path = str(path) if path.endswith(".zip"): @@ -59,6 +64,7 @@ def save(self, path: Union[str, Path]): def load( path: Union[str, Path], keyset_cache_directory: Optional[Union[str, Path]] = None, + is_simulated: bool = False, ) -> "Client": """ Load the client from the given path in zip format. @@ -70,6 +76,9 @@ def load( keyset_cache_directory (Optional[Union[str, Path]], default = None): keyset cache directory to use + is_simulated (bool, default = False): + should perform + Returns: Client: client loaded from the filesystem @@ -80,10 +89,24 @@ def load( with open(Path(tmp_dir) / "client.specs.json", "rb") as f: client_specs = ClientSpecs.deserialize(f.read()) - return Client(client_specs, keyset_cache_directory) + return Client(client_specs, keyset_cache_directory, is_simulated) + + @property + def specs(self) -> ClientSpecs: + """ + Get the client specs for the client. + """ + return self._client_specs + + @specs.setter + def specs(self, new_spec: ClientSpecs): + """ + Get the spec for the client. + """ + self._client_specs = new_spec @property - def keys(self) -> Keys: + def keys(self) -> Optional[Keys]: """ Get the keys for the client. """ @@ -94,14 +117,14 @@ def keys(self, new_keys: Keys): """ Set the keys for the client. """ - # TODO: implement verification for compatibility with keyset. - + assert self._keys is not None, "Tried to set keys on simulated client." + assert new_keys.are_generated, "Keyset is not generated." self._keys = new_keys def keygen( self, force: bool = False, - seed: Optional[int] = None, + secret_seed: Optional[int] = None, encryption_seed: Optional[int] = None, initial_keys: Optional[Dict[int, LweSecretKey]] = None, ): @@ -112,7 +135,7 @@ def keygen( force (bool, default = False): whether to generate new keys even if keys are already generated - seed (Optional[int], default = None): + secret_seed (Optional[int], default = None): seed for private keys randomness encryption_seed (Optional[int], default = None): @@ -122,9 +145,10 @@ def keygen( initial keys to set before keygen """ - self.keys.generate( + assert self._keys is not None, "Tried to generate keys on simulated client." + self._keys.generate( force=force, - seed=seed, + secret_seed=secret_seed, encryption_seed=encryption_seed, initial_keys=initial_keys, ) @@ -148,8 +172,12 @@ def encrypt( encrypted argument(s) for evaluation """ + assert self._keys is not None, "Tried to encrypt on a simulated client." + if not self._keys.are_generated: + self._keys.generate() + if function_name is None: - functions = self.specs.client_parameters.function_list() + functions = self.specs.program_info.function_list() if len(functions) == 1: function_name = functions[0] else: # pragma: no cover @@ -157,21 +185,62 @@ def encrypt( Provide a `function_name` keyword argument to disambiguate." raise TypeError(msg) - ordered_sanitized_args = validate_input_args(self.specs, *args, function_name=function_name) + ordered_sanitized_args = validate_input_args( + self._client_specs, *args, function_name=function_name + ) + client_program = ClientProgram.create_encrypted( + self._client_specs.program_info, self._keys._keyset # pylint: disable=protected-access + ) + client_circuit = client_program.get_client_circuit(function_name) - self.keygen(force=False) - keyset = self.keys._keyset # pylint: disable=protected-access + exported = [ + (None if arg is None else Value(client_circuit.prepare_input(Value_(arg), position))) + for position, arg in enumerate(ordered_sanitized_args) + ] + + return tuple(exported) if len(exported) != 1 else exported[0] + + def simulate_encrypt( + self, + *args: Optional[Union[int, np.ndarray, List]], + function_name: Optional[str] = None, + ) -> Optional[Union[Value, Tuple[Optional[Value], ...]]]: + """ + Simulate encryption of argument(s) for evaluation. + + Args: + *args (Optional[Union[int, np.ndarray, List]]): + argument(s) for evaluation + function_name (str): + name of the function to encrypt + + Returns: + Optional[Union[Value, Tuple[Optional[Value], ...]]]: + encrypted argument(s) for evaluation + """ + + assert self._keys is None, "Tried to simulate encryption on an encrypted client." + + if function_name is None: # pragma: no cover + functions = self.specs.program_info.function_list() + if len(functions) == 1: + function_name = functions[0] + else: + msg = "The client contains more than one functions. \ +Provide a `function_name` keyword argument to disambiguate." + raise TypeError(msg) + + ordered_sanitized_args = validate_input_args( + self._client_specs, *args, function_name=function_name + ) + client_program = ClientProgram.create_simulated(self._client_specs.program_info) + client_circuit = client_program.get_client_circuit(function_name) - exporter = ValueExporter.new(keyset, self.specs.client_parameters, function_name) exported = [ ( None if arg is None - else Value( - exporter.export_tensor(position, arg.flatten().tolist(), list(arg.shape)) - if isinstance(arg, np.ndarray) and arg.shape != () - else exporter.export_scalar(position, int(arg)) - ) + else Value(client_circuit.simulate_prepare_input(Value_(arg), position)) ) for position, arg in enumerate(ordered_sanitized_args) ] @@ -197,8 +266,8 @@ def decrypt( decrypted result(s) of evaluation """ - if function_name is None: - functions = self.specs.client_parameters.function_list() + if function_name is None: # pragma: no cover + functions = self.specs.program_info.function_list() if len(functions) == 1: function_name = functions[0] else: # pragma: no cover @@ -214,12 +283,68 @@ def decrypt( else: flattened_results.append(result) - self.keygen(force=False) - keyset = self.keys._keyset # pylint: disable=protected-access + assert self._keys is not None, "Tried to decrypt on a simulated client." + assert self._keys.are_generated + + client_program = ClientProgram.create_encrypted( + self._client_specs.program_info, self._keys._keyset # pylint: disable=protected-access + ) + client_circuit = client_program.get_client_circuit(function_name) + + decrypted = tuple( + client_circuit.process_output( + result._inner, position # pylint: disable=protected-access + ).to_py_val() + for position, result in enumerate(flattened_results) + ) + + return decrypted if len(decrypted) != 1 else decrypted[0] + + def simulate_decrypt( + self, + *results: Union[Value, Tuple[Value, ...]], + function_name: Optional[str] = None, + ) -> Optional[Union[int, np.ndarray, Tuple[Optional[Union[int, np.ndarray]], ...]]]: + """ + Simulate decryption of result(s) of evaluation. + + Args: + *results (Union[Value, Tuple[Value, ...]]): + result(s) of evaluation + function_name (str): + name of the function to decrypt for + + Returns: + Optional[Union[int, np.ndarray, Tuple[Optional[Union[int, np.ndarray]], ...]]]: + decrypted result(s) of evaluation + """ + + if function_name is None: # pragma: no cover + functions = self.specs.program_info.function_list() + if len(functions) == 1: + function_name = functions[0] + else: # pragma: no cover + msg = "The client contains more than one functions. \ +Provide a `function_name` keyword argument to disambiguate." + raise TypeError(msg) + + flattened_results: List[Value] = [] + for result in results: + if isinstance(result, tuple): # pragma: no cover + # this branch is impossible to cover without multiple outputs + flattened_results.extend(result) + else: + flattened_results.append(result) + + assert self._keys is None, "Tried to simulate decryption on an encrypted client." + + client_program = ClientProgram.create_simulated(self._client_specs.program_info) + client_circuit = client_program.get_client_circuit(function_name) - decrypter = ValueDecrypter.new(keyset, self.specs.client_parameters, function_name) decrypted = tuple( - decrypter.decrypt(position, result.inner) + client_circuit.simulate_process_output( + result._inner, position # pylint: disable=protected-access + ).to_py_val() for position, result in enumerate(flattened_results) ) @@ -235,5 +360,6 @@ def evaluation_keys(self) -> EvaluationKeys: evaluation keys for encrypted computation """ + assert self._keys is not None, "Tried to get evaluation keys from simulated client." self.keygen(force=False) - return self.keys.evaluation + return self._keys.evaluation diff --git a/frontends/concrete-python/concrete/fhe/compilation/compiler.py b/frontends/concrete-python/concrete/fhe/compilation/compiler.py index 6ad80365f4..196b087b6a 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/compiler.py +++ b/frontends/concrete-python/concrete/fhe/compilation/compiler.py @@ -65,7 +65,7 @@ def assemble( compiler = Compiler( function, { - name: "encrypted" if value.is_encrypted else "clear" + name: EncryptionStatus.ENCRYPTED if value.is_encrypted else EncryptionStatus.CLEAR for name, value in parameter_values.items() }, composition=( diff --git a/frontends/concrete-python/concrete/fhe/compilation/evaluation_keys.py b/frontends/concrete-python/concrete/fhe/compilation/evaluation_keys.py new file mode 100644 index 0000000000..d7cff08a92 --- /dev/null +++ b/frontends/concrete-python/concrete/fhe/compilation/evaluation_keys.py @@ -0,0 +1,28 @@ +""" +Declaration of `EvaluationKeys`. +""" + +# pylint: disable=import-error,no-member,no-name-in-module +from concrete.compiler import ServerKeyset +from typing_extensions import NamedTuple + + +class EvaluationKeys(NamedTuple): + """ + EvaluationKeys required for execution. + """ + + server_keyset: ServerKeyset + + def serialize(self) -> bytes: + """ + Serialize the evaluation keys. + """ + return self.server_keyset.serialize() + + @staticmethod + def deserialize(buffer: bytes) -> "EvaluationKeys": + """ + Deserialize evaluation keys from bytes. + """ + return EvaluationKeys(ServerKeyset.deserialize(buffer)) diff --git a/frontends/concrete-python/concrete/fhe/compilation/keys.py b/frontends/concrete-python/concrete/fhe/compilation/keys.py index 6e9613bce6..72942cb4c1 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/keys.py +++ b/frontends/concrete-python/concrete/fhe/compilation/keys.py @@ -8,8 +8,9 @@ from pathlib import Path from typing import Dict, Optional, Union -from concrete.compiler import ClientSupport, EvaluationKeys, KeySet, KeySetCache, LweSecretKey +from concrete.compiler import Keyset, KeysetCache, LweSecretKey +from .evaluation_keys import EvaluationKeys from .specs import ClientSpecs # pylint: enable=import-error,no-name-in-module @@ -23,26 +24,19 @@ class Keys: Be careful when serializing/saving keys! """ - client_specs: Optional[ClientSpecs] - cache_directory: Optional[Union[str, Path]] - - _keyset_cache: Optional[KeySetCache] - _keyset: Optional[KeySet] + _cache: Optional[KeysetCache] + _specs: Optional[ClientSpecs] + _keyset: Optional[Keyset] def __init__( self, - client_specs: Optional[ClientSpecs], + specs: Optional[ClientSpecs], cache_directory: Optional[Union[str, Path]] = None, ): - self.client_specs = client_specs - self.cache_directory = cache_directory - - self._keyset_cache = None + self._cache = KeysetCache(str(cache_directory)) if cache_directory is not None else None + self._specs = specs self._keyset = None - if cache_directory is not None: - self._keyset_cache = KeySetCache.new(str(cache_directory)) - @property def are_generated(self) -> bool: """ @@ -54,7 +48,7 @@ def are_generated(self) -> bool: def generate( self, force: bool = False, - seed: Optional[int] = None, + secret_seed: Optional[int] = None, encryption_seed: Optional[int] = None, initial_keys: Optional[Dict[int, LweSecretKey]] = None, ): @@ -65,7 +59,7 @@ def generate( force (bool, default = False): whether to generate new keys even if keys are already generated/loaded - seed (Optional[int], default = None): + secret_seed (Optional[int], default = None): seed for private keys randomness encryption_seed (Optional[int], default = None): @@ -76,14 +70,30 @@ def generate( """ if self._keyset is None or force: - if self.client_specs is None: # pragma: no cover + if self._specs is None: # pragma: no cover message = "Tried to generate Keys without client specs." raise ValueError(message) - self._keyset = ClientSupport.key_set( - self.client_specs.client_parameters, - self._keyset_cache, - seed, - encryption_seed, + + secret_seed = 0 if secret_seed is None else secret_seed + encryption_seed = 0 if encryption_seed is None else encryption_seed + if secret_seed < 0 or secret_seed >= 2**128: + message = "secret_seed must be a positive 128 bits integer" + raise ValueError(message) + if encryption_seed < 0 or encryption_seed >= 2**128: + message = "encryption_seed must be a positive 128 bits integer" + raise ValueError(message) + secret_seed_msb = (secret_seed >> 64) & 0xFFFFFFFFFFFFFFFF + secret_seed_lsb = (secret_seed) & 0xFFFFFFFFFFFFFFFF + encryption_seed_msb = (encryption_seed >> 64) & 0xFFFFFFFFFFFFFFFF + encryption_seed_lsb = (encryption_seed) & 0xFFFFFFFFFFFFFFFF + + self._keyset = Keyset( + self._specs.program_info, + self._cache, + secret_seed_msb, + secret_seed_lsb, + encryption_seed_msb, + encryption_seed_lsb, initial_keys, ) @@ -125,18 +135,16 @@ def load(self, location: Union[str, Path]): keys = Keys.deserialize(bytes(location.read_bytes())) - self.client_specs = None - self.cache_directory = None - # pylint: disable=protected-access - self._keyset_cache = None + self._specs = None + self._cache = None self._keyset = keys._keyset # pylint: enable=protected-access def load_if_exists_generate_and_save_otherwise( self, location: Union[str, Path], - seed: Optional[int] = None, + secret_seed: Optional[int] = None, ): """ Load keys from a location if they exist, else generate new keys and save to that location. @@ -145,7 +153,7 @@ def load_if_exists_generate_and_save_otherwise( location (Union[str, Path]): location to load from or save to - seed (Optional[int], default = None): + secret_seed (Optional[int], default = None): seed for randomness in case keys need to be generated """ @@ -155,7 +163,7 @@ def load_if_exists_generate_and_save_otherwise( if location.exists(): self.load(location) else: - self.generate(seed=seed) + self.generate(secret_seed=secret_seed) self.save(location) def serialize(self) -> bytes: @@ -190,7 +198,7 @@ def deserialize(serialized_keys: bytes) -> "Keys": deserialized keys """ - keyset = KeySet.deserialize(serialized_keys) + keyset = Keyset.deserialize(serialized_keys) # pylint: disable=protected-access result = Keys(None) @@ -199,6 +207,13 @@ def deserialize(serialized_keys: bytes) -> "Keys": return result + @property + def specs(self) -> Optional[ClientSpecs]: + """ + Return the associated client specs if any. + """ + return self._specs # pragma: no cover + @property def evaluation(self) -> EvaluationKeys: """ @@ -207,5 +222,4 @@ def evaluation(self) -> EvaluationKeys: self.generate(force=False) assert self._keyset is not None - - return self._keyset.get_evaluation_keys() + return EvaluationKeys(self._keyset.get_server_keys()) diff --git a/frontends/concrete-python/concrete/fhe/compilation/module.py b/frontends/concrete-python/concrete/fhe/compilation/module.py index 8b8de2bfb0..279c26dff6 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/module.py +++ b/frontends/concrete-python/concrete/fhe/compilation/module.py @@ -8,13 +8,7 @@ from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union import numpy as np -from concrete.compiler import ( - CompilationContext, - LweSecretKey, - Parameter, - SimulatedValueDecrypter, - SimulatedValueExporter, -) +from concrete.compiler import CompilationContext, LweSecretKey, Parameter from mlir.ir import Module as MlirModule from ..internal.utils import assert_that @@ -24,7 +18,7 @@ from .configuration import Configuration from .keys import Keys from .server import Server -from .utils import Lazy, validate_input_args +from .utils import Lazy from .value import Value # pylint: enable=import-error,no-member,no-name-in-module @@ -44,6 +38,7 @@ class SimulationRt(NamedTuple): Runtime object class for simulation. """ + client: Client server: Server @@ -124,6 +119,29 @@ def __str__(self): def __repr__(self) -> str: return f"FheFunction(name={self.name})" + def _simulate_encrypt( + self, + *args: Optional[Union[int, np.ndarray, List]], + ) -> Optional[Union[Value, Tuple[Optional[Value], ...]]]: + + return self.simulation_runtime.val.client.simulate_encrypt(*args, function_name=self.name) + + def _simulate_run( + self, + *args: Optional[Union[Value, Tuple[Optional[Value], ...]]], + ) -> Union[Value, Tuple[Value, ...]]: + + return self.simulation_runtime.val.server.run(*args, function_name=self.name) + + def _simulate_decrypt( + self, + *results: Union[Value, Tuple[Value, ...]], + ) -> Optional[Union[int, np.ndarray, Tuple[Optional[Union[int, np.ndarray]], ...]]]: + + return self.simulation_runtime.val.client.simulate_decrypt( + *results, function_name=self.name + ) + def simulate(self, *args: Any) -> Any: """ Simulate execution of the function. @@ -137,42 +155,11 @@ def simulate(self, *args: Any) -> Any: result of the simulation """ - ordered_validated_args = validate_input_args( - self.simulation_runtime.val.server.client_specs, - *args, - function_name=self.name, - ) - - exporter = SimulatedValueExporter.new( - self.simulation_runtime.val.server.client_specs.client_parameters, self.name - ) - exported = [ - ( - None - if arg is None - else Value( - exporter.export_tensor(position, arg.flatten().tolist(), list(arg.shape)) - if isinstance(arg, np.ndarray) and arg.shape != () - else exporter.export_scalar(position, int(arg)) - ) - ) - for position, arg in enumerate(ordered_validated_args) - ] - - results = self.simulation_runtime.val.server.run(*exported, function_name=self.name) - if not isinstance(results, tuple): - results = (results,) - - decrypter = SimulatedValueDecrypter.new( - self.simulation_runtime.val.server.client_specs.client_parameters, self.name - ) - decrypted = tuple( - decrypter.decrypt(position, result.inner) for position, result in enumerate(results) - ) - return decrypted if len(decrypted) != 1 else decrypted[0] + return self._simulate_decrypt(self._simulate_run(self._simulate_encrypt(*args))) def encrypt( - self, *args: Optional[Union[int, np.ndarray, List]] + self, + *args: Optional[Union[int, np.ndarray, List]], ) -> Optional[Union[Value, Tuple[Optional[Value], ...]]]: """ Encrypt argument(s) to for evaluation. @@ -187,8 +174,7 @@ def encrypt( """ if self.configuration.simulate_encrypt_run_decrypt: - return args if len(args) != 1 else args[0] # type: ignore - + return tuple(args) if len(args) > 1 else args[0] # type: ignore return self.execution_runtime.val.client.encrypt(*args, function_name=self.name) def run( @@ -208,8 +194,7 @@ def run( """ if self.configuration.simulate_encrypt_run_decrypt: - return self.simulate(*args) - + return self._simulate_decrypt(self._simulate_run(*args)) # type: ignore return self.execution_runtime.val.server.run( *args, evaluation_keys=self.execution_runtime.val.client.evaluation_keys, @@ -233,7 +218,7 @@ def decrypt( """ if self.configuration.simulate_encrypt_run_decrypt: - return results if len(results) != 1 else results[0] # type: ignore + return tuple(results) if len(results) > 1 else results[0] # type: ignore return self.execution_runtime.val.client.decrypt(*results, function_name=self.name) @@ -612,7 +597,8 @@ def init_simulation(): is_simulated=True, compilation_context=self.compilation_context, ) - return SimulationRt(simulation_server) + simulation_client = Client(simulation_server.client_specs, is_simulated=True) + return SimulationRt(simulation_client, simulation_server) self.simulation_runtime = Lazy(init_simulation) if configuration.fhe_simulation: @@ -624,13 +610,16 @@ def init_execution(): self.configuration.fork(fhe_simulation=False), compilation_context=self.compilation_context, composition_rules=composition_rules, + is_simulated=False, ) keyset_cache_directory = None if self.configuration.use_insecure_key_cache: assert_that(self.configuration.enable_unsafe_features) assert_that(self.configuration.insecure_key_cache_location is not None) keyset_cache_directory = self.configuration.insecure_key_cache_location - execution_client = Client(execution_server.client_specs, keyset_cache_directory) + execution_client = Client( + execution_server.client_specs, keyset_cache_directory, is_simulated=False + ) return ExecutionRt(execution_client, execution_server) self.execution_runtime = Lazy(init_execution) @@ -647,7 +636,7 @@ def mlir(self) -> str: return str(self.mlir_module).strip() @property - def keys(self) -> Keys: + def keys(self) -> Optional[Keys]: """ Get the keys of the module. """ @@ -717,14 +706,14 @@ def size_of_keyswitch_keys(self) -> int: return self.execution_runtime.val.server.size_of_keyswitch_keys # pragma: no cover @property - def p_error(self) -> int: + def p_error(self) -> float: """ Get probability of error for each simple TLU (on a scalar). """ return self.execution_runtime.val.server.p_error # pragma: no cover @property - def global_p_error(self) -> int: + def global_p_error(self) -> float: """ Get the probability of having at least one simple TLU error during the entire execution. """ @@ -799,7 +788,7 @@ def function_count(self) -> int: """ return len(self.graphs) - def __getattr__(self, item): + def __getattr__(self, item) -> FheFunction: if item not in list(self.graphs.keys()): error = f"No attribute {item}" raise AttributeError(error) diff --git a/frontends/concrete-python/concrete/fhe/compilation/module_compiler.py b/frontends/concrete-python/concrete/fhe/compilation/module_compiler.py index aca0f005ec..fe814572d9 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/module_compiler.py +++ b/frontends/concrete-python/concrete/fhe/compilation/module_compiler.py @@ -343,7 +343,7 @@ def __init__(self, functions: List[FunctionDef], composition: CompositionPolicy) parameter_selection_strategy="multi", ) self.functions = {function.name: function for function in functions} - self.compilation_context = CompilationContext.new() + self.compilation_context = CompilationContext() self.composition = composition def wire_pipeline(self, inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]]): @@ -467,7 +467,7 @@ def compile( # pylint: enable=too-many-branches,too-many-statements - def __getattr__(self, item): + def __getattr__(self, item) -> FunctionDef: if item not in list(self.functions.keys()): error = f"No attribute {item}" raise AttributeError(error) diff --git a/frontends/concrete-python/concrete/fhe/compilation/server.py b/frontends/concrete-python/concrete/fhe/compilation/server.py index c22750db38..81430f93e6 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/server.py +++ b/frontends/concrete-python/concrete/fhe/compilation/server.py @@ -8,36 +8,32 @@ import shutil import tempfile from pathlib import Path -from typing import Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Dict, Iterable, List, Optional, Tuple, Union # mypy: disable-error-code=attr-defined import concrete.compiler import jsonpickle import numpy as np from concrete.compiler import ( + Backend, + ClientProgram, CompilationContext, CompilationOptions, - EvaluationKeys, - LibraryCompilationResult, - LibrarySupport, - Parameter, - ProgramCompilationFeedback, - PublicArguments, - ServerProgram, - SimulatedValueExporter, - set_compiler_logging, - set_llvm_debug_flag, -) -from mlir._mlir_libs._concretelang._compiler import ( - Backend, + Compiler, KeyType, + Library, + MoreCircuitCompilationFeedback, OptimizerMultiParameterStrategy, OptimizerStrategy, + Parameter, PrimitiveOperation, + ProgramInfo, + ServerProgram, ) +from concrete.compiler import Value as Value_ +from concrete.compiler import lookup_runtime_lib, set_compiler_logging, set_llvm_debug_flag from mlir.ir import Module as MlirModule -from ..internal.utils import assert_that from .composition import CompositionClause, CompositionRule from .configuration import ( DEFAULT_GLOBAL_P_ERROR, @@ -46,8 +42,9 @@ MultiParameterStrategy, ParameterSelectionStrategy, ) +from .evaluation_keys import EvaluationKeys from .specs import ClientSpecs -from .utils import friendly_type_format +from .utils import Lazy, friendly_type_format from .value import Value # pylint: enable=import-error,no-member,no-name-in-module @@ -58,64 +55,29 @@ class Server: Server class, which can be used to perform homomorphic computation. """ - client_specs: ClientSpecs is_simulated: bool - - _output_dir: Union[None, str, Path] - _support: LibrarySupport - _compilation_result: LibraryCompilationResult - _compilation_feedback: ProgramCompilationFeedback - _server_program: ServerProgram - + _library: Library _mlir: Optional[str] _configuration: Optional[Configuration] _composition_rules: Optional[List[CompositionRule]] - _clear_input_indices: Dict[str, Set[int]] - _clear_input_shapes: Dict[str, Dict[int, Tuple[int, ...]]] - def __init__( self, - client_specs: ClientSpecs, - output_dir: Union[None, str, Path], - support: LibrarySupport, - compilation_result: LibraryCompilationResult, - server_program: ServerProgram, + library: Library, is_simulated: bool, composition_rules: Optional[List[CompositionRule]], ): - self.client_specs = client_specs self.is_simulated = is_simulated - - self._output_dir = output_dir - self._support = support - self._compilation_result = compilation_result - self._compilation_feedback = self._support.load_compilation_feedback(compilation_result) - self._server_program = server_program + self._library = library self._mlir = None self._composition_rules = composition_rules - self._clear_input_indices = {} - self._clear_input_shapes = {} - - functions_parameters = json.loads(client_specs.client_parameters.serialize())["circuits"] - for function_parameters in functions_parameters: - name = function_parameters["name"] - self._clear_input_indices[name] = { - index - for index, input_spec in enumerate(function_parameters["inputs"]) - if "plaintext" in input_spec["typeInfo"] - } - self._clear_input_shapes[name] = { - index: tuple(input_spec["rawInfo"]["shape"]["dimensions"]) - for index, input_spec in enumerate(function_parameters["inputs"]) - if "plaintext" in input_spec["typeInfo"] - } - - assert_that( - support.load_client_parameters(compilation_result).serialize() - == client_specs.client_parameters.serialize() - ) + @property + def client_specs(self) -> ClientSpecs: + """ + Return the associated client specs. + """ + return ClientSpecs(self._library.get_program_info()) @staticmethod def create( @@ -146,10 +108,8 @@ def create( """ backend = Backend.GPU if configuration.use_gpu else Backend.CPU - options = CompilationOptions.new(backend) - + options = CompilationOptions(backend) options.simulation(is_simulated) - options.set_loop_parallelize(configuration.loop_parallelize) options.set_dataflow_parallelize(configuration.dataflow_parallelize) options.set_auto_parallelize(configuration.auto_parallelize) @@ -158,7 +118,6 @@ def create( options.set_enable_overflow_detection_in_simulation( configuration.detect_overflow_in_simulation ) - options.set_composable(configuration.composable) composition_rules = list(composition_rules) if composition_rules else [] for rule in composition_rules: @@ -231,33 +190,30 @@ def create( output_dir = tempfile.mkdtemp() output_dir_path = Path(output_dir) - support = LibrarySupport.new( - str(output_dir_path), generateCppHeader=False, generateStaticLib=False + compiler = Compiler( + str(output_dir_path), + lookup_runtime_lib(), + generate_shared_lib=True, + generate_program_info=True, + generate_compilation_feedback=True, ) if isinstance(mlir, str): - compilation_result = support.compile(mlir, options) + library = compiler.compile(mlir, options) else: # MlirModule assert ( compilation_context is not None ), "must provide compilation context when compiling MlirModule" - compilation_result = support.compile(mlir, options, compilation_context) - server_program = ServerProgram.load(support, is_simulated) + library = compiler.compile( + mlir._CAPIPtr, options, compilation_context # pylint: disable=protected-access + ) finally: set_llvm_debug_flag(False) set_compiler_logging(False) - client_parameters = support.load_client_parameters(compilation_result) - client_specs = ClientSpecs(client_parameters) composition_rules = composition_rules if composition_rules else None result = Server( - client_specs, - output_dir, - support, - compilation_result, - server_program, - is_simulated, - composition_rules, + library=library, is_simulated=is_simulated, composition_rules=composition_rules ) # pylint: disable=protected-access @@ -306,20 +262,24 @@ def save(self, path: Union[str, Path], via_mlir: bool = False): return - if self._output_dir is None: # pragma: no cover - message = "Output directory must be provided" - raise RuntimeError(message) - - with open(Path(self._output_dir) / "client.specs.json", "wb") as f: + # Note that the shared library, program info and more are already in the output directory. + # We just add a few things related to concrete-python here. + with open(Path(self._library.get_output_dir_path()) / "client.specs.json", "wb") as f: f.write(self.client_specs.serialize()) - with open(Path(self._output_dir) / "is_simulated", "w", encoding="utf-8") as f: + with open( + Path(self._library.get_output_dir_path()) / "is_simulated", "w", encoding="utf-8" + ) as f: f.write("1" if self.is_simulated else "0") - with open(Path(self._output_dir) / "composition_rules.json", "w", encoding="utf-8") as f: + with open( + Path(self._library.get_output_dir_path()) / "composition_rules.json", + "w", + encoding="utf-8", + ) as f: f.write(json.dumps(self._composition_rules)) - shutil.make_archive(path, "zip", self._output_dir) + shutil.make_archive(path, "zip", self._library.get_output_dir_path()) @staticmethod def load(path: Union[str, Path], **kwargs) -> "Server": @@ -376,23 +336,10 @@ def load(path: Union[str, Path], **kwargs) -> "Server": mlir, configuration, is_simulated, composition_rules=composition_rules ) - with open(output_dir_path / "client.specs.json", "rb") as f: - client_specs = ClientSpecs.deserialize(f.read()) - - support = LibrarySupport.new( - str(output_dir_path), - generateCppHeader=False, - generateStaticLib=False, - ) - compilation_result = support.reload() - server_program = ServerProgram.load(support, is_simulated) + library = Library(str(output_dir_path)) return Server( - client_specs, - output_dir, - support, - compilation_result, - server_program, + library, is_simulated, composition_rules, ) @@ -422,11 +369,11 @@ def run( """ if function_name is None: - functions = self.client_specs.client_parameters.function_list() - if len(functions) == 1: - function_name = functions[0] + circuits = self.program_info.get_circuits() + if len(circuits) == 1: + function_name = circuits[0].get_name() else: # pragma: no cover - msg = "The client contains more than one functions. \ + msg = "The server contains more than one functions. \ Provide a `function_name` keyword argument to disambiguate." raise TypeError(msg) @@ -441,118 +388,142 @@ def run( else: flattened_args.append(arg) - buffers = [] - for i, arg in enumerate(flattened_args): - if arg is None: - message = f"Expected argument {i} to be an fhe.Value but it's None" - raise ValueError(message) - - if not isinstance(arg, Value): - if i not in self._clear_input_indices[function_name]: - message = ( - f"Expected argument {i} to be an fhe.Value " - f"but it's {friendly_type_format(type(arg))}" - ) + if not self.is_simulated: + for i, arg in enumerate(flattened_args): + if arg is None: + message = f"Expected argument {i} to be an fhe.Value but it's None" raise ValueError(message) - # Simulated value exporter can be used here - # as "clear" fhe.Values have the same - # internal representation as "simulation" fhe.Values + if not isinstance(arg, Value): + if ( + not self.client_specs.program_info.get_circuit(function_name) + .get_inputs()[i] + .get_type_info() + .is_plaintext() + ): + message = ( + f"Expected argument {i} to be an fhe.Value " + f"but it's {friendly_type_format(type(arg))}" + ) + raise ValueError(message) - exporter = SimulatedValueExporter.new( - self.client_specs.client_parameters, - function_name, - ) + server_program = ServerProgram(self._library, self.is_simulated) + server_circuit = server_program.get_server_circuit(function_name) - if isinstance(arg, (int, np.integer)): - arg = exporter.export_scalar(i, arg) - else: - arg = np.array(arg) - arg = exporter.export_tensor(i, arg.flatten().tolist(), arg.shape) + def init_simulated_client_circuit(): + client_program = ClientProgram.create_simulated(self.client_specs.program_info) + return client_program.get_client_circuit(function_name) + simulated_client_circuit = Lazy(init_simulated_client_circuit) + + unwrapped_args = [] + for i, arg in enumerate(flattened_args): if isinstance(arg, Value): - buffers.append(arg.inner) + unwrapped_args.append(arg._inner) # pylint: disable=protected-access + elif isinstance(arg, list): + unwrapped_args.append( + simulated_client_circuit.val.simulate_prepare_input(Value_(np.array(arg)), i) + ) else: - buffers.append(arg) - - public_args = PublicArguments.new(self.client_specs.client_parameters, buffers) - server_circuit = self._server_program.get_server_circuit(function_name) + unwrapped_args.append( + simulated_client_circuit.val.simulate_prepare_input(Value_(arg), i) + ) if self.is_simulated: - public_result = server_circuit.simulate(public_args) + result = server_circuit.simulate(unwrapped_args) else: - public_result = server_circuit.call(public_args, evaluation_keys) + assert evaluation_keys is not None + result = server_circuit.call(unwrapped_args, evaluation_keys.server_keyset) - result = tuple(Value(public_result.get_value(i)) for i in range(public_result.n_values())) - return result if len(result) > 1 else result[0] + result = [Value(r) for r in result] + return tuple(result) if len(result) > 1 else result[0] def cleanup(self): """ Cleanup the temporary library output directory. """ - if self._output_dir is not None: - shutil.rmtree(Path(self._output_dir).resolve()) + # if self._output_dir is not None: + # shutil.rmtree(Path(self._output_dir).resolve()) + + @property + def program_info(self) -> ProgramInfo: + """ + The program info associated with the server. + """ + return self._library.get_program_info() @property def size_of_secret_keys(self) -> int: """ Get size of the secret keys of the compiled program. """ - return self._compilation_feedback.total_secret_keys_size + return self._library.get_program_compilation_feedback().total_secret_keys_size @property def size_of_bootstrap_keys(self) -> int: """ Get size of the bootstrap keys of the compiled program. """ - return self._compilation_feedback.total_bootstrap_keys_size + return self._library.get_program_compilation_feedback().total_bootstrap_keys_size @property def size_of_keyswitch_keys(self) -> int: """ Get size of the key switch keys of the compiled program. """ - return self._compilation_feedback.total_keyswitch_keys_size + return self._library.get_program_compilation_feedback().total_keyswitch_keys_size @property - def p_error(self) -> int: + def p_error(self) -> float: """ Get the probability of error for each simple TLU (on a scalar). """ - return self._compilation_feedback.p_error + return self._library.get_program_compilation_feedback().p_error @property - def global_p_error(self) -> int: + def global_p_error(self) -> float: """ Get the probability of having at least one simple TLU error during the entire execution. """ - return self._compilation_feedback.global_p_error + return self._library.get_program_compilation_feedback().global_p_error @property def complexity(self) -> float: """ Get complexity of the compiled program. """ - return self._compilation_feedback.complexity + return self._library.get_program_compilation_feedback().complexity - def memory_usage_per_location(self, function: str) -> Dict[str, int]: + def memory_usage_per_location(self, function: str) -> Dict[str, Optional[int]]: """ Get the memory usage of operations per location. """ - return self._compilation_feedback.circuit(function).memory_usage_per_location + return ( + self._library.get_program_compilation_feedback() + .get_circuit_feedback(function) + .memory_usage_per_location + ) def size_of_inputs(self, function: str) -> int: """ Get size of the inputs of the compiled program. """ - return self._compilation_feedback.circuit(function).total_inputs_size + return ( + self._library.get_program_compilation_feedback() + .get_circuit_feedback(function) + .total_inputs_size + ) def size_of_outputs(self, function: str) -> int: """ Get size of the outputs of the compiled program. """ - return self._compilation_feedback.circuit(function).total_output_size + return ( + self._library.get_program_compilation_feedback() + .get_circuit_feedback(function) + .total_output_size + ) # Programmable Bootstrap Statistics @@ -560,7 +531,8 @@ def programmable_bootstrap_count(self, function: str) -> int: """ Get the number of programmable bootstraps in the compiled program. """ - return self._compilation_feedback.circuit(function).count( + return MoreCircuitCompilationFeedback.count( + self._library.get_program_compilation_feedback().get_circuit_feedback(function), operations={PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS}, ) @@ -568,17 +540,19 @@ def programmable_bootstrap_count_per_parameter(self, function: str) -> Dict[Para """ Get the number of programmable bootstraps per parameter in the compiled program. """ - return self._compilation_feedback.circuit(function).count_per_parameter( + return MoreCircuitCompilationFeedback.count_per_parameter( + self._library.get_program_compilation_feedback().get_circuit_feedback(function), operations={PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS}, key_types={KeyType.BOOTSTRAP}, - client_parameters=self.client_specs.client_parameters, + program_info=self.program_info, ) def programmable_bootstrap_count_per_tag(self, function: str) -> Dict[str, int]: """ Get the number of programmable bootstraps per tag in the compiled program. """ - return self._compilation_feedback.circuit(function).count_per_tag( + return MoreCircuitCompilationFeedback.count_per_tag( + self._library.get_program_compilation_feedback().get_circuit_feedback(function), operations={PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS}, ) @@ -588,10 +562,11 @@ def programmable_bootstrap_count_per_tag_per_parameter( """ Get the number of programmable bootstraps per tag per parameter in the compiled program. """ - return self._compilation_feedback.circuit(function).count_per_tag_per_parameter( + return MoreCircuitCompilationFeedback.count_per_tag_per_parameter( + self._library.get_program_compilation_feedback().get_circuit_feedback(function), operations={PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS}, key_types={KeyType.BOOTSTRAP}, - client_parameters=self.client_specs.client_parameters, + program_info=self.program_info, ) # Key Switch Statistics @@ -600,7 +575,8 @@ def key_switch_count(self, function: str) -> int: """ Get the number of key switches in the compiled program. """ - return self._compilation_feedback.circuit(function).count( + return MoreCircuitCompilationFeedback.count( + self._library.get_program_compilation_feedback().get_circuit_feedback(function), operations={PrimitiveOperation.KEY_SWITCH, PrimitiveOperation.WOP_PBS}, ) @@ -608,17 +584,19 @@ def key_switch_count_per_parameter(self, function: str) -> Dict[Parameter, int]: """ Get the number of key switches per parameter in the compiled program. """ - return self._compilation_feedback.circuit(function).count_per_parameter( + return MoreCircuitCompilationFeedback.count_per_parameter( + self._library.get_program_compilation_feedback().get_circuit_feedback(function), operations={PrimitiveOperation.KEY_SWITCH, PrimitiveOperation.WOP_PBS}, key_types={KeyType.KEY_SWITCH}, - client_parameters=self.client_specs.client_parameters, + program_info=self.program_info, ) def key_switch_count_per_tag(self, function: str) -> Dict[str, int]: """ Get the number of key switches per tag in the compiled program. """ - return self._compilation_feedback.circuit(function).count_per_tag( + return MoreCircuitCompilationFeedback.count_per_tag( + self._library.get_program_compilation_feedback().get_circuit_feedback(function), operations={PrimitiveOperation.KEY_SWITCH, PrimitiveOperation.WOP_PBS}, ) @@ -628,10 +606,11 @@ def key_switch_count_per_tag_per_parameter( """ Get the number of key switches per tag per parameter in the compiled program. """ - return self._compilation_feedback.circuit(function).count_per_tag_per_parameter( + return MoreCircuitCompilationFeedback.count_per_tag_per_parameter( + self._library.get_program_compilation_feedback().get_circuit_feedback(function), operations={PrimitiveOperation.KEY_SWITCH, PrimitiveOperation.WOP_PBS}, key_types={KeyType.KEY_SWITCH}, - client_parameters=self.client_specs.client_parameters, + program_info=self.program_info, ) # Packing Key Switch Statistics @@ -640,26 +619,29 @@ def packing_key_switch_count(self, function: str) -> int: """ Get the number of packing key switches in the compiled program. """ - return self._compilation_feedback.circuit(function).count( - operations={PrimitiveOperation.WOP_PBS} + return MoreCircuitCompilationFeedback.count( + self._library.get_program_compilation_feedback().get_circuit_feedback(function), + operations={PrimitiveOperation.WOP_PBS}, ) def packing_key_switch_count_per_parameter(self, function: str) -> Dict[Parameter, int]: """ Get the number of packing key switches per parameter in the compiled program. """ - return self._compilation_feedback.circuit(function).count_per_parameter( + return MoreCircuitCompilationFeedback.count_per_parameter( + self._library.get_program_compilation_feedback().get_circuit_feedback(function), operations={PrimitiveOperation.WOP_PBS}, key_types={KeyType.PACKING_KEY_SWITCH}, - client_parameters=self.client_specs.client_parameters, + program_info=self.program_info, ) def packing_key_switch_count_per_tag(self, function: str) -> Dict[str, int]: """ Get the number of packing key switches per tag in the compiled program. """ - return self._compilation_feedback.circuit(function).count_per_tag( - operations={PrimitiveOperation.WOP_PBS} + return MoreCircuitCompilationFeedback.count_per_tag( + self._library.get_program_compilation_feedback().get_circuit_feedback(function), + operations={PrimitiveOperation.WOP_PBS}, ) def packing_key_switch_count_per_tag_per_parameter( @@ -668,10 +650,11 @@ def packing_key_switch_count_per_tag_per_parameter( """ Get the number of packing key switches per tag per parameter in the compiled program. """ - return self._compilation_feedback.circuit(function).count_per_tag_per_parameter( + return MoreCircuitCompilationFeedback.count_per_tag_per_parameter( + self._library.get_program_compilation_feedback().get_circuit_feedback(function), operations={PrimitiveOperation.WOP_PBS}, key_types={KeyType.PACKING_KEY_SWITCH}, - client_parameters=self.client_specs.client_parameters, + program_info=self.program_info, ) # Clear Addition Statistics @@ -680,25 +663,28 @@ def clear_addition_count(self, function: str) -> int: """ Get the number of clear additions in the compiled program. """ - return self._compilation_feedback.circuit(function).count( - operations={PrimitiveOperation.CLEAR_ADDITION} + return MoreCircuitCompilationFeedback.count( + self._library.get_program_compilation_feedback().get_circuit_feedback(function), + operations={PrimitiveOperation.CLEAR_ADDITION}, ) def clear_addition_count_per_parameter(self, function: str) -> Dict[Parameter, int]: """ Get the number of clear additions per parameter in the compiled program. """ - return self._compilation_feedback.circuit(function).count_per_parameter( + return MoreCircuitCompilationFeedback.count_per_parameter( + self._library.get_program_compilation_feedback().get_circuit_feedback(function), operations={PrimitiveOperation.CLEAR_ADDITION}, key_types={KeyType.SECRET}, - client_parameters=self.client_specs.client_parameters, + program_info=self.program_info, ) def clear_addition_count_per_tag(self, function: str) -> Dict[str, int]: """ Get the number of clear additions per tag in the compiled program. """ - return self._compilation_feedback.circuit(function).count_per_tag( + return MoreCircuitCompilationFeedback.count_per_tag( + self._library.get_program_compilation_feedback().get_circuit_feedback(function), operations={PrimitiveOperation.CLEAR_ADDITION}, ) @@ -708,10 +694,11 @@ def clear_addition_count_per_tag_per_parameter( """ Get the number of clear additions per tag per parameter in the compiled program. """ - return self._compilation_feedback.circuit(function).count_per_tag_per_parameter( + return MoreCircuitCompilationFeedback.count_per_tag_per_parameter( + self._library.get_program_compilation_feedback().get_circuit_feedback(function), operations={PrimitiveOperation.CLEAR_ADDITION}, key_types={KeyType.SECRET}, - client_parameters=self.client_specs.client_parameters, + program_info=self.program_info, ) # Encrypted Addition Statistics @@ -720,25 +707,28 @@ def encrypted_addition_count(self, function: str) -> int: """ Get the number of encrypted additions in the compiled program. """ - return self._compilation_feedback.circuit(function).count( - operations={PrimitiveOperation.ENCRYPTED_ADDITION} + return MoreCircuitCompilationFeedback.count( + self._library.get_program_compilation_feedback().get_circuit_feedback(function), + operations={PrimitiveOperation.ENCRYPTED_ADDITION}, ) def encrypted_addition_count_per_parameter(self, function: str) -> Dict[Parameter, int]: """ Get the number of encrypted additions per parameter in the compiled program. """ - return self._compilation_feedback.circuit(function).count_per_parameter( + return MoreCircuitCompilationFeedback.count_per_parameter( + self._library.get_program_compilation_feedback().get_circuit_feedback(function), operations={PrimitiveOperation.ENCRYPTED_ADDITION}, key_types={KeyType.SECRET}, - client_parameters=self.client_specs.client_parameters, + program_info=self.program_info, ) def encrypted_addition_count_per_tag(self, function: str) -> Dict[str, int]: """ Get the number of encrypted additions per tag in the compiled program. """ - return self._compilation_feedback.circuit(function).count_per_tag( + return MoreCircuitCompilationFeedback.count_per_tag( + self._library.get_program_compilation_feedback().get_circuit_feedback(function), operations={PrimitiveOperation.ENCRYPTED_ADDITION}, ) @@ -748,10 +738,11 @@ def encrypted_addition_count_per_tag_per_parameter( """ Get the number of encrypted additions per tag per parameter in the compiled program. """ - return self._compilation_feedback.circuit(function).count_per_tag_per_parameter( + return MoreCircuitCompilationFeedback.count_per_tag_per_parameter( + self._library.get_program_compilation_feedback().get_circuit_feedback(function), operations={PrimitiveOperation.ENCRYPTED_ADDITION}, key_types={KeyType.SECRET}, - client_parameters=self.client_specs.client_parameters, + program_info=self.program_info, ) # Clear Multiplication Statistics @@ -760,7 +751,8 @@ def clear_multiplication_count(self, function: str) -> int: """ Get the number of clear multiplications in the compiled program. """ - return self._compilation_feedback.circuit(function).count( + return MoreCircuitCompilationFeedback.count( + self._library.get_program_compilation_feedback().get_circuit_feedback(function), operations={PrimitiveOperation.CLEAR_MULTIPLICATION}, ) @@ -768,17 +760,19 @@ def clear_multiplication_count_per_parameter(self, function: str) -> Dict[Parame """ Get the number of clear multiplications per parameter in the compiled program. """ - return self._compilation_feedback.circuit(function).count_per_parameter( + return MoreCircuitCompilationFeedback.count_per_parameter( + self._library.get_program_compilation_feedback().get_circuit_feedback(function), operations={PrimitiveOperation.CLEAR_MULTIPLICATION}, key_types={KeyType.SECRET}, - client_parameters=self.client_specs.client_parameters, + program_info=self.program_info, ) def clear_multiplication_count_per_tag(self, function: str) -> Dict[str, int]: """ Get the number of clear multiplications per tag in the compiled program. """ - return self._compilation_feedback.circuit(function).count_per_tag( + return MoreCircuitCompilationFeedback.count_per_tag( + self._library.get_program_compilation_feedback().get_circuit_feedback(function), operations={PrimitiveOperation.CLEAR_MULTIPLICATION}, ) @@ -788,10 +782,11 @@ def clear_multiplication_count_per_tag_per_parameter( """ Get the number of clear multiplications per tag per parameter in the compiled program. """ - return self._compilation_feedback.circuit(function).count_per_tag_per_parameter( + return MoreCircuitCompilationFeedback.count_per_tag_per_parameter( + self._library.get_program_compilation_feedback().get_circuit_feedback(function), operations={PrimitiveOperation.CLEAR_MULTIPLICATION}, key_types={KeyType.SECRET}, - client_parameters=self.client_specs.client_parameters, + program_info=self.program_info, ) # Encrypted Negation Statistics @@ -800,25 +795,28 @@ def encrypted_negation_count(self, function: str) -> int: """ Get the number of encrypted negations in the compiled program. """ - return self._compilation_feedback.circuit(function).count( - operations={PrimitiveOperation.ENCRYPTED_NEGATION} + return MoreCircuitCompilationFeedback.count( + self._library.get_program_compilation_feedback().get_circuit_feedback(function), + operations={PrimitiveOperation.ENCRYPTED_NEGATION}, ) def encrypted_negation_count_per_parameter(self, function: str) -> Dict[Parameter, int]: """ Get the number of encrypted negations per parameter in the compiled program. """ - return self._compilation_feedback.circuit(function).count_per_parameter( + return MoreCircuitCompilationFeedback.count_per_parameter( + self._library.get_program_compilation_feedback().get_circuit_feedback(function), operations={PrimitiveOperation.ENCRYPTED_NEGATION}, key_types={KeyType.SECRET}, - client_parameters=self.client_specs.client_parameters, + program_info=self.program_info, ) def encrypted_negation_count_per_tag(self, function: str) -> Dict[str, int]: """ Get the number of encrypted negations per tag in the compiled program. """ - return self._compilation_feedback.circuit(function).count_per_tag( + return MoreCircuitCompilationFeedback.count_per_tag( + self._library.get_program_compilation_feedback().get_circuit_feedback(function), operations={PrimitiveOperation.ENCRYPTED_NEGATION}, ) @@ -828,8 +826,9 @@ def encrypted_negation_count_per_tag_per_parameter( """ Get the number of encrypted negations per tag per parameter in the compiled program. """ - return self._compilation_feedback.circuit(function).count_per_tag_per_parameter( + return MoreCircuitCompilationFeedback.count_per_tag_per_parameter( + self._library.get_program_compilation_feedback().get_circuit_feedback(function), operations={PrimitiveOperation.ENCRYPTED_NEGATION}, key_types={KeyType.SECRET}, - client_parameters=self.client_specs.client_parameters, + program_info=self.program_info, ) diff --git a/frontends/concrete-python/concrete/fhe/compilation/specs.py b/frontends/concrete-python/concrete/fhe/compilation/specs.py index 4d79132bc2..7033bd1893 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/specs.py +++ b/frontends/concrete-python/concrete/fhe/compilation/specs.py @@ -7,7 +7,7 @@ from typing import Any # mypy: disable-error-code=attr-defined -from concrete.compiler import ClientParameters +from concrete.compiler import ProgramInfo # pylint: enable=import-error,no-member,no-name-in-module @@ -17,13 +17,13 @@ class ClientSpecs: ClientSpecs class, to create Client objects. """ - client_parameters: ClientParameters + program_info: ProgramInfo - def __init__(self, client_parameters: ClientParameters): - self.client_parameters = client_parameters + def __init__(self, program_info: ProgramInfo): + self.program_info = program_info def __eq__(self, other: Any): # pragma: no cover - return self.client_parameters.serialize() == other.client_parameters.serialize() + return self.program_info.serialize() == other.program_info.serialize() def serialize(self) -> bytes: """ @@ -34,7 +34,7 @@ def serialize(self) -> bytes: serialized client specs """ - return self.client_parameters.serialize() + return self.program_info.serialize() @staticmethod def deserialize(serialized_client_specs: bytes) -> "ClientSpecs": @@ -50,5 +50,5 @@ def deserialize(serialized_client_specs: bytes) -> "ClientSpecs": deserialized client specs """ - client_parameters = ClientParameters.deserialize(serialized_client_specs) - return ClientSpecs(client_parameters) + program_info = ProgramInfo.deserialize(serialized_client_specs) + return ClientSpecs(program_info) diff --git a/frontends/concrete-python/concrete/fhe/compilation/utils.py b/frontends/concrete-python/concrete/fhe/compilation/utils.py index 3303cf152c..0c36994589 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/utils.py +++ b/frontends/concrete-python/concrete/fhe/compilation/utils.py @@ -136,7 +136,7 @@ def validate_input_args( List[Optional[Union[int, np.ndarray]]]: ordered validated args """ - functions_parameters = json.loads(client_specs.client_parameters.serialize())["circuits"] + functions_parameters = json.loads(client_specs.program_info.serialize())["circuits"] for function_parameters in functions_parameters: if function_parameters["name"] == function_name: client_parameters_json = function_parameters diff --git a/frontends/concrete-python/concrete/fhe/compilation/value.py b/frontends/concrete-python/concrete/fhe/compilation/value.py index cff393c67f..3d42bf274b 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/value.py +++ b/frontends/concrete-python/concrete/fhe/compilation/value.py @@ -4,44 +4,28 @@ # pylint: disable=import-error,no-name-in-module -from concrete.compiler import Value as NativeValue - -# pylint: enable=import-error,no-name-in-module +from concrete.compiler import TransportValue class Value: """ - Value class, to store scalar or tensor values which can be encrypted or clear. + A public value object that can be sent between client and server. """ - inner: NativeValue + _inner: TransportValue - def __init__(self, inner: NativeValue): - self.inner = inner + def __init__(self, inner: TransportValue): + self._inner = inner - def serialize(self) -> bytes: + @staticmethod + def deserialize(buffer: bytes) -> "Value": """ - Serialize data into bytes. - - Returns: - bytes: - serialized data + Deserialize a Value from bytes. """ + return Value(TransportValue.deserialize(buffer)) - return self.inner.serialize() - - @staticmethod - def deserialize(serialized_data: bytes) -> "Value": + def serialize(self) -> bytes: """ - Deserialize data from bytes. - - Args: - serialized_data (bytes): - previously serialized data - - Returns: - Value: - deserialized data + Serialize a Value to bytes. """ - - return Value(NativeValue.deserialize(serialized_data)) + return self._inner.serialize() diff --git a/frontends/concrete-python/concrete/fhe/mlir/converter.py b/frontends/concrete-python/concrete/fhe/mlir/converter.py index c895859282..ff4c97e223 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/converter.py +++ b/frontends/concrete-python/concrete/fhe/mlir/converter.py @@ -18,11 +18,10 @@ from mlir.ir import Location as MlirLocation from mlir.ir import Module as MlirModule -from .. import tfhers from ..compilation.composition import CompositionRule from ..compilation.configuration import Configuration, Exactness, ParameterSelectionStrategy from ..representation import Graph, GraphProcessor, MultiGraphProcessor, Node, Operation -from ..tfhers import TFHERSIntegerType +from ..tfhers.dtypes import TFHERSIntegerType from .context import Context from .conversion import Conversion from .processors import * # pylint: disable=wildcard-import @@ -933,7 +932,7 @@ def zeros(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion def tfhers_to_native(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: assert len(preds) == 1 tfhers_int = preds[0] - dtype: tfhers.TFHERSIntegerType = node.properties["attributes"]["type"] + dtype: TFHERSIntegerType = node.properties["attributes"]["type"] result_bit_width, carry_width, msg_width = ( dtype.bit_width, dtype.carry_width, @@ -965,7 +964,7 @@ def tfhers_to_native(self, ctx: Context, node: Node, preds: List[Conversion]) -> def tfhers_from_native(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: assert len(preds) == 1 - dtype: tfhers.TFHERSIntegerType = node.properties["attributes"]["type"] + dtype: TFHERSIntegerType = node.properties["attributes"]["type"] input_bit_width, carry_width, msg_width = ( dtype.bit_width, dtype.carry_width, diff --git a/frontends/concrete-python/concrete/fhe/tfhers/bridge.py b/frontends/concrete-python/concrete/fhe/tfhers/bridge.py index af8dd1701f..52ffed0af6 100644 --- a/frontends/concrete-python/concrete/fhe/tfhers/bridge.py +++ b/frontends/concrete-python/concrete/fhe/tfhers/bridge.py @@ -2,11 +2,13 @@ Declaration of `tfhers.Bridge` class. """ -from typing import Dict, List, Optional, Union +# pylint: disable=import-error,no-member,no-name-in-module +from typing import Dict, List, Optional from concrete.compiler import LweSecretKey, TfhersExporter, TfhersFheIntDescription from concrete import fhe +from concrete.fhe.compilation.value import Value from .dtypes import EncryptionKeyChoice, TFHERSIntegerType @@ -22,6 +24,7 @@ class Bridge: circuit: "fhe.Circuit" input_types: List[Optional[TFHERSIntegerType]] + output_types: List[Optional[TFHERSIntegerType]] def __init__( @@ -57,7 +60,7 @@ def _output_type(self, output_idx: int) -> Optional[TFHERSIntegerType]: return self.output_types[output_idx] def _input_keyid(self, input_idx: int) -> int: - return self.circuit.client.specs.client_parameters.input_keyid_at( + return self.circuit.client.specs.program_info.input_keyid_at( input_idx, self.circuit.function_name ) @@ -99,7 +102,7 @@ def _description_from_type( ks_first, ) - def import_value(self, buffer: bytes, input_idx: int) -> "fhe.Value": + def import_value(self, buffer: bytes, input_idx: int) -> Value: """Import a serialized TFHErs integer as a Value. Args: @@ -107,7 +110,7 @@ def import_value(self, buffer: bytes, input_idx: int) -> "fhe.Value": input_idx (int): the index of the input expecting this value Returns: - fhe.Value: imported value + fhe.TransportValue: imported value """ input_type = self._input_type(input_idx) if input_type is None: # pragma: no cover @@ -122,9 +125,7 @@ def import_value(self, buffer: bytes, input_idx: int) -> "fhe.Value": if not signed: keyid = self._input_keyid(input_idx) variance = self._input_variance(input_idx) - return fhe.Value( - TfhersExporter.import_fheuint8(buffer, fheint_desc, keyid, variance) - ) + return Value(TfhersExporter.import_fheuint8(buffer, fheint_desc, keyid, variance)) msg = ( # pragma: no cover f"importing {'signed' if signed else 'unsigned'} integers of {bit_width}bits is not" @@ -132,11 +133,11 @@ def import_value(self, buffer: bytes, input_idx: int) -> "fhe.Value": ) raise NotImplementedError(msg) # pragma: no cover - def export_value(self, value: "fhe.Value", output_idx: int) -> bytes: + def export_value(self, value: Value, output_idx: int) -> bytes: """Export a value as a serialized TFHErs integer. Args: - value (fhe.Value): value to export + value (TransportValue): value to export output_idx (int): the index corresponding to this output Returns: @@ -153,7 +154,9 @@ def export_value(self, value: "fhe.Value", output_idx: int) -> bytes: signed = output_type.is_signed if bit_width == 8: if not signed: - return TfhersExporter.export_fheuint8(value.inner, fheint_desc) + return TfhersExporter.export_fheuint8( + value._inner, fheint_desc # pylint: disable=protected-access + ) msg = ( # pragma: no cover f"exporting value to {'signed' if signed else 'unsigned'} integers of {bit_width}bits" @@ -172,7 +175,9 @@ def serialize_input_secret_key(self, input_idx: int) -> bytes: """ keyid = self._input_keyid(input_idx) # pylint: disable=protected-access - secret_key = self.circuit.client.keys._keyset.get_lwe_secret_key(keyid) # type: ignore + keys = self.circuit.client.keys + assert keys is not None + secret_key = keys._keyset.get_client_keys().get_secret_keys()[keyid] # type: ignore # pylint: enable=protected-access return secret_key.serialize() @@ -200,9 +205,6 @@ def keygen_with_initial_keys( Raises: RuntimeError: if failed to deserialize the key """ - client_specs = self.circuit.keys.client_specs - assert isinstance(client_specs, fhe.ClientSpecs) - initial_keys: Dict[int, LweSecretKey] = {} for input_idx in input_idx_to_key_buffer: key_id = self._input_keyid(input_idx) @@ -211,7 +213,7 @@ def keygen_with_initial_keys( continue key_buffer = input_idx_to_key_buffer[input_idx] - param = client_specs.client_parameters.lwe_secret_key_param(key_id) + param = self.circuit.client.specs.program_info.secret_keys()[key_id] try: initial_keys[key_id] = LweSecretKey.deserialize(key_buffer, param) except Exception as e: # pragma: no cover diff --git a/frontends/concrete-python/concrete/fhe/tfhers/dtypes.py b/frontends/concrete-python/concrete/fhe/tfhers/dtypes.py index 5bf95c67c4..a4fd342e52 100644 --- a/frontends/concrete-python/concrete/fhe/tfhers/dtypes.py +++ b/frontends/concrete-python/concrete/fhe/tfhers/dtypes.py @@ -201,7 +201,7 @@ def decode(self, value: Union[list, np.ndarray]) -> Union[int, np.ndarray]: if len(value.shape) == 1: # lsb first - return sum(v << i * msg_width for i, v in enumerate(value)) + return sum(int(v) << (i * msg_width) for i, v in enumerate(value)) cts = value.reshape((-1, expected_ct_shape)) return np.array([self.decode(ct) for ct in cts]).reshape(value.shape[:-1]) diff --git a/frontends/concrete-python/examples/key_value_database/static_size.py b/frontends/concrete-python/examples/key_value_database/static_size.py index c82bdb2d92..31cd3df436 100644 --- a/frontends/concrete-python/examples/key_value_database/static_size.py +++ b/frontends/concrete-python/examples/key_value_database/static_size.py @@ -261,7 +261,7 @@ def initialize(self, initial_state: Optional[Union[List, np.ndarray]] = None): initial_state_clear = self.module.reset.encrypt(initial_state) initial_state_encrypted = self.module.reset.run(initial_state_clear) - self.state = initial_state_encrypted + self.state = initial_state_encrypted # type: ignore def decode_entry(self, entry: np.ndarray) -> Optional[Tuple[int, int]]: if entry[0] == 0: diff --git a/frontends/concrete-python/examples/prime-match/prime-match-semi-honest.py b/frontends/concrete-python/examples/prime-match/prime-match-semi-honest.py index ef32c7ea26..d1c9bb1577 100644 --- a/frontends/concrete-python/examples/prime-match/prime-match-semi-honest.py +++ b/frontends/concrete-python/examples/prime-match/prime-match-semi-honest.py @@ -58,6 +58,7 @@ def prime_match( print() start = time.time() +assert circuit.keys is not None circuit.keys.generate() end = time.time() print(f"Key generation took: {end - start:.3f} seconds") diff --git a/frontends/concrete-python/examples/prime-match/prime-match.py b/frontends/concrete-python/examples/prime-match/prime-match.py index 1f18b8a894..828b16a3e9 100644 --- a/frontends/concrete-python/examples/prime-match/prime-match.py +++ b/frontends/concrete-python/examples/prime-match/prime-match.py @@ -183,6 +183,7 @@ def prime_match( print() start = time.time() +assert circuit.keys is not None circuit.keys.generate() end = time.time() print(f"Key generation took: {end - start:.3f} seconds") diff --git a/frontends/concrete-python/tests/compilation/test_keys.py b/frontends/concrete-python/tests/compilation/test_keys.py index d1d9fbe24e..94d56ef858 100644 --- a/frontends/concrete-python/tests/compilation/test_keys.py +++ b/frontends/concrete-python/tests/compilation/test_keys.py @@ -98,6 +98,39 @@ def f(x): assert circuit2.decrypt(evaluation) == 25 +def test_keys_bad_seed(helpers): + """ + Test serializing and deserializing keys. + """ + + @fhe.compiler({"x": "encrypted"}) + def f(x): + return x**2 + + inputset = range(10) + + circuit = f.compile(inputset, helpers.configuration()) + server = circuit.server + + client1 = fhe.Client(server.client_specs) + + with pytest.raises(ValueError) as excinfo: + client1.keys.generate(secret_seed=-1) + assert str(excinfo.value) == "secret_seed must be a positive 128 bits integer" + + with pytest.raises(ValueError) as excinfo: + client1.keys.generate(secret_seed=2**128) + assert str(excinfo.value) == "secret_seed must be a positive 128 bits integer" + + with pytest.raises(ValueError) as excinfo: + client1.keys.generate(encryption_seed=2**128) + assert str(excinfo.value) == "encryption_seed must be a positive 128 bits integer" + + with pytest.raises(ValueError) as excinfo: + client1.keys.generate(encryption_seed=-1) + assert str(excinfo.value) == "encryption_seed must be a positive 128 bits integer" + + def test_keys_serialize_deserialize(helpers): """ Test serializing and deserializing keys. diff --git a/frontends/concrete-python/tests/execution/test_tfhers.py b/frontends/concrete-python/tests/execution/test_tfhers.py index 8589cdb724..c56f26ee84 100644 --- a/frontends/concrete-python/tests/execution/test_tfhers.py +++ b/frontends/concrete-python/tests/execution/test_tfhers.py @@ -56,7 +56,7 @@ def is_input_and_output_tfhers( tfhers_outs: List[int], ) -> bool: """Check if inputs and outputs description match tfhers parameters""" - params = json.loads(circuit.client.specs.client_parameters.serialize()) + params = json.loads(circuit.client.specs.serialize()) main_circuit = params["circuits"][0] # check all encrypted input/output have the correct lwe_dim ins = main_circuit["inputs"]