diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 54951d766..41d85f5d6 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -40,8 +40,9 @@ web-time = { workspace = true } webpki-roots = { workspace = true } [dev-dependencies] -rstest = { workspace = true } +bincode = { workspace = true } hex = { workspace = true } +rstest = { workspace = true } tlsn-data-fixtures = { workspace = true } [[test]] diff --git a/crates/core/src/transcript.rs b/crates/core/src/transcript.rs index 61e375875..0d6c07330 100644 --- a/crates/core/src/transcript.rs +++ b/crates/core/src/transcript.rs @@ -152,8 +152,8 @@ impl Transcript { PartialTranscript { sent, received, - sent_authed: sent_idx, - received_authed: recv_idx, + sent_authed_idx: sent_idx, + received_authed_idx: recv_idx, } } } @@ -163,16 +163,83 @@ impl Transcript { /// A partial transcript is a transcript which may not have all the data /// authenticated. #[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(try_from = "validation::PartialTranscriptUnchecked")] +#[serde(try_from = "CompressedPartialTranscript")] +#[serde(into = "CompressedPartialTranscript")] +#[cfg_attr(test, derive(PartialEq))] pub struct PartialTranscript { /// Data sent from the Prover to the Server. sent: Vec, /// Data received by the Prover from the Server. received: Vec, /// Index of `sent` which have been authenticated. - sent_authed: Idx, + sent_authed_idx: Idx, /// Index of `received` which have been authenticated. - received_authed: Idx, + received_authed_idx: Idx, +} + +/// `PartialTranscript` in a compressed form. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(try_from = "validation::CompressedPartialTranscriptUnchecked")] +pub struct CompressedPartialTranscript { + /// Sent data which has been authenticated. + sent_authed: Vec, + /// Received data which has been authenticated. + received_authed: Vec, + /// Index of `sent_authed`. + sent_idx: Idx, + /// Index of `received_authed`. + recv_idx: Idx, + /// Total bytelength of sent data in the original partial transcript. + sent_total: usize, + /// Total bytelength of received data in the original partial transcript. + recv_total: usize, +} + +impl From for CompressedPartialTranscript { + fn from(uncompressed: PartialTranscript) -> Self { + Self { + sent_authed: uncompressed + .sent + .index_ranges(&uncompressed.sent_authed_idx.0), + received_authed: uncompressed + .received + .index_ranges(&uncompressed.received_authed_idx.0), + sent_idx: uncompressed.sent_authed_idx, + recv_idx: uncompressed.received_authed_idx, + sent_total: uncompressed.sent.len(), + recv_total: uncompressed.received.len(), + } + } +} + +impl From for PartialTranscript { + fn from(compressed: CompressedPartialTranscript) -> Self { + let mut sent = vec![0; compressed.sent_total]; + let mut received = vec![0; compressed.recv_total]; + + let mut offset = 0; + + for range in compressed.sent_idx.iter_ranges() { + sent[range.clone()] + .copy_from_slice(&compressed.sent_authed[offset..offset + range.len()]); + offset += range.len(); + } + + let mut offset = 0; + + for range in compressed.recv_idx.iter_ranges() { + received[range.clone()] + .copy_from_slice(&compressed.received_authed[offset..offset + range.len()]); + offset += range.len(); + } + + Self { + sent, + received, + sent_authed_idx: compressed.sent_idx, + received_authed_idx: compressed.recv_idx, + } + } } impl PartialTranscript { @@ -186,8 +253,8 @@ impl PartialTranscript { Self { sent: vec![0; sent_len], received: vec![0; received_len], - sent_authed: Idx::default(), - received_authed: Idx::default(), + sent_authed_idx: Idx::default(), + received_authed_idx: Idx::default(), } } @@ -203,8 +270,8 @@ impl PartialTranscript { /// Returns whether the transcript is complete. pub fn is_complete(&self) -> bool { - self.sent_authed.len() == self.sent.len() - && self.received_authed.len() == self.received.len() + self.sent_authed_idx.len() == self.sent.len() + && self.received_authed_idx.len() == self.received.len() } /// Returns whether the index is in bounds of the transcript. @@ -239,29 +306,29 @@ impl PartialTranscript { /// Returns the index of sent data which have been authenticated. pub fn sent_authed(&self) -> &Idx { - &self.sent_authed + &self.sent_authed_idx } /// Returns the index of received data which have been authenticated. pub fn received_authed(&self) -> &Idx { - &self.received_authed + &self.received_authed_idx } /// Returns the index of sent data which haven't been authenticated. pub fn sent_unauthed(&self) -> Idx { - Idx(RangeSet::from(0..self.sent.len()).difference(&self.sent_authed.0)) + Idx(RangeSet::from(0..self.sent.len()).difference(&self.sent_authed_idx.0)) } /// Returns the index of received data which haven't been authenticated. pub fn received_unauthed(&self) -> Idx { - Idx(RangeSet::from(0..self.received.len()).difference(&self.received_authed.0)) + Idx(RangeSet::from(0..self.received.len()).difference(&self.received_authed_idx.0)) } /// Returns an iterator over the authenticated data in the transcript. pub fn iter(&self, direction: Direction) -> impl Iterator + '_ { let (data, authed) = match direction { - Direction::Sent => (&self.sent, &self.sent_authed), - Direction::Received => (&self.received, &self.received_authed), + Direction::Sent => (&self.sent, &self.sent_authed_idx), + Direction::Received => (&self.received, &self.received_authed_idx), }; authed.0.iter().map(|i| data[i]) @@ -285,25 +352,25 @@ impl PartialTranscript { ); for range in other - .sent_authed + .sent_authed_idx .0 - .difference(&self.sent_authed.0) + .difference(&self.sent_authed_idx.0) .iter_ranges() { self.sent[range.clone()].copy_from_slice(&other.sent[range]); } for range in other - .received_authed + .received_authed_idx .0 - .difference(&self.received_authed.0) + .difference(&self.received_authed_idx.0) .iter_ranges() { self.received[range.clone()].copy_from_slice(&other.received[range]); } - self.sent_authed = self.sent_authed.union(&other.sent_authed); - self.received_authed = self.received_authed.union(&other.received_authed); + self.sent_authed_idx = self.sent_authed_idx.union(&other.sent_authed_idx); + self.received_authed_idx = self.received_authed_idx.union(&other.received_authed_idx); } /// Unions an authenticated subsequence into this transcript. @@ -315,11 +382,11 @@ impl PartialTranscript { match direction { Direction::Sent => { seq.copy_to(&mut self.sent); - self.sent_authed = self.sent_authed.union(&seq.idx); + self.sent_authed_idx = self.sent_authed_idx.union(&seq.idx); } Direction::Received => { seq.copy_to(&mut self.received); - self.received_authed = self.received_authed.union(&seq.idx); + self.received_authed_idx = self.received_authed_idx.union(&seq.idx); } } } @@ -348,12 +415,12 @@ impl PartialTranscript { pub fn set_unauthed_range(&mut self, value: u8, direction: Direction, range: Range) { match direction { Direction::Sent => { - for range in range.difference(&self.sent_authed.0).iter_ranges() { + for range in range.difference(&self.sent_authed_idx.0).iter_ranges() { self.sent[range].fill(value); } } Direction::Received => { - for range in range.difference(&self.received_authed.0).iter_ranges() { + for range in range.difference(&self.received_authed_idx.0).iter_ranges() { self.received[range].fill(value); } } @@ -549,51 +616,118 @@ mod validation { } } - /// Invalid partial transcript error. + /// Invalid compressed partial transcript error. #[derive(Debug, thiserror::Error)] - #[error("invalid partial transcript: {0}")] - pub struct InvalidPartialTranscript(&'static str); + #[error("invalid compressed partial transcript: {0}")] + pub struct InvalidCompressedPartialTranscript(&'static str); #[derive(Debug, Deserialize)] - pub(super) struct PartialTranscriptUnchecked { - sent: Vec, - received: Vec, - sent_authed: Idx, - received_authed: Idx, - } - - impl TryFrom for PartialTranscript { - type Error = InvalidPartialTranscript; - - fn try_from(unchecked: PartialTranscriptUnchecked) -> Result { - if unchecked.sent_authed.end() > unchecked.sent.len() - || unchecked.received_authed.end() > unchecked.received.len() + #[cfg_attr(test, derive(Serialize))] + pub(super) struct CompressedPartialTranscriptUnchecked { + sent_authed: Vec, + received_authed: Vec, + sent_idx: Idx, + recv_idx: Idx, + sent_total: usize, + recv_total: usize, + } + + impl TryFrom for CompressedPartialTranscript { + type Error = InvalidCompressedPartialTranscript; + + fn try_from(unchecked: CompressedPartialTranscriptUnchecked) -> Result { + if unchecked.sent_authed.len() != unchecked.sent_idx.len() + || unchecked.received_authed.len() != unchecked.recv_idx.len() { - return Err(InvalidPartialTranscript( - "authenticated ranges are not in bounds of the data", + return Err(InvalidCompressedPartialTranscript( + "lengths of index and data don't match", )); } - // Rewrite the data to ensure that unauthenticated data is zeroed out. - let mut sent = vec![0; unchecked.sent.len()]; - let mut received = vec![0; unchecked.received.len()]; - - for range in unchecked.sent_authed.iter_ranges() { - sent[range.clone()].copy_from_slice(&unchecked.sent[range]); - } - - for range in unchecked.received_authed.iter_ranges() { - received[range.clone()].copy_from_slice(&unchecked.received[range]); + if unchecked.sent_idx.end() > unchecked.sent_total + || unchecked.recv_idx.end() > unchecked.recv_total + { + return Err(InvalidCompressedPartialTranscript( + "ranges are not in bounds of the data", + )); } Ok(Self { - sent, - received, - sent_authed: unchecked.sent_authed, received_authed: unchecked.received_authed, + recv_idx: unchecked.recv_idx, + recv_total: unchecked.recv_total, + sent_authed: unchecked.sent_authed, + sent_idx: unchecked.sent_idx, + sent_total: unchecked.sent_total, }) } } + + #[cfg(test)] + mod tests { + use rstest::{fixture, rstest}; + + use super::*; + + #[fixture] + fn partial_transcript() -> CompressedPartialTranscriptUnchecked { + CompressedPartialTranscriptUnchecked { + received_authed: vec![1, 2, 3, 11, 12, 13], + sent_authed: vec![4, 5, 6, 14, 15, 16], + recv_idx: Idx(RangeSet::new(&[1..4, 11..14])), + sent_idx: Idx(RangeSet::new(&[4..7, 14..17])), + sent_total: 20, + recv_total: 20, + } + } + + #[rstest] + fn test_partial_transcript_valid(partial_transcript: CompressedPartialTranscriptUnchecked) { + let bytes = bincode::serialize(&partial_transcript).unwrap(); + let transcript: Result> = + bincode::deserialize(&bytes); + assert!(transcript.is_ok()); + } + + #[rstest] + // Expect to fail since the length of data and the length of the index do not + // match. + fn test_partial_transcript_invalid_lengths( + mut partial_transcript: CompressedPartialTranscriptUnchecked, + ) { + // Add an extra byte to the data. + let mut old = partial_transcript.sent_authed; + old.extend([1]); + partial_transcript.sent_authed = old; + + let bytes = bincode::serialize(&partial_transcript).unwrap(); + let transcript: Result> = + bincode::deserialize(&bytes); + assert!(transcript.is_err()); + } + + #[rstest] + // Expect to fail since the index is out of bounds. + fn test_partial_transcript_invalid_ranges( + mut partial_transcript: CompressedPartialTranscriptUnchecked, + ) { + // Change the total to be less than the last range's end bound. + let end = partial_transcript + .sent_idx + .0 + .iter_ranges() + .last() + .unwrap() + .end; + + partial_transcript.sent_total = end - 1; + + let bytes = bincode::serialize(&partial_transcript).unwrap(); + let transcript: Result> = + bincode::deserialize(&bytes); + assert!(transcript.is_err()); + } + } } #[cfg(test)] @@ -610,6 +744,14 @@ mod tests { ) } + #[fixture] + fn partial_transcript() -> PartialTranscript { + transcript().to_partial( + Idx::new(RangeSet::new(&[1..4, 6..9])), + Idx::new(RangeSet::new(&[2..5, 7..10])), + ) + } + #[rstest] fn test_transcript_get_subsequence(transcript: Transcript) { let subseq = transcript @@ -632,6 +774,13 @@ mod tests { assert_eq!(subseq, None); } + #[rstest] + fn test_partial_transcript_serialization_ok(partial_transcript: PartialTranscript) { + let bytes = bincode::serialize(&partial_transcript).unwrap(); + let deserialized_transcript: PartialTranscript = bincode::deserialize(&bytes).unwrap(); + assert_eq!(partial_transcript, deserialized_transcript); + } + #[rstest] fn test_transcript_to_partial_success(transcript: Transcript) { let partial = transcript.to_partial(Idx::new(0..2), Idx::new(3..7));