Skip to content

Commit 29073cb

Browse files
authored
fix: ensure PG connection is established before using it (#1989)
Fixes #1940.
1 parent 5e08cd0 commit 29073cb

File tree

1 file changed

+40
-21
lines changed

1 file changed

+40
-21
lines changed

sqlx-core/src/postgres/listener.rs

+40-21
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use either::Either;
66
use futures_channel::mpsc;
77
use futures_core::future::BoxFuture;
88
use futures_core::stream::{BoxStream, Stream};
9+
use futures_util::{FutureExt, StreamExt, TryStreamExt};
910

1011
use crate::describe::Describe;
1112
use crate::error::Error;
@@ -96,6 +97,7 @@ impl PgListener {
9697
/// The channel name is quoted here to ensure case sensitivity.
9798
pub async fn listen(&mut self, channel: &str) -> Result<(), Error> {
9899
self.connection()
100+
.await?
99101
.execute(&*format!(r#"LISTEN "{}""#, ident(channel)))
100102
.await?;
101103

@@ -112,21 +114,22 @@ impl PgListener {
112114
let beg = self.channels.len();
113115
self.channels.extend(channels.into_iter().map(|s| s.into()));
114116

115-
self.connection
116-
.as_mut()
117-
.unwrap()
118-
.execute(&*build_listen_all_query(&self.channels[beg..]))
119-
.await?;
117+
let query = build_listen_all_query(&self.channels[beg..]);
118+
self.connection().await?.execute(&*query).await?;
120119

121120
Ok(())
122121
}
123122

124123
/// Stops listening for notifications on a channel.
125124
/// The channel name is quoted here to ensure case sensitivity.
126125
pub async fn unlisten(&mut self, channel: &str) -> Result<(), Error> {
127-
self.connection()
128-
.execute(&*format!(r#"UNLISTEN "{}""#, ident(channel)))
129-
.await?;
126+
// use RAW connection and do NOT re-connect automatically, since this is not required for
127+
// UNLISTEN (we've disconnected anyways)
128+
if let Some(connection) = self.connection.as_mut() {
129+
connection
130+
.execute(&*format!(r#"UNLISTEN "{}""#, ident(channel)))
131+
.await?;
132+
}
130133

131134
if let Some(pos) = self.channels.iter().position(|s| s == channel) {
132135
self.channels.remove(pos);
@@ -137,7 +140,11 @@ impl PgListener {
137140

138141
/// Stops listening for notifications on all channels.
139142
pub async fn unlisten_all(&mut self) -> Result<(), Error> {
140-
self.connection().execute("UNLISTEN *").await?;
143+
// use RAW connection and do NOT re-connect automatically, since this is not required for
144+
// UNLISTEN (we've disconnected anyways)
145+
if let Some(connection) = self.connection.as_mut() {
146+
connection.execute("UNLISTEN *").await?;
147+
}
141148

142149
self.channels.clear();
143150

@@ -161,8 +168,11 @@ impl PgListener {
161168
}
162169

163170
#[inline]
164-
fn connection(&mut self) -> &mut PgConnection {
165-
self.connection.as_mut().unwrap()
171+
async fn connection(&mut self) -> Result<&mut PgConnection, Error> {
172+
// Ensure we have an active connection to work with.
173+
self.connect_if_needed().await?;
174+
175+
Ok(self.connection.as_mut().unwrap())
166176
}
167177

168178
/// Receives the next notification available from any of the subscribed channels.
@@ -237,10 +247,7 @@ impl PgListener {
237247
let mut close_event = (!self.ignore_close_event).then(|| self.pool.close_event());
238248

239249
loop {
240-
// Ensure we have an active connection to work with.
241-
self.connect_if_needed().await?;
242-
243-
let next_message = self.connection().stream.recv_unchecked();
250+
let next_message = self.connection().await?.stream.recv_unchecked();
244251

245252
let res = if let Some(ref mut close_event) = close_event {
246253
// cancels the wait and returns `Err(PoolClosed)` if the pool is closed
@@ -256,7 +263,7 @@ impl PgListener {
256263
// The connection is dead, ensure that it is dropped,
257264
// update self state, and loop to try again.
258265
Err(Error::Io(err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
259-
self.buffer_tx = self.connection().stream.notifications.take();
266+
self.buffer_tx = self.connection().await?.stream.notifications.take();
260267
self.connection = None;
261268

262269
// lost connection
@@ -277,7 +284,7 @@ impl PgListener {
277284

278285
// Mark the connection as ready for another query
279286
MessageFormat::ReadyForQuery => {
280-
self.connection().pending_ready_for_query_count -= 1;
287+
self.connection().await?.pending_ready_for_query_count -= 1;
281288
}
282289

283290
// Ignore unexpected messages
@@ -336,7 +343,13 @@ impl<'c> Executor<'c> for &'c mut PgListener {
336343
'c: 'e,
337344
E: Execute<'q, Self::Database>,
338345
{
339-
self.connection().fetch_many(query)
346+
futures_util::stream::once(async move {
347+
// need some basic type annotation to help the compiler a bit
348+
let res: Result<_, Error> = Ok(self.connection().await?.fetch_many(query));
349+
res
350+
})
351+
.try_flatten()
352+
.boxed()
340353
}
341354

342355
fn fetch_optional<'e, 'q: 'e, E: 'q>(
@@ -347,7 +360,7 @@ impl<'c> Executor<'c> for &'c mut PgListener {
347360
'c: 'e,
348361
E: Execute<'q, Self::Database>,
349362
{
350-
self.connection().fetch_optional(query)
363+
async move { self.connection().await?.fetch_optional(query).await }.boxed()
351364
}
352365

353366
fn prepare_with<'e, 'q: 'e>(
@@ -358,7 +371,13 @@ impl<'c> Executor<'c> for &'c mut PgListener {
358371
where
359372
'c: 'e,
360373
{
361-
self.connection().prepare_with(query, parameters)
374+
async move {
375+
self.connection()
376+
.await?
377+
.prepare_with(query, parameters)
378+
.await
379+
}
380+
.boxed()
362381
}
363382

364383
#[doc(hidden)]
@@ -369,7 +388,7 @@ impl<'c> Executor<'c> for &'c mut PgListener {
369388
where
370389
'c: 'e,
371390
{
372-
self.connection().describe(query)
391+
async move { self.connection().await?.describe(query).await }.boxed()
373392
}
374393
}
375394

0 commit comments

Comments
 (0)