Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Endpoint connectors swappable #148

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
endpoint: Make endpoint contain a connector
Instead of providing a connector in the call to `connect()`, make `Endpoint`
contain a connector, by default an `HTTPConnector`, which the caller can
swap.
  • Loading branch information
akshayknarayan committed Dec 3, 2019
commit fd7470996c5d5ceb3da6e48e73590a7d9f806d57
4 changes: 2 additions & 2 deletions tonic-build/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,14 @@ fn generate_connect(service_ident: &syn::Ident) -> TokenStream {
/// connector.
pub async fn connect_with_connector<C, D>(dst: D, connector: C) -> Result<Self, tonic::transport::Error>
where
C: MakeConnection<http::Uri> + Send + 'static,
C: MakeConnection<http::Uri> + Send + Clone + 'static,
C::Connection: Unpin + Send + 'static,
C::Future: Send + 'static,
C::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send,
D: std::convert::TryInto<tonic::transport::Endpoint>,
D::Error: Into<StdError>,
{
let conn = tonic::transport::Endpoint::new(dst)?.connect_with_connector(connector).await?;
let conn = tonic::transport::Endpoint::new(dst)?.connector(connector).connect().await?;
Ok(Self::new(conn))
}
}
Expand Down
24 changes: 13 additions & 11 deletions tonic-interop/src/bin/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,24 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

let test_cases = matches.test_case;

#[allow(unused_mut)]
let mut endpoint = Endpoint::from_static("http://localhost:10000")
let endpoint = Endpoint::from_static("http://localhost:10000")
.timeout(Duration::from_secs(5))
.concurrency_limit(30);

if matches.use_tls {
let channel = if matches.use_tls {
let pem = tokio::fs::read("tonic-interop/data/ca.pem").await?;
let ca = Certificate::from_pem(pem);
endpoint = endpoint.tls_config(
ClientTlsConfig::with_rustls()
.ca_certificate(ca)
.domain_name("foo.test.google.fr"),
);
}

let channel = endpoint.connect().await?;
endpoint
.tls_config(
ClientTlsConfig::with_rustls()
.ca_certificate(ca)
.domain_name("foo.test.google.fr"),
)
.connect()
.await?
} else {
endpoint.connect().await?
};

let mut client = client::TestClient::new(channel.clone());
let mut unimplemented_client = client::UnimplementedClient::new(channel);
Expand Down
2 changes: 1 addition & 1 deletion tonic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ tokio-rustls = { version = "=0.12.0-alpha.5", optional = true }
rustls-native-certs = { version = "0.1", optional = true }

[dev-dependencies]
hyper-unix-connector = "0.1.2"
hyper-unix-connector = "0.1.3"
static_assertions = "1.0"
rand = "0.7.2"
criterion = "0.3"
Expand Down
33 changes: 4 additions & 29 deletions tonic/src/transport/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,7 @@ impl Channel {
///
/// This creates a [`Channel`] that will load balance accross all the
/// provided endpoints.
pub fn balance_list_with_connector<C>(
list: impl Iterator<Item = Endpoint>,
connector: C,
) -> Self
pub fn balance_list<C>(list: impl Iterator<Item = Endpoint<C>>) -> Self
where
C: tower_make::MakeConnection<hyper::Uri> + Send + Clone + Unpin + 'static,
C::Connection: Unpin + Send + 'static,
Expand All @@ -120,34 +117,12 @@ impl Channel {
.next()
.and_then(|e| e.interceptor_headers.clone());

let discover = ServiceList::new(list, connector);
let discover = ServiceList::new(list);

Self::balance(discover, buffer_size, interceptor_headers)
}

/// Balance a list of [`Endpoint`]'s.
///
/// This creates a [`Channel`] that will load balance accross all the
/// provided endpoints.
pub fn balance_list(list: impl Iterator<Item = Endpoint>) -> Self {
// Backwards API compatibility.
// Uses TCP if the TLS feature is not enabled, and TLS otherwise.

let list = list.collect::<Vec<_>>();

#[cfg(feature = "tls")]
let connector = {
let tls_connector = list.iter().next().and_then(|e| e.tls.clone());
super::service::connector(tls_connector)
};

#[cfg(not(feature = "tls"))]
let connector = super::service::connector();

Channel::balance_list_with_connector(list.into_iter(), connector)
}

pub(crate) async fn connect<C>(endpoint: Endpoint, connector: C) -> Result<Self, super::Error>
pub(crate) async fn connect<C>(endpoint: Endpoint<C>) -> Result<Self, super::Error>
where
C: tower_make::MakeConnection<hyper::Uri> + Send + 'static,
C::Connection: Unpin + Send + 'static,
Expand All @@ -157,7 +132,7 @@ impl Channel {
let buffer_size = endpoint.buffer_size.clone().unwrap_or(DEFAULT_BUFFER_SIZE);
let interceptor_headers = endpoint.interceptor_headers.clone();

let svc = Connection::new(endpoint, connector)
let svc = Connection::new(endpoint)
.await
.map_err(|e| super::Error::from_source(super::ErrorKind::Client, e))?;

Expand Down
149 changes: 84 additions & 65 deletions tonic/src/transport/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use super::{
};
use bytes::Bytes;
use http::uri::{InvalidUriBytes, Uri};
use hyper::client::connect::HttpConnector;
use std::{
convert::{TryFrom, TryInto},
fmt,
Expand All @@ -17,21 +18,20 @@ use std::{
///
/// This struct is used to build and configure HTTP/2 channels.
#[derive(Clone)]
pub struct Endpoint {
pub struct Endpoint<C = HttpConnector> {
pub(super) uri: Uri,
pub(super) connector: C,
pub(super) timeout: Option<Duration>,
pub(super) concurrency_limit: Option<usize>,
pub(super) rate_limit: Option<(u64, Duration)>,
#[cfg(feature = "tls")]
pub(super) tls: Option<TlsConnector>,
pub(super) buffer_size: Option<usize>,
pub(super) interceptor_headers:
Option<Arc<dyn Fn(&mut http::HeaderMap) + Send + Sync + 'static>>,
pub(super) init_stream_window_size: Option<u32>,
pub(super) init_connection_window_size: Option<u32>,
}

impl Endpoint {
impl<C> Endpoint<C> {
// FIXME: determine if we want to expose this or not. This is really
// just used in codegen for a shortcut.
#[doc(hidden)]
Expand All @@ -46,28 +46,6 @@ impl Endpoint {
Ok(me)
}

/// Convert an `Endpoint` from a static string.
///
/// ```
/// # use tonic::transport::Endpoint;
/// Endpoint::from_static("https://example.com");
/// ```
pub fn from_static(s: &'static str) -> Self {
let uri = Uri::from_static(s);
Self::from(uri)
}

/// Convert an `Endpoint` from shared bytes.
///
/// ```
/// # use tonic::transport::Endpoint;
/// Endpoint::from_shared("https://example.com".to_string());
/// ```
pub fn from_shared(s: impl Into<Bytes>) -> Result<Self, InvalidUriBytes> {
let uri = Uri::from_shared(s.into())?;
Ok(Self::from(uri))
}

/// Apply a timeout to each request.
///
/// ```
Expand Down Expand Up @@ -147,32 +125,32 @@ impl Endpoint {
}

/// Configures TLS for the endpoint.
///
/// Shortcut for configuring a TLS connector and calling [`Endpoint::connector`].
#[cfg(feature = "tls")]
pub fn tls_config(self, tls_config: ClientTlsConfig) -> Self {
Endpoint {
tls: Some(tls_config.tls_connector(self.uri.clone()).unwrap()),
..self
}
}

/// Create a channel from this config.
pub async fn connect(&self) -> Result<Channel, super::Error> {
// Backwards API compatibility.
// Uses TCP if the TLS feature is not enabled, and TLS otherwise.

#[cfg(feature = "tls")]
let connector = super::service::connector(self.tls.clone());

#[cfg(not(feature = "tls"))]
let connector = super::service::connector();

self.connect_with_connector(connector).await
pub fn tls_config(
self,
tls_config: ClientTlsConfig,
) -> Endpoint<
impl tower_make::MakeConnection<
hyper::Uri,
Connection = impl Unpin + Send + 'static,
Future = impl Send + 'static,
Error = impl Into<Box<dyn std::error::Error + Send + Sync>> + Send,
> + Clone,
> {
let tls_connector = tls_config.tls_connector(self.uri.clone()).unwrap();
let connector = super::service::tls_connector(Some(tls_connector));
self.connector(connector)
}
}

/// Create a channel using a custom connector.
impl<C> Endpoint<C> {
/// Use a custom connector for the underlying channel.
///
/// The [`tower_make::MakeConnection`] requirement is an alias for `tower::Service<Uri, Response = AsyncRead +
/// Async Write>` - for example, a TCP stream as in [`Endpoint::connect`] above.
/// Calling [`Endpoint::connect`] requires the connector implement
/// the [`tower_make::MakeConnection`] requirement, which is an alias for `tower::Service<Uri, Response = AsyncRead +
/// Async Write>` - for example, a TCP stream for the default [`HttpConnector`].
///
/// # Example
/// ```rust
Expand All @@ -185,7 +163,7 @@ impl Endpoint {
/// connector.set_nodelay(true);
///
/// let endpoint = Endpoint::from_static("http://example.com");
/// endpoint.connect_with_connector(connector); //.await
/// endpoint.connector(connector).connect(); //.await
/// ```
///
/// # Example with non-default Connector
Expand All @@ -195,28 +173,69 @@ impl Endpoint {
/// use tonic::transport::Endpoint;
///
/// let endpoint = Endpoint::from_static("http://example.com");
/// endpoint.connect_with_connector(UnixClient); //.await
/// endpoint.connector(UnixClient).connect(); //.await
/// ```
pub async fn connect_with_connector<C>(&self, connector: C) -> Result<Channel, super::Error>
where
C: tower_make::MakeConnection<hyper::Uri> + Send + 'static,
C::Connection: Unpin + Send + 'static,
C::Future: Send + 'static,
C::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send,
{
Channel::connect(self.clone(), connector).await
pub fn connector<D>(self, connector: D) -> Endpoint<D> {
Endpoint {
uri: self.uri,
connector,
concurrency_limit: self.concurrency_limit,
rate_limit: self.rate_limit,
timeout: self.timeout,
buffer_size: self.buffer_size,
interceptor_headers: self.interceptor_headers,
init_stream_window_size: self.init_stream_window_size,
init_connection_window_size: self.init_connection_window_size,
}
}
}

impl<C> Endpoint<C>
where
C: tower_make::MakeConnection<hyper::Uri> + Send + Clone + 'static,
C::Connection: Unpin + Send + 'static,
C::Future: Send + 'static,
C::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send,
{
/// Create the channel.
pub async fn connect(&self) -> Result<Channel, super::Error> {
let e: Self = self.clone();
Channel::connect(e).await
}
}

impl Endpoint<HttpConnector> {
/// Convert an `Endpoint` from a static string.
///
/// ```
/// # use tonic::transport::Endpoint;
/// Endpoint::from_static("https://example.com");
/// ```
pub fn from_static(s: &'static str) -> Self {
let uri = Uri::from_static(s);
Self::from(uri)
}

/// Convert an `Endpoint` from shared bytes.
///
/// ```
/// # use tonic::transport::Endpoint;
/// Endpoint::from_shared("https://example.com".to_string());
/// ```
pub fn from_shared(s: impl Into<Bytes>) -> Result<Self, InvalidUriBytes> {
let uri = Uri::from_shared(s.into())?;
Ok(Self::from(uri))
}
}

impl From<Uri> for Endpoint {
impl From<Uri> for Endpoint<HttpConnector> {
fn from(uri: Uri) -> Self {
Self {
uri,
connector: super::service::connector(),
concurrency_limit: None,
rate_limit: None,
timeout: None,
#[cfg(feature = "tls")]
tls: None,
buffer_size: None,
interceptor_headers: None,
init_stream_window_size: None,
Expand All @@ -225,23 +244,23 @@ impl From<Uri> for Endpoint {
}
}

impl TryFrom<Bytes> for Endpoint {
impl TryFrom<Bytes> for Endpoint<HttpConnector> {
type Error = InvalidUriBytes;

fn try_from(t: Bytes) -> Result<Self, Self::Error> {
Self::from_shared(t)
}
}

impl TryFrom<String> for Endpoint {
impl TryFrom<String> for Endpoint<HttpConnector> {
type Error = InvalidUriBytes;

fn try_from(t: String) -> Result<Self, Self::Error> {
Self::from_shared(t.into_bytes())
}
}

impl TryFrom<&'static str> for Endpoint {
impl TryFrom<&'static str> for Endpoint<HttpConnector> {
type Error = Never;

fn try_from(t: &'static str) -> Result<Self, Self::Error> {
Expand All @@ -260,7 +279,7 @@ impl std::fmt::Display for Never {

impl std::error::Error for Never {}

impl fmt::Debug for Endpoint {
impl<C> fmt::Debug for Endpoint<C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Endpoint").finish()
}
Expand Down
Loading