Skip to content

Commit

Permalink
feat(frontend-python): module run are scheduled and parallelized in a…
Browse files Browse the repository at this point in the history
… worker pool
  • Loading branch information
rudy-6-4 committed Aug 16, 2024
1 parent 7f31f6d commit 371a873
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 5 deletions.
8 changes: 8 additions & 0 deletions docs/guides/configure.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,11 @@ Additional kwargs to `compile` functions take higher precedence. So if you set t
* When this option is set to `True`, encrypt and decrypt are identity functions, and run is a wrapper around simulation. In other words, this option allows to switch off the encryption to quickly test if a function has expected semantic (without paying the price of FHE execution).
* This is extremely unsafe and should only be used during development.
* For this reason, it requires **enable\_unsafe\_features** to be set to `True`.
* **auto\_schedule\_run**: bool = False
* Enable automatic scheduling of `run` method calls. When enabled, fhe function are computated in parallel in a background threads pool. When several `run` are composed, they are automatically synchronized.
* For now, it only works for the `run` method of a `FheModule`, in that case you obtain a `Future[Value]` immediately instead of a `Value` when computation is finished.
* E.g. `my_module.f3.run( my_module.f1.run(a), my_module.f1.run(b) )` will runs `f1` and `f2` in parallel in the background and `f3` in background when both `f1` and `f2` intermediate results are available.
* If you want to manually synchronize on the termination of a full computation, e.g. you want to return the encrypted result, you can call explicitely `value.result()` to wait for the result. To simplify testing, decryption does it automatically.
* Automatic scheduling behavior can be override locally by calling directly a variant of `run`:
* `run_sync`: forces the fhe function to occur in the current thread, not in the background,
* `run_async`: forces the fhe function to occur in a background thread, returning immediately a `Future[Value]`
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,7 @@ class Configuration:
dynamic_assignment_check_out_of_bounds: bool
simulate_encrypt_run_decrypt: bool
composable: bool
auto_schedule_run: bool

def __init__(
self,
Expand Down Expand Up @@ -1063,6 +1064,7 @@ def __init__(
dynamic_indexing_check_out_of_bounds: bool = True,
dynamic_assignment_check_out_of_bounds: bool = True,
simulate_encrypt_run_decrypt: bool = False,
auto_schedule_run: bool = False,
):
self.verbose = verbose
self.compiler_debug_mode = compiler_debug_mode
Expand Down Expand Up @@ -1170,6 +1172,8 @@ def __init__(

self.simulate_encrypt_run_decrypt = simulate_encrypt_run_decrypt

self.auto_schedule_run = auto_schedule_run

self._validate()

class Keep:
Expand Down Expand Up @@ -1245,6 +1249,7 @@ def fork(
dynamic_indexing_check_out_of_bounds: Union[Keep, bool] = KEEP,
dynamic_assignment_check_out_of_bounds: Union[Keep, bool] = KEEP,
simulate_encrypt_run_decrypt: Union[Keep, bool] = KEEP,
auto_schedule_run: Union[Keep, bool] = KEEP,
) -> "Configuration":
"""
Get a new configuration from another one specified changes.
Expand Down
122 changes: 118 additions & 4 deletions frontends/concrete-python/concrete/fhe/compilation/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@

# pylint: disable=import-error,no-member,no-name-in-module

import asyncio
from pathlib import Path
from threading import Thread
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import Future

import numpy as np
from concrete.compiler import (
Expand All @@ -29,14 +33,38 @@
# pylint: enable=import-error,no-member,no-name-in-module


class ExecutionRt(NamedTuple):
class ExecutionRt:
"""
Runtime object class for execution.
"""

client: Client
server: Server
auto_schedule_run: bool
fhe_executor_pool: ThreadPoolExecutor
fhe_waiter_loop: asyncio.BaseEventLoop
fhe_waiter_thread: Thread # daemon thread

def __init__(self, client, server, auto_schedule_run):
self.client = client
self.server = server
self.auto_schedule_run = auto_schedule_run
if auto_schedule_run:
self.fhe_executor_pool = ThreadPoolExecutor()
self.fhe_waiter_loop = asyncio.new_event_loop()
def loop_thread():
asyncio.set_event_loop(self.fhe_waiter_loop)
self.fhe_waiter_loop.run_forever()
self.fhe_waiter_thread = Thread(target=loop_thread, args=(), daemon=True)
self.fhe_waiter_thread.start()
else:
self.fhe_executor_pool = None
self.fhe_waiter_loop = None
self.fhe_waiter_thread = None

def __del__(self):
if self.fhe_waiter_loop:
self.fhe_waiter_loop.stop() # daemon cleanup

class SimulationRt(NamedTuple):
"""
Expand Down Expand Up @@ -186,10 +214,47 @@ def encrypt(
assert isinstance(self.runtime, ExecutionRt)
return self.runtime.client.encrypt(*args, function_name=self.name)

def run(
def run_sync(
self,
*args: Optional[Union[Value, Tuple[Optional[Value], ...]]],
) -> Union[Value, Tuple[Value, ...]]:
"""
Evaluate the function synchronuously.
Args:
*args (Value):
argument(s) for evaluation
Returns:
Union[Value, Tuple[Value, ...]]:
result(s) of evaluation
"""

return self._run(True, *args)

def run_async(self, *args: Optional[Union[Value, Tuple[Optional[Value], ...]]]
) -> 'Union[Future[Value], Future[Tuple[Value, ...]]]':
"""
Evaluate the function asynchronuously.
Args:
*args (Value):
argument(s) for evaluation
Returns:
Union[Value, Tuple[Value, ...]]:
result(s) of evaluation
"""
if isinstance(self.runtime, ExecutionRt) and not self.runtime.fhe_executor_pool:
self.runtime = ExecutionRt(self.runtime.client, self.runtime.server, True)
self.runtime.auto_schedule_run = False

return self._run(False, *args)

def run(
self,
*args: Optional[Union[Value, Tuple[Optional[Value], ...]]],
) -> Union[Value, Tuple[Value, ...], Future]:
"""
Evaluate the function.
Expand All @@ -201,15 +266,63 @@ def run(
Union[Value, Tuple[Value, ...]]:
result(s) of evaluation
"""
if isinstance(self.runtime, ExecutionRt):
auto_schedule_run = self.runtime.auto_schedule_run
else:
auto_schedule_run = False
return self._run(not auto_schedule_run, *args)

def _run(
self,
sync,
*args: Optional[Union[Value, Tuple[Optional[Value], ...]]],
) -> Union[Value, Tuple[Value, ...], Future]:
"""
Evaluate the function.
Args:
*args (Value):
argument(s) for evaluation
Returns:
Union[Value, Tuple[Value, ...]]:
result(s) of evaluation
"""
if self.configuration.simulate_encrypt_run_decrypt:
return self.simulate(*args)

assert isinstance(self.runtime, ExecutionRt)
return self.runtime.server.run(

fhe_work = lambda *args:self.runtime.server.run(
*args, evaluation_keys=self.runtime.client.evaluation_keys, function_name=self.name
)

def args_ready(args):
return [arg.result() if isinstance(arg, Future) else arg for arg in args]

if sync:
return fhe_work(*args_ready(args))

all_args_done = all(not isinstance(arg, Future) or arg.done() for arg in args)

fhe_work_future = lambda *args:self.runtime.fhe_executor_pool.submit(fhe_work, *args)
if all_args_done:
return fhe_work_future(*args_ready(args))

# waiting args to be ready with async coroutines
# it only required one thread to run unlimited waits vs unlimited sync threads
async def wait_async(arg):
if not isinstance(arg, Future):
return arg
if arg.done():
return arg.result()
return await asyncio.wrap_future(arg, loop=self.runtime.fhe_waiter_loop)
async def args_ready_and_submit(*args):
args = [await wait_async(arg) for arg in args]
return await wait_async(fhe_work_future(*args))
run_async = args_ready_and_submit(*args)
return asyncio.run_coroutine_threadsafe(run_async, self.runtime.fhe_waiter_loop)

def decrypt(
self,
*results: Union[Value, Tuple[Value, ...]],
Expand All @@ -230,6 +343,7 @@ def decrypt(
return results if len(results) != 1 else results[0] # type: ignore

assert isinstance(self.runtime, ExecutionRt)
results = [res.result() if isinstance(res, Future) else res for res in results]
return self.runtime.client.decrypt(*results, function_name=self.name)

def encrypt_run_decrypt(self, *args: Any) -> Any:
Expand Down Expand Up @@ -585,7 +699,7 @@ def __init__(
keyset_cache_directory = self.configuration.insecure_key_cache_location

client = Client(server.client_specs, keyset_cache_directory)
self.runtime = ExecutionRt(client, server)
self.runtime = ExecutionRt(client, server, self.configuration.auto_schedule_run)

@property
def mlir(self) -> str:
Expand Down
77 changes: 76 additions & 1 deletion frontends/concrete-python/tests/compilation/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Tests of everything related to modules.
"""

from concurrent.futures import Future
import inspect
import re
import tempfile
Expand Down Expand Up @@ -325,7 +326,6 @@ def dec(x):
fhe_simulation=True,
)

assert module.client is None
assert module.keys is None
assert module.inc.simulate(5) == 6
assert module.dec.simulate(5) == 4
Expand Down Expand Up @@ -718,3 +718,78 @@ def function(x):
output = client.decrypt(deserialized_result, function_name="inc")

assert output == 11

class IncDec:
@fhe.module()
class Module:
@fhe.function({"x": "encrypted"})
def inc(x):
return fhe.refresh(x + 1)

@fhe.function({"x": "encrypted"})
def dec(x):
return fhe.refresh(x - 1)

precision = 4

inputset = list(range(1, 2**precision-1))
to_compile = {"inc": inputset, "dec": inputset}

def test_run_async():
"""
Test `run_async` with `auto_schedule_run=False` configuration option.
"""

module = IncDec.Module.compile(IncDec.to_compile)

sample_x = 2
encrypted_x = module.inc.encrypt(sample_x)

a = module.inc.run_async(encrypted_x)
assert isinstance(a, Future)

b = module.dec.run(a)
assert isinstance(b, type(encrypted_x))

result = module.inc.decrypt(b)
assert result == sample_x

def test_run_sync():
"""
Test `run_sync` with `auto_schedule_run=True` configuration option.
"""

conf = fhe.Configuration(auto_schedule_run=True)
module = IncDec.Module.compile(IncDec.to_compile, conf)

sample_x = 2
encrypted_x = module.inc.encrypt(sample_x)

a = module.inc.run(encrypted_x)
assert isinstance(a, Future)

b = module.dec.run_sync(a)
assert isinstance(b, type(encrypted_x))

result = module.inc.decrypt(b)
assert result == sample_x

def test_run_auto_schedule():
"""
Test `run` with `auto_schedule_run=True` configuration option.
"""

conf = fhe.Configuration(auto_schedule_run=True)
module = IncDec.Module.compile(IncDec.to_compile, conf)

sample_x = 2
encrypted_x = module.inc.encrypt(sample_x)

a = module.inc.run(encrypted_x)
assert isinstance(a, Future)

b = module.dec.run(a)
assert isinstance(b, Future)

result = module.inc.decrypt(b)
assert result == sample_x

0 comments on commit 371a873

Please sign in to comment.