diff --git a/prost-derive/src/lib.rs b/prost-derive/src/lib.rs index 96efa4382..4972c7ff9 100644 --- a/prost-derive/src/lib.rs +++ b/prost-derive/src/lib.rs @@ -189,7 +189,7 @@ fn try_message(input: TokenStream) -> Result { #struct_name match tag { #(#merge)* - _ => ::prost::encoding::skip_field(wire_type, tag, buf), + _ => ::prost::encoding::skip_field(wire_type, tag, buf, ctx), } } diff --git a/src/encoding.rs b/src/encoding.rs index faa52019e..7fc02cba7 100644 --- a/src/encoding.rs +++ b/src/encoding.rs @@ -381,10 +381,11 @@ where Ok(()) } -pub fn skip_field(wire_type: WireType, tag: u32, buf: &mut B) -> Result<(), DecodeError> +pub fn skip_field(wire_type: WireType, tag: u32, buf: &mut B, ctx: DecodeContext) -> Result<(), DecodeError> where B: Buf, { + ctx.limit_reached()?; let len = match wire_type { WireType::Varint => decode_varint(buf).map(|_| 0)?, WireType::ThirtyTwoBit => 4, @@ -399,7 +400,7 @@ where } break 0; } - _ => skip_field(inner_wire_type, inner_tag, buf)?, + _ => skip_field(inner_wire_type, inner_tag, buf, ctx.enter_recursion())?, } }, WireType::EndGroup => return Err(DecodeError::new("unexpected end group tag")), @@ -1179,7 +1180,7 @@ macro_rules! map { match tag { 1 => key_merge(wire_type, key, buf, ctx), 2 => val_merge(wire_type, val, buf, ctx), - _ => skip_field(wire_type, tag, buf), + _ => skip_field(wire_type, tag, buf, ctx), } }, )?; diff --git a/src/types.rs b/src/types.rs index 6156e297f..b40e23241 100644 --- a/src/types.rs +++ b/src/types.rs @@ -34,7 +34,7 @@ impl Message for bool { if tag == 1 { bool::merge(wire_type, self, buf, ctx) } else { - skip_field(wire_type, tag, buf) + skip_field(wire_type, tag, buf, ctx) } } fn encoded_len(&self) -> usize { @@ -72,7 +72,7 @@ impl Message for u32 { if tag == 1 { uint32::merge(wire_type, self, buf, ctx) } else { - skip_field(wire_type, tag, buf) + skip_field(wire_type, tag, buf, ctx) } } fn encoded_len(&self) -> usize { @@ -110,7 +110,7 @@ impl Message for u64 { if tag == 1 { uint64::merge(wire_type, self, buf, ctx) } else { - skip_field(wire_type, tag, buf) + skip_field(wire_type, tag, buf, ctx) } } fn encoded_len(&self) -> usize { @@ -148,7 +148,7 @@ impl Message for i32 { if tag == 1 { int32::merge(wire_type, self, buf, ctx) } else { - skip_field(wire_type, tag, buf) + skip_field(wire_type, tag, buf, ctx) } } fn encoded_len(&self) -> usize { @@ -186,7 +186,7 @@ impl Message for i64 { if tag == 1 { int64::merge(wire_type, self, buf, ctx) } else { - skip_field(wire_type, tag, buf) + skip_field(wire_type, tag, buf, ctx) } } fn encoded_len(&self) -> usize { @@ -224,7 +224,7 @@ impl Message for f32 { if tag == 1 { float::merge(wire_type, self, buf, ctx) } else { - skip_field(wire_type, tag, buf) + skip_field(wire_type, tag, buf, ctx) } } fn encoded_len(&self) -> usize { @@ -262,7 +262,7 @@ impl Message for f64 { if tag == 1 { double::merge(wire_type, self, buf, ctx) } else { - skip_field(wire_type, tag, buf) + skip_field(wire_type, tag, buf, ctx) } } fn encoded_len(&self) -> usize { @@ -300,7 +300,7 @@ impl Message for String { if tag == 1 { string::merge(wire_type, self, buf, ctx) } else { - skip_field(wire_type, tag, buf) + skip_field(wire_type, tag, buf, ctx) } } fn encoded_len(&self) -> usize { @@ -338,7 +338,7 @@ impl Message for Vec { if tag == 1 { bytes::merge(wire_type, self, buf, ctx) } else { - skip_field(wire_type, tag, buf) + skip_field(wire_type, tag, buf, ctx) } } fn encoded_len(&self) -> usize { @@ -365,12 +365,12 @@ impl Message for () { tag: u32, wire_type: WireType, buf: &mut B, - _ctx: DecodeContext, + ctx: DecodeContext, ) -> Result<(), DecodeError> where B: Buf, { - skip_field(wire_type, tag, buf) + skip_field(wire_type, tag, buf, ctx) } fn encoded_len(&self) -> usize { 0 diff --git a/tests/src/lib.rs b/tests/src/lib.rs index ca2b83332..8b0da7c37 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -462,6 +462,16 @@ mod tests { }; } + #[test] + fn test_267_regression() { + // Checks that skip_field will error appropriately when given a big stack of StartGroup + // tags. + // + // https://github.com/danburkert/prost/issues/267 + let buf = vec![b'C'; 1 << 20]; + <() as Message>::decode(&buf[..]).err().unwrap(); + } + #[test] fn test_default_enum() { let msg = default_enum_value::Test::default();