Skip to content

Commit

Permalink
Add protocol version to message header (#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
guilload authored Mar 14, 2024
1 parent 95b2edd commit d9049ab
Showing 1 changed file with 52 additions and 14 deletions.
66 changes: 52 additions & 14 deletions chitchat/src/message.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::io::BufRead;

use anyhow::Context;
use anyhow::{bail, Context};

use crate::delta::Delta;
use crate::digest::Digest;
Expand All @@ -26,7 +26,26 @@ pub enum ChitchatMessage {
BadCluster,
}

#[derive(Copy, Clone)]
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
enum ProtocolVersion {
V0 = 0,
}

impl ProtocolVersion {
pub fn from_code(code: u8) -> Option<Self> {
match code {
0 => Some(Self::V0),
_ => None,
}
}

pub fn to_code(self) -> u8 {
self as u8
}
}

#[derive(Clone, Copy)]
#[repr(u8)]
enum MessageType {
Syn = 0,
Expand All @@ -45,13 +64,16 @@ impl MessageType {
_ => None,
}
}

pub fn to_code(self) -> u8 {
self as u8
}
}

impl Serializable for ChitchatMessage {
fn serialize(&self, buf: &mut Vec<u8>) {
ProtocolVersion::V0.to_code().serialize(buf);

match self {
ChitchatMessage::Syn { cluster_id, digest } => {
buf.push(MessageType::Syn.to_code());
Expand All @@ -74,7 +96,7 @@ impl Serializable for ChitchatMessage {
}

fn serialized_len(&self) -> usize {
match self {
1 + match self {
ChitchatMessage::Syn { cluster_id, digest } => {
1 + cluster_id.serialized_len() + digest.serialized_len()
}
Expand All @@ -89,13 +111,28 @@ impl Serializable for ChitchatMessage {

impl Deserializable for ChitchatMessage {
fn deserialize(buf: &mut &[u8]) -> anyhow::Result<Self> {
let code = buf
let protocol_version = buf
.first()
.copied()
.and_then(ProtocolVersion::from_code)
.context("invalid protocol version")?;

if protocol_version != ProtocolVersion::V0 {
bail!(
"unsupported protocol version `{}`",
protocol_version.to_code()
)
}
buf.consume(1);

let message_type = buf
.first()
.copied()
.and_then(MessageType::from_code)
.context("invalid message type")?;
buf.consume(1);
match code {

match message_type {
MessageType::Syn => {
let digest = Digest::deserialize(buf)?;
let cluster_id = String::deserialize(buf)?;
Expand Down Expand Up @@ -127,7 +164,7 @@ mod tests {
cluster_id: "cluster-a".to_string(),
digest: Digest::default(),
};
test_serdeser_aux(&syn, 14);
test_serdeser_aux(&syn, 15);
}
{
let mut digest = Digest::default();
Expand All @@ -138,7 +175,7 @@ mod tests {
cluster_id: "cluster-a".to_string(),
digest,
};
test_serdeser_aux(&syn, 65);
test_serdeser_aux(&syn, 66);
}
}

Expand All @@ -149,8 +186,8 @@ mod tests {
digest: Digest::default(),
delta: Delta::default(),
};
// 1 (message tag) + 2 (digest len) + 1 (delta end op)
test_serdeser_aux(&syn_ack, 4);
// 1 (protocol version) + 1 (message tag) + 2 (digest len) + 1 (delta end op)
test_serdeser_aux(&syn_ack, 5);
}
{
// 2 bytes.
Expand All @@ -173,8 +210,9 @@ mod tests {
delta.set_serialized_len(60);

let syn_ack = ChitchatMessage::SynAck { digest, delta };
// 1 bytes (syn ack message) + 45 bytes (digest) + 69 bytes (delta).
test_serdeser_aux(&syn_ack, 1 + 53 + 60);
// 1 byte (protocol version) + 1 byte (message tag) + 53 bytes (digest) + 60 bytes
// (delta).
test_serdeser_aux(&syn_ack, 1 + 1 + 53 + 60);
}
}

Expand All @@ -183,7 +221,7 @@ mod tests {
{
let delta = Delta::default();
let ack = ChitchatMessage::Ack { delta };
test_serdeser_aux(&ack, 2);
test_serdeser_aux(&ack, 3);
}
{
// 4 bytes.
Expand All @@ -195,12 +233,12 @@ mod tests {
delta.add_kv(&node, "key", "value", 0, true);
delta.set_serialized_len(60);
let ack = ChitchatMessage::Ack { delta };
test_serdeser_aux(&ack, 1 + 60);
test_serdeser_aux(&ack, 1 + 1 + 60);
}
}

#[test]
fn test_bad_cluster() {
test_serdeser_aux(&ChitchatMessage::BadCluster, 1);
test_serdeser_aux(&ChitchatMessage::BadCluster, 2);
}
}

0 comments on commit d9049ab

Please sign in to comment.