Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tfhers ml example #1151

Merged
merged 4 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions frontends/concrete-python/examples/tfhers-ml/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# TFHE-rs interoperability example

This is a similar example to the [first TFHE-rs example](../tfhers/), except that it uses tensors and run a linear ML model. It also uses quantization.

## Make tmpdir

We want to setup a temporary working directory first:

```sh
export TDIR=`mktemp -d`
```

## KeyGen

First we need to build the TFHE-rs utility in [this directory](../../tests/tfhers-utils/) by running the following:

```sh
cd ../../tests/tfhers-utils/
make build
cd -
```

Then we can generate keys in two different ways. You only need to run one of the following methods.

#### Generate the Secret Key in Concrete

We start by doing keygen in Concrete:

```sh
python example.py keygen -o $TDIR/concrete_sk -k $TDIR/concrete_keyset
```

Then we do a partial keygen in TFHE-rs:

```sh
../../tests/tfhers-utils/target/release/tfhers_utils keygen --lwe-sk $TDIR/concrete_sk --output-lwe-sk $TDIR/tfhers_sk -c $TDIR/tfhers_client_key -s $TDIR/tfhers_server_key
```

#### Generate the Secret Key in TFHE-rs

We start by doing keygen in TFHE-rs:

```sh
../../tests/tfhers-utils/target/release/tfhers_utils keygen --output-lwe-sk $TDIR/tfhers_sk -c $TDIR/tfhers_client_key -s $TDIR/tfhers_server_key
```

Then we do a partial keygen in Concrete:

```sh
python example.py keygen -s $TDIR/tfhers_sk -o $TDIR/concrete_sk -k $TDIR/concrete_keyset
```

## Quantize values

We need to quantize floating point inputs using a pre-built quantizer for our ML model:

```sh
../../tests/tfhers-utils/target/release/tfhers_utils quantize --value=5.1,3.5,1.4,0.2,4.9,3,1.4,0.2,4.7,3.2,1.3,0.2,4.6,3.1,1.5,0.2,5,3.6,1.4,0.2 --config ./input_quantizer.json -o $TDIR/quantized_values
```

## Encrypt in TFHE-rs

```sh
../../tests/tfhers-utils/target/release/tfhers_utils encrypt-with-key --signed --value=$(cat $TDIR/quantized_values) --ciphertext $TDIR/tfhers_ct --client-key $TDIR/tfhers_client_key
```

## Run in Concrete

```sh
python example.py run -k $TDIR/concrete_keyset -c $TDIR/tfhers_ct -o $TDIR/tfhers_ct_out
```

## Decrypt in TFHE-rs

```sh
../../tests/tfhers-utils/target/release/tfhers_utils decrypt-with-key --tensor --signed --ciphertext $TDIR/tfhers_ct_out --client-key $TDIR/tfhers_client_key --plaintext $TDIR/result_plaintext
```

## Rescale Output

At the end of the circuit, we are rounding the result to 8 bits, discarding the remaining LSB bits. As we have `lsbs_to_remove=10` we are re-introducing the 10 bits of LSB.

```sh
python -c "print(','.join(map(lambda x: str(x << 10), [$(cat $TDIR/result_plaintext)])))" > $TDIR/rescaled_plaintext
```


## Dequantize values

We need to dequantize integer outputs using a pre-built quantizer for our ML model:

```sh
../../tests/tfhers-utils/target/release/tfhers_utils dequantize --value=$(cat $TDIR/rescaled_plaintext) --shape=5,3 --config ./output_quantizer.json
```

## Compute error

We compare the output to the expected result:

```sh
python compute_error.py --plaintext-file "$TDIR/rescaled_plaintext" --quantized-predictions-file "test_values.json"
```

## Clean tmpdir

```sh
rm -rf $TDIR
```
63 changes: 63 additions & 0 deletions frontends/concrete-python/examples/tfhers-ml/compute_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import json

import click


@click.command()
@click.option(
"--plaintext-file", "-p", required=True, help="Path to the rescaled plaintext values file."
)
@click.option(
"--quantized-predictions-file",
"-q",
required=True,
help="Path to the test_values.json file containing quantized predictions.",
)
def compute_error(plaintext_file, quantized_predictions_file):
"""Compute the error between decrypted rescaled values and quantized predictions."""
# Read rescaled plaintext values from plaintext_file
with open(plaintext_file) as f:
rescaled_plaintext_values = [int(x) for x in f.read().strip().split(",")]

# Read quantized_predictions from quantized_predictions_file
with open(quantized_predictions_file) as f:
data = json.load(f)
quantized_predictions = data["quantized_predictions"]

# Flatten quantized_predictions
quantized_predictions_flat = [int(x) for sublist in quantized_predictions for x in sublist]

# Round down 10 bits using (x // (1 << 10)) * (1 << 10)
rounded_quantized_predictions = [
(x // (1 << 10)) * (1 << 10) for x in quantized_predictions_flat
]

# Compare rescaled_plaintext_values with rounded_quantized_predictions
num_differences = 0
total_values = len(rescaled_plaintext_values)
errors = []
for i in range(total_values):
a = rescaled_plaintext_values[i]
b = rounded_quantized_predictions[i]
print(f"output: {a}, expected: {b}")
if a != b:
num_differences += 1
error_in_units = round((a - b) / (1 << 10))
errors.append((i, error_in_units))

print("Number of differing values: {}".format(num_differences))
print("Total values compared: {}".format(total_values))
if num_differences > 0:
print("Differences (index, error in units of 2^10):")
for idx, error_in_units in errors:
print("Index {}: error = {}".format(idx, error_in_units))

# success is when we don't offset by more than 1
for error in errors:
if error[1] > 1:
return 1
return 0


if __name__ == "__main__":
compute_error()
179 changes: 179 additions & 0 deletions frontends/concrete-python/examples/tfhers-ml/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import os
import typing
from functools import partial

import click
import numpy as np

import concrete.fhe as fhe
from concrete.fhe import tfhers

### Options ###########################
# These parameters were saved by running the tfhers_utils utility:
# tfhers_utils save-params tfhers_params.json
TFHERS_PARAMS_FILE = "tfhers_params.json"
FHEUINT_PRECISION = 8
IS_SIGNED = True
#######################################

tfhers_type = tfhers.get_type_from_params(
TFHERS_PARAMS_FILE,
is_signed=IS_SIGNED,
precision=FHEUINT_PRECISION,
)
tfhers_int = partial(tfhers.TFHERSInteger, tfhers_type)

#### Model Parameters ##################
q_weights = np.array([[-25, 21, -10], [42, -20, -37], [-128, -15, 127], [-58, -51, 94]])
q_bias = np.array([[35167, 9417, -44584]])
weight_quantizer_zero_point = -5
########################################

rounder = fhe.AutoRounder(target_msbs=8) # We want to keep 8 MSBs


@typing.no_type_check
def ml_inference(q_X: np.ndarray) -> np.ndarray:
y_pred = q_X @ q_weights - weight_quantizer_zero_point * np.sum(q_X, axis=1, keepdims=True)
y_pred += q_bias
y_pred = fhe.round_bit_pattern(y_pred, rounder)
y_pred = y_pred >> rounder.lsbs_to_remove
return y_pred


def compute(tfhers_x):
####### TFHE-rs to Concrete #########

# x and y are supposed to be TFHE-rs values.
# to_native will use type information from x and y to do
# a correct conversion from TFHE-rs to Concrete
concrete_x = tfhers.to_native(tfhers_x)
####### TFHE-rs to Concrete #########

####### Concrete Computation ########
concrete_res = ml_inference(concrete_x)
####### Concrete Computation ########

####### Concrete to TFHE-rs #########
tfhers_res = tfhers.from_native(
concrete_res, tfhers_type
) # we have to specify the type we want to convert to
####### Concrete to TFHE-rs #########
return tfhers_res


def ccompilee():
compiler = fhe.Compiler(
compute,
{
"tfhers_x": "encrypted",
},
)

inputset = [
(
tfhers_int(
np.array(
[
[36, -17, -85, -124],
[29, -33, -85, -124],
[23, -26, -88, -124],
[19, -30, -82, -124],
[32, -13, -85, -124],
]
)
),
)
]

# Add the auto-adjustment before compilation
fhe.AutoRounder.adjust(compute, inputset)

# Print the number of bits rounded
print(f"lsbs_to_remove: {rounder.lsbs_to_remove}")

circuit = compiler.compile(inputset)

tfhers_bridge = tfhers.new_bridge(circuit=circuit)
return circuit, tfhers_bridge


@click.group()
def cli():
pass


@cli.command()
@click.option("-s", "--secret-key", type=str, required=False)
@click.option("-o", "--output-secret-key", type=str, required=True)
@click.option("-k", "--concrete-keyset-path", type=str, required=True)
def keygen(output_secret_key: str, secret_key: str, concrete_keyset_path: str):
"""Concrete Key Generation"""

circuit, tfhers_bridge = ccompilee()

if os.path.exists(concrete_keyset_path):
print(f"removing old keyset at '{concrete_keyset_path}'")
os.remove(concrete_keyset_path)

if secret_key:
print(f"partial keygen from sk at '{secret_key}'")
Dismissed Show dismissed Hide dismissed
# load the initial secret key to use for keygen
with open(
secret_key,
"rb",
) as f:
buff = f.read()
input_idx_to_key = {0: buff}
tfhers_bridge.keygen_with_initial_keys(input_idx_to_key_buffer=input_idx_to_key)
else:
print("full keygen")
circuit.keygen()

print("saving Concrete keyset")
circuit.client.keys.save(concrete_keyset_path)
print(f"saved Concrete keyset to '{concrete_keyset_path}'")

sk: bytes = tfhers_bridge.serialize_input_secret_key(input_idx=0)
print(f"writing secret key of size {len(sk)} to '{output_secret_key}'")
Dismissed Show dismissed Hide dismissed
with open(output_secret_key, "wb") as f:
f.write(sk)


@cli.command()
@click.option("-c", "--rust-ct", type=str, required=True)
@click.option("-o", "--output-rust-ct", type=str, required=False)
@click.option("-k", "--concrete-keyset-path", type=str, required=True)
def run(rust_ct: str, output_rust_ct: str, concrete_keyset_path: str):
"""Run circuit"""
circuit, tfhers_bridge = ccompilee()

if not os.path.exists(concrete_keyset_path):
raise RuntimeError("cannot find keys, you should run keygen before")
print(f"loading keys from '{concrete_keyset_path}'")
circuit.client.keys.load(concrete_keyset_path)

# read tfhers int from file
with open(rust_ct, "rb") as f:
buff = f.read()
# import fheuint8 and get its description
tfhers_uint8_x = tfhers_bridge.import_value(buff, input_idx=0)

print("Homomorphic evaluation...")
encrypted_result = circuit.run(tfhers_uint8_x)

if output_rust_ct:
print("exporting Rust ciphertexts")
# export fheuint8
buff = tfhers_bridge.export_value(encrypted_result, output_idx=0)
# write it to file
with open(output_rust_ct, "wb") as f:
f.write(buff)
else:
result = circuit.decrypt(encrypted_result)
decoded = tfhers_type.decode(result)
print(f"Concrete decryption result: raw({result}), decoded({decoded})")


if __name__ == "__main__":
cli()
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type_name": "UniformQuantizer", "serialized_value": {"n_bits": 8, "is_signed": true, "is_symmetric": false, "is_qat": false, "is_narrow": false, "is_precomputed_qat": false, "rmax": {"type_name": "numpy_float", "serialized_value": 7.9, "dtype": "float64"}, "rmin": {"type_name": "numpy_float", "serialized_value": 0.1, "dtype": "float64"}, "scale": {"type_name": "numpy_float", "serialized_value": 0.03058823529411765, "dtype": "float64"}, "zero_point": {"type_name": "numpy_integer", "serialized_value": -131, "dtype": "int64"}, "offset": 128, "no_clipping": false}}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type_name": "UniformQuantizer", "serialized_value": {"is_signed": false, "is_symmetric": false, "is_qat": false, "is_narrow": false, "is_precomputed_qat": false, "rmax": null, "rmin": null, "scale": {"type_name": "numpy_float", "serialized_value": 0.0006288117860507253, "dtype": "float64"}, "zero_point": {"type_name": "numpy_array", "serialized_value": [[39038, 11790, -50828]], "dtype": "int64"}, "offset": 0, "no_clipping": true}}
8 changes: 8 additions & 0 deletions frontends/concrete-python/examples/tfhers-ml/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/bin/sh

# This file tests that the example is working

shell_blocks=$(sed -n '/^```sh/,/^```/ p' < README.md | sed '/^```sh/d' | sed '/^```/d')

set -e
output=$(eval "$shell_blocks" 2>&1) || echo "$output"
Loading
Loading