From 6255fd51fd26b4611a92f0d6708f523af7b4ec1a Mon Sep 17 00:00:00 2001 From: Millione Date: Thu, 11 Apr 2024 11:14:27 +0800 Subject: [PATCH] feat: add `&mut [u8]` implementation for `TOutputProtocol` --- pilota/src/thrift/binary.rs | 183 +++++++++++++++++++++ pilota/src/thrift/binary_unsafe.rs | 245 +++++++++++++++++++++++++++++ pilota/src/thrift/rw_ext.rs | 7 +- 3 files changed, 433 insertions(+), 2 deletions(-) diff --git a/pilota/src/thrift/binary.rs b/pilota/src/thrift/binary.rs index 86746581..f49cd848 100644 --- a/pilota/src/thrift/binary.rs +++ b/pilota/src/thrift/binary.rs @@ -559,6 +559,189 @@ impl TOutputProtocol for TBinaryProtocol<&mut LinkedBytes> { } } +impl<'a> TOutputProtocol for TBinaryProtocol<&'a mut [u8]> { + type BufMut = &'a mut [u8]; + + #[inline] + fn write_message_begin( + &mut self, + identifier: &TMessageIdentifier, + ) -> Result<(), ThriftException> { + let msg_type_u8: u8 = identifier.message_type.into(); + let version = (VERSION_1 | msg_type_u8 as u32) as i32; + self.write_i32(version)?; + self.write_faststr(identifier.name.clone())?; + self.write_i32(identifier.sequence_number)?; + Ok(()) + } + + #[inline] + fn write_message_end(&mut self) -> Result<(), ThriftException> { + Ok(()) + } + + #[inline] + fn write_struct_begin(&mut self, _: &TStructIdentifier) -> Result<(), ThriftException> { + Ok(()) + } + + #[inline] + fn write_struct_end(&mut self) -> Result<(), ThriftException> { + Ok(()) + } + + #[inline] + fn write_field_begin(&mut self, field_type: TType, id: i16) -> Result<(), ThriftException> { + let mut data: [u8; 3] = [0; 3]; + data[0] = field_type as u8; + let id = id.to_be_bytes(); + data[1] = id[0]; + data[2] = id[1]; + self.trans.write_slice(&data); + Ok(()) + } + + #[inline] + fn write_field_end(&mut self) -> Result<(), ThriftException> { + Ok(()) + } + + #[inline] + fn write_field_stop(&mut self) -> Result<(), ThriftException> { + self.write_byte(TType::Stop as u8) + } + + #[inline] + fn write_bool(&mut self, b: bool) -> Result<(), ThriftException> { + if b { + self.write_i8(1) + } else { + self.write_i8(0) + } + } + + #[inline] + fn write_bytes(&mut self, b: Bytes) -> Result<(), ThriftException> { + self.write_i32(b.len() as i32)?; + self.write_bytes_without_len(b) + } + + #[inline] + fn write_bytes_without_len(&mut self, b: Bytes) -> Result<(), ThriftException> { + self.trans.write_slice(&b); + Ok(()) + } + + #[inline] + fn write_byte(&mut self, b: u8) -> Result<(), ThriftException> { + self.trans.write_u8(b); + Ok(()) + } + + #[inline] + fn write_uuid(&mut self, u: [u8; 16]) -> Result<(), ThriftException> { + self.trans.write_slice(&u); + Ok(()) + } + + #[inline] + fn write_i8(&mut self, i: i8) -> Result<(), ThriftException> { + self.trans.write_i8(i); + Ok(()) + } + + #[inline] + fn write_i16(&mut self, i: i16) -> Result<(), ThriftException> { + self.trans.write_i16(i); + Ok(()) + } + + #[inline] + fn write_i32(&mut self, i: i32) -> Result<(), ThriftException> { + self.trans.write_i32(i); + Ok(()) + } + + #[inline] + fn write_i64(&mut self, i: i64) -> Result<(), ThriftException> { + self.trans.write_i64(i); + Ok(()) + } + + #[inline] + fn write_double(&mut self, d: f64) -> Result<(), ThriftException> { + self.trans.write_f64(d); + Ok(()) + } + + #[inline] + fn write_string(&mut self, s: &str) -> Result<(), ThriftException> { + self.write_i32(s.len() as i32)?; + self.trans.write_slice(s.as_bytes()); + Ok(()) + } + + #[inline] + fn write_faststr(&mut self, s: FastStr) -> Result<(), ThriftException> { + self.write_i32(s.len() as i32)?; + self.trans.write_slice(s.as_ref()); + Ok(()) + } + + #[inline] + fn write_list_begin(&mut self, identifier: TListIdentifier) -> Result<(), ThriftException> { + self.write_byte(identifier.element_type.into())?; + self.write_i32(identifier.size as i32) + } + + #[inline] + fn write_list_end(&mut self) -> Result<(), ThriftException> { + Ok(()) + } + + #[inline] + fn write_set_begin(&mut self, identifier: TSetIdentifier) -> Result<(), ThriftException> { + self.write_byte(identifier.element_type.into())?; + self.write_i32(identifier.size as i32) + } + + #[inline] + fn write_set_end(&mut self) -> Result<(), ThriftException> { + Ok(()) + } + + #[inline] + fn write_map_begin(&mut self, identifier: TMapIdentifier) -> Result<(), ThriftException> { + let key_type = identifier.key_type; + self.write_byte(key_type.into())?; + let val_type = identifier.value_type; + self.write_byte(val_type.into())?; + self.write_i32(identifier.size as i32) + } + + #[inline] + fn write_map_end(&mut self) -> Result<(), ThriftException> { + Ok(()) + } + + #[inline] + fn flush(&mut self) -> Result<(), ThriftException> { + Ok(()) + } + + #[inline] + fn write_bytes_vec(&mut self, b: &[u8]) -> Result<(), ThriftException> { + self.write_i32(b.len() as i32)?; + self.trans.write_slice(b); + Ok(()) + } + + #[inline] + fn buf_mut(&mut self) -> &mut Self::BufMut { + &mut self.trans + } +} + impl TInputProtocol for TBinaryProtocol<&mut Bytes> { type Buf = Bytes; diff --git a/pilota/src/thrift/binary_unsafe.rs b/pilota/src/thrift/binary_unsafe.rs index 87b5d87b..40971e69 100644 --- a/pilota/src/thrift/binary_unsafe.rs +++ b/pilota/src/thrift/binary_unsafe.rs @@ -730,6 +730,251 @@ impl TOutputProtocol for TBinaryUnsafeOutputProtocol<&mut LinkedBytes> { } } +impl<'a> TOutputProtocol for TBinaryUnsafeOutputProtocol<&'a mut [u8]> { + type BufMut = &'a mut [u8]; + + #[inline] + fn write_message_begin( + &mut self, + identifier: &TMessageIdentifier, + ) -> Result<(), ThriftException> { + let msg_type_u8: u8 = identifier.message_type.into(); + let version = (VERSION_1 | msg_type_u8 as u32) as i32; + self.write_i32(version)?; + self.write_faststr(identifier.name.clone())?; + self.write_i32(identifier.sequence_number)?; + Ok(()) + } + + #[inline] + fn write_message_end(&mut self) -> Result<(), ThriftException> { + Ok(()) + } + + #[inline] + fn write_struct_begin(&mut self, _: &TStructIdentifier) -> Result<(), ThriftException> { + Ok(()) + } + + #[inline] + fn write_struct_end(&mut self) -> Result<(), ThriftException> { + Ok(()) + } + + #[inline] + fn write_field_begin(&mut self, field_type: TType, id: i16) -> Result<(), ThriftException> { + unsafe { + *self.trans.get_unchecked_mut(self.index) = field_type as u8; + let buf: &mut [u8; 2] = self + .trans + .get_unchecked_mut(self.index + 1..self.index + 3) + .try_into() + .unwrap_unchecked(); + *buf = id.to_be_bytes(); + self.index += 3; + } + Ok(()) + } + + #[inline] + fn write_field_end(&mut self) -> Result<(), ThriftException> { + Ok(()) + } + + #[inline] + fn write_field_stop(&mut self) -> Result<(), ThriftException> { + self.write_byte(TType::Stop as u8) + } + + #[inline] + fn write_bool(&mut self, b: bool) -> Result<(), ThriftException> { + if b { + self.write_i8(1) + } else { + self.write_i8(0) + } + } + + #[inline] + fn write_bytes(&mut self, b: Bytes) -> Result<(), ThriftException> { + self.write_i32(b.len() as i32)?; + self.write_bytes_without_len(b) + } + + #[inline] + fn write_bytes_without_len(&mut self, b: Bytes) -> Result<(), ThriftException> { + unsafe { + ptr::copy_nonoverlapping(b.as_ptr(), self.trans.as_mut_ptr().add(self.index), b.len()); + self.index += b.len(); + } + Ok(()) + } + + #[inline] + fn write_byte(&mut self, b: u8) -> Result<(), ThriftException> { + unsafe { + *self.trans.get_unchecked_mut(self.index) = b; + self.index += 1; + } + Ok(()) + } + + #[inline] + fn write_uuid(&mut self, u: [u8; 16]) -> Result<(), ThriftException> { + unsafe { + let buf: &mut [u8; 16] = self + .trans + .get_unchecked_mut(self.index..self.index + 16) + .try_into() + .unwrap_unchecked(); + *buf = u; + self.index += 16; + } + Ok(()) + } + + #[inline] + fn write_i8(&mut self, i: i8) -> Result<(), ThriftException> { + unsafe { + *self.trans.get_unchecked_mut(self.index) = *i.to_be_bytes().get_unchecked(0); + self.index += 1; + } + Ok(()) + } + + #[inline] + fn write_i16(&mut self, i: i16) -> Result<(), ThriftException> { + unsafe { + let buf: &mut [u8; 2] = self + .trans + .get_unchecked_mut(self.index..self.index + 2) + .try_into() + .unwrap_unchecked(); + *buf = i.to_be_bytes(); + self.index += 2; + } + Ok(()) + } + + #[inline] + fn write_i32(&mut self, i: i32) -> Result<(), ThriftException> { + unsafe { + let buf: &mut [u8; 4] = self + .trans + .get_unchecked_mut(self.index..self.index + 4) + .try_into() + .unwrap_unchecked(); + *buf = i.to_be_bytes(); + self.index += 4; + } + Ok(()) + } + + #[inline] + fn write_i64(&mut self, i: i64) -> Result<(), ThriftException> { + unsafe { + let buf: &mut [u8; 8] = self + .trans + .get_unchecked_mut(self.index..self.index + 8) + .try_into() + .unwrap_unchecked(); + *buf = i.to_be_bytes(); + self.index += 8; + } + Ok(()) + } + + #[inline] + fn write_double(&mut self, d: f64) -> Result<(), ThriftException> { + unsafe { + let buf: &mut [u8; 8] = self + .trans + .get_unchecked_mut(self.index..self.index + 8) + .try_into() + .unwrap_unchecked(); + *buf = d.to_bits().to_be_bytes(); + self.index += 8; + } + Ok(()) + } + + #[inline] + fn write_string(&mut self, s: &str) -> Result<(), ThriftException> { + self.write_i32(s.len() as i32)?; + unsafe { + ptr::copy_nonoverlapping(s.as_ptr(), self.trans.as_mut_ptr().add(self.index), s.len()); + self.index += s.len(); + } + Ok(()) + } + + #[inline] + fn write_faststr(&mut self, s: FastStr) -> Result<(), ThriftException> { + self.write_i32(s.len() as i32)?; + unsafe { + ptr::copy_nonoverlapping(s.as_ptr(), self.trans.as_mut_ptr().add(self.index), s.len()); + self.index += s.len(); + } + Ok(()) + } + + #[inline] + fn write_list_begin(&mut self, identifier: TListIdentifier) -> Result<(), ThriftException> { + self.write_byte(identifier.element_type.into())?; + self.write_i32(identifier.size as i32) + } + + #[inline] + fn write_list_end(&mut self) -> Result<(), ThriftException> { + Ok(()) + } + + #[inline] + fn write_set_begin(&mut self, identifier: TSetIdentifier) -> Result<(), ThriftException> { + self.write_byte(identifier.element_type.into())?; + self.write_i32(identifier.size as i32) + } + + #[inline] + fn write_set_end(&mut self) -> Result<(), ThriftException> { + Ok(()) + } + + #[inline] + fn write_map_begin(&mut self, identifier: TMapIdentifier) -> Result<(), ThriftException> { + let key_type = identifier.key_type; + self.write_byte(key_type.into())?; + let val_type = identifier.value_type; + self.write_byte(val_type.into())?; + self.write_i32(identifier.size as i32) + } + + #[inline] + fn write_map_end(&mut self) -> Result<(), ThriftException> { + Ok(()) + } + + #[inline] + fn flush(&mut self) -> Result<(), ThriftException> { + Ok(()) + } + + #[inline] + fn write_bytes_vec(&mut self, b: &[u8]) -> Result<(), ThriftException> { + self.write_i32(b.len() as i32)?; + unsafe { + ptr::copy_nonoverlapping(b.as_ptr(), self.trans.as_mut_ptr().add(self.index), b.len()); + self.index += b.len(); + } + Ok(()) + } + + #[inline] + fn buf_mut(&mut self) -> &mut Self::BufMut { + &mut self.trans + } +} + pub struct TBinaryUnsafeInputProtocol<'a> { pub(crate) trans: &'a mut Bytes, pub(crate) buf: &'a [u8], diff --git a/pilota/src/thrift/rw_ext.rs b/pilota/src/thrift/rw_ext.rs index a0bdc8e7..c7e9e059 100644 --- a/pilota/src/thrift/rw_ext.rs +++ b/pilota/src/thrift/rw_ext.rs @@ -1,6 +1,6 @@ use std::mem; -use bytes::{Buf as _, BufMut, BytesMut}; +use bytes::{Buf as _, BufMut}; use super::{new_protocol_exception, ThriftException}; @@ -128,7 +128,10 @@ pub trait WriteExt { fn write_f64_le(&mut self, n: f64); } -impl WriteExt for BytesMut { +impl WriteExt for B +where + B: BufMut, +{ #[inline] fn write_slice(&mut self, src: &[u8]) { self.put_slice(src);