Skip to content

Commit

Permalink
#Rust Add setter for string fields
Browse files Browse the repository at this point in the history
  • Loading branch information
candysonya committed Sep 27, 2023
1 parent 5bdeb2f commit 7e5b92b
Show file tree
Hide file tree
Showing 3 changed files with 289 additions and 3 deletions.
1 change: 1 addition & 0 deletions rust/reflection/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ edition = "2021"
[dependencies]
flatbuffers = { path = "../flatbuffers"}
escape_string = "0.1.2"
stdint = "0.2.0"
269 changes: 266 additions & 3 deletions rust/reflection/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use reflection_generated::reflection::*;
use core::mem::size_of;
use escape_string::escape;
use std::fmt::Debug;
use stdint::uintmax_t;

/// Gets the root table for a trusted Flatbuffer.
///
Expand Down Expand Up @@ -272,6 +273,80 @@ pub fn set_field<'a, T: EndianScalar>(
true
}

/// Sets a string field to a new value. Returns false if the field doesn't point to a valid buffer location or is not of string type.
///
/// # Safety
///
/// The buffer should be valid and conform to the schema.
pub unsafe fn set_string<'a>(
buf: &mut Vec<u8>,
field_type: BaseType,
field_loc: Option<usize>,
v: String,
schema: &Schema,
root_loc: usize,
) -> bool {
if field_type != BaseType::String {
return false;
}
if field_loc.is_none() {
return v == "";
}

let field_loc = field_loc.unwrap();
if buf.len() < field_loc + get_type_size(BaseType::String) {
return false;
}
// Safe because the index was verified above.
let string_loc = unsafe { deref_uoffset(buf, field_loc) };

if buf.len() < string_loc + SIZE_UOFFSET {
return false;
}
// Safe because the index was verified above.
let len_old = unsafe { read_uoffset(buf, string_loc) };

if buf.len() < string_loc + SIZE_UOFFSET + len_old {
return false;
}

let len_new = v.len();
let delta = len_new as i32 - len_old as i32;
let mut bytes_to_insert = v.into_bytes();

if delta != 0 {
// Rounds the delta up to the nearest multiple of the maximum int size to keep the types after the insersion point aligned.
let mask = (size_of::<uintmax_t>() - 1) as i32;
let offset = (delta + mask) & !mask;
let visited_vec = &mut Vec::<bool>::new();
visited_vec.resize_with(buf.len(), Default::default);

if offset != 0 {
update_offset(
buf,
root_loc,
visited_vec,
&schema.root_table().unwrap(),
schema,
string_loc,
offset,
);
// Sets the new length.
emplace_scalar::<SOffsetT>(
&mut buf[string_loc..string_loc + SIZE_UOFFSET],
len_new as SOffsetT,
);
}

// Pads the bytes vector with 0 if `offset` doesn't equal `delta`.
bytes_to_insert.resize(bytes_to_insert.len() + (offset - delta) as usize, 0);
}

// Replaces the data.
buf.splice(string_loc + SIZE_SOFFSET..string_loc + SIZE_UOFFSET + len_old, bytes_to_insert);
true
}

/// Returns the size of a scalar type in the `BaseType` enum. In the case of structs, returns the size of their offset (`UOffsetT`) in the buffer.
fn get_type_size(base_type: BaseType) -> usize {
match base_type {
Expand Down Expand Up @@ -400,7 +475,7 @@ unsafe fn get_any_value_string(
/// # Safety
///
/// Caller must ensure `buf.len() >= loc + size_of::<T>()`.
pub unsafe fn set_any_value_integer(base_type: BaseType, buf: &mut [u8], v: i64) -> bool {
unsafe fn set_any_value_integer(base_type: BaseType, buf: &mut [u8], v: i64) -> bool {
match base_type {
BaseType::UType | BaseType::UByte => emplace_scalar::<u8>(buf, (v as u8).into()),
BaseType::Bool => emplace_scalar::<bool>(buf, v != 0),
Expand All @@ -423,7 +498,7 @@ pub unsafe fn set_any_value_integer(base_type: BaseType, buf: &mut [u8], v: i64)
/// # Safety
///
/// Caller must ensure `buf.len() >= loc + size_of::<T>()`.
pub unsafe fn set_any_value_float(base_type: BaseType, buf: &mut [u8], v: f64) -> bool {
unsafe fn set_any_value_float(base_type: BaseType, buf: &mut [u8], v: f64) -> bool {
match base_type {
BaseType::Float => emplace_scalar::<f32>(buf, (v as f32).into()),
BaseType::Double => emplace_scalar::<f64>(buf, v),
Expand All @@ -437,7 +512,7 @@ pub unsafe fn set_any_value_float(base_type: BaseType, buf: &mut [u8], v: f64) -
/// # Safety
///
/// Caller must ensure `buf.len() >= loc + size_of::<T>()`.
pub unsafe fn set_any_value_string(base_type: BaseType, buf: &mut [u8], v: &str) -> bool {
unsafe fn set_any_value_string(base_type: BaseType, buf: &mut [u8], v: &str) -> bool {
match base_type {
BaseType::Float | BaseType::Double => {
set_any_value_float(base_type, buf, v.parse::<f64>().unwrap_or_default())
Expand All @@ -449,3 +524,191 @@ pub unsafe fn set_any_value_string(base_type: BaseType, buf: &mut [u8], v: &str)
fn is_scalar(base_type: BaseType) -> bool {
return base_type <= BaseType::Double;
}

/// Iterates through the buffer and updates all the relative offsets affected by the insertion.
///
/// # Safety
///
/// Caller must ensure `buf.len() >= loc + size_of::<T>()` at any access layer.
unsafe fn update_offset(
buf: &mut Vec<u8>,
table_loc: usize,
updated: &mut Vec<bool>,
object: &Object,
schema: &Schema,
insertion_loc: usize,
offset: i32,
) {
if updated[table_loc] {
return;
}

let slice = &mut buf[table_loc..table_loc + SIZE_SOFFSET];
let vtable_offset = read_scalar::<SOffsetT>(slice);
let vtable_loc = (table_loc as i32 - vtable_offset) as usize;

if insertion_loc <= table_loc {
// Checks if insertion point is between the table and a vtable that
// precedes it.
if is_update(insertion_loc, vtable_loc, table_loc) {
emplace_scalar::<SOffsetT>(slice, vtable_offset + offset);
updated[table_loc] = true;
}

// Early out: since all fields inside the table must point forwards in
// memory, if the insertion point is before the table we can stop here.
return;
}

for field in object.fields() {
let field_type = field.type_().base_type();
if is_scalar(field_type) {
continue;
}

let field_offset = VOffsetT::follow(buf, vtable_loc + field.offset() as usize) as usize;
if field_offset == 0 {
continue;
}

let field_loc = table_loc + field_offset;
if updated[field_loc] {
continue;
}

if field_type == BaseType::Obj
&& schema.objects().get(field.type_().index() as usize).is_struct()
{
continue;
}

// Updates the relative offset from table to actual data if needed
let slice = &mut buf[field_loc..field_loc + SIZE_UOFFSET];
let field_value_offset = read_scalar::<UOffsetT>(slice);
let field_value_loc = field_loc + field_value_offset as usize;
if is_update(insertion_loc, field_loc, field_value_loc) {
emplace_scalar::<UOffsetT>(slice, (field_value_offset as i32 + offset) as UOffsetT);
updated[field_loc] = true;
}

match field_type {
BaseType::Obj => {
let field_obj = schema.objects().get(field.type_().index() as usize);
update_offset(
buf,
field_value_loc,
updated,
&field_obj,
schema,
insertion_loc,
offset,
);
}
BaseType::Vector => {
let elem_type = field.type_().element();
if elem_type != BaseType::Obj || elem_type != BaseType::String {
continue;
}
if elem_type == BaseType::Obj
&& schema.objects().get(field.type_().index() as usize).is_struct()
{
continue;
}
let vec_size = read_uoffset(buf, field_value_loc);
for index in 0..vec_size {
let elem_loc = field_value_loc + SIZE_UOFFSET + index * SIZE_UOFFSET;
if updated[elem_loc] {
continue;
}
let slice = &mut buf[elem_loc..elem_loc + SIZE_UOFFSET];
let elem_value_offset = read_scalar::<UOffsetT>(slice);
let elem_value_loc = elem_loc + elem_value_offset as usize;
if is_update(insertion_loc, elem_loc, elem_value_loc) {
emplace_scalar::<UOffsetT>(
slice,
(elem_value_offset as i32 + offset) as UOffsetT,
);
updated[elem_loc] = true;
}

if elem_type == BaseType::Obj {
let elem_obj = schema.objects().get(field.type_().index() as usize);
update_offset(
buf,
elem_value_loc,
updated,
&elem_obj,
schema,
insertion_loc,
offset,
);
}
}
}
BaseType::Union => {
let union_enum = schema.enums().get(field.type_().index() as usize);
let union_type = object
.fields()
.lookup_by_key(field.name().to_string() + "_type", |field, key| {
field.key_compare_with_value(key)
})
.unwrap();
let union_type_loc = vtable_loc + union_type.offset() as usize;
let union_type_offset = VOffsetT::follow(buf, union_type_loc) as usize;
let union_type_value = u8::follow(buf, table_loc + union_type_offset);
let union_enum_value = union_enum
.values()
.lookup_by_key(union_type_value as i64, |value, key| {
value.key_compare_with_value(*key)
})
.unwrap();
let union_object =
schema.objects().get(union_enum_value.union_type().unwrap().index() as usize);
update_offset(
buf,
field_value_loc,
updated,
&union_object,
schema,
insertion_loc,
offset,
);
}
_ => (),
}
}

// Checks if the vtable offset points beyond the insertion point.
if is_update(insertion_loc, table_loc, vtable_loc) {
let slice = &mut buf[table_loc..table_loc + SIZE_SOFFSET];
emplace_scalar::<SOffsetT>(slice, vtable_offset - offset);
updated[table_loc] = true;
}
}

/// Returns true if the insertion point is in between `left` (which contains the relative offset) and `right` (where the relative offset points to).
fn is_update(insertion_loc: usize, left: usize, right: usize) -> bool {
if insertion_loc >= left && insertion_loc < right {
return true;
}
false
}

/// Returns the absolute location of the data (e.g. string) in the buffer when the field contains relative offset (`UOffsetT`) to the data.
///
/// # Safety
///
/// The value of the corresponding slot must have type `UOffsetT`.
unsafe fn deref_uoffset<'a>(buf: &'a [u8], field_loc: usize) -> usize {
field_loc + read_uoffset(buf, field_loc)
}

/// Reads the value of `UOffsetT` at the give location.
///
/// # Safety
///
/// The value of the corresponding slot must have type `UOffsetT`.
unsafe fn read_uoffset<'a>(buf: &'a [u8], loc: usize) -> usize {
let slice = &buf[loc..loc + SIZE_UOFFSET];
read_scalar::<UOffsetT>(slice) as usize
}
22 changes: 22 additions & 0 deletions tests/rust_reflection_test/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,28 @@ fn test_schema() {

// Resets it, for further tests.
assert_eq!(set_field::<u16>(flatbuf, hp_field.type_().base_type(), hp_field_loc, 80), true);
let root_table = unsafe { get_any_root(&mutable_buf) };

// Resets string field with a longer string
let name_field =
fields.lookup_by_key("name", |field, key| field.key_compare_with_value(key)).unwrap();
let name_field_address = get_field_loc(&root_table, &name_field);
assert_eq!(
unsafe {
set_string(
mutable_buf,
name_field.type_().base_type(),
name_field_address,
String::from("totally new string"),
&schema,
root_table.loc(),
)
},
true
);
let root_table = unsafe { get_any_root(&mutable_buf) };
let name_field_value = unsafe { get_field_string(&root_table, &name_field) }.unwrap();
assert_eq!(name_field_value, "totally new string");
}

// Builds up a serialized buffer algorithmically.
Expand Down

0 comments on commit 7e5b92b

Please sign in to comment.