diff --git a/serde/src/de/impls.rs b/serde/src/de/impls.rs index ea48d34e6..87aa6c684 100644 --- a/serde/src/de/impls.rs +++ b/serde/src/de/impls.rs @@ -15,6 +15,7 @@ use de::{Deserialize, Deserializer, EnumAccess, Error, SeqAccess, Unexpected, Va use de::MapAccess; use de::from_primitive::FromPrimitive; +use private::de::DeserializeFromSeed; #[cfg(any(feature = "std", feature = "alloc"))] use private::de::size_hint; @@ -51,6 +52,7 @@ impl<'de> Deserialize<'de> for () { struct BoolVisitor; + impl<'de> Visitor<'de> for BoolVisitor { type Value = bool; @@ -210,6 +212,8 @@ impl<'de> Deserialize<'de> for char { #[cfg(any(feature = "std", feature = "alloc"))] struct StringVisitor; +#[cfg(any(feature = "std", feature = "alloc"))] +struct StringFromVisitor<'a>(&'a mut String); #[cfg(any(feature = "std", feature = "alloc"))] impl<'de> Visitor<'de> for StringVisitor { @@ -254,6 +258,59 @@ impl<'de> Visitor<'de> for StringVisitor { } } +#[cfg(any(feature = "std", feature = "alloc"))] +impl<'a, 'de> Visitor<'de> for StringFromVisitor<'a> { + type Value = (); + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a string") + } + + fn visit_str(self, v: &str) -> Result<(), E> + where + E: Error, + { + self.0.clear(); + self.0.push_str(v); + Ok(()) + } + + fn visit_string(self, v: String) -> Result<(), E> + where + E: Error, + { + *self.0 = v; + Ok(()) + } + + fn visit_bytes(self, v: &[u8]) -> Result<(), E> + where + E: Error, + { + match str::from_utf8(v) { + Ok(s) => { + self.0.clear(); + self.0.push_str(s); + Ok(()) + } + Err(_) => Err(Error::invalid_value(Unexpected::Bytes(v), &self)), + } + } + + fn visit_byte_buf(self, v: Vec) -> Result<(), E> + where + E: Error, + { + match String::from_utf8(v) { + Ok(s) => { + *self.0 = s; + Ok(()) + } + Err(e) => Err(Error::invalid_value(Unexpected::Bytes(&e.into_bytes()), &self),), + } + } +} + #[cfg(any(feature = "std", feature = "alloc"))] impl<'de> Deserialize<'de> for String { fn deserialize(deserializer: D) -> Result @@ -262,6 +319,13 @@ impl<'de> Deserialize<'de> for String { { deserializer.deserialize_string(StringVisitor) } + + fn deserialize_from(&mut self, deserializer: D) -> Result<(), D::Error> + where + D: Deserializer<'de>, + { + deserializer.deserialize_string(StringFromVisitor(self)) + } } //////////////////////////////////////////////////////////////////////////////// @@ -467,6 +531,12 @@ where { deserializer.deserialize_option(OptionVisitor { marker: PhantomData }) } + + // The Some variant's repr is opaque, so we can't play cute tricks with its + // tag to have deserialize_from build the content in place unconditionally. + // + // FIXME: investigate whether branching on the old value being Some to + // deserialize_from the value is profitable (probably data-dependent?) } //////////////////////////////////////////////////////////////////////////////// @@ -509,7 +579,9 @@ macro_rules! seq_impl { $ty:ident < T $(: $tbound1:ident $(+ $tbound2:ident)*)* $(, $typaram:ident : $bound1:ident $(+ $bound2:ident)*)* >, $access:ident, $ctor:expr, + $clear:expr, $with_capacity:expr, + $reserve:expr, $insert:expr ) => { impl<'de, T $(, $typaram)*> Deserialize<'de> for $ty @@ -554,16 +626,59 @@ macro_rules! seq_impl { let visitor = SeqVisitor { marker: PhantomData }; deserializer.deserialize_seq(visitor) } + + fn deserialize_from(&mut self, deserializer: D) -> Result<(), D::Error> + where + D: Deserializer<'de>, + { + struct SeqVisitor<'a, T: 'a $(, $typaram: 'a)*>(&'a mut $ty); + + impl<'a, 'de, T $(, $typaram)*> Visitor<'de> for SeqVisitor<'a, T $(, $typaram)*> + where + T: Deserialize<'de> $(+ $tbound1 $(+ $tbound2)*)*, + $($typaram: $bound1 $(+ $bound2)*,)* + { + type Value = (); + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a sequence") + } + + #[inline] + fn visit_seq(mut self, mut $access: A) -> Result<(), A::Error> + where + A: SeqAccess<'de>, + { + $clear(&mut self.0); + $reserve(&mut self.0, size_hint::cautious($access.size_hint())); + + // FIXME: try to overwrite old values here? (Vec, VecDeque, LinkedList) + while let Some(value) = try!($access.next_element()) { + $insert(&mut self.0, value); + } + + Ok(()) + } + } + + deserializer.deserialize_seq(SeqVisitor(self)) + } } } } +// Dummy impl of reserve +#[cfg(any(feature = "std", feature = "alloc"))] +fn nop_reserve(_seq: T, _n: usize) {} + #[cfg(any(feature = "std", feature = "alloc"))] seq_impl!( BinaryHeap, seq, BinaryHeap::new(), + BinaryHeap::clear, BinaryHeap::with_capacity(size_hint::cautious(seq.size_hint())), + BinaryHeap::reserve, BinaryHeap::push); #[cfg(any(feature = "std", feature = "alloc"))] @@ -571,7 +686,9 @@ seq_impl!( BTreeSet, seq, BTreeSet::new(), + BTreeSet::clear, BTreeSet::new(), + nop_reserve, BTreeSet::insert); #[cfg(any(feature = "std", feature = "alloc"))] @@ -579,7 +696,9 @@ seq_impl!( LinkedList, seq, LinkedList::new(), + LinkedList::clear, LinkedList::new(), + nop_reserve, LinkedList::push_back); #[cfg(feature = "std")] @@ -587,7 +706,9 @@ seq_impl!( HashSet, seq, HashSet::with_hasher(S::default()), + HashSet::clear, HashSet::with_capacity_and_hasher(size_hint::cautious(seq.size_hint()), S::default()), + HashSet::reserve, HashSet::insert); #[cfg(any(feature = "std", feature = "alloc"))] @@ -595,7 +716,9 @@ seq_impl!( Vec, seq, Vec::new(), + Vec::clear, Vec::with_capacity(size_hint::cautious(seq.size_hint())), + Vec::reserve, Vec::push); #[cfg(any(feature = "std", feature = "alloc"))] @@ -603,7 +726,9 @@ seq_impl!( VecDeque, seq, VecDeque::new(), + VecDeque::clear, VecDeque::with_capacity(size_hint::cautious(seq.size_hint())), + VecDeque::reserve, VecDeque::push_back); //////////////////////////////////////////////////////////////////////////////// @@ -611,6 +736,7 @@ seq_impl!( struct ArrayVisitor { marker: PhantomData, } +struct ArrayFromVisitor<'a, A: 'a>(&'a mut A); impl ArrayVisitor { fn new() -> Self { @@ -673,6 +799,35 @@ macro_rules! array_impls { } } + impl<'a, 'de, T> Visitor<'de> for ArrayFromVisitor<'a, [T; $len]> + where + T: Deserialize<'de>, + { + type Value = (); + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str(concat!("an array of length ", $len)) + } + + #[inline] + fn visit_seq(self, mut seq: A) -> Result<(), A::Error> + where + A: SeqAccess<'de>, + { + let mut fail_idx = None; + for (idx, dest) in self.0[..].iter_mut().enumerate() { + if try!(seq.next_element_seed(DeserializeFromSeed(dest))).is_none() { + fail_idx = Some(idx); + break; + } + } + if let Some(idx) = fail_idx { + return Err(Error::invalid_length(idx, &self)); + } + Ok(()) + } + } + impl<'de, T> Deserialize<'de> for [T; $len] where T: Deserialize<'de>, @@ -683,6 +838,13 @@ macro_rules! array_impls { { deserializer.deserialize_tuple($len, ArrayVisitor::<[T; $len]>::new()) } + + fn deserialize_from(&mut self, deserializer: D) -> Result<(), D::Error> + where + D: Deserializer<'de>, + { + deserializer.deserialize_tuple($len, ArrayFromVisitor(self)) + } } )+ } @@ -726,49 +888,76 @@ array_impls! { //////////////////////////////////////////////////////////////////////////////// macro_rules! tuple_impls { - ($($len:tt $visitor:ident => ($($n:tt $name:ident)+))+) => { + ($($len:tt => ($($n:tt $name:ident)+))+) => { $( - struct $visitor<$($name,)+> { - marker: PhantomData<($($name,)+)>, - } + impl<'de, $($name: Deserialize<'de>),+> Deserialize<'de> for ($($name,)+) { + #[inline] + fn deserialize(deserializer: D) -> Result<($($name,)+), D::Error> + where + D: Deserializer<'de>, + { + struct TupleVisitor<$($name,)+> { + marker: PhantomData<($($name,)+)>, + } - impl<$($name,)+> $visitor<$($name,)+> { - fn new() -> Self { - $visitor { marker: PhantomData } - } - } + impl<'de, $($name: Deserialize<'de>),+> Visitor<'de> for TupleVisitor<$($name,)+> { + type Value = ($($name,)+); - impl<'de, $($name: Deserialize<'de>),+> Visitor<'de> for $visitor<$($name,)+> { - type Value = ($($name,)+); + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str(concat!("a tuple of size ", $len)) + } - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str(concat!("a tuple of size ", $len)) - } + #[inline] + #[allow(non_snake_case)] + fn visit_seq(self, mut seq: A) -> Result<($($name,)+), A::Error> + where + A: SeqAccess<'de>, + { + $( + let $name = match try!(seq.next_element()) { + Some(value) => value, + None => return Err(Error::invalid_length($n, &self)), + }; + )+ - #[inline] - #[allow(non_snake_case)] - fn visit_seq(self, mut seq: A) -> Result<($($name,)+), A::Error> - where - A: SeqAccess<'de>, - { - $( - let $name = match try!(seq.next_element()) { - Some(value) => value, - None => return Err(Error::invalid_length($n, &self)), - }; - )+ + Ok(($($name,)+)) + } + } - Ok(($($name,)+)) + deserializer.deserialize_tuple($len, TupleVisitor { marker: PhantomData }) } - } - impl<'de, $($name: Deserialize<'de>),+> Deserialize<'de> for ($($name,)+) { #[inline] - fn deserialize(deserializer: D) -> Result<($($name,)+), D::Error> + fn deserialize_from(&mut self, deserializer: D) -> Result<(), D::Error> where D: Deserializer<'de>, { - deserializer.deserialize_tuple($len, $visitor::new()) + struct TupleVisitor<'a, $($name: 'a,)+>(&'a mut ($($name,)+)); + + impl<'a, 'de, $($name: Deserialize<'de>),+> Visitor<'de> for TupleVisitor<'a, $($name,)+> { + type Value = (); + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str(concat!("a tuple of size ", $len)) + } + + #[inline] + #[allow(non_snake_case)] + fn visit_seq(self, mut seq: A) -> Result<(), A::Error> + where + A: SeqAccess<'de>, + { + $( + if try!(seq.next_element_seed(DeserializeFromSeed(&mut (self.0).$n))).is_none() { + return Err(Error::invalid_length($n, &self)); + } + )+ + + Ok(()) + } + } + + deserializer.deserialize_tuple($len, TupleVisitor(self)) } } )+ @@ -776,22 +965,22 @@ macro_rules! tuple_impls { } tuple_impls! { - 1 TupleVisitor1 => (0 T0) - 2 TupleVisitor2 => (0 T0 1 T1) - 3 TupleVisitor3 => (0 T0 1 T1 2 T2) - 4 TupleVisitor4 => (0 T0 1 T1 2 T2 3 T3) - 5 TupleVisitor5 => (0 T0 1 T1 2 T2 3 T3 4 T4) - 6 TupleVisitor6 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5) - 7 TupleVisitor7 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6) - 8 TupleVisitor8 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7) - 9 TupleVisitor9 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8) - 10 TupleVisitor10 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9) - 11 TupleVisitor11 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10) - 12 TupleVisitor12 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11) - 13 TupleVisitor13 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12) - 14 TupleVisitor14 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13) - 15 TupleVisitor15 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14) - 16 TupleVisitor16 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14 15 T15) + 1 => (0 T0) + 2 => (0 T0 1 T1) + 3 => (0 T0 1 T1 2 T2) + 4 => (0 T0 1 T1 2 T2 3 T3) + 5 => (0 T0 1 T1 2 T2 3 T3 4 T4) + 6 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5) + 7 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6) + 8 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7) + 9 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8) + 10 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9) + 11 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10) + 12 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11) + 13 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12) + 14 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13) + 15 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14) + 16 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14 15 T15) } //////////////////////////////////////////////////////////////////////////////// diff --git a/serde/src/de/mod.rs b/serde/src/de/mod.rs index 848183c7b..219907bda 100644 --- a/serde/src/de/mod.rs +++ b/serde/src/de/mod.rs @@ -504,6 +504,34 @@ pub trait Deserialize<'de>: Sized { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>; + + /// Deserializes a value into `self` from the given Deserializer. + /// + /// The purpose of this method is to allow the deserializer to reuse + /// resources and avoid copies. As such, if this method returns an error, + /// `self` will be in an indeterminate state where some parts of the struct + /// have been overwritten. Although whatever state that is will be + /// memory-safe. + /// + /// This is generally useful when repeateadly deserializing values that + /// are processed one at a time, where the value of `self` doesn't matter + /// when the next deserialization occurs. + /// + /// If you manually implement this, your recursive deserializations should + /// use `deserialize_from`. + /// + /// This method is stable and an official public API, but hidden from the + /// documentation because it is almost never what newbies are looking for. + /// Showing it in rustdoc would cause it to be featured more prominently + /// than it deserves. + #[doc(hidden)] + fn deserialize_from(&mut self, deserializer: D) -> Result<(), D::Error> + where D: Deserializer<'de> + { + // Default implementation just delegates to `deserialize` impl. + *self = Deserialize::deserialize(deserializer)?; + Ok(()) + } } /// A data structure that can be deserialized without borrowing any data from diff --git a/serde/src/private/de.rs b/serde/src/private/de.rs index 7c98d9f93..a3134e69f 100644 --- a/serde/src/private/de.rs +++ b/serde/src/private/de.rs @@ -8,7 +8,7 @@ use lib::*; -use de::{Deserialize, Deserializer, IntoDeserializer, Error, Visitor}; +use de::{Deserialize, Deserializer, DeserializeSeed, IntoDeserializer, Error, Visitor}; #[cfg(any(feature = "std", feature = "alloc"))] use de::Unexpected; @@ -2009,3 +2009,20 @@ where map struct enum identifier ignored_any } } + +/// A DeserializeSeed helper for implementing deserialize_from Visitors. +/// +/// Wraps a mutable reference and calls deserialize_from on it. +pub struct DeserializeFromSeed<'a, T: 'a>(pub &'a mut T); + +impl<'a, 'de, T> DeserializeSeed<'de> for DeserializeFromSeed<'a, T> + where T: Deserialize<'de>, +{ + type Value = (); + fn deserialize(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + self.0.deserialize_from(deserializer) + } +} diff --git a/serde_derive/Cargo.toml b/serde_derive/Cargo.toml index cf59b9c05..a93411c14 100644 --- a/serde_derive/Cargo.toml +++ b/serde_derive/Cargo.toml @@ -14,6 +14,10 @@ include = ["Cargo.toml", "src/**/*.rs", "README.md", "LICENSE-APACHE", "LICENSE- [badges] travis-ci = { repository = "serde-rs/serde" } +[features] +default = [] +deserialize_from = [] + [lib] name = "serde_derive" proc-macro = true diff --git a/serde_derive/src/de.rs b/serde_derive/src/de.rs index 2012c10b0..dbe5a1934 100644 --- a/serde_derive/src/de.rs +++ b/serde_derive/src/de.rs @@ -40,6 +40,8 @@ pub fn expand_derive_deserialize(input: &syn::DeriveInput) -> Result for #ident #ty_generics #where_clause { @@ -48,6 +50,8 @@ pub fn expand_derive_deserialize(input: &syn::DeriveInput) -> Result Fragment { } } +#[cfg(feature = "deserialize_from")] +fn deserialize_from_body(cont: &Container, params: &Parameters) -> Option { + // Only remote derives have getters, and we do not generate deserialize_from + // for remote derives. + assert!(!params.has_getter); + + if cont.attrs.from_type().is_some() + || cont.attrs.identifier().is_some() + || cont.body.all_fields().all(|f| f.attrs.deserialize_with().is_some()) + { + return None; + } + + let code = match cont.body { + Body::Struct(Style::Struct, ref fields) => { + deserialize_from_struct(None, params, fields, &cont.attrs, None, Untagged::No) + } + Body::Struct(Style::Tuple, ref fields) | + Body::Struct(Style::Newtype, ref fields) => { + deserialize_from_tuple(None, params, fields, &cont.attrs, None) + } + Body::Enum(_) | Body::Struct(Style::Unit, _) => { + return None; + } + }; + + let delife = params.borrowed.de_lifetime(); + let stmts = Stmts(code); + + let fn_deserialize_from = quote_block! { + fn deserialize_from<__D>(&mut self, __deserializer: __D) -> _serde::export::Result<(), __D::Error> + where __D: _serde::Deserializer<#delife> + { + #stmts + } + }; + + Some(Stmts(fn_deserialize_from)) +} + +#[cfg(not(feature = "deserialize_from"))] +fn deserialize_from_body(_cont: &Container, _params: &Parameters) -> Option { + None +} + fn deserialize_from(from_type: &syn::Ty) -> Fragment { quote_block! { _serde::export::Result::map( @@ -376,6 +425,93 @@ fn deserialize_tuple( } } +#[cfg(feature = "deserialize_from")] +fn deserialize_from_tuple( + variant_ident: Option<&syn::Ident>, + params: &Parameters, + fields: &[Field], + cattrs: &attr::Container, + deserializer: Option, +) -> Fragment { + let this = ¶ms.this; + let (de_impl_generics, de_ty_generics, ty_generics, where_clause) = split_with_de_lifetime(params,); + let delife = params.borrowed.de_lifetime(); + + let is_enum = variant_ident.is_some(); + let expecting = match variant_ident { + Some(variant_ident) => format!("tuple variant {}::{}", params.type_name(), variant_ident), + None => format!("tuple struct {}", params.type_name()), + }; + + let nfields = fields.len(); + + let visit_newtype_struct = if !is_enum && nfields == 1 { + Some(deserialize_from_newtype_struct(params, &fields[0])) + } else { + None + }; + + let visit_seq = Stmts(deserialize_from_seq(params, fields, cattrs)); + + let visitor_expr = quote! { + __Visitor { + dest: self, + lifetime: _serde::export::PhantomData, + } + }; + + let dispatch = if let Some(deserializer) = deserializer { + quote!(_serde::Deserializer::deserialize_tuple(#deserializer, #nfields, #visitor_expr)) + } else if is_enum { + quote!(_serde::de::VariantAccess::tuple_variant(__variant, #nfields, #visitor_expr)) + } else if nfields == 1 { + let type_name = cattrs.name().deserialize_name(); + quote!(_serde::Deserializer::deserialize_newtype_struct(__deserializer, #type_name, #visitor_expr)) + } else { + let type_name = cattrs.name().deserialize_name(); + quote!(_serde::Deserializer::deserialize_tuple_struct(__deserializer, #type_name, #nfields, #visitor_expr)) + }; + + let all_skipped = fields + .iter() + .all(|field| field.attrs.skip_deserializing()); + let visitor_var = if all_skipped { + quote!(_) + } else { + quote!(mut __seq) + }; + + let de_from_impl_generics = de_impl_generics.with_dest(); + let de_from_ty_generics = de_ty_generics.with_dest(); + let dest_life = dest_lifetime(); + + quote_block! { + struct __Visitor #de_from_impl_generics #where_clause { + dest: &#dest_life mut #this #ty_generics, + lifetime: _serde::export::PhantomData<&#delife ()>, + } + + impl #de_from_impl_generics _serde::de::Visitor<#delife> for __Visitor #de_from_ty_generics #where_clause { + type Value = (); + + fn expecting(&self, formatter: &mut _serde::export::Formatter) -> _serde::export::fmt::Result { + _serde::export::Formatter::write_str(formatter, #expecting) + } + + #visit_newtype_struct + + #[inline] + fn visit_seq<__A>(self, #visitor_var: __A) -> _serde::export::Result + where __A: _serde::de::SeqAccess<#delife> + { + #visit_seq + } + } + + #dispatch + } +} + fn deserialize_seq( type_path: &Tokens, params: &Parameters, @@ -476,6 +612,98 @@ fn deserialize_seq( } } +#[cfg(feature = "deserialize_from")] +fn deserialize_from_seq( + params: &Parameters, + fields: &[Field], + cattrs: &attr::Container, +) -> Fragment { + let vars = (0..fields.len()).map(field_i as fn(_) -> _); + + let deserialized_count = fields + .iter() + .filter(|field| !field.attrs.skip_deserializing()) + .count(); + let expecting = format!("tuple of {} elements", deserialized_count); + + let mut index_in_seq = 0usize; + let write_values = vars.clone().zip(fields).enumerate() + .map(|(field_index, (_, field))| { + // If there's no field name, assume we're a tuple-struct and use a numeric index + let field_name = field.ident.clone() + .unwrap_or_else(|| Ident::new(field_index.to_string())); + + if field.attrs.skip_deserializing() { + let default = Expr(expr_is_missing(&field, cattrs)); + quote! { + self.dest.#field_name = #default; + } + } else { + let return_invalid_length = quote! { + return _serde::export::Err(_serde::de::Error::invalid_length(#index_in_seq, &#expecting)); + }; + let write = match field.attrs.deserialize_with() { + None => { + quote! { + if let _serde::export::None = try!(_serde::de::SeqAccess::next_element_seed(&mut __seq, + _serde::private::de::DeserializeFromSeed(&mut self.dest.#field_name))) + { + #return_invalid_length + } + } + } + Some(path) => { + let (wrapper, wrapper_ty) = wrap_deserialize_field_with( + params, field.ty, path); + quote!({ + #wrapper + match try!(_serde::de::SeqAccess::next_element::<#wrapper_ty>(&mut __seq)) { + _serde::export::Some(__wrap) => { + self.dest.#field_name = __wrap.value; + } + _serde::export::None => { + #return_invalid_length + } + } + }) + } + }; + index_in_seq += 1; + write + } + }); + + let this = ¶ms.this; + let (_, ty_generics, _) = params.generics.split_for_impl(); + let let_default = match *cattrs.default() { + attr::Default::Default => { + Some( + quote!( + let __default: #this #ty_generics = _serde::export::Default::default(); + ), + ) + } + attr::Default::Path(ref path) => { + Some( + quote!( + let __default: #this #ty_generics = #path(); + ), + ) + } + attr::Default::None => { + // We don't need the default value, to prevent an unused variable warning + // we'll leave the line empty. + None + } + }; + + quote_block! { + #let_default + #(#write_values)* + _serde::export::Ok(()) + } +} + fn deserialize_newtype_struct(type_path: &Tokens, params: &Parameters, field: &Field) -> Tokens { let delife = params.borrowed.de_lifetime(); @@ -513,6 +741,26 @@ fn deserialize_newtype_struct(type_path: &Tokens, params: &Parameters, field: &F } } +#[cfg(feature = "deserialize_from")] +fn deserialize_from_newtype_struct( + params: &Parameters, + field: &Field +) -> Tokens { + // We do not generate deserialize_from if every field has a deserialize_with. + assert!(field.attrs.deserialize_with().is_none()); + + let delife = params.borrowed.de_lifetime(); + + quote! { + #[inline] + fn visit_newtype_struct<__E>(self, __e: __E) -> _serde::export::Result + where __E: _serde::Deserializer<#delife> + { + _serde::Deserialize::deserialize_from(&mut self.dest.0, __e) + } + } +} + enum Untagged { Yes, No, @@ -635,6 +883,116 @@ fn deserialize_struct( } } +#[cfg(feature = "deserialize_from")] +fn deserialize_from_struct( + variant_ident: Option<&syn::Ident>, + params: &Parameters, + fields: &[Field], + cattrs: &attr::Container, + deserializer: Option, + untagged: Untagged, +) -> Fragment { + let is_enum = variant_ident.is_some(); + + let this = ¶ms.this; + let (de_impl_generics, de_ty_generics, ty_generics, where_clause) = split_with_de_lifetime(params,); + let delife = params.borrowed.de_lifetime(); + + let expecting = match variant_ident { + Some(variant_ident) => format!("struct variant {}::{}", params.type_name(), variant_ident), + None => format!("struct {}", params.type_name()), + }; + + let visit_seq = Stmts(deserialize_from_seq(params, fields, cattrs)); + + let (field_visitor, fields_stmt, visit_map) = + deserialize_from_struct_visitor(params, fields, cattrs); + let field_visitor = Stmts(field_visitor); + let fields_stmt = Stmts(fields_stmt); + let visit_map = Stmts(visit_map); + + let visitor_expr = quote! { + __Visitor { + dest: self, + lifetime: _serde::export::PhantomData, + } + }; + let dispatch = if let Some(deserializer) = deserializer { + quote! { + _serde::Deserializer::deserialize_any(#deserializer, #visitor_expr) + } + } else if is_enum { + quote! { + _serde::de::VariantAccess::struct_variant(__variant, FIELDS, #visitor_expr) + } + } else { + let type_name = cattrs.name().deserialize_name(); + quote! { + _serde::Deserializer::deserialize_struct(__deserializer, #type_name, FIELDS, #visitor_expr) + } + }; + + + let all_skipped = fields + .iter() + .all(|field| field.attrs.skip_deserializing()); + let visitor_var = if all_skipped { + quote!(_) + } else { + quote!(mut __seq) + }; + + // untagged struct variants do not get a visit_seq method + let visit_seq = match untagged { + Untagged::Yes => None, + Untagged::No => { + Some(quote! { + #[inline] + fn visit_seq<__A>(self, #visitor_var: __A) -> _serde::export::Result + where __A: _serde::de::SeqAccess<#delife> + { + #visit_seq + } + }) + + } + }; + + let de_from_impl_generics = de_impl_generics.with_dest(); + let de_from_ty_generics = de_ty_generics.with_dest(); + let dest_life = dest_lifetime(); + + quote_block! { + #field_visitor + + struct __Visitor #de_from_impl_generics #where_clause { + dest: &#dest_life mut #this #ty_generics, + lifetime: _serde::export::PhantomData<&#delife ()>, + } + + impl #de_from_impl_generics _serde::de::Visitor<#delife> for __Visitor #de_from_ty_generics #where_clause { + type Value = (); + + fn expecting(&self, formatter: &mut _serde::export::Formatter) -> _serde::export::fmt::Result { + _serde::export::Formatter::write_str(formatter, #expecting) + } + + #visit_seq + + #[inline] + fn visit_map<__A>(self, mut __map: __A) -> _serde::export::Result + where __A: _serde::de::MapAccess<#delife> + { + #visit_map + } + } + + #fields_stmt + + #dispatch + } +} + fn deserialize_enum( params: &Parameters, variants: &[Variant], @@ -1763,6 +2121,189 @@ fn deserialize_map( } } +#[cfg(feature = "deserialize_from")] +fn deserialize_from_struct_visitor( + params: &Parameters, + fields: &[Field], + cattrs: &attr::Container, +) -> (Fragment, Fragment, Fragment) { + let field_names_idents: Vec<_> = fields + .iter() + .enumerate() + .filter(|&(_, field)| !field.attrs.skip_deserializing()) + .map(|(i, field)| (field.attrs.name().deserialize_name(), field_i(i)),) + .collect(); + + let fields_stmt = { + let field_names = field_names_idents.iter().map(|&(ref name, _)| name); + quote_block! { + const FIELDS: &'static [&'static str] = &[ #(#field_names),* ]; + } + }; + + let field_visitor = deserialize_generated_identifier(field_names_idents, cattrs, false); + + let visit_map = deserialize_from_map(params, fields, cattrs); + + (field_visitor, fields_stmt, visit_map) +} + +#[cfg(feature = "deserialize_from")] +fn deserialize_from_map( + params: &Parameters, + fields: &[Field], + cattrs: &attr::Container, +) -> Fragment { + // Create the field names for the fields. + let fields_names: Vec<_> = fields + .iter() + .enumerate() + .map(|(i, field)| (field, field_i(i))) + .collect(); + + // For deserialize_from, declare booleans for each field that will be deserialized. + let let_flags = fields_names + .iter() + .filter(|&&(field, _)| !field.attrs.skip_deserializing()) + .map( + |&(_, ref name)| { + quote! { + let mut #name: bool = false; + } + }, + ); + + // Match arms to extract a value for a field. + let value_arms_from = fields_names.iter() + .filter(|&&(field, _)| !field.attrs.skip_deserializing()) + .map(|&(field, ref name)| { + let deser_name = field.attrs.name().deserialize_name(); + let field_name = &field.ident; + + let visit = match field.attrs.deserialize_with() { + None => { + quote! { + try!(_serde::de::MapAccess::next_value_seed(&mut __map, _serde::private::de::DeserializeFromSeed(&mut self.dest.#field_name))) + } + } + Some(path) => { + let (wrapper, wrapper_ty) = wrap_deserialize_field_with( + params, field.ty, path); + quote!({ + #wrapper + self.dest.#field_name = try!(_serde::de::MapAccess::next_value::<#wrapper_ty>(&mut __map)).value + }) + } + }; + quote! { + __Field::#name => { + if #name { + return _serde::export::Err(<__A::Error as _serde::de::Error>::duplicate_field(#deser_name)); + } + #visit; + #name = true; + } + } + }); + + // Visit ignored values to consume them + let ignored_arm = if cattrs.deny_unknown_fields() { + None + } else { + Some(quote! { + _ => { let _ = try!(_serde::de::MapAccess::next_value::<_serde::de::IgnoredAny>(&mut __map)); } + }) + }; + + let all_skipped = fields + .iter() + .all(|field| field.attrs.skip_deserializing()); + + let match_keys = if cattrs.deny_unknown_fields() && all_skipped { + quote! { + // FIXME: Once we drop support for Rust 1.15: + // let _serde::export::None::<__Field> = try!(_serde::de::MapAccess::next_key(&mut __map)); + _serde::export::Option::map( + try!(_serde::de::MapAccess::next_key::<__Field>(&mut __map)), + |__impossible| match __impossible {}); + } + } else { + quote! { + while let _serde::export::Some(__key) = try!(_serde::de::MapAccess::next_key::<__Field>(&mut __map)) { + match __key { + #(#value_arms_from)* + #ignored_arm + } + } + } + }; + + let check_flags = fields_names + .iter() + .filter(|&&(field, _)| !field.attrs.skip_deserializing()) + .map( + |&(field, ref name)| { + let missing_expr = expr_is_missing(&field, cattrs); + // If missing_expr unconditionally returns an error, don't try + // to assign its value to self.dest. Maybe this could be handled + // more elegantly. + if missing_expr.as_ref().as_str().starts_with("return ") { + let missing_expr = Stmts(missing_expr); + quote! { + if !#name { + #missing_expr; + } + } + } else { + let field_name = &field.ident; + let missing_expr = Expr(missing_expr); + quote! { + if !#name { + self.dest.#field_name = #missing_expr; + }; + } + } + }, + ); + + let this = ¶ms.this; + let (_, _, ty_generics, _) = split_with_de_lifetime(params,); + + let let_default = match *cattrs.default() { + attr::Default::Default => { + Some( + quote!( + let __default: #this #ty_generics = _serde::export::Default::default(); + ), + ) + } + attr::Default::Path(ref path) => { + Some( + quote!( + let __default: #this #ty_generics = #path(); + ), + ) + } + attr::Default::None => { + // We don't need the default value, to prevent an unused variable warning + // we'll leave the line empty. + None + } + }; + + quote_block! { + #(#let_flags)* + + #match_keys + + #let_default + + #(#check_flags)* + + _serde::export::Ok(()) + } +} + fn field_i(i: usize) -> Ident { Ident::new(format!("__field{}", i)) } @@ -1901,6 +2442,8 @@ fn expr_is_missing(field: &Field, cattrs: &attr::Container) -> Fragment { } struct DeImplGenerics<'a>(&'a Parameters); +#[cfg(feature = "deserialize_from")] +struct DeFromImplGenerics<'a>(&'a Parameters); impl<'a> ToTokens for DeImplGenerics<'a> { fn to_tokens(&self, tokens: &mut Tokens) { @@ -1913,7 +2456,38 @@ impl<'a> ToTokens for DeImplGenerics<'a> { } } +#[cfg(feature = "deserialize_from")] +impl<'a> ToTokens for DeFromImplGenerics<'a> { + fn to_tokens(&self, tokens: &mut Tokens) { + let dest_lifetime = dest_lifetime(); + let mut generics = self.0.generics.clone(); + + // Add lifetime for `&'dest mut Self, and `'a: 'dest` + for lifetime in &mut generics.lifetimes { + lifetime.bounds.push(dest_lifetime.lifetime.clone()); + } + for generic in &mut generics.ty_params { + generic.bounds.push(syn::TyParamBound::Region(dest_lifetime.lifetime.clone())); + } + generics.lifetimes.insert(0, dest_lifetime); + if let Some(de_lifetime) = self.0.borrowed.de_lifetime_def() { + generics.lifetimes.insert(0, de_lifetime); + } + let (impl_generics, _, _) = generics.split_for_impl(); + impl_generics.to_tokens(tokens); + } +} + +#[cfg(feature = "deserialize_from")] +impl<'a> DeImplGenerics<'a> { + fn with_dest(self) -> DeFromImplGenerics<'a> { + DeFromImplGenerics(self.0) + } +} + struct DeTyGenerics<'a>(&'a Parameters); +#[cfg(feature = "deserialize_from")] +struct DeFromTyGenerics<'a>(&'a Parameters); impl<'a> ToTokens for DeTyGenerics<'a> { fn to_tokens(&self, tokens: &mut Tokens) { @@ -1928,6 +2502,34 @@ impl<'a> ToTokens for DeTyGenerics<'a> { } } +#[cfg(feature = "deserialize_from")] +impl<'a> ToTokens for DeFromTyGenerics<'a> { + fn to_tokens(&self, tokens: &mut Tokens) { + let mut generics = self.0.generics.clone(); + generics.lifetimes.insert(0, dest_lifetime()); + + if self.0.borrowed.de_lifetime_def().is_some() { + generics + .lifetimes + .insert(0, syn::LifetimeDef::new("'de")); + } + let (_, ty_generics, _) = generics.split_for_impl(); + ty_generics.to_tokens(tokens); + } +} + +#[cfg(feature = "deserialize_from")] +impl<'a> DeTyGenerics<'a> { + fn with_dest(self) -> DeFromTyGenerics<'a> { + DeFromTyGenerics(self.0) + } +} + +#[cfg(feature = "deserialize_from")] +fn dest_lifetime() -> syn::LifetimeDef { + syn::LifetimeDef::new("'dest") +} + fn split_with_de_lifetime(params: &Parameters,) -> (DeImplGenerics, DeTyGenerics, syn::TyGenerics, &syn::WhereClause) { let de_impl_generics = DeImplGenerics(¶ms); diff --git a/serde_derive/src/fragment.rs b/serde_derive/src/fragment.rs index 58cf0a2ca..c882bcf9c 100644 --- a/serde_derive/src/fragment.rs +++ b/serde_derive/src/fragment.rs @@ -73,3 +73,12 @@ impl ToTokens for Match { } } } + +impl AsRef for Fragment { + fn as_ref(&self) -> &Tokens { + match *self { + Fragment::Expr(ref expr) => expr, + Fragment::Block(ref block) => block, + } + } +} diff --git a/serde_derive_internals/src/attr.rs b/serde_derive_internals/src/attr.rs index 9124c0ee9..efe9dff85 100644 --- a/serde_derive_internals/src/attr.rs +++ b/serde_derive_internals/src/attr.rs @@ -164,6 +164,15 @@ pub enum Identifier { Variant, } +impl Identifier { + pub fn is_some(self) -> bool { + match self { + Identifier::No => false, + Identifier::Field | Identifier::Variant => true, + } + } +} + impl Container { /// Extract out the `#[serde(...)]` attributes from an item. pub fn from_ast(cx: &Ctxt, item: &syn::DeriveInput) -> Self { diff --git a/serde_test/src/assert.rs b/serde_test/src/assert.rs index 72116b71c..903a56164 100644 --- a/serde_test/src/assert.rs +++ b/serde_test/src/assert.rs @@ -184,11 +184,27 @@ where T: Deserialize<'de> + PartialEq + Debug, { let mut de = Deserializer::new(tokens); - match T::deserialize(&mut de) { - Ok(v) => assert_eq!(v, *value), + let mut deserialized_val = match T::deserialize(&mut de) { + Ok(v) => { + assert_eq!(v, *value); + v + } Err(e) => panic!("tokens failed to deserialize: {}", e), + }; + if de.remaining() > 0 { + panic!("{} remaining tokens", de.remaining()); } + // Do the same thing for deserialize_from. This isn't *great* because a no-op + // impl of deserialize_from can technically succeed here. Still, this should + // catch a lot of junk. + let mut de = Deserializer::new(tokens); + match deserialized_val.deserialize_from(&mut de) { + Ok(()) => { + assert_eq!(deserialized_val, *value); + } + Err(e) => panic!("tokens failed to deserialize_from: {}", e), + } if de.remaining() > 0 { panic!("{} remaining tokens", de.remaining()); } diff --git a/test_suite/Cargo.toml b/test_suite/Cargo.toml index 74c4a7b5b..1bd4a99c6 100644 --- a/test_suite/Cargo.toml +++ b/test_suite/Cargo.toml @@ -11,7 +11,7 @@ unstable = ["serde/unstable", "compiletest_rs"] fnv = "1.0" rustc-serialize = "0.3.16" serde = { path = "../serde", features = ["rc"] } -serde_derive = { path = "../serde_derive" } +serde_derive = { path = "../serde_derive", features = ["deserialize_from"] } serde_test = { path = "../serde_test" } [dependencies] diff --git a/test_suite/tests/test_de.rs b/test_suite/tests/test_de.rs index 4c801d627..58bcf4648 100644 --- a/test_suite/tests/test_de.rs +++ b/test_suite/tests/test_de.rs @@ -86,6 +86,12 @@ struct StructSkipDefault { #[serde(skip_deserializing)] a: i32, } +#[derive(PartialEq, Debug, Deserialize)] +#[serde(default)] +struct StructSkipDefaultGeneric { + #[serde(skip_deserializing)] t: T, +} + impl Default for StructSkipDefault { fn default() -> Self { StructSkipDefault {