From fa5239a6ec1c3b526a15595e1edd0a374875f07c Mon Sep 17 00:00:00 2001 From: Wenqi Mou <452787782@qq.com> Date: Tue, 19 May 2020 15:04:29 -0700 Subject: [PATCH] Issue 72: Micro benchmark (#98) Micro benchmark --- Cargo.toml | 8 +- benches/benchmark.rs | 275 +++++++++++++++++++++++ controller-client/src/lib.rs | 2 +- controller-client/src/mock_controller.rs | 44 +++- src/client_factory.rs | 13 +- src/event_stream_writer.rs | 8 +- src/lib.rs | 2 +- wire_protocol/Cargo.toml | 2 +- wire_protocol/src/client_config.rs | 4 + wire_protocol/src/client_connection.rs | 59 ++--- wire_protocol/src/connection.rs | 42 ++-- wire_protocol/src/connection_factory.rs | 94 ++++---- wire_protocol/src/lib.rs | 1 + wire_protocol/src/mock_connection.rs | 204 +++++++++++++++++ 14 files changed, 627 insertions(+), 131 deletions(-) create mode 100644 benches/benchmark.rs create mode 100644 wire_protocol/src/mock_connection.rs diff --git a/Cargo.toml b/Cargo.toml index 6b55cade2..687fba448 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ authors = ["Tom Kaitchuck ", "Wenqi Mou Self { + let listener = TcpListener::bind("127.0.0.1:0").await.expect("local server"); + let address = listener.local_addr().unwrap(); + MockServer { address, listener } + } + + pub async fn run(mut self) { + let (mut stream, _addr) = self.listener.accept().await.expect("get incoming stream"); + loop { + let mut header: Vec = vec![0; LENGTH_FIELD_OFFSET as usize + LENGTH_FIELD_LENGTH as usize]; + stream + .read_exact(&mut header[..]) + .await + .expect("read header from incoming stream"); + let mut rdr = Cursor::new(&header[4..8]); + let payload_length = + byteorder::ReadBytesExt::read_u32::(&mut rdr).expect("exact size"); + let mut payload: Vec = vec![0; payload_length as usize]; + stream + .read_exact(&mut payload[..]) + .await + .expect("read payload from incoming stream"); + let concatenated = [&header[..], &payload[..]].concat(); + let request: Requests = Requests::read_from(&concatenated).expect("decode wirecommand"); + match request { + Requests::Hello(cmd) => { + let reply = Replies::Hello(cmd).write_fields().expect("encode reply"); + stream + .write_all(&reply) + .await + .expect("write reply back to client"); + } + Requests::SetupAppend(cmd) => { + let reply = Replies::AppendSetup(AppendSetupCommand { + request_id: cmd.request_id, + segment: cmd.segment, + writer_id: cmd.writer_id, + last_event_number: -9223372036854775808, // when there is no previous event in this segment + }) + .write_fields() + .expect("encode reply"); + stream + .write_all(&reply) + .await + .expect("write reply back to client"); + } + Requests::AppendBlockEnd(cmd) => { + let reply = Replies::DataAppended(DataAppendedCommand { + writer_id: cmd.writer_id, + event_number: cmd.last_event_number, + previous_event_number: 0, //not used in event stream writer + request_id: cmd.request_id, + current_segment_write_offset: 0, //not used in event stream writer + }) + .write_fields() + .expect("encode reply"); + stream + .write_all(&reply) + .await + .expect("write reply back to client"); + } + _ => { + panic!("unsupported request {:?}", request); + } + } + } + } +} + +// This benchmark test uses a mock server that replies ok to any requests instantly. It involves +// kernel latency. +fn mock_server(c: &mut Criterion) { + let mut rt = tokio::runtime::Runtime::new().unwrap(); + let mock_server = rt.block_on(MockServer::new()); + let config = ClientConfigBuilder::default() + .controller_uri(mock_server.address) + .mock(true) + .build() + .expect("creating config"); + let mut writer = rt.block_on(set_up(config)); + rt.spawn(async { MockServer::run(mock_server).await }); + + info!("start mock server performance testing"); + c.bench_function("mock server", |b| { + b.iter(|| { + rt.block_on(run(&mut writer)); + }); + }); + info!("mock server performance testing finished"); +} + +// This benchmark test uses a mock server that replies ok to any requests instantly. It involves +// kernel latency. It does not wait for reply. +fn mock_server_no_block(c: &mut Criterion) { + let mut rt = tokio::runtime::Runtime::new().unwrap(); + let mock_server = rt.block_on(MockServer::new()); + let config = ClientConfigBuilder::default() + .controller_uri(mock_server.address) + .mock(true) + .build() + .expect("creating config"); + let mut writer = rt.block_on(set_up(config)); + rt.spawn(async { MockServer::run(mock_server).await }); + + info!("start mock server(no block) performance testing"); + c.bench_function("mock server(no block)", |b| { + b.iter(|| { + rt.block_on(run_no_block(&mut writer)); + }); + }); + info!("mock server(no block) performance testing finished"); +} + +// This benchmark test uses a mock connection that replies ok to any requests instantly. It does not +// involve kernel latency. +fn mock_connection(c: &mut Criterion) { + let mut rt = tokio::runtime::Runtime::new().unwrap(); + let config = ClientConfigBuilder::default() + .controller_uri("127.0.0.1:9090".parse::().unwrap()) + .mock(true) + .connection_type(ConnectionType::Mock) + .build() + .expect("creating config"); + let mut writer = rt.block_on(set_up(config)); + + info!("start mock connection performance testing"); + c.bench_function("mock connection", |b| { + b.iter(|| { + rt.block_on(run(&mut writer)); + }); + }); + info!("mock server connection testing finished"); +} + +// This benchmark test uses a mock connection that replies ok to any requests instantly. It does not +// involve kernel latency. It does not wait for reply. +fn mock_connection_no_block(c: &mut Criterion) { + let mut rt = tokio::runtime::Runtime::new().unwrap(); + let config = ClientConfigBuilder::default() + .controller_uri("127.0.0.1:9090".parse::().unwrap()) + .mock(true) + .connection_type(ConnectionType::Mock) + .build() + .expect("creating config"); + let mut writer = rt.block_on(set_up(config)); + + info!("start mock connection(no block) performance testing"); + c.bench_function("mock connection(no block)", |b| { + b.iter(|| { + rt.block_on(run_no_block(&mut writer)); + }); + }); + info!("mock server connection(no block) testing finished"); +} + +// helper functions +async fn set_up(config: ClientConfig) -> EventStreamWriter { + let scope_name = Scope::new("testWriterPerf".into()); + let stream_name = Stream::new("testWriterPerf".into()); + let client_factory = ClientFactory::new(config.clone()); + let controller_client = client_factory.get_controller_client(); + create_scope_stream(controller_client, &scope_name, &stream_name, 1).await; + let scoped_stream = ScopedStream { + scope: scope_name.clone(), + stream: stream_name.clone(), + }; + client_factory.create_event_stream_writer(scoped_stream, config.clone()) +} + +async fn create_scope_stream( + controller_client: &dyn ControllerClient, + scope_name: &Scope, + stream_name: &Stream, + segment_number: i32, +) { + controller_client + .create_scope(scope_name) + .await + .expect("create scope"); + info!("Scope created"); + let request = StreamConfiguration { + scoped_stream: ScopedStream { + scope: scope_name.clone(), + stream: stream_name.clone(), + }, + scaling: Scaling { + scale_type: ScaleType::FixedNumSegments, + target_rate: 0, + scale_factor: 0, + min_num_segments: segment_number, + }, + retention: Retention { + retention_type: RetentionType::None, + retention_param: 0, + }, + }; + controller_client + .create_stream(&request) + .await + .expect("create stream"); + info!("Stream created"); +} + +// run sends request to server and wait for the reply +async fn run(writer: &mut EventStreamWriter) { + let mut receivers = vec![]; + for _i in 0..EVENT_NUM { + let rx = writer.write_event(vec![0; EVENT_SIZE]).await; + receivers.push(rx); + } + assert_eq!(receivers.len(), EVENT_NUM); + + for rx in receivers { + let reply: Result<(), EventStreamWriterError> = rx.await.expect("wait for result from oneshot"); + assert_eq!(reply.is_ok(), true); + } +} + +// run no block sends request to server and does not wait for the reply +async fn run_no_block(writer: &mut EventStreamWriter) { + let mut receivers = vec![]; + for _i in 0..EVENT_NUM { + let rx = writer.write_event(vec![0; EVENT_SIZE]).await; + receivers.push(rx); + } + assert_eq!(receivers.len(), EVENT_NUM); +} + +criterion_group!( + performance, + mock_server, + mock_server_no_block, + mock_connection, + mock_connection_no_block +); +criterion_main!(performance); diff --git a/controller-client/src/lib.rs b/controller-client/src/lib.rs index 9ed19608f..4d1656d20 100644 --- a/controller-client/src/lib.rs +++ b/controller-client/src/lib.rs @@ -61,7 +61,7 @@ pub mod controller { // this is the rs file name generated after compiling the proto file, located inside the target folder. } -mod mock_controller; +pub mod mock_controller; mod model_helper; #[cfg(test)] mod test; diff --git a/controller-client/src/mock_controller.rs b/controller-client/src/mock_controller.rs index f107766d9..36d2a50ca 100644 --- a/controller-client/src/mock_controller.rs +++ b/controller-client/src/mock_controller.rs @@ -12,28 +12,45 @@ use super::ControllerClient; use super::ControllerError; use async_trait::async_trait; use ordered_float::OrderedFloat; +use pravega_connection_pool::connection_pool::ConnectionPool; use pravega_rust_client_shared::*; -use pravega_wire_protocol::client_connection::ClientConnection; +use pravega_wire_protocol::client_connection::{ClientConnection, ClientConnectionImpl}; use pravega_wire_protocol::commands::{CreateSegmentCommand, DeleteSegmentCommand, MergeSegmentsCommand}; +use pravega_wire_protocol::connection_factory::{ + ConnectionFactory, ConnectionType, SegmentConnectionManager, +}; use pravega_wire_protocol::error::ClientConnectionError; use pravega_wire_protocol::wire_commands::{Replies, Requests}; use std::collections::HashSet; use std::collections::{BTreeMap, HashMap}; +use std::net::SocketAddr; use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; -use tokio::sync::{Mutex, RwLock, RwLockReadGuard}; +use tokio::sync::{RwLock, RwLockReadGuard}; use uuid::Uuid; static ID_GENERATOR: AtomicUsize = AtomicUsize::new(0); -struct MockController { - endpoint: String, - port: i32, - connection: Mutex>, // a fake client connection to send wire command. +pub struct MockController { + endpoint: SocketAddr, + pool: ConnectionPool, created_scopes: RwLock>>, created_streams: RwLock>, } +impl MockController { + pub fn new(endpoint: SocketAddr) -> Self { + let cf = ConnectionFactory::create(ConnectionType::Mock) as Box; + let manager = SegmentConnectionManager::new(cf, 10); + let pool = ConnectionPool::new(manager); + MockController { + endpoint, + pool, + created_scopes: RwLock::new(HashMap::new()), + created_streams: RwLock::new(HashMap::new()), + } + } +} #[async_trait] impl ControllerClient for MockController { async fn create_scope(&self, scope: &Scope) -> Result { @@ -251,8 +268,7 @@ impl ControllerClient for MockController { &self, _segment: &ScopedSegment, ) -> Result { - let uri = self.endpoint.clone() + ":" + &self.port.to_string(); - Ok(PravegaNodeUri(uri)) + Ok(PravegaNodeUri(self.endpoint.to_string())) } async fn get_or_refresh_delegation_token_for( @@ -578,6 +594,14 @@ async fn send_request_over_connection( command: &Requests, controller: &MockController, ) -> Result { - controller.connection.lock().await.write(command).await?; - controller.connection.lock().await.read().await + let pooled_connection = controller + .pool + .get_connection(controller.endpoint) + .await + .expect("get connection from pool"); + let mut connection = ClientConnectionImpl { + connection: pooled_connection, + }; + connection.write(command).await?; + connection.read().await } diff --git a/src/client_factory.rs b/src/client_factory.rs index 2b4fd6aa9..3fdb8dd6e 100644 --- a/src/client_factory.rs +++ b/src/client_factory.rs @@ -11,6 +11,7 @@ use std::net::SocketAddr; use pravega_connection_pool::connection_pool::ConnectionPool; +use pravega_controller_client::mock_controller::MockController; use pravega_controller_client::{ControllerClient, ControllerClientImpl}; use pravega_rust_client_shared::{ScopedSegment, ScopedStream}; use pravega_wire_protocol::client_config::ClientConfig; @@ -27,7 +28,7 @@ pub struct ClientFactory(Arc); pub struct ClientFactoryInternal { connection_pool: ConnectionPool, - controller_client: ControllerClientImpl, + controller_client: Box, } impl ClientFactory { @@ -35,7 +36,11 @@ impl ClientFactory { let _ = setup_logger(); //Ignore failure let cf = ConnectionFactory::create(config.connection_type); let pool = ConnectionPool::new(SegmentConnectionManager::new(cf, config.max_connections_in_pool)); - let controller = ControllerClientImpl::new(config); + let controller = if config.mock { + Box::new(MockController::new(config.controller_uri)) as Box + } else { + Box::new(ControllerClientImpl::new(config)) as Box + }; ClientFactory(Arc::new(ClientFactoryInternal { connection_pool: pool, controller_client: controller, @@ -78,7 +83,7 @@ impl ClientFactory { } pub fn get_controller_client(&self) -> &dyn ControllerClient { - &self.0.controller_client + &*self.0.controller_client } } @@ -92,6 +97,6 @@ impl ClientFactoryInternal { } pub(crate) fn get_controller_client(&self) -> &dyn ControllerClient { - &self.controller_client + &*self.controller_client } } diff --git a/src/event_stream_writer.rs b/src/event_stream_writer.rs index 8e92bff0a..5f883d602 100644 --- a/src/event_stream_writer.rs +++ b/src/event_stream_writer.rs @@ -477,10 +477,10 @@ impl EventSegmentWriter { } let acked = self.inflight.pop_front().expect("must have"); - if let Err(e) = acked.event.oneshot_sender.send(Result::Ok(())) { - error!( - "failed to send ack back to caller using oneshot due to {:?}: event id {:?}", - e, acked.event_id + if acked.event.oneshot_sender.send(Result::Ok(())).is_err() { + debug!( + "failed to send ack back to caller using oneshot due to Receiver dropped: event id {:?}", + acked.event_id ); } diff --git a/src/lib.rs b/src/lib.rs index 19b270941..38cafb282 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -64,7 +64,7 @@ pub fn setup_logger() -> Result<(), fern::InitError> { message )) }) - .level(log::LevelFilter::Debug) + .level(log::LevelFilter::Info) .chain(std::io::stdout()) .chain(fern::log_file("./output.log")?) .apply()?; diff --git a/wire_protocol/Cargo.toml b/wire_protocol/Cargo.toml index 9e3db4dc4..4814f3849 100644 --- a/wire_protocol/Cargo.toml +++ b/wire_protocol/Cargo.toml @@ -27,7 +27,7 @@ parking_lot = "0.10.0" uuid = {version = "0.8", features = ["v4"]} serde = { version = "1.0", features = ["derive"] } snafu = "0.6.2" -tokio = { version = "0.2.16", features = ["full"] } +tokio = { version = "0.2.20", features = ["full"] } futures = "0.3.5" dashmap = "3.4.4" log = "0.4.8" diff --git a/wire_protocol/src/client_config.rs b/wire_protocol/src/client_config.rs index 3001418ee..2197fc6df 100644 --- a/wire_protocol/src/client_config.rs +++ b/wire_protocol/src/client_config.rs @@ -38,6 +38,10 @@ pub struct ClientConfig { #[get_copy = "pub"] #[builder(setter(into))] pub controller_uri: SocketAddr, + + #[get_copy = "pub"] + #[builder(default = "false")] + pub mock: bool, } #[cfg(test)] diff --git a/wire_protocol/src/client_connection.rs b/wire_protocol/src/client_connection.rs index 2cb1e4303..d287097c0 100644 --- a/wire_protocol/src/client_connection.rs +++ b/wire_protocol/src/client_connection.rs @@ -37,11 +37,11 @@ pub struct ClientConnectionImpl<'a> { } pub struct ReadingClientConnection { - read_half: ReadingConnection, + read_half: Box, } pub struct WritingClientConnection { - write_half: WritingConnection, + write_half: Box, } impl<'a> ClientConnectionImpl<'a> { @@ -149,59 +149,34 @@ mod tests { use super::*; use crate::commands::HelloCommand; use crate::connection_factory::{ConnectionFactory, ConnectionType, SegmentConnectionManager}; - use crate::wire_commands::{Encode, Replies}; + use crate::wire_commands::Replies; use pravega_connection_pool::connection_pool::ConnectionPool; - use std::io::Write; - use std::net::{SocketAddr, TcpListener}; + use std::net::SocketAddr; use tokio::runtime::Runtime; - struct Server { - address: SocketAddr, - listener: TcpListener, - } - - impl Server { - pub fn new() -> Server { - let listener = TcpListener::bind("127.0.0.1:0").expect("local server"); - let address = listener.local_addr().expect("get listener address"); - Server { address, listener } - } - - pub fn send_hello_wirecommand(&mut self) { - let hello = Replies::Hello(HelloCommand { - high_version: 9, - low_version: 5, - }) - .write_fields() - .expect("serialize wirecommand"); - for stream in self.listener.incoming() { - let mut stream = stream.expect("get tcp stream"); - stream.write(&hello).expect("reply with hello wirecommand"); - break; - } - } - } #[test] - fn client_connection_read() { + fn client_connection_write_and_read() { let mut rt = Runtime::new().expect("create tokio Runtime"); - let mut server = Server::new(); - let connection_factory = ConnectionFactory::create(ConnectionType::Mock); let manager = SegmentConnectionManager::new(connection_factory, 1); let pool = ConnectionPool::new(manager); let connection = rt - .block_on(pool.get_connection(server.address)) + .block_on(pool.get_connection("127.0.0.1:9090".parse::().unwrap())) .expect("get connection from pool"); - // server send wirecommand - server.send_hello_wirecommand(); - + let mut client_connection = ClientConnectionImpl::new(connection); + // write wirecommand + let request = Requests::Hello(HelloCommand { + high_version: 9, + low_version: 5, + }); + rt.block_on(client_connection.write(&request)) + .expect("client connection write"); // read wirecommand - let mut reader = ClientConnectionImpl::new(connection); - - let fut = reader.read(); - let reply = rt.block_on(fut).expect("get reply from server"); + let reply = rt + .block_on(client_connection.read()) + .expect("client connection read"); assert_eq!( reply, diff --git a/wire_protocol/src/connection.rs b/wire_protocol/src/connection.rs index ab86ae1de..79e6eea37 100644 --- a/wire_protocol/src/connection.rs +++ b/wire_protocol/src/connection.rs @@ -64,7 +64,7 @@ pub trait Connection: Send + Sync { /// ``` async fn read_async(&mut self, buf: &mut [u8]) -> Result<(), ConnectionError>; - fn split(&mut self) -> (ReadingConnection, WritingConnection); + fn split(&mut self) -> (Box, Box); fn get_endpoint(&self) -> SocketAddr; @@ -107,20 +107,20 @@ impl Connection for TokioConnection { Ok(()) } - fn split(&mut self) -> (ReadingConnection, WritingConnection) { + fn split(&mut self) -> (Box, Box) { assert!(!self.stream.is_none()); let (read_half, write_half) = tokio::io::split(self.stream.take().expect("take connection")); - let read = ReadingConnection { + let read = Box::new(ReadingConnectionImpl { uuid: self.uuid, endpoint: self.endpoint, read_half, - }; - let write = WritingConnection { + }) as Box; + let write = Box::new(WritingConnectionImpl { uuid: self.uuid, endpoint: self.endpoint, write_half, - }; + }) as Box; (read, write) } @@ -141,14 +141,21 @@ impl Connection for TokioConnection { } } -pub struct ReadingConnection { +#[async_trait] +pub trait ReadingConnection: Send + Sync { + async fn read_async(&mut self, buf: &mut [u8]) -> Result<(), ConnectionError>; + fn get_id(&self) -> Uuid; +} + +pub struct ReadingConnectionImpl { uuid: Uuid, endpoint: SocketAddr, read_half: ReadHalf, } -impl ReadingConnection { - pub async fn read_async(&mut self, buf: &mut [u8]) -> Result<(), ConnectionError> { +#[async_trait] +impl ReadingConnection for ReadingConnectionImpl { + async fn read_async(&mut self, buf: &mut [u8]) -> Result<(), ConnectionError> { let endpoint = self.endpoint; self.read_half .read_exact(buf) @@ -157,19 +164,26 @@ impl ReadingConnection { Ok(()) } - pub fn get_id(&self) -> Uuid { + fn get_id(&self) -> Uuid { self.uuid } } -pub struct WritingConnection { +#[async_trait] +pub trait WritingConnection: Send + Sync { + async fn send_async(&mut self, payload: &[u8]) -> Result<(), ConnectionError>; + fn get_id(&self) -> Uuid; +} + +pub struct WritingConnectionImpl { uuid: Uuid, endpoint: SocketAddr, write_half: WriteHalf, } -impl WritingConnection { - pub async fn send_async(&mut self, payload: &[u8]) -> Result<(), ConnectionError> { +#[async_trait] +impl WritingConnection for WritingConnectionImpl { + async fn send_async(&mut self, payload: &[u8]) -> Result<(), ConnectionError> { let endpoint = self.endpoint; self.write_half .write_all(payload) @@ -178,7 +192,7 @@ impl WritingConnection { Ok(()) } - pub fn get_id(&self) -> Uuid { + fn get_id(&self) -> Uuid { self.uuid } } diff --git a/wire_protocol/src/connection_factory.rs b/wire_protocol/src/connection_factory.rs index 2b72c8abb..a17ce40b6 100644 --- a/wire_protocol/src/connection_factory.rs +++ b/wire_protocol/src/connection_factory.rs @@ -12,6 +12,7 @@ use crate::client_connection::{read_wirecommand, write_wirecommand}; use crate::commands::{HelloCommand, OLDEST_COMPATIBLE_VERSION, WIRE_VERSION}; use crate::connection::{Connection, TokioConnection}; use crate::error::*; +use crate::mock_connection::MockConnection; use crate::wire_commands::{Replies, Requests}; use async_trait::async_trait; use pravega_connection_pool::connection_pool::{ConnectionPoolError, Manager}; @@ -108,18 +109,8 @@ impl ConnectionFactory for MockConnectionFactory { &self, endpoint: SocketAddr, ) -> Result, ConnectionFactoryError> { - let connection_type = ConnectionType::Mock; - let uuid = Uuid::new_v4(); - let stream = TcpStream::connect(endpoint).await.context(Connect { - connection_type, - endpoint, - })?; - let tokio_connection: Box = Box::new(TokioConnection { - uuid, - endpoint, - stream: Some(stream), - }) as Box; - Ok(tokio_connection) + let mock = MockConnection::new(endpoint); + Ok(Box::new(mock) as Box) } } @@ -196,54 +187,53 @@ impl Manager for SegmentConnectionManager { #[cfg(test)] mod tests { use super::*; - use std::io::Write; - use std::net::{SocketAddr, TcpListener}; + use crate::wire_commands::{Decode, Encode}; + use log::info; + use std::net::SocketAddr; use tokio::runtime::Runtime; - struct Server { - address: SocketAddr, - listener: TcpListener, - } - - impl Server { - pub fn new() -> Server { - let listener = TcpListener::bind("127.0.0.1:0").expect("local server"); - let address = listener.local_addr().unwrap(); - Server { address, listener } - } - - pub fn echo(&mut self) { - for stream in self.listener.incoming() { - let mut stream = stream.unwrap(); - stream.write(b"Hello World\r\n").unwrap(); - break; - } - } - } - #[test] - fn test_connection() { + fn test_mock_connection() { + info!("test mock connection factory"); let mut rt = Runtime::new().unwrap(); - let mut server = Server::new(); - let connection_factory = ConnectionFactory::create(ConnectionType::Mock); - let connection_future = connection_factory.establish_connection(server.address); - let mut connection = rt.block_on(connection_future).unwrap(); - - let mut payload: Vec = Vec::new(); - payload.push(12); - let fut = connection.send_async(&payload); - - let _res = rt.block_on(fut).unwrap(); + let connection_future = + connection_factory.establish_connection("127.1.1.1:9090".parse::().unwrap()); + let mut mock_connection = rt.block_on(connection_future).unwrap(); + + let request = Requests::Hello(HelloCommand { + high_version: 9, + low_version: 5, + }) + .write_fields() + .unwrap(); + let len = request.len(); + rt.block_on(mock_connection.send_async(&request)) + .expect("write to mock connection"); + let mut buf = vec![0; len]; + rt.block_on(mock_connection.read_async(&mut buf)) + .expect("read from mock connection"); + let reply = Replies::read_from(&buf).unwrap(); + let expected = Replies::Hello(HelloCommand { + high_version: 9, + low_version: 5, + }); + assert_eq!(reply, expected); + info!("mock connection factory test passed"); + } - server.echo(); - let mut buf = [0; 13]; + #[test] + #[should_panic] + fn test_tokio_connection() { + info!("test tokio connection factory"); + let mut rt = Runtime::new().unwrap(); - let fut = connection.read_async(&mut buf); - let _res = rt.block_on(fut).unwrap(); + let connection_factory = ConnectionFactory::create(ConnectionType::Tokio); + let connection_future = + connection_factory.establish_connection("127.1.1.1:9090".parse::().unwrap()); + let mut _connection = rt.block_on(connection_future).expect("create tokio connection"); - let echo = "Hello World\r\n".as_bytes(); - assert_eq!(buf, &echo[..]); + info!("tokio connection factory test passed"); } } diff --git a/wire_protocol/src/lib.rs b/wire_protocol/src/lib.rs index 83a957348..564bc6762 100644 --- a/wire_protocol/src/lib.rs +++ b/wire_protocol/src/lib.rs @@ -35,5 +35,6 @@ pub mod connection_factory; pub mod error; pub mod wire_commands; +pub mod mock_connection; #[cfg(test)] mod tests; diff --git a/wire_protocol/src/mock_connection.rs b/wire_protocol/src/mock_connection.rs new file mode 100644 index 000000000..8946d506b --- /dev/null +++ b/wire_protocol/src/mock_connection.rs @@ -0,0 +1,204 @@ +// +// Copyright (c) Dell Inc., or its subsidiaries. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// + +extern crate byteorder; +use crate::commands::{AppendSetupCommand, DataAppendedCommand}; +use crate::connection::{Connection, ReadingConnection, WritingConnection}; +use crate::error::*; +use crate::wire_commands::{Decode, Encode, Replies, Requests}; +use async_trait::async_trait; +use std::net::SocketAddr; +use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; +use uuid::Uuid; + +pub struct MockConnection { + id: Uuid, + endpoint: SocketAddr, + sender: Option>, + receiver: Option>, + buffer: Vec, + index: usize, +} + +impl MockConnection { + pub fn new(endpoint: SocketAddr) -> Self { + let (tx, rx) = unbounded_channel(); + MockConnection { + id: Uuid::new_v4(), + endpoint, + sender: Some(tx), + receiver: Some(rx), + buffer: vec![], + index: 0, + } + } +} + +#[async_trait] +impl Connection for MockConnection { + async fn send_async(&mut self, payload: &[u8]) -> Result<(), ConnectionError> { + send(self.sender.as_mut().expect("get sender"), payload).await + } + + async fn read_async(&mut self, buf: &mut [u8]) -> Result<(), ConnectionError> { + if self.index == self.buffer.len() { + let reply: Replies = self + .receiver + .as_mut() + .expect("get receiver") + .recv() + .await + .expect("read"); + self.buffer = reply.write_fields().expect("serialize reply"); + self.index = 0; + } + buf.copy_from_slice(&self.buffer[self.index..self.index + buf.len()]); + self.index += buf.len(); + assert!(self.index <= self.buffer.len()); + Ok(()) + } + + fn split(&mut self) -> (Box, Box) { + let reader = Box::new(MockReadingConnection { + id: self.id, + receiver: self + .receiver + .take() + .expect("split mock connection and get receiver"), + buffer: vec![], + index: 0, + }) as Box; + let writer = Box::new(MockWritingConnection { + id: self.id, + sender: self.sender.take().expect("split mock connection and get sender"), + }) as Box; + (reader, writer) + } + + fn get_endpoint(&self) -> SocketAddr { + self.endpoint + } + + fn get_uuid(&self) -> Uuid { + self.id + } + + fn is_valid(&self) -> bool { + true + } +} + +pub struct MockReadingConnection { + id: Uuid, + receiver: UnboundedReceiver, + buffer: Vec, + index: usize, +} + +pub struct MockWritingConnection { + id: Uuid, + sender: UnboundedSender, +} + +#[async_trait] +impl ReadingConnection for MockReadingConnection { + async fn read_async(&mut self, buf: &mut [u8]) -> Result<(), ConnectionError> { + if self.index == self.buffer.len() { + let reply: Replies = self.receiver.recv().await.expect("read"); + self.buffer = reply.write_fields().expect("serialize reply"); + self.index = 0; + } + buf.copy_from_slice(&self.buffer[self.index..self.index + buf.len()]); + self.index += buf.len(); + assert!(self.index <= self.buffer.len()); + Ok(()) + } + + fn get_id(&self) -> Uuid { + self.id + } +} + +#[async_trait] +impl WritingConnection for MockWritingConnection { + async fn send_async(&mut self, payload: &[u8]) -> Result<(), ConnectionError> { + send(&mut self.sender, payload).await + } + + fn get_id(&self) -> Uuid { + self.id + } +} + +async fn send(sender: &mut UnboundedSender, payload: &[u8]) -> Result<(), ConnectionError> { + let request: Requests = Requests::read_from(payload).expect("mock connection decode request"); + match request { + Requests::Hello(cmd) => { + let reply = Replies::Hello(cmd); + sender.send(reply).expect("send reply"); + } + Requests::SetupAppend(cmd) => { + let reply = Replies::AppendSetup(AppendSetupCommand { + request_id: cmd.request_id, + segment: cmd.segment, + writer_id: cmd.writer_id, + last_event_number: -9_223_372_036_854_775_808, // when there is no previous event in this segment + }); + sender.send(reply).expect("send reply"); + } + Requests::AppendBlockEnd(cmd) => { + let reply = Replies::DataAppended(DataAppendedCommand { + writer_id: cmd.writer_id, + event_number: cmd.last_event_number, + previous_event_number: 0, //not used in event stream writer + request_id: cmd.request_id, + current_segment_write_offset: 0, //not used in event stream writer + }); + sender.send(reply).expect("send reply"); + } + _ => { + panic!("unsupported request {:?}", request); + } + } + Ok(()) +} + +#[cfg(test)] +mod test { + use super::*; + use crate::commands::HelloCommand; + use log::info; + + #[test] + fn test_simple_write_and_read() { + info!("mock client connection test"); + let mut rt = tokio::runtime::Runtime::new().unwrap(); + let mut mock_connection = MockConnection::new("127.1.1.1:9090".parse::().unwrap()); + let request = Requests::Hello(HelloCommand { + high_version: 9, + low_version: 5, + }) + .write_fields() + .unwrap(); + let len = request.len(); + rt.block_on(mock_connection.send_async(&request)) + .expect("write to mock connection"); + let mut buf = vec![0; len]; + rt.block_on(mock_connection.read_async(&mut buf)) + .expect("read from mock connection"); + let reply = Replies::read_from(&buf).unwrap(); + let expected = Replies::Hello(HelloCommand { + high_version: 9, + low_version: 5, + }); + assert_eq!(reply, expected); + info!("mock connection test passed"); + } +}