@@ -9,6 +9,10 @@ use either::Either;
9
9
use futures_channel:: mpsc;
10
10
use futures_core:: future:: BoxFuture ;
11
11
use futures_core:: stream:: { BoxStream , Stream } ;
12
+ use futures_util:: {
13
+ stream:: { StreamExt , TryStreamExt } ,
14
+ FutureExt ,
15
+ } ;
12
16
use std:: fmt:: { self , Debug } ;
13
17
use std:: io;
14
18
use std:: str:: from_utf8;
@@ -65,6 +69,7 @@ impl PgListener {
65
69
/// The channel name is quoted here to ensure case sensitivity.
66
70
pub async fn listen ( & mut self , channel : & str ) -> Result < ( ) , Error > {
67
71
self . connection ( )
72
+ . await ?
68
73
. execute ( & * format ! ( r#"LISTEN "{}""# , ident( channel) ) )
69
74
. await ?;
70
75
@@ -81,11 +86,8 @@ impl PgListener {
81
86
let beg = self . channels . len ( ) ;
82
87
self . channels . extend ( channels. into_iter ( ) . map ( |s| s. into ( ) ) ) ;
83
88
84
- self . connection
85
- . as_mut ( )
86
- . unwrap ( )
87
- . execute ( & * build_listen_all_query ( & self . channels [ beg..] ) )
88
- . await ?;
89
+ let query = build_listen_all_query ( & self . channels [ beg..] ) ;
90
+ self . connection ( ) . await ?. execute ( & * query) . await ?;
89
91
90
92
Ok ( ( ) )
91
93
}
@@ -94,6 +96,7 @@ impl PgListener {
94
96
/// The channel name is quoted here to ensure case sensitivity.
95
97
pub async fn unlisten ( & mut self , channel : & str ) -> Result < ( ) , Error > {
96
98
self . connection ( )
99
+ . await ?
97
100
. execute ( & * format ! ( r#"UNLISTEN "{}""# , ident( channel) ) )
98
101
. await ?;
99
102
@@ -106,7 +109,7 @@ impl PgListener {
106
109
107
110
/// Stops listening for notifications on all channels.
108
111
pub async fn unlisten_all ( & mut self ) -> Result < ( ) , Error > {
109
- self . connection ( ) . execute ( "UNLISTEN *" ) . await ?;
112
+ self . connection ( ) . await ? . execute ( "UNLISTEN *" ) . await ?;
110
113
111
114
self . channels . clear ( ) ;
112
115
@@ -130,8 +133,11 @@ impl PgListener {
130
133
}
131
134
132
135
#[ inline]
133
- fn connection ( & mut self ) -> & mut PgConnection {
134
- self . connection . as_mut ( ) . unwrap ( )
136
+ async fn connection ( & mut self ) -> Result < & mut PgConnection , Error > {
137
+ // Ensure we have an active connection to work with.
138
+ self . connect_if_needed ( ) . await ?;
139
+
140
+ Ok ( self . connection . as_mut ( ) . unwrap ( ) )
135
141
}
136
142
137
143
/// Receives the next notification available from any of the subscribed channels.
@@ -203,16 +209,13 @@ impl PgListener {
203
209
}
204
210
205
211
loop {
206
- // Ensure we have an active connection to work with.
207
- self . connect_if_needed ( ) . await ?;
208
-
209
- let message = match self . connection ( ) . stream . recv_unchecked ( ) . await {
212
+ let message = match self . connection ( ) . await ?. stream . recv_unchecked ( ) . await {
210
213
Ok ( message) => message,
211
214
212
215
// The connection is dead, ensure that it is dropped,
213
216
// update self state, and loop to try again.
214
217
Err ( Error :: Io ( err) ) if err. kind ( ) == io:: ErrorKind :: ConnectionAborted => {
215
- self . buffer_tx = self . connection ( ) . stream . notifications . take ( ) ;
218
+ self . buffer_tx = self . connection ( ) . await ? . stream . notifications . take ( ) ;
216
219
self . connection = None ;
217
220
218
221
// lost connection
@@ -233,7 +236,7 @@ impl PgListener {
233
236
234
237
// Mark the connection as ready for another query
235
238
MessageFormat :: ReadyForQuery => {
236
- self . connection ( ) . pending_ready_for_query_count -= 1 ;
239
+ self . connection ( ) . await ? . pending_ready_for_query_count -= 1 ;
237
240
}
238
241
239
242
// Ignore unexpected messages
@@ -292,7 +295,13 @@ impl<'c> Executor<'c> for &'c mut PgListener {
292
295
' c : ' e ,
293
296
E : Execute < ' q , Self :: Database > ,
294
297
{
295
- self . connection ( ) . fetch_many ( query)
298
+ futures_util:: stream:: once ( async move {
299
+ // need some basic type annotation to help the compiler a bit
300
+ let res: Result < _ , Error > = Ok ( self . connection ( ) . await ?. fetch_many ( query) ) ;
301
+ res
302
+ } )
303
+ . try_flatten ( )
304
+ . boxed ( )
296
305
}
297
306
298
307
fn fetch_optional < ' e , ' q : ' e , E : ' q > (
@@ -303,7 +312,7 @@ impl<'c> Executor<'c> for &'c mut PgListener {
303
312
' c : ' e ,
304
313
E : Execute < ' q , Self :: Database > ,
305
314
{
306
- self . connection ( ) . fetch_optional ( query)
315
+ async move { self . connection ( ) . await ? . fetch_optional ( query) . await } . boxed ( )
307
316
}
308
317
309
318
fn prepare_with < ' e , ' q : ' e > (
@@ -314,7 +323,13 @@ impl<'c> Executor<'c> for &'c mut PgListener {
314
323
where
315
324
' c : ' e ,
316
325
{
317
- self . connection ( ) . prepare_with ( query, parameters)
326
+ async move {
327
+ self . connection ( )
328
+ . await ?
329
+ . prepare_with ( query, parameters)
330
+ . await
331
+ }
332
+ . boxed ( )
318
333
}
319
334
320
335
#[ doc( hidden) ]
@@ -325,7 +340,7 @@ impl<'c> Executor<'c> for &'c mut PgListener {
325
340
where
326
341
' c : ' e ,
327
342
{
328
- self . connection ( ) . describe ( query)
343
+ async move { self . connection ( ) . await ? . describe ( query) . await } . boxed ( )
329
344
}
330
345
}
331
346
0 commit comments