Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpenEnum: type-safe wrapper for enum field values #1061

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions prost-build/src/code_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ impl<'a> CodeGenerator<'a> {

fn append_field(&mut self, fq_message_name: &str, field: &Field) {
let type_ = field.descriptor.r#type();
let repeated = field.descriptor.label == Some(Label::Repeated as i32);
let repeated = field.descriptor.label.and_then(|v| v.known()) == Some(Label::Repeated);
let deprecated = self.deprecated(&field.descriptor);
let optional = self.optional(&field.descriptor);
let boxed = self.boxed(&field.descriptor, fq_message_name, None);
Expand Down Expand Up @@ -947,7 +947,7 @@ impl<'a> CodeGenerator<'a> {
Type::Double => String::from("f64"),
Type::Uint32 | Type::Fixed32 => String::from("u32"),
Type::Uint64 | Type::Fixed64 => String::from("u64"),
Type::Int32 | Type::Sfixed32 | Type::Sint32 | Type::Enum => String::from("i32"),
Type::Int32 | Type::Sfixed32 | Type::Sint32 => String::from("i32"),
Type::Int64 | Type::Sfixed64 | Type::Sint64 => String::from("i64"),
Type::Bool => String::from("bool"),
Type::String => format!("{}::alloc::string::String", prost_path(self.config)),
Expand All @@ -960,6 +960,11 @@ impl<'a> CodeGenerator<'a> {
.rust_type()
.to_owned(),
Type::Group | Type::Message => self.resolve_ident(field.type_name()),
Type::Enum => format!(
"{}::OpenEnum<{}>",
prost_path(self.config),
self.resolve_ident(field.type_name())
),
}
}

Expand Down Expand Up @@ -1063,7 +1068,7 @@ impl<'a> CodeGenerator<'a> {
fq_message_name: &str,
oneof: Option<&str>,
) -> bool {
let repeated = field.label == Some(Label::Repeated as i32);
let repeated = field.label.and_then(|v| v.known()) == Some(Label::Repeated);
let fd_type = field.r#type();
if !repeated
&& (fd_type == Type::Message || fd_type == Type::Group)
Expand Down
24 changes: 9 additions & 15 deletions prost-derive/src/field/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,13 @@ impl Field {
let module = self.map_ty.module();
match &self.value_ty {
ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => {
let default = quote!(#ty::default() as i32);
let default = quote!(::prost::OpenEnum::from(#ty::default()));
quote! {
::prost::encoding::#module::encode_with_default(
#ke,
#kl,
::prost::encoding::int32::encode,
::prost::encoding::int32::encoded_len,
::prost::encoding::enumeration::encode,
::prost::encoding::enumeration::encoded_len,
&(#default),
#tag,
&#ident,
Expand Down Expand Up @@ -184,11 +184,11 @@ impl Field {
let module = self.map_ty.module();
match &self.value_ty {
ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => {
let default = quote!(#ty::default() as i32);
let default = quote!(::prost::OpenEnum::from(#ty::default()));
quote! {
::prost::encoding::#module::merge_with_default(
#km,
::prost::encoding::int32::merge,
::prost::encoding::enumeration::merge,
#default,
&mut #ident,
buf,
Expand Down Expand Up @@ -221,11 +221,11 @@ impl Field {
let module = self.map_ty.module();
match &self.value_ty {
ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => {
let default = quote!(#ty::default() as i32);
let default = quote!(::prost::OpenEnum::from(#ty::default()));
quote! {
::prost::encoding::#module::encoded_len_with_default(
#kl,
::prost::encoding::int32::encoded_len,
::prost::encoding::enumeration::encoded_len,
&(#default),
#tag,
&#ident,
Expand Down Expand Up @@ -275,17 +275,11 @@ impl Field {
Some(quote! {
#[doc=#get_doc]
pub fn #get(&self, key: #key_ref_ty) -> ::core::option::Option<#ty> {
self.#ident.get(#take_ref key).cloned().and_then(|x| {
let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x);
result.ok()
})
self.#ident.get(#take_ref key).cloned().and_then(|x| { x.known() })
}
#[doc=#insert_doc]
pub fn #insert(&mut self, key: #key_ty, value: #ty) -> ::core::option::Option<#ty> {
self.#ident.insert(key, value as i32).and_then(|x| {
let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x);
result.ok()
})
self.#ident.insert(key, value.into()).and_then(|x| { x.known() })
}
})
} else {
Expand Down
36 changes: 16 additions & 20 deletions prost-derive/src/field/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,10 @@ impl Field {
fn debug_inner(&self, wrap_name: TokenStream) -> TokenStream {
if let Ty::Enumeration(ref ty) = self.ty {
quote! {
struct #wrap_name<'a>(&'a i32);
struct #wrap_name<'a>(&'a ::prost::OpenEnum<#ty>);
impl<'a> ::core::fmt::Debug for #wrap_name<'a> {
fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
let res: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(*self.0);
let res = self.0.known_or(());
match res {
Err(_) => ::core::fmt::Debug::fmt(&self.0, f),
Ok(en) => ::core::fmt::Debug::fmt(&en, f),
Expand Down Expand Up @@ -296,12 +296,12 @@ impl Field {
quote! {
#[doc=#get_doc]
pub fn #get(&self) -> #ty {
::core::convert::TryFrom::try_from(self.#ident).unwrap_or(#default)
self.#ident.unwrap_or(#default)
}

#[doc=#set_doc]
pub fn #set(&mut self, value: #ty) {
self.#ident = value as i32;
self.#ident = value.into();
}
}
}
Expand All @@ -314,15 +314,12 @@ impl Field {
quote! {
#[doc=#get_doc]
pub fn #get(&self) -> #ty {
self.#ident.and_then(|x| {
let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x);
result.ok()
}).unwrap_or(#default)
self.#ident.and_then(|x| { x.known() }).unwrap_or(#default)
}

#[doc=#set_doc]
pub fn #set(&mut self, value: #ty) {
self.#ident = ::core::option::Option::Some(value as i32);
self.#ident = ::core::option::Option::Some(value.into());
}
}
}
Expand All @@ -333,20 +330,18 @@ impl Field {
);
let push = Ident::new(&format!("push_{}", ident_str), Span::call_site());
let push_doc = format!("Appends the provided enum value to `{}`.", ident_str);
let wrapped_ty = quote!(::prost::OpenEnum<#ty>);
quote! {
#[doc=#iter_doc]
pub fn #get(&self) -> ::core::iter::FilterMap<
::core::iter::Cloned<::core::slice::Iter<i32>>,
fn(i32) -> ::core::option::Option<#ty>,
::core::iter::Cloned<::core::slice::Iter<#wrapped_ty>>,
fn(#wrapped_ty) -> ::core::option::Option<#ty>,
> {
self.#ident.iter().cloned().filter_map(|x| {
let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x);
result.ok()
})
self.#ident.iter().cloned().filter_map(|x| { x.known() })
}
#[doc=#push_doc]
pub fn #push(&mut self, value: #ty) {
self.#ident.push(value as i32);
self.#ident.push(value.into());
}
}
}
Expand Down Expand Up @@ -532,13 +527,14 @@ impl Ty {
match self {
Ty::String => quote!(::prost::alloc::string::String),
Ty::Bytes(ty) => ty.rust_type(),
Ty::Enumeration(path) => quote!(::prost::OpenEnum<#path>),
_ => self.rust_ref_type(),
}
}

// TODO: rename to 'ref_type'
pub fn rust_ref_type(&self) -> TokenStream {
match *self {
match self {
Ty::Double => quote!(f64),
Ty::Float => quote!(f32),
Ty::Int32 => quote!(i32),
Expand All @@ -554,13 +550,13 @@ impl Ty {
Ty::Bool => quote!(bool),
Ty::String => quote!(&str),
Ty::Bytes(..) => quote!(&[u8]),
Ty::Enumeration(..) => quote!(i32),
Ty::Enumeration(..) => unreachable!("an enum should never be queried for its ref type"),
}
}

pub fn module(&self) -> Ident {
match *self {
Ty::Enumeration(..) => Ident::new("int32", Span::call_site()),
Ty::Enumeration(..) => Ident::new("enumeration", Span::call_site()),
_ => Ident::new(self.as_str(), Span::call_site()),
}
}
Expand Down Expand Up @@ -798,7 +794,7 @@ impl DefaultValue {

pub fn typed(&self) -> TokenStream {
if let DefaultValue::Enumeration(_) = *self {
quote!(#self as i32)
quote!(::prost::OpenEnum::from(#self))
} else {
quote!(#self)
}
Expand Down
30 changes: 17 additions & 13 deletions prost-types/src/protobuf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ pub struct FieldDescriptorProto {
#[prost(int32, optional, tag = "3")]
pub number: ::core::option::Option<i32>,
#[prost(enumeration = "field_descriptor_proto::Label", optional, tag = "4")]
pub label: ::core::option::Option<i32>,
pub label: ::core::option::Option<::prost::OpenEnum<field_descriptor_proto::Label>>,
/// If type_name is set, this need not be set. If both this and type_name
/// are set, this must be one of TYPE_ENUM, TYPE_MESSAGE or TYPE_GROUP.
#[prost(enumeration = "field_descriptor_proto::Type", optional, tag = "5")]
pub r#type: ::core::option::Option<i32>,
pub r#type: ::core::option::Option<::prost::OpenEnum<field_descriptor_proto::Type>>,
/// For message and enum types, this is the name of the type. If the name
/// starts with a '.', it is fully-qualified. Otherwise, C++-like scoping
/// rules are used to find the type (i.e. first the nested types within this
Expand Down Expand Up @@ -484,7 +484,9 @@ pub struct FileOptions {
tag = "9",
default = "Speed"
)]
pub optimize_for: ::core::option::Option<i32>,
pub optimize_for: ::core::option::Option<
::prost::OpenEnum<file_options::OptimizeMode>,
>,
/// Sets the Go package where structs generated from this .proto will be
/// placed. If omitted, the Go package will be derived from the following:
///
Expand Down Expand Up @@ -680,7 +682,7 @@ pub struct FieldOptions {
tag = "1",
default = "String"
)]
pub ctype: ::core::option::Option<i32>,
pub ctype: ::core::option::Option<::prost::OpenEnum<field_options::CType>>,
/// The packed option can be enabled for repeated primitive fields to enable
/// a more efficient representation on the wire. Rather than repeatedly
/// writing the tag and type for each element, the entire array is encoded as
Expand All @@ -705,7 +707,7 @@ pub struct FieldOptions {
tag = "6",
default = "JsNormal"
)]
pub jstype: ::core::option::Option<i32>,
pub jstype: ::core::option::Option<::prost::OpenEnum<field_options::JsType>>,
/// Should this field be parsed lazily? Lazy applies only to message-type
/// fields. It means that when the outer message is initially parsed, the
/// inner message's contents will not be parsed but instead stored in encoded
Expand Down Expand Up @@ -898,7 +900,9 @@ pub struct MethodOptions {
tag = "34",
default = "IdempotencyUnknown"
)]
pub idempotency_level: ::core::option::Option<i32>,
pub idempotency_level: ::core::option::Option<
::prost::OpenEnum<method_options::IdempotencyLevel>,
>,
/// The parser stores options it doesn't recognize here. See above.
#[prost(message, repeated, tag = "999")]
pub uninterpreted_option: ::prost::alloc::vec::Vec<UninterpretedOption>,
Expand Down Expand Up @@ -1332,18 +1336,18 @@ pub struct Type {
pub source_context: ::core::option::Option<SourceContext>,
/// The source syntax.
#[prost(enumeration = "Syntax", tag = "6")]
pub syntax: i32,
pub syntax: ::prost::OpenEnum<Syntax>,
}
/// A single field of a message type.
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Field {
/// The field type.
#[prost(enumeration = "field::Kind", tag = "1")]
pub kind: i32,
pub kind: ::prost::OpenEnum<field::Kind>,
/// The field cardinality.
#[prost(enumeration = "field::Cardinality", tag = "2")]
pub cardinality: i32,
pub cardinality: ::prost::OpenEnum<field::Cardinality>,
/// The field number.
#[prost(int32, tag = "3")]
pub number: i32,
Expand Down Expand Up @@ -1546,7 +1550,7 @@ pub struct Enum {
pub source_context: ::core::option::Option<SourceContext>,
/// The source syntax.
#[prost(enumeration = "Syntax", tag = "5")]
pub syntax: i32,
pub syntax: ::prost::OpenEnum<Syntax>,
}
/// Enum value definition.
#[allow(clippy::derive_partial_eq_without_eq)]
Expand Down Expand Up @@ -1661,7 +1665,7 @@ pub struct Api {
pub mixins: ::prost::alloc::vec::Vec<Mixin>,
/// The source syntax of the service.
#[prost(enumeration = "Syntax", tag = "7")]
pub syntax: i32,
pub syntax: ::prost::OpenEnum<Syntax>,
}
/// Method represents a method of an API interface.
#[allow(clippy::derive_partial_eq_without_eq)]
Expand All @@ -1687,7 +1691,7 @@ pub struct Method {
pub options: ::prost::alloc::vec::Vec<Option>,
/// The source syntax of this method.
#[prost(enumeration = "Syntax", tag = "7")]
pub syntax: i32,
pub syntax: ::prost::OpenEnum<Syntax>,
}
/// Declares an API Interface to be included in this interface. The including
/// interface must redeclare all the methods from the included interface, but
Expand Down Expand Up @@ -2137,7 +2141,7 @@ pub mod value {
pub enum Kind {
/// Represents a null value.
#[prost(enumeration = "super::NullValue", tag = "1")]
NullValue(i32),
NullValue(::prost::OpenEnum<super::NullValue>),
/// Represents a double value.
#[prost(double, tag = "2")]
NumberValue(f64),
Expand Down
Loading
Loading