Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Doc/restrictions #1150

Open
wants to merge 2 commits into
base: alex/optimizer_keyset_generation
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#ifndef CONCRETELANG_COMMON_KEYSETS_H
#define CONCRETELANG_COMMON_KEYSETS_H

#include "concrete-optimizer.hpp"
#include "concrete-protocol.capnp.h"
#include "concretelang/Common/Csprng.h"
#include "concretelang/Common/Error.h"
Expand Down Expand Up @@ -92,6 +93,10 @@ class KeysetCache {
KeysetCache() = default;
};

Message<concreteprotocol::KeysetInfo> generate_generic_keyset_info(
std::vector<concrete_optimizer::utils::PartitionDefinition> partitions,
bool generate_fks);

} // namespace keysets
} // namespace concretelang

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include "concretelang/Support/Error.h"
#include "concretelang/Support/V0Parameters.h"
#include "concretelang/Support/logging.h"
#include <_types/_uint8_t.h>
#include <cstdint>
#include <filesystem>
#include <memory>
#include <mlir-c/Bindings/Python/Interop.h>
Expand Down Expand Up @@ -645,6 +647,18 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
output.append(")");
return output;
}

bool operator==(LweSecretKeyParam const &other) const {
capnp::AnyStruct::Reader left = this->info.asReader().getParams();
capnp::AnyStruct::Reader right = other.info.asReader().getParams();
return left == right;
}

bool operator!=(LweSecretKeyParam const &other) const {
capnp::AnyStruct::Reader left = this->info.asReader().getParams();
capnp::AnyStruct::Reader right = other.info.asReader().getParams();
return left != right;
}
};
pybind11::class_<LweSecretKeyParam>(m, "LweSecretKeyParam")
.def(
Expand All @@ -659,6 +673,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](pybind11::object key) {
return pybind11::hash(pybind11::repr(key));
})
.def(pybind11::self == pybind11::self)
.def(pybind11::self != pybind11::self)
.doc() = "Parameters of an LWE Secret Key.";

// ------------------------------------------------------------------------------//
Expand Down Expand Up @@ -689,6 +705,18 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
output.append(")");
return output;
}

bool operator==(BootstrapKeyParam const &other) const {
capnp::AnyStruct::Reader left = this->info.asReader().getParams();
capnp::AnyStruct::Reader right = other.info.asReader().getParams();
return left == right;
}

bool operator!=(BootstrapKeyParam const &other) const {
capnp::AnyStruct::Reader left = this->info.asReader().getParams();
capnp::AnyStruct::Reader right = other.info.asReader().getParams();
return left != right;
}
};
pybind11::class_<BootstrapKeyParam>(m, "BootstrapKeyParam")
.def(
Expand Down Expand Up @@ -745,6 +773,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](pybind11::object key) {
return pybind11::hash(pybind11::repr(key));
})
.def(pybind11::self == pybind11::self)
.def(pybind11::self != pybind11::self)
.doc() = "Parameters of a Bootstrap key.";

// ------------------------------------------------------------------------------//
Expand All @@ -766,6 +796,18 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
output.append(")");
return output;
}

bool operator==(KeyswitchKeyParam const &other) const {
capnp::AnyStruct::Reader left = this->info.asReader().getParams();
capnp::AnyStruct::Reader right = other.info.asReader().getParams();
return left == right;
}

bool operator!=(KeyswitchKeyParam const &other) const {
capnp::AnyStruct::Reader left = this->info.asReader().getParams();
capnp::AnyStruct::Reader right = other.info.asReader().getParams();
return left != right;
}
};
pybind11::class_<KeyswitchKeyParam>(m, "KeyswitchKeyParam")
.def(
Expand Down Expand Up @@ -804,6 +846,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](pybind11::object key) {
return pybind11::hash(pybind11::repr(key));
})
.def(pybind11::self == pybind11::self)
.def(pybind11::self != pybind11::self)
.doc() = "Parameters of a keyswitch key.";

// ------------------------------------------------------------------------------//
Expand Down Expand Up @@ -834,6 +878,18 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
output.append(")");
return output;
}

bool operator==(PackingKeyswitchKeyParam const &other) const {
capnp::AnyStruct::Reader left = this->info.asReader().getParams();
capnp::AnyStruct::Reader right = other.info.asReader().getParams();
return left == right;
}

bool operator!=(PackingKeyswitchKeyParam const &other) const {
capnp::AnyStruct::Reader left = this->info.asReader().getParams();
capnp::AnyStruct::Reader right = other.info.asReader().getParams();
return left != right;
}
};
pybind11::class_<PackingKeyswitchKeyParam>(m, "PackingKeyswitchKeyParam")
.def(
Expand Down Expand Up @@ -892,13 +948,44 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](pybind11::object key) {
return pybind11::hash(pybind11::repr(key));
})
.def(pybind11::self == pybind11::self)
.def(pybind11::self != pybind11::self)
.doc() = "Parameters of a packing keyswitch key.";

// ------------------------------------------------------------------------------//
// PARTITION DEFINITION //
// ------------------------------------------------------------------------------//
//
pybind11::class_<concrete_optimizer::utils::PartitionDefinition>(
m, "PartitionDefinition")
.def(init([](uint8_t precision, double norm2)
-> concrete_optimizer::utils::PartitionDefinition {
return concrete_optimizer::utils::PartitionDefinition{precision,
norm2};
}),
arg("precision"), arg("norm2"))
.doc() = "Definition of a partition (in terms of precision in bits and "
"norm2 in value).";

// ------------------------------------------------------------------------------//
// KEYSET INFO //
// ------------------------------------------------------------------------------//
typedef Message<concreteprotocol::KeysetInfo> KeysetInfo;
pybind11::class_<KeysetInfo>(m, "KeysetInfo")
.def_static(
"generate_generic",
[](std::vector<concrete_optimizer::utils::PartitionDefinition>
partitions,
bool generateFks) -> KeysetInfo {
if (partitions.size() < 2) {
throw std::runtime_error("Need at least two partition defs to "
"generate a generic keyset info.");
}
return ::concretelang::keysets::generate_generic_keyset_info(
partitions, generateFks);
},
arg("partition_defs"), arg("generate_fks"),
"Generate a generic keyset info for a set of partition definitions")
.def(
"secret_keys",
[](KeysetInfo &keysetInfo) {
Expand Down
94 changes: 94 additions & 0 deletions compilers/concrete-compiler/compiler/lib/Common/Keysets.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "concretelang/Common/Keysets.h"
#include "capnp/message.h"
#include "concrete-cpu.h"
#include "concrete-optimizer.hpp"
#include "concrete-protocol.capnp.h"
#include "concretelang/Common/Csprng.h"
#include "concretelang/Common/Error.h"
Expand Down Expand Up @@ -417,5 +418,98 @@ KeysetCache::getKeyset(const Message<concreteprotocol::KeysetInfo> &keysetInfo,
return std::move(keyset);
}

Message<concreteprotocol::KeysetInfo> generate_generic_keyset_info(
std::vector<concrete_optimizer::utils::PartitionDefinition> partitionDefs,
bool generateFks) {
auto output = Message<concreteprotocol::KeysetInfo>{};
rust::Vec<concrete_optimizer::utils::PartitionDefinition> rustPartitionDefs{};
for (auto def : partitionDefs) {
rustPartitionDefs.push_back(def);
}
auto parameters = concrete_optimizer::utils::generate_generic_keyset_info(
rustPartitionDefs, generateFks);

auto skLen = (int)parameters.secret_keys.size();
auto skBuilder = output.asBuilder().initLweSecretKeys(skLen);
for (int i = 0; i < skLen; i++) {
auto output = Message<concreteprotocol::LweSecretKeyInfo>();
auto sk = parameters.secret_keys[i];
output.asBuilder().setId(sk.identifier);
output.asBuilder().getParams().setIntegerPrecision(64);
output.asBuilder().getParams().setLweDimension(sk.polynomial_size *
sk.glwe_dimension);
output.asBuilder().getParams().setKeyType(
::concreteprotocol::KeyType::BINARY);
skBuilder.setWithCaveats(i, output.asReader());
}

auto bskLen = (int)parameters.bootstrap_keys.size();
auto bskBuilder = output.asBuilder().initLweBootstrapKeys(bskLen);
for (int i = 0; i < bskLen; i++) {
auto output = Message<concreteprotocol::LweBootstrapKeyInfo>();
auto bsk = parameters.bootstrap_keys[i];
output.asBuilder().setId(bsk.identifier);
output.asBuilder().setInputId(bsk.input_key.identifier);
output.asBuilder().setOutputId(bsk.output_key.identifier);
output.asBuilder().getParams().setLevelCount(
bsk.br_decomposition_parameter.level);
output.asBuilder().getParams().setBaseLog(
bsk.br_decomposition_parameter.log2_base);
output.asBuilder().getParams().setGlweDimension(
bsk.output_key.glwe_dimension);
output.asBuilder().getParams().setPolynomialSize(
bsk.output_key.polynomial_size);
output.asBuilder().getParams().setInputLweDimension(
bsk.input_key.polynomial_size);
output.asBuilder().getParams().setIntegerPrecision(64);
output.asBuilder().getParams().setKeyType(
concreteprotocol::KeyType::BINARY);
bskBuilder.setWithCaveats(i, output.asReader());
}

auto kskLen = (int)parameters.keyswitch_keys.size();
auto ckskLen = (int)parameters.conversion_keyswitch_keys.size();
auto kskBuilder = output.asBuilder().initLweKeyswitchKeys(kskLen + ckskLen);
for (int i = 0; i < kskLen; i++) {
auto output = Message<concreteprotocol::LweKeyswitchKeyInfo>();
auto ksk = parameters.keyswitch_keys[i];
output.asBuilder().setId(ksk.identifier);
output.asBuilder().setInputId(ksk.input_key.identifier);
output.asBuilder().setOutputId(ksk.output_key.identifier);
output.asBuilder().getParams().setLevelCount(
ksk.ks_decomposition_parameter.level);
output.asBuilder().getParams().setBaseLog(
ksk.ks_decomposition_parameter.log2_base);
output.asBuilder().getParams().setIntegerPrecision(64);
output.asBuilder().getParams().setInputLweDimension(
ksk.input_key.glwe_dimension * ksk.input_key.polynomial_size);
output.asBuilder().getParams().setOutputLweDimension(
ksk.output_key.glwe_dimension * ksk.output_key.polynomial_size);
output.asBuilder().getParams().setKeyType(
concreteprotocol::KeyType::BINARY);
kskBuilder.setWithCaveats(i, output.asReader());
}
for (int i = 0; i < ckskLen; i++) {
auto output = Message<concreteprotocol::LweKeyswitchKeyInfo>();
auto ksk = parameters.conversion_keyswitch_keys[i];
output.asBuilder().setId(ksk.identifier);
output.asBuilder().setInputId(ksk.input_key.identifier);
output.asBuilder().setOutputId(ksk.output_key.identifier);
output.asBuilder().getParams().setLevelCount(
ksk.ks_decomposition_parameter.level);
output.asBuilder().getParams().setBaseLog(
ksk.ks_decomposition_parameter.log2_base);
output.asBuilder().getParams().setIntegerPrecision(64);
output.asBuilder().getParams().setInputLweDimension(
ksk.input_key.glwe_dimension * ksk.input_key.polynomial_size);
output.asBuilder().getParams().setOutputLweDimension(
ksk.output_key.glwe_dimension * ksk.output_key.polynomial_size);
output.asBuilder().getParams().setKeyType(
concreteprotocol::KeyType::BINARY);
kskBuilder.setWithCaveats(i + kskLen, output.asReader());
}
return output;
}

} // namespace keysets
} // namespace concretelang
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use concrete_optimizer::dag::operator::{
};
use concrete_optimizer::dag::unparametrized;
use concrete_optimizer::optimization::config::{Config, SearchSpace};
use concrete_optimizer::optimization::dag::multi_parameters::generic_generation::generate_generic_parameters;
use concrete_optimizer::optimization::dag::multi_parameters::keys_spec::CircuitSolution;
use concrete_optimizer::optimization::dag::multi_parameters::optimize::{
KeysetRestriction, MacroParameters, NoSearchSpaceRestriction, RangeRestriction,
Expand Down Expand Up @@ -913,6 +914,22 @@ fn location_from_string(string: &str) -> Box<Location> {
}
}

fn generate_generic_keyset_info(
inputs: Vec<ffi::PartitionDefinition>,
generate_fks: bool,
) -> ffi::CircuitKeys {
generate_generic_parameters(
inputs
.into_iter()
.map(
|ffi::PartitionDefinition { precision, norm2 }| concrete_optimizer::optimization::dag::multi_parameters::generic_generation::PartitionDefinition { precision, norm2 },
)
.collect(),
generate_fks,
)
.into()
}

pub struct Weights(operator::Weights);

fn vector(weights: &[i64]) -> Box<Weights> {
Expand Down Expand Up @@ -981,6 +998,12 @@ mod ffi {
#[namespace = "concrete_optimizer::utils"]
fn location_from_string(string: &str) -> Box<Location>;

#[namespace = "concrete_optimizer::utils"]
fn generate_generic_keyset_info(
partitions: Vec<PartitionDefinition>,
generate_fks: bool,
) -> CircuitKeys;

#[namespace = "concrete_optimizer::utils"]
fn get_external_partition(
name: String,
Expand Down Expand Up @@ -1359,6 +1382,13 @@ mod ffi {
pub struct KeysetRestriction {
pub info: KeysetInfo,
}

#[namespace = "concrete_optimizer::utils"]
#[derive(Debug, Clone)]
pub struct PartitionDefinition {
pub precision: u8,
pub norm2: f64,
}
}

fn processing_unit(options: &ffi::Options) -> ProcessingUnit {
Expand Down
Loading
Loading