Skip to content

Commit

Permalink
refactor(compiler): generalize noise calculation in FHE/FHELinalg
Browse files Browse the repository at this point in the history
Instead of having one `getSQManp` implementation per op with a lot of repetition, the noise
calculation is now modular.

- Ops that implements`UnaryEint`/`BinaryInt`/`BinaryEint` interfaces share the operand noise
presence check.
- For many scalar ops no further calculation is needed. If it's not the case, an op can override
`sqMANP`.
- Integer operand types lookups are abstracted into `BinaryInt::operandIntType()`
- Finding largest operand value for a type is abstracted into `BinaryInt::operandMaxConstant`
- Noise calculation for matmul ops is simplified and it's now general enough to work for
`matmul_eint_int`, `matmul_int_eint` and `dot_eint_int` at once.
  • Loading branch information
mkmks authored and youben11 committed Sep 12, 2023
1 parent d71201f commit 73a992f
Show file tree
Hide file tree
Showing 31 changed files with 850 additions and 1,227 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(Interfaces)
add_subdirectory(Analysis)
add_subdirectory(IR)
add_subdirectory(Transforms)
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
set(LLVM_TARGET_DEFINITIONS FHEInterfaces.td)
mlir_tablegen(FHETypesInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(FHETypesInterfaces.cpp.inc -gen-type-interface-defs)

set(LLVM_TARGET_DEFINITIONS FHEOps.td)
mlir_tablegen(FHEOps.h.inc -gen-op-decls)
mlir_tablegen(FHEOps.cpp.inc -gen-op-defs)
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ include "mlir/Interfaces/ControlFlowInterfaces.td"

include "concretelang/Dialect/FHE/IR/FHEDialect.td"
include "concretelang/Dialect/FHE/IR/FHETypes.td"
include "concretelang/Dialect/FHE/Interfaces/FHEInterfaces.td"

class FHE_Op<string mnemonic, list<Trait> traits = []> :
Op<FHE_Dialect, mnemonic, traits>;

def FHE_ZeroEintOp : FHE_Op<"zero", [Pure]> {
def FHE_ZeroEintOp : FHE_Op<"zero", [Pure, ConstantNoise]> {
let summary = "Returns a trivial encrypted integer of 0";

let description = [{
Expand All @@ -33,7 +34,7 @@ def FHE_ZeroEintOp : FHE_Op<"zero", [Pure]> {
let results = (outs FHE_AnyEncryptedInteger:$out);
}

def FHE_ZeroTensorOp : FHE_Op<"zero_tensor", [Pure]> {
def FHE_ZeroTensorOp : FHE_Op<"zero_tensor", [Pure, ConstantNoise]> {
let summary = "Creates a new tensor with all elements initialized to an encrypted zero.";

let description = [{
Expand All @@ -51,7 +52,7 @@ def FHE_ZeroTensorOp : FHE_Op<"zero_tensor", [Pure]> {
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$tensor);
}

def FHE_AddEintIntOp : FHE_Op<"add_eint_int", [Pure]> {
def FHE_AddEintIntOp : FHE_Op<"add_eint_int", [Pure, BinaryEintInt, DeclareOpInterfaceMethods<Binary>]> {
let summary = "Adds an encrypted integer and a clear integer";

let description = [{
Expand Down Expand Up @@ -84,7 +85,7 @@ def FHE_AddEintIntOp : FHE_Op<"add_eint_int", [Pure]> {
let hasFolder = 1;
}

def FHE_AddEintOp : FHE_Op<"add_eint", [Pure]> {
def FHE_AddEintOp : FHE_Op<"add_eint", [Pure, BinaryEint, DeclareOpInterfaceMethods<BinaryEint>]> {
let summary = "Adds two encrypted integers";

let description = [{
Expand Down Expand Up @@ -116,7 +117,7 @@ def FHE_AddEintOp : FHE_Op<"add_eint", [Pure]> {
let hasVerifier = 1;
}

def FHE_SubIntEintOp : FHE_Op<"sub_int_eint", [Pure]> {
def FHE_SubIntEintOp : FHE_Op<"sub_int_eint", [Pure, BinaryIntEint]> {
let summary = "Subtract an encrypted integer from a clear integer";

let description = [{
Expand Down Expand Up @@ -148,7 +149,7 @@ def FHE_SubIntEintOp : FHE_Op<"sub_int_eint", [Pure]> {
let hasVerifier = 1;
}

def FHE_SubEintIntOp : FHE_Op<"sub_eint_int", [Pure]> {
def FHE_SubEintIntOp : FHE_Op<"sub_eint_int", [Pure, BinaryEintInt, DeclareOpInterfaceMethods<Binary>]> {
let summary = "Subtract a clear integer from an encrypted integer";

let description = [{
Expand Down Expand Up @@ -181,7 +182,7 @@ def FHE_SubEintIntOp : FHE_Op<"sub_eint_int", [Pure]> {
let hasFolder = 1;
}

def FHE_SubEintOp : FHE_Op<"sub_eint", [Pure]> {
def FHE_SubEintOp : FHE_Op<"sub_eint", [Pure, BinaryEint, DeclareOpInterfaceMethods<BinaryEint>]> {
let summary = "Subtract an encrypted integer from an encrypted integer";

let description = [{
Expand Down Expand Up @@ -213,7 +214,7 @@ def FHE_SubEintOp : FHE_Op<"sub_eint", [Pure]> {
let hasVerifier = 1;
}

def FHE_NegEintOp : FHE_Op<"neg_eint", [Pure]> {
def FHE_NegEintOp : FHE_Op<"neg_eint", [Pure, UnaryEint, DeclareOpInterfaceMethods<UnaryEint>]> {

let summary = "Negates an encrypted integer";

Expand Down Expand Up @@ -243,7 +244,7 @@ def FHE_NegEintOp : FHE_Op<"neg_eint", [Pure]> {
let hasVerifier = 1;
}

def FHE_MulEintIntOp : FHE_Op<"mul_eint_int", [Pure]> {
def FHE_MulEintIntOp : FHE_Op<"mul_eint_int", [Pure, BinaryEintInt, DeclareOpInterfaceMethods<Binary, ["sqMANP"]>]> {
let summary = "Multiply an encrypted integer with a clear integer";

let description = [{
Expand Down Expand Up @@ -277,7 +278,7 @@ def FHE_MulEintIntOp : FHE_Op<"mul_eint_int", [Pure]> {
let hasCanonicalizer = 1;
}

def FHE_MulEintOp : FHE_Op<"mul_eint", [Pure]> {
def FHE_MulEintOp : FHE_Op<"mul_eint", [Pure, BinaryEint, DeclareOpInterfaceMethods<BinaryEint>]> {
let summary = "Multiplies two encrypted integers";

let description = [{
Expand Down Expand Up @@ -313,7 +314,7 @@ def FHE_MulEintOp : FHE_Op<"mul_eint", [Pure]> {
let hasVerifier = 1;
}

def FHE_MaxEintOp : FHE_Op<"max_eint", [Pure]> {
def FHE_MaxEintOp : FHE_Op<"max_eint", [Pure, BinaryEint, DeclareOpInterfaceMethods<BinaryEint>]> {
let summary = "Retrieve the maximum of two encrypted integers.";

let description = [{
Expand Down Expand Up @@ -348,7 +349,7 @@ def FHE_MaxEintOp : FHE_Op<"max_eint", [Pure]> {
let hasVerifier = 1;
}

def FHE_ToSignedOp : FHE_Op<"to_signed", [Pure]> {
def FHE_ToSignedOp : FHE_Op<"to_signed", [Pure, UnaryEint, DeclareOpInterfaceMethods<UnaryEint>]> {
let summary = "Cast an unsigned integer to a signed one";

let description = [{
Expand All @@ -366,14 +367,14 @@ def FHE_ToSignedOp : FHE_Op<"to_signed", [Pure]> {
```
}];

let arguments = (ins FHE_EncryptedIntegerType:$input);
let arguments = (ins FHE_EncryptedUnsignedIntegerType:$input);
let results = (outs FHE_EncryptedSignedIntegerType);

let hasVerifier = 1;
let hasCanonicalizer = 1;
}

def FHE_ToUnsignedOp : FHE_Op<"to_unsigned", [Pure]> {
def FHE_ToUnsignedOp : FHE_Op<"to_unsigned", [Pure, UnaryEint, DeclareOpInterfaceMethods<UnaryEint>]> {
let summary = "Cast a signed integer to an unsigned one";

let description = [{
Expand All @@ -392,13 +393,13 @@ def FHE_ToUnsignedOp : FHE_Op<"to_unsigned", [Pure]> {
}];

let arguments = (ins FHE_EncryptedSignedIntegerType:$input);
let results = (outs FHE_EncryptedIntegerType);
let results = (outs FHE_EncryptedUnsignedIntegerType);

let hasVerifier = 1;
let hasCanonicalizer = 1;
}

def FHE_ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table", [Pure]> {
def FHE_ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table", [Pure, ConstantNoise]> {

let summary = "Applies a clear lookup table to an encrypted integer";

Expand All @@ -424,7 +425,7 @@ def FHE_ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table", [Pure]> {
let hasVerifier = 1;
}

def FHE_RoundEintOp: FHE_Op<"round", [Pure]> {
def FHE_RoundEintOp: FHE_Op<"round", [Pure, UnaryEint, DeclareOpInterfaceMethods<UnaryEint, ["sqMANP"]>]> {

let summary = "Rounds a ciphertext to a smaller precision.";

Expand Down Expand Up @@ -556,7 +557,7 @@ def FHE_BoolXorOp : FHE_Op<"xor", [Pure]> {
let results = (outs FHE_EncryptedBooleanType);
}

def FHE_BoolNotOp : FHE_Op<"not", [Pure]> {
def FHE_BoolNotOp : FHE_Op<"not", [Pure, UnaryEint, DeclareOpInterfaceMethods<UnaryEint>]> {

let summary = "Applies a NOT gate to an encrypted boolean value";

Expand All @@ -571,7 +572,7 @@ def FHE_BoolNotOp : FHE_Op<"not", [Pure]> {
let results = (outs FHE_EncryptedBooleanType);
}

def FHE_ToBoolOp : FHE_Op<"to_bool", [Pure]> {
def FHE_ToBoolOp : FHE_Op<"to_bool", [Pure, UnaryEint]> {
let summary = "Cast an unsigned integer to a boolean";

let description = [{
Expand All @@ -589,13 +590,13 @@ def FHE_ToBoolOp : FHE_Op<"to_bool", [Pure]> {
```
}];

let arguments = (ins FHE_EncryptedIntegerType:$input);
let arguments = (ins FHE_EncryptedUnsignedIntegerType:$input);
let results = (outs FHE_EncryptedBooleanType);

let hasVerifier = 1;
}

def FHE_FromBoolOp : FHE_Op<"from_bool", [Pure]> {
def FHE_FromBoolOp : FHE_Op<"from_bool", [Pure, UnaryEint]> {
let summary = "Cast a boolean to an unsigned integer";

let description = [{
Expand All @@ -608,7 +609,7 @@ def FHE_FromBoolOp : FHE_Op<"from_bool", [Pure]> {
}];

let arguments = (ins FHE_EncryptedBooleanType:$input);
let results = (outs FHE_EncryptedIntegerType);
let results = (outs FHE_EncryptedUnsignedIntegerType);
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/DialectImplementation.h>

#include "concretelang/Dialect/FHE/IR/FHETypesInterfaces.h.inc"
#include <mlir/Dialect/Arith/IR/Arith.h>

#include "concretelang/Dialect/FHE/Interfaces/FHEInterfaces.h"

#define GET_TYPEDEF_CLASSES
#include "concretelang/Dialect/FHE/IR/FHEOpsTypes.h.inc"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,20 @@
#define CONCRETELANG_DIALECT_FHE_IR_FHE_TYPES

include "concretelang/Dialect/FHE/IR/FHEDialect.td"
include "concretelang/Dialect/FHE/IR/FHEInterfaces.td"
include "concretelang/Dialect/FHE/Interfaces/FHEInterfaces.td"
include "mlir/IR/BuiltinTypes.td"

class FHE_Type<string name, list<Trait> traits = []> :
TypeDef<FHE_Dialect, name, traits> { }

def FHE_EncryptedIntegerType : FHE_Type<"EncryptedInteger",
def FHE_EncryptedUnsignedIntegerType : FHE_Type<"EncryptedUnsignedInteger",
[MemRefElementTypeInterface, FheIntegerInterface]> {
let mnemonic = "eint";

let summary = "An encrypted integer";
let summary = "An encrypted unsigned integer";

let description = [{
An encrypted integer with `width` bits to performs FHE Operations.
An encrypted unsigned integer with `width` bits to performs FHE Operations.

Examples:
```mlir
Expand Down Expand Up @@ -73,7 +73,7 @@ def FHE_EncryptedSignedIntegerType : FHE_Type<"EncryptedSignedInteger",
}

def FHE_AnyEncryptedInteger : Type<Or<[
FHE_EncryptedIntegerType.predicate,
FHE_EncryptedUnsignedIntegerType.predicate,
FHE_EncryptedSignedIntegerType.predicate
]>>;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
set(LLVM_TARGET_DEFINITIONS FHEInterfaces.td)
mlir_tablegen(FHETypesInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(FHETypesInterfaces.cpp.inc -gen-type-interface-defs)
mlir_tablegen(FHEOpsInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(FHEOpsInterfaces.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(MLIRFHEInterfacesIncGen)
add_dependencies(mlir-generic-headers MLIRFHEInterfacesIncGen)
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.

#ifndef CONCRETELANG_DIALECT_FHE_INTERFACES_FHEINTERFACES_H
#define CONCRETELANG_DIALECT_FHE_INTERFACES_FHEINTERFACES_H

#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"

#include "concretelang/Dialect/FHE/Interfaces/FHEOpsInterfaces.h.inc"
#include "concretelang/Dialect/FHE/Interfaces/FHETypesInterfaces.h.inc"

#endif
Loading

0 comments on commit 73a992f

Please sign in to comment.