@@ -6,6 +6,7 @@ use either::Either;
6
6
use futures_channel:: mpsc;
7
7
use futures_core:: future:: BoxFuture ;
8
8
use futures_core:: stream:: { BoxStream , Stream } ;
9
+ use futures_util:: { FutureExt , StreamExt , TryStreamExt } ;
9
10
10
11
use crate :: describe:: Describe ;
11
12
use crate :: error:: Error ;
@@ -96,6 +97,7 @@ impl PgListener {
96
97
/// The channel name is quoted here to ensure case sensitivity.
97
98
pub async fn listen ( & mut self , channel : & str ) -> Result < ( ) , Error > {
98
99
self . connection ( )
100
+ . await ?
99
101
. execute ( & * format ! ( r#"LISTEN "{}""# , ident( channel) ) )
100
102
. await ?;
101
103
@@ -112,21 +114,22 @@ impl PgListener {
112
114
let beg = self . channels . len ( ) ;
113
115
self . channels . extend ( channels. into_iter ( ) . map ( |s| s. into ( ) ) ) ;
114
116
115
- self . connection
116
- . as_mut ( )
117
- . unwrap ( )
118
- . execute ( & * build_listen_all_query ( & self . channels [ beg..] ) )
119
- . await ?;
117
+ let query = build_listen_all_query ( & self . channels [ beg..] ) ;
118
+ self . connection ( ) . await ?. execute ( & * query) . await ?;
120
119
121
120
Ok ( ( ) )
122
121
}
123
122
124
123
/// Stops listening for notifications on a channel.
125
124
/// The channel name is quoted here to ensure case sensitivity.
126
125
pub async fn unlisten ( & mut self , channel : & str ) -> Result < ( ) , Error > {
127
- self . connection ( )
128
- . execute ( & * format ! ( r#"UNLISTEN "{}""# , ident( channel) ) )
129
- . await ?;
126
+ // use RAW connection and do NOT re-connect automatically, since this is not required for
127
+ // UNLISTEN (we've disconnected anyways)
128
+ if let Some ( connection) = self . connection . as_mut ( ) {
129
+ connection
130
+ . execute ( & * format ! ( r#"UNLISTEN "{}""# , ident( channel) ) )
131
+ . await ?;
132
+ }
130
133
131
134
if let Some ( pos) = self . channels . iter ( ) . position ( |s| s == channel) {
132
135
self . channels . remove ( pos) ;
@@ -137,7 +140,11 @@ impl PgListener {
137
140
138
141
/// Stops listening for notifications on all channels.
139
142
pub async fn unlisten_all ( & mut self ) -> Result < ( ) , Error > {
140
- self . connection ( ) . execute ( "UNLISTEN *" ) . await ?;
143
+ // use RAW connection and do NOT re-connect automatically, since this is not required for
144
+ // UNLISTEN (we've disconnected anyways)
145
+ if let Some ( connection) = self . connection . as_mut ( ) {
146
+ connection. execute ( "UNLISTEN *" ) . await ?;
147
+ }
141
148
142
149
self . channels . clear ( ) ;
143
150
@@ -161,8 +168,11 @@ impl PgListener {
161
168
}
162
169
163
170
#[ inline]
164
- fn connection ( & mut self ) -> & mut PgConnection {
165
- self . connection . as_mut ( ) . unwrap ( )
171
+ async fn connection ( & mut self ) -> Result < & mut PgConnection , Error > {
172
+ // Ensure we have an active connection to work with.
173
+ self . connect_if_needed ( ) . await ?;
174
+
175
+ Ok ( self . connection . as_mut ( ) . unwrap ( ) )
166
176
}
167
177
168
178
/// Receives the next notification available from any of the subscribed channels.
@@ -237,10 +247,7 @@ impl PgListener {
237
247
let mut close_event = ( !self . ignore_close_event ) . then ( || self . pool . close_event ( ) ) ;
238
248
239
249
loop {
240
- // Ensure we have an active connection to work with.
241
- self . connect_if_needed ( ) . await ?;
242
-
243
- let next_message = self . connection ( ) . stream . recv_unchecked ( ) ;
250
+ let next_message = self . connection ( ) . await ?. stream . recv_unchecked ( ) ;
244
251
245
252
let res = if let Some ( ref mut close_event) = close_event {
246
253
// cancels the wait and returns `Err(PoolClosed)` if the pool is closed
@@ -256,7 +263,7 @@ impl PgListener {
256
263
// The connection is dead, ensure that it is dropped,
257
264
// update self state, and loop to try again.
258
265
Err ( Error :: Io ( err) ) if err. kind ( ) == io:: ErrorKind :: ConnectionAborted => {
259
- self . buffer_tx = self . connection ( ) . stream . notifications . take ( ) ;
266
+ self . buffer_tx = self . connection ( ) . await ? . stream . notifications . take ( ) ;
260
267
self . connection = None ;
261
268
262
269
// lost connection
@@ -277,7 +284,7 @@ impl PgListener {
277
284
278
285
// Mark the connection as ready for another query
279
286
MessageFormat :: ReadyForQuery => {
280
- self . connection ( ) . pending_ready_for_query_count -= 1 ;
287
+ self . connection ( ) . await ? . pending_ready_for_query_count -= 1 ;
281
288
}
282
289
283
290
// Ignore unexpected messages
@@ -336,7 +343,13 @@ impl<'c> Executor<'c> for &'c mut PgListener {
336
343
' c : ' e ,
337
344
E : Execute < ' q , Self :: Database > ,
338
345
{
339
- self . connection ( ) . fetch_many ( query)
346
+ futures_util:: stream:: once ( async move {
347
+ // need some basic type annotation to help the compiler a bit
348
+ let res: Result < _ , Error > = Ok ( self . connection ( ) . await ?. fetch_many ( query) ) ;
349
+ res
350
+ } )
351
+ . try_flatten ( )
352
+ . boxed ( )
340
353
}
341
354
342
355
fn fetch_optional < ' e , ' q : ' e , E : ' q > (
@@ -347,7 +360,7 @@ impl<'c> Executor<'c> for &'c mut PgListener {
347
360
' c : ' e ,
348
361
E : Execute < ' q , Self :: Database > ,
349
362
{
350
- self . connection ( ) . fetch_optional ( query)
363
+ async move { self . connection ( ) . await ? . fetch_optional ( query) . await } . boxed ( )
351
364
}
352
365
353
366
fn prepare_with < ' e , ' q : ' e > (
@@ -358,7 +371,13 @@ impl<'c> Executor<'c> for &'c mut PgListener {
358
371
where
359
372
' c : ' e ,
360
373
{
361
- self . connection ( ) . prepare_with ( query, parameters)
374
+ async move {
375
+ self . connection ( )
376
+ . await ?
377
+ . prepare_with ( query, parameters)
378
+ . await
379
+ }
380
+ . boxed ( )
362
381
}
363
382
364
383
#[ doc( hidden) ]
@@ -369,7 +388,7 @@ impl<'c> Executor<'c> for &'c mut PgListener {
369
388
where
370
389
' c : ' e ,
371
390
{
372
- self . connection ( ) . describe ( query)
391
+ async move { self . connection ( ) . await ? . describe ( query) . await } . boxed ( )
373
392
}
374
393
}
375
394
0 commit comments