Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom recursion limits at build time #785

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
10 changes: 10 additions & 0 deletions prost-build/src/code_generator.rs
Original file line number Diff line number Diff line change
@@ -188,6 +188,7 @@ impl<'a> CodeGenerator<'a> {
"#[derive(Clone, PartialEq, {}::Message)]\n",
self.config.prost_path.as_deref().unwrap_or("::prost")
));
self.append_recursion_limit(&fq_message_name);
self.push_indent();
self.buf.push_str("pub struct ");
self.buf.push_str(&to_upper_camel(&message_name));
@@ -271,6 +272,15 @@ impl<'a> CodeGenerator<'a> {
}
}

fn append_recursion_limit(&mut self, fq_message_name: &str) {
assert_eq!(b'.', fq_message_name.as_bytes()[0]);
if let Some(limit) = self.config.recursion_limits.get_first(fq_message_name) {
push_indent(self.buf, self.depth);
self.buf.push_str(&format!("#[RecursionLimit({})]", limit));
self.buf.push('\n');
}
}

fn append_message_attributes(&mut self, fq_message_name: &str) {
assert_eq!(b'.', fq_message_name.as_bytes()[0]);
for attribute in self.config.message_attributes.get(fq_message_name) {
21 changes: 21 additions & 0 deletions prost-build/src/lib.rs
Original file line number Diff line number Diff line change
@@ -247,6 +247,7 @@ pub struct Config {
message_attributes: PathMap<String>,
enum_attributes: PathMap<String>,
field_attributes: PathMap<String>,
recursion_limits: PathMap<u32>,
prost_types: bool,
strip_enum_prefix: bool,
out_dir: Option<PathBuf>,
@@ -470,6 +471,25 @@ impl Config {
self
}

/// Configure a custom recursion limit for certain messages.
///
/// This defaults to 100, and can be disabled with the no-recursion-limit crate feature.
///
/// # Example
///
/// ```rust
/// # let mut config = prost_build::Config::new();
/// config.recursion_limit("my_messages.MyMessageType", 1000);
/// ```
pub fn recursion_limit<P>(&mut self, path: P, limit: u32) -> &mut Self
where
P: AsRef<str>,
{
self.recursion_limits
.insert(path.as_ref().to_string(), limit);
self
}

/// Add additional attribute to matched messages.
///
/// # Arguments
@@ -1214,6 +1234,7 @@ impl default::Default for Config {
message_attributes: PathMap::default(),
enum_attributes: PathMap::default(),
field_attributes: PathMap::default(),
recursion_limits: PathMap::default(),
prost_types: true,
strip_enum_prefix: true,
out_dir: None,
20 changes: 19 additions & 1 deletion prost-derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -23,6 +23,20 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {

let ident = input.ident;

let recursion_limit: u32 = if let Some(attr) = input
.attrs
.iter()
.find(|attr| attr.path.is_ident("RecursionLimit"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should make this one camel case like serde does https://serde.rs/container-attrs.html

so #[prost(recursion_limit = 5)] etc

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm having trouble getting a bespoke attribute parser working. Would it be okay to pull in darling as a dependency? https://crates.io/crates/darling

Another option is to skip proper attribute parsing and only handle the single attribute we have for now.

{
if let syn::Lit::Int(attr) = attr.parse_args().unwrap() {
attr.base10_parse().unwrap()
} else {
panic!("unexpected RecursionLimit type: {:?}", attr)
}
} else {
100
};

let variant_data = match input.data {
Data::Struct(variant_data) => variant_data,
Data::Enum(..) => bail!("Message can not be derived for an enum"),
@@ -187,6 +201,10 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {

let expanded = quote! {
impl #impl_generics ::prost::Message for #ident #ty_generics #where_clause {
fn recursion_limit() -> u32 {
#recursion_limit
}

#[allow(unused_variables)]
fn encode_raw<B>(&self, buf: &mut B) where B: ::prost::bytes::BufMut {
#(#encode)*
@@ -238,7 +256,7 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
Ok(expanded.into())
}

#[proc_macro_derive(Message, attributes(prost))]
#[proc_macro_derive(Message, attributes(prost, RecursionLimit))]
pub fn message(input: TokenStream) -> TokenStream {
try_message(input).unwrap()
}
27 changes: 12 additions & 15 deletions src/encoding.rs
Original file line number Diff line number Diff line change
@@ -195,29 +195,26 @@ pub struct DecodeContext {
/// How many times we can recurse in the current decode stack before we hit
/// the recursion limit.
///
/// The recursion limit is defined by `RECURSION_LIMIT` and cannot be
/// customized. The recursion limit can be ignored by building the Prost
/// crate with the `no-recursion-limit` feature.
/// It defaults to 100 and can be changed using `prost_build::recursion_limit`,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So because this is hidden this doc won't actually be readable. So we need to make sure this is documented at the lib level of prost.

/// or it can be disabled entirely using the `no-recursion-limit` feature.
#[cfg(not(feature = "no-recursion-limit"))]
recurse_count: u32,
#[doc(hidden)]
pub recurse_count: u32,
}

#[cfg(not(feature = "no-recursion-limit"))]
impl Default for DecodeContext {
#[inline]
fn default() -> DecodeContext {
impl DecodeContext {
pub(crate) fn new(recursion_limit: u32) -> DecodeContext {
DecodeContext {
recurse_count: crate::RECURSION_LIMIT,
#[cfg(not(feature = "no-recursion-limit"))]
recurse_count: recursion_limit,
}
}
}

impl DecodeContext {
/// Call this function before recursively decoding.
///
/// There is no `exit` function since this function creates a new `DecodeContext`
/// to be used at the next level of recursion. Continue to use the old context
// at the previous level of recursion.
/// at the previous level of recursion.
#[cfg(not(feature = "no-recursion-limit"))]
#[inline]
pub(crate) fn enter_recursion(&self) -> DecodeContext {
@@ -1503,7 +1500,7 @@ mod test {
wire_type,
&mut roundtrip_value,
&mut buf,
DecodeContext::default(),
DecodeContext::new(100),
)
.map_err(|error| TestCaseError::fail(error.to_string()))?;

@@ -1575,7 +1572,7 @@ mod test {
wire_type,
&mut roundtrip_value,
&mut buf,
DecodeContext::default(),
DecodeContext::new(100),
)
.map_err(|error| TestCaseError::fail(error.to_string()))?;
}
@@ -1594,7 +1591,7 @@ mod test {
WireType::LengthDelimited,
&mut s,
&mut &buf[..],
DecodeContext::default(),
DecodeContext::new(100),
);
r.expect_err("must be an error");
assert!(s.is_empty());
5 changes: 0 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -23,11 +23,6 @@ use bytes::{Buf, BufMut};

use crate::encoding::{decode_varint, encode_varint, encoded_len_varint};

// See `encoding::DecodeContext` for more info.
// 100 is the default recursion limit in the C++ implementation.
#[cfg(not(feature = "no-recursion-limit"))]
const RECURSION_LIMIT: u32 = 100;
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it sounds like this will never change, I opted to inline it everywhere it's used. That allows documentation to say explicitly that the default recursion limit is 100 instead of requiring that users look up this constant.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually kinda like having this as a constant. I wonder if we could just make a constants module that contains just this one. And we can then have all the deep dive docs on the recursion implementation there and then we just need to link to there from the lib doc page. I feel like that would make it easier to maintain down the line and centralize it a bit.


/// Encodes a length delimiter to the buffer.
///
/// See [Message.encode_length_delimited] for more info.
17 changes: 15 additions & 2 deletions src/message.rs
Original file line number Diff line number Diff line change
@@ -14,6 +14,16 @@ use crate::EncodeError;

/// A Protocol Buffers message.
pub trait Message: Debug + Send + Sync {
/// The recursion limit for decoding protobuf messages.
///
/// Defaults to 100. Can be customized in your build.rs or by using the no-recursion-limit crate feature.
fn recursion_limit() -> u32
where
Self: Sized,
{
100
}

/// Encodes the message to a buffer.
///
/// This method will panic if the buffer has insufficient capacity.
@@ -135,7 +145,7 @@ pub trait Message: Debug + Send + Sync {
B: Buf,
Self: Sized,
{
let ctx = DecodeContext::default();
let ctx = DecodeContext::new(Self::recursion_limit());
while buf.has_remaining() {
let (tag, wire_type) = decode_key(&mut buf)?;
self.merge_field(tag, wire_type, &mut buf, ctx.clone())?;
@@ -154,7 +164,7 @@ pub trait Message: Debug + Send + Sync {
WireType::LengthDelimited,
self,
&mut buf,
DecodeContext::default(),
DecodeContext::new(Self::recursion_limit()),
)
}

@@ -166,6 +176,9 @@ impl<M> Message for Box<M>
where
M: Message,
{
fn recursion_limit() -> u32 {
M::recursion_limit()
}
fn encode_raw<B>(&self, buf: &mut B)
where
B: BufMut,
1 change: 1 addition & 0 deletions tests/src/build.rs
Original file line number Diff line number Diff line change
@@ -54,6 +54,7 @@ fn main() {
.compile_protos(&[src.join("ident_conversion.proto")], includes)
.unwrap();

config.recursion_limit("nesting.E", 200);
config
.compile_protos(&[src.join("nesting.proto")], includes)
.unwrap();
21 changes: 21 additions & 0 deletions tests/src/lib.rs
Original file line number Diff line number Diff line change
@@ -405,6 +405,27 @@ mod tests {
assert!(build_and_roundtrip(101).is_err());
}

#[test]
fn test_deep_nesting_with_custom_recursion_limit() {
fn build_and_roundtrip(depth: usize) -> Result<(), prost::DecodeError> {
use crate::nesting::E;

let mut e = Box::new(E::default());
for _ in 0..depth {
let mut next = Box::new(E::default());
next.e = Some(e);
e = next;
}

let mut buf = Vec::new();
e.encode(&mut buf).unwrap();
E::decode(&*buf).map(|_| ())
}

assert!(build_and_roundtrip(200).is_ok());
assert!(build_and_roundtrip(201).is_err());
}

#[test]
fn test_deep_nesting_oneof() {
fn build_and_roundtrip(depth: usize) -> Result<(), prost::DecodeError> {
10 changes: 10 additions & 0 deletions tests/src/nesting.proto
Original file line number Diff line number Diff line change
@@ -23,3 +23,13 @@ message C {
message D {
map<string, D> m = 1;
}

message E {
E e = 1;
repeated E repeated_e = 2;
map<int32, E> map_e = 3;

B b = 4;
repeated B repeated_b = 5;
map<int32, B> map_b = 6;
}