Skip to content

Commit

Permalink
server_test.rs: Huge simplification.
Browse files Browse the repository at this point in the history
  • Loading branch information
egrimley-arm committed Jan 3, 2023
1 parent 2bab6e7 commit 1a8ddcd
Showing 1 changed file with 25 additions and 268 deletions.
293 changes: 25 additions & 268 deletions tests/tests/server_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,7 @@ use std::{
error::Error,
io::{Read, Write},
path::Path,
sync::{
atomic::{AtomicBool, Ordering},
mpsc::{channel, Receiver, Sender},
Arc, Mutex,
},
thread::{self, JoinHandle},
sync::Arc,
time::{Duration, Instant},
vec::Vec,
};
Expand Down Expand Up @@ -521,81 +516,10 @@ fn performance_set_intersection_sum() {
struct TestExecutor {
// The policy for the runtime.
policy: Policy,
// The hash of the policy, that is used in attestation
// The hash of the policy, that is used in attestation.
policy_hash: String,
// The emulated TLS connect from client to server.
client_tls_receiver: Receiver<Vec<u8>>,
client_tls_sender: Sender<Vec<u8>>,
// Paths to client certification and private key.
// Note that we only have one client in all tests.
client_connection: mbedtls::ssl::Context<InsecureConnection>,
// Read and write buffers shared with InsecureConnection.
shared_buffers: Arc<Mutex<Buffers>>,
// A alive flag. This is to solve the problem where the server thread still in loop while
// client thread is terminated.
alive_flag: Arc<AtomicBool>,
// Hold the server thread. The test will join the thread in the end to check the server
// state.
server_thread: JoinHandle<Result<()>>,
}

struct Buffers {
// Read buffer used by mbedtls for cyphertext.
read_buffer: Vec<u8>,
// Write buffer used by mbedtls for cyphertext.
write_buffer: Option<Vec<u8>>,
}

/// This is the structure given to mbedtls and used for reading and
/// writing cyphertext, using the standard Read and Write traits.
struct InsecureConnection {
// Read and write buffers shared with Session.
shared_buffers: Arc<Mutex<Buffers>>,
}

// To convert any error to a std::io error:
fn std_err(error_text: &str) -> std::io::Error {
std::io::Error::new(std::io::ErrorKind::Other, error_text)
}

impl Read for InsecureConnection {
fn read(&mut self, data: &mut [u8]) -> Result<usize, std::io::Error> {
// Return as much data from the read_buffer as fits.
let mut shared_buffers = self
.shared_buffers
.lock()
.map_err(|_| std_err("lock failed"))?;
let n = std::cmp::min(data.len(), shared_buffers.read_buffer.len());
if n == 0 {
Err(std::io::Error::new(
std::io::ErrorKind::WouldBlock,
"InsecureConnection Read",
))
} else {
data[0..n].clone_from_slice(&shared_buffers.read_buffer[0..n]);
shared_buffers.read_buffer = shared_buffers.read_buffer[n..].to_vec();
Ok(n)
}
}
}

impl Write for InsecureConnection {
fn write(&mut self, data: &[u8]) -> Result<usize, std::io::Error> {
// Append to write buffer.
let mut shared_buffers = self
.shared_buffers
.lock()
.map_err(|_| std_err("lock failed"))?;
match &mut shared_buffers.write_buffer {
None => shared_buffers.write_buffer = Some(data.to_vec()),
Some(x) => x.extend_from_slice(data),
}
// Return value to indicate that we handled all the data.
Ok(data.len())
}
fn flush(&mut self) -> Result<(), std::io::Error> {
Ok(())
}
client_connection: mbedtls::ssl::Context<VeracruzSession>,
}

impl TestExecutor {
Expand Down Expand Up @@ -654,127 +578,36 @@ impl TestExecutor {
&env::var("VERACRUZ_DATA_DIR").unwrap_or("../test-collateral".to_string()),
);

info!("Create simulated connection channels.");
// Create two channel, simulating the connecting channels.
let (server_tls_sender, client_tls_receiver) = channel::<Vec<u8>>();
let (client_tls_sender, server_tls_receiver) = channel::<Vec<u8>>();
info!("Initialise Veracruz runtime.");
// Create the server
let mut veracruz_server =
VeracruzServer::new(&policy_json).map_err(|e| anyhow!("{:?}", e))?;

let shared_buffers = Arc::new(Mutex::new(Buffers {
read_buffer: vec![],
write_buffer: None,
}));
// Create the client tls session.
let veracruz_session = veracruz_server
.new_session()
.map_err(|e| anyhow!("{:?}", e))?;

info!("Initialise a client with its certificate and key.");
// Create a fake client session which only ends to the simulated connecting channel.
let client_connection = create_client_test_connection(
client_cert_path,
client_key_path,
&policy.ciphersuite(),
Arc::clone(&shared_buffers),
veracruz_session.clone(),
)?;

info!("Initialise Veracruz runtime.");
// Create the server
let mut veracruz_server =
VeracruzServer::new(&policy_json).map_err(|e| anyhow!("{:?}", e))?;

// Create the client tls session. Note that we need the session id.
let mut veracruz_session = veracruz_server
.new_session()
.map_err(|e| anyhow!("{:?}", e))?;

info!("Spawn server thread.");
// Create the sever loop, it is the end of the previous created channels.
let alive_flag = Arc::new(AtomicBool::new(true));
let init_flag = Arc::new(AtomicBool::new(false));
// Create a clone which passes to server thread.
let alive_flag_clone = alive_flag.clone();
let init_flag_clone = init_flag.clone();
let server_thread = thread::spawn(move || {
if let Err(e) = TestExecutor::simulated_server(
&mut veracruz_session,
server_tls_sender,
server_tls_receiver,
alive_flag_clone.clone(),
init_flag_clone,
) {
alive_flag_clone.store(false, Ordering::SeqCst);
Err(e)
} else {
Ok(())
}
});
info!("A new test executor is created.");

// Block until the init_flag is set by the server thread.
while !init_flag.load(Ordering::SeqCst) {}

Ok(TestExecutor {
policy,
policy_hash,
client_connection,
shared_buffers,
client_tls_sender,
client_tls_receiver,
alive_flag,
server_thread,
})
}

/// This function simulating a Veracruz server, it should run on a separate thread.
fn simulated_server(
veracruz_session: &mut veracruz_server::VeracruzSession,
sender: Sender<Vec<u8>>,
receiver: Receiver<Vec<u8>>,
test_alive_flag: Arc<AtomicBool>,
test_init_flag: Arc<AtomicBool>,
) -> Result<()> {
info!("Server: simulated server loop starts...");

test_init_flag.store(true, Ordering::SeqCst);

let mut veracruz_session_clone = veracruz_session.clone();
let test_alive_flag_clone = test_alive_flag.clone();
thread::spawn(move || {
while test_alive_flag_clone.load(Ordering::SeqCst) {
let received = receiver.recv();
let received_buffer = received.map_err(|e| anyhow!("Server: {:?}", e)).unwrap();
veracruz_session_clone.write_all(&received_buffer).unwrap();
}
});

let mut veracruz_session_clone = veracruz_session.clone();
let test_alive_flag_clone = test_alive_flag.clone();
thread::spawn(move || {
while test_alive_flag_clone.load(Ordering::SeqCst) {
let mut buf = vec![0; 100000];
let n = veracruz_session_clone.read(&mut buf).unwrap();
if n == 0 {
break;
}
sender.send(buf[0..n].to_vec()).unwrap();
}
});

Ok(())
}

/// Execute this test. The client sends messages though the channel to the server
/// thread driven by `events`. It consumes the ownership of `self`,
/// because it will join server thread at the end.
fn execute(mut self, events: Vec<TestEvent>, timeout: Duration) -> anyhow::Result<bool> {
// Spawn a thread that will send the timeout signal by killing alive flag.
let alive_flag_clone = self.alive_flag.clone();
thread::spawn(move || {
thread::sleep(timeout);
if alive_flag_clone.load(Ordering::SeqCst) {
error!(
"--->>> Force timeout. It is very likely to trigger error on the test. <<<---"
);
}
alive_flag_clone.store(false, Ordering::SeqCst);
});

fn execute(mut self, events: Vec<TestEvent>, _timeout: Duration) -> anyhow::Result<bool> {
let mut error_occurred = false;

// process test events
Expand All @@ -783,7 +616,6 @@ impl TestExecutor {
let time_init = Instant::now();
let response = self.process_event(&event).map_err(|e| {
error!("Client: {:?}", e);
self.alive_flag.store(false, Ordering::SeqCst);
e
})?;
if response.get_status() != transport_protocol::ResponseStatus::SUCCESS {
Expand All @@ -797,11 +629,6 @@ impl TestExecutor {
);
}

// Wait the server to finish.
self.server_thread
.join()
.map_err(|e| anyhow!("server thread failed with error {:?}", e))?
.map_err(|e| anyhow!("{:?}", e))?;
Ok(!error_occurred)
}

Expand Down Expand Up @@ -937,85 +764,16 @@ impl TestExecutor {

/// The client sends TLS packages via the simulated channel.
fn client_send(&mut self, send_data: &[u8]) -> Result<Vec<u8>> {
info!(
"Client: client send with length of data {:?}",
send_data.len()
);
let connection = &mut self.client_connection;
let mut write_all_succeeded = false;
while self.alive_flag.load(Ordering::SeqCst) {
// connection.write_all
if !write_all_succeeded {
match connection.write_all(&send_data[..]) {
Ok(()) => write_all_succeeded = true,
Err(err) => {
if err.kind() == std::io::ErrorKind::WouldBlock {
()
} else {
return Err(anyhow!(
"Failed to send all data. Error produced: {:?}.",
err
));
}
}
}
}

// write_buffer.take
let taken = self
.shared_buffers
.lock()
.map_err(|_| anyhow!("lock failed"))?
.write_buffer
.take();
match taken {
None => (),
Some(output) => {
// client_tls_sender.send
self.client_tls_sender
.send(output)
.map_err(|e| {
anyhow!(
"Failed to send data on TX channel. Error produced: {:?}.",
e
)
})?;

// client_tls_receiver.recv
let received = self.client_tls_receiver.recv()?;

// read_buffer.extend_from_slice
self.shared_buffers
.lock()
.map_err(|_| anyhow!("lock failed"))?
.read_buffer
.extend_from_slice(&received);
}
}

// connection.read_to_end
let mut received_buffer: Vec<u8> = Vec::new();
let res = connection.read_to_end(&mut received_buffer);
if received_buffer.len() > 0 {
return Ok(received_buffer);
}
match res {
Ok(_) => (),
Err(err) => {
if err.kind() == std::io::ErrorKind::WouldBlock {
()
} else {
return Err(anyhow!(
"Failed to read data to end. Error produced: {:?}.",
err
));
}
}
}
}

// If reach here, it means the server crashed.
Err(anyhow!("Terminate due to server crash"))
connection.write_all(&send_data)?;
const PREFLEN: usize = transport_protocol::LENGTH_PREFIX_SIZE;
let mut length_buffer = [0; PREFLEN];
connection.read_exact(&mut length_buffer)?;
let length = PREFLEN + u64::from_be_bytes(length_buffer) as usize;
let mut response = length_buffer.to_vec();
response.resize(length, 0);
connection.read_exact(&mut response[PREFLEN..length])?;
Ok(response)
}
}

Expand Down Expand Up @@ -1064,8 +822,8 @@ fn create_client_test_connection<P: AsRef<Path>, Q: AsRef<Path>>(
client_cert_filename: P,
client_key_filename: Q,
ciphersuite_str: &str,
shared_buffers: Arc<Mutex<Buffers>>,
) -> Result<mbedtls::ssl::Context<InsecureConnection>> {
session: VeracruzSession,
) -> Result<mbedtls::ssl::Context<VeracruzSession>> {
let client_cert = read_cert_file(client_cert_filename)?;

let client_priv_key = read_priv_key_file(client_key_filename)?;
Expand Down Expand Up @@ -1101,8 +859,7 @@ fn create_client_test_connection<P: AsRef<Path>, Q: AsRef<Path>>(
config.set_ca_list(Arc::new(root_store), None);
config.push_cert(Arc::new(client_cert), Arc::new(client_priv_key))?;
let mut ctx = mbedtls::ssl::Context::new(Arc::new(config));
let conn = InsecureConnection { shared_buffers };
let _ = ctx.establish(conn, None);
let _ = ctx.establish(session, None);
Ok(ctx)
}

Expand Down

0 comments on commit 1a8ddcd

Please sign in to comment.