Skip to content

Commit

Permalink
feat: make defer-decryption default
Browse files Browse the repository at this point in the history
* feat: add `defer` option to inner configs

- `TranscriptConfig`
- `MpcTlsLeaderConfig`
- `MpcTlsFollowerConfig`

* adapt common `ProtocolConfig` to `deferred` feature

* Adapt prover and verifier configs.

* Adapt benches.

* Adapt examples.

* Adapt `crates/tests-integration`.

* Adapt notary integration test and wasm crates.

* Fix test.

* add clarifying comments

* Add feedback.

* Improve default handling for `max_deferred_size`.

* Use default handling instead of validation.

* Add feedback.

* fix: bugfix for `notarize.rs`

* Remove defaults for `ProtocolConfigValidator`

* Set `ProtocolConfigValidator` where needed.

* Only preprocess online part for transcript.

---------

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
th4s and Ubuntu authored Sep 11, 2024
1 parent 32df138 commit 9bbb2fb
Show file tree
Hide file tree
Showing 21 changed files with 326 additions and 123 deletions.
25 changes: 15 additions & 10 deletions crates/benches/bin/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,18 +117,28 @@ async fn run_instance<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>

let start_time = Instant::now();

let protocol_config = ProtocolConfig::builder()
.max_sent_data(upload_size + 256)
.max_recv_data(download_size + 256)
.build()
.unwrap();
let protocol_config = if defer_decryption {
ProtocolConfig::builder()
.max_sent_data(upload_size + 256)
.max_recv_data(download_size + 256)
.build()
.unwrap()
} else {
ProtocolConfig::builder()
.max_sent_data(upload_size + 256)
.max_recv_data(download_size + 256)
.max_recv_data_online(download_size + 256)
.build()
.unwrap()
};

let prover = Prover::new(
ProverConfig::builder()
.id("test")
.server_dns(SERVER_DOMAIN)
.root_cert_store(root_store())
.protocol_config(protocol_config)
.defer_decryption_from_start(defer_decryption)
.build()
.context("invalid prover config")?,
)
Expand All @@ -137,7 +147,6 @@ async fn run_instance<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>

let (mut mpc_tls_connection, prover_fut) = prover.connect(client_conn.compat()).await.unwrap();

let prover_ctrl = prover_fut.control();
let prover_task = tokio::spawn(prover_fut);

let request = format!(
Expand All @@ -146,10 +155,6 @@ async fn run_instance<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
String::from_utf8(vec![0x42u8; upload_size]).unwrap(),
);

if defer_decryption {
prover_ctrl.defer_decryption().await?;
}

mpc_tls_connection.write_all(request.as_bytes()).await?;
mpc_tls_connection.close().await?;

Expand Down
96 changes: 69 additions & 27 deletions crates/common/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,14 @@ use std::error::Error;

use crate::Role;

/// Default for the maximum number of bytes that can be sent (4KB).
pub const DEFAULT_MAX_SENT_LIMIT: usize = 1 << 12;
/// Default for the maximum number of bytes that can be received (16KB).
pub const DEFAULT_MAX_RECV_LIMIT: usize = 1 << 14;

// Extra cushion room, eg. for sharing J0 blocks.
const EXTRA_OTS: usize = 16384;

const OTS_PER_BYTE_SENT: usize = 8;

// Without deferred decryption we use 16, with it we use 8.
const OTS_PER_BYTE_RECV: usize = 16;
const OTS_PER_BYTE_RECV_ONLINE: usize = 16;
const OTS_PER_BYTE_RECV_DEFER: usize = 8;

// Current version that is running.
static VERSION: Lazy<Version> = Lazy::new(|| {
Expand All @@ -27,18 +25,32 @@ static VERSION: Lazy<Version> = Lazy::new(|| {

/// Protocol configuration to be set up initially by prover and verifier.
#[derive(derive_builder::Builder, Clone, Debug, Deserialize, Serialize)]
#[builder(build_fn(validate = "Self::validate"))]
pub struct ProtocolConfig {
/// Maximum number of bytes that can be sent.
#[builder(default = "DEFAULT_MAX_SENT_LIMIT")]
max_sent_data: usize,
/// Maximum number of bytes that can be decrypted online, i.e. while the MPC-TLS connection is
/// active.
#[builder(default = "0")]
max_recv_data_online: usize,
/// Maximum number of bytes that can be received.
#[builder(default = "DEFAULT_MAX_RECV_LIMIT")]
max_recv_data: usize,
/// Version that is being run by prover/verifier.
#[builder(setter(skip), default = "VERSION.clone()")]
version: Version,
}

impl ProtocolConfigBuilder {
fn validate(&self) -> Result<(), String> {
if self.max_recv_data_online > self.max_recv_data {
return Err(
"max_recv_data_online must be smaller or equal to max_recv_data".to_string(),
);
}
Ok(())
}
}

impl Default for ProtocolConfig {
fn default() -> Self {
Self::builder().build().unwrap()
Expand All @@ -56,19 +68,34 @@ impl ProtocolConfig {
self.max_sent_data
}

/// Returns the maximum number of bytes that can be decrypted online.
pub fn max_recv_data_online(&self) -> usize {
self.max_recv_data_online
}

/// Returns the maximum number of bytes that can be received.
pub fn max_recv_data(&self) -> usize {
self.max_recv_data
}

/// Returns OT sender setup count.
pub fn ot_sender_setup_count(&self, role: Role) -> usize {
ot_send_estimate(role, self.max_sent_data, self.max_recv_data)
ot_send_estimate(
role,
self.max_sent_data,
self.max_recv_data_online,
self.max_recv_data,
)
}

/// Returns OT receiver setup count.
pub fn ot_receiver_setup_count(&self, role: Role) -> usize {
ot_recv_estimate(role, self.max_sent_data, self.max_recv_data)
ot_recv_estimate(
role,
self.max_sent_data,
self.max_recv_data_online,
self.max_recv_data,
)
}
}

Expand All @@ -77,22 +104,14 @@ impl ProtocolConfig {
#[derive(derive_builder::Builder, Clone, Debug)]
pub struct ProtocolConfigValidator {
/// Maximum number of bytes that can be sent.
#[builder(default = "DEFAULT_MAX_SENT_LIMIT")]
max_sent_data: usize,
/// Maximum number of bytes that can be received.
#[builder(default = "DEFAULT_MAX_RECV_LIMIT")]
max_recv_data: usize,
/// Version that is being run by checker.
#[builder(setter(skip), default = "VERSION.clone()")]
version: Version,
}

impl Default for ProtocolConfigValidator {
fn default() -> Self {
Self::builder().build().unwrap()
}
}

impl ProtocolConfigValidator {
/// Creates a new builder for `ProtocolConfigValidator`.
pub fn builder() -> ProtocolConfigValidatorBuilder {
Expand Down Expand Up @@ -208,20 +227,36 @@ enum ErrorKind {
}

/// Returns an estimate of the number of OTs that will be sent.
pub fn ot_send_estimate(role: Role, max_sent_data: usize, max_recv_data: usize) -> usize {
pub fn ot_send_estimate(
role: Role,
max_sent_data: usize,
max_recv_data_online: usize,
max_recv_data: usize,
) -> usize {
match role {
Role::Prover => EXTRA_OTS,
Role::Verifier => {
EXTRA_OTS + (max_sent_data * OTS_PER_BYTE_SENT) + (max_recv_data * OTS_PER_BYTE_RECV)
EXTRA_OTS
+ (max_sent_data * OTS_PER_BYTE_SENT)
+ (max_recv_data_online * OTS_PER_BYTE_RECV_ONLINE)
+ ((max_recv_data - max_recv_data_online) * OTS_PER_BYTE_RECV_DEFER)
}
}
}

/// Returns an estimate of the number of OTs that will be received.
pub fn ot_recv_estimate(role: Role, max_sent_data: usize, max_recv_data: usize) -> usize {
pub fn ot_recv_estimate(
role: Role,
max_sent_data: usize,
max_recv_data_online: usize,
max_recv_data: usize,
) -> usize {
match role {
Role::Prover => {
EXTRA_OTS + (max_sent_data * OTS_PER_BYTE_SENT) + (max_recv_data * OTS_PER_BYTE_RECV)
EXTRA_OTS
+ (max_sent_data * OTS_PER_BYTE_SENT)
+ (max_recv_data_online * OTS_PER_BYTE_RECV_ONLINE)
+ ((max_recv_data - max_recv_data_online) * OTS_PER_BYTE_RECV_DEFER)
}
Role::Verifier => EXTRA_OTS,
}
Expand All @@ -232,16 +267,23 @@ mod test {
use super::*;
use rstest::{fixture, rstest};

const TEST_MAX_SENT_LIMIT: usize = 1 << 12;
const TEST_MAX_RECV_LIMIT: usize = 1 << 14;

#[fixture]
#[once]
fn config_validator() -> ProtocolConfigValidator {
ProtocolConfigValidator::builder().build().unwrap()
ProtocolConfigValidator::builder()
.max_sent_data(TEST_MAX_SENT_LIMIT)
.max_recv_data(TEST_MAX_RECV_LIMIT)
.build()
.unwrap()
}

#[rstest]
#[case::same_max_sent_recv_data(DEFAULT_MAX_SENT_LIMIT, DEFAULT_MAX_RECV_LIMIT)]
#[case::smaller_max_sent_data(1 << 11, DEFAULT_MAX_RECV_LIMIT)]
#[case::smaller_max_recv_data(DEFAULT_MAX_SENT_LIMIT, 1 << 13)]
#[case::same_max_sent_recv_data(TEST_MAX_SENT_LIMIT, TEST_MAX_RECV_LIMIT)]
#[case::smaller_max_sent_data(1 << 11, TEST_MAX_RECV_LIMIT)]
#[case::smaller_max_recv_data(TEST_MAX_SENT_LIMIT, 1 << 13)]
#[case::smaller_max_sent_recv_data(1 << 7, 1 << 9)]
fn test_check_success(
config_validator: &ProtocolConfigValidator,
Expand All @@ -258,7 +300,7 @@ mod test {
}

#[rstest]
#[case::bigger_max_sent_data(1 << 13, DEFAULT_MAX_RECV_LIMIT)]
#[case::bigger_max_sent_data(1 << 13, TEST_MAX_RECV_LIMIT)]
#[case::bigger_max_recv_data(1 << 10, 1 << 16)]
#[case::bigger_max_sent_recv_data(1 << 14, 1 << 21)]
fn test_check_fail(
Expand Down
4 changes: 0 additions & 4 deletions crates/examples/discord/discord_dm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@ const SERVER_DOMAIN: &str = "discord.com";
const NOTARY_HOST: &str = "127.0.0.1";
const NOTARY_PORT: u16 = 7047;

// P/S: If the following limits are increased, please ensure max-transcript-size of
// the notary server's config (../../notary/server) is increased too, where
// max-transcript-size = MAX_SENT_DATA + MAX_RECV_DATA
//
// Maximum number of bytes that can be sent from prover to server
const MAX_SENT_DATA: usize = 1 << 12;
// Maximum number of bytes that can be received by prover from server
Expand Down
32 changes: 24 additions & 8 deletions crates/examples/interactive/interactive.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use http_body_util::Empty;
use hyper::{body::Bytes, Request, StatusCode, Uri};
use hyper_util::rt::TokioIo;
use tlsn_common::config::{ProtocolConfig, ProtocolConfigValidator};
use tlsn_core::{proof::SessionInfo, Direction, RedactedTranscript};
use tlsn_prover::tls::{state::Prove, Prover, ProverConfig};
use tlsn_verifier::tls::{Verifier, VerifierConfig};
Expand All @@ -11,6 +12,11 @@ use tracing::instrument;
const SECRET: &str = "TLSNotary's private key 🤡";
const SERVER_DOMAIN: &str = "example.com";

// Maximum number of bytes that can be sent from prover to server
const MAX_SENT_DATA: usize = 1 << 12;
// Maximum number of bytes that can be received by prover from server
const MAX_RECV_DATA: usize = 1 << 14;

#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();
Expand Down Expand Up @@ -53,6 +59,13 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
ProverConfig::builder()
.id(id)
.server_dns(server_domain)
.protocol_config(
ProtocolConfig::builder()
.max_sent_data(MAX_SENT_DATA)
.max_recv_data(MAX_RECV_DATA)
.build()
.unwrap(),
)
.build()
.unwrap(),
)
Expand All @@ -69,9 +82,6 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
let (mpc_tls_connection, prover_fut) =
prover.connect(tls_client_socket.compat()).await.unwrap();

// Grab a controller for the Prover so we can enable deferred decryption.
let ctrl = prover_fut.control();

// Wrap the connection in a TokioIo compatibility layer to use it with hyper.
let mpc_tls_connection = TokioIo::new(mpc_tls_connection.compat());

Expand All @@ -87,10 +97,6 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
// Spawn the connection to run in the background.
tokio::spawn(connection);

// Enable deferred decryption. This speeds up the proving time, but doesn't
// let us see the decrypted data until after the connection is closed.
ctrl.defer_decryption().await.unwrap();

// MPC-TLS: Send Request and wait for Response.
let request = Request::builder()
.uri(uri.clone())
Expand Down Expand Up @@ -120,7 +126,17 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
id: &str,
) -> (RedactedTranscript, RedactedTranscript, SessionInfo) {
// Setup Verifier.
let verifier_config = VerifierConfig::builder().id(id).build().unwrap();
let config_validator = ProtocolConfigValidator::builder()
.max_sent_data(MAX_SENT_DATA)
.max_recv_data(MAX_RECV_DATA)
.build()
.unwrap();

let verifier_config = VerifierConfig::builder()
.id(id)
.protocol_config_validator(config_validator)
.build()
.unwrap();
let verifier = Verifier::new(verifier_config);

// Verify MPC-TLS and wait for (redacted) data.
Expand Down
20 changes: 18 additions & 2 deletions crates/examples/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
use elliptic_curve::pkcs8::DecodePrivateKey;
use futures::{AsyncRead, AsyncWrite};
use tlsn_common::config::ProtocolConfigValidator;
use tlsn_verifier::tls::{Verifier, VerifierConfig};

// Maximum number of bytes that can be sent from prover to server
const MAX_SENT_DATA: usize = 1 << 12;
// Maximum number of bytes that can be received by prover from server
const MAX_RECV_DATA: usize = 1 << 14;

/// Runs a simple Notary with the provided connection to the Prover.
pub async fn run_notary<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(conn: T) {
// Load the notary signing key
Expand All @@ -11,9 +17,19 @@ pub async fn run_notary<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(conn
.unwrap();
let signing_key = p256::ecdsa::SigningKey::from_pkcs8_pem(signing_key_str).unwrap();

// Setup default config. Normally a different ID would be generated
// Setup the config. Normally a different ID would be generated
// for each notarization.
let config = VerifierConfig::builder().id("example").build().unwrap();
let config_validator = ProtocolConfigValidator::builder()
.max_sent_data(MAX_SENT_DATA)
.max_recv_data(MAX_RECV_DATA)
.build()
.unwrap();

let config = VerifierConfig::builder()
.id("example")
.protocol_config_validator(config_validator)
.build()
.unwrap();

Verifier::new(config)
.notarize::<_, p256::ecdsa::Signature>(conn, &signing_key)
Expand Down
Loading

0 comments on commit 9bbb2fb

Please sign in to comment.