Skip to content

Commit 16e3f10

Browse files
committed
fix(postgres): add missing type resolution for arrays by name
1 parent efbf572 commit 16e3f10

File tree

19 files changed

+333
-84
lines changed

19 files changed

+333
-84
lines changed

Cargo.lock

+14-14
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ bit-vec = "0.6.3"
135135
chrono = { version = "0.4.22", default-features = false }
136136
ipnetwork = "0.20.0"
137137
mac_address = "1.1.5"
138-
rust_decimal = "1.26.1"
138+
rust_decimal = { version = "1.26.1", default-features = false, features = ["std"] }
139139
time = { version = "0.3.36", features = ["formatting", "parsing", "macros"] }
140140
uuid = "1.1.2"
141141

sqlx-core/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ uuid = { workspace = true, optional = true }
5151

5252
async-io = { version = "1.9.0", optional = true }
5353
paste = "1.0.6"
54-
ahash = "0.8.7"
5554
atoi = "2.0"
5655

5756
bytes = "1.1.0"
@@ -88,6 +87,7 @@ bstr = { version = "1.0", default-features = false, features = ["std"], optional
8887
hashlink = "0.9.0"
8988
indexmap = "2.0"
9089
event-listener = "5.2.0"
90+
hashbrown = "0.14.5"
9191

9292
[dev-dependencies]
9393
sqlx = { workspace = true, features = ["postgres", "sqlite", "mysql", "migrate", "macros", "time", "uuid"] }

sqlx-core/src/ext/ustr.rs

+14
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@ impl UStr {
1717
pub fn new(s: &str) -> Self {
1818
UStr::Shared(Arc::from(s.to_owned()))
1919
}
20+
21+
/// Apply [str::strip_prefix], without copying if possible.
22+
pub fn strip_prefix(this: &Self, prefix: &str) -> Option<Self> {
23+
match this {
24+
UStr::Static(s) => s.strip_prefix(prefix).map(Self::Static),
25+
UStr::Shared(s) => s.strip_prefix(prefix).map(|s| Self::Shared(s.into())),
26+
}
27+
}
2028
}
2129

2230
impl Deref for UStr {
@@ -60,6 +68,12 @@ impl From<&'static str> for UStr {
6068
}
6169
}
6270

71+
impl<'a> From<&'a UStr> for UStr {
72+
fn from(value: &'a UStr) -> Self {
73+
value.clone()
74+
}
75+
}
76+
6377
impl From<String> for UStr {
6478
#[inline]
6579
fn from(s: String) -> Self {

sqlx-core/src/lib.rs

+2-5
Original file line numberDiff line numberDiff line change
@@ -95,18 +95,15 @@ pub mod testing;
9595

9696
pub use error::{Error, Result};
9797

98-
/// sqlx uses ahash for increased performance, at the cost of reduced DoS resistance.
99-
pub use ahash::AHashMap as HashMap;
10098
pub use either::Either;
99+
pub use hashbrown::{hash_map, HashMap};
101100
pub use indexmap::IndexMap;
102101
pub use percent_encoding;
103102
pub use smallvec::SmallVec;
104103
pub use url::{self, Url};
105104

106105
pub use bytes;
107106

108-
//type HashMap<K, V> = std::collections::HashMap<K, V, ahash::RandomState>;
109-
110107
/// Helper module to get drivers compiling again that used to be in this crate,
111108
/// to avoid having to replace tons of `use crate::<...>` imports.
112109
///
@@ -119,6 +116,6 @@ pub mod driver_prelude {
119116
};
120117

121118
pub use crate::error::{Error, Result};
122-
pub use crate::HashMap;
119+
pub use crate::{hash_map, HashMap};
123120
pub use either::Either;
124121
}

sqlx-core/src/type_info.rs

+10
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,16 @@ pub trait TypeInfo: Debug + Display + Clone + PartialEq<Self> + Send + Sync {
99
/// should be a rough approximation of how they are written in SQL in the given database.
1010
fn name(&self) -> &str;
1111

12+
/// Return `true` if `self` and `other` represent mutually compatible types.
13+
///
14+
/// Defaults to `self == other`.
15+
fn type_compatible(&self, other: &Self) -> bool
16+
where
17+
Self: Sized,
18+
{
19+
self == other
20+
}
21+
1222
#[doc(hidden)]
1323
fn is_void(&self) -> bool {
1424
false

sqlx-core/src/types/mod.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,10 @@ pub trait Type<DB: Database> {
210210
///
211211
/// When binding arguments with `query!` or `query_as!`, this method is consulted to determine
212212
/// if the Rust type is acceptable.
213+
///
214+
/// Defaults to checking [`TypeInfo::type_compatible()`].
213215
fn compatible(ty: &DB::TypeInfo) -> bool {
214-
*ty == Self::type_info()
216+
Self::type_info().type_compatible(ty)
215217
}
216218
}
217219

sqlx-macros-core/src/derives/type.rs

+36-16
Original file line numberDiff line numberDiff line change
@@ -14,35 +14,42 @@ use syn::{
1414
pub fn expand_derive_type(input: &DeriveInput) -> syn::Result<TokenStream> {
1515
let attrs = parse_container_attributes(&input.attrs)?;
1616
match &input.data {
17+
// Newtype structs:
18+
// struct Foo(i32);
1719
Data::Struct(DataStruct {
1820
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
1921
..
20-
}) if unnamed.len() == 1 => {
21-
expand_derive_has_sql_type_transparent(input, unnamed.first().unwrap())
22+
}) => {
23+
if unnamed.len() == 1 {
24+
expand_derive_has_sql_type_transparent(input, unnamed.first().unwrap())
25+
} else {
26+
Err(syn::Error::new_spanned(
27+
input,
28+
"structs with zero or more than one unnamed field are not supported",
29+
))
30+
}
2231
}
23-
Data::Enum(DataEnum { variants, .. }) => match attrs.repr {
24-
Some(_) => expand_derive_has_sql_type_weak_enum(input, variants),
25-
None => expand_derive_has_sql_type_strong_enum(input, variants),
26-
},
32+
// Record types
33+
// struct Foo { foo: i32, bar: String }
2734
Data::Struct(DataStruct {
2835
fields: Fields::Named(FieldsNamed { named, .. }),
2936
..
3037
}) => expand_derive_has_sql_type_struct(input, named),
31-
Data::Union(_) => Err(syn::Error::new_spanned(input, "unions are not supported")),
32-
Data::Struct(DataStruct {
33-
fields: Fields::Unnamed(..),
34-
..
35-
}) => Err(syn::Error::new_spanned(
36-
input,
37-
"structs with zero or more than one unnamed field are not supported",
38-
)),
3938
Data::Struct(DataStruct {
4039
fields: Fields::Unit,
4140
..
4241
}) => Err(syn::Error::new_spanned(
4342
input,
4443
"unit structs are not supported",
4544
)),
45+
46+
Data::Enum(DataEnum { variants, .. }) => match attrs.repr {
47+
// Enums that encode to/from integers (weak enums)
48+
Some(_) => expand_derive_has_sql_type_weak_enum(input, variants),
49+
// Enums that decode to/from strings (strong enums)
50+
None => expand_derive_has_sql_type_strong_enum(input, variants),
51+
},
52+
Data::Union(_) => Err(syn::Error::new_spanned(input, "unions are not supported")),
4653
}
4754
}
4855

@@ -148,9 +155,10 @@ fn expand_derive_has_sql_type_weak_enum(
148155

149156
if cfg!(feature = "postgres") && !attrs.no_pg_array {
150157
ts.extend(quote!(
158+
#[automatically_derived]
151159
impl ::sqlx::postgres::PgHasArrayType for #ident {
152160
fn array_type_info() -> ::sqlx::postgres::PgTypeInfo {
153-
<#ident as ::sqlx::postgres::PgHasArrayType>::array_type_info()
161+
<#repr as ::sqlx::postgres::PgHasArrayType>::array_type_info()
154162
}
155163
}
156164
));
@@ -197,9 +205,10 @@ fn expand_derive_has_sql_type_strong_enum(
197205

198206
if !attributes.no_pg_array {
199207
tts.extend(quote!(
208+
#[automatically_derived]
200209
impl ::sqlx::postgres::PgHasArrayType for #ident {
201210
fn array_type_info() -> ::sqlx::postgres::PgTypeInfo {
202-
<#ident as ::sqlx::postgres::PgHasArrayType>::array_type_info()
211+
::sqlx::postgres::PgTypeInfo::array_of(#ty_name)
203212
}
204213
}
205214
));
@@ -244,6 +253,17 @@ fn expand_derive_has_sql_type_struct(
244253
}
245254
}
246255
));
256+
257+
if !attributes.no_pg_array {
258+
tts.extend(quote!(
259+
#[automatically_derived]
260+
impl ::sqlx::postgres::PgHasArrayType for #ident {
261+
fn array_type_info() -> ::sqlx::postgres::PgTypeInfo {
262+
::sqlx::postgres::PgTypeInfo::array_of(#ty_name)
263+
}
264+
}
265+
));
266+
}
247267
}
248268

249269
Ok(tts)

sqlx-postgres/Cargo.toml

+3
Original file line numberDiff line numberDiff line change
@@ -71,5 +71,8 @@ workspace = true
7171
# We use JSON in the driver implementation itself so there's no reason not to enable it here.
7272
features = ["json"]
7373

74+
[dev-dependencies]
75+
sqlx.workspace = true
76+
7477
[target.'cfg(target_os = "windows")'.dependencies]
7578
etcetera = "0.8.0"

sqlx-postgres/src/arguments.rs

+26-4
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
use std::fmt::{self, Write};
22
use std::ops::{Deref, DerefMut};
3+
use std::sync::Arc;
34

45
use crate::encode::{Encode, IsNull};
56
use crate::error::Error;
67
use crate::ext::ustr::UStr;
78
use crate::types::Type;
89
use crate::{PgConnection, PgTypeInfo, Postgres};
910

11+
use crate::type_info::PgArrayOf;
1012
pub(crate) use sqlx_core::arguments::Arguments;
1113
use sqlx_core::error::BoxDynError;
1214

@@ -41,7 +43,12 @@ pub struct PgArgumentBuffer {
4143
// This is done for Records and Arrays as the OID is needed well before we are in an async
4244
// function and can just ask postgres.
4345
//
44-
type_holes: Vec<(usize, UStr)>, // Vec<{ offset, type_name }>
46+
type_holes: Vec<(usize, HoleKind)>, // Vec<{ offset, type_name }>
47+
}
48+
49+
enum HoleKind {
50+
Type { name: UStr },
51+
Array(Arc<PgArrayOf>),
4552
}
4653

4754
struct Patch {
@@ -106,8 +113,11 @@ impl PgArguments {
106113
(patch.callback)(buf, ty);
107114
}
108115

109-
for (offset, name) in type_holes {
110-
let oid = conn.fetch_type_id_by_name(name).await?;
116+
for (offset, kind) in type_holes {
117+
let oid = match kind {
118+
HoleKind::Type { name } => conn.fetch_type_id_by_name(name).await?,
119+
HoleKind::Array(array) => conn.fetch_array_type_id(array).await?,
120+
};
111121
buffer[*offset..(*offset + 4)].copy_from_slice(&oid.0.to_be_bytes());
112122
}
113123

@@ -186,7 +196,19 @@ impl PgArgumentBuffer {
186196
let offset = self.len();
187197

188198
self.extend_from_slice(&0_u32.to_be_bytes());
189-
self.type_holes.push((offset, type_name.clone()));
199+
self.type_holes.push((
200+
offset,
201+
HoleKind::Type {
202+
name: type_name.clone(),
203+
},
204+
));
205+
}
206+
207+
pub(crate) fn patch_array_type(&mut self, array: Arc<PgArrayOf>) {
208+
let offset = self.len();
209+
210+
self.extend_from_slice(&0_u32.to_be_bytes());
211+
self.type_holes.push((offset, HoleKind::Array(array)));
190212
}
191213

192214
fn snapshot(&self) -> PgArgumentBufferSnapshot {

0 commit comments

Comments
 (0)