Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade json-transport to Tokio 0.2 #263

Merged
merged 2 commits into from
Oct 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions example-service/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ use futures::{
prelude::*,
};
use service::World;
use std::{io, net::SocketAddr};
use std::{
io,
net::{IpAddr, SocketAddr},
};
use tarpc::{
context,
server::{self, Channel, Handler},
Expand Down Expand Up @@ -59,11 +62,12 @@ async fn main() -> io::Result<()> {
.parse()
.unwrap_or_else(|e| panic!(r#"--port value "{}" invalid: {}"#, port, e));

let server_addr = ([0, 0, 0, 0], port).into();
let server_addr = (IpAddr::from([0, 0, 0, 0]), port);

// tarpc_json_transport is provided by the associated crate tarpc-json-transport. It makes it easy
// to start up a serde-powered json serialization strategy over TCP.
tarpc_json_transport::listen(&server_addr)?
tarpc_json_transport::listen(&server_addr)
.await?
// Ignore accept errors.
.filter_map(|r| future::ready(r.ok()))
.map(server::BaseChannel::with_defaults)
Expand Down
18 changes: 8 additions & 10 deletions json-transport/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,14 @@ readme = "../README.md"
description = "A JSON-based transport for tarpc services."

[dependencies]
futures-preview = { version = "0.3.0-alpha.18", features = ["compat"] }
futures_legacy = { version = "0.1", package = "futures" }
pin-utils = "0.1.0-alpha.4"
serde = "1.0"
serde_json = "1.0"
tokio = { version = "0.1", default-features = false, features = ["codec"] }
tokio-io = "0.1"
tokio-serde-json = "0.2"
tokio-tcp = "0.1"
futures-preview = "0.3.0-alpha"
pin-project = "0.4"
serde = "1"
serde_json = "1"
tokio = { version = "0.2.0-alpha", default-features = false, features = ["codec", "io", "net"] }
tokio-net = "0.2.0-alpha"
tokio-serde-json = "0.3"

[dev-dependencies]
futures-test-preview = { version = "0.3.0-alpha.18" }
pin-utils = "0.1.0-alpha"
assert_matches = "1.0"
166 changes: 108 additions & 58 deletions json-transport/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

#![deny(missing_docs)]

use futures::{compat::*, prelude::*, ready};
use pin_utils::unsafe_pinned;
use futures::{prelude::*, ready};
use pin_project::pin_project;
use serde::{Deserialize, Serialize};
use std::{
error::Error,
Expand All @@ -20,37 +20,28 @@ use std::{
task::{Context, Poll},
};
use tokio::codec::{length_delimited::LengthDelimitedCodec, Framed};
use tokio_io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream};
use tokio_net::ToSocketAddrs;
use tokio_serde_json::*;
use tokio_tcp::{TcpListener, TcpStream};

/// A transport that serializes to, and deserializes from, a [`TcpStream`].
pub struct Transport<S: AsyncWrite, Item, SinkItem> {
inner: Compat01As03Sink<
ReadJson<WriteJson<Framed<S, LengthDelimitedCodec>, SinkItem>, Item>,
SinkItem,
>,
}

impl<S: AsyncWrite, Item, SinkItem> Transport<S, Item, SinkItem> {
unsafe_pinned!(
inner:
Compat01As03Sink<
ReadJson<WriteJson<Framed<S, LengthDelimitedCodec>, SinkItem>, Item>,
SinkItem,
>
);
#[pin_project]
pub struct Transport<S, Item, SinkItem> {
#[pin]
inner: ReadJson<WriteJson<Framed<S, LengthDelimitedCodec>, SinkItem>, Item>,
}

impl<S, Item, SinkItem> Stream for Transport<S, Item, SinkItem>
where
S: AsyncWrite + AsyncRead,
// TODO: Remove Unpin bound when tokio-rs/tokio#1272 is resolved.
S: AsyncWrite + AsyncRead + Unpin,
vorot93 marked this conversation as resolved.
Show resolved Hide resolved
Item: for<'a> Deserialize<'a>,
{
type Item = io::Result<Item>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<io::Result<Item>>> {
match self.inner().poll_next(cx) {
match self.project().inner.poll_next(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Ok(next))) => Poll::Ready(Some(Ok(next))),
Expand All @@ -63,27 +54,28 @@ where

impl<S, Item, SinkItem> Sink<SinkItem> for Transport<S, Item, SinkItem>
where
S: AsyncWrite,
S: AsyncWrite + Unpin,
SinkItem: Serialize,
{
type Error = io::Error;

fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
convert(self.project().inner.poll_ready(cx))
}

fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
self.inner()
self.project()
.inner
.start_send(item)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}

fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
convert(self.inner().poll_ready(cx))
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
convert(self.inner().poll_flush(cx))
convert(self.project().inner.poll_flush(cx))
}

fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
convert(self.inner().poll_close(cx))
convert(self.project().inner.poll_close(cx))
}
}

Expand All @@ -100,22 +92,12 @@ fn convert<E: Into<Box<dyn Error + Send + Sync>>>(
impl<Item, SinkItem> Transport<TcpStream, Item, SinkItem> {
/// Returns the peer address of the underlying TcpStream.
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.inner
.get_ref()
.get_ref()
.get_ref()
.get_ref()
.peer_addr()
self.inner.get_ref().get_ref().get_ref().peer_addr()
}

/// Returns the local address of the underlying TcpStream.
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner
.get_ref()
.get_ref()
.get_ref()
.get_ref()
.local_addr()
self.inner.get_ref().get_ref().get_ref().local_addr()
}
}

Expand All @@ -133,10 +115,10 @@ impl<S: AsyncWrite + AsyncRead, Item: serde::de::DeserializeOwned, SinkItem: Ser
{
fn from(inner: S) -> Self {
Transport {
inner: Compat01As03Sink::new(ReadJson::new(WriteJson::new(Framed::new(
inner: ReadJson::new(WriteJson::new(Framed::new(
inner,
LengthDelimitedCodec::new(),
)))),
))),
}
}
}
Expand All @@ -149,36 +131,40 @@ where
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
{
Ok(new(TcpStream::connect(addr).compat().await?))
Ok(new(TcpStream::connect(addr).await?))
}

/// Listens on `addr`, wrapping accepted connections in JSON transports.
pub fn listen<Item, SinkItem>(addr: &SocketAddr) -> io::Result<Incoming<Item, SinkItem>>
pub async fn listen<A, Item, SinkItem>(addr: A) -> io::Result<Incoming<Item, SinkItem>>
where
A: ToSocketAddrs,
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
{
let listener = TcpListener::bind(addr)?;
let listener = TcpListener::bind(addr).await?;
let local_addr = listener.local_addr()?;
let incoming = listener.incoming().compat();
let incoming = Box::pin(listener.incoming());
Ok(Incoming {
incoming,
local_addr,
ghost: PhantomData,
})
}

trait IncomingTrait: Stream<Item = io::Result<TcpStream>> + std::fmt::Debug + Send {}
tikue marked this conversation as resolved.
Show resolved Hide resolved
impl<T: Stream<Item = io::Result<TcpStream>> + std::fmt::Debug + Send> IncomingTrait for T {}

/// A [`TcpListener`] that wraps connections in JSON transports.
#[pin_project]
#[derive(Debug)]
pub struct Incoming<Item, SinkItem> {
incoming: Compat01As03<tokio_tcp::Incoming>,
#[pin]
incoming: Pin<Box<dyn IncomingTrait>>,
tikue marked this conversation as resolved.
Show resolved Hide resolved
local_addr: SocketAddr,
ghost: PhantomData<(Item, SinkItem)>,
}

impl<Item, SinkItem> Incoming<Item, SinkItem> {
unsafe_pinned!(incoming: Compat01As03<tokio_tcp::Incoming>);

/// Returns the address being listened on.
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
Expand All @@ -193,7 +179,7 @@ where
type Item = io::Result<Transport<TcpStream, Item, SinkItem>>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let next = ready!(self.incoming().poll_next(cx)?);
let next = ready!(self.project().incoming.poll_next(cx)?);
Poll::Ready(next.map(|conn| Ok(new(conn))))
}
}
Expand All @@ -202,23 +188,54 @@ where
mod tests {
use super::Transport;
use assert_matches::assert_matches;
use futures::task::noop_waker_ref;
use futures::{Sink, Stream};
use futures_test::task::noop_waker_ref;
use pin_utils::pin_mut;
use std::{
io::Cursor,
io::{self, Cursor},
pin::Pin,
task::{Context, Poll},
};
use tokio::io::{AsyncRead, AsyncWrite};

fn ctx() -> Context<'static> {
Context::from_waker(&noop_waker_ref())
}

#[test]
fn test_stream() {
let reader = *b"\x00\x00\x00\x18\"Test one, check check.\"";
let reader: Box<[u8]> = Box::new(reader);
let transport = Transport::<_, String, String>::from(Cursor::new(reader));
struct TestIo(Cursor<&'static [u8]>);

impl AsyncRead for TestIo {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
AsyncRead::poll_read(Pin::new(self.0.get_mut()), cx, buf)
}
}

impl AsyncWrite for TestIo {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<io::Result<usize>> {
unreachable!()
}

fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
unreachable!()
}

fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
unreachable!()
}
}

let data = b"\x00\x00\x00\x18\"Test one, check check.\"";
let transport = Transport::<_, String, String>::from(TestIo(Cursor::new(data)));
pin_mut!(transport);

assert_matches!(
Expand All @@ -228,8 +245,41 @@ mod tests {

#[test]
fn test_sink() {
let writer: &mut [u8] = &mut [0; 28];
let transport = Transport::<_, String, String>::from(Cursor::new(&mut *writer));
struct TestIo<'a>(&'a mut Vec<u8>);

impl<'a> AsyncRead for TestIo<'a> {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut [u8],
) -> Poll<io::Result<usize>> {
unreachable!()
}
}

impl<'a> AsyncWrite for TestIo<'a> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
AsyncWrite::poll_write(Pin::new(&mut *self.0), cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
AsyncWrite::poll_flush(Pin::new(&mut *self.0), cx)
}

fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
AsyncWrite::poll_shutdown(Pin::new(&mut *self.0), cx)
}
}

let mut writer = vec![];
let transport = Transport::<_, String, String>::from(TestIo(&mut writer));
pin_mut!(transport);

assert_matches!(
Expand Down