Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proxy error reworking #6453

Merged
merged 9 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions proxy/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ pub use backend::BackendType;

mod credentials;
pub use credentials::{
check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint, IpPattern,
check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint,
ComputeUserInfoParseError, IpPattern,
};

mod password_hack;
Expand All @@ -14,8 +15,12 @@ use password_hack::PasswordHackPayload;

mod flow;
pub use flow::*;
use tokio::time::error::Elapsed;

use crate::{console, error::UserFacingError};
use crate::{
console,
error::{ReportableError, UserFacingError},
};
use std::io;
use thiserror::Error;

Expand Down Expand Up @@ -67,6 +72,9 @@ pub enum AuthErrorImpl {

#[error("Too many connections to this endpoint. Please try again later.")]
TooManyConnections,

#[error("Authentication timed out")]
UserTimeout(Elapsed),
}

#[derive(Debug, Error)]
Expand All @@ -93,6 +101,10 @@ impl AuthError {
pub fn is_auth_failed(&self) -> bool {
matches!(self.0.as_ref(), AuthErrorImpl::AuthFailed(_))
}

pub fn user_timeout(elapsed: Elapsed) -> Self {
AuthErrorImpl::UserTimeout(elapsed).into()
}
}

impl<E: Into<AuthErrorImpl>> From<E> for AuthError {
Expand All @@ -116,6 +128,27 @@ impl UserFacingError for AuthError {
Io(_) => "Internal error".to_string(),
IpAddressNotAllowed => self.to_string(),
TooManyConnections => self.to_string(),
UserTimeout(_) => self.to_string(),
}
}
}

impl ReportableError for AuthError {
fn get_error_type(&self) -> crate::error::ErrorKind {
use AuthErrorImpl::*;
match self.0.as_ref() {
Link(e) => e.get_error_type(),
GetAuthInfo(e) => e.get_error_type(),
WakeCompute(e) => e.get_error_type(),
Sasl(e) => e.get_error_type(),
AuthFailed(_) => crate::error::ErrorKind::User,
BadAuthMethod(_) => crate::error::ErrorKind::User,
MalformedPassword(_) => crate::error::ErrorKind::User,
MissingEndpointName => crate::error::ErrorKind::User,
Io(_) => crate::error::ErrorKind::Disconnect,
IpAddressNotAllowed => crate::error::ErrorKind::User,
TooManyConnections => crate::error::ErrorKind::RateLimit,
UserTimeout(_) => crate::error::ErrorKind::User,
}
}
}
4 changes: 2 additions & 2 deletions proxy/src/auth/backend/classic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ pub(super) async fn authenticate(
}
)
.await
.map_err(|error| {
.map_err(|e| {
warn!("error processing scram messages error = authentication timed out, execution time exeeded {} seconds", config.scram_protocol_timeout.as_secs());
auth::io::Error::new(auth::io::ErrorKind::TimedOut, error)
auth::AuthError::user_timeout(e)
})??;

let client_key = match auth_outcome {
Expand Down
18 changes: 10 additions & 8 deletions proxy/src/auth/backend/link.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{
auth, compute,
console::{self, provider::NodeInfo},
context::RequestMonitoring,
error::UserFacingError,
error::{ReportableError, UserFacingError},
stream::PqStream,
waiters,
};
Expand All @@ -14,10 +14,6 @@ use tracing::{info, info_span};

#[derive(Debug, Error)]
pub enum LinkAuthError {
/// Authentication error reported by the console.
#[error("Authentication failed: {0}")]
AuthFailed(String),

#[error(transparent)]
WaiterRegister(#[from] waiters::RegisterError),

Expand All @@ -30,10 +26,16 @@ pub enum LinkAuthError {

impl UserFacingError for LinkAuthError {
fn to_string_client(&self) -> String {
use LinkAuthError::*;
"Internal error".to_string()
}
}

impl ReportableError for LinkAuthError {
fn get_error_type(&self) -> crate::error::ErrorKind {
match self {
AuthFailed(_) => self.to_string(),
_ => "Internal error".to_string(),
LinkAuthError::WaiterRegister(_) => crate::error::ErrorKind::Service,
LinkAuthError::WaiterWait(_) => crate::error::ErrorKind::Service,
LinkAuthError::Io(_) => crate::error::ErrorKind::Disconnect,
}
}
}
Expand Down
14 changes: 12 additions & 2 deletions proxy/src/auth/credentials.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
//! User credentials used in authentication.

use crate::{
auth::password_hack::parse_endpoint_param, context::RequestMonitoring, error::UserFacingError,
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions, serverless::SERVERLESS_DRIVER_SNI,
auth::password_hack::parse_endpoint_param,
context::RequestMonitoring,
error::{ReportableError, UserFacingError},
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI,
proxy::NeonOptions,
serverless::SERVERLESS_DRIVER_SNI,
EndpointId, RoleName,
};
use itertools::Itertools;
Expand Down Expand Up @@ -39,6 +43,12 @@ pub enum ComputeUserInfoParseError {

impl UserFacingError for ComputeUserInfoParseError {}

impl ReportableError for ComputeUserInfoParseError {
fn get_error_type(&self) -> crate::error::ErrorKind {
crate::error::ErrorKind::User
}
}

/// Various client credentials which we use for authentication.
/// Note that we don't store any kind of client key or password here.
#[derive(Debug, Clone, PartialEq, Eq)]
Expand Down
11 changes: 9 additions & 2 deletions proxy/src/bin/pg_sni_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,9 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
?unexpected,
"unexpected startup packet, rejecting connection"
);
stream.throw_error_str(ERR_INSECURE_CONNECTION).await?
stream
.throw_error_str(ERR_INSECURE_CONNECTION, proxy::error::ErrorKind::User)
.await?
}
}
}
Expand Down Expand Up @@ -272,5 +274,10 @@ async fn handle_client(
let client = tokio::net::TcpStream::connect(destination).await?;

let metrics_aux: MetricsAuxInfo = Default::default();
proxy::proxy::passthrough::proxy_pass(ctx, tls_stream, client, metrics_aux).await

// doesn't yet matter as pg-sni-router doesn't report analytics logs
ctx.set_success();
ctx.log();

proxy::proxy::passthrough::proxy_pass(tls_stream, client, metrics_aux).await
}
34 changes: 26 additions & 8 deletions proxy/src/cancellation.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,42 @@
use anyhow::Context;
use dashmap::DashMap;
use pq_proto::CancelKeyData;
use std::{net::SocketAddr, sync::Arc};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio_postgres::{CancelToken, NoTls};
use tracing::info;

use crate::error::ReportableError;

/// Enables serving `CancelRequest`s.
#[derive(Default)]
pub struct CancelMap(DashMap<CancelKeyData, Option<CancelClosure>>);

#[derive(Debug, Error)]
pub enum CancelError {
#[error("{0}")]
IO(#[from] std::io::Error),
#[error("{0}")]
Postgres(#[from] tokio_postgres::Error),
}

impl ReportableError for CancelError {
fn get_error_type(&self) -> crate::error::ErrorKind {
match self {
CancelError::IO(_) => crate::error::ErrorKind::Compute,
CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
}
}
}

impl CancelMap {
/// Cancel a running query for the corresponding connection.
pub async fn cancel_session(&self, key: CancelKeyData) -> anyhow::Result<()> {
pub async fn cancel_session(&self, key: CancelKeyData) -> Result<(), CancelError> {
// NB: we should immediately release the lock after cloning the token.
let cancel_closure = self
.0
.get(&key)
.and_then(|x| x.clone())
.with_context(|| format!("query cancellation key not found: {key}"))?;
let Some(cancel_closure) = self.0.get(&key).and_then(|x| x.clone()) else {
tracing::warn!("query cancellation key not found: {key}");
return Ok(());
};

info!("cancelling query per user's request using key {key}");
cancel_closure.try_cancel_query().await
Expand Down Expand Up @@ -81,7 +99,7 @@ impl CancelClosure {
}

/// Cancels the query running on user's compute node.
pub async fn try_cancel_query(self) -> anyhow::Result<()> {
async fn try_cancel_query(self) -> Result<(), CancelError> {
let socket = TcpStream::connect(self.socket_addr).await?;
self.cancel_token.cancel_query_raw(socket, NoTls).await?;

Expand Down
19 changes: 17 additions & 2 deletions proxy/src/compute.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use crate::{
auth::parse_endpoint_param, cancellation::CancelClosure, console::errors::WakeComputeError,
context::RequestMonitoring, error::UserFacingError, metrics::NUM_DB_CONNECTIONS_GAUGE,
auth::parse_endpoint_param,
cancellation::CancelClosure,
console::errors::WakeComputeError,
context::RequestMonitoring,
error::{ReportableError, UserFacingError},
metrics::NUM_DB_CONNECTIONS_GAUGE,
proxy::neon_option,
};
use futures::{FutureExt, TryFutureExt};
Expand Down Expand Up @@ -58,6 +62,17 @@ impl UserFacingError for ConnectionError {
}
}

impl ReportableError for ConnectionError {
fn get_error_type(&self) -> crate::error::ErrorKind {
match self {
ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
ConnectionError::WakeComputeError(e) => e.get_error_type(),
}
}
}

/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;

Expand Down
31 changes: 30 additions & 1 deletion proxy/src/console/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use tracing::info;

pub mod errors {
use crate::{
error::{io_error, UserFacingError},
error::{io_error, ReportableError, UserFacingError},
http,
proxy::retry::ShouldRetry,
};
Expand Down Expand Up @@ -81,6 +81,15 @@ pub mod errors {
}
}

impl ReportableError for ApiError {
fn get_error_type(&self) -> crate::error::ErrorKind {
match self {
ApiError::Console { .. } => crate::error::ErrorKind::ControlPlane,
ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane,
}
}
}

impl ShouldRetry for ApiError {
fn could_retry(&self) -> bool {
match self {
Expand Down Expand Up @@ -150,6 +159,16 @@ pub mod errors {
}
}
}

impl ReportableError for GetAuthInfoError {
fn get_error_type(&self) -> crate::error::ErrorKind {
match self {
GetAuthInfoError::BadSecret => crate::error::ErrorKind::ControlPlane,
GetAuthInfoError::ApiError(_) => crate::error::ErrorKind::ControlPlane,
}
}
}

#[derive(Debug, Error)]
pub enum WakeComputeError {
#[error("Console responded with a malformed compute address: {0}")]
Expand Down Expand Up @@ -194,6 +213,16 @@ pub mod errors {
}
}
}

impl ReportableError for WakeComputeError {
fn get_error_type(&self) -> crate::error::ErrorKind {
match self {
WakeComputeError::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane,
WakeComputeError::ApiError(e) => e.get_error_type(),
WakeComputeError::TimeoutError => crate::error::ErrorKind::RateLimit,
}
}
}
}

/// Auth secret which is managed by the cloud.
Expand Down
18 changes: 16 additions & 2 deletions proxy/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ use tokio::sync::mpsc;
use uuid::Uuid;

use crate::{
console::messages::MetricsAuxInfo, error::ErrorKind, metrics::LatencyTimer, BranchId,
EndpointId, ProjectId, RoleName,
console::messages::MetricsAuxInfo,
error::ErrorKind,
metrics::{LatencyTimer, ENDPOINT_ERRORS_BY_KIND, ERROR_BY_KIND},
BranchId, EndpointId, ProjectId, RoleName,
};

pub mod parquet;
Expand Down Expand Up @@ -108,6 +110,18 @@ impl RequestMonitoring {
self.user = Some(user);
}

pub fn set_error_kind(&mut self, error: ErrorKind) {
ERROR_BY_KIND
.with_label_values(&[error.to_metric_label()])
.inc();
if let Some(ep) = &self.endpoint_id {
ENDPOINT_ERRORS_BY_KIND
.with_label_values(&[error.to_metric_label()])
.measure(ep);
}
self.error_kind = Some(error);
}

pub fn set_success(&mut self) {
self.success = true;
}
Expand Down
Loading
Loading