Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Endpoint connectors swappable #148

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tonic-build/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,21 @@ fn generate_connect(service_ident: &syn::Ident) -> TokenStream {
let conn = tonic::transport::Endpoint::new(dst)?.connect().await?;
Ok(Self::new(conn))
}

/// Attempt to create a new client by connecting to a given endpoint using a custom
/// connector.
pub async fn connect_with_connector<C, D>(dst: D, connector: C) -> Result<Self, tonic::transport::Error>
where
C: MakeConnection<http::Uri> + Send + Clone + 'static,
C::Connection: Unpin + Send + 'static,
C::Future: Send + 'static,
C::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send,
D: std::convert::TryInto<tonic::transport::Endpoint>,
D::Error: Into<StdError>,
{
let conn = tonic::transport::Endpoint::new(dst)?.connector(connector).connect().await?;
Ok(Self::new(conn))
}
}
}
}
Expand Down
24 changes: 13 additions & 11 deletions tonic-interop/src/bin/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,24 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

let test_cases = matches.test_case;

#[allow(unused_mut)]
let mut endpoint = Endpoint::from_static("http://localhost:10000")
let endpoint = Endpoint::from_static("http://localhost:10000")
.timeout(Duration::from_secs(5))
.concurrency_limit(30);

if matches.use_tls {
let channel = if matches.use_tls {
let pem = tokio::fs::read("tonic-interop/data/ca.pem").await?;
let ca = Certificate::from_pem(pem);
endpoint = endpoint.tls_config(
ClientTlsConfig::with_rustls()
.ca_certificate(ca)
.domain_name("foo.test.google.fr"),
);
}

let channel = endpoint.connect().await?;
endpoint
.tls_config(
ClientTlsConfig::with_rustls()
.ca_certificate(ca)
.domain_name("foo.test.google.fr"),
)
.connect()
.await?
} else {
endpoint.connect().await?
};

let mut client = client::TestClient::new(channel.clone());
let mut unimplemented_client = client::UnimplementedClient::new(channel);
Expand Down
1 change: 1 addition & 0 deletions tonic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ tokio-rustls = { version = "0.12", optional = true }
rustls-native-certs = { version = "0.1", optional = true }

[dev-dependencies]
hyper-unix-connector = "0.1.3"
tokio = { version = "0.2", features = ["rt-core", "macros"] }
static_assertions = "1.0"
rand = "0.6"
Expand Down
2 changes: 2 additions & 0 deletions tonic/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ pub use std::future::Future;
pub use std::pin::Pin;
pub use std::sync::Arc;
pub use std::task::{Context, Poll};
#[cfg(feature = "transport")]
pub use tower_make::MakeConnection;
pub use tower_service::Service;
pub type StdError = Box<dyn std::error::Error + Send + Sync + 'static>;
pub use crate::body::Body;
Expand Down
16 changes: 14 additions & 2 deletions tonic/src/transport/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,13 @@ impl Channel {
///
/// This creates a [`Channel`] that will load balance accross all the
/// provided endpoints.
pub fn balance_list(list: impl Iterator<Item = Endpoint>) -> Self {
pub fn balance_list<C>(list: impl Iterator<Item = Endpoint<C>>) -> Self
where
C: tower_make::MakeConnection<hyper::Uri> + Send + Clone + Unpin + 'static,
C::Connection: Unpin + Send + 'static,
C::Future: Send + 'static,
C::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send,
{
let list = list.collect::<Vec<_>>();

let buffer_size = list
Expand All @@ -116,7 +122,13 @@ impl Channel {
Self::balance(discover, buffer_size, interceptor_headers)
}

pub(crate) async fn connect(endpoint: Endpoint) -> Result<Self, super::Error> {
pub(crate) async fn connect<C>(endpoint: Endpoint<C>) -> Result<Self, super::Error>
where
C: tower_make::MakeConnection<hyper::Uri> + Send + 'static,
C::Connection: Unpin + Send + 'static,
C::Future: Send + 'static,
C::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send,
{
let buffer_size = endpoint.buffer_size.clone().unwrap_or(DEFAULT_BUFFER_SIZE);
let interceptor_headers = endpoint.interceptor_headers.clone();

Expand Down
145 changes: 106 additions & 39 deletions tonic/src/transport/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use super::{
};
use bytes::Bytes;
use http::uri::{InvalidUri, Uri};
use hyper::client::connect::HttpConnector;
use std::{
convert::{TryFrom, TryInto},
fmt,
Expand All @@ -17,13 +18,12 @@ use std::{
///
/// This struct is used to build and configure HTTP/2 channels.
#[derive(Clone)]
pub struct Endpoint {
pub struct Endpoint<C = HttpConnector> {
pub(super) uri: Uri,
pub(super) connector: C,
pub(super) timeout: Option<Duration>,
pub(super) concurrency_limit: Option<usize>,
pub(super) rate_limit: Option<(u64, Duration)>,
#[cfg(feature = "tls")]
pub(super) tls: Option<TlsConnector>,
pub(super) buffer_size: Option<usize>,
pub(super) interceptor_headers:
Option<Arc<dyn Fn(&mut http::HeaderMap) + Send + Sync + 'static>>,
Expand All @@ -33,7 +33,7 @@ pub struct Endpoint {
pub(super) tcp_nodelay: bool,
}

impl Endpoint {
impl<C> Endpoint<C> {
// FIXME: determine if we want to expose this or not. This is really
// just used in codegen for a shortcut.
#[doc(hidden)]
Expand All @@ -48,28 +48,6 @@ impl Endpoint {
Ok(me)
}

/// Convert an `Endpoint` from a static string.
///
/// ```
/// # use tonic::transport::Endpoint;
/// Endpoint::from_static("https://example.com");
/// ```
pub fn from_static(s: &'static str) -> Self {
let uri = Uri::from_static(s);
Self::from(uri)
}

/// Convert an `Endpoint` from shared bytes.
///
/// ```
/// # use tonic::transport::Endpoint;
/// Endpoint::from_shared("https://example.com".to_string());
/// ```
pub fn from_shared(s: impl Into<Bytes>) -> Result<Self, InvalidUri> {
let uri = Uri::from_maybe_shared(s.into())?;
Ok(Self::from(uri))
}

/// Apply a timeout to each request.
///
/// ```
Expand Down Expand Up @@ -164,14 +142,79 @@ impl Endpoint {
}

/// Configures TLS for the endpoint.
///
/// Shortcut for configuring a TLS connector and calling [`Endpoint::connector`].
#[cfg(feature = "tls")]
pub fn tls_config(self, tls_config: ClientTlsConfig) -> Self {
pub fn tls_config(
self,
tls_config: ClientTlsConfig,
) -> Endpoint<
impl tower_make::MakeConnection<
hyper::Uri,
Connection = impl Unpin + Send + 'static,
Future = impl Send + 'static,
Error = impl Into<Box<dyn std::error::Error + Send + Sync>> + Send,
> + Clone,
> {
let tls_connector = tls_config.tls_connector(self.uri.clone()).unwrap();
let connector = super::service::tls_connector(Some(tls_connector));
self.connector(connector)
}
}

impl<C> Endpoint<C> {
/// Use a custom connector for the underlying channel.
///
/// Calling [`Endpoint::connect`] requires the connector implement
/// the [`tower_make::MakeConnection`] requirement, which is an alias for `tower::Service<Uri, Response = AsyncRead +
/// Async Write>` - for example, a TCP stream for the default [`HttpConnector`].
///
/// # Example
/// ```rust
/// use hyper::client::connect::HttpConnector;
/// use tonic::transport::Endpoint;
///
/// // note: This connector is the same as the default provided in `connect()`.
/// let mut connector = HttpConnector::new();
/// connector.enforce_http(false);
/// connector.set_nodelay(true);
///
/// let endpoint = Endpoint::from_static("http://example.com");
/// endpoint.connector(connector).connect(); //.await
/// ```
///
/// # Example with non-default Connector
/// ```rust
/// // Use for unix-domain sockets
/// use hyper_unix_connector::UnixClient;
/// use tonic::transport::Endpoint;
///
/// let endpoint = Endpoint::from_static("http://example.com");
/// endpoint.connector(UnixClient).connect(); //.await
/// ```
pub fn connector<D>(self, connector: D) -> Endpoint<D> {
Endpoint {
tls: Some(tls_config.tls_connector(self.uri.clone()).unwrap()),
..self
uri: self.uri,
connector,
concurrency_limit: self.concurrency_limit,
rate_limit: self.rate_limit,
timeout: self.timeout,
buffer_size: self.buffer_size,
interceptor_headers: self.interceptor_headers,
init_stream_window_size: self.init_stream_window_size,
init_connection_window_size: self.init_connection_window_size,
}
}
}

impl<C> Endpoint<C>
where
C: tower_make::MakeConnection<hyper::Uri> + Send + Clone + 'static,
C::Connection: Unpin + Send + 'static,
C::Future: Send + 'static,
C::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send,
{
/// Create the channel.
/// Set the value of `TCP_NODELAY` option for accepted connections. Enabled by default.
pub fn tcp_nodelay(self, enabled: bool) -> Self {
Endpoint {
Expand All @@ -182,19 +225,43 @@ impl Endpoint {

/// Create a channel from this config.
pub async fn connect(&self) -> Result<Channel, super::Error> {
Channel::connect(self.clone()).await
let e: Self = self.clone();
Channel::connect(e).await
}
}

impl Endpoint<HttpConnector> {
/// Convert an `Endpoint` from a static string.
///
/// ```
/// # use tonic::transport::Endpoint;
/// Endpoint::from_static("https://example.com");
/// ```
pub fn from_static(s: &'static str) -> Self {
let uri = Uri::from_static(s);
Self::from(uri)
}

/// Convert an `Endpoint` from shared bytes.
///
/// ```
/// # use tonic::transport::Endpoint;
/// Endpoint::from_shared("https://example.com".to_string());
/// ```
pub fn from_shared(s: impl Into<Bytes>) -> Result<Self, InvalidUriBytes> {
let uri = Uri::from_shared(s.into())?;
Ok(Self::from(uri))
}
}

impl From<Uri> for Endpoint {
impl From<Uri> for Endpoint<HttpConnector> {
fn from(uri: Uri) -> Self {
Self {
uri,
connector: super::service::connector(),
concurrency_limit: None,
rate_limit: None,
timeout: None,
#[cfg(feature = "tls")]
tls: None,
buffer_size: None,
interceptor_headers: None,
init_stream_window_size: None,
Expand All @@ -205,23 +272,23 @@ impl From<Uri> for Endpoint {
}
}

impl TryFrom<Bytes> for Endpoint {
type Error = InvalidUri;
impl TryFrom<Bytes> for Endpoint<HttpConnector> {
type Error = InvalidUriBytes;

fn try_from(t: Bytes) -> Result<Self, Self::Error> {
Self::from_shared(t)
}
}

impl TryFrom<String> for Endpoint {
type Error = InvalidUri;
impl TryFrom<String> for Endpoint<HttpConnector> {
type Error = InvalidUriBytes;

fn try_from(t: String) -> Result<Self, Self::Error> {
Self::from_shared(t.into_bytes())
}
}

impl TryFrom<&'static str> for Endpoint {
impl TryFrom<&'static str> for Endpoint<HttpConnector> {
type Error = Never;

fn try_from(t: &'static str) -> Result<Self, Self::Error> {
Expand All @@ -240,7 +307,7 @@ impl std::fmt::Display for Never {

impl std::error::Error for Never {}

impl fmt::Debug for Endpoint {
impl<C> fmt::Debug for Endpoint<C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Endpoint").finish()
}
Expand Down
Loading