Skip to content

Commit

Permalink
handshake error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
conradludgate committed Jan 23, 2024
1 parent a687732 commit 977c54e
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 44 deletions.
10 changes: 10 additions & 0 deletions proxy/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use password_hack::PasswordHackPayload;

mod flow;
pub use flow::*;
use tokio::time::error::Elapsed;

use crate::{
console,
Expand Down Expand Up @@ -70,6 +71,9 @@ pub enum AuthErrorImpl {

#[error("Too many connections to this endpoint. Please try again later.")]
TooManyConnections,

#[error("Authentication timed out")]
UserTimeout(Elapsed),
}

#[derive(Debug, Error)]
Expand All @@ -96,6 +100,10 @@ impl AuthError {
pub fn is_auth_failed(&self) -> bool {
matches!(self.0.as_ref(), AuthErrorImpl::AuthFailed(_))
}

pub fn user_timeout(elapsed: Elapsed) -> Self {
AuthErrorImpl::UserTimeout(elapsed).into()
}
}

impl<E: Into<AuthErrorImpl>> From<E> for AuthError {
Expand All @@ -119,6 +127,7 @@ impl UserFacingError for AuthError {
Io(_) => "Internal error".to_string(),
IpAddressNotAllowed => self.to_string(),
TooManyConnections => self.to_string(),
UserTimeout(_) => self.to_string(),
}
}
}
Expand All @@ -138,6 +147,7 @@ impl ReportableError for AuthError {
Io(_) => crate::error::ErrorKind::Disconnect,
IpAddressNotAllowed => crate::error::ErrorKind::User,
TooManyConnections => crate::error::ErrorKind::RateLimit,
UserTimeout(_) => crate::error::ErrorKind::User,
}
}
}
4 changes: 2 additions & 2 deletions proxy/src/auth/backend/classic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ pub(super) async fn authenticate(
}
)
.await
.map_err(|error| {
.map_err(|e| {
warn!("error processing scram messages error = authentication timed out, execution time exeeded {} seconds", config.scram_protocol_timeout.as_secs());
auth::io::Error::new(auth::io::ErrorKind::TimedOut, error)
auth::AuthError::user_timeout(e)
})??;

let client_key = match auth_outcome {
Expand Down
11 changes: 1 addition & 10 deletions proxy/src/auth/backend/link.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ use tracing::{info, info_span};

#[derive(Debug, Error)]
pub enum LinkAuthError {
/// Authentication error reported by the console.
#[error("Authentication failed: {0}")]
AuthFailed(String),

#[error(transparent)]
WaiterRegister(#[from] waiters::RegisterError),

Expand All @@ -30,18 +26,13 @@ pub enum LinkAuthError {

impl UserFacingError for LinkAuthError {
fn to_string_client(&self) -> String {
use LinkAuthError::*;
match self {
AuthFailed(_) => self.to_string(),
_ => "Internal error".to_string(),
}
"Internal error".to_string()
}
}

impl ReportableError for LinkAuthError {
fn get_error_type(&self) -> crate::error::ErrorKind {
match self {
LinkAuthError::AuthFailed(_) => crate::error::ErrorKind::User,
LinkAuthError::WaiterRegister(_) => crate::error::ErrorKind::Service,
LinkAuthError::WaiterWait(_) => crate::error::ErrorKind::Service,
LinkAuthError::Io(_) => crate::error::ErrorKind::Disconnect,
Expand Down
15 changes: 9 additions & 6 deletions proxy/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ use crate::{
context::RequestMonitoring,
metrics::{NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE},
protocol2::WithClientIp,
proxy::{handshake::handshake, passthrough::proxy_pass},
proxy::{
handshake::{handshake, HandshakeData},
passthrough::proxy_pass,
},
rate_limiter::EndpointRateLimiter,
stream::{PqStream, Stream},
};
Expand All @@ -34,7 +37,6 @@ use tracing::{error, info, info_span, Instrument};
use self::connect_compute::{connect_to_compute, TcpMechanism};

const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
const ERR_PROTO_VIOLATION: &str = "protocol violation";

pub async fn run_until_cancelled<F: std::future::Future>(
f: F,
Expand Down Expand Up @@ -189,10 +191,11 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let tls = config.tls_config.as_ref();

let pause = ctx.latency_timer.pause();
let do_handshake = handshake(stream, mode.handshake_tls(tls), &cancel_map);
let (mut stream, params) = match do_handshake.await? {
Some(x) => x,
None => return Ok(()), // it's a cancellation request
let (mut stream, params) = match handshake(stream, mode.handshake_tls(tls)).await? {
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data) => {
return cancel_map.cancel_session(cancel_key_data).await
}
};
drop(pause);

Expand Down
74 changes: 58 additions & 16 deletions proxy/src/proxy/handshake.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,57 @@
use anyhow::{bail, Context};
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
use pq_proto::{BeMessage as Be, CancelKeyData, FeStartupPacket, StartupMessageParams};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;

use crate::{
cancellation::CancelMap,
config::TlsConfig,
proxy::{ERR_INSECURE_CONNECTION, ERR_PROTO_VIOLATION},
stream::{PqStream, Stream},
error::ReportableError,
proxy::ERR_INSECURE_CONNECTION,
stream::{PqStream, Stream, StreamUpgradeError},
};

#[derive(Error, Debug)]
pub enum HandshakeError {
#[error("data is sent before server replied with EncryptionResponse")]
EarlyData,

#[error("protocol violation")]
ProtocolViolation,

#[error("connection is insecure (try using `sslmode=require`)")]
InsecureConnection,

#[error("missing certificate")]
MissingCertificate,

#[error("{0}")]
StreamUpgradeError(#[from] StreamUpgradeError),

#[error("{0}")]
Io(#[from] std::io::Error),
}

impl ReportableError for HandshakeError {
fn get_error_type(&self) -> crate::error::ErrorKind {
match self {
HandshakeError::EarlyData => crate::error::ErrorKind::User,
HandshakeError::ProtocolViolation => crate::error::ErrorKind::User,
HandshakeError::InsecureConnection => crate::error::ErrorKind::User,
HandshakeError::MissingCertificate => todo!(),
HandshakeError::StreamUpgradeError(upgrade) => match upgrade {
StreamUpgradeError::AlreadyTls => crate::error::ErrorKind::User,
StreamUpgradeError::Io(_) => crate::error::ErrorKind::Disconnect,
},
HandshakeError::Io(_) => crate::error::ErrorKind::Disconnect,
}
}
}

pub enum HandshakeData<S> {
Startup(PqStream<Stream<S>>, StartupMessageParams),
Cancel(CancelKeyData),
}

/// Establish a (most probably, secure) connection with the client.
/// For better testing experience, `stream` can be any object satisfying the traits.
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
Expand All @@ -18,8 +60,7 @@ use crate::{
pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream: S,
mut tls: Option<&TlsConfig>,
cancel_map: &CancelMap,
) -> anyhow::Result<Option<(PqStream<Stream<S>>, StartupMessageParams)>> {
) -> Result<HandshakeData<S>, HandshakeError> {
// Client may try upgrading to each protocol only once
let (mut tried_ssl, mut tried_gss) = (false, false);

Expand Down Expand Up @@ -49,22 +90,22 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
// pipelining in our node js driver. We should probably
// support that by chaining read_buf with the stream.
if !read_buf.is_empty() {
bail!("data is sent before server replied with EncryptionResponse");
return Err(HandshakeError::EarlyData);
}
let tls_stream = raw.upgrade(tls.to_server_config()).await?;

let (_, tls_server_end_point) = tls
.cert_resolver
.resolve(tls_stream.get_ref().1.server_name())
.context("missing certificate")?;
.ok_or(HandshakeError::MissingCertificate)?;

stream = PqStream::new(Stream::Tls {
tls: Box::new(tls_stream),
tls_server_end_point,
});
}
}
_ => bail!(ERR_PROTO_VIOLATION),
_ => return Err(HandshakeError::ProtocolViolation),
},
GssEncRequest => match stream.get_ref() {
Stream::Raw { .. } if !tried_gss => {
Expand All @@ -73,23 +114,24 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
// Currently, we don't support GSSAPI
stream.write_message(&Be::EncryptionResponse(false)).await?;
}
_ => bail!(ERR_PROTO_VIOLATION),
_ => return Err(HandshakeError::ProtocolViolation),
},
StartupMessage { params, .. } => {
// Check that the config has been consumed during upgrade
// OR we didn't provide it at all (for dev purposes).
if tls.is_some() {
stream.throw_error_str(ERR_INSECURE_CONNECTION).await?;
stream
.write_message(&Be::ErrorResponse(ERR_INSECURE_CONNECTION, None))
.await?;
return Err(HandshakeError::InsecureConnection);
}

info!(session_type = "normal", "successful handshake");
break Ok(Some((stream, params)));
break Ok(HandshakeData::Startup(stream, params));
}
CancelRequest(cancel_key_data) => {
cancel_map.cancel_session(cancel_key_data).await?;

info!(session_type = "cancellation", "successful handshake");
break Ok(None);
break Ok(HandshakeData::Cancel(cancel_key_data));
}
}
}
Expand Down
8 changes: 4 additions & 4 deletions proxy/src/proxy/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,11 @@ async fn dummy_proxy(
tls: Option<TlsConfig>,
auth: impl TestAuth + Send,
) -> anyhow::Result<()> {
let cancel_map = CancelMap::default();
let client = WithClientIp::new(client);
let (mut stream, _params) = handshake(client, tls.as_ref(), &cancel_map)
.await?
.context("handshake failed")?;
let mut stream = match handshake(client, tls.as_ref()).await? {
HandshakeData::Startup(stream, _) => stream,
HandshakeData::Cancel(_) => bail!("cancellation not supported"),
};

auth.authenticate(&mut stream).await?;

Expand Down
10 changes: 4 additions & 6 deletions proxy/src/proxy/tests/mitm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,10 @@ async fn proxy_mitm(
tokio::spawn(async move {
// begin handshake with end_server
let end_server = connect_tls(server2, client_config2.make_tls_connect().unwrap()).await;
// process handshake with end_client
let (end_client, startup) =
handshake(client1, Some(&server_config1), &CancelMap::default())
.await
.unwrap()
.unwrap();
let (end_client, startup) = match handshake(client1, Some(&server_config1)).await.unwrap() {
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(_) => panic!("cancellation not supported"),
};

let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame);
let (end_client, buf) = end_client.framed.into_inner();
Expand Down

0 comments on commit 977c54e

Please sign in to comment.