Skip to content

Commit

Permalink
derive as feat only
Browse files Browse the repository at this point in the history
  • Loading branch information
stelzo committed May 7, 2024
1 parent 0e705ac commit 2c637c4
Show file tree
Hide file tree
Showing 7 changed files with 221 additions and 94 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ r2r_msg = ["dep:r2r"]
rayon = ["dep:rayon"]
derive = ["dep:rpcl2_derive"]

default = ["derive"]
default = ["derive"]
2 changes: 2 additions & 0 deletions benches/roundtrip.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#![cfg(feature = "derive")]

use criterion::{black_box, criterion_group, criterion_main, Criterion};
use ros_pointcloud2::{pcl_utils::PointXYZ, PointCloud2Msg};

Expand Down
103 changes: 55 additions & 48 deletions examples/custom_label_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
// The use case is a segmentation point cloud where each point holds a label
// with a custom enum type we want to filter.

use ros_pointcloud2::{Fields, Point, PointCloud2Msg, PointConvertible};
use ros_pointcloud2::{Fields, Point};

#[cfg(not(feature = "derive"))]
use ros_pointcloud2::{PointCloud2Msg, PointConvertible};

#[derive(Debug, PartialEq, Clone, Default)]
enum Label {
Expand Down Expand Up @@ -83,54 +86,58 @@ impl Fields<5> for CustomPoint {
}

// We implemented everything that is needed so we declare it as a PointConvertible
#[cfg(not(feature = "derive"))]
impl PointConvertible<5> for CustomPoint {}

fn main() {
let cloud = vec![
CustomPoint {
x: 1.0,
y: 2.0,
z: 3.0,
intensity: 4.0,
my_custom_label: Label::Deer,
},
CustomPoint {
x: 4.0,
y: 5.0,
z: 6.0,
intensity: 7.0,
my_custom_label: Label::Car,
},
CustomPoint {
x: 7.0,
y: 8.0,
z: 9.0,
intensity: 10.0,
my_custom_label: Label::Human,
},
];

println!("Original cloud: {:?}", cloud);

let msg = PointCloud2Msg::try_from_iter(cloud.clone().into_iter()).unwrap();

println!("filtering by label == Deer");
let out = msg
.try_into_iter()
.unwrap()
.filter(|point: &CustomPoint| point.my_custom_label == Label::Deer)
.collect::<Vec<_>>();

println!("Filtered cloud: {:?}", out);

assert_eq!(
vec![CustomPoint {
x: 1.0,
y: 2.0,
z: 3.0,
intensity: 4.0,
my_custom_label: Label::Deer,
},],
out
);
#[cfg(not(feature = "derive"))]
{
let cloud = vec![
CustomPoint {
x: 1.0,
y: 2.0,
z: 3.0,
intensity: 4.0,
my_custom_label: Label::Deer,
},
CustomPoint {
x: 4.0,
y: 5.0,
z: 6.0,
intensity: 7.0,
my_custom_label: Label::Car,
},
CustomPoint {
x: 7.0,
y: 8.0,
z: 9.0,
intensity: 10.0,
my_custom_label: Label::Human,
},
];

println!("Original cloud: {:?}", cloud);

let msg = PointCloud2Msg::try_from_iter(cloud.clone().into_iter()).unwrap();

println!("filtering by label == Deer");
let out = msg
.try_into_iter()
.unwrap()
.filter(|point: &CustomPoint| point.my_custom_label == Label::Deer)
.collect::<Vec<_>>();

println!("Filtered cloud: {:?}", out);

assert_eq!(
vec![CustomPoint {
x: 1.0,
y: 2.0,
z: 3.0,
intensity: 4.0,
my_custom_label: Label::Deer,
},],
out
);
}
}
35 changes: 16 additions & 19 deletions rpcl2_derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
extern crate proc_macro;

use std::{any::Any, collections::HashMap, fmt::Debug};
use std::collections::HashMap;

use proc_macro::TokenStream;
use quote::{quote, ToTokens};
use syn::{parse_macro_input, DeriveInput};

fn get_allowed_types() -> HashMap<&'static str, usize> {
let mut allowed_datatypes = HashMap::<&'static str, usize>::new();
allowed_datatypes.insert("f32", 4);
allowed_datatypes.insert("f64", 8);
allowed_datatypes.insert("i32", 4);
allowed_datatypes.insert("u8", 1);
allowed_datatypes.insert("u16", 2);
allowed_datatypes.insert("u32", 4);
allowed_datatypes.insert("i8", 1);
allowed_datatypes.insert("i16", 2);
allowed_datatypes
}

/// Derive macro for the `Fields` trait.
///
/// Given the ordering from the source code of your struct, this macro will generate an array of field names.
Expand All @@ -19,15 +32,7 @@ pub fn ros_point_fields_derive(input: TokenStream) -> TokenStream {
_ => return syn::Error::new_spanned(input, "Only structs are supported").to_compile_error().into(),
};

let mut allowed_datatypes = HashMap::<&'static str, usize>::new();
allowed_datatypes.insert("f32", 4);
allowed_datatypes.insert("f64", 8);
allowed_datatypes.insert("i32", 4);
allowed_datatypes.insert("u8", 1);
allowed_datatypes.insert("u16", 2);
allowed_datatypes.insert("u32", 4);
allowed_datatypes.insert("i8", 1);
allowed_datatypes.insert("i16", 2);
let allowed_datatypes = get_allowed_types();

if fields.is_empty() {
return syn::Error::new_spanned(input, "No fields found").to_compile_error().into();
Expand Down Expand Up @@ -75,15 +80,7 @@ pub fn ros_point_derive(input: TokenStream) -> TokenStream {
_ => return syn::Error::new_spanned(input, "Only structs are supported").to_compile_error().into(),
};

let mut allowed_datatypes = HashMap::<&'static str, usize>::new();
allowed_datatypes.insert("f32", 4);
allowed_datatypes.insert("f64", 8);
allowed_datatypes.insert("i32", 4);
allowed_datatypes.insert("u8", 1);
allowed_datatypes.insert("u16", 2);
allowed_datatypes.insert("u32", 4);
allowed_datatypes.insert("i8", 1);
allowed_datatypes.insert("i16", 2);
let allowed_datatypes = get_allowed_types();

if fields.is_empty() {
return syn::Error::new_spanned(input, "No fields found").to_compile_error().into();
Expand Down
140 changes: 119 additions & 21 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,12 @@ use crate::ros_types::{HeaderMsg, PointFieldMsg};
use convert::Endianness;
pub use convert::Fields;

use type_layout::TypeLayout;

#[cfg(feature = "derive")]
pub use rpcl2_derive::*;

#[cfg(feature = "derive")]
pub use type_layout::TypeLayout;

/// All errors that can occur while converting to or from the PointCloud2 message.
#[derive(Debug)]
pub enum ConversionError {
Expand Down Expand Up @@ -123,6 +124,7 @@ impl Default for PointCloud2Msg {
}

impl PointCloud2Msg {
#[cfg(feature = "derive")]
fn prepare_direct_copy<const N: usize, C>() -> Result<Self, ConversionError>
where
C: PointConvertible<N>,
Expand All @@ -134,7 +136,7 @@ impl PointCloud2Msg {
debug_assert!(meta_names.len() == N);

let mut offset: u32 = 0;
let layout = C::layout();
let layout = TypeLayoutInfo::try_from(C::type_layout())?;
let mut fields: Vec<PointFieldMsg> = Vec::with_capacity(layout.fields.len());
for f in layout.fields.into_iter() {
match f {
Expand Down Expand Up @@ -164,6 +166,53 @@ impl PointCloud2Msg {
})
}

#[cfg(feature = "derive")]
fn assert_byte_similarity<const N: usize, C>(&self) -> Result<bool, ConversionError>
where
C: PointConvertible<N>,
{
let point: Point<N> = C::default().into();
debug_assert!(point.fields.len() == N);

let meta_names = C::field_names_ordered();
debug_assert!(meta_names.len() == N);

let mut offset: u32 = 0;
let layout = TypeLayoutInfo::try_from(C::type_layout())?;
for (f, msg_f) in layout.fields.into_iter().zip(self.fields.iter()) {
match f {
PointField::Field {
name,
datatype,
size,
} => {
if msg_f.name != name {
return Err(ConversionError::FieldNotFound(vec![name.clone()]));
}

if msg_f.datatype != datatype {
return Err(ConversionError::InvalidFieldFormat);
}

if msg_f.offset != offset {
return Err(ConversionError::DataLengthMismatch);
}

if msg_f.count != 1 {
return Err(ConversionError::UnsupportedFieldType);
}

offset += size; // assume field_count 1
}
PointField::Padding(size) => {
offset += size; // assume field_count 1
}
}
}

Ok(true)
}

#[inline(always)]
fn prepare<const N: usize, C>() -> Result<Self, ConversionError>
where
Expand Down Expand Up @@ -250,27 +299,73 @@ impl PointCloud2Msg {
Ok(cloud)
}

#[cfg(feature = "derive")]
pub fn try_from_vec<const N: usize, C>(vec: Vec<C>) -> Result<Self, ConversionError>
where
C: PointConvertible<N>,
{
let mut cloud = Self::prepare_direct_copy::<N, C>()?;
let endianness = if cfg!(target_endian = "big") {
Endianness::Big
} else if cfg!(target_endian = "little") {
Endianness::Little
} else {
panic!("Unsupported endianness");
};

match endianness {
Endianness::Big => Self::try_from_iter(vec.into_iter()),
Endianness::Little => {
let mut cloud = Self::prepare_direct_copy::<N, C>()?;

let bytes_total = vec.len() * cloud.point_step as usize;
cloud.data.resize(bytes_total, u8::default());
let raw_data: *mut C = cloud.data.as_ptr() as *mut C;
unsafe {
std::ptr::copy_nonoverlapping(
vec.as_ptr() as *const u8,
raw_data as *mut u8,
bytes_total,
);
}

let bytes_total = vec.len() * cloud.point_step as usize;
cloud.data.resize(bytes_total, u8::default());
let raw_data: *mut C = cloud.data.as_ptr() as *mut C;
unsafe {
std::ptr::copy_nonoverlapping(
vec.as_ptr() as *const u8,
raw_data as *mut u8,
bytes_total,
);
cloud.width = vec.len() as u32;
cloud.row_step = cloud.width * cloud.point_step;

Ok(cloud)
}
}
}

cloud.width = vec.len() as u32;
cloud.row_step = cloud.width * cloud.point_step;
#[cfg(feature = "derive")]
pub fn try_into_vec<const N: usize, C>(self) -> Result<Vec<C>, ConversionError>
where
C: PointConvertible<N>,
{
let endianness = if cfg!(target_endian = "big") {
Endianness::Big
} else if cfg!(target_endian = "little") {
Endianness::Little
} else {
panic!("Unsupported endianness");
};

self.assert_byte_similarity::<N, C>()?;

match endianness {
Endianness::Big => Ok(self.try_into_iter()?.collect()),
Endianness::Little => {
let mut vec = Vec::with_capacity(self.width as usize);
let raw_data: *const C = self.data.as_ptr() as *const C;
unsafe {
for i in 0..self.width {
let point = raw_data.add(i as usize).read();
vec.push(point);
}
}

Ok(cloud)
Ok(vec)
}
}
}

pub fn try_into_iter<const N: usize, C>(
Expand Down Expand Up @@ -347,8 +442,15 @@ pub struct Point<const N: usize> {
///
/// impl PointConvertible<f32, {size_of!(f32)}, 3, 1> for MyPointXYZI {}
/// ```
#[cfg(not(feature = "derive"))]
pub trait PointConvertible<const N: usize>:
KnownLayout + From<Point<N>> + Into<Point<N>> + Fields<N> + Clone + 'static + Default
From<Point<N>> + Into<Point<N>> + Fields<N> + Clone + 'static + Default
{
}

#[cfg(feature = "derive")]
pub trait PointConvertible<const N: usize>:
type_layout::TypeLayout + From<Point<N>> + Into<Point<N>> + Fields<N> + Clone + 'static + Default
{
}

Expand Down Expand Up @@ -397,10 +499,6 @@ impl TryFrom<type_layout::TypeLayoutInfo> for TypeLayoutInfo {
}
}

trait KnownLayout {
fn layout() -> TypeLayoutInfo;
}

/// Metadata representation for a point.
///
/// This struct is used to store meta data in a fixed size byte buffer along the with the
Expand Down
Loading

0 comments on commit 2c637c4

Please sign in to comment.