Skip to content

Commit

Permalink
endpoint: Make endpoint contain a connector
Browse files Browse the repository at this point in the history
Instead of providing a connector in the call to `connect()`, make `Endpoint`
contain a connector, by default an `HTTPConnector`, which the caller can
swap.
  • Loading branch information
akshayknarayan committed Dec 3, 2019
1 parent 8efaba6 commit fd74709
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 130 deletions.
4 changes: 2 additions & 2 deletions tonic-build/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,14 @@ fn generate_connect(service_ident: &syn::Ident) -> TokenStream {
/// connector.
pub async fn connect_with_connector<C, D>(dst: D, connector: C) -> Result<Self, tonic::transport::Error>
where
C: MakeConnection<http::Uri> + Send + 'static,
C: MakeConnection<http::Uri> + Send + Clone + 'static,
C::Connection: Unpin + Send + 'static,
C::Future: Send + 'static,
C::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send,
D: std::convert::TryInto<tonic::transport::Endpoint>,
D::Error: Into<StdError>,
{
let conn = tonic::transport::Endpoint::new(dst)?.connect_with_connector(connector).await?;
let conn = tonic::transport::Endpoint::new(dst)?.connector(connector).connect().await?;
Ok(Self::new(conn))
}
}
Expand Down
24 changes: 13 additions & 11 deletions tonic-interop/src/bin/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,24 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

let test_cases = matches.test_case;

#[allow(unused_mut)]
let mut endpoint = Endpoint::from_static("http://localhost:10000")
let endpoint = Endpoint::from_static("http://localhost:10000")
.timeout(Duration::from_secs(5))
.concurrency_limit(30);

if matches.use_tls {
let channel = if matches.use_tls {
let pem = tokio::fs::read("tonic-interop/data/ca.pem").await?;
let ca = Certificate::from_pem(pem);
endpoint = endpoint.tls_config(
ClientTlsConfig::with_rustls()
.ca_certificate(ca)
.domain_name("foo.test.google.fr"),
);
}

let channel = endpoint.connect().await?;
endpoint
.tls_config(
ClientTlsConfig::with_rustls()
.ca_certificate(ca)
.domain_name("foo.test.google.fr"),
)
.connect()
.await?
} else {
endpoint.connect().await?
};

let mut client = client::TestClient::new(channel.clone());
let mut unimplemented_client = client::UnimplementedClient::new(channel);
Expand Down
2 changes: 1 addition & 1 deletion tonic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ tokio-rustls = { version = "=0.12.0-alpha.5", optional = true }
rustls-native-certs = { version = "0.1", optional = true }

[dev-dependencies]
hyper-unix-connector = "0.1.2"
hyper-unix-connector = "0.1.3"
static_assertions = "1.0"
rand = "0.7.2"
criterion = "0.3"
Expand Down
33 changes: 4 additions & 29 deletions tonic/src/transport/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,7 @@ impl Channel {
///
/// This creates a [`Channel`] that will load balance accross all the
/// provided endpoints.
pub fn balance_list_with_connector<C>(
list: impl Iterator<Item = Endpoint>,
connector: C,
) -> Self
pub fn balance_list<C>(list: impl Iterator<Item = Endpoint<C>>) -> Self
where
C: tower_make::MakeConnection<hyper::Uri> + Send + Clone + Unpin + 'static,
C::Connection: Unpin + Send + 'static,
Expand All @@ -120,34 +117,12 @@ impl Channel {
.next()
.and_then(|e| e.interceptor_headers.clone());

let discover = ServiceList::new(list, connector);
let discover = ServiceList::new(list);

Self::balance(discover, buffer_size, interceptor_headers)
}

/// Balance a list of [`Endpoint`]'s.
///
/// This creates a [`Channel`] that will load balance accross all the
/// provided endpoints.
pub fn balance_list(list: impl Iterator<Item = Endpoint>) -> Self {
// Backwards API compatibility.
// Uses TCP if the TLS feature is not enabled, and TLS otherwise.

let list = list.collect::<Vec<_>>();

#[cfg(feature = "tls")]
let connector = {
let tls_connector = list.iter().next().and_then(|e| e.tls.clone());
super::service::connector(tls_connector)
};

#[cfg(not(feature = "tls"))]
let connector = super::service::connector();

Channel::balance_list_with_connector(list.into_iter(), connector)
}

pub(crate) async fn connect<C>(endpoint: Endpoint, connector: C) -> Result<Self, super::Error>
pub(crate) async fn connect<C>(endpoint: Endpoint<C>) -> Result<Self, super::Error>
where
C: tower_make::MakeConnection<hyper::Uri> + Send + 'static,
C::Connection: Unpin + Send + 'static,
Expand All @@ -157,7 +132,7 @@ impl Channel {
let buffer_size = endpoint.buffer_size.clone().unwrap_or(DEFAULT_BUFFER_SIZE);
let interceptor_headers = endpoint.interceptor_headers.clone();

let svc = Connection::new(endpoint, connector)
let svc = Connection::new(endpoint)
.await
.map_err(|e| super::Error::from_source(super::ErrorKind::Client, e))?;

Expand Down
149 changes: 84 additions & 65 deletions tonic/src/transport/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use super::{
};
use bytes::Bytes;
use http::uri::{InvalidUriBytes, Uri};
use hyper::client::connect::HttpConnector;
use std::{
convert::{TryFrom, TryInto},
fmt,
Expand All @@ -17,21 +18,20 @@ use std::{
///
/// This struct is used to build and configure HTTP/2 channels.
#[derive(Clone)]
pub struct Endpoint {
pub struct Endpoint<C = HttpConnector> {
pub(super) uri: Uri,
pub(super) connector: C,
pub(super) timeout: Option<Duration>,
pub(super) concurrency_limit: Option<usize>,
pub(super) rate_limit: Option<(u64, Duration)>,
#[cfg(feature = "tls")]
pub(super) tls: Option<TlsConnector>,
pub(super) buffer_size: Option<usize>,
pub(super) interceptor_headers:
Option<Arc<dyn Fn(&mut http::HeaderMap) + Send + Sync + 'static>>,
pub(super) init_stream_window_size: Option<u32>,
pub(super) init_connection_window_size: Option<u32>,
}

impl Endpoint {
impl<C> Endpoint<C> {
// FIXME: determine if we want to expose this or not. This is really
// just used in codegen for a shortcut.
#[doc(hidden)]
Expand All @@ -46,28 +46,6 @@ impl Endpoint {
Ok(me)
}

/// Convert an `Endpoint` from a static string.
///
/// ```
/// # use tonic::transport::Endpoint;
/// Endpoint::from_static("https://example.com");
/// ```
pub fn from_static(s: &'static str) -> Self {
let uri = Uri::from_static(s);
Self::from(uri)
}

/// Convert an `Endpoint` from shared bytes.
///
/// ```
/// # use tonic::transport::Endpoint;
/// Endpoint::from_shared("https://example.com".to_string());
/// ```
pub fn from_shared(s: impl Into<Bytes>) -> Result<Self, InvalidUriBytes> {
let uri = Uri::from_shared(s.into())?;
Ok(Self::from(uri))
}

/// Apply a timeout to each request.
///
/// ```
Expand Down Expand Up @@ -147,32 +125,32 @@ impl Endpoint {
}

/// Configures TLS for the endpoint.
///
/// Shortcut for configuring a TLS connector and calling [`Endpoint::connector`].
#[cfg(feature = "tls")]
pub fn tls_config(self, tls_config: ClientTlsConfig) -> Self {
Endpoint {
tls: Some(tls_config.tls_connector(self.uri.clone()).unwrap()),
..self
}
}

/// Create a channel from this config.
pub async fn connect(&self) -> Result<Channel, super::Error> {
// Backwards API compatibility.
// Uses TCP if the TLS feature is not enabled, and TLS otherwise.

#[cfg(feature = "tls")]
let connector = super::service::connector(self.tls.clone());

#[cfg(not(feature = "tls"))]
let connector = super::service::connector();

self.connect_with_connector(connector).await
pub fn tls_config(
self,
tls_config: ClientTlsConfig,
) -> Endpoint<
impl tower_make::MakeConnection<
hyper::Uri,
Connection = impl Unpin + Send + 'static,
Future = impl Send + 'static,
Error = impl Into<Box<dyn std::error::Error + Send + Sync>> + Send,
> + Clone,
> {
let tls_connector = tls_config.tls_connector(self.uri.clone()).unwrap();
let connector = super::service::tls_connector(Some(tls_connector));
self.connector(connector)
}
}

/// Create a channel using a custom connector.
impl<C> Endpoint<C> {
/// Use a custom connector for the underlying channel.
///
/// The [`tower_make::MakeConnection`] requirement is an alias for `tower::Service<Uri, Response = AsyncRead +
/// Async Write>` - for example, a TCP stream as in [`Endpoint::connect`] above.
/// Calling [`Endpoint::connect`] requires the connector implement
/// the [`tower_make::MakeConnection`] requirement, which is an alias for `tower::Service<Uri, Response = AsyncRead +
/// Async Write>` - for example, a TCP stream for the default [`HttpConnector`].
///
/// # Example
/// ```rust
Expand All @@ -185,7 +163,7 @@ impl Endpoint {
/// connector.set_nodelay(true);
///
/// let endpoint = Endpoint::from_static("http://example.com");
/// endpoint.connect_with_connector(connector); //.await
/// endpoint.connector(connector).connect(); //.await
/// ```
///
/// # Example with non-default Connector
Expand All @@ -195,28 +173,69 @@ impl Endpoint {
/// use tonic::transport::Endpoint;
///
/// let endpoint = Endpoint::from_static("http://example.com");
/// endpoint.connect_with_connector(UnixClient); //.await
/// endpoint.connector(UnixClient).connect(); //.await
/// ```
pub async fn connect_with_connector<C>(&self, connector: C) -> Result<Channel, super::Error>
where
C: tower_make::MakeConnection<hyper::Uri> + Send + 'static,
C::Connection: Unpin + Send + 'static,
C::Future: Send + 'static,
C::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send,
{
Channel::connect(self.clone(), connector).await
pub fn connector<D>(self, connector: D) -> Endpoint<D> {
Endpoint {
uri: self.uri,
connector,
concurrency_limit: self.concurrency_limit,
rate_limit: self.rate_limit,
timeout: self.timeout,
buffer_size: self.buffer_size,
interceptor_headers: self.interceptor_headers,
init_stream_window_size: self.init_stream_window_size,
init_connection_window_size: self.init_connection_window_size,
}
}
}

impl<C> Endpoint<C>
where
C: tower_make::MakeConnection<hyper::Uri> + Send + Clone + 'static,
C::Connection: Unpin + Send + 'static,
C::Future: Send + 'static,
C::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send,
{
/// Create the channel.
pub async fn connect(&self) -> Result<Channel, super::Error> {
let e: Self = self.clone();
Channel::connect(e).await
}
}

impl Endpoint<HttpConnector> {
/// Convert an `Endpoint` from a static string.
///
/// ```
/// # use tonic::transport::Endpoint;
/// Endpoint::from_static("https://example.com");
/// ```
pub fn from_static(s: &'static str) -> Self {
let uri = Uri::from_static(s);
Self::from(uri)
}

/// Convert an `Endpoint` from shared bytes.
///
/// ```
/// # use tonic::transport::Endpoint;
/// Endpoint::from_shared("https://example.com".to_string());
/// ```
pub fn from_shared(s: impl Into<Bytes>) -> Result<Self, InvalidUriBytes> {
let uri = Uri::from_shared(s.into())?;
Ok(Self::from(uri))
}
}

impl From<Uri> for Endpoint {
impl From<Uri> for Endpoint<HttpConnector> {
fn from(uri: Uri) -> Self {
Self {
uri,
connector: super::service::connector(),
concurrency_limit: None,
rate_limit: None,
timeout: None,
#[cfg(feature = "tls")]
tls: None,
buffer_size: None,
interceptor_headers: None,
init_stream_window_size: None,
Expand All @@ -225,23 +244,23 @@ impl From<Uri> for Endpoint {
}
}

impl TryFrom<Bytes> for Endpoint {
impl TryFrom<Bytes> for Endpoint<HttpConnector> {
type Error = InvalidUriBytes;

fn try_from(t: Bytes) -> Result<Self, Self::Error> {
Self::from_shared(t)
}
}

impl TryFrom<String> for Endpoint {
impl TryFrom<String> for Endpoint<HttpConnector> {
type Error = InvalidUriBytes;

fn try_from(t: String) -> Result<Self, Self::Error> {
Self::from_shared(t.into_bytes())
}
}

impl TryFrom<&'static str> for Endpoint {
impl TryFrom<&'static str> for Endpoint<HttpConnector> {
type Error = Never;

fn try_from(t: &'static str) -> Result<Self, Self::Error> {
Expand All @@ -260,7 +279,7 @@ impl std::fmt::Display for Never {

impl std::error::Error for Never {}

impl fmt::Debug for Endpoint {
impl<C> fmt::Debug for Endpoint<C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Endpoint").finish()
}
Expand Down
Loading

0 comments on commit fd74709

Please sign in to comment.