Skip to content

Commit

Permalink
Merge pull request #41 from trussed-dev/main
Browse files Browse the repository at this point in the history
Merge upstream
  • Loading branch information
sosthene-nitrokey authored Jun 7, 2024
2 parents 720006d + a055e4f commit 40e3128
Show file tree
Hide file tree
Showing 4 changed files with 250 additions and 24 deletions.
31 changes: 30 additions & 1 deletion derive/examples/extension-dispatch.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use trussed::Error;

mod backends {
use super::extensions::{TestExtension, TestReply, TestRequest};
use super::extensions::{
SampleExtension, SampleReply, SampleRequest, TestExtension, TestReply, TestRequest,
};
use trussed::{
backend::Backend, platform::Platform, serde_extensions::ExtensionImpl,
service::ServiceResources, types::CoreContext, Error,
Expand All @@ -26,6 +28,18 @@ mod backends {
}
}

impl ExtensionImpl<SampleExtension> for ABackend {
fn extension_request<P: Platform>(
&mut self,
_core_ctx: &mut CoreContext,
_backend_ctx: &mut Self::Context,
_request: &SampleRequest,
_resources: &mut ServiceResources<P>,
) -> Result<SampleReply, Error> {
Ok(SampleReply)
}
}

#[derive(Default)]
pub struct BBackend;

Expand Down Expand Up @@ -88,6 +102,7 @@ mod extensions {

enum Backend {
A,
ASample,
B,
}

Expand All @@ -104,9 +119,19 @@ enum Extension {
Sample = "extensions::SampleExtension"
)]
struct Dispatch {
#[dispatch(no_core)]
#[extensions("Test")]
a: backends::ABackend,

#[dispatch(delegate_to = "a")]
#[extensions("Sample")]
a_sample: (),

b: backends::BBackend,

#[allow(unused)]
#[dispatch(skip)]
other: String,
}

fn main() {
Expand Down Expand Up @@ -135,5 +160,9 @@ fn main() {
&[BackendId::Custom(Backend::B)],
Some(Error::RequestNotAvailable),
);
run(
&[BackendId::Custom(Backend::ASample)],
Some(Error::RequestNotAvailable),
);
run(&[BackendId::Custom(Backend::A)], None);
}
228 changes: 207 additions & 21 deletions derive/src/extension_dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub struct ExtensionDispatch {
dispatch_attrs: DispatchAttrs,
extension_attrs: ExtensionAttrs,
backends: Vec<Backend>,
delegated_backends: Vec<DelegatedBackend>,
}

impl ExtensionDispatch {
Expand All @@ -27,18 +28,38 @@ impl ExtensionDispatch {
};
let dispatch_attrs = DispatchAttrs::new(&input)?;
let extension_attrs = ExtensionAttrs::new(&input)?;
let backends = data_struct
.fields
.iter()
.enumerate()
.map(|(i, field)| Backend::new(i, field, &extension_attrs.extensions))
let mut raw_backends = Vec::new();
for field in &data_struct.fields {
if let Some(raw_backend) = RawBackend::new(field)? {
raw_backends.push(raw_backend);
}
}
let mut backends = Vec::new();
let mut delegated_backends = Vec::new();
for raw_backend in raw_backends {
if let Some(delegate_to) = raw_backend.delegate_to.clone() {
delegated_backends.push((raw_backend, delegate_to));
} else {
backends.push(Backend::new(
backends.len(),
raw_backend,
&extension_attrs.extensions,
)?);
}
}
let delegated_backends = delegated_backends
.into_iter()
.map(|(raw, delegate_to)| {
DelegatedBackend::new(raw, delegate_to, &backends, &extension_attrs.extensions)
})
.collect::<Result<_>>()?;
Ok(Self {
name: input.ident,
generics: input.generics,
dispatch_attrs,
extension_attrs,
backends,
delegated_backends,
})
}

Expand All @@ -49,7 +70,15 @@ impl ExtensionDispatch {
let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
let context = self.backends.iter().map(Backend::context);
let requests = self.backends.iter().map(Backend::request);
let delegated_requests = self
.delegated_backends
.iter()
.map(DelegatedBackend::request);
let extension_requests = self.backends.iter().map(Backend::extension_request);
let delegated_extension_requests = self
.delegated_backends
.iter()
.map(DelegatedBackend::extension_request);
let extension_impls = self
.extension_attrs
.extensions
Expand All @@ -71,6 +100,7 @@ impl ExtensionDispatch {
) -> ::core::result::Result<::trussed::api::Reply, ::trussed::error::Error> {
match backend {
#(#requests)*
#(#delegated_requests)*
}
}

Expand All @@ -84,6 +114,7 @@ impl ExtensionDispatch {
) -> ::core::result::Result<::trussed::api::reply::SerdeExtension, ::trussed::error::Error> {
match backend {
#(#extension_requests)*
#(#delegated_extension_requests)*
}
}
}
Expand Down Expand Up @@ -165,16 +196,40 @@ impl ExtensionAttrs {
}
}

struct Backend {
struct RawBackend {
id: Ident,
field: Ident,
ty: Type,
index: Index,
extensions: Vec<Extension>,
no_core: bool,
delegate_to: Option<Ident>,
extensions: Vec<Ident>,
}

impl Backend {
fn new(i: usize, field: &Field, extension_types: &HashMap<Ident, Path>) -> Result<Self> {
impl RawBackend {
fn new(field: &Field) -> Result<Option<Self>> {
let mut delegate_to = None;
let mut no_core = false;
let mut skip = false;
for attr in util::get_attrs(&field.attrs, "dispatch") {
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("delegate_to") {
let s: LitStr = meta.value()?.parse()?;
delegate_to = Some(s.parse()?);
Ok(())
} else if meta.path.is_ident("no_core") {
no_core = true;
Ok(())
} else if meta.path.is_ident("skip") {
skip = true;
Ok(())
} else {
Err(meta.error("unsupported dispatch attribute"))
}
})?;
}
if skip {
return Ok(None);
}
let ident = field.ident.clone().ok_or_else(|| {
Error::new_spanned(
field,
Expand All @@ -184,14 +239,43 @@ impl Backend {
let mut extensions = Vec::new();
for attr in util::get_attrs(&field.attrs, "extensions") {
for s in attr.parse_args_with(Punctuated::<LitStr, Token![,]>::parse_terminated)? {
extensions.push(Extension::new(&s, extension_types)?);
extensions.push(s.parse()?);
}
}
Ok(Self {
Ok(Some(Self {
id: util::to_camelcase(&ident),
field: ident,
ty: field.ty.clone(),
no_core,
delegate_to,
extensions,
}))
}
}

#[derive(Clone)]
struct Backend {
id: Ident,
field: Ident,
ty: Type,
index: Index,
no_core: bool,
extensions: Vec<Extension>,
}

impl Backend {
fn new(i: usize, raw: RawBackend, extensions: &HashMap<Ident, Path>) -> Result<Self> {
let extensions = raw
.extensions
.into_iter()
.map(|i| Extension::new(i, extensions))
.collect::<Result<_>>()?;
Ok(Self {
id: raw.id,
field: raw.field,
ty: raw.ty,
index: Index::from(i),
no_core: raw.no_core,
extensions,
})
}
Expand All @@ -202,13 +286,23 @@ impl Backend {
}

fn request(&self) -> TokenStream {
let Self {
index, id, field, ..
} = self;
let id = &self.id;
let request = if self.no_core {
quote! {
Err(::trussed::Error::RequestNotAvailable)
}
} else {
let Self { index, field, .. } = self;
quote! {
::trussed::backend::Backend::request(
&mut self.#field, &mut ctx.core, &mut ctx.backends.#index, request, resources,
)
}
};
quote! {
Self::BackendId::#id => ::trussed::backend::Backend::request(
&mut self.#field, &mut ctx.core, &mut ctx.backends.#index, request, resources,
),
Self::BackendId::#id => {
#request
}
}
}

Expand All @@ -224,17 +318,109 @@ impl Backend {
}
}

struct DelegatedBackend {
id: Ident,
field: Ident,
backend: Backend,
no_core: bool,
extensions: Vec<Extension>,
}

impl DelegatedBackend {
fn new(
raw: RawBackend,
delegate_to: Ident,
backends: &[Backend],
extensions: &HashMap<Ident, Path>,
) -> Result<Self> {
match raw.ty {
Type::Tuple(tuple) if tuple.elems.is_empty() => (),
_ => {
return Err(Error::new_spanned(
&raw.ty,
"delegated backends must use the unit type ()",
));
}
}

let extensions = raw
.extensions
.into_iter()
.map(|i| Extension::new(i, extensions))
.collect::<Result<_>>()?;
let backend = backends
.iter()
.find(|backend| backend.field == delegate_to)
.ok_or_else(|| Error::new_spanned(delegate_to, "unknown backend"))?
.clone();
Ok(Self {
id: raw.id,
field: raw.field,
backend,
no_core: raw.no_core,
extensions,
})
}

fn request(&self) -> TokenStream {
let id = &self.id;
let request = if self.no_core {
quote! {
Err(::trussed::Error::RequestNotAvailable)
}
} else {
let Self { backend, field, .. } = self;
let Backend {
field: delegated_field,
index: delegated_index,
..
} = backend;
quote! {
let _ = self.#field;
::trussed::backend::Backend::request(
&mut self.#delegated_field, &mut ctx.core, &mut ctx.backends.#delegated_index, request, resources,
)
}
};
quote! {
Self::BackendId::#id => {
#request
}
}
}

fn extension_request(&self) -> TokenStream {
let Self {
id,
extensions,
backend,
field,
..
} = self;
let extension_requests = extensions.iter().map(|e| e.extension_request(backend));
quote! {
Self::BackendId::#id => {
let _ = self.#field;
match extension {
#(#extension_requests)*
_ => Err(::trussed::error::Error::RequestNotAvailable),
}
}
}
}
}

#[derive(Clone)]
struct Extension {
id: Ident,
ty: Path,
}

impl Extension {
fn new(s: &LitStr, extensions: &HashMap<Ident, Path>) -> Result<Self> {
let id = s.parse()?;
fn new(id: Ident, extensions: &HashMap<Ident, Path>) -> Result<Self> {
let ty = extensions
.get(&id)
.ok_or_else(|| Error::new_spanned(s, "unknown extension ID"))?
.ok_or_else(|| Error::new_spanned(&id, "unknown extension ID"))?
.clone();
Ok(Self { id, ty })
}
Expand Down
Loading

0 comments on commit 40e3128

Please sign in to comment.