diff --git a/controller-client/src/lib.rs b/controller-client/src/lib.rs index ab2a697c7..433aa7386 100644 --- a/controller-client/src/lib.rs +++ b/controller-client/src/lib.rs @@ -29,8 +29,10 @@ use std::result::Result as StdResult; use std::time::Duration; +use snafu::ResultExt; use snafu::Snafu; use tonic::transport::channel::Channel; +use tonic::transport::Error as tonicError; use tonic::{Code, Status}; use async_trait::async_trait; @@ -45,7 +47,7 @@ use log::debug; use pravega_rust_client_shared::*; use pravega_wire_protocol::client_config::ClientConfig; use pravega_wire_protocol::connection_pool::{ConnectionPool, Manager, PooledConnection}; -use pravega_wire_protocol::error::*; +use pravega_wire_protocol::error::ConnectionPoolError; use std::convert::{From, Into}; use std::net::SocketAddr; use uuid::Uuid; @@ -78,6 +80,13 @@ pub enum ControllerError { can_retry: bool, endpoint: String, error_msg: String, + source: tonicError, + }, + #[snafu(display("Could not get connection from connection pool"))] + PoolError { + can_retry: bool, + endpoint: String, + source: ConnectionPoolError, }, } @@ -218,11 +227,10 @@ impl ControllerClientImpl { #[async_trait] impl ControllerClient for ControllerClientImpl { async fn create_scope(&self, scope: &Scope) -> Result { - let connection = self - .pool - .get_connection(self.endpoint) - .await - .expect("get connection"); + let connection = self.pool.get_connection(self.endpoint).await.context(PoolError { + can_retry: true, + endpoint: self.endpoint.to_string(), + })?; create_scope(scope, connection).await } @@ -231,74 +239,66 @@ impl ControllerClient for ControllerClientImpl { } async fn delete_scope(&self, scope: &Scope) -> Result { - let connection = self - .pool - .get_connection(self.endpoint) - .await - .expect("get connection"); + let connection = self.pool.get_connection(self.endpoint).await.context(PoolError { + can_retry: true, + endpoint: self.endpoint.to_string(), + })?; delete_scope(scope, connection).await } async fn create_stream(&self, stream_config: &StreamConfiguration) -> Result { - let connection = self - .pool - .get_connection(self.endpoint) - .await - .expect("get connection"); + let connection = self.pool.get_connection(self.endpoint).await.context(PoolError { + can_retry: true, + endpoint: self.endpoint.to_string(), + })?; create_stream(stream_config, connection).await } async fn update_stream(&self, stream_config: &StreamConfiguration) -> Result { - let connection = self - .pool - .get_connection(self.endpoint) - .await - .expect("get connection"); + let connection = self.pool.get_connection(self.endpoint).await.context(PoolError { + can_retry: true, + endpoint: self.endpoint.to_string(), + })?; update_stream(stream_config, connection).await } async fn truncate_stream(&self, stream_cut: &StreamCut) -> Result { - let connection = self - .pool - .get_connection(self.endpoint) - .await - .expect("get connection"); + let connection = self.pool.get_connection(self.endpoint).await.context(PoolError { + can_retry: true, + endpoint: self.endpoint.to_string(), + })?; truncate_stream(stream_cut, connection).await } async fn seal_stream(&self, stream: &ScopedStream) -> Result { - let connection = self - .pool - .get_connection(self.endpoint) - .await - .expect("get connection"); + let connection = self.pool.get_connection(self.endpoint).await.context(PoolError { + can_retry: true, + endpoint: self.endpoint.to_string(), + })?; seal_stream(stream, connection).await } async fn delete_stream(&self, stream: &ScopedStream) -> Result { - let connection = self - .pool - .get_connection(self.endpoint) - .await - .expect("get connection"); + let connection = self.pool.get_connection(self.endpoint).await.context(PoolError { + can_retry: true, + endpoint: self.endpoint.to_string(), + })?; delete_stream(stream, connection).await } async fn get_current_segments(&self, stream: &ScopedStream) -> Result { - let connection = self - .pool - .get_connection(self.endpoint) - .await - .expect("get connection"); + let connection = self.pool.get_connection(self.endpoint).await.context(PoolError { + can_retry: true, + endpoint: self.endpoint.to_string(), + })?; get_current_segments(stream, connection).await } async fn create_transaction(&self, stream: &ScopedStream, lease: Duration) -> Result { - let connection = self - .pool - .get_connection(self.endpoint) - .await - .expect("get connection"); + let connection = self.pool.get_connection(self.endpoint).await.context(PoolError { + can_retry: true, + endpoint: self.endpoint.to_string(), + })?; create_transaction(stream, lease, connection).await } @@ -308,11 +308,10 @@ impl ControllerClient for ControllerClientImpl { tx_id: TxId, lease: Duration, ) -> Result { - let connection = self - .pool - .get_connection(self.endpoint) - .await - .expect("get connection"); + let connection = self.pool.get_connection(self.endpoint).await.context(PoolError { + can_retry: true, + endpoint: self.endpoint.to_string(), + })?; ping_transaction(stream, tx_id, lease, connection).await } @@ -323,20 +322,18 @@ impl ControllerClient for ControllerClientImpl { writer_id: WriterId, time: Timestamp, ) -> Result<()> { - let connection = self - .pool - .get_connection(self.endpoint) - .await - .expect("get connection"); + let connection = self.pool.get_connection(self.endpoint).await.context(PoolError { + can_retry: true, + endpoint: self.endpoint.to_string(), + })?; commit_transaction(stream, tx_id, writer_id, time, connection).await } async fn abort_transaction(&self, stream: &ScopedStream, tx_id: TxId) -> Result<()> { - let connection = self - .pool - .get_connection(self.endpoint) - .await - .expect("get connection"); + let connection = self.pool.get_connection(self.endpoint).await.context(PoolError { + can_retry: true, + endpoint: self.endpoint.to_string(), + })?; abort_transaction(stream, tx_id, connection).await } @@ -345,20 +342,18 @@ impl ControllerClient for ControllerClientImpl { stream: &ScopedStream, tx_id: TxId, ) -> Result { - let connection = self - .pool - .get_connection(self.endpoint) - .await - .expect("get connection"); + let connection = self.pool.get_connection(self.endpoint).await.context(PoolError { + can_retry: true, + endpoint: self.endpoint.to_string(), + })?; check_transaction_status(stream, tx_id, connection).await } async fn get_endpoint_for_segment(&self, segment: &ScopedSegment) -> Result { - let connection = self - .pool - .get_connection(self.endpoint) - .await - .expect("get connection"); + let connection = self.pool.get_connection(self.endpoint).await.context(PoolError { + can_retry: true, + endpoint: self.endpoint.to_string(), + })?; get_uri_segment(segment, connection).await } @@ -367,11 +362,10 @@ impl ControllerClient for ControllerClientImpl { } async fn get_successors(&self, segment: &ScopedSegment) -> Result { - let connection = self - .pool - .get_connection(self.endpoint) - .await - .expect("get connection"); + let connection = self.pool.get_connection(self.endpoint).await.context(PoolError { + can_retry: true, + endpoint: self.endpoint.to_string(), + })?; get_successors(segment, connection).await } } @@ -411,11 +405,17 @@ impl Manager for ControllerConnectionManager { &self, endpoint: SocketAddr, ) -> std::result::Result { - let channel = create_connection(&format!("{}{}", "http://", &endpoint.to_string())).await; - Ok(ControllerConnection::new(endpoint, channel)) + let result = create_connection(&format!("{}{}", "http://", &endpoint.to_string())).await; + match result { + Ok(channel) => Ok(ControllerConnection::new(endpoint, channel)), + Err(_e) => Err(ConnectionPoolError::EstablishConnection { + endpoint: endpoint.to_string(), + error_msg: String::from("Could not establish connection"), + }), + } } - fn is_valid(&self, _conn: &PooledConnection<'_, Self::Conn>) -> bool { + fn is_valid(&self, _conn: &Self::Conn) -> bool { true } @@ -425,12 +425,16 @@ impl Manager for ControllerConnectionManager { } /// create_connection with the given controller uri. -pub async fn create_connection(uri: &str) -> ControllerServiceClient { +async fn create_connection(uri: &str) -> Result> { // Placeholder to add authentication headers. - let connection: ControllerServiceClient = ControllerServiceClient::connect(uri.to_string()) + let connection = ControllerServiceClient::connect(uri.to_string()) .await - .expect("Failed to create a channel"); - connection + .context(ConnectionError { + can_retry: true, + endpoint: String::from(uri), + error_msg: String::from("Connection Refused"), + })?; + Ok(connection) } // Method used to translate grpc errors to custom error. diff --git a/integration_test/Cargo.toml b/integration_test/Cargo.toml index 5842f212a..82666bbb6 100644 --- a/integration_test/Cargo.toml +++ b/integration_test/Cargo.toml @@ -17,7 +17,8 @@ pravega-client-rust = { path = "../" } pravega-wire-protocol = { path = "../wire_protocol"} pravega-controller-client = { path = "../controller-client"} pravega-rust-client-shared = { path = "../shared"} -tokio = { version = "0.2.8", features = ["full"] } +pravega-rust-client-retry = {path = "../retry"} +tokio = { version = "0.2.13", features = ["full"] } lazy_static = "1.4.0" uuid = {version = "0.8", features = ["v4"]} diff --git a/integration_test/src/disconnection_tests.rs b/integration_test/src/disconnection_tests.rs new file mode 100644 index 000000000..3e9f3860f --- /dev/null +++ b/integration_test/src/disconnection_tests.rs @@ -0,0 +1,248 @@ +use super::check_standalone_status; +use super::wait_for_standalone_with_timeout; +use crate::pravega_service::{PravegaService, PravegaStandaloneService}; +use log::info; +use pravega_client_rust::raw_client::{RawClient, RawClientImpl}; +use pravega_client_rust::setup_logger; +use pravega_controller_client::{ControllerClient, ControllerClientImpl}; +use pravega_rust_client_retry::retry_async::retry_async; +use pravega_rust_client_retry::retry_policy::RetryWithBackoff; +use pravega_rust_client_retry::retry_result::RetryResult; +use pravega_rust_client_shared::*; +use pravega_wire_protocol::client_config::ClientConfigBuilder; +use pravega_wire_protocol::client_connection::{ClientConnection, ClientConnectionImpl}; +use pravega_wire_protocol::commands::{HelloCommand, SealSegmentCommand}; +use pravega_wire_protocol::connection_factory::{ConnectionFactory, ConnectionType}; +use pravega_wire_protocol::connection_pool::{ConnectionPool, SegmentConnectionManager}; +use pravega_wire_protocol::wire_commands::Requests; +use pravega_wire_protocol::wire_commands::{Encode, Replies}; +use std::io::{Read, Write}; +use std::net::{Shutdown, SocketAddr, TcpListener}; +use std::process::Command; +use std::time::Duration; +use std::{thread, time}; +use tokio::runtime::Runtime; + +pub async fn disconnection_test_wrapper() { + test_retry_with_no_connection().await; + let mut pravega = PravegaStandaloneService::start(false); + test_retry_while_start_pravega().await; + assert_eq!(check_standalone_status(), true); + test_retry_with_unexpected_reply().await; + pravega.stop().unwrap(); + wait_for_standalone_with_timeout(false, 10); + test_with_mock_server().await; +} + +async fn test_retry_with_no_connection() { + let retry_policy = RetryWithBackoff::default().max_tries(4); + // give a wrong endpoint + let endpoint = "127.0.0.1:0" + .parse::() + .expect("Unable to parse socket address"); + + let cf = ConnectionFactory::create(ConnectionType::Tokio); + let manager = SegmentConnectionManager::new(cf, 1); + let pool = ConnectionPool::new(manager); + + let raw_client = RawClientImpl::new(&pool, endpoint); + + let result = retry_async(retry_policy, || async { + let request = Requests::Hello(HelloCommand { + low_version: 5, + high_version: 9, + }); + let reply = raw_client.send_request(&request).await; + match reply { + Ok(r) => RetryResult::Success(r), + Err(error) => RetryResult::Retry(error), + } + }) + .await; + if let Err(e) = result { + assert_eq!(e.tries, 5); + } else { + panic!("Test failed.") + } +} + +async fn test_retry_while_start_pravega() { + let retry_policy = RetryWithBackoff::default().max_tries(10); + let controller_uri = "127.0.0.1:9090" + .parse::() + .expect("parse to socketaddr"); + let config = ClientConfigBuilder::default() + .controller_uri(controller_uri) + .build() + .expect("build client config"); + let controller_client = ControllerClientImpl::new(config); + + let scope_name = Scope::new("retryScope".into()); + + let result = retry_async(retry_policy, || async { + let result = controller_client.create_scope(&scope_name).await; + match result { + Ok(created) => RetryResult::Success(created), + Err(error) => RetryResult::Retry(error), + } + }) + .await + .expect("create scope"); + assert!(result, true); + + let stream_name = Stream::new("testStream".into()); + let request = StreamConfiguration { + scoped_stream: ScopedStream { + scope: scope_name.clone(), + stream: stream_name.clone(), + }, + scaling: Scaling { + scale_type: ScaleType::FixedNumSegments, + target_rate: 0, + scale_factor: 0, + min_num_segments: 1, + }, + retention: Retention { + retention_type: RetentionType::None, + retention_param: 0, + }, + }; + let retry_policy = RetryWithBackoff::default().max_tries(10); + let result = retry_async(retry_policy, || async { + let result = controller_client.create_stream(&request).await; + match result { + Ok(created) => RetryResult::Success(created), + Err(error) => RetryResult::Retry(error), + } + }) + .await + .expect("create stream"); + assert!(result, true); +} + +async fn test_retry_with_unexpected_reply() { + let retry_policy = RetryWithBackoff::default().max_tries(4); + let scope_name = Scope::new("retryScope".into()); + let stream_name = Stream::new("retryStream".into()); + let controller_uri = "127.0.0.1:9090" + .parse::() + .expect("parse to socketaddr"); + let config = ClientConfigBuilder::default() + .controller_uri(controller_uri) + .build() + .expect("build client config"); + + let controller_client = ControllerClientImpl::new(config); + + //Get the endpoint. + let segment_name = ScopedSegment { + scope: scope_name.clone(), + stream: stream_name.clone(), + segment: Segment { number: 0 }, + }; + let endpoint = controller_client + .get_endpoint_for_segment(&segment_name) + .await + .expect("get endpoint for segment") + .parse::() + .expect("convert to socketaddr"); + + let cf = ConnectionFactory::create(ConnectionType::Tokio); + let manager = SegmentConnectionManager::new(cf, 1); + let pool = ConnectionPool::new(manager); + let raw_client = RawClientImpl::new(&pool, endpoint); + let result = retry_async(retry_policy, || async { + let request = Requests::SealSegment(SealSegmentCommand { + segment: segment_name.to_string(), + request_id: 0, + delegation_token: String::from(""), + }); + let reply = raw_client.send_request(&request).await; + match reply { + Ok(r) => match r { + Replies::SegmentSealed(_) => RetryResult::Success(r), + Replies::NoSuchSegment(_) => RetryResult::Retry("No Such Segment"), + _ => RetryResult::Fail("Wrong reply type"), + }, + Err(_error) => RetryResult::Retry("Connection Refused"), + } + }) + .await; + if let Err(e) = result { + assert_eq!(e.error, "No Such Segment"); + } else { + panic!("Test failed.") + } +} + +struct Server { + listener: TcpListener, + address: SocketAddr, +} + +impl Server { + pub fn new() -> Server { + let listener = TcpListener::bind("127.0.0.1:0").expect("local server"); + let address = listener.local_addr().expect("get listener address"); + Server { address, listener } + } +} + +async fn test_with_mock_server() { + let server = Server::new(); + let endpoint = server.address; + + thread::spawn(move || { + for stream in server.listener.incoming() { + let mut client = stream.expect("get a new client connection"); + let mut buffer = [0u8; 100]; + let _size = client.read(&mut buffer).unwrap(); + let reply = Replies::Hello(HelloCommand { + high_version: 9, + low_version: 5, + }); + let data = reply.write_fields().expect("serialize"); + client.write(&data).expect("send back the reply"); + // close connection immediately to mock the connection failed. + client.shutdown(Shutdown::Both).expect("shutdown the connection"); + } + drop(server); + }); + + let cf = ConnectionFactory::create(ConnectionType::Mock); + let manager = SegmentConnectionManager::new(cf, 3); + let pool = ConnectionPool::new(manager); + + // test with 3 requests, they should be all succeed. + for _i in 0..3 { + let retry_policy = RetryWithBackoff::default().max_tries(5); + let result = retry_async(retry_policy, || async { + let connection = pool + .get_connection(endpoint) + .await + .expect("get connection from pool"); + let mut client_connection = ClientConnectionImpl { connection }; + let request = Requests::Hello(HelloCommand { + high_version: 9, + low_version: 5, + }); + let reply = client_connection.write(&request).await; + if let Err(error) = reply { + return RetryResult::Retry(error); + } + + let reply = client_connection.read().await; + match reply { + Ok(r) => RetryResult::Success(r), + Err(error) => RetryResult::Retry(error), + } + }) + .await; + + if let Ok(r) = result { + println!("reply is {:?}", r); + } else { + panic!("Test failed.") + } + } +} diff --git a/integration_test/src/lib.rs b/integration_test/src/lib.rs index ec3146ea1..df0a22795 100644 --- a/integration_test/src/lib.rs +++ b/integration_test/src/lib.rs @@ -11,6 +11,8 @@ #![allow(dead_code)] #![allow(unused_imports)] +#[cfg(test)] +mod disconnection_tests; mod event_stream_writer_tests; mod pravega_service; mod wirecommand_tests; @@ -46,6 +48,7 @@ fn check_standalone_status() -> bool { #[cfg(test)] mod test { use super::*; + use wirecommand_tests::*; #[test] fn integration_test() { @@ -55,12 +58,15 @@ mod test { let mut pravega = PravegaStandaloneService::start(false); wait_for_standalone_with_timeout(true, 30); - wirecommand_tests::test_wirecommand(&mut rt); + rt.block_on(wirecommand_tests::wirecommand_test_wrapper()); rt.block_on(event_stream_writer_tests::test_event_stream_writer()); // Shut down Pravega standalone pravega.stop().unwrap(); wait_for_standalone_with_timeout(false, 30); + + // disconnection test will start its own Pravega Standalone. + rt.block_on(disconnection_tests::disconnection_test_wrapper()); } } diff --git a/integration_test/src/wirecommand_tests.rs b/integration_test/src/wirecommand_tests.rs index fa6ebbcdf..98a51ffa2 100644 --- a/integration_test/src/wirecommand_tests.rs +++ b/integration_test/src/wirecommand_tests.rs @@ -11,9 +11,7 @@ use lazy_static::*; use pravega_client_rust::raw_client::RawClient; use pravega_client_rust::raw_client::RawClientImpl; -use pravega_controller_client::{ - create_connection, ControllerClient, ControllerClientImpl, ControllerConnectionManager, -}; +use pravega_controller_client::{ControllerClient, ControllerClientImpl, ControllerConnectionManager}; use pravega_rust_client_shared::*; use pravega_wire_protocol::client_config::{ClientConfig, ClientConfigBuilder, TEST_CONTROLLER_URI}; use pravega_wire_protocol::client_connection::{ClientConnection, ClientConnectionImpl}; @@ -44,76 +42,57 @@ lazy_static! { static ref CONTROLLER_CLIENT: ControllerClientImpl = { ControllerClientImpl::new(CONFIG.clone()) }; } -pub fn test_wirecommand(rt: &mut Runtime) { +pub async fn wirecommand_test_wrapper() { let timeout_second = time::Duration::from_secs(30); - rt.block_on(async { - timeout(timeout_second, test_hello()).await.unwrap(); - }); - rt.block_on(async { - timeout(timeout_second, test_keep_alive()).await.unwrap(); - }); - rt.block_on(async { - timeout(timeout_second, test_setup_append()).await.unwrap(); - }); - rt.block_on(async { - timeout(timeout_second, test_create_segment()).await.unwrap(); - }); - rt.block_on(async { - timeout(timeout_second, test_update_and_get_segment_attribute()) - .await - .unwrap(); - }); + timeout(timeout_second, test_hello()).await.unwrap(); - rt.block_on(async { - timeout(timeout_second, test_get_stream_segment_info()) - .await - .unwrap(); - }); + timeout(timeout_second, test_keep_alive()).await.unwrap(); - rt.block_on(async { - timeout(timeout_second, test_seal_segment()).await.unwrap(); - }); + timeout(timeout_second, test_setup_append()).await.unwrap(); - rt.block_on(async { - timeout(timeout_second, test_delete_segment()).await.unwrap(); - }); - rt.block_on(async { - timeout(timeout_second, test_conditional_append_and_read_segment()) - .await - .unwrap(); - }); - rt.block_on(async { - timeout(timeout_second, test_update_segment_policy()) - .await - .unwrap(); - }); - rt.block_on(async { - timeout(timeout_second, test_merge_segment()).await.unwrap(); - }); - rt.block_on(async { - timeout(timeout_second, test_truncate_segment()).await.unwrap(); - }); - rt.block_on(async { - timeout(timeout_second, test_update_table_entries()) - .await - .unwrap(); - }); - rt.block_on(async { - timeout(timeout_second, test_read_table_key()).await.unwrap(); - }); - rt.block_on(async { - timeout(timeout_second, test_read_table()).await.unwrap(); - }); - rt.block_on(async { - timeout(timeout_second, test_read_table_entries()).await.unwrap(); - }); + timeout(timeout_second, test_create_segment()).await.unwrap(); + + timeout(timeout_second, test_update_and_get_segment_attribute()) + .await + .unwrap(); + + timeout(timeout_second, test_get_stream_segment_info()) + .await + .unwrap(); + + timeout(timeout_second, test_seal_segment()).await.unwrap(); + + timeout(timeout_second, test_delete_segment()).await.unwrap(); + + timeout(timeout_second, test_conditional_append_and_read_segment()) + .await + .unwrap(); + + timeout(timeout_second, test_update_segment_policy()) + .await + .unwrap(); + + timeout(timeout_second, test_merge_segment()).await.unwrap(); + + timeout(timeout_second, test_truncate_segment()).await.unwrap(); + + timeout(timeout_second, test_update_table_entries()) + .await + .unwrap(); + + timeout(timeout_second, test_read_table_key()).await.unwrap(); + + timeout(timeout_second, test_read_table()).await.unwrap(); + + timeout(timeout_second, test_read_table_entries()).await.unwrap(); } async fn test_hello() { let scope_name = Scope::new("testScope".into()); let stream_name = Stream::new("testStream".into()); // Create scope and stream + CONTROLLER_CLIENT .create_scope(&scope_name) .await @@ -354,6 +333,7 @@ async fn test_update_and_get_segment_attribute() { stream: stream_name.clone(), segment: Segment { number: 0 }, }; + let endpoint = CONTROLLER_CLIENT .get_endpoint_for_segment(&segment_name) .await diff --git a/wire_protocol/Cargo.toml b/wire_protocol/Cargo.toml index 35c3f5ab3..f585f5f9c 100644 --- a/wire_protocol/Cargo.toml +++ b/wire_protocol/Cargo.toml @@ -26,6 +26,8 @@ parking_lot = "0.10.0" uuid = {version = "0.8", features = ["v4"]} serde = { version = "1.0", features = ["derive"] } snafu = "0.6.2" -tokio = { version = "0.2.8", features = ["full"] } +tokio = { version = "0.2.16", features = ["full"] } +futures = "0.3.4" dashmap = "3.4.4" -log = "0.4.8" \ No newline at end of file +log = "0.4.8" + diff --git a/wire_protocol/src/connection.rs b/wire_protocol/src/connection.rs index ba7f26773..ab86ae1de 100644 --- a/wire_protocol/src/connection.rs +++ b/wire_protocol/src/connection.rs @@ -133,7 +133,11 @@ impl Connection for TokioConnection { } fn is_valid(&self) -> bool { - self.stream.is_some() + let result = self.stream.as_ref().expect("get connection").peer_addr(); + match result { + Err(_e) => false, + Ok(_addr) => true, + } } } diff --git a/wire_protocol/src/connection_pool.rs b/wire_protocol/src/connection_pool.rs index 0018af03f..e5c9fbf25 100644 --- a/wire_protocol/src/connection_pool.rs +++ b/wire_protocol/src/connection_pool.rs @@ -10,11 +10,9 @@ use crate::connection::Connection; use crate::connection_factory::ConnectionFactory; -use crate::error::*; - +use crate::error::ConnectionPoolError; use async_trait::async_trait; use dashmap::DashMap; -use snafu::ResultExt; use std::fmt; use std::net::SocketAddr; use std::ops::{Deref, DerefMut}; @@ -51,7 +49,7 @@ pub trait Manager { /// Check whether this connection is still valid. This method will be used to filter out /// invalid connections when putting connection back to the pool - fn is_valid(&self, conn: &PooledConnection<'_, Self::Conn>) -> bool; + fn is_valid(&self, conn: &Self::Conn) -> bool; /// Get the maximum connections in the pool fn get_max_connections(&self) -> u32; @@ -82,14 +80,19 @@ impl Manager for SegmentConnectionManager { type Conn = Box; async fn establish_connection(&self, endpoint: SocketAddr) -> Result { - self.connection_factory - .establish_connection(endpoint) - .await - .context(EstablishConnection {}) + let result = self.connection_factory.establish_connection(endpoint).await; + + match result { + Ok(conn) => Ok(conn), + Err(_e) => Err(ConnectionPoolError::EstablishConnection { + endpoint: endpoint.to_string(), + error_msg: String::from("Could not establish connection"), + }), + } } - fn is_valid(&self, conn: &PooledConnection<'_, Self::Conn>) -> bool { - conn.inner.as_ref().expect("get inner connection").is_valid() + fn is_valid(&self, conn: &Self::Conn) -> bool { + conn.is_valid() } fn get_max_connections(&self) -> u32 { @@ -150,23 +153,33 @@ where &self, endpoint: SocketAddr, ) -> Result, ConnectionPoolError> { - match self.managed_pool.get_connection(endpoint) { - Ok(internal_conn) => Ok(PooledConnection { - uuid: internal_conn.uuid, - inner: Some(internal_conn.conn), - endpoint, - pool: &self.managed_pool, - valid: true, - }), - Err(_e) => { - let conn = self.manager.establish_connection(endpoint).await?; - Ok(PooledConnection { - uuid: Uuid::new_v4(), - inner: Some(conn), - endpoint, - pool: &self.managed_pool, - valid: true, - }) + // use an infinite loop. + loop { + match self.managed_pool.get_connection(endpoint) { + Ok(internal_conn) => { + let conn = internal_conn.conn; + if self.manager.is_valid(&conn) { + return Ok(PooledConnection { + uuid: internal_conn.uuid, + inner: Some(conn), + endpoint, + pool: &self.managed_pool, + valid: true, + }); + } + + //if it is not valid, will be delete automatically + } + Err(_e) => { + let conn = self.manager.establish_connection(endpoint).await?; + return Ok(PooledConnection { + uuid: Uuid::new_v4(), + inner: Some(conn), + endpoint, + pool: &self.managed_pool, + valid: true, + }); + } } } } @@ -305,7 +318,7 @@ mod tests { Ok(FooConnection {}) } - fn is_valid(&self, _conn: &PooledConnection<'_, Self::Conn>) -> bool { + fn is_valid(&self, _conn: &Self::Conn) -> bool { true } diff --git a/wire_protocol/src/error.rs b/wire_protocol/src/error.rs index 9cf1a0018..a4a7fee1f 100644 --- a/wire_protocol/src/error.rs +++ b/wire_protocol/src/error.rs @@ -142,11 +142,8 @@ pub enum ClientConnectionError { #[derive(Debug, Snafu)] #[snafu(visibility = "pub(crate)")] pub enum ConnectionPoolError { - #[snafu(display("Could not establish connection to endpoint: {}", source))] - EstablishConnection { - source: ConnectionFactoryError, - backtrace: Backtrace, - }, + #[snafu(display("Could not establish connection to endpoint: {}", endpoint))] + EstablishConnection { endpoint: String, error_msg: String }, #[snafu(display("No available connection in the internal pool"))] NoAvailableConnection {},