diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/IR/TFHEAttrs.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/IR/TFHEAttrs.td index 7fc336c9c..895e41b16 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/IR/TFHEAttrs.td +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/IR/TFHEAttrs.td @@ -28,10 +28,11 @@ def TFHE_KeyswitchKeyAttr: TFHE_Attr<"GLWEKeyswitchKey", "ksk"> { "mlir::concretelang::TFHE::GLWESecretKey":$outputKey, "int":$levels, "int":$baseLog, + DefaultValuedParameter<"int64_t", "-1">: $complexity, DefaultValuedParameter<"int", "-1">: $index ); - let assemblyFormat = " (`[` $index^ `]`)? `<` $inputKey `,` $outputKey `,` $levels `,` $baseLog `>`"; + let assemblyFormat = " (`[` $index^ `]`)? (`{` $complexity^ `}`)? `<` $inputKey `,` $outputKey `,` $levels `,` $baseLog `>`"; } def TFHE_BootstrapKeyAttr: TFHE_Attr<"GLWEBootstrapKey", "bsk"> { @@ -45,10 +46,11 @@ def TFHE_BootstrapKeyAttr: TFHE_Attr<"GLWEBootstrapKey", "bsk"> { "int":$glweDim, "int":$levels, "int":$baseLog, + DefaultValuedParameter<"int64_t", "-1">: $complexity, DefaultValuedParameter<"int", "-1">: $index ); - let assemblyFormat = "(`[` $index^ `]`)? `<` $inputKey `,` $outputKey `,` $polySize `,` $glweDim `,` $levels `,` $baseLog `>`"; + let assemblyFormat = "(`[` $index^ `]`)? (`{` $complexity^ `}`)? `<` $inputKey `,` $outputKey `,` $polySize `,` $glweDim `,` $levels `,` $baseLog `>`"; } def TFHE_PackingKeyswitchKeyAttr: TFHE_Attr<"GLWEPackingKeyswitchKey", "pksk"> { diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilationFeedback.h b/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilationFeedback.h index 13a5b5760..3627ecca6 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilationFeedback.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilationFeedback.h @@ -46,6 +46,7 @@ struct Statistic { PrimitiveOperation operation; std::vector> keys; std::optional count; + double complexity; }; struct CircuitCompilationFeedback { diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index 5388396ff..c01ec1d4a 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -829,7 +829,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( .def_readonly("operation", &mlir::concretelang::Statistic::operation) .def_readonly("location", &mlir::concretelang::Statistic::location) .def_readonly("keys", &mlir::concretelang::Statistic::keys) - .def_readonly("count", &mlir::concretelang::Statistic::count); + .def_readonly("count", &mlir::concretelang::Statistic::count) + .def_readonly("complexity", &mlir::concretelang::Statistic::complexity); pybind11::class_( m, "ProgramCompilationFeedback") diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py index f073150cc..91bef4d36 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py @@ -22,22 +22,23 @@ # matches (@tag, separator( | ), filename) -REGEX_LOCATION = re.compile(r"loc\(\"(@[\w\.]+)?( \| )?(.+)\"") +REGEX_LOCATION = re.compile(r"loc\(\"(%\d*) \| (@[\w\.]+)?( \| )?(.+)\"") -def tag_from_location(location): +def id_and_tag_from_location(location): """ Extract tag of the operation from its location. """ match = REGEX_LOCATION.match(location) if match is not None: - tag, _, _ = match.groups() + id_, tag, _, _ = match.groups() # remove the @ tag = tag[1:] if tag else "" else: + id_ = "" tag = "" - return tag + return id_, tag class CircuitCompilationFeedback(WrapperCpp): @@ -149,7 +150,7 @@ def count_per_tag(self, *, operations: Set[PrimitiveOperation]) -> Dict[str, int if statistic.operation not in operations: continue - tag = tag_from_location(statistic.location) + _, tag = id_and_tag_from_location(statistic.location) tag_components = tag.split(".") for i in range(1, len(tag_components) + 1): @@ -194,7 +195,7 @@ def count_per_tag_per_parameter( if statistic.operation not in operations: continue - tag = tag_from_location(statistic.location) + _, tag = id_and_tag_from_location(statistic.location) tag_components = tag.split(".") for i in range(1, len(tag_components) + 1): @@ -216,6 +217,49 @@ def count_per_tag_per_parameter( return result + def complexity_per_tag(self) -> Dict[str, float]: + """ + Compute the complexity of each tag in the computation graph. + + Returns: + Dict[str, float]: + complexity per tag + """ + + result = {} + for statistic in self.statistics: + _, tag = id_and_tag_from_location(statistic.location) + + tag_components = tag.split(".") + for i in range(1, len(tag_components) + 1): + current_tag = ".".join(tag_components[0:i]) + if current_tag == "": + continue + + if current_tag not in result: + result[current_tag] = 0.0 + + result[current_tag] += statistic.complexity + + return result + + def complexity_per_node(self) -> Dict[str, float]: + """ + Compute the complexity of each node in the computation graph. + + Returns: + Dict[str, float]: + complexity per node + """ + + result = {} + for statistic in self.statistics: + node_id, _ = id_and_tag_from_location(statistic.location) + if node_id not in result: + result[node_id] = 0.0 + result[node_id] += statistic.complexity + return result + class ProgramCompilationFeedback(WrapperCpp): """CompilationFeedback is a set of hint computed by the compiler engine.""" diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp index 88d38fbc3..54a5b5aa2 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp @@ -585,10 +585,10 @@ struct ApplyLookupTableEintOpPattern op.getLoc(), converter->convertType(op.getType()), adaptor.getA(), newLut, TFHE::GLWEKeyswitchKeyAttr::get(op.getContext(), TFHE::GLWESecretKey(), - TFHE::GLWESecretKey(), -1, -1, -1), + TFHE::GLWESecretKey(), -1, -1, -1, -1), TFHE::GLWEBootstrapKeyAttr::get(op.getContext(), TFHE::GLWESecretKey(), TFHE::GLWESecretKey(), -1, -1, -1, -1, - -1), + -1, -1), TFHE::GLWEPackingKeyswitchKeyAttr::get( op.getContext(), TFHE::GLWESecretKey(), TFHE::GLWESecretKey(), -1, -1, -1, -1, -1, -1), diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp index ae88c00c3..a03dc05f0 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp @@ -386,7 +386,7 @@ struct ApplyLookupTableEintOpPattern op.getLoc(), getTypeConverter()->convertType(adaptor.getA().getType()), input, TFHE::GLWEKeyswitchKeyAttr::get(op.getContext(), TFHE::GLWESecretKey(), - TFHE::GLWESecretKey(), -1, -1, -1)); + TFHE::GLWESecretKey(), -1, -1, -1, -1)); if (operatorIndexes != nullptr) { ksOp->setAttr("TFHE.OId", rewriter.getI32IntegerAttr( @@ -398,7 +398,7 @@ struct ApplyLookupTableEintOpPattern op, getTypeConverter()->convertType(op.getType()), ksOp, newLut, TFHE::GLWEBootstrapKeyAttr::get(op.getContext(), TFHE::GLWESecretKey(), TFHE::GLWESecretKey(), -1, -1, -1, -1, - -1)); + -1, -1)); if (operatorIndexes != nullptr) { bsOp->setAttr("TFHE.OId", rewriter.getI32IntegerAttr( @@ -515,9 +515,9 @@ std::vector extractBitWithClearedLowerBits( auto context = op.getContext(); auto secretKey = TFHE::GLWESecretKey(); auto ksk = TFHE::GLWEKeyswitchKeyAttr::get(context, secretKey, secretKey, -1, - -1, -1); + -1, -1, -1); auto bsk = TFHE::GLWEBootstrapKeyAttr::get(context, secretKey, secretKey, -1, - -1, -1, -1, -1); + -1, -1, -1, -1, -1); auto keyswitched = rewriter.create( loc, cInputTy, shiftedRotatedInput, ksk); diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp index 74d24dd03..1ef6a2f97 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp @@ -118,7 +118,7 @@ struct KeySwitchGLWEOpPattern auto newOutputKey = converter.getIntraPBSKey(); auto keyswitchKey = TFHE::GLWEKeyswitchKeyAttr::get( ksOp->getContext(), newInputKey, newOutputKey, cryptoParameters.ksLevel, - cryptoParameters.ksLogBase, -1); + cryptoParameters.ksLogBase, -1, -1); auto newOp = rewriter.replaceOpWithNewOp( ksOp, newOutputTy, ksOp.getCiphertext(), keyswitchKey); rewriter.startRootUpdate(newOp); @@ -156,7 +156,7 @@ struct BootstrapGLWEOpPattern auto bootstrapKey = TFHE::GLWEBootstrapKeyAttr::get( bsOp->getContext(), newInputKey, newOutputKey, cryptoParameters.getPolynomialSize(), cryptoParameters.glweDimension, - cryptoParameters.brLevel, cryptoParameters.brLogBase, -1); + cryptoParameters.brLevel, cryptoParameters.brLogBase, -1, -1); auto newOp = rewriter.replaceOpWithNewOp( bsOp, newOutputTy, bsOp.getCiphertext(), bsOp.getLookupTable(), bootstrapKey); @@ -193,11 +193,11 @@ struct WopPBSGLWEOpPattern : public mlir::OpRewritePattern { auto intraKey = converter.getIntraPBSKey(); auto keyswitchKey = TFHE::GLWEKeyswitchKeyAttr::get( wopPBSOp->getContext(), interKey, intraKey, cryptoParameters.ksLevel, - cryptoParameters.ksLogBase, -1); + cryptoParameters.ksLogBase, -1, -1); auto bootstrapKey = TFHE::GLWEBootstrapKeyAttr::get( wopPBSOp->getContext(), intraKey, interKey, cryptoParameters.getPolynomialSize(), cryptoParameters.glweDimension, - cryptoParameters.brLevel, cryptoParameters.brLogBase, -1); + cryptoParameters.brLevel, cryptoParameters.brLogBase, -1, -1); auto packingKeyswitchKey = TFHE::GLWEPackingKeyswitchKeyAttr::get( wopPBSOp->getContext(), interKey, interKey, cryptoParameters.largeInteger->wopPBS.packingKeySwitch diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEKeyNormalization/TFHEKeyNormalization.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEKeyNormalization/TFHEKeyNormalization.cpp index f6c59b3c1..349ee8d4d 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEKeyNormalization/TFHEKeyNormalization.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEKeyNormalization/TFHEKeyNormalization.cpp @@ -51,7 +51,7 @@ class KeyConverter { bsk.getContext(), convertSecretKey(bsk.getInputKey()), convertSecretKey(bsk.getOutputKey()), bsk.getPolySize(), bsk.getGlweDim(), bsk.getLevels(), bsk.getBaseLog(), - circuitKeys.getBootstrapKeyIndex(bsk).value()); + bsk.getComplexity(), circuitKeys.getBootstrapKeyIndex(bsk).value()); } TFHE::GLWEKeyswitchKeyAttr @@ -59,7 +59,7 @@ class KeyConverter { return TFHE::GLWEKeyswitchKeyAttr::get( ksk.getContext(), convertSecretKey(ksk.getInputKey()), convertSecretKey(ksk.getOutputKey()), ksk.getLevels(), ksk.getBaseLog(), - circuitKeys.getKeyswitchKeyIndex(ksk).value()); + ksk.getComplexity(), circuitKeys.getKeyswitchKeyIndex(ksk).value()); } TFHE::GLWEPackingKeyswitchKeyAttr diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Analysis/ExtractStatistics.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Analysis/ExtractStatistics.cpp index a68b3d015..623dde16a 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Analysis/ExtractStatistics.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Analysis/ExtractStatistics.cpp @@ -1,15 +1,16 @@ -#include "concretelang/Support/CompilationFeedback.h" #include #include #include +#include +#include + +#include #include #include #include #include -#include - using namespace mlir::concretelang; using namespace mlir; @@ -33,6 +34,16 @@ namespace TFHE { } \ } +double levelledComplexity(GLWESecretKey key, std::optional count) { + double complexity = 0.0; + if (key.isNormalized()) { + complexity = key.getNormalized()->dimension; + } else if (key.isParameterized()) { + complexity = key.getParameterized()->dimension; + } + return complexity * count.value_or(1.0); +} + struct ExtractTFHEStatisticsPass : public PassWrapper>, public TripCountTracker { @@ -138,11 +149,15 @@ struct ExtractTFHEStatisticsPass std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index); keys.push_back(key); + double complexity = + levelledComplexity(op.getResult().getType().getKey(), count); + pass.circuitFeedback->statistics.push_back(concretelang::Statistic{ location, operation, keys, count, + complexity, }); return std::nullopt; @@ -165,11 +180,15 @@ struct ExtractTFHEStatisticsPass std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index); keys.push_back(key); + double complexity = + levelledComplexity(op.getResult().getType().getKey(), count); + pass.circuitFeedback->statistics.push_back(concretelang::Statistic{ location, operation, keys, count, + complexity, }); return std::nullopt; @@ -192,11 +211,15 @@ struct ExtractTFHEStatisticsPass std::make_pair(KeyType::BOOTSTRAP, (int64_t)bsk.getIndex()); keys.push_back(key); + auto complexity = + (double)op.getKeyAttr().getComplexity() * (double)count.value_or(1); + pass.circuitFeedback->statistics.push_back(concretelang::Statistic{ location, operation, keys, count, + complexity, }); return std::nullopt; @@ -219,11 +242,15 @@ struct ExtractTFHEStatisticsPass std::make_pair(KeyType::KEY_SWITCH, (int64_t)ksk.getIndex()); keys.push_back(key); + auto complexity = + (double)op.getKeyAttr().getComplexity() * (double)count.value_or(1); + pass.circuitFeedback->statistics.push_back(concretelang::Statistic{ location, operation, keys, count, + complexity, }); return std::nullopt; @@ -246,11 +273,15 @@ struct ExtractTFHEStatisticsPass std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index); keys.push_back(key); + double complexity = + levelledComplexity(op.getResult().getType().getKey(), count); + pass.circuitFeedback->statistics.push_back(concretelang::Statistic{ location, operation, keys, count, + complexity, }); return std::nullopt; @@ -273,11 +304,15 @@ struct ExtractTFHEStatisticsPass std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index); keys.push_back(key); + double complexity = + levelledComplexity(op.getResult().getType().getKey(), count); + pass.circuitFeedback->statistics.push_back(concretelang::Statistic{ location, operation, keys, count, + complexity, }); return std::nullopt; @@ -299,22 +334,33 @@ struct ExtractTFHEStatisticsPass std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index); keys.push_back(key); + double complexity = + levelledComplexity(op.getResult().getType().getKey(), count); + + // TODO: I though subtraction was implemented like this but it's complexity + // seems to be the same as either `neg(encrypted)` or the addition, not + // both. What should we do here? + // clear - encrypted = clear + neg(encrypted) auto operation = PrimitiveOperation::ENCRYPTED_NEGATION; + pass.circuitFeedback->statistics.push_back(concretelang::Statistic{ location, operation, keys, count, + complexity, }); operation = PrimitiveOperation::CLEAR_ADDITION; + pass.circuitFeedback->statistics.push_back(concretelang::Statistic{ location, operation, keys, count, + complexity, }); return std::nullopt; @@ -345,11 +391,15 @@ struct ExtractTFHEStatisticsPass key = std::make_pair(KeyType::PACKING_KEY_SWITCH, (int64_t)pksk.getIndex()); keys.push_back(key); + // TODO + double complexity = 0.0; + pass.circuitFeedback->statistics.push_back(concretelang::Statistic{ location, operation, keys, count, + complexity, }); return std::nullopt; diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/TFHECircuitSolutionParametrization.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/TFHECircuitSolutionParametrization.cpp index 1094757ce..5e6b88c13 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/TFHECircuitSolutionParametrization.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/TFHECircuitSolutionParametrization.cpp @@ -84,7 +84,7 @@ class CircuitSolutionWrapper { return TFHE::GLWEKeyswitchKeyAttr::get( ctx, toGLWESecretKey(ksk.input_key), toGLWESecretKey(ksk.output_key), ksk.ks_decomposition_parameter.level, - ksk.ks_decomposition_parameter.log2_base, -1); + ksk.ks_decomposition_parameter.log2_base, ksk.unitary_cost, -1); } // Returns a `GLWEKeyswitchKeyAttr` for the keyswitch key of an @@ -108,7 +108,7 @@ class CircuitSolutionWrapper { ctx, toGLWESecretKey(bsk.input_key), toGLWESecretKey(bsk.output_key), bsk.output_key.polynomial_size, bsk.output_key.glwe_dimension, bsk.br_decomposition_parameter.level, - bsk.br_decomposition_parameter.log2_base, -1); + bsk.br_decomposition_parameter.log2_base, bsk.unitary_cost, -1); } // Looks up the keyswitch key for an operation tagged with a given diff --git a/compilers/concrete-compiler/compiler/lib/Support/CompilationFeedback.cpp b/compilers/concrete-compiler/compiler/lib/Support/CompilationFeedback.cpp index f5395fa8f..8592e115b 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/CompilationFeedback.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/CompilationFeedback.cpp @@ -198,6 +198,7 @@ llvm::json::Object statisticToJson(const Statistic &statistic) { keysJson.push_back(std::move(keyJson)); } object.insert({"keys", std::move(keysJson)}); + object.insert({"complexity", statistic.complexity}); return object; } @@ -333,7 +334,8 @@ bool fromJSON(const llvm::json::Value j, mlir::concretelang::Statistic &v, return O && O.map("location", v.location) && O.map("operation", v.operation) && O.map("operation", v.operation) && - O.map("keys", v.keys) && O.map("count", v.count); + O.map("keys", v.keys) && O.map("count", v.count) && + O.map("complexity", v.complexity); } bool fromJSON(const llvm::json::Value j, diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs index 1e61cc33e..e0913a1c2 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs @@ -243,6 +243,7 @@ fn convert_to_circuit_solution(sol: &ffi::DagSolution, dag: &OperationDag) -> ff log2_base: sol.ks_decomposition_base_log, }, description: "tlu keyswitch".into(), + unitary_cost: 0.0, }; let bootstrap_key = ffi::BootstrapKey { identifier: 0, @@ -253,6 +254,7 @@ fn convert_to_circuit_solution(sol: &ffi::DagSolution, dag: &OperationDag) -> ff log2_base: sol.br_decomposition_base_log, }, description: "tlu bootstrap".into(), + unitary_cost: 0.0, }; let circuit_bootstrap_keys = if sol.use_wop_pbs { vec![ffi::CircuitBoostrapKey { @@ -383,6 +385,7 @@ impl From for ffi::KeySwitchKey { input_key: v.input_key.into(), output_key: v.output_key.into(), ks_decomposition_parameter: v.ks_decomposition_parameter.into(), + unitary_cost: v.unitary_cost, description: v.description, } } @@ -397,6 +400,7 @@ impl From for ffi::ConversionKeySwitchKey { ks_decomposition_parameter: v.ks_decomposition_parameter.into(), description: v.description, fast_keyswitch: v.fast_keyswitch, + unitary_cost: v.unitary_cost, } } } @@ -408,6 +412,7 @@ impl From for ffi::BootstrapKey { input_key: v.input_key.into(), output_key: v.output_key.into(), br_decomposition_parameter: v.br_decomposition_parameter.into(), + unitary_cost: v.unitary_cost, description: v.description, } } @@ -425,7 +430,7 @@ impl From for ffi::CircuitBoostrapKey { } impl From - for ffi::PrivateFunctionalPackingBoostrapKey +for ffi::PrivateFunctionalPackingBoostrapKey { fn from(v: keys_spec::PrivateFunctionalPackingBoostrapKey) -> Self { Self { @@ -622,7 +627,7 @@ impl OperationDag { let encoding = options.encoding.into(); #[allow(clippy::wildcard_in_or_patterns)] - let p_cut = match options.multi_param_strategy { + let p_cut = match options.multi_param_strategy { ffi::MultiParamStrategy::ByPrecisionAndNorm2 => { PartitionCut::maximal_partitionning(&self.0) } @@ -682,7 +687,6 @@ impl Into for ffi::Encoding { mod ffi { #[namespace = "concrete_optimizer"] extern "Rust" { - #[namespace = "concrete_optimizer::v0"] fn optimize_bootstrap(precision: u64, noise_factor: f64, options: Options) -> Solution; @@ -785,14 +789,22 @@ mod ffi { #[namespace = "concrete_optimizer::v0"] #[derive(Debug, Clone, Copy, Default)] pub struct Solution { - pub input_lwe_dimension: u64, //n_big - pub internal_ks_output_lwe_dimension: u64, //n_small - pub ks_decomposition_level_count: u64, //l(KS) - pub ks_decomposition_base_log: u64, //b(KS) - pub glwe_polynomial_size: u64, //N - pub glwe_dimension: u64, //k - pub br_decomposition_level_count: u64, //l(BR) - pub br_decomposition_base_log: u64, //b(BR) + pub input_lwe_dimension: u64, + //n_big + pub internal_ks_output_lwe_dimension: u64, + //n_small + pub ks_decomposition_level_count: u64, + //l(KS) + pub ks_decomposition_base_log: u64, + //b(KS) + pub glwe_polynomial_size: u64, + //N + pub glwe_dimension: u64, + //k + pub br_decomposition_level_count: u64, + //l(BR) + pub br_decomposition_base_log: u64, + //b(BR) pub complexity: f64, pub noise_max: f64, pub p_error: f64, // error probability @@ -801,17 +813,26 @@ mod ffi { #[namespace = "concrete_optimizer::dag"] #[derive(Debug, Clone, Default)] pub struct DagSolution { - pub input_lwe_dimension: u64, //n_big - pub internal_ks_output_lwe_dimension: u64, //n_small - pub ks_decomposition_level_count: u64, //l(KS) - pub ks_decomposition_base_log: u64, //b(KS) - pub glwe_polynomial_size: u64, //N - pub glwe_dimension: u64, //k - pub br_decomposition_level_count: u64, //l(BR) - pub br_decomposition_base_log: u64, //b(BR) + pub input_lwe_dimension: u64, + //n_big + pub internal_ks_output_lwe_dimension: u64, + //n_small + pub ks_decomposition_level_count: u64, + //l(KS) + pub ks_decomposition_base_log: u64, + //b(KS) + pub glwe_polynomial_size: u64, + //N + pub glwe_dimension: u64, + //k + pub br_decomposition_level_count: u64, + //l(BR) + pub br_decomposition_base_log: u64, + //b(BR) pub complexity: f64, pub noise_max: f64, - pub p_error: f64, // error probability + pub p_error: f64, + // error probability pub global_p_error: f64, pub use_wop_pbs: bool, pub cb_decomposition_level_count: u64, @@ -875,6 +896,7 @@ mod ffi { pub input_key: SecretLweKey, pub output_key: SecretLweKey, pub br_decomposition_parameter: BrDecompositionParameters, + pub unitary_cost: f64, pub description: String, } @@ -885,6 +907,7 @@ mod ffi { pub input_key: SecretLweKey, pub output_key: SecretLweKey, pub ks_decomposition_parameter: KsDecompositionParameters, + pub unitary_cost: f64, pub description: String, } @@ -896,6 +919,7 @@ mod ffi { pub output_key: SecretLweKey, pub ks_decomposition_parameter: KsDecompositionParameters, pub fast_keyswitch: bool, + pub unitary_cost: f64, pub description: String, } diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp index 77a08cc30..29b119e70 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp @@ -1140,6 +1140,7 @@ struct BootstrapKey final { ::concrete_optimizer::dag::SecretLweKey input_key; ::concrete_optimizer::dag::SecretLweKey output_key; ::concrete_optimizer::dag::BrDecompositionParameters br_decomposition_parameter; + double unitary_cost; ::rust::String description; using IsRelocatable = ::std::true_type; @@ -1153,6 +1154,7 @@ struct KeySwitchKey final { ::concrete_optimizer::dag::SecretLweKey input_key; ::concrete_optimizer::dag::SecretLweKey output_key; ::concrete_optimizer::dag::KsDecompositionParameters ks_decomposition_parameter; + double unitary_cost; ::rust::String description; using IsRelocatable = ::std::true_type; @@ -1167,6 +1169,7 @@ struct ConversionKeySwitchKey final { ::concrete_optimizer::dag::SecretLweKey output_key; ::concrete_optimizer::dag::KsDecompositionParameters ks_decomposition_parameter; bool fast_keyswitch; + double unitary_cost; ::rust::String description; using IsRelocatable = ::std::true_type; diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp index 4636598f5..15451efdc 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp @@ -1121,6 +1121,7 @@ struct BootstrapKey final { ::concrete_optimizer::dag::SecretLweKey input_key; ::concrete_optimizer::dag::SecretLweKey output_key; ::concrete_optimizer::dag::BrDecompositionParameters br_decomposition_parameter; + double unitary_cost; ::rust::String description; using IsRelocatable = ::std::true_type; @@ -1134,6 +1135,7 @@ struct KeySwitchKey final { ::concrete_optimizer::dag::SecretLweKey input_key; ::concrete_optimizer::dag::SecretLweKey output_key; ::concrete_optimizer::dag::KsDecompositionParameters ks_decomposition_parameter; + double unitary_cost; ::rust::String description; using IsRelocatable = ::std::true_type; @@ -1148,6 +1150,7 @@ struct ConversionKeySwitchKey final { ::concrete_optimizer::dag::SecretLweKey output_key; ::concrete_optimizer::dag::KsDecompositionParameters ks_decomposition_parameter; bool fast_keyswitch; + double unitary_cost; ::rust::String description; using IsRelocatable = ::std::true_type; diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/keys_spec.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/keys_spec.rs index f2c0626a6..1215a4cb5 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/keys_spec.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/keys_spec.rs @@ -24,7 +24,13 @@ pub struct SecretLweKey { pub description: String, } -#[derive(Debug, Clone, PartialEq, Eq)] +impl SecretLweKey { + pub fn size(&self) -> u64 { + self.polynomial_size * self.glwe_dimension + } +} + +#[derive(Debug, Clone)] pub struct BootstrapKey { /* Public TLU bootstrap keys */ pub identifier: BootstrapKeyId, @@ -32,19 +38,21 @@ pub struct BootstrapKey { pub output_key: SecretLweKey, pub br_decomposition_parameter: BrDecompositionParameters, pub description: String, + pub unitary_cost: f64, } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq)] pub struct KeySwitchKey { /* Public TLU keyswitch keys */ pub identifier: KeySwitchKeyId, pub input_key: SecretLweKey, pub output_key: SecretLweKey, pub ks_decomposition_parameter: KsDecompositionParameters, + pub unitary_cost: f64, pub description: String, } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq)] pub struct ConversionKeySwitchKey { /* Public conversion to make compatible ciphertext with incompatible keys. It's currently only between two big secret keys. */ @@ -53,6 +61,7 @@ pub struct ConversionKeySwitchKey { pub output_key: SecretLweKey, pub ks_decomposition_parameter: KsDecompositionParameters, pub fast_keyswitch: bool, + pub unitary_cost: f64, pub description: String, } @@ -153,6 +162,7 @@ impl CircuitSolution { log2_base: sol.ks_decomposition_base_log, }, description: "tlu keyswitch".into(), + unitary_cost: 0.0, }; let bootstrap_key = BootstrapKey { identifier: 0, @@ -162,6 +172,7 @@ impl CircuitSolution { level: sol.br_decomposition_level_count, log2_base: sol.br_decomposition_base_log, }, + unitary_cost: 0.0, description: "tlu bootstrap".into(), }; let circuit_bootstrap_key = CircuitBoostrapKey { @@ -206,7 +217,7 @@ impl CircuitSolution { } else { "No crypto-parameters for the given constraints" } - .into(); + .into(); Self { circuit_keys, instructions_keys, @@ -226,7 +237,7 @@ impl CircuitSolution { } else { "No crypto-parameters for the given constraints" } - .into(); + .into(); let big_key = SecretLweKey { identifier: 0, polynomial_size: sol.glwe_polynomial_size, @@ -277,6 +288,7 @@ impl CircuitSolution { log2_base: sol.ks_decomposition_base_log, }, description: "tlu keyswitch".into(), + unitary_cost: 0.0, }; let bootstrap_key = BootstrapKey { identifier: 0, @@ -286,6 +298,7 @@ impl CircuitSolution { level: sol.br_decomposition_level_count, log2_base: sol.br_decomposition_base_log, }, + unitary_cost: 0.0, description: "tlu bootstrap".into(), }; let instruction_keys = InstructionKeys { @@ -373,13 +386,18 @@ impl ExpandedCircuitKeys { .iter() .enumerate() .map(|(i, v): (usize, &Option<_>)| { - let br_decomposition_parameter = v.unwrap().decomp; + let bs = v.unwrap(); + let br_decomposition_parameter = bs.decomp; + let input_key = small_secret_keys[i].clone(); + #[allow(clippy::cast_sign_loss)] + let unitary_cost = bs.complexity_br(input_key.size()); BootstrapKey { identifier: i as Id, - input_key: small_secret_keys[i].clone(), + input_key, output_key: big_secret_keys[i].clone(), br_decomposition_parameter, description: format!("pbs[{i}]"), + unitary_cost, } }) .collect(); @@ -399,12 +417,18 @@ impl ExpandedCircuitKeys { }; if let Some(ks) = params.micro_params.ks[src][dst] { let identifier = identifier_ks; + + let input_key = big_secret_keys[src].clone(); + #[allow(clippy::cast_sign_loss)] + let unitary_cost = ks.complexity(input_key.size()); + keyswitch_keys[src][dst] = Some(KeySwitchKey { identifier, input_key: big_secret_keys[src].clone(), output_key: small_secret_keys[dst].clone(), ks_decomposition_parameter: ks.decomp, description: cross_key("ks"), + unitary_cost, }); identifier_ks += 1; } @@ -416,6 +440,7 @@ impl ExpandedCircuitKeys { output_key: big_secret_keys[dst].clone(), ks_decomposition_parameter: fks.decomp, fast_keyswitch: REAL_FAST_KS, + unitary_cost: 0.0, description: cross_key("fks"), }); identifier_fks += 1; @@ -449,7 +474,7 @@ impl ExpandedCircuitKeys { let final_id_out = final_key_id(&key.output_key); let canon_key = (final_id_in, final_id_out, key.br_decomposition_parameter); #[allow(clippy::option_if_let_else)] - let final_bootstrap = + let final_bootstrap = if let Some(final_bootstrap) = canon_final_bootstraps.get(&canon_key) { final_bootstrap.clone() } else { @@ -489,7 +514,7 @@ impl ExpandedCircuitKeys { let final_id_out = final_key_id(&key.output_key); let canon_key = (final_id_in, final_id_out, key.ks_decomposition_parameter); #[allow(clippy::option_if_let_else)] - let final_keyswitch = if let Some(final_keyswitch) = + let final_keyswitch = if let Some(final_keyswitch) = canon_final_keyswitchs.get(&canon_key) { keyswitch_keys[i][j] = None; @@ -544,7 +569,7 @@ impl ExpandedCircuitKeys { } let canon_key = (final_id_in, final_id_out, key.ks_decomposition_parameter); #[allow(clippy::option_if_let_else)] - let final_c_keyswitch = if let Some(final_c_keyswitch) = + let final_c_keyswitch = if let Some(final_c_keyswitch) = canon_final_c_keyswitchs.get(&canon_key) { conversion_keyswitch_keys[i][j] = None; diff --git a/frontends/concrete-python/concrete/fhe/compilation/circuit.py b/frontends/concrete-python/concrete/fhe/compilation/circuit.py index d4fc79fb3..0e840155d 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/circuit.py +++ b/frontends/concrete-python/concrete/fhe/compilation/circuit.py @@ -389,6 +389,26 @@ def complexity(self) -> float: """ return self._property("complexity") # pragma: no cover + def complexity_per_tag(self) -> Dict[str, float]: + """ + Get the complexity of each tag in the computation graph. + + Returns: + Dict[str, float]: + complexity per tag + """ + return self._property("complexity_per_tag") + + def complexity_per_node(self) -> Dict[str, float]: + """ + Get the complexity of each node in the computation graph. + + Returns: + Dict[str, float]: + complexity per node + """ + return self._property("complexity_per_node") + # Programmable Bootstrap Statistics @property diff --git a/frontends/concrete-python/concrete/fhe/compilation/compiler.py b/frontends/concrete-python/concrete/fhe/compilation/compiler.py index 9125ad112..a43053e29 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/compiler.py +++ b/frontends/concrete-python/concrete/fhe/compilation/compiler.py @@ -615,6 +615,13 @@ def pretty(d, indent=0): # pragma: no cover if isinstance(value, dict): pretty(value, indent + 1) + elif isinstance(value, int): + print(f"{value:_}") + elif isinstance(value, float): + if round(value) == value: + print(f"{int(value):_}") + else: + print(f"{value:_}") else: print(value) diff --git a/frontends/concrete-python/concrete/fhe/compilation/server.py b/frontends/concrete-python/concrete/fhe/compilation/server.py index 5a13fdef9..09e038505 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/server.py +++ b/frontends/concrete-python/concrete/fhe/compilation/server.py @@ -432,6 +432,33 @@ def complexity(self) -> float: """ return self._compilation_feedback.complexity + @property + def complexity(self) -> float: + """ + Get complexity of the compiled program. + """ + return self._compilation_feedback.complexity + + def complexity_per_tag(self, function: str = "main") -> Dict[str, float]: + """ + Get the complexity of each tag in the computation graph. + + Returns: + Dict[str, float]: + complexity per tag + """ + return self._compilation_feedback.circuit(function).complexity_per_tag() + + def complexity_per_node(self, function: str = "main") -> Dict[str, float]: + """ + Get the complexity of each node in the computation graph. + + Returns: + Dict[str, float]: + complexity per node + """ + return self._compilation_feedback.circuit(function).complexity_per_node() + def size_of_inputs(self, function: str = "main") -> int: """ Get size of the inputs of the compiled program. @@ -772,6 +799,8 @@ def statistics(self) -> Dict: "encrypted_negation_count_per_parameter", "encrypted_negation_count_per_tag", "encrypted_negation_count_per_tag_per_parameter", + "complexity_per_tag", + "complexity_per_node", ] output = {attribute: getattr(self, attribute)() for attribute in attributes} output["size_of_secret_keys"] = self.size_of_secret_keys diff --git a/frontends/concrete-python/concrete/fhe/mlir/context.py b/frontends/concrete-python/concrete/fhe/mlir/context.py index ad553df28..d8cfe2e88 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/context.py +++ b/frontends/concrete-python/concrete/fhe/mlir/context.py @@ -136,7 +136,7 @@ def location(self) -> MlirLocation: tag = "" if self.converting.tag == "" else f"@{self.converting.tag} | " return MlirLocation.file( - f"{tag}{path}", + f"{self.converting.properties['id']} | {tag}{path}", line=int(lineno), col=0, context=self.context, diff --git a/frontends/concrete-python/concrete/fhe/mlir/converter.py b/frontends/concrete-python/concrete/fhe/mlir/converter.py index dba0dee7f..7cdebca5e 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/converter.py +++ b/frontends/concrete-python/concrete/fhe/mlir/converter.py @@ -232,6 +232,7 @@ def process(self, graphs: Dict[str, Graph]): ), ] + configuration.additional_post_processors + + [AssignNodeIds()] ) for processor in pipeline: diff --git a/frontends/concrete-python/concrete/fhe/mlir/processors/__init__.py b/frontends/concrete-python/concrete/fhe/mlir/processors/__init__.py index 93c43b2c3..ba4010c0c 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/processors/__init__.py +++ b/frontends/concrete-python/concrete/fhe/mlir/processors/__init__.py @@ -5,6 +5,7 @@ # pylint: disable=unused-import from .assign_bit_widths import AssignBitWidths +from .assign_node_ids import AssignNodeIds from .check_integer_only import CheckIntegerOnly from .process_rounding import ProcessRounding diff --git a/frontends/concrete-python/concrete/fhe/mlir/processors/assign_node_ids.py b/frontends/concrete-python/concrete/fhe/mlir/processors/assign_node_ids.py new file mode 100644 index 000000000..87e10f2fc --- /dev/null +++ b/frontends/concrete-python/concrete/fhe/mlir/processors/assign_node_ids.py @@ -0,0 +1,28 @@ +""" +Declaration of `AssignNodeIds` graph processor. +""" + +from itertools import chain +from typing import Dict, List + +import z3 + +from ...compilation.configuration import ( + BitwiseStrategy, + ComparisonStrategy, + MinMaxStrategy, + MultivariateStrategy, +) +from ...dtypes import Integer +from ...representation import Graph, MultiGraphProcessor, Node, Operation + + +class AssignNodeIds(MultiGraphProcessor): + """ + AssignNodeIds graph processor, to assign node id (%0, %1, etc.) to node properties. + """ + + def apply_many(self, graphs: Dict[str, Graph]): + for graph_name, graph in graphs.items(): + for index, node in enumerate(graph.query_nodes(ordered=True)): + node.properties["id"] = f"%{index}" diff --git a/frontends/concrete-python/concrete/fhe/representation/graph.py b/frontends/concrete-python/concrete/fhe/representation/graph.py index 2ea83e4bf..e70cbcb7b 100644 --- a/frontends/concrete-python/concrete/fhe/representation/graph.py +++ b/frontends/concrete-python/concrete/fhe/representation/graph.py @@ -575,7 +575,9 @@ def format_bit_width_constraints(self) -> str: if len(node.bit_width_constraints) > 0: result += f"%{i}:\n" for constraint in node.bit_width_constraints: - result += f" {constraint.arg(0)} {constraint.decl()} {constraint.arg(1)}\n" + lhs_id = str(constraint.arg(0)).split(".")[-1] + rhs_id = str(constraint.arg(1)).split(".")[-1] + result += f" {lhs_id} {constraint.decl()} {rhs_id}\n" return result[:-1] def format_bit_width_assignments(self) -> str: @@ -591,10 +593,11 @@ def format_bit_width_assignments(self) -> str: for variable in self.bit_width_assignments.decls(): # type: ignore if variable.name().startswith(f"{self.name}.") or variable.name() == "input_output": width = self.bit_width_assignments.get_interp(variable) # type: ignore - lines.append(f"{variable} = {width}") + variable_id = variable.name().split(".")[-1] + lines.append(f"{variable_id} = {width}") def sorter(line: str) -> int: - if line.startswith(f"{self.name}.max"): + if line.startswith(f"max"): # we won't have 4 million nodes... return 2**32 if line.startswith("input_output"): @@ -602,7 +605,7 @@ def sorter(line: str) -> int: return 2**32 equals_position = line.find("=") - index = line[len(self.name) + 2 : equals_position - 1] + index = line[1 : equals_position - 1] return int(index) result = ""