1
- use sqlx:: { pool:: PoolConnection , postgres:: PgPoolOptions , Error , PgPool , Postgres } ;
1
+ use anyhow:: Context ;
2
+ use sqlx:: {
3
+ pool:: PoolConnection ,
4
+ postgres:: { PgConnectOptions , PgPoolOptions } ,
5
+ ConnectOptions , Error , PgPool , Postgres ,
6
+ } ;
7
+ use tide_disco:: Url ;
2
8
3
9
use crate :: DatabaseOptions ;
4
10
@@ -8,13 +14,41 @@ impl PostgresClient {
8
14
pub async fn connect ( opts : DatabaseOptions ) -> anyhow:: Result < Self > {
9
15
let DatabaseOptions {
10
16
url,
17
+ host,
18
+ port,
19
+ db_name,
20
+ username,
21
+ password,
11
22
max_connections,
12
23
acquire_timeout,
13
24
migrations,
14
25
} = opts;
15
26
16
27
let mut options = PgPoolOptions :: new ( ) ;
17
28
29
+ let postgres_url: Url = match url {
30
+ Some ( url) => url. parse ( ) ?,
31
+ None => {
32
+ let host = host. context ( "host not provided" ) ?;
33
+ let port = port. context ( "port not provided" ) ?;
34
+ let mut connect_opts = PgConnectOptions :: new ( ) . host ( & host) . port ( port) ;
35
+
36
+ if let Some ( username) = username {
37
+ connect_opts = connect_opts. username ( & username) ;
38
+ }
39
+
40
+ if let Some ( password) = password {
41
+ connect_opts = connect_opts. password ( & password) ;
42
+ }
43
+
44
+ if let Some ( db_name) = db_name {
45
+ connect_opts = connect_opts. database ( & db_name) ;
46
+ }
47
+
48
+ connect_opts. to_url_lossy ( )
49
+ }
50
+ } ;
51
+
18
52
if let Some ( max_connections) = max_connections {
19
53
options = options. max_connections ( max_connections) ;
20
54
}
@@ -23,7 +57,7 @@ impl PostgresClient {
23
57
options = options. acquire_timeout ( acquire_timeout) ;
24
58
}
25
59
26
- let connection = options. connect ( & url ) . await ?;
60
+ let connection = options. connect ( postgres_url . as_str ( ) ) . await ?;
27
61
28
62
if migrations {
29
63
sqlx:: migrate!( "./migration" ) . run ( & connection) . await ?;
@@ -32,7 +66,56 @@ impl PostgresClient {
32
66
Ok ( Self ( connection) )
33
67
}
34
68
69
+ pub fn pool ( & self ) -> & PgPool {
70
+ & self . 0
71
+ }
72
+
35
73
pub async fn acquire ( & self ) -> Result < PoolConnection < Postgres > , Error > {
36
74
self . 0 . acquire ( ) . await
37
75
}
38
76
}
77
+
78
+ #[ cfg( all( test, not( target_os = "windows" ) ) ) ]
79
+ mod test {
80
+ use hotshot_query_service:: data_source:: sql:: testing:: TmpDb ;
81
+
82
+ use super :: PostgresClient ;
83
+ use crate :: DatabaseOptions ;
84
+
85
+ #[ async_std:: test]
86
+ async fn test_database_connection ( ) {
87
+ let db = TmpDb :: init ( ) . await ;
88
+ let host = db. host ( ) ;
89
+ let port = db. port ( ) ;
90
+
91
+ let opts = DatabaseOptions {
92
+ url : None ,
93
+ host : Some ( host) ,
94
+ port : Some ( port) ,
95
+ db_name : None ,
96
+ username : Some ( "postgres" . to_string ( ) ) ,
97
+ password : Some ( "password" . to_string ( ) ) ,
98
+ max_connections : Some ( 100 ) ,
99
+ acquire_timeout : None ,
100
+ migrations : true ,
101
+ } ;
102
+
103
+ let client = PostgresClient :: connect ( opts)
104
+ . await
105
+ . expect ( "failed to connect to database" ) ;
106
+
107
+ let pool = client. pool ( ) ;
108
+
109
+ sqlx:: query ( "INSERT INTO test (str) VALUES ('testing');" )
110
+ . execute ( pool)
111
+ . await
112
+ . unwrap ( ) ;
113
+
114
+ let result: i64 = sqlx:: query_scalar ( "Select id from test where str = 'testing';" )
115
+ . fetch_one ( pool)
116
+ . await
117
+ . unwrap ( ) ;
118
+
119
+ assert_eq ! ( result, 1 ) ;
120
+ }
121
+ }
0 commit comments