Skip to content

Commit

Permalink
fix(frontend): support higher bitwidth computation when using TFHE-rs
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed Nov 14, 2024
1 parent 84fd038 commit 77406a6
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
17 changes: 14 additions & 3 deletions frontends/concrete-python/concrete/fhe/mlir/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,9 +985,20 @@ def tfhers_to_native(self, ctx: Context, node: Node, preds: List[Conversion]) ->
] * (2 ** (carry_width + msg_width) - 2 ** (msg_width - 1))
padding_bits_inc = ctx.tlu(result_type, msbs, padding_bit_table)
# set padding bits (where necessary) in the final result
return ctx.add(result_type, sum_result, padding_bits_inc)

return sum_result
result = ctx.add(result_type, sum_result, padding_bits_inc)
else:
result = sum_result

# even if TFHE-rs value are using non-variable bit-width, we want the output
# to be pluggable into the rest of the computation. For example, two 8bits TFHE-rs integers
# could be used in a 9bits addition. If we don't cast, it won't pass the bitwidth
# compatibility check.
output_bit_width = ctx.typeof(node).bit_width
casted_result_type = ctx.tensor(
ctx.esint(output_bit_width) if dtype.is_signed else ctx.eint(output_bit_width),
result_shape,
)
return ctx.cast(casted_result_type, result)

def tfhers_from_native(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 1
Expand Down
10 changes: 10 additions & 0 deletions frontends/concrete-python/tests/execution/test_tfhers.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,16 @@ def lut_add_lut(x, y):
TFHERS_UINT_8_3_2_4096,
id="x + y",
),
# make sure Concrete ciphertexts can use more than 8 bits
pytest.param(
lambda x, y: (x + y) % 213,
{
"x": {"range": [128, 255], "status": "encrypted"},
"y": {"range": [128, 255], "status": "encrypted"},
},
TFHERS_UINT_8_3_2_4096,
id="mod(x + y)",
),
pytest.param(
lambda x, y: x + y,
{
Expand Down

0 comments on commit 77406a6

Please sign in to comment.