Skip to content

Commit

Permalink
refactor(frontends): unify circuits and modules
Browse files Browse the repository at this point in the history
  • Loading branch information
aPere3 committed Sep 24, 2024
1 parent 52636e4 commit d9b34f1
Show file tree
Hide file tree
Showing 33 changed files with 923 additions and 1,542 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1310,6 +1310,15 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
clientParameters, inputId, circuitName);
return encryption.getVariance();
})
.def("function_list",
[](::concretelang::clientlib::ClientParameters &clientParameters) {
std::vector<std::string> result;
for (auto circuit :
clientParameters.programInfo.asReader().getCircuits()) {
result.push_back(circuit.getName());
}
return result;
})
.def("output_signs",
[](::concretelang::clientlib::ClientParameters &clientParameters) {
std::vector<bool> result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,12 @@ def __init__(self, client_parameters: _ClientParameters):
)
super().__init__(client_parameters)

def input_keyid_at(self, input_idx: int, circuit_name: str = "main") -> int:
def input_keyid_at(self, input_idx: int, circuit_name: str) -> int:
"""Get the keyid of a selected encrypted input in a given circuit.
Args:
input_idx (int): index of the input in the circuit.
circuit_name (str, optional): name of the circuit containing the desired input.
Defaults to "main".
circuit_name (str): name of the circuit containing the desired input.
Raises:
TypeError: if arguments aren't of expected types
Expand All @@ -59,13 +58,12 @@ def input_keyid_at(self, input_idx: int, circuit_name: str = "main") -> int:
)
return self.cpp().input_keyid_at(input_idx, circuit_name)

def input_variance_at(self, input_idx: int, circuit_name: str = "main") -> float:
def input_variance_at(self, input_idx: int, circuit_name: str) -> float:
"""Get the variance of a selected encrypted input in a given circuit.
Args:
input_idx (int): index of the input in the circuit.
circuit_name (str, optional): name of the circuit containing the desired input.
Defaults to "main".
circuit_name (str): name of the circuit containing the desired input.
Raises:
TypeError: if arguments aren't of expected types
Expand Down Expand Up @@ -97,6 +95,14 @@ def output_signs(self) -> List[bool]:
"""
return self.cpp().output_signs()

def function_list(self) -> List[str]:
"""Return the list of function names.
Returns:
List[str]: list of the names of the functions.
"""
return self.cpp().function_list()

def serialize(self) -> bytes:
"""Serialize the ClientParameters.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def encrypt_arguments(
client_parameters: ClientParameters,
keyset: KeySet,
args: List[Union[int, np.ndarray]],
circuit_name: str = "main",
circuit_name: str,
) -> PublicArguments:
"""Prepare arguments for encrypted computation.
Expand Down Expand Up @@ -172,7 +172,7 @@ def decrypt_result(
client_parameters: ClientParameters,
keyset: KeySet,
public_result: PublicResult,
circuit_name: str = "main",
circuit_name: str,
) -> Union[int, np.ndarray]:
"""Decrypt a public result using the keyset.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def load_server_lambda(
self,
library_compilation_result: LibraryCompilationResult,
simulation: bool,
circuit_name: str = "main",
circuit_name: str,
) -> LibraryLambda:
"""Load the server lambda for a given circuit from the library compilation result.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, value_decrypter: _SimulatedValueDecrypter):

@staticmethod
# pylint: disable=arguments-differ
def new(client_parameters: ClientParameters, circuit_name: str = "main"):
def new(client_parameters: ClientParameters, circuit_name: str):
"""
Create a value decrypter.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, value_exporter: _SimulatedValueExporter):
@staticmethod
# pylint: disable=arguments-differ
def new(
client_parameters: ClientParameters, circuitName: str = "main"
client_parameters: ClientParameters, circuitName: str
) -> "SimulatedValueExporter":
"""
Create a value exporter.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, value_exporter: _ValueExporter):
@staticmethod
# pylint: disable=arguments-differ
def new(
keyset: KeySet, client_parameters: ClientParameters, circuit_name: str = "main"
keyset: KeySet, client_parameters: ClientParameters, circuit_name: str
) -> "ValueExporter":
"""
Create a value exporter.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def assert_result(result, expected_result):
assert np.all(result == expected_result)


def run(engine, args, compilation_result, keyset_cache, circuit_name="main"):
def run(engine, args, compilation_result, keyset_cache, circuit_name):
"""Execute engine on the given arguments.
Perform required loading, encryption, execution, and decryption."""
Expand Down Expand Up @@ -219,7 +219,7 @@ def test_lib_compile_reload_and_run(mlir_input, args, expected_result, keyset_ca
# Here don't save compilation result, reload
engine.compile(mlir_input)
compilation_result = engine.reload()
result = run(engine, args, compilation_result, keyset_cache)
result = run(engine, args, compilation_result, keyset_cache, "main")
# Check result
assert_result(result, expected_result)
shutil.rmtree(artifact_dir)
Expand Down Expand Up @@ -398,7 +398,7 @@ def test_compile_and_run_invalid_arg_number(mlir_input, args, keyset_cache):

def test_crt_decomposition_feedback():
mlir = """
func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> {
%tlu = arith.constant dense<60000> : tensor<65536xi64>
%1 = "FHE.apply_lookup_table"(%arg0, %tlu): (!FHE.eint<16>, tensor<65536xi64>) -> (!FHE.eint<16>)
Expand Down
19 changes: 5 additions & 14 deletions frontends/concrete-python/concrete/fhe/compilation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .artifacts import DebugArtifacts, FunctionDebugArtifacts, ModuleDebugArtifacts
from .circuit import Circuit
from .client import Client
from .compiler import Compiler, EncryptionStatus
from .compiler import Compiler
from .composition import CompositionClause, CompositionPolicy, CompositionRule
from .configuration import (
DEFAULT_GLOBAL_P_ERROR,
Expand All @@ -22,19 +22,10 @@
)
from .keys import Keys
from .module import FheFunction, FheModule
from .module_compiler import (
AllComposable,
AllInputs,
AllOutputs,
FunctionDef,
Input,
ModuleCompiler,
NotComposable,
Output,
Wire,
Wired,
)
from .module_compiler import FunctionDef, ModuleCompiler
from .server import Server
from .specs import ClientSpecs
from .utils import inputset
from .status import EncryptionStatus
from .utils import get_terminal_size, inputset
from .value import Value
from .wiring import AllComposable, AllInputs, AllOutputs, Input, NotComposable, Output, Wire, Wired
Loading

0 comments on commit d9b34f1

Please sign in to comment.