Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(compiler): complexity per node #788

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ def TFHE_KeyswitchKeyAttr: TFHE_Attr<"GLWEKeyswitchKey", "ksk"> {
"mlir::concretelang::TFHE::GLWESecretKey":$outputKey,
"int":$levels,
"int":$baseLog,
DefaultValuedParameter<"int64_t", "-1">: $complexity,
DefaultValuedParameter<"int", "-1">: $index
);

let assemblyFormat = " (`[` $index^ `]`)? `<` $inputKey `,` $outputKey `,` $levels `,` $baseLog `>`";
let assemblyFormat = " (`[` $index^ `]`)? (`{` $complexity^ `}`)? `<` $inputKey `,` $outputKey `,` $levels `,` $baseLog `>`";
}

def TFHE_BootstrapKeyAttr: TFHE_Attr<"GLWEBootstrapKey", "bsk"> {
Expand All @@ -45,10 +46,11 @@ def TFHE_BootstrapKeyAttr: TFHE_Attr<"GLWEBootstrapKey", "bsk"> {
"int":$glweDim,
"int":$levels,
"int":$baseLog,
DefaultValuedParameter<"int64_t", "-1">: $complexity,
DefaultValuedParameter<"int", "-1">: $index
);

let assemblyFormat = "(`[` $index^ `]`)? `<` $inputKey `,` $outputKey `,` $polySize `,` $glweDim `,` $levels `,` $baseLog `>`";
let assemblyFormat = "(`[` $index^ `]`)? (`{` $complexity^ `}`)? `<` $inputKey `,` $outputKey `,` $polySize `,` $glweDim `,` $levels `,` $baseLog `>`";
}

def TFHE_PackingKeyswitchKeyAttr: TFHE_Attr<"GLWEPackingKeyswitchKey", "pksk"> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ struct Statistic {
PrimitiveOperation operation;
std::vector<std::pair<KeyType, int64_t>> keys;
std::optional<int64_t> count;
double complexity;
};

struct CircuitCompilationFeedback {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
.def_readonly("operation", &mlir::concretelang::Statistic::operation)
.def_readonly("location", &mlir::concretelang::Statistic::location)
.def_readonly("keys", &mlir::concretelang::Statistic::keys)
.def_readonly("count", &mlir::concretelang::Statistic::count);
.def_readonly("count", &mlir::concretelang::Statistic::count)
.def_readonly("complexity", &mlir::concretelang::Statistic::complexity);

pybind11::class_<mlir::concretelang::ProgramCompilationFeedback>(
m, "ProgramCompilationFeedback")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,23 @@


# matches (@tag, separator( | ), filename)
REGEX_LOCATION = re.compile(r"loc\(\"(@[\w\.]+)?( \| )?(.+)\"")
REGEX_LOCATION = re.compile(r"loc\(\"(%\d*) \| (@[\w\.]+)?( \| )?(.+)\"")


def tag_from_location(location):
def id_and_tag_from_location(location):
"""
Extract tag of the operation from its location.
"""

match = REGEX_LOCATION.match(location)
if match is not None:
tag, _, _ = match.groups()
id_, tag, _, _ = match.groups()
# remove the @
tag = tag[1:] if tag else ""
else:
id_ = ""
tag = ""
return tag
return id_, tag


class CircuitCompilationFeedback(WrapperCpp):
Expand Down Expand Up @@ -149,7 +150,7 @@ def count_per_tag(self, *, operations: Set[PrimitiveOperation]) -> Dict[str, int
if statistic.operation not in operations:
continue

tag = tag_from_location(statistic.location)
_, tag = id_and_tag_from_location(statistic.location)

tag_components = tag.split(".")
for i in range(1, len(tag_components) + 1):
Expand Down Expand Up @@ -194,7 +195,7 @@ def count_per_tag_per_parameter(
if statistic.operation not in operations:
continue

tag = tag_from_location(statistic.location)
_, tag = id_and_tag_from_location(statistic.location)

tag_components = tag.split(".")
for i in range(1, len(tag_components) + 1):
Expand All @@ -216,6 +217,49 @@ def count_per_tag_per_parameter(

return result

def complexity_per_tag(self) -> Dict[str, float]:
"""
Compute the complexity of each tag in the computation graph.

Returns:
Dict[str, float]:
complexity per tag
"""

result = {}
for statistic in self.statistics:
_, tag = id_and_tag_from_location(statistic.location)

tag_components = tag.split(".")
for i in range(1, len(tag_components) + 1):
current_tag = ".".join(tag_components[0:i])
if current_tag == "":
continue

if current_tag not in result:
result[current_tag] = 0.0

result[current_tag] += statistic.complexity

return result

def complexity_per_node(self) -> Dict[str, float]:
"""
Compute the complexity of each node in the computation graph.

Returns:
Dict[str, float]:
complexity per node
"""

result = {}
for statistic in self.statistics:
node_id, _ = id_and_tag_from_location(statistic.location)
if node_id not in result:
result[node_id] = 0.0
result[node_id] += statistic.complexity
return result


class ProgramCompilationFeedback(WrapperCpp):
"""CompilationFeedback is a set of hint computed by the compiler engine."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -585,10 +585,10 @@ struct ApplyLookupTableEintOpPattern
op.getLoc(), converter->convertType(op.getType()), adaptor.getA(),
newLut,
TFHE::GLWEKeyswitchKeyAttr::get(op.getContext(), TFHE::GLWESecretKey(),
TFHE::GLWESecretKey(), -1, -1, -1),
TFHE::GLWESecretKey(), -1, -1, -1, -1),
TFHE::GLWEBootstrapKeyAttr::get(op.getContext(), TFHE::GLWESecretKey(),
TFHE::GLWESecretKey(), -1, -1, -1, -1,
-1),
-1, -1),
TFHE::GLWEPackingKeyswitchKeyAttr::get(
op.getContext(), TFHE::GLWESecretKey(), TFHE::GLWESecretKey(), -1,
-1, -1, -1, -1, -1),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ struct ApplyLookupTableEintOpPattern
op.getLoc(), getTypeConverter()->convertType(adaptor.getA().getType()),
input,
TFHE::GLWEKeyswitchKeyAttr::get(op.getContext(), TFHE::GLWESecretKey(),
TFHE::GLWESecretKey(), -1, -1, -1));
TFHE::GLWESecretKey(), -1, -1, -1, -1));
if (operatorIndexes != nullptr) {
ksOp->setAttr("TFHE.OId",
rewriter.getI32IntegerAttr(
Expand All @@ -398,7 +398,7 @@ struct ApplyLookupTableEintOpPattern
op, getTypeConverter()->convertType(op.getType()), ksOp, newLut,
TFHE::GLWEBootstrapKeyAttr::get(op.getContext(), TFHE::GLWESecretKey(),
TFHE::GLWESecretKey(), -1, -1, -1, -1,
-1));
-1, -1));
if (operatorIndexes != nullptr) {
bsOp->setAttr("TFHE.OId",
rewriter.getI32IntegerAttr(
Expand Down Expand Up @@ -515,9 +515,9 @@ std::vector<mlir::Value> extractBitWithClearedLowerBits(
auto context = op.getContext();
auto secretKey = TFHE::GLWESecretKey();
auto ksk = TFHE::GLWEKeyswitchKeyAttr::get(context, secretKey, secretKey, -1,
-1, -1);
-1, -1, -1);
auto bsk = TFHE::GLWEBootstrapKeyAttr::get(context, secretKey, secretKey, -1,
-1, -1, -1, -1);
-1, -1, -1, -1, -1);

auto keyswitched = rewriter.create<TFHE::KeySwitchGLWEOp>(
loc, cInputTy, shiftedRotatedInput, ksk);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ struct KeySwitchGLWEOpPattern
auto newOutputKey = converter.getIntraPBSKey();
auto keyswitchKey = TFHE::GLWEKeyswitchKeyAttr::get(
ksOp->getContext(), newInputKey, newOutputKey, cryptoParameters.ksLevel,
cryptoParameters.ksLogBase, -1);
cryptoParameters.ksLogBase, -1, -1);
auto newOp = rewriter.replaceOpWithNewOp<TFHE::KeySwitchGLWEOp>(
ksOp, newOutputTy, ksOp.getCiphertext(), keyswitchKey);
rewriter.startRootUpdate(newOp);
Expand Down Expand Up @@ -156,7 +156,7 @@ struct BootstrapGLWEOpPattern
auto bootstrapKey = TFHE::GLWEBootstrapKeyAttr::get(
bsOp->getContext(), newInputKey, newOutputKey,
cryptoParameters.getPolynomialSize(), cryptoParameters.glweDimension,
cryptoParameters.brLevel, cryptoParameters.brLogBase, -1);
cryptoParameters.brLevel, cryptoParameters.brLogBase, -1, -1);
auto newOp = rewriter.replaceOpWithNewOp<TFHE::BootstrapGLWEOp>(
bsOp, newOutputTy, bsOp.getCiphertext(), bsOp.getLookupTable(),
bootstrapKey);
Expand Down Expand Up @@ -193,11 +193,11 @@ struct WopPBSGLWEOpPattern : public mlir::OpRewritePattern<TFHE::WopPBSGLWEOp> {
auto intraKey = converter.getIntraPBSKey();
auto keyswitchKey = TFHE::GLWEKeyswitchKeyAttr::get(
wopPBSOp->getContext(), interKey, intraKey, cryptoParameters.ksLevel,
cryptoParameters.ksLogBase, -1);
cryptoParameters.ksLogBase, -1, -1);
auto bootstrapKey = TFHE::GLWEBootstrapKeyAttr::get(
wopPBSOp->getContext(), intraKey, interKey,
cryptoParameters.getPolynomialSize(), cryptoParameters.glweDimension,
cryptoParameters.brLevel, cryptoParameters.brLogBase, -1);
cryptoParameters.brLevel, cryptoParameters.brLogBase, -1, -1);
auto packingKeyswitchKey = TFHE::GLWEPackingKeyswitchKeyAttr::get(
wopPBSOp->getContext(), interKey, interKey,
cryptoParameters.largeInteger->wopPBS.packingKeySwitch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ class KeyConverter {
bsk.getContext(), convertSecretKey(bsk.getInputKey()),
convertSecretKey(bsk.getOutputKey()), bsk.getPolySize(),
bsk.getGlweDim(), bsk.getLevels(), bsk.getBaseLog(),
circuitKeys.getBootstrapKeyIndex(bsk).value());
bsk.getComplexity(), circuitKeys.getBootstrapKeyIndex(bsk).value());
}

TFHE::GLWEKeyswitchKeyAttr
convertKeyswitchKey(TFHE::GLWEKeyswitchKeyAttr ksk) {
return TFHE::GLWEKeyswitchKeyAttr::get(
ksk.getContext(), convertSecretKey(ksk.getInputKey()),
convertSecretKey(ksk.getOutputKey()), ksk.getLevels(), ksk.getBaseLog(),
circuitKeys.getKeyswitchKeyIndex(ksk).value());
ksk.getComplexity(), circuitKeys.getKeyswitchKeyIndex(ksk).value());
}

TFHE::GLWEPackingKeyswitchKeyAttr
Expand Down
Loading
Loading