Skip to content

Commit 68da5ae

Browse files
authored
Box Pgconnection fields (#3529)
* Update PgConnection code * rustfmt
1 parent 81298b8 commit 68da5ae

File tree

7 files changed

+137
-104
lines changed

7 files changed

+137
-104
lines changed

sqlx-postgres/src/connection/describe.rs

+25-10
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ impl PgConnection {
163163
}
164164

165165
// next we check a local cache for user-defined type names <-> object id
166-
if let Some(info) = self.cache_type_info.get(&oid) {
166+
if let Some(info) = self.inner.cache_type_info.get(&oid) {
167167
return Ok(info.clone());
168168
}
169169

@@ -173,8 +173,9 @@ impl PgConnection {
173173

174174
// cache the type name <-> oid relationship in a paired hashmap
175175
// so we don't come down this road again
176-
self.cache_type_info.insert(oid, info.clone());
177-
self.cache_type_oid
176+
self.inner.cache_type_info.insert(oid, info.clone());
177+
self.inner
178+
.cache_type_oid
178179
.insert(info.0.name().to_string().into(), oid);
179180

180181
Ok(info)
@@ -374,7 +375,7 @@ WHERE rngtypid = $1
374375
}
375376

376377
pub(crate) async fn fetch_type_id_by_name(&mut self, name: &str) -> Result<Oid, Error> {
377-
if let Some(oid) = self.cache_type_oid.get(name) {
378+
if let Some(oid) = self.inner.cache_type_oid.get(name) {
378379
return Ok(*oid);
379380
}
380381

@@ -387,15 +388,18 @@ WHERE rngtypid = $1
387388
type_name: name.into(),
388389
})?;
389390

390-
self.cache_type_oid.insert(name.to_string().into(), oid);
391+
self.inner
392+
.cache_type_oid
393+
.insert(name.to_string().into(), oid);
391394
Ok(oid)
392395
}
393396

394397
pub(crate) async fn fetch_array_type_id(&mut self, array: &PgArrayOf) -> Result<Oid, Error> {
395398
if let Some(oid) = self
399+
.inner
396400
.cache_type_oid
397401
.get(&array.elem_name)
398-
.and_then(|elem_oid| self.cache_elem_type_to_array.get(elem_oid))
402+
.and_then(|elem_oid| self.inner.cache_elem_type_to_array.get(elem_oid))
399403
{
400404
return Ok(*oid);
401405
}
@@ -411,10 +415,13 @@ WHERE rngtypid = $1
411415
})?;
412416

413417
// Avoids copying `elem_name` until necessary
414-
self.cache_type_oid
418+
self.inner
419+
.cache_type_oid
415420
.entry_ref(&array.elem_name)
416421
.insert(elem_oid);
417-
self.cache_elem_type_to_array.insert(elem_oid, array_oid);
422+
self.inner
423+
.cache_elem_type_to_array
424+
.insert(elem_oid, array_oid);
418425

419426
Ok(array_oid)
420427
}
@@ -475,8 +482,16 @@ WHERE rngtypid = $1
475482
})?;
476483

477484
// If the server is CockroachDB or Materialize, skip this step (#1248).
478-
if !self.stream.parameter_statuses.contains_key("crdb_version")
479-
&& !self.stream.parameter_statuses.contains_key("mz_version")
485+
if !self
486+
.inner
487+
.stream
488+
.parameter_statuses
489+
.contains_key("crdb_version")
490+
&& !self
491+
.inner
492+
.stream
493+
.parameter_statuses
494+
.contains_key("mz_version")
480495
{
481496
// patch up our null inference with data from EXPLAIN
482497
let nullable_patch = self

sqlx-postgres/src/connection/establish.rs

+16-12
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ use crate::message::{
99
};
1010
use crate::{PgConnectOptions, PgConnection};
1111

12+
use super::PgConnectionInner;
13+
1214
// https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.3
1315
// https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.11
1416

@@ -134,18 +136,20 @@ impl PgConnection {
134136
}
135137

136138
Ok(PgConnection {
137-
stream,
138-
process_id,
139-
secret_key,
140-
transaction_status,
141-
transaction_depth: 0,
142-
pending_ready_for_query_count: 0,
143-
next_statement_id: StatementId::NAMED_START,
144-
cache_statement: StatementCache::new(options.statement_cache_capacity),
145-
cache_type_oid: HashMap::new(),
146-
cache_type_info: HashMap::new(),
147-
cache_elem_type_to_array: HashMap::new(),
148-
log_settings: options.log_settings.clone(),
139+
inner: Box::new(PgConnectionInner {
140+
stream,
141+
process_id,
142+
secret_key,
143+
transaction_status,
144+
transaction_depth: 0,
145+
pending_ready_for_query_count: 0,
146+
next_statement_id: StatementId::NAMED_START,
147+
cache_statement: StatementCache::new(options.statement_cache_capacity),
148+
cache_type_oid: HashMap::new(),
149+
cache_type_info: HashMap::new(),
150+
cache_elem_type_to_array: HashMap::new(),
151+
log_settings: options.log_settings.clone(),
152+
}),
149153
})
150154
}
151155
}

sqlx-postgres/src/connection/executor.rs

+29-24
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ async fn prepare(
2626
parameters: &[PgTypeInfo],
2727
metadata: Option<Arc<PgStatementMetadata>>,
2828
) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
29-
let id = conn.next_statement_id;
30-
conn.next_statement_id = id.next();
29+
let id = conn.inner.next_statement_id;
30+
conn.inner.next_statement_id = id.next();
3131

3232
// build a list of type OIDs to send to the database in the PARSE command
3333
// we have not yet started the query sequence, so we are *safe* to cleanly make
@@ -43,23 +43,25 @@ async fn prepare(
4343
conn.wait_until_ready().await?;
4444

4545
// next we send the PARSE command to the server
46-
conn.stream.write_msg(Parse {
46+
conn.inner.stream.write_msg(Parse {
4747
param_types: &param_types,
4848
query: sql,
4949
statement: id,
5050
})?;
5151

5252
if metadata.is_none() {
5353
// get the statement columns and parameters
54-
conn.stream.write_msg(message::Describe::Statement(id))?;
54+
conn.inner
55+
.stream
56+
.write_msg(message::Describe::Statement(id))?;
5557
}
5658

5759
// we ask for the server to immediately send us the result of the PARSE command
5860
conn.write_sync();
59-
conn.stream.flush().await?;
61+
conn.inner.stream.flush().await?;
6062

6163
// indicates that the SQL query string is now successfully parsed and has semantic validity
62-
conn.stream.recv_expect::<ParseComplete>().await?;
64+
conn.inner.stream.recv_expect::<ParseComplete>().await?;
6365

6466
let metadata = if let Some(metadata) = metadata {
6567
// each SYNC produces one READY FOR QUERY
@@ -94,11 +96,11 @@ async fn prepare(
9496
}
9597

9698
async fn recv_desc_params(conn: &mut PgConnection) -> Result<ParameterDescription, Error> {
97-
conn.stream.recv_expect().await
99+
conn.inner.stream.recv_expect().await
98100
}
99101

100102
async fn recv_desc_rows(conn: &mut PgConnection) -> Result<Option<RowDescription>, Error> {
101-
let rows: Option<RowDescription> = match conn.stream.recv().await? {
103+
let rows: Option<RowDescription> = match conn.inner.stream.recv().await? {
102104
// describes the rows that will be returned when the statement is eventually executed
103105
message if message.format == BackendMessageFormat::RowDescription => {
104106
Some(message.decode()?)
@@ -123,7 +125,7 @@ impl PgConnection {
123125
pub(super) async fn wait_for_close_complete(&mut self, mut count: usize) -> Result<(), Error> {
124126
// we need to wait for the [CloseComplete] to be returned from the server
125127
while count > 0 {
126-
match self.stream.recv().await? {
128+
match self.inner.stream.recv().await? {
127129
message if message.format == BackendMessageFormat::PortalSuspended => {
128130
// there was an open portal
129131
// this can happen if the last time a statement was used it was not fully executed
@@ -148,12 +150,13 @@ impl PgConnection {
148150

149151
#[inline(always)]
150152
pub(crate) fn write_sync(&mut self) {
151-
self.stream
153+
self.inner
154+
.stream
152155
.write_msg(message::Sync)
153156
.expect("BUG: Sync should not be too big for protocol");
154157

155158
// all SYNC messages will return a ReadyForQuery
156-
self.pending_ready_for_query_count += 1;
159+
self.inner.pending_ready_for_query_count += 1;
157160
}
158161

159162
async fn get_or_prepare<'a>(
@@ -166,18 +169,18 @@ impl PgConnection {
166169
// a statement object
167170
metadata: Option<Arc<PgStatementMetadata>>,
168171
) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
169-
if let Some(statement) = self.cache_statement.get_mut(sql) {
172+
if let Some(statement) = self.inner.cache_statement.get_mut(sql) {
170173
return Ok((*statement).clone());
171174
}
172175

173176
let statement = prepare(self, sql, parameters, metadata).await?;
174177

175-
if store_to_cache && self.cache_statement.is_enabled() {
176-
if let Some((id, _)) = self.cache_statement.insert(sql, statement.clone()) {
177-
self.stream.write_msg(Close::Statement(id))?;
178+
if store_to_cache && self.inner.cache_statement.is_enabled() {
179+
if let Some((id, _)) = self.inner.cache_statement.insert(sql, statement.clone()) {
180+
self.inner.stream.write_msg(Close::Statement(id))?;
178181
self.write_sync();
179182

180-
self.stream.flush().await?;
183+
self.inner.stream.flush().await?;
181184

182185
self.wait_for_close_complete(1).await?;
183186
self.recv_ready_for_query().await?;
@@ -195,7 +198,7 @@ impl PgConnection {
195198
persistent: bool,
196199
metadata_opt: Option<Arc<PgStatementMetadata>>,
197200
) -> Result<impl Stream<Item = Result<Either<PgQueryResult, PgRow>, Error>> + 'e, Error> {
198-
let mut logger = QueryLogger::new(query, self.log_settings.clone());
201+
let mut logger = QueryLogger::new(query, self.inner.log_settings.clone());
199202

200203
// before we continue, wait until we are "ready" to accept more queries
201204
self.wait_until_ready().await?;
@@ -231,7 +234,7 @@ impl PgConnection {
231234
self.wait_until_ready().await?;
232235

233236
// bind to attach the arguments to the statement and create a portal
234-
self.stream.write_msg(Bind {
237+
self.inner.stream.write_msg(Bind {
235238
portal: PortalId::UNNAMED,
236239
statement,
237240
formats: &[PgValueFormat::Binary],
@@ -242,7 +245,7 @@ impl PgConnection {
242245

243246
// executes the portal up to the passed limit
244247
// the protocol-level limit acts nearly identically to the `LIMIT` in SQL
245-
self.stream.write_msg(message::Execute {
248+
self.inner.stream.write_msg(message::Execute {
246249
portal: PortalId::UNNAMED,
247250
limit: limit.into(),
248251
})?;
@@ -255,7 +258,9 @@ impl PgConnection {
255258

256259
// we ask the database server to close the unnamed portal and free the associated resources
257260
// earlier - after the execution of the current query.
258-
self.stream.write_msg(Close::Portal(PortalId::UNNAMED))?;
261+
self.inner
262+
.stream
263+
.write_msg(Close::Portal(PortalId::UNNAMED))?;
259264

260265
// finally, [Sync] asks postgres to process the messages that we sent and respond with
261266
// a [ReadyForQuery] message when it's completely done. Theoretically, we could send
@@ -268,8 +273,8 @@ impl PgConnection {
268273
PgValueFormat::Binary
269274
} else {
270275
// Query will trigger a ReadyForQuery
271-
self.stream.write_msg(Query(query))?;
272-
self.pending_ready_for_query_count += 1;
276+
self.inner.stream.write_msg(Query(query))?;
277+
self.inner.pending_ready_for_query_count += 1;
273278

274279
// metadata starts out as "nothing"
275280
metadata = Arc::new(PgStatementMetadata::default());
@@ -278,11 +283,11 @@ impl PgConnection {
278283
PgValueFormat::Text
279284
};
280285

281-
self.stream.flush().await?;
286+
self.inner.stream.flush().await?;
282287

283288
Ok(try_stream! {
284289
loop {
285-
let message = self.stream.recv().await?;
290+
let message = self.inner.stream.recv().await?;
286291

287292
match message.format {
288293
BackendMessageFormat::BindComplete

0 commit comments

Comments
 (0)