Skip to content

Commit

Permalink
Upgrade json-transport to Tokio 0.2
Browse files Browse the repository at this point in the history
  • Loading branch information
vorot93 committed Aug 30, 2019
1 parent 46bcc0f commit 6e7a4cd
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 56 deletions.
10 changes: 7 additions & 3 deletions example-service/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ use futures::{
prelude::*,
};
use service::World;
use std::{io, net::SocketAddr};
use std::{
io,
net::{IpAddr, SocketAddr},
};
use tarpc::{
context,
server::{self, Channel, Handler},
Expand Down Expand Up @@ -59,11 +62,12 @@ async fn main() -> io::Result<()> {
.parse()
.unwrap_or_else(|e| panic!(r#"--port value "{}" invalid: {}"#, port, e));

let server_addr = ([0, 0, 0, 0], port).into();
let server_addr = (IpAddr::from([0, 0, 0, 0]), port);

// tarpc_json_transport is provided by the associated crate tarpc-json-transport. It makes it easy
// to start up a serde-powered json serialization strategy over TCP.
tarpc_json_transport::listen(&server_addr)?
tarpc_json_transport::listen(&server_addr)
.await?
// Ignore accept errors.
.filter_map(|r| future::ready(r.ok()))
.map(server::BaseChannel::with_defaults)
Expand Down
8 changes: 3 additions & 5 deletions json-transport/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@ description = "A JSON-based transport for tarpc services."

[dependencies]
futures-preview = { version = "0.3.0-alpha.18", features = ["compat"] }
futures_legacy = { version = "0.1", package = "futures" }
pin-utils = "0.1.0-alpha.4"
serde = "1.0"
serde_json = "1.0"
tokio = { version = "0.1", default-features = false, features = ["codec"] }
tokio-io = "0.1"
tokio-serde-json = "0.2"
tokio-tcp = "0.1"
tokio = { version = "0.2.0-alpha.4", default-features = false, features = ["codec", "io", "net"] }
tokio-net = "0.2.0-alpha.4"
tokio-serde-json = { git = "https://github.com/vorot93/tokio-serde-json" }

[dev-dependencies]
futures-test-preview = { version = "0.3.0-alpha.18" }
Expand Down
146 changes: 98 additions & 48 deletions json-transport/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#![deny(missing_docs)]

use futures::{compat::*, prelude::*, ready};
use futures::{prelude::*, ready};
use pin_utils::unsafe_pinned;
use serde::{Deserialize, Serialize};
use std::{
Expand All @@ -20,31 +20,23 @@ use std::{
task::{Context, Poll},
};
use tokio::codec::{length_delimited::LengthDelimitedCodec, Framed};
use tokio_io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream};
use tokio_net::ToSocketAddrs;
use tokio_serde_json::*;
use tokio_tcp::{TcpListener, TcpStream};

/// A transport that serializes to, and deserializes from, a [`TcpStream`].
pub struct Transport<S: AsyncWrite, Item, SinkItem> {
inner: Compat01As03Sink<
ReadJson<WriteJson<Framed<S, LengthDelimitedCodec>, SinkItem>, Item>,
SinkItem,
>,
pub struct Transport<S, Item, SinkItem> {
inner: ReadJson<WriteJson<Framed<S, LengthDelimitedCodec>, SinkItem>, Item>,
}

impl<S: AsyncWrite, Item, SinkItem> Transport<S, Item, SinkItem> {
unsafe_pinned!(
inner:
Compat01As03Sink<
ReadJson<WriteJson<Framed<S, LengthDelimitedCodec>, SinkItem>, Item>,
SinkItem,
>
);
impl<S, Item, SinkItem> Transport<S, Item, SinkItem> {
unsafe_pinned!(inner: ReadJson<WriteJson<Framed<S, LengthDelimitedCodec>, SinkItem>, Item>);
}

impl<S, Item, SinkItem> Stream for Transport<S, Item, SinkItem>
where
S: AsyncWrite + AsyncRead,
S: AsyncWrite + AsyncRead + Unpin,
Item: for<'a> Deserialize<'a>,
{
type Item = io::Result<Item>;
Expand All @@ -63,21 +55,21 @@ where

impl<S, Item, SinkItem> Sink<SinkItem> for Transport<S, Item, SinkItem>
where
S: AsyncWrite,
S: AsyncWrite + Unpin,
SinkItem: Serialize,
{
type Error = io::Error;

fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
convert(self.inner().poll_ready(cx))
}

fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
self.inner()
.start_send(item)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}

fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
convert(self.inner().poll_ready(cx))
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
convert(self.inner().poll_flush(cx))
}
Expand All @@ -100,22 +92,12 @@ fn convert<E: Into<Box<dyn Error + Send + Sync>>>(
impl<Item, SinkItem> Transport<TcpStream, Item, SinkItem> {
/// Returns the peer address of the underlying TcpStream.
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.inner
.get_ref()
.get_ref()
.get_ref()
.get_ref()
.peer_addr()
self.inner.get_ref().get_ref().get_ref().peer_addr()
}

/// Returns the local address of the underlying TcpStream.
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner
.get_ref()
.get_ref()
.get_ref()
.get_ref()
.local_addr()
self.inner.get_ref().get_ref().get_ref().local_addr()
}
}

Expand All @@ -133,10 +115,10 @@ impl<S: AsyncWrite + AsyncRead, Item: serde::de::DeserializeOwned, SinkItem: Ser
{
fn from(inner: S) -> Self {
Transport {
inner: Compat01As03Sink::new(ReadJson::new(WriteJson::new(Framed::new(
inner: ReadJson::new(WriteJson::new(Framed::new(
inner,
LengthDelimitedCodec::new(),
)))),
))),
}
}
}
Expand All @@ -149,35 +131,39 @@ where
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
{
Ok(new(TcpStream::connect(addr).compat().await?))
Ok(new(TcpStream::connect(addr).await?))
}

/// Listens on `addr`, wrapping accepted connections in JSON transports.
pub fn listen<Item, SinkItem>(addr: &SocketAddr) -> io::Result<Incoming<Item, SinkItem>>
pub async fn listen<A, Item, SinkItem>(addr: A) -> io::Result<Incoming<Item, SinkItem>>
where
A: ToSocketAddrs,
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
{
let listener = TcpListener::bind(addr)?;
let listener = TcpListener::bind(addr).await?;
let local_addr = listener.local_addr()?;
let incoming = listener.incoming().compat();
let incoming = Box::pin(listener.incoming());
Ok(Incoming {
incoming,
local_addr,
ghost: PhantomData,
})
}

trait IncomingTrait: Stream<Item = io::Result<TcpStream>> + std::fmt::Debug + Send {}
impl<T: Stream<Item = io::Result<TcpStream>> + std::fmt::Debug + Send> IncomingTrait for T {}

/// A [`TcpListener`] that wraps connections in JSON transports.
#[derive(Debug)]
pub struct Incoming<Item, SinkItem> {
incoming: Compat01As03<tokio_tcp::Incoming>,
incoming: Pin<Box<dyn IncomingTrait>>,
local_addr: SocketAddr,
ghost: PhantomData<(Item, SinkItem)>,
}

impl<Item, SinkItem> Incoming<Item, SinkItem> {
unsafe_pinned!(incoming: Compat01As03<tokio_tcp::Incoming>);
unsafe_pinned!(incoming: Pin<Box<dyn IncomingTrait>>);

/// Returns the address being listened on.
pub fn local_addr(&self) -> SocketAddr {
Expand Down Expand Up @@ -206,19 +192,50 @@ mod tests {
use futures_test::task::noop_waker_ref;
use pin_utils::pin_mut;
use std::{
io::Cursor,
io::{self, Cursor},
pin::Pin,
task::{Context, Poll},
};
use tokio::io::{AsyncRead, AsyncWrite};

fn ctx() -> Context<'static> {
Context::from_waker(&noop_waker_ref())
}

#[test]
fn test_stream() {
let reader = *b"\x00\x00\x00\x18\"Test one, check check.\"";
let reader: Box<[u8]> = Box::new(reader);
let transport = Transport::<_, String, String>::from(Cursor::new(reader));
struct TestIo(Cursor<&'static [u8]>);

impl AsyncRead for TestIo {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
AsyncRead::poll_read(Pin::new(self.0.get_mut()), cx, buf)
}
}

impl AsyncWrite for TestIo {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<io::Result<usize>> {
unreachable!()
}

fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
unreachable!()
}

fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
unreachable!()
}
}

let data = b"\x00\x00\x00\x18\"Test one, check check.\"";
let transport = Transport::<_, String, String>::from(TestIo(Cursor::new(data)));
pin_mut!(transport);

assert_matches!(
Expand All @@ -228,8 +245,41 @@ mod tests {

#[test]
fn test_sink() {
let writer: &mut [u8] = &mut [0; 28];
let transport = Transport::<_, String, String>::from(Cursor::new(&mut *writer));
struct TestIo<'a>(&'a mut Vec<u8>);

impl<'a> AsyncRead for TestIo<'a> {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut [u8],
) -> Poll<io::Result<usize>> {
unreachable!()
}
}

impl<'a> AsyncWrite for TestIo<'a> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
AsyncWrite::poll_write(Pin::new(&mut *self.0), cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
AsyncWrite::poll_flush(Pin::new(&mut *self.0), cx)
}

fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
AsyncWrite::poll_shutdown(Pin::new(&mut *self.0), cx)
}
}

let mut writer = vec![];
let transport = Transport::<_, String, String>::from(TestIo(&mut writer));
pin_mut!(transport);

assert_matches!(
Expand Down

0 comments on commit 6e7a4cd

Please sign in to comment.