Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft/module auto schedule #966

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/guides/configure.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,11 @@ Additional kwargs to `compile` functions take higher precedence. So if you set t
* When this option is set to `True`, encrypt and decrypt are identity functions, and run is a wrapper around simulation. In other words, this option allows to switch off the encryption to quickly test if a function has expected semantic (without paying the price of FHE execution).
* This is extremely unsafe and should only be used during development.
* For this reason, it requires **enable\_unsafe\_features** to be set to `True`.
* **auto\_schedule\_run**: bool = False
* Enable automatic scheduling of `run` method calls. When enabled, fhe function are computated in parallel in a background threads pool. When several `run` are composed, they are automatically synchronized.
* For now, it only works for the `run` method of a `FheModule`, in that case you obtain a `Future[Value]` immediately instead of a `Value` when computation is finished.
* E.g. `my_module.f3.run( my_module.f1.run(a), my_module.f1.run(b) )` will runs `f1` and `f2` in parallel in the background and `f3` in background when both `f1` and `f2` intermediate results are available.
* If you want to manually synchronize on the termination of a full computation, e.g. you want to return the encrypted result, you can call explicitely `value.result()` to wait for the result. To simplify testing, decryption does it automatically.
* Automatic scheduling behavior can be override locally by calling directly a variant of `run`:
* `run_sync`: forces the fhe function to occur in the current thread, not in the background,
* `run_async`: forces the fhe function to occur in a background thread, returning immediately a `Future[Value]`
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,7 @@ class Configuration:
dynamic_assignment_check_out_of_bounds: bool
simulate_encrypt_run_decrypt: bool
composable: bool
auto_schedule_run: bool

def __init__(
self,
Expand Down Expand Up @@ -1063,6 +1064,7 @@ def __init__(
dynamic_indexing_check_out_of_bounds: bool = True,
dynamic_assignment_check_out_of_bounds: bool = True,
simulate_encrypt_run_decrypt: bool = False,
auto_schedule_run: bool = False,
):
self.verbose = verbose
self.compiler_debug_mode = compiler_debug_mode
Expand Down Expand Up @@ -1170,6 +1172,8 @@ def __init__(

self.simulate_encrypt_run_decrypt = simulate_encrypt_run_decrypt

self.auto_schedule_run = auto_schedule_run

self._validate()

class Keep:
Expand Down Expand Up @@ -1245,6 +1249,7 @@ def fork(
dynamic_indexing_check_out_of_bounds: Union[Keep, bool] = KEEP,
dynamic_assignment_check_out_of_bounds: Union[Keep, bool] = KEEP,
simulate_encrypt_run_decrypt: Union[Keep, bool] = KEEP,
auto_schedule_run: Union[Keep, bool] = KEEP,
) -> "Configuration":
"""
Get a new configuration from another one specified changes.
Expand Down
127 changes: 123 additions & 4 deletions frontends/concrete-python/concrete/fhe/compilation/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

# pylint: disable=import-error,no-member,no-name-in-module

import asyncio
from concurrent.futures import Future, ThreadPoolExecutor
from pathlib import Path
from threading import Thread
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union

import numpy as np
Expand All @@ -29,13 +32,40 @@
# pylint: enable=import-error,no-member,no-name-in-module


class ExecutionRt(NamedTuple):
class ExecutionRt:
"""
Runtime object class for execution.
"""

client: Client
server: Server
auto_schedule_run: bool
fhe_executor_pool: ThreadPoolExecutor
fhe_waiter_loop: asyncio.BaseEventLoop
fhe_waiter_thread: Thread # daemon thread

def __init__(self, client, server, auto_schedule_run):
self.client = client
self.server = server
self.auto_schedule_run = auto_schedule_run
if auto_schedule_run:
self.fhe_executor_pool = ThreadPoolExecutor()
self.fhe_waiter_loop = asyncio.new_event_loop()

def loop_thread():
asyncio.set_event_loop(self.fhe_waiter_loop)
self.fhe_waiter_loop.run_forever()

self.fhe_waiter_thread = Thread(target=loop_thread, args=(), daemon=True)
self.fhe_waiter_thread.start()
else:
self.fhe_executor_pool = None
self.fhe_waiter_loop = None
self.fhe_waiter_thread = None

def __del__(self):
if self.fhe_waiter_loop:
self.fhe_waiter_loop.stop() # daemon cleanup


class SimulationRt(NamedTuple):
Expand Down Expand Up @@ -186,10 +216,48 @@ def encrypt(
assert isinstance(self.runtime, ExecutionRt)
return self.runtime.client.encrypt(*args, function_name=self.name)

def run(
def run_sync(
umut-sahin marked this conversation as resolved.
Show resolved Hide resolved
self,
*args: Optional[Union[Value, Tuple[Optional[Value], ...]]],
) -> Union[Value, Tuple[Value, ...]]:
"""
Evaluate the function synchronuously.

Args:
*args (Value):
argument(s) for evaluation

Returns:
Union[Value, Tuple[Value, ...]]:
result(s) of evaluation
"""

return self._run(True, *args)

def run_async(
self, *args: Optional[Union[Value, Tuple[Optional[Value], ...]]]
) -> "Union[Future[Value], Future[Tuple[Value, ...]]]":
"""
Evaluate the function asynchronuously.

Args:
*args (Value):
argument(s) for evaluation

Returns:
Union[Value, Tuple[Value, ...]]:
result(s) of evaluation
"""
if isinstance(self.runtime, ExecutionRt) and not self.runtime.fhe_executor_pool:
self.runtime = ExecutionRt(self.runtime.client, self.runtime.server, True)
self.runtime.auto_schedule_run = False

return self._run(False, *args)

def run(
self,
*args: Optional[Union[Value, Tuple[Optional[Value], ...]]],
) -> Union[Value, Tuple[Value, ...], Future]:
"""
Evaluate the function.

Expand All @@ -201,15 +269,65 @@ def run(
Union[Value, Tuple[Value, ...]]:
result(s) of evaluation
"""
if isinstance(self.runtime, ExecutionRt):
auto_schedule_run = self.runtime.auto_schedule_run
else:
auto_schedule_run = False
return self._run(not auto_schedule_run, *args)

def _run(
self,
sync,
*args: Optional[Union[Value, Tuple[Optional[Value], ...]]],
) -> Union[Value, Tuple[Value, ...], Future]:
"""
Evaluate the function.

Args:
*args (Value):
argument(s) for evaluation

Returns:
Union[Value, Tuple[Value, ...]]:
result(s) of evaluation
"""
if self.configuration.simulate_encrypt_run_decrypt:
return self.simulate(*args)

assert isinstance(self.runtime, ExecutionRt)
return self.runtime.server.run(

fhe_work = lambda *args: self.runtime.server.run(
*args, evaluation_keys=self.runtime.client.evaluation_keys, function_name=self.name
)

def args_ready(args):
return [arg.result() if isinstance(arg, Future) else arg for arg in args]

if sync:
return fhe_work(*args_ready(args))

all_args_done = all(not isinstance(arg, Future) or arg.done() for arg in args)

fhe_work_future = lambda *args: self.runtime.fhe_executor_pool.submit(fhe_work, *args)
if all_args_done:
return fhe_work_future(*args_ready(args))

# waiting args to be ready with async coroutines
# it only required one thread to run unlimited waits vs unlimited sync threads
async def wait_async(arg):
if not isinstance(arg, Future):
return arg
if arg.done():
return arg.result()
return await asyncio.wrap_future(arg, loop=self.runtime.fhe_waiter_loop)

async def args_ready_and_submit(*args):
args = [await wait_async(arg) for arg in args]
return await wait_async(fhe_work_future(*args))

run_async = args_ready_and_submit(*args)
return asyncio.run_coroutine_threadsafe(run_async, self.runtime.fhe_waiter_loop)

def decrypt(
self,
*results: Union[Value, Tuple[Value, ...]],
Expand All @@ -230,6 +348,7 @@ def decrypt(
return results if len(results) != 1 else results[0] # type: ignore

assert isinstance(self.runtime, ExecutionRt)
results = [res.result() if isinstance(res, Future) else res for res in results]
return self.runtime.client.decrypt(*results, function_name=self.name)

def encrypt_run_decrypt(self, *args: Any) -> Any:
Expand Down Expand Up @@ -585,7 +704,7 @@ def __init__(
keyset_cache_directory = self.configuration.insecure_key_cache_location

client = Client(server.client_specs, keyset_cache_directory)
self.runtime = ExecutionRt(client, server)
self.runtime = ExecutionRt(client, server, self.configuration.auto_schedule_run)

@property
def mlir(self) -> str:
Expand Down
81 changes: 80 additions & 1 deletion frontends/concrete-python/tests/compilation/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import inspect
import re
import tempfile
from concurrent.futures import Future
from pathlib import Path

import numpy as np
Expand Down Expand Up @@ -325,7 +326,6 @@ def dec(x):
fhe_simulation=True,
)

assert module.client is None
assert module.keys is None
assert module.inc.simulate(5) == 6
assert module.dec.simulate(5) == 4
Expand Down Expand Up @@ -718,3 +718,82 @@ def function(x):
output = client.decrypt(deserialized_result, function_name="inc")

assert output == 11


class IncDec:
@fhe.module()
class Module:
@fhe.function({"x": "encrypted"})
def inc(x):
return fhe.refresh(x + 1)

@fhe.function({"x": "encrypted"})
def dec(x):
return fhe.refresh(x - 1)

precision = 4

inputset = list(range(1, 2**precision - 1))
to_compile = {"inc": inputset, "dec": inputset}


def test_run_async():
"""
Test `run_async` with `auto_schedule_run=False` configuration option.
"""

module = IncDec.Module.compile(IncDec.to_compile)

sample_x = 2
encrypted_x = module.inc.encrypt(sample_x)

a = module.inc.run_async(encrypted_x)
assert isinstance(a, Future)

b = module.dec.run(a)
assert isinstance(b, type(encrypted_x))

result = module.inc.decrypt(b)
assert result == sample_x


def test_run_sync():
"""
Test `run_sync` with `auto_schedule_run=True` configuration option.
"""

conf = fhe.Configuration(auto_schedule_run=True)
module = IncDec.Module.compile(IncDec.to_compile, conf)

sample_x = 2
encrypted_x = module.inc.encrypt(sample_x)

a = module.inc.run(encrypted_x)
assert isinstance(a, Future)

b = module.dec.run_sync(a)
assert isinstance(b, type(encrypted_x))

result = module.inc.decrypt(b)
assert result == sample_x


def test_run_auto_schedule():
"""
Test `run` with `auto_schedule_run=True` configuration option.
"""

conf = fhe.Configuration(auto_schedule_run=True)
module = IncDec.Module.compile(IncDec.to_compile, conf)

sample_x = 2
encrypted_x = module.inc.encrypt(sample_x)

a = module.inc.run(encrypted_x)
assert isinstance(a, Future)

b = module.dec.run(a)
assert isinstance(b, Future)

result = module.inc.decrypt(b)
assert result == sample_x
Loading