Skip to content

Commit

Permalink
fix(dgw): use all resolved addresses when connecting (#601)
Browse files Browse the repository at this point in the history
This patch ensures Devolutions Gateway does not immediately discard
resolved addresses which are not emitted first by Tokio’s `lookup_host`.

Typically, the first address is enough and there is no need to try
subsequent ones. Therefore, it is not expected for this change to
cause any additional latence in the the vast majority of the cases.
However, just to be on the safe side and enable easier troubleshooting,
a WARN-level log is emitted when failing at connecting to a resolved
address. If latence were to be introduced by this patch, we can
easily be made aware of the problem and investigate further (network
configuration, etc).

If this proves to be a problem in the future, we can add filtering
options. For instance, on a network where IPv4 is not supported or
disabled, we may want to filter out all the IPv4 addresses which may
be resolved by the Devolutions Gateway.

Issue: DGW-125
  • Loading branch information
CBenoit authored Nov 30, 2023
1 parent dc03e73 commit fe4dc63
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 69 deletions.
20 changes: 2 additions & 18 deletions crates/jmux-proxy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -846,24 +846,8 @@ impl StreamResolverTask {
let host = destination_url.host();
let port = destination_url.port();

let mut addrs = match tokio::net::lookup_host((host, port)).await {
Ok(addrs) => addrs,
Err(error) => {
debug!(?error, "tokio::net::lookup_host failed");
msg_to_send_tx
.send(Message::open_failure(
channel.distant_id,
ReasonCode::from(error.kind()),
error.to_string(),
))
.context("couldn’t send OPEN FAILURE message through mpsc channel")?;
anyhow::bail!("couldn't resolve {}:{}: {}", host, port, error);
}
};
let socket_addr = addrs.next().expect("at least one resolved address should be present");

match scheme {
"tcp" => match TcpStream::connect(socket_addr).await {
"tcp" => match TcpStream::connect((host, port)).await {
Ok(stream) => {
internal_msg_tx
.send(InternalMessage::StreamResolved { channel, stream })
Expand All @@ -878,7 +862,7 @@ impl StreamResolverTask {
error.to_string(),
))
.context("couldn’t send OPEN FAILURE message through mpsc channel")?;
anyhow::bail!("Couldn’t connect TCP socket to {}:{}: {}", host, port, error);
anyhow::bail!("couldn’t open TCP stream to {}:{}: {}", host, port, error);
}
},
_ => anyhow::bail!("unsupported scheme: {}", scheme),
Expand Down
21 changes: 5 additions & 16 deletions crates/proxy-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,26 +193,15 @@ async fn process_https(incoming: TcpStream) -> io::Result<()> {

println!("Requested proxying to {dest_addr:?}");

let socket_addr = {
match dest_addr.clone() {
proxy_types::DestAddr::Ip(addr) => addr,
proxy_types::DestAddr::Domain(domain, port) => {
let mut addrs = match tokio::net::lookup_host((domain, port)).await {
Ok(addrs) => addrs,
Err(e) => {
acceptor.failure(proxy_http::ErrorCode::BadGateway).await?;
return Err(e);
}
};
addrs.next().expect("at least one resolved address should be present")
}
}
let connect_result = match dest_addr {
proxy_types::DestAddr::Ip(addr) => TcpStream::connect(addr).await,
proxy_types::DestAddr::Domain(domain, port) => TcpStream::connect((domain.as_str(), *port)).await,
};

let target_stream = match TcpStream::connect(socket_addr).await {
let target_stream = match connect_result {
Ok(stream) => stream,
Err(e) => {
acceptor.failure(proxy_http::ErrorCode::InternalServerError).await?;
acceptor.failure(proxy_http::ErrorCode::BadGateway).await?;
return Err(e);
}
};
Expand Down
9 changes: 2 additions & 7 deletions devolutions-gateway/src/api/kdc_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use tokio::net::{TcpStream, UdpSocket};

use crate::http::HttpError;
use crate::token::AccessTokenClaims;
use crate::utils::resolve_target_addr;
use crate::DgwState;

pub fn make_router<S>(state: DgwState) -> Router<S> {
Expand Down Expand Up @@ -79,14 +78,10 @@ async fn kdc_proxy(

let protocol = kdc_addr.scheme();

let kdc_addr = resolve_target_addr(kdc_addr)
.await
.map_err(HttpError::internal().with_msg("unable to locate KDC server").err())?;

trace!("Connecting to KDC server located at {kdc_addr} using protocol {protocol}...");

let kdc_reply_message = if protocol == "tcp" {
let mut connection = TcpStream::connect(kdc_addr)
let mut connection = TcpStream::connect(kdc_addr.as_addr())
.await
.map_err(HttpError::internal().with_msg("unable to connect to KDC server").err())?;

Expand Down Expand Up @@ -124,7 +119,7 @@ async fn kdc_proxy(

// first 4 bytes contains message length. we don't need it for UDP
udp_socket
.send_to(&kdc_proxy_message.kerb_message.0 .0[4..], kdc_addr)
.send_to(&kdc_proxy_message.kerb_message.0 .0[4..], kdc_addr.as_addr())
.await
.map_err(
HttpError::internal()
Expand Down
2 changes: 1 addition & 1 deletion devolutions-gateway/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ fn read_pfx_file(

let crypto_context = password
.map(|pwd| Pkcs12CryptoContext::new_with_password(pwd.get()))
.unwrap_or_else(|| Pkcs12CryptoContext::new_without_password());
.unwrap_or_else(Pkcs12CryptoContext::new_without_password);
let parsing_params = Pkcs12ParsingParams::default();

let pfx_contents = normalize_data_path(path, &get_data_dir())
Expand Down
54 changes: 44 additions & 10 deletions devolutions-gateway/src/target_addr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use core::fmt;
use serde::{de, ser};
use smol_str::SmolStr;
use std::net::IpAddr;
use std::ops::RangeBounds;
use std::str::FromStr;
use tap::Pipe as _;

Expand Down Expand Up @@ -96,11 +97,11 @@ impl TargetAddr {
}

pub fn scheme(&self) -> &str {
self.h_slice(0, self.scheme_end)
self.h_slice_repr(..self.scheme_end)
}

pub fn host(&self) -> &str {
self.h_slice(self.host_start, self.host_end)
self.h_slice_repr(self.host_start..self.host_end)
}

pub fn host_ip(&self) -> Option<IpAddr> {
Expand All @@ -114,9 +115,28 @@ impl TargetAddr {
self.port
}

#[inline]
fn h_slice(&self, start: u16, end: u16) -> &str {
&self.serialization[usize::from(start)..usize::from(end)]
pub fn as_addr(&self) -> &str {
self.h_slice_repr((self.scheme_end + 3)..)
}

// Slice the internal representation using a [`Range<u16>`]
fn h_slice_repr(&self, range: impl RangeBounds<u16>) -> &str {
use std::ops::Bound;

// TODO(@CBenoit): use Bound::map when stabilized (bound_map feature)
// https://github.com/rust-lang/rust/issues/86026
let lo = match range.start_bound() {
Bound::Included(idx) => Bound::Included(usize::from(*idx)),
Bound::Excluded(idx) => Bound::Excluded(usize::from(*idx)),
Bound::Unbounded => Bound::Unbounded,
};
let hi = match range.end_bound() {
Bound::Included(idx) => Bound::Included(usize::from(*idx)),
Bound::Excluded(idx) => Bound::Excluded(usize::from(*idx)),
Bound::Unbounded => Bound::Unbounded,
};

&self.serialization.as_str()[(lo, hi)]
}
}

Expand Down Expand Up @@ -237,6 +257,14 @@ impl<'de> de::Deserialize<'de> for TargetAddr {
}
}

impl std::net::ToSocketAddrs for TargetAddr {
type Iter = std::vec::IntoIter<std::net::SocketAddr>;

fn to_socket_addrs(&self) -> std::io::Result<Self::Iter> {
self.as_addr().to_socket_addrs()
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -245,47 +273,53 @@ mod tests {
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};

#[rstest]
#[case("localhost:80", "tcp", "localhost", None, 80)]
#[case("localhost:80", "tcp", "localhost", None, 80, "localhost:80")]
#[case(
"udp://127.0.0.1:8080",
"udp",
"127.0.0.1",
Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))),
8080
8080,
"127.0.0.1:8080"
)]
#[case(
"tcp://[2001:db8::8a2e:370:7334]:7171",
"tcp",
"2001:db8::8a2e:370:7334",
Some(IpAddr::V6(Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0x8a2e, 0x0370, 0x7334))),
7171
7171,
"[2001:db8::8a2e:370:7334]:7171"
)]
#[case(
"https://[2001:0db8:0000:0000:0000:8a2e:0370:7334]:433",
"https",
"2001:0db8:0000:0000:0000:8a2e:0370:7334",
Some(IpAddr::V6(Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0x8a2e, 0x0370, 0x7334))),
433
433,
"[2001:0db8:0000:0000:0000:8a2e:0370:7334]:433"
)]
#[case(
"ws://[::1]:2222",
"ws",
"::1",
Some(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))),
2222
2222,
"[::1]:2222"
)]
fn target_addr_parsing(
#[case] repr: &str,
#[case] scheme: &str,
#[case] host: &str,
#[case] ip: Option<IpAddr>,
#[case] port: u16,
#[case] as_addr: &str,
) {
let addr = TargetAddr::parse(repr, None).unwrap();
assert_eq!(addr.scheme(), scheme);
assert_eq!(addr.host(), host);
assert_eq!(addr.host_ip(), ip);
assert_eq!(addr.port(), port);
assert_eq!(addr.as_addr(), as_addr);
}

#[rstest]
Expand Down
35 changes: 18 additions & 17 deletions devolutions-gateway/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,31 @@ use url::Url;

use crate::target_addr::TargetAddr;

pub async fn resolve_target_addr(dest: &TargetAddr) -> anyhow::Result<SocketAddr> {
let port = dest.port();

if let Some(ip) = dest.host_ip() {
Ok(SocketAddr::new(ip, port))
} else {
lookup_host((dest.host(), port))
.await?
.next()
.context("host lookup yielded no result")
}
}

pub async fn tcp_connect(dest: &TargetAddr) -> anyhow::Result<(TcpStream, SocketAddr)> {
const CONNECTION_TIMEOUT: tokio::time::Duration = tokio::time::Duration::from_secs(10);

let fut = async move {
let socket_addr = resolve_target_addr(dest).await?;
let stream = TcpStream::connect(socket_addr)
let addrs = lookup_host(dest.as_addr())
.await
.context("couldn't connect stream")?;
Ok::<_, anyhow::Error>((stream, socket_addr))
.context("failed to lookup destination address")?;

let mut last_err = None;

for addr in addrs {
match TcpStream::connect(addr).await {
Ok(stream) => return Ok((stream, addr)),
Err(error) => {
warn!(%error, resolved = %addr, destination = %dest, "Failed to connect to a resolved address");
last_err = Some(anyhow::Error::new(error).context("TcpStream::connect"))
}
}
}

Err::<_, anyhow::Error>(last_err.unwrap_or_else(|| anyhow::format_err!("could not resolve to any address")))
};

let result = tokio::time::timeout(CONNECTION_TIMEOUT, fut).await??;

Ok(result)
}

Expand Down

0 comments on commit fe4dc63

Please sign in to comment.