diff --git a/crates/benches/bin/prover.rs b/crates/benches/bin/prover.rs index 48166e49b6..ea9623143c 100644 --- a/crates/benches/bin/prover.rs +++ b/crates/benches/bin/prover.rs @@ -117,11 +117,20 @@ async fn run_instance let start_time = Instant::now(); - let protocol_config = ProtocolConfig::builder() - .max_sent_data(upload_size + 256) - .max_recv_data(download_size + 256) - .build() - .unwrap(); + let protocol_config = if defer_decryption { + ProtocolConfig::builder() + .max_sent_data(upload_size + 256) + .max_recv_data(download_size + 256) + .build() + .unwrap() + } else { + ProtocolConfig::builder() + .max_sent_data(upload_size + 256) + .max_recv_data(download_size + 256) + .max_recv_data_online(download_size + 256) + .build() + .unwrap() + }; let prover = Prover::new( ProverConfig::builder() @@ -129,6 +138,7 @@ async fn run_instance .server_dns(SERVER_DOMAIN) .root_cert_store(root_store()) .protocol_config(protocol_config) + .defer_decryption_from_start(defer_decryption) .build() .context("invalid prover config")?, ) @@ -137,7 +147,6 @@ async fn run_instance let (mut mpc_tls_connection, prover_fut) = prover.connect(client_conn.compat()).await.unwrap(); - let prover_ctrl = prover_fut.control(); let prover_task = tokio::spawn(prover_fut); let request = format!( @@ -146,10 +155,6 @@ async fn run_instance String::from_utf8(vec![0x42u8; upload_size]).unwrap(), ); - if defer_decryption { - prover_ctrl.defer_decryption().await?; - } - mpc_tls_connection.write_all(request.as_bytes()).await?; mpc_tls_connection.close().await?; diff --git a/crates/common/src/config.rs b/crates/common/src/config.rs index 038583a9e9..1ae074d031 100644 --- a/crates/common/src/config.rs +++ b/crates/common/src/config.rs @@ -7,16 +7,14 @@ use std::error::Error; use crate::Role; -/// Default for the maximum number of bytes that can be sent (4KB). -pub const DEFAULT_MAX_SENT_LIMIT: usize = 1 << 12; -/// Default for the maximum number of bytes that can be received (16KB). -pub const DEFAULT_MAX_RECV_LIMIT: usize = 1 << 14; - // Extra cushion room, eg. for sharing J0 blocks. const EXTRA_OTS: usize = 16384; + const OTS_PER_BYTE_SENT: usize = 8; + // Without deferred decryption we use 16, with it we use 8. -const OTS_PER_BYTE_RECV: usize = 16; +const OTS_PER_BYTE_RECV_ONLINE: usize = 16; +const OTS_PER_BYTE_RECV_DEFER: usize = 8; // Current version that is running. static VERSION: Lazy = Lazy::new(|| { @@ -27,18 +25,32 @@ static VERSION: Lazy = Lazy::new(|| { /// Protocol configuration to be set up initially by prover and verifier. #[derive(derive_builder::Builder, Clone, Debug, Deserialize, Serialize)] +#[builder(build_fn(validate = "Self::validate"))] pub struct ProtocolConfig { /// Maximum number of bytes that can be sent. - #[builder(default = "DEFAULT_MAX_SENT_LIMIT")] max_sent_data: usize, + /// Maximum number of bytes that can be decrypted online, i.e. while the MPC-TLS connection is + /// active. + #[builder(default = "0")] + max_recv_data_online: usize, /// Maximum number of bytes that can be received. - #[builder(default = "DEFAULT_MAX_RECV_LIMIT")] max_recv_data: usize, /// Version that is being run by prover/verifier. #[builder(setter(skip), default = "VERSION.clone()")] version: Version, } +impl ProtocolConfigBuilder { + fn validate(&self) -> Result<(), String> { + if self.max_recv_data_online > self.max_recv_data { + return Err( + "max_recv_data_online must be smaller or equal to max_recv_data".to_string(), + ); + } + Ok(()) + } +} + impl Default for ProtocolConfig { fn default() -> Self { Self::builder().build().unwrap() @@ -56,6 +68,11 @@ impl ProtocolConfig { self.max_sent_data } + /// Returns the maximum number of bytes that can be decrypted online. + pub fn max_recv_data_online(&self) -> usize { + self.max_recv_data_online + } + /// Returns the maximum number of bytes that can be received. pub fn max_recv_data(&self) -> usize { self.max_recv_data @@ -63,12 +80,22 @@ impl ProtocolConfig { /// Returns OT sender setup count. pub fn ot_sender_setup_count(&self, role: Role) -> usize { - ot_send_estimate(role, self.max_sent_data, self.max_recv_data) + ot_send_estimate( + role, + self.max_sent_data, + self.max_recv_data_online, + self.max_recv_data, + ) } /// Returns OT receiver setup count. pub fn ot_receiver_setup_count(&self, role: Role) -> usize { - ot_recv_estimate(role, self.max_sent_data, self.max_recv_data) + ot_recv_estimate( + role, + self.max_sent_data, + self.max_recv_data_online, + self.max_recv_data, + ) } } @@ -77,22 +104,14 @@ impl ProtocolConfig { #[derive(derive_builder::Builder, Clone, Debug)] pub struct ProtocolConfigValidator { /// Maximum number of bytes that can be sent. - #[builder(default = "DEFAULT_MAX_SENT_LIMIT")] max_sent_data: usize, /// Maximum number of bytes that can be received. - #[builder(default = "DEFAULT_MAX_RECV_LIMIT")] max_recv_data: usize, /// Version that is being run by checker. #[builder(setter(skip), default = "VERSION.clone()")] version: Version, } -impl Default for ProtocolConfigValidator { - fn default() -> Self { - Self::builder().build().unwrap() - } -} - impl ProtocolConfigValidator { /// Creates a new builder for `ProtocolConfigValidator`. pub fn builder() -> ProtocolConfigValidatorBuilder { @@ -208,20 +227,36 @@ enum ErrorKind { } /// Returns an estimate of the number of OTs that will be sent. -pub fn ot_send_estimate(role: Role, max_sent_data: usize, max_recv_data: usize) -> usize { +pub fn ot_send_estimate( + role: Role, + max_sent_data: usize, + max_recv_data_online: usize, + max_recv_data: usize, +) -> usize { match role { Role::Prover => EXTRA_OTS, Role::Verifier => { - EXTRA_OTS + (max_sent_data * OTS_PER_BYTE_SENT) + (max_recv_data * OTS_PER_BYTE_RECV) + EXTRA_OTS + + (max_sent_data * OTS_PER_BYTE_SENT) + + (max_recv_data_online * OTS_PER_BYTE_RECV_ONLINE) + + ((max_recv_data - max_recv_data_online) * OTS_PER_BYTE_RECV_DEFER) } } } /// Returns an estimate of the number of OTs that will be received. -pub fn ot_recv_estimate(role: Role, max_sent_data: usize, max_recv_data: usize) -> usize { +pub fn ot_recv_estimate( + role: Role, + max_sent_data: usize, + max_recv_data_online: usize, + max_recv_data: usize, +) -> usize { match role { Role::Prover => { - EXTRA_OTS + (max_sent_data * OTS_PER_BYTE_SENT) + (max_recv_data * OTS_PER_BYTE_RECV) + EXTRA_OTS + + (max_sent_data * OTS_PER_BYTE_SENT) + + (max_recv_data_online * OTS_PER_BYTE_RECV_ONLINE) + + ((max_recv_data - max_recv_data_online) * OTS_PER_BYTE_RECV_DEFER) } Role::Verifier => EXTRA_OTS, } @@ -232,16 +267,23 @@ mod test { use super::*; use rstest::{fixture, rstest}; + const TEST_MAX_SENT_LIMIT: usize = 1 << 12; + const TEST_MAX_RECV_LIMIT: usize = 1 << 14; + #[fixture] #[once] fn config_validator() -> ProtocolConfigValidator { - ProtocolConfigValidator::builder().build().unwrap() + ProtocolConfigValidator::builder() + .max_sent_data(TEST_MAX_SENT_LIMIT) + .max_recv_data(TEST_MAX_RECV_LIMIT) + .build() + .unwrap() } #[rstest] - #[case::same_max_sent_recv_data(DEFAULT_MAX_SENT_LIMIT, DEFAULT_MAX_RECV_LIMIT)] - #[case::smaller_max_sent_data(1 << 11, DEFAULT_MAX_RECV_LIMIT)] - #[case::smaller_max_recv_data(DEFAULT_MAX_SENT_LIMIT, 1 << 13)] + #[case::same_max_sent_recv_data(TEST_MAX_SENT_LIMIT, TEST_MAX_RECV_LIMIT)] + #[case::smaller_max_sent_data(1 << 11, TEST_MAX_RECV_LIMIT)] + #[case::smaller_max_recv_data(TEST_MAX_SENT_LIMIT, 1 << 13)] #[case::smaller_max_sent_recv_data(1 << 7, 1 << 9)] fn test_check_success( config_validator: &ProtocolConfigValidator, @@ -258,7 +300,7 @@ mod test { } #[rstest] - #[case::bigger_max_sent_data(1 << 13, DEFAULT_MAX_RECV_LIMIT)] + #[case::bigger_max_sent_data(1 << 13, TEST_MAX_RECV_LIMIT)] #[case::bigger_max_recv_data(1 << 10, 1 << 16)] #[case::bigger_max_sent_recv_data(1 << 14, 1 << 21)] fn test_check_fail( diff --git a/crates/examples/discord/discord_dm.rs b/crates/examples/discord/discord_dm.rs index 722bc17ac6..3f4bffee8f 100644 --- a/crates/examples/discord/discord_dm.rs +++ b/crates/examples/discord/discord_dm.rs @@ -21,10 +21,6 @@ const SERVER_DOMAIN: &str = "discord.com"; const NOTARY_HOST: &str = "127.0.0.1"; const NOTARY_PORT: u16 = 7047; -// P/S: If the following limits are increased, please ensure max-transcript-size of -// the notary server's config (../../notary/server) is increased too, where -// max-transcript-size = MAX_SENT_DATA + MAX_RECV_DATA -// // Maximum number of bytes that can be sent from prover to server const MAX_SENT_DATA: usize = 1 << 12; // Maximum number of bytes that can be received by prover from server diff --git a/crates/examples/interactive/interactive.rs b/crates/examples/interactive/interactive.rs index acf0bdc8b7..7ad9be4b22 100644 --- a/crates/examples/interactive/interactive.rs +++ b/crates/examples/interactive/interactive.rs @@ -1,6 +1,7 @@ use http_body_util::Empty; use hyper::{body::Bytes, Request, StatusCode, Uri}; use hyper_util::rt::TokioIo; +use tlsn_common::config::{ProtocolConfig, ProtocolConfigValidator}; use tlsn_core::{proof::SessionInfo, Direction, RedactedTranscript}; use tlsn_prover::tls::{state::Prove, Prover, ProverConfig}; use tlsn_verifier::tls::{Verifier, VerifierConfig}; @@ -11,6 +12,11 @@ use tracing::instrument; const SECRET: &str = "TLSNotary's private key 🤡"; const SERVER_DOMAIN: &str = "example.com"; +// Maximum number of bytes that can be sent from prover to server +const MAX_SENT_DATA: usize = 1 << 12; +// Maximum number of bytes that can be received by prover from server +const MAX_RECV_DATA: usize = 1 << 14; + #[tokio::main] async fn main() { tracing_subscriber::fmt::init(); @@ -53,6 +59,13 @@ async fn prover( ProverConfig::builder() .id(id) .server_dns(server_domain) + .protocol_config( + ProtocolConfig::builder() + .max_sent_data(MAX_SENT_DATA) + .max_recv_data(MAX_RECV_DATA) + .build() + .unwrap(), + ) .build() .unwrap(), ) @@ -69,9 +82,6 @@ async fn prover( let (mpc_tls_connection, prover_fut) = prover.connect(tls_client_socket.compat()).await.unwrap(); - // Grab a controller for the Prover so we can enable deferred decryption. - let ctrl = prover_fut.control(); - // Wrap the connection in a TokioIo compatibility layer to use it with hyper. let mpc_tls_connection = TokioIo::new(mpc_tls_connection.compat()); @@ -87,10 +97,6 @@ async fn prover( // Spawn the connection to run in the background. tokio::spawn(connection); - // Enable deferred decryption. This speeds up the proving time, but doesn't - // let us see the decrypted data until after the connection is closed. - ctrl.defer_decryption().await.unwrap(); - // MPC-TLS: Send Request and wait for Response. let request = Request::builder() .uri(uri.clone()) @@ -120,7 +126,17 @@ async fn verifier( id: &str, ) -> (RedactedTranscript, RedactedTranscript, SessionInfo) { // Setup Verifier. - let verifier_config = VerifierConfig::builder().id(id).build().unwrap(); + let config_validator = ProtocolConfigValidator::builder() + .max_sent_data(MAX_SENT_DATA) + .max_recv_data(MAX_RECV_DATA) + .build() + .unwrap(); + + let verifier_config = VerifierConfig::builder() + .id(id) + .protocol_config_validator(config_validator) + .build() + .unwrap(); let verifier = Verifier::new(verifier_config); // Verify MPC-TLS and wait for (redacted) data. diff --git a/crates/examples/src/lib.rs b/crates/examples/src/lib.rs index a90b707bad..ff59f313c0 100644 --- a/crates/examples/src/lib.rs +++ b/crates/examples/src/lib.rs @@ -1,7 +1,13 @@ use elliptic_curve::pkcs8::DecodePrivateKey; use futures::{AsyncRead, AsyncWrite}; +use tlsn_common::config::ProtocolConfigValidator; use tlsn_verifier::tls::{Verifier, VerifierConfig}; +// Maximum number of bytes that can be sent from prover to server +const MAX_SENT_DATA: usize = 1 << 12; +// Maximum number of bytes that can be received by prover from server +const MAX_RECV_DATA: usize = 1 << 14; + /// Runs a simple Notary with the provided connection to the Prover. pub async fn run_notary(conn: T) { // Load the notary signing key @@ -11,9 +17,19 @@ pub async fn run_notary(conn .unwrap(); let signing_key = p256::ecdsa::SigningKey::from_pkcs8_pem(signing_key_str).unwrap(); - // Setup default config. Normally a different ID would be generated + // Setup the config. Normally a different ID would be generated // for each notarization. - let config = VerifierConfig::builder().id("example").build().unwrap(); + let config_validator = ProtocolConfigValidator::builder() + .max_sent_data(MAX_SENT_DATA) + .max_recv_data(MAX_RECV_DATA) + .build() + .unwrap(); + + let config = VerifierConfig::builder() + .id("example") + .protocol_config_validator(config_validator) + .build() + .unwrap(); Verifier::new(config) .notarize::<_, p256::ecdsa::Signature>(conn, &signing_key) diff --git a/crates/examples/twitter/twitter_dm.rs b/crates/examples/twitter/twitter_dm.rs index 3a61e22e01..ccee197a0b 100644 --- a/crates/examples/twitter/twitter_dm.rs +++ b/crates/examples/twitter/twitter_dm.rs @@ -7,6 +7,7 @@ use hyper::{body::Bytes, Request, StatusCode}; use hyper_util::rt::TokioIo; use notary_client::{Accepted, NotarizationRequest, NotaryClient}; use std::{env, str}; +use tlsn_common::config::ProtocolConfig; use tlsn_core::{commitment::CommitmentKind, proof::TlsProof}; use tlsn_prover::tls::{Prover, ProverConfig}; use tokio::io::AsyncWriteExt as _; @@ -22,6 +23,11 @@ const USER_AGENT: &str = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KH const NOTARY_HOST: &str = "127.0.0.1"; const NOTARY_PORT: u16 = 7047; +// Maximum number of bytes that can be sent from prover to server +const MAX_SENT_DATA: usize = 1 << 12; +// Maximum number of bytes that can be received by prover from server +const MAX_RECV_DATA: usize = 1 << 14; + #[tokio::main] async fn main() { tracing_subscriber::fmt::init(); @@ -44,7 +50,11 @@ async fn main() { .unwrap(); // Send requests for configuration and notarization to the notary server. - let notarization_request = NotarizationRequest::builder().build().unwrap(); + let notarization_request = NotarizationRequest::builder() + .max_sent_data(MAX_SENT_DATA) + .max_recv_data(MAX_RECV_DATA) + .build() + .unwrap(); let Accepted { io: notary_connection, @@ -59,6 +69,13 @@ async fn main() { let prover_config = ProverConfig::builder() .id(session_id) .server_dns(SERVER_DOMAIN) + .protocol_config( + ProtocolConfig::builder() + .max_sent_data(MAX_SENT_DATA) + .max_recv_data(MAX_RECV_DATA) + .build() + .unwrap(), + ) .build() .unwrap(); @@ -77,9 +94,6 @@ async fn main() { let (tls_connection, prover_fut) = prover.connect(client_socket.compat()).await.unwrap(); let tls_connection = TokioIo::new(tls_connection.compat()); - // Grab a control handle to the Prover - let prover_ctrl = prover_fut.control(); - // Spawn the Prover to be run concurrently let prover_task = tokio::spawn(prover_fut); @@ -115,10 +129,6 @@ async fn main() { debug!("Sending request"); - // Because we don't need to decrypt the response right away, we can defer decryption - // until after the connection is closed. This will speed up the proving process! - prover_ctrl.defer_decryption().await.unwrap(); - let response = request_sender.send_request(request).await.unwrap(); debug!("Sent request"); diff --git a/crates/notary/client/src/client.rs b/crates/notary/client/src/client.rs index 1ff07388d6..4db0dce6a0 100644 --- a/crates/notary/client/src/client.rs +++ b/crates/notary/client/src/client.rs @@ -12,7 +12,6 @@ use std::{ sync::Arc, task::{Context, Poll}, }; -use tlsn_common::config::{DEFAULT_MAX_RECV_LIMIT, DEFAULT_MAX_SENT_LIMIT}; use tokio::{ io::{AsyncRead, AsyncWrite, ReadBuf}, net::TcpStream, @@ -30,10 +29,8 @@ use crate::error::{ClientError, ErrorKind}; #[derive(Debug, Clone, derive_builder::Builder)] pub struct NotarizationRequest { /// Maximum number of bytes that can be sent. - #[builder(default = "DEFAULT_MAX_SENT_LIMIT")] max_sent_data: usize, /// Maximum number of bytes that can be received. - #[builder(default = "DEFAULT_MAX_RECV_LIMIT")] max_recv_data: usize, } diff --git a/crates/prover/src/tls/config.rs b/crates/prover/src/tls/config.rs index 9e01bcd264..4892cd3122 100644 --- a/crates/prover/src/tls/config.rs +++ b/crates/prover/src/tls/config.rs @@ -18,6 +18,12 @@ pub struct ProverConfig { /// Protocol configuration to be checked with the verifier. #[builder(default)] protocol_config: ProtocolConfig, + /// Whether the `deferred decryption` feature is toggled on from the start of the MPC-TLS + /// connection. + /// + /// See `defer_decryption_from_start` in [tls_mpc::MpcTlsLeaderConfig]. + #[builder(default = "true")] + defer_decryption_from_start: bool, } impl ProverConfig { @@ -41,6 +47,12 @@ impl ProverConfig { &self.protocol_config } + /// Returns whether the `deferred decryption` feature is toggled on from the start of the MPC-TLS + /// connection. + pub fn defer_decryption_from_start(&self) -> bool { + self.defer_decryption_from_start + } + pub(crate) fn build_mpc_tls_config(&self) -> MpcTlsLeaderConfig { MpcTlsLeaderConfig::builder() .common( @@ -48,13 +60,17 @@ impl ProverConfig { .id(format!("{}/mpc_tls", &self.id)) .tx_config( TranscriptConfig::default_tx() - .max_size(self.protocol_config.max_sent_data()) + .max_online_size(self.protocol_config.max_sent_data()) .build() .unwrap(), ) .rx_config( TranscriptConfig::default_rx() - .max_size(self.protocol_config.max_recv_data()) + .max_online_size(self.protocol_config.max_recv_data_online()) + .max_offline_size( + self.protocol_config.max_recv_data() + - self.protocol_config.max_recv_data_online(), + ) .build() .unwrap(), ) diff --git a/crates/tests-integration/Cargo.toml b/crates/tests-integration/Cargo.toml index 7c9c5bccd8..8ed307eb45 100644 --- a/crates/tests-integration/Cargo.toml +++ b/crates/tests-integration/Cargo.toml @@ -6,6 +6,7 @@ publish = false [dev-dependencies] tlsn-core = { workspace = true } +tlsn-common = { workspace = true } tlsn-prover = { workspace = true } tlsn-server-fixture = { workspace = true } tlsn-server-fixture-certs = { workspace = true } diff --git a/crates/tests-integration/tests/defer_decryption.rs b/crates/tests-integration/tests/defer_decryption.rs index 13db889adf..0740775c9e 100644 --- a/crates/tests-integration/tests/defer_decryption.rs +++ b/crates/tests-integration/tests/defer_decryption.rs @@ -1,3 +1,4 @@ +use tlsn_common::config::{ProtocolConfig, ProtocolConfigValidator}; use tlsn_prover::tls::{Prover, ProverConfig}; use tlsn_server_fixture::bind; use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN}; @@ -8,6 +9,11 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::compat::TokioAsyncReadCompatExt; use tracing::instrument; +// Maximum number of bytes that can be sent from prover to server +const MAX_SENT_DATA: usize = 1 << 12; +// Maximum number of bytes that can be received by prover from server +const MAX_RECV_DATA: usize = 1 << 14; + #[tokio::test] #[ignore] async fn test_defer_decryption() { @@ -33,6 +39,13 @@ async fn prover(notary_socke ProverConfig::builder() .id("test") .server_dns(SERVER_DOMAIN) + .protocol_config( + ProtocolConfig::builder() + .max_sent_data(MAX_SENT_DATA) + .max_recv_data(MAX_RECV_DATA) + .build() + .unwrap(), + ) .root_cert_store(root_store) .build() .unwrap(), @@ -42,12 +55,8 @@ async fn prover(notary_socke .unwrap(); let (mut tls_connection, prover_fut) = prover.connect(client_socket.compat()).await.unwrap(); - let prover_ctrl = prover_fut.control(); let prover_task = tokio::spawn(prover_fut); - // Defer decryption until after the server closes the connection. - prover_ctrl.defer_decryption().await.unwrap(); - tls_connection .write_all(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n") .await @@ -74,7 +83,19 @@ async fn prover(notary_socke #[instrument(skip(socket))] async fn notary(socket: T) { - let verifier = Verifier::new(VerifierConfig::builder().id("test").build().unwrap()); + let config_validator = ProtocolConfigValidator::builder() + .max_sent_data(MAX_SENT_DATA) + .max_recv_data(MAX_RECV_DATA) + .build() + .unwrap(); + + let verifier = Verifier::new( + VerifierConfig::builder() + .id("test") + .protocol_config_validator(config_validator) + .build() + .unwrap(), + ); let signing_key = p256::ecdsa::SigningKey::from_bytes(&[1u8; 32].into()).unwrap(); _ = verifier diff --git a/crates/tests-integration/tests/notarize.rs b/crates/tests-integration/tests/notarize.rs index 8338e1119c..583514e89d 100644 --- a/crates/tests-integration/tests/notarize.rs +++ b/crates/tests-integration/tests/notarize.rs @@ -1,3 +1,4 @@ +use tlsn_common::config::{ProtocolConfig, ProtocolConfigValidator}; use tlsn_prover::tls::{Prover, ProverConfig}; use tlsn_server_fixture::bind; use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN}; @@ -10,6 +11,11 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; use tracing::instrument; +// Maximum number of bytes that can be sent from prover to server +const MAX_SENT_DATA: usize = 1 << 12; +// Maximum number of bytes that can be received by prover from server +const MAX_RECV_DATA: usize = 1 << 14; + #[tokio::test] #[ignore] async fn notarize() { @@ -31,11 +37,20 @@ async fn prover(notary_socke .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) .unwrap(); + let protocol_config = ProtocolConfig::builder() + .max_sent_data(MAX_SENT_DATA) + .max_recv_data(MAX_RECV_DATA) + .max_recv_data_online(MAX_RECV_DATA) + .build() + .unwrap(); + let prover = Prover::new( ProverConfig::builder() .id("test") .server_dns(SERVER_DOMAIN) .root_cert_store(root_store) + .defer_decryption_from_start(false) + .protocol_config(protocol_config) .build() .unwrap(), ) @@ -86,7 +101,19 @@ async fn prover(notary_socke #[instrument(skip(socket))] async fn notary(socket: T) { - let verifier = Verifier::new(VerifierConfig::builder().id("test").build().unwrap()); + let config_validator = ProtocolConfigValidator::builder() + .max_sent_data(MAX_SENT_DATA) + .max_recv_data(MAX_RECV_DATA) + .build() + .unwrap(); + + let verifier = Verifier::new( + VerifierConfig::builder() + .id("test") + .protocol_config_validator(config_validator) + .build() + .unwrap(), + ); let signing_key = p256::ecdsa::SigningKey::from_bytes(&[1u8; 32].into()).unwrap(); _ = verifier diff --git a/crates/tests-integration/tests/verify.rs b/crates/tests-integration/tests/verify.rs index 64f21da8ca..f924771359 100644 --- a/crates/tests-integration/tests/verify.rs +++ b/crates/tests-integration/tests/verify.rs @@ -1,4 +1,5 @@ use tls_core::{anchors::RootCertStore, verify::WebPkiVerifier}; +use tlsn_common::config::{ProtocolConfig, ProtocolConfigValidator}; use tlsn_core::{proof::SessionInfo, Direction, RedactedTranscript}; use tlsn_prover::tls::{Prover, ProverConfig}; use tlsn_server_fixture::bind; @@ -13,6 +14,11 @@ use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; use tracing::instrument; use utils::range::RangeSet; +// Maximum number of bytes that can be sent from prover to server +const MAX_SENT_DATA: usize = 1 << 12; +// Maximum number of bytes that can be received by prover from server +const MAX_RECV_DATA: usize = 1 << 14; + #[tokio::test] #[ignore] async fn verify() { @@ -43,11 +49,20 @@ async fn prover(notary_socke .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) .unwrap(); + let protocol_config = ProtocolConfig::builder() + .max_sent_data(MAX_SENT_DATA) + .max_recv_data(MAX_RECV_DATA) + .max_recv_data_online(MAX_RECV_DATA) + .build() + .unwrap(); + let prover = Prover::new( ProverConfig::builder() .id("test") .server_dns(SERVER_DOMAIN) .root_cert_store(root_store) + .defer_decryption_from_start(false) + .protocol_config(protocol_config) .build() .unwrap(), ) @@ -105,8 +120,15 @@ async fn verifier( .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) .unwrap(); + let config_validator = ProtocolConfigValidator::builder() + .max_sent_data(MAX_SENT_DATA) + .max_recv_data(MAX_RECV_DATA) + .build() + .unwrap(); + let verifier_config = VerifierConfig::builder() .id("test") + .protocol_config_validator(config_validator) .cert_verifier(WebPkiVerifier::new(root_store, None)) .build() .unwrap(); diff --git a/crates/tls/mpc/src/config.rs b/crates/tls/mpc/src/config.rs index 8567efaccf..6cc8bc4b95 100644 --- a/crates/tls/mpc/src/config.rs +++ b/crates/tls/mpc/src/config.rs @@ -14,8 +14,12 @@ pub struct TranscriptConfig { /// The "opaque" transcript id, used for parts of the transcript that are not /// part of the application data. opaque_id: String, - /// The maximum length of the transcript in bytes. - max_size: usize, + /// The maximum number of bytes that can be written to the transcript during the **online** + /// phase, i.e. while the MPC-TLS connection is active. + max_online_size: usize, + /// The maximum number of bytes that can be written to the transcript during the **offline** + /// phase, i.e. after the MPC-TLS connection was closed. + max_offline_size: usize, } impl TranscriptConfig { @@ -26,7 +30,8 @@ impl TranscriptConfig { builder .id(DEFAULT_TX_TRANSCRIPT_ID.to_string()) .opaque_id(DEFAULT_OPAQUE_TX_TRANSCRIPT_ID.to_string()) - .max_size(DEFAULT_TRANSCRIPT_MAX_SIZE); + .max_online_size(DEFAULT_TRANSCRIPT_MAX_SIZE) + .max_offline_size(0); builder } @@ -38,7 +43,8 @@ impl TranscriptConfig { builder .id(DEFAULT_RX_TRANSCRIPT_ID.to_string()) .opaque_id(DEFAULT_OPAQUE_RX_TRANSCRIPT_ID.to_string()) - .max_size(DEFAULT_TRANSCRIPT_MAX_SIZE); + .max_online_size(0) + .max_offline_size(DEFAULT_TRANSCRIPT_MAX_SIZE); builder } @@ -58,9 +64,16 @@ impl TranscriptConfig { &self.opaque_id } - /// Returns the maximum length of the transcript in bytes. - pub fn max_size(&self) -> usize { - self.max_size + /// Returns the maximum number of bytes that can be written to the transcript during the **online** + /// phase, i.e. while the MPC-TLS connection is active. + pub fn max_online_size(&self) -> usize { + self.max_online_size + } + + /// Returns the maximum number of bytes that can be written to the transcript during the **offline** + /// phase, i.e. after the MPC-TLS connection was closed. + pub fn max_offline_size(&self) -> usize { + self.max_offline_size } } @@ -121,6 +134,18 @@ impl MpcTlsCommonConfig { #[derive(Debug, Clone, Builder)] pub struct MpcTlsLeaderConfig { common: MpcTlsCommonConfig, + /// Whether the `deferred decryption` feature is toggled on from the start of the MPC-TLS + /// connection. + /// + /// The received data will be decrypted locally without MPC, thus improving + /// bandwidth usage and performance. + /// + /// Decryption of the data received while `deferred decryption` is toggled on will be deferred + /// until after the MPC-TLS connection is closed. + /// If you need to decrypt some subset of data received from the TLS peer while the MPC-TLS + /// connection is active, you must toggle `deferred decryption` **off** for that subset of data. + #[builder(default = "true")] + defer_decryption_from_start: bool, } impl MpcTlsLeaderConfig { @@ -133,6 +158,12 @@ impl MpcTlsLeaderConfig { pub fn common(&self) -> &MpcTlsCommonConfig { &self.common } + + /// Returns whether the `deferred decryption` feature is toggled on from the start of the MPC-TLS + /// connection. + pub fn defer_decryption_from_start(&self) -> bool { + self.defer_decryption_from_start + } } /// Configuration for the follower diff --git a/crates/tls/mpc/src/follower.rs b/crates/tls/mpc/src/follower.rs index 27a74486ed..393240cf89 100644 --- a/crates/tls/mpc/src/follower.rs +++ b/crates/tls/mpc/src/follower.rs @@ -148,11 +148,12 @@ impl MpcTlsFollower { self.ke.preprocess().await?; self.prf.preprocess().await?; + let preprocess_encrypt = self.config.common().tx_config().max_online_size(); + let preprocess_decrypt = self.config.common().rx_config().max_online_size(); + futures::try_join!( - self.encrypter - .preprocess(self.config.common().tx_config().max_size()), - // For now we just preprocess enough for the handshake - self.decrypter.preprocess(256), + self.encrypter.preprocess(preprocess_encrypt), + self.decrypter.preprocess(preprocess_decrypt), )?; self.prf.set_client_random(None).await?; @@ -212,7 +213,7 @@ impl MpcTlsFollower { match direction { Direction::Sent => { let new_len = self.encrypter.sent_bytes() + len; - let max_size = self.config.common().tx_config().max_size(); + let max_size = self.config.common().tx_config().max_online_size(); if new_len > max_size { return Err(MpcTlsError::new( Kind::Config, @@ -225,7 +226,8 @@ impl MpcTlsFollower { } Direction::Recv => { let new_len = self.decrypter.recv_bytes() + len; - let max_size = self.config.common().rx_config().max_size(); + let max_size = self.config.common().rx_config().max_online_size() + + self.config.common().rx_config().max_offline_size(); if new_len > max_size { return Err(MpcTlsError::new( Kind::Config, diff --git a/crates/tls/mpc/src/leader.rs b/crates/tls/mpc/src/leader.rs index f6a9845d81..312cdb8df7 100644 --- a/crates/tls/mpc/src/leader.rs +++ b/crates/tls/mpc/src/leader.rs @@ -101,6 +101,7 @@ impl MpcTlsLeader { config.common().rx_config().id().to_string(), config.common().rx_config().opaque_id().to_string(), ); + let is_decrypting = !config.defer_decryption_from_start(); Self { config, @@ -111,7 +112,7 @@ impl MpcTlsLeader { encrypter, decrypter, notifier: BackendNotifier::new(), - is_decrypting: true, + is_decrypting, buffer: VecDeque::new(), committed: false, } @@ -134,11 +135,12 @@ impl MpcTlsLeader { self.ke.preprocess().await?; self.prf.preprocess().await?; + let preprocess_encrypt = self.config.common().tx_config().max_online_size(); + let preprocess_decrypt = self.config.common().rx_config().max_online_size(); + futures::try_join!( - self.encrypter - .preprocess(self.config.common().tx_config().max_size()), - // For now we just preprocess enough for the handshake - self.decrypter.preprocess(256), + self.encrypter.preprocess(preprocess_encrypt), + self.decrypter.preprocess(preprocess_decrypt), )?; self.prf @@ -178,7 +180,7 @@ impl MpcTlsLeader { match direction { Direction::Sent => { let new_len = self.encrypter.sent_bytes() + len; - let max_size = self.config.common().tx_config().max_size(); + let max_size = self.config.common().tx_config().max_online_size(); if new_len > max_size { return Err(MpcTlsError::new( Kind::Config, @@ -191,7 +193,8 @@ impl MpcTlsLeader { } Direction::Recv => { let new_len = self.decrypter.recv_bytes() + len; - let max_size = self.config.common().rx_config().max_size(); + let max_size = self.config.common().rx_config().max_online_size() + + self.config.common().rx_config().max_offline_size(); if new_len > max_size { return Err(MpcTlsError::new( Kind::Config, diff --git a/crates/tls/mpc/tests/test.rs b/crates/tls/mpc/tests/test.rs index 3f7868e9db..cb9b3a37c2 100644 --- a/crates/tls/mpc/tests/test.rs +++ b/crates/tls/mpc/tests/test.rs @@ -115,6 +115,7 @@ async fn leader(config: MpcTlsCommonConfig, mux: TestFramedMux) { let mut leader = MpcTlsLeader::new( MpcTlsLeaderConfig::builder() .common(config) + .defer_decryption_from_start(false) .build() .unwrap(), Box::new(StreamExt::compat_stream( diff --git a/crates/verifier/src/tls/config.rs b/crates/verifier/src/tls/config.rs index 82d0393213..2b05c3b4aa 100644 --- a/crates/verifier/src/tls/config.rs +++ b/crates/verifier/src/tls/config.rs @@ -12,7 +12,6 @@ use tlsn_core::proof::default_cert_verifier; pub struct VerifierConfig { #[builder(setter(into))] id: String, - #[builder(default)] protocol_config_validator: ProtocolConfigValidator, #[builder( pattern = "owned", @@ -94,13 +93,17 @@ impl VerifierConfig { .id(format!("{}/mpc_tls", &self.id)) .tx_config( TranscriptConfig::default_tx() - .max_size(protocol_config.max_sent_data()) + .max_online_size(protocol_config.max_sent_data()) .build() .unwrap(), ) .rx_config( TranscriptConfig::default_rx() - .max_size(protocol_config.max_recv_data()) + .max_online_size(protocol_config.max_recv_data_online()) + .max_offline_size( + protocol_config.max_recv_data() + - protocol_config.max_recv_data_online(), + ) .build() .unwrap(), ) diff --git a/crates/wasm-test-runner/src/tlsn_fixture.rs b/crates/wasm-test-runner/src/tlsn_fixture.rs index 2a1c068a83..3f76d4a72a 100644 --- a/crates/wasm-test-runner/src/tlsn_fixture.rs +++ b/crates/wasm-test-runner/src/tlsn_fixture.rs @@ -150,12 +150,8 @@ async fn handle_prover(io: TcpStream) -> Result<()> { let client_socket = TcpStream::connect((addr, port)).await.unwrap(); let (mut tls_connection, prover_fut) = prover.connect(client_socket.compat()).await.unwrap(); - let prover_ctrl = prover_fut.control(); let prover_task = tokio::spawn(prover_fut); - // Defer decryption until after the server closes the connection. - prover_ctrl.defer_decryption().await.unwrap(); - tls_connection .write_all(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n") .await diff --git a/crates/wasm/src/prover/config.rs b/crates/wasm/src/prover/config.rs index 957e83878e..c54bcc4bc1 100644 --- a/crates/wasm/src/prover/config.rs +++ b/crates/wasm/src/prover/config.rs @@ -7,29 +7,36 @@ use tsify_next::Tsify; pub struct ProverConfig { pub id: String, pub server_dns: String, - pub max_sent_data: Option, - pub max_recv_data: Option, + pub max_sent_data: usize, + pub max_recv_data_online: Option, + pub max_recv_data: usize, + pub defer_decryption_from_start: Option, } impl From for tlsn_prover::tls::ProverConfig { fn from(value: ProverConfig) -> Self { let mut builder = ProtocolConfig::builder(); - if let Some(value) = value.max_sent_data { - builder.max_sent_data(value); - } + builder.max_sent_data(value.max_sent_data); - if let Some(value) = value.max_recv_data { - builder.max_recv_data(value); + if let Some(value) = value.max_recv_data_online { + builder.max_recv_data_online(value); } + builder.max_recv_data(value.max_recv_data); + let protocol_config = builder.build().unwrap(); - tlsn_prover::tls::ProverConfig::builder() + let mut builder = tlsn_prover::tls::ProverConfig::builder(); + builder .id(value.id) .server_dns(value.server_dns) - .protocol_config(protocol_config) - .build() - .unwrap() + .protocol_config(protocol_config); + + if let Some(value) = value.defer_decryption_from_start { + builder.defer_decryption_from_start(value); + } + + builder.build().unwrap() } } diff --git a/crates/wasm/src/prover/mod.rs b/crates/wasm/src/prover/mod.rs index 2c02939f70..9fa3c49cf1 100644 --- a/crates/wasm/src/prover/mod.rs +++ b/crates/wasm/src/prover/mod.rs @@ -83,16 +83,12 @@ impl JsProver { info!("connected to server"); let (tls_conn, prover_fut) = prover.connect(server_conn.into_io()).await?; - let prover_ctrl = prover_fut.control(); info!("sending request"); let (response, prover) = futures::try_join!( - async move { - prover_ctrl.defer_decryption().await?; - send_request(tls_conn, request).await - }, - prover_fut.map_err(Into::into), + send_request(tls_conn, request), + prover_fut.map_err(Into::into) )?; info!("response received"); diff --git a/crates/wasm/src/verifier/config.rs b/crates/wasm/src/verifier/config.rs index 1300ba6364..7e95f9f187 100644 --- a/crates/wasm/src/verifier/config.rs +++ b/crates/wasm/src/verifier/config.rs @@ -6,21 +6,16 @@ use tsify_next::Tsify; #[tsify(from_wasm_abi)] pub struct VerifierConfig { pub id: String, - pub max_sent_data: Option, - pub max_received_data: Option, + pub max_sent_data: usize, + pub max_received_data: usize, } impl From for tlsn_verifier::tls::VerifierConfig { fn from(value: VerifierConfig) -> Self { let mut builder = ProtocolConfigValidator::builder(); - if let Some(value) = value.max_sent_data { - builder.max_sent_data(value); - } - - if let Some(value) = value.max_received_data { - builder.max_recv_data(value); - } + builder.max_sent_data(value.max_sent_data); + builder.max_recv_data(value.max_received_data); let config_validator = builder.build().unwrap();