Skip to content

Commit

Permalink
throw error when model is missing fields
Browse files Browse the repository at this point in the history
  • Loading branch information
m1guelpf committed Aug 31, 2023
1 parent c163ac0 commit 98d8abd
Showing 1 changed file with 43 additions and 4 deletions.
47 changes: 43 additions & 4 deletions ensemble_derive/src/model/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ fn field_deserialize(column: &Rc<[Ident]>, enum_key: &Rc<[Ident]>) -> TokenStrea
}
}

#[allow(clippy::too_many_lines)]
fn visitor_deserialize(
name: &Ident,
visitor_name: &Ident,
Expand All @@ -197,6 +198,14 @@ fn visitor_deserialize(
.map(|f| &f.ident)
.collect::<Rc<_>>();

let needs_collect = fields.fields.iter().any(|f| {
let Some((relationship_type, _, _)) = f.relationship(primary_key) else {
return false;
};

matches!(relationship_type, Relationship::BelongsTo)
});

let required_checks = fields.fields.iter().filter_map(|f| {
let ident = &f.ident;
let column = f
Expand All @@ -213,6 +222,16 @@ fn visitor_deserialize(
Some(quote_spanned! {f.span()=> let #ident = #ident.ok_or_else(|| _serde::de::Error::missing_field(stringify!(#column)))?; })
});

let ensure_no_leftovers = if needs_collect {
quote! {
if let Some(key) = __collect.keys().next() {
return Err(_serde::de::Error::unknown_field(&key, FIELDS));
}
}
} else {
TokenStream::new()
};

let model_keys = fields.fields.iter().map(|f| {
let ident = &f.ident;

Expand All @@ -231,7 +250,7 @@ fn visitor_deserialize(
let key: &'static str = #relationship_expr.leak();

_serde::de::Deserialize::deserialize::<_serde::__private::de::ContentDeserializer<'_, _serde::de::value::Error>>(
__collect.get(key).ok_or_else(|| _serde::de::Error::missing_field(key))?.clone().into_deserializer()
__collect.remove(key).ok_or_else(|| _serde::de::Error::missing_field(key))?.clone().into_deserializer()
).unwrap()
}}
}, |key| quote_spanned! {f.span()=> #key });
Expand All @@ -242,7 +261,27 @@ fn visitor_deserialize(
});

let build_model = quote! {
Ok(#name { #(#model_keys),* })
let __model = #name { #(#model_keys),* };
#ensure_no_leftovers
Ok(__model)
};

let init_collect = if needs_collect {
quote! {
let mut __collect = ::std::collections::HashMap::<String, _serde::__private::de::Content>::new();
}
} else {
TokenStream::new()
};

let handle_unknown_field = if needs_collect {
quote! {
__collect.insert(name, map.next_value()?);
}
} else {
quote! {
return Err(_serde::de::Error::unknown_field(&name, FIELDS));
}
};

Ok(quote! {
Expand All @@ -255,7 +294,7 @@ fn visitor_deserialize(

fn visit_map<V: _serde::de::MapAccess<'de>>(self, mut map: V) -> Result<#name, V::Error> {
#(let mut #key = None;)*
let mut __collect = ::std::collections::HashMap::<String, _serde::__private::de::Content>::new();
#init_collect

while let Some(key) = map.next_key()? {
match key {
Expand All @@ -268,7 +307,7 @@ fn visitor_deserialize(
},
)*
Field::Other(name) => {
__collect.insert(name, map.next_value()?);
#handle_unknown_field
}
}
}
Expand Down

0 comments on commit 98d8abd

Please sign in to comment.