Skip to content

Commit 74229f9

Browse files
committed
fix: ensure PG connection is established before using it
Fixes launchbadge#1940.
1 parent 2182925 commit 74229f9

File tree

1 file changed

+33
-18
lines changed

1 file changed

+33
-18
lines changed

sqlx-core/src/postgres/listener.rs

+33-18
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ use either::Either;
99
use futures_channel::mpsc;
1010
use futures_core::future::BoxFuture;
1111
use futures_core::stream::{BoxStream, Stream};
12+
use futures_util::{
13+
stream::{StreamExt, TryStreamExt},
14+
FutureExt,
15+
};
1216
use std::fmt::{self, Debug};
1317
use std::io;
1418
use std::str::from_utf8;
@@ -65,6 +69,7 @@ impl PgListener {
6569
/// The channel name is quoted here to ensure case sensitivity.
6670
pub async fn listen(&mut self, channel: &str) -> Result<(), Error> {
6771
self.connection()
72+
.await?
6873
.execute(&*format!(r#"LISTEN "{}""#, ident(channel)))
6974
.await?;
7075

@@ -81,11 +86,8 @@ impl PgListener {
8186
let beg = self.channels.len();
8287
self.channels.extend(channels.into_iter().map(|s| s.into()));
8388

84-
self.connection
85-
.as_mut()
86-
.unwrap()
87-
.execute(&*build_listen_all_query(&self.channels[beg..]))
88-
.await?;
89+
let query = build_listen_all_query(&self.channels[beg..]);
90+
self.connection().await?.execute(&*query).await?;
8991

9092
Ok(())
9193
}
@@ -94,6 +96,7 @@ impl PgListener {
9496
/// The channel name is quoted here to ensure case sensitivity.
9597
pub async fn unlisten(&mut self, channel: &str) -> Result<(), Error> {
9698
self.connection()
99+
.await?
97100
.execute(&*format!(r#"UNLISTEN "{}""#, ident(channel)))
98101
.await?;
99102

@@ -106,7 +109,7 @@ impl PgListener {
106109

107110
/// Stops listening for notifications on all channels.
108111
pub async fn unlisten_all(&mut self) -> Result<(), Error> {
109-
self.connection().execute("UNLISTEN *").await?;
112+
self.connection().await?.execute("UNLISTEN *").await?;
110113

111114
self.channels.clear();
112115

@@ -130,8 +133,11 @@ impl PgListener {
130133
}
131134

132135
#[inline]
133-
fn connection(&mut self) -> &mut PgConnection {
134-
self.connection.as_mut().unwrap()
136+
async fn connection(&mut self) -> Result<&mut PgConnection, Error> {
137+
// Ensure we have an active connection to work with.
138+
self.connect_if_needed().await?;
139+
140+
Ok(self.connection.as_mut().unwrap())
135141
}
136142

137143
/// Receives the next notification available from any of the subscribed channels.
@@ -203,16 +209,13 @@ impl PgListener {
203209
}
204210

205211
loop {
206-
// Ensure we have an active connection to work with.
207-
self.connect_if_needed().await?;
208-
209-
let message = match self.connection().stream.recv_unchecked().await {
212+
let message = match self.connection().await?.stream.recv_unchecked().await {
210213
Ok(message) => message,
211214

212215
// The connection is dead, ensure that it is dropped,
213216
// update self state, and loop to try again.
214217
Err(Error::Io(err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
215-
self.buffer_tx = self.connection().stream.notifications.take();
218+
self.buffer_tx = self.connection().await?.stream.notifications.take();
216219
self.connection = None;
217220

218221
// lost connection
@@ -233,7 +236,7 @@ impl PgListener {
233236

234237
// Mark the connection as ready for another query
235238
MessageFormat::ReadyForQuery => {
236-
self.connection().pending_ready_for_query_count -= 1;
239+
self.connection().await?.pending_ready_for_query_count -= 1;
237240
}
238241

239242
// Ignore unexpected messages
@@ -292,7 +295,13 @@ impl<'c> Executor<'c> for &'c mut PgListener {
292295
'c: 'e,
293296
E: Execute<'q, Self::Database>,
294297
{
295-
self.connection().fetch_many(query)
298+
futures_util::stream::once(async move {
299+
// need some basic type annotation to help the compiler a bit
300+
let res: Result<_, Error> = Ok(self.connection().await?.fetch_many(query));
301+
res
302+
})
303+
.try_flatten()
304+
.boxed()
296305
}
297306

298307
fn fetch_optional<'e, 'q: 'e, E: 'q>(
@@ -303,7 +312,7 @@ impl<'c> Executor<'c> for &'c mut PgListener {
303312
'c: 'e,
304313
E: Execute<'q, Self::Database>,
305314
{
306-
self.connection().fetch_optional(query)
315+
async move { self.connection().await?.fetch_optional(query).await }.boxed()
307316
}
308317

309318
fn prepare_with<'e, 'q: 'e>(
@@ -314,7 +323,13 @@ impl<'c> Executor<'c> for &'c mut PgListener {
314323
where
315324
'c: 'e,
316325
{
317-
self.connection().prepare_with(query, parameters)
326+
async move {
327+
self.connection()
328+
.await?
329+
.prepare_with(query, parameters)
330+
.await
331+
}
332+
.boxed()
318333
}
319334

320335
#[doc(hidden)]
@@ -325,7 +340,7 @@ impl<'c> Executor<'c> for &'c mut PgListener {
325340
where
326341
'c: 'e,
327342
{
328-
self.connection().describe(query)
343+
async move { self.connection().await?.describe(query).await }.boxed()
329344
}
330345
}
331346

0 commit comments

Comments
 (0)