diff --git a/crates/jmux-proxy/src/lib.rs b/crates/jmux-proxy/src/lib.rs index 53a22acdf..bba8cc3e0 100644 --- a/crates/jmux-proxy/src/lib.rs +++ b/crates/jmux-proxy/src/lib.rs @@ -846,24 +846,8 @@ impl StreamResolverTask { let host = destination_url.host(); let port = destination_url.port(); - let mut addrs = match tokio::net::lookup_host((host, port)).await { - Ok(addrs) => addrs, - Err(error) => { - debug!(?error, "tokio::net::lookup_host failed"); - msg_to_send_tx - .send(Message::open_failure( - channel.distant_id, - ReasonCode::from(error.kind()), - error.to_string(), - )) - .context("couldn’t send OPEN FAILURE message through mpsc channel")?; - anyhow::bail!("couldn't resolve {}:{}: {}", host, port, error); - } - }; - let socket_addr = addrs.next().expect("at least one resolved address should be present"); - match scheme { - "tcp" => match TcpStream::connect(socket_addr).await { + "tcp" => match TcpStream::connect((host, port)).await { Ok(stream) => { internal_msg_tx .send(InternalMessage::StreamResolved { channel, stream }) @@ -878,7 +862,7 @@ impl StreamResolverTask { error.to_string(), )) .context("couldn’t send OPEN FAILURE message through mpsc channel")?; - anyhow::bail!("Couldn’t connect TCP socket to {}:{}: {}", host, port, error); + anyhow::bail!("couldn’t open TCP stream to {}:{}: {}", host, port, error); } }, _ => anyhow::bail!("unsupported scheme: {}", scheme), diff --git a/crates/proxy-server/src/main.rs b/crates/proxy-server/src/main.rs index 734759d16..28544d49b 100644 --- a/crates/proxy-server/src/main.rs +++ b/crates/proxy-server/src/main.rs @@ -193,26 +193,15 @@ async fn process_https(incoming: TcpStream) -> io::Result<()> { println!("Requested proxying to {dest_addr:?}"); - let socket_addr = { - match dest_addr.clone() { - proxy_types::DestAddr::Ip(addr) => addr, - proxy_types::DestAddr::Domain(domain, port) => { - let mut addrs = match tokio::net::lookup_host((domain, port)).await { - Ok(addrs) => addrs, - Err(e) => { - acceptor.failure(proxy_http::ErrorCode::BadGateway).await?; - return Err(e); - } - }; - addrs.next().expect("at least one resolved address should be present") - } - } + let connect_result = match dest_addr { + proxy_types::DestAddr::Ip(addr) => TcpStream::connect(addr).await, + proxy_types::DestAddr::Domain(domain, port) => TcpStream::connect((domain.as_str(), *port)).await, }; - let target_stream = match TcpStream::connect(socket_addr).await { + let target_stream = match connect_result { Ok(stream) => stream, Err(e) => { - acceptor.failure(proxy_http::ErrorCode::InternalServerError).await?; + acceptor.failure(proxy_http::ErrorCode::BadGateway).await?; return Err(e); } }; diff --git a/devolutions-gateway/src/api/kdc_proxy.rs b/devolutions-gateway/src/api/kdc_proxy.rs index db06046a4..6f5403080 100644 --- a/devolutions-gateway/src/api/kdc_proxy.rs +++ b/devolutions-gateway/src/api/kdc_proxy.rs @@ -9,7 +9,6 @@ use tokio::net::{TcpStream, UdpSocket}; use crate::http::HttpError; use crate::token::AccessTokenClaims; -use crate::utils::resolve_target_addr; use crate::DgwState; pub fn make_router(state: DgwState) -> Router { @@ -79,14 +78,10 @@ async fn kdc_proxy( let protocol = kdc_addr.scheme(); - let kdc_addr = resolve_target_addr(kdc_addr) - .await - .map_err(HttpError::internal().with_msg("unable to locate KDC server").err())?; - trace!("Connecting to KDC server located at {kdc_addr} using protocol {protocol}..."); let kdc_reply_message = if protocol == "tcp" { - let mut connection = TcpStream::connect(kdc_addr) + let mut connection = TcpStream::connect(kdc_addr.as_addr()) .await .map_err(HttpError::internal().with_msg("unable to connect to KDC server").err())?; @@ -124,7 +119,7 @@ async fn kdc_proxy( // first 4 bytes contains message length. we don't need it for UDP udp_socket - .send_to(&kdc_proxy_message.kerb_message.0 .0[4..], kdc_addr) + .send_to(&kdc_proxy_message.kerb_message.0 .0[4..], kdc_addr.as_addr()) .await .map_err( HttpError::internal() diff --git a/devolutions-gateway/src/config.rs b/devolutions-gateway/src/config.rs index 97315f251..272b31f6d 100644 --- a/devolutions-gateway/src/config.rs +++ b/devolutions-gateway/src/config.rs @@ -394,7 +394,7 @@ fn read_pfx_file( let crypto_context = password .map(|pwd| Pkcs12CryptoContext::new_with_password(pwd.get())) - .unwrap_or_else(|| Pkcs12CryptoContext::new_without_password()); + .unwrap_or_else(Pkcs12CryptoContext::new_without_password); let parsing_params = Pkcs12ParsingParams::default(); let pfx_contents = normalize_data_path(path, &get_data_dir()) diff --git a/devolutions-gateway/src/target_addr.rs b/devolutions-gateway/src/target_addr.rs index b50f415de..0d430612c 100644 --- a/devolutions-gateway/src/target_addr.rs +++ b/devolutions-gateway/src/target_addr.rs @@ -2,6 +2,7 @@ use core::fmt; use serde::{de, ser}; use smol_str::SmolStr; use std::net::IpAddr; +use std::ops::RangeBounds; use std::str::FromStr; use tap::Pipe as _; @@ -96,11 +97,11 @@ impl TargetAddr { } pub fn scheme(&self) -> &str { - self.h_slice(0, self.scheme_end) + self.h_slice_repr(..self.scheme_end) } pub fn host(&self) -> &str { - self.h_slice(self.host_start, self.host_end) + self.h_slice_repr(self.host_start..self.host_end) } pub fn host_ip(&self) -> Option { @@ -114,9 +115,28 @@ impl TargetAddr { self.port } - #[inline] - fn h_slice(&self, start: u16, end: u16) -> &str { - &self.serialization[usize::from(start)..usize::from(end)] + pub fn as_addr(&self) -> &str { + self.h_slice_repr((self.scheme_end + 3)..) + } + + // Slice the internal representation using a [`Range`] + fn h_slice_repr(&self, range: impl RangeBounds) -> &str { + use std::ops::Bound; + + // TODO(@CBenoit): use Bound::map when stabilized (bound_map feature) + // https://github.com/rust-lang/rust/issues/86026 + let lo = match range.start_bound() { + Bound::Included(idx) => Bound::Included(usize::from(*idx)), + Bound::Excluded(idx) => Bound::Excluded(usize::from(*idx)), + Bound::Unbounded => Bound::Unbounded, + }; + let hi = match range.end_bound() { + Bound::Included(idx) => Bound::Included(usize::from(*idx)), + Bound::Excluded(idx) => Bound::Excluded(usize::from(*idx)), + Bound::Unbounded => Bound::Unbounded, + }; + + &self.serialization.as_str()[(lo, hi)] } } @@ -237,6 +257,14 @@ impl<'de> de::Deserialize<'de> for TargetAddr { } } +impl std::net::ToSocketAddrs for TargetAddr { + type Iter = std::vec::IntoIter; + + fn to_socket_addrs(&self) -> std::io::Result { + self.as_addr().to_socket_addrs() + } +} + #[cfg(test)] mod tests { use super::*; @@ -245,34 +273,38 @@ mod tests { use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; #[rstest] - #[case("localhost:80", "tcp", "localhost", None, 80)] + #[case("localhost:80", "tcp", "localhost", None, 80, "localhost:80")] #[case( "udp://127.0.0.1:8080", "udp", "127.0.0.1", Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))), - 8080 + 8080, + "127.0.0.1:8080" )] #[case( "tcp://[2001:db8::8a2e:370:7334]:7171", "tcp", "2001:db8::8a2e:370:7334", Some(IpAddr::V6(Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0x8a2e, 0x0370, 0x7334))), - 7171 + 7171, + "[2001:db8::8a2e:370:7334]:7171" )] #[case( "https://[2001:0db8:0000:0000:0000:8a2e:0370:7334]:433", "https", "2001:0db8:0000:0000:0000:8a2e:0370:7334", Some(IpAddr::V6(Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0x8a2e, 0x0370, 0x7334))), - 433 + 433, + "[2001:0db8:0000:0000:0000:8a2e:0370:7334]:433" )] #[case( "ws://[::1]:2222", "ws", "::1", Some(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))), - 2222 + 2222, + "[::1]:2222" )] fn target_addr_parsing( #[case] repr: &str, @@ -280,12 +312,14 @@ mod tests { #[case] host: &str, #[case] ip: Option, #[case] port: u16, + #[case] as_addr: &str, ) { let addr = TargetAddr::parse(repr, None).unwrap(); assert_eq!(addr.scheme(), scheme); assert_eq!(addr.host(), host); assert_eq!(addr.host_ip(), ip); assert_eq!(addr.port(), port); + assert_eq!(addr.as_addr(), as_addr); } #[rstest] diff --git a/devolutions-gateway/src/utils.rs b/devolutions-gateway/src/utils.rs index c75833816..6d76cb1ad 100644 --- a/devolutions-gateway/src/utils.rs +++ b/devolutions-gateway/src/utils.rs @@ -6,30 +6,31 @@ use url::Url; use crate::target_addr::TargetAddr; -pub async fn resolve_target_addr(dest: &TargetAddr) -> anyhow::Result { - let port = dest.port(); - - if let Some(ip) = dest.host_ip() { - Ok(SocketAddr::new(ip, port)) - } else { - lookup_host((dest.host(), port)) - .await? - .next() - .context("host lookup yielded no result") - } -} - pub async fn tcp_connect(dest: &TargetAddr) -> anyhow::Result<(TcpStream, SocketAddr)> { const CONNECTION_TIMEOUT: tokio::time::Duration = tokio::time::Duration::from_secs(10); let fut = async move { - let socket_addr = resolve_target_addr(dest).await?; - let stream = TcpStream::connect(socket_addr) + let addrs = lookup_host(dest.as_addr()) .await - .context("couldn't connect stream")?; - Ok::<_, anyhow::Error>((stream, socket_addr)) + .context("failed to lookup destination address")?; + + let mut last_err = None; + + for addr in addrs { + match TcpStream::connect(addr).await { + Ok(stream) => return Ok((stream, addr)), + Err(error) => { + warn!(%error, resolved = %addr, destination = %dest, "Failed to connect to a resolved address"); + last_err = Some(anyhow::Error::new(error).context("TcpStream::connect")) + } + } + } + + Err::<_, anyhow::Error>(last_err.unwrap_or_else(|| anyhow::format_err!("could not resolve to any address"))) }; + let result = tokio::time::timeout(CONNECTION_TIMEOUT, fut).await??; + Ok(result) }