Skip to content

Commit

Permalink
Remove separate validation array (#1522)
Browse files Browse the repository at this point in the history
Right now, we have two types with the same shape:

```rust
/// Results of validating a single block
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub(crate) enum Validation {
    /// The block has no hash / context and is empty
    Empty,
    /// For an unencrypted block, the result is the hash
    Unencrypted(u64),
    /// For an encrypted block, the result is the tag + nonce
    Encrypted(crucible_protocol::EncryptionContext),
}
```

```rust
#[derive(Debug, PartialEq, Copy, Clone, Serialize, Deserialize)]
pub enum ReadBlockContext {
    Empty,
    Encrypted { ctx: EncryptionContext },
    Unencrypted { hash: u64 },
}
```

The `DownstairsIO` stores both of these types:
```rust
struct DownstairsIO {
    // other members elided
    data: Option<RawReadResponse>, // contains blocks: Vec<ReadBlockContext>
    pub hashes: Vec<Validation>,
}
```

By the time we write to `data`, the block contexts have _already_ been
validated, so storing them again in `hashes` adds unnecessary
duplication (and they're guaranteed to be the same).

This PR removes the duplicate `Validation` hashes, using the
`RawReadResponse::blocks` member instead.

There are small changes to `GuestBlockRes::transfer_and_notify` and
`Buffer::write_read_response`, because we can't take the entire
`Option<RawReadResponse>`; instead, we pass the `&mut BytesMut`
separately, so it can be sent back to the host.
  • Loading branch information
mkeeter authored Nov 5, 2024
1 parent 6277bf3 commit affcda4
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 139 deletions.
36 changes: 16 additions & 20 deletions upstairs/src/buffer.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// Copyright 2023 Oxide Computer Company
use crate::RawReadResponse;
use bytes::{Bytes, BytesMut};
use crucible_protocol::ReadBlockContext;
use itertools::Itertools;
Expand Down Expand Up @@ -141,14 +140,17 @@ impl Buffer {
/// # Panics
/// The response data length must be the same as our buffer length (which
/// must be an even multiple of block size, ensured at construction)
pub(crate) fn write_read_response(&mut self, response: RawReadResponse) {
assert!(response.data.len() == self.data.len());
assert_eq!(response.data.len() % self.block_size, 0);
pub(crate) fn write_read_response(
&mut self,
blocks: &[ReadBlockContext],
data: &mut BytesMut,
) {
assert!(data.len() == self.data.len());
assert_eq!(data.len() % self.block_size, 0);
let bs = self.block_size;

// Build contiguous chunks which are all owned, to copy in bulk
for (empty, mut group) in &response
.blocks
for (empty, mut group) in &blocks
.iter()
.enumerate()
.chunk_by(|(_i, b)| matches!(b, ReadBlockContext::Empty))
Expand All @@ -164,16 +166,13 @@ impl Buffer {

// Special case: if the entire buffer is owned, then we swap it
// instead of copying element-by-element.
if count == response.blocks.len()
&& self.data.len() == response.data.len()
{
self.data = response.data;
if count == blocks.len() && self.data.len() == data.len() {
self.data = std::mem::take(data);
break;
} else {
// Otherwise, just copy the sub-region
self.data[(block * bs)..][..(count * bs)].copy_from_slice(
&response.data[(block * bs)..][..(count * bs)],
);
self.data[(block * bs)..][..(count * bs)]
.copy_from_slice(&data[(block * bs)..][..(count * bs)]);
}
}
}
Expand Down Expand Up @@ -493,7 +492,7 @@ mod test {
let mut rng = rand::thread_rng();
rng.fill_bytes(&mut data);

let blocks = (0..10)
let blocks: Vec<_> = (0..10)
.map(|i| {
if f(i) {
ReadBlockContext::Unencrypted { hash: 123 }
Expand All @@ -503,10 +502,7 @@ mod test {
})
.collect();

buf.write_read_response(RawReadResponse {
blocks,
data: data.clone(),
});
buf.write_read_response(&blocks, &mut data.clone());

for i in 0..10 {
let buf_chunk = &buf[i * 512..][..512];
Expand Down Expand Up @@ -564,12 +560,12 @@ mod test {
let mut rng = rand::thread_rng();
rng.fill_bytes(&mut data);

let blocks = (0..10)
let blocks: Vec<_> = (0..10)
.map(|_| ReadBlockContext::Unencrypted { hash: 123 })
.collect();

let prev_data_ptr = data.as_ptr();
buf.write_read_response(RawReadResponse { blocks, data });
buf.write_read_response(&blocks, &mut data);

assert_eq!(buf.data.as_ptr(), prev_data_ptr);
}
Expand Down
41 changes: 17 additions & 24 deletions upstairs/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use crate::{
ClientIOStateCount, ClientId, CrucibleDecoder, CrucibleError, DownstairsIO,
DsState, EncryptionContext, IOState, IOop, JobId, Message, RawReadResponse,
ReconcileIO, ReconcileIOState, RegionDefinitionStatus, RegionMetadata,
Validation,
};
use crucible_common::{x509::TLSContext, ExtentId, VerboseTimeout};
use crucible_protocol::{
Expand Down Expand Up @@ -1226,7 +1225,6 @@ impl DownstairsClient {
ds_id: JobId,
job: &mut DownstairsIO,
responses: Result<RawReadResponse, CrucibleError>,
read_validations: Vec<Validation>,
deactivate: bool,
extent_info: Option<ExtentInfo>,
) -> bool {
Expand Down Expand Up @@ -1357,7 +1355,7 @@ impl DownstairsClient {
*/
let read_data = responses.unwrap();
assert!(!read_data.blocks.is_empty());
if job.read_validations != read_validations {
if job.data.as_ref().unwrap().blocks != read_data.blocks {
// XXX This error needs to go to Nexus
// XXX This will become the "force all downstairs
// to stop and refuse to restart" mode.
Expand All @@ -1371,8 +1369,8 @@ impl DownstairsClient {
self.client_id,
ds_id,
self.cfg.session_id,
job.read_validations,
read_validations,
job.data.as_ref().unwrap().blocks,
read_data.blocks,
start_eid,
start_offset,
job.state,
Expand Down Expand Up @@ -1419,9 +1417,7 @@ impl DownstairsClient {
assert!(extent_info.is_none());
if jobs_completed_ok == 1 {
assert!(job.data.is_none());
assert!(job.read_validations.is_empty());
job.data = Some(read_data);
job.read_validations = read_validations;
assert!(!job.acked);
ackable = true;
debug!(self.log, "Read AckReady {}", ds_id.0);
Expand All @@ -1433,7 +1429,8 @@ impl DownstairsClient {
* that and verify they are the same.
*/
debug!(self.log, "Read already AckReady {ds_id}");
if job.read_validations != read_validations {
let job_blocks = &job.data.as_ref().unwrap().blocks;
if job_blocks != &read_data.blocks {
// XXX This error needs to go to Nexus
// XXX This will become the "force all downstairs
// to stop and refuse to restart" mode.
Expand All @@ -1444,8 +1441,8 @@ impl DownstairsClient {
job: {:?}",
self.client_id,
ds_id,
job.read_validations,
read_validations,
job_blocks,
read_data.blocks,
job,
);
}
Expand Down Expand Up @@ -2947,18 +2944,15 @@ fn update_net_done_probes(m: &Message, cid: ClientId) {
}

/// Returns:
/// - `Ok(Some(ctx))` for successfully decrypted data
/// - `Ok(None)` if there is no block context and the block is all 0
/// - `Ok(())` for successfully decrypted data, or if there is no block context
/// and the block is all 0s (i.e. a valid empty block)
/// - `Err(..)` otherwise
///
/// The return value of this will be stored with the job, and compared
/// between each read.
pub(crate) fn validate_encrypted_read_response(
block_context: Option<crucible_protocol::EncryptionContext>,
data: &mut [u8],
encryption_context: &EncryptionContext,
log: &Logger,
) -> Result<Validation, CrucibleError> {
) -> Result<(), CrucibleError> {
// XXX because we don't have block generation numbers, an attacker
// downstairs could:
//
Expand All @@ -2980,7 +2974,7 @@ pub(crate) fn validate_encrypted_read_response(
//
// XXX if it's not a blank block, we may be under attack?
if data.iter().all(|&x| x == 0) {
return Ok(Validation::Empty);
return Ok(());
} else {
error!(log, "got empty block context with non-blank block");
return Err(CrucibleError::MissingBlockContext);
Expand All @@ -3002,29 +2996,28 @@ pub(crate) fn validate_encrypted_read_response(
Tag::from_slice(&ctx.tag[..]),
);
if decryption_result.is_ok() {
Ok(Validation::Encrypted(ctx))
Ok(())
} else {
error!(log, "Decryption failed!");
Err(CrucibleError::DecryptionError)
}
}

/// Returns:
/// - Ok(Some(valid_hash)) where the integrity hash matches
/// - Ok(None) where there is no integrity hash in the response and the
/// block is all 0
/// - Ok(()) where the integrity hash matches (or the integrity hash is missing
/// and the block is all 0s, indicating an empty block)
/// - Err otherwise
pub(crate) fn validate_unencrypted_read_response(
block_hash: Option<u64>,
data: &mut [u8],
log: &Logger,
) -> Result<Validation, CrucibleError> {
) -> Result<(), CrucibleError> {
if let Some(hash) = block_hash {
// check integrity hashes - make sure it is correct
let computed_hash = integrity_hash(&[data]);

if computed_hash == hash {
Ok(Validation::Unencrypted(computed_hash))
Ok(())
} else {
// No integrity hash was correct for this response
error!(log, "No match computed hash:0x{:x}", computed_hash,);
Expand All @@ -3048,7 +3041,7 @@ pub(crate) fn validate_unencrypted_read_response(
//
// XXX if it's not a blank block, we may be under attack?
if data[..].iter().all(|&x| x == 0) {
Ok(Validation::Empty)
Ok(())
} else {
error!(log, "got empty block context with non-blank block");
Err(CrucibleError::MissingBlockContext)
Expand Down
51 changes: 24 additions & 27 deletions upstairs/src/deferred.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ use std::sync::Arc;
use crate::{
backpressure::BackpressureGuard, client::ConnectionId,
upstairs::UpstairsConfig, BlockContext, BlockOp, ClientData, ClientId,
ImpactedBlocks, Message, RawWrite, Validation,
ImpactedBlocks, Message, RawWrite,
};
use bytes::BytesMut;
use crucible_common::{integrity_hash, CrucibleError, RegionDefinition};
use crucible_protocol::ReadBlockContext;
use crucible_protocol::{ReadBlockContext, ReadResponseHeader};
use futures::{
future::{ready, Either, Ready},
stream::FuturesOrdered,
Expand Down Expand Up @@ -192,11 +192,13 @@ impl DeferredWrite {

#[derive(Debug)]
pub(crate) struct DeferredMessage {
/// Message received from the client
///
/// If the deferred message was a read, then the data and context blocks in
/// this [Message::ReadResponse] has been validated (and decrypted if
/// necessary).
pub message: Message,

/// If this was a `ReadResponse`, then the validation result is stored here
pub hashes: Vec<Validation>,

pub client_id: ClientId,

/// See `DeferredRead::connection_id`
Expand All @@ -205,8 +207,8 @@ pub(crate) struct DeferredMessage {

/// Standalone data structure which can perform decryption
pub(crate) struct DeferredRead {
/// Message, which must be a `ReadResponse`
pub message: Message,
pub header: ReadResponseHeader,
pub data: BytesMut,

/// Unique ID for this particular connection to the downstairs
///
Expand All @@ -225,20 +227,16 @@ impl DeferredRead {
/// Consume the `DeferredRead` and perform decryption
///
/// If decryption fails, then the resulting `Message` has an error in the
/// `responses` field, and `hashes` is empty.
/// `responses` field.
pub fn run(mut self) -> DeferredMessage {
use crate::client::{
validate_encrypted_read_response,
validate_unencrypted_read_response,
};
let Message::ReadResponse { header, data } = &mut self.message else {
panic!("invalid DeferredRead");
};
let mut hashes = vec![];

if let Ok(rs) = header.blocks.as_mut() {
assert_eq!(data.len() % rs.len(), 0);
let block_size = data.len() / rs.len();
if let Ok(rs) = self.header.blocks.as_mut() {
assert_eq!(self.data.len() % rs.len(), 0);
let block_size = self.data.len() / rs.len();
for (i, r) in rs.iter_mut().enumerate() {
let v = if let Some(ctx) = &self.cfg.encryption_context {
match r {
Expand All @@ -256,7 +254,7 @@ impl DeferredRead {
.and_then(|r| {
validate_encrypted_read_response(
r,
&mut data[i * block_size..][..block_size],
&mut self.data[i * block_size..][..block_size],
ctx,
&self.log,
)
Expand All @@ -279,28 +277,27 @@ impl DeferredRead {
.and_then(|r| {
validate_unencrypted_read_response(
r,
&mut data[i * block_size..][..block_size],
&mut self.data[i * block_size..][..block_size],
&self.log,
)
})
};
match v {
Ok(hash) => hashes.push(hash),
Err(e) => {
error!(self.log, "decryption failure: {e:?}");
header.blocks = Err(e);
hashes.clear();
break;
}
if let Err(e) = v {
error!(self.log, "decryption failure: {e:?}");
self.header.blocks = Err(e);
break;
}
}
}

let message = Message::ReadResponse {
header: self.header,
data: self.data,
};
DeferredMessage {
client_id: self.client_id,
message: self.message,
message,
connection_id: self.connection_id,
hashes,
}
}
}
Loading

0 comments on commit affcda4

Please sign in to comment.