Skip to content

Commit 4939eff

Browse files
committed
feat: Add try_from attribute for FromRow
1 parent 20877d8 commit 4939eff

File tree

3 files changed

+127
-16
lines changed

3 files changed

+127
-16
lines changed

sqlx-macros/src/derives/attributes.rs

+8
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ pub struct SqlxChildAttributes {
7171
pub rename: Option<String>,
7272
pub default: bool,
7373
pub flatten: bool,
74+
pub try_from: Option<Ident>,
7475
}
7576

7677
pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result<SqlxContainerAttributes> {
@@ -178,6 +179,7 @@ pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result<SqlxContai
178179
pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttributes> {
179180
let mut rename = None;
180181
let mut default = false;
182+
let mut try_from = None;
181183
let mut flatten = false;
182184

183185
for attr in input.iter().filter(|a| a.path.is_ident("sqlx")) {
@@ -194,6 +196,11 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttri
194196
lit: Lit::Str(val),
195197
..
196198
}) if path.is_ident("rename") => try_set!(rename, val.value(), value),
199+
Meta::NameValue(MetaNameValue {
200+
path,
201+
lit: Lit::Str(val),
202+
..
203+
}) if path.is_ident("try_from") => try_set!(try_from, val.parse()?, value),
197204
Meta::Path(path) if path.is_ident("default") => default = true,
198205
Meta::Path(path) if path.is_ident("flatten") => flatten = true,
199206
u => fail!(u, "unexpected attribute"),
@@ -208,6 +215,7 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttri
208215
rename,
209216
default,
210217
flatten,
218+
try_from,
211219
})
212220
}
213221

sqlx-macros/src/derives/row.rs

+39-16
Original file line numberDiff line numberDiff line change
@@ -72,22 +72,45 @@ fn expand_derive_from_row_struct(
7272
let attributes = parse_child_attributes(&field.attrs).unwrap();
7373
let ty = &field.ty;
7474

75-
let expr: Expr = if attributes.flatten {
76-
predicates.push(parse_quote!(#ty: ::sqlx::FromRow<#lifetime, R>));
77-
parse_quote!(#ty::from_row(row))
78-
} else {
79-
predicates.push(parse_quote!(#ty: ::sqlx::decode::Decode<#lifetime, R::Database>));
80-
predicates.push(parse_quote!(#ty: ::sqlx::types::Type<R::Database>));
81-
82-
let id_s = attributes
83-
.rename
84-
.or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned()))
85-
.map(|s| match container_attributes.rename_all {
86-
Some(pattern) => rename_all(&s, pattern),
87-
None => s,
88-
})
89-
.unwrap();
90-
parse_quote!(row.try_get(#id_s))
75+
let expr: Expr = match (attributes.flatten, attributes.try_from) {
76+
(true, None) => {
77+
predicates.push(parse_quote!(#ty: ::sqlx::FromRow<#lifetime, R>));
78+
parse_quote!(#ty::from_row(row))
79+
}
80+
(false, None) => {
81+
predicates
82+
.push(parse_quote!(#ty: ::sqlx::decode::Decode<#lifetime, R::Database>));
83+
predicates.push(parse_quote!(#ty: ::sqlx::types::Type<R::Database>));
84+
85+
let id_s = attributes
86+
.rename
87+
.or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned()))
88+
.map(|s| match container_attributes.rename_all {
89+
Some(pattern) => rename_all(&s, pattern),
90+
None => s,
91+
})
92+
.unwrap();
93+
parse_quote!(row.try_get(#id_s))
94+
}
95+
(true,Some(try_from)) => {
96+
predicates.push(parse_quote!(#try_from: ::sqlx::FromRow<#lifetime, R>));
97+
parse_quote!(#try_from::from_row(row).and_then(|v| <#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v).map_err(|e| ::sqlx::Error::ColumnNotFound("FromRow: try_from failed".to_string()))))
98+
}
99+
(false,Some(try_from)) => {
100+
predicates
101+
.push(parse_quote!(#try_from: ::sqlx::decode::Decode<#lifetime, R::Database>));
102+
predicates.push(parse_quote!(#try_from: ::sqlx::types::Type<R::Database>));
103+
104+
let id_s = attributes
105+
.rename
106+
.or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned()))
107+
.map(|s| match container_attributes.rename_all {
108+
Some(pattern) => rename_all(&s, pattern),
109+
None => s,
110+
})
111+
.unwrap();
112+
parse_quote!(row.try_get(#id_s).and_then(|v| <#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v).map_err(|e| ::sqlx::Error::ColumnNotFound("FromRow: try_from failed".to_string()))))
113+
}
91114
};
92115

93116
if attributes.default {

tests/mysql/macros.rs

+80
Original file line numberDiff line numberDiff line change
@@ -354,4 +354,84 @@ async fn test_column_override_exact_enum() -> anyhow::Result<()> {
354354
Ok(())
355355
}
356356

357+
#[sqlx_macros::test]
358+
async fn test_try_from_attr_for_native_type() -> anyhow::Result<()> {
359+
#[derive(sqlx::FromRow)]
360+
struct Record {
361+
#[sqlx(try_from = "i64")]
362+
id: u64,
363+
}
364+
365+
let mut conn = new::<MySql>().await?;
366+
let (mut conn, id) = with_test_row(&mut conn).await?;
367+
368+
let record = sqlx::query_as::<_, Record>("select id from tweet")
369+
.fetch_one(&mut conn)
370+
.await?;
371+
372+
assert_eq!(record.id, id.0 as u64);
373+
374+
Ok(())
375+
}
376+
377+
#[sqlx_macros::test]
378+
async fn test_try_from_attr_for_custom_type() -> anyhow::Result<()> {
379+
#[derive(sqlx::FromRow)]
380+
struct Record {
381+
#[sqlx(try_from = "i64")]
382+
id: Id,
383+
}
384+
385+
#[derive(Debug, PartialEq)]
386+
struct Id(i64);
387+
impl std::convert::TryFrom<i64> for Id {
388+
type Error = std::io::Error;
389+
fn try_from(value: i64) -> Result<Self, Self::Error> {
390+
Ok(Id(value))
391+
}
392+
}
393+
394+
let mut conn = new::<MySql>().await?;
395+
let (mut conn, id) = with_test_row(&mut conn).await?;
396+
397+
let record = sqlx::query_as::<_, Record>("select id from tweet")
398+
.fetch_one(&mut conn)
399+
.await?;
400+
401+
assert_eq!(record.id, Id(id.0));
402+
403+
Ok(())
404+
}
405+
406+
#[sqlx_macros::test]
407+
async fn test_try_from_attr_with_flatten() -> anyhow::Result<()> {
408+
#[derive(sqlx::FromRow)]
409+
struct Record {
410+
#[sqlx(try_from = "Id", flatten)]
411+
id: u64,
412+
}
413+
414+
#[derive(Debug, PartialEq, sqlx::FromRow)]
415+
struct Id {
416+
id: i64,
417+
};
418+
impl std::convert::TryFrom<Id> for u64 {
419+
type Error = std::io::Error;
420+
fn try_from(value: Id) -> Result<Self, Self::Error> {
421+
Ok(value.id as u64)
422+
}
423+
}
424+
425+
let mut conn = new::<MySql>().await?;
426+
let (mut conn, id) = with_test_row(&mut conn).await?;
427+
428+
let record = sqlx::query_as::<_, Record>("select id from tweet")
429+
.fetch_one(&mut conn)
430+
.await?;
431+
432+
assert_eq!(record.id, id.0 as u64);
433+
434+
Ok(())
435+
}
436+
357437
// we don't emit bind parameter type-checks for MySQL so testing the overrides is redundant

0 commit comments

Comments
 (0)