Skip to content

Commit

Permalink
[proxy] Pass extra parameters to the console
Browse files Browse the repository at this point in the history
With this change we now pass additional params
to the console's auth methods.
  • Loading branch information
funbringer committed Sep 21, 2022
1 parent 71c92e0 commit f8422d2
Show file tree
Hide file tree
Showing 13 changed files with 257 additions and 164 deletions.
6 changes: 6 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions proxy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ bstr = "0.2.17"
bytes = { version = "1.0.1", features = ['serde'] }
clap = "3.0"
futures = "0.3.13"
git-version = "0.3.5"
hashbrown = "0.12"
hex = "0.4.3"
hmac = "0.12.1"
hyper = "0.14"
itertools = "0.10.3"
once_cell = "1.13.0"
md5 = "0.7.0"
once_cell = "1.13.0"
parking_lot = "0.12"
pin-project-lite = "0.2.7"
rand = "0.8.3"
Expand All @@ -35,14 +36,13 @@ tokio = { version = "1.17", features = ["macros"] }
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="d052ee8b86fff9897c77b0fe89ea9daba0e1fa38" }
tokio-rustls = "0.23.0"
url = "2.2.2"
git-version = "0.3.5"
uuid = { version = "0.8.2", features = ["v4", "serde"]}
x509-parser = "0.13.2"

utils = { path = "../libs/utils" }
metrics = { path = "../libs/metrics" }
workspace_hack = { version = "0.1", path = "../workspace_hack" }

x509-parser = "0.13.2"

[dev-dependencies]
rcgen = "0.8.14"
rstest = "0.12"
Expand Down
2 changes: 1 addition & 1 deletion proxy/src/auth.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Client authentication mechanisms.
pub mod backend;
pub use backend::{BackendType, DatabaseInfo};
pub use backend::{BackendType, ConsoleReqExtra, DatabaseInfo};

mod credentials;
pub use credentials::ClientCredentials;
Expand Down
132 changes: 68 additions & 64 deletions proxy/src/auth/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@ pub use console::{GetAuthInfoError, WakeComputeError};

use crate::{
auth::{self, AuthFlow, ClientCredentials},
compute, config, mgmt,
stream::PqStream,
compute, http, mgmt, stream, url,
waiters::{self, Waiter, Waiters},
};

use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use tokio::io::{AsyncRead, AsyncWrite};

static CPLANE_WAITERS: Lazy<Waiters<mgmt::ComputeReady>> = Lazy::new(Default::default);
Expand Down Expand Up @@ -75,6 +74,14 @@ impl From<DatabaseInfo> for tokio_postgres::Config {
}
}

/// Extra query params we'd like to pass to the console.
pub struct ConsoleReqExtra<'a> {
/// A unique identifier for a connection.
pub session_id: uuid::Uuid,
/// Name of client application, if set.
pub application_name: Option<&'a str>,
}

/// This type serves two purposes:
///
/// * When `T` is `()`, it's just a regular auth backend selector
Expand All @@ -83,53 +90,83 @@ impl From<DatabaseInfo> for tokio_postgres::Config {
/// * However, when we substitute `T` with [`ClientCredentials`],
/// this helps us provide the credentials only to those auth
/// backends which require them for the authentication process.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BackendType<T> {
#[derive(Debug)]
pub enum BackendType<'a, T> {
/// Current Cloud API (V2).
Console(T),
Console(Cow<'a, http::Endpoint>, T),
/// Local mock of Cloud API (V2).
Postgres(T),
Postgres(Cow<'a, url::ApiUrl>, T),
/// Authentication via a web browser.
Link,
Link(Cow<'a, url::ApiUrl>),
}

impl std::fmt::Display for BackendType<'_, ()> {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use BackendType::*;
match self {
Console(endpoint, _) => fmt
.debug_tuple("Console")
.field(&endpoint.url().as_str())
.finish(),
Postgres(endpoint, _) => fmt
.debug_tuple("Postgres")
.field(&endpoint.as_str())
.finish(),
Link(url) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
}
}
}

impl<T> BackendType<T> {
impl<T> BackendType<'_, T> {
/// Very similar to [`std::option::Option::as_ref`].
/// This helps us pass structured config to async tasks.
pub fn as_ref(&self) -> BackendType<'_, &T> {
use BackendType::*;
match self {
Console(c, x) => Console(Cow::Borrowed(c), x),
Postgres(c, x) => Postgres(Cow::Borrowed(c), x),
Link(c) => Link(Cow::Borrowed(c)),
}
}
}

impl<'a, T> BackendType<'a, T> {
/// Very similar to [`std::option::Option::map`].
/// Maps [`BackendType<T>`] to [`BackendType<R>`] by applying
/// a function to a contained value.
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<R> {
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R> {
use BackendType::*;
match self {
Console(x) => Console(f(x)),
Postgres(x) => Postgres(f(x)),
Link => Link,
Console(c, x) => Console(c, f(x)),
Postgres(c, x) => Postgres(c, f(x)),
Link(c) => Link(c),
}
}
}

impl<T, E> BackendType<Result<T, E>> {
impl<'a, T, E> BackendType<'a, Result<T, E>> {
/// Very similar to [`std::option::Option::transpose`].
/// This is most useful for error handling.
pub fn transpose(self) -> Result<BackendType<T>, E> {
pub fn transpose(self) -> Result<BackendType<'a, T>, E> {
use BackendType::*;
match self {
Console(x) => x.map(Console),
Postgres(x) => x.map(Postgres),
Link => Ok(Link),
Console(c, x) => x.map(|x| Console(c, x)),
Postgres(c, x) => x.map(|x| Postgres(c, x)),
Link(c) => Ok(Link(c)),
}
}
}

impl BackendType<ClientCredentials<'_>> {
impl BackendType<'_, ClientCredentials<'_>> {
/// Authenticate the client via the requested backend, possibly using credentials.
pub async fn authenticate(
mut self,
urls: &config::AuthUrls,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
extra: &ConsoleReqExtra<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
) -> super::Result<compute::NodeInfo> {
use BackendType::*;

if let Console(creds) | Postgres(creds) = &mut self {
if let Console(_, creds) | Postgres(_, creds) = &mut self {
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the project name.
// We now expect to see a very specific payload in the place of password.
Expand All @@ -145,15 +182,13 @@ impl BackendType<ClientCredentials<'_>> {
creds.project = Some(payload.project.into());

let mut config = match &self {
Console(creds) => {
console::Api::new(&urls.auth_endpoint, creds)
Console(endpoint, creds) => {
console::Api::new(endpoint, extra, creds)
.wake_compute()
.await?
}
Postgres(creds) => {
postgres::Api::new(&urls.auth_endpoint, creds)
.wake_compute()
.await?
Postgres(endpoint, creds) => {
postgres::Api::new(endpoint, creds).wake_compute().await?
}
_ => unreachable!("see the patterns above"),
};
Expand All @@ -169,49 +204,18 @@ impl BackendType<ClientCredentials<'_>> {
}

match self {
Console(creds) => {
console::Api::new(&urls.auth_endpoint, &creds)
Console(endpoint, creds) => {
console::Api::new(&endpoint, extra, &creds)
.handle_user(client)
.await
}
Postgres(creds) => {
postgres::Api::new(&urls.auth_endpoint, &creds)
Postgres(endpoint, creds) => {
postgres::Api::new(&endpoint, &creds)
.handle_user(client)
.await
}
// NOTE: this auth backend doesn't use client credentials.
Link => link::handle_user(&urls.auth_link_uri, client).await,
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_backend_type_map() {
let values = [
BackendType::Console(0),
BackendType::Postgres(0),
BackendType::Link,
];

for value in values {
assert_eq!(value.map(|x| x), value);
}
}

#[test]
fn test_backend_type_transpose() {
let values = [
BackendType::Console(Ok::<_, ()>(0)),
BackendType::Postgres(Ok(0)),
BackendType::Link,
];

for value in values {
assert_eq!(value.map(Result::unwrap), value.transpose().unwrap());
Link(url) => link::handle_user(&url, client).await,
}
}
}
57 changes: 39 additions & 18 deletions proxy/src/auth/backend/console.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
//! Cloud API V2.
use super::ConsoleReqExtra;
use crate::{
auth::{self, AuthFlow, ClientCredentials},
compute::{self, ComputeConnCfg},
error::{io_error, UserFacingError},
scram,
http, scram,
stream::PqStream,
url::ApiUrl,
};
use serde::{Deserialize, Serialize};
use std::future::Future;
Expand Down Expand Up @@ -120,14 +120,23 @@ pub enum AuthInfo {

#[must_use]
pub(super) struct Api<'a> {
endpoint: &'a ApiUrl,
endpoint: &'a http::Endpoint,
extra: &'a ConsoleReqExtra<'a>,
creds: &'a ClientCredentials<'a>,
}

impl<'a> Api<'a> {
/// Construct an API object containing the auth parameters.
pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Self {
Self { endpoint, creds }
pub(super) fn new(
endpoint: &'a http::Endpoint,
extra: &'a ConsoleReqExtra<'a>,
creds: &'a ClientCredentials,
) -> Self {
Self {
endpoint,
extra,
creds,
}
}

/// Authenticate the existing user or throw an error.
Expand All @@ -139,16 +148,22 @@ impl<'a> Api<'a> {
}

async fn get_auth_info(&self) -> Result<AuthInfo, GetAuthInfoError> {
let mut url = self.endpoint.clone();
url.path_segments_mut().push("proxy_get_role_secret");
url.query_pairs_mut()
.append_pair("project", self.creds.project().expect("impossible"))
.append_pair("role", self.creds.user);
let req = self
.endpoint
.get("proxy_get_role_secret")
.header("X-Request-ID", uuid::Uuid::new_v4().to_string())
.query(&[("session_id", self.extra.session_id)])
.query(&[
("application_name", self.extra.application_name),
("project", Some(self.creds.project().expect("impossible"))),
("role", Some(self.creds.user)),
])
.build()?;

// TODO: use a proper logger
println!("cplane request: {url}");
println!("cplane request: {}", req.url());

let resp = reqwest::get(url.into_inner()).await?;
let resp = self.endpoint.execute(req).await?;
if !resp.status().is_success() {
return Err(TransportError::HttpStatus(resp.status()).into());
}
Expand All @@ -162,15 +177,21 @@ impl<'a> Api<'a> {

/// Wake up the compute node and return the corresponding connection info.
pub(super) async fn wake_compute(&self) -> Result<ComputeConnCfg, WakeComputeError> {
let mut url = self.endpoint.clone();
url.path_segments_mut().push("proxy_wake_compute");
url.query_pairs_mut()
.append_pair("project", self.creds.project().expect("impossible"));
let req = self
.endpoint
.get("proxy_wake_compute")
.header("X-Request-ID", uuid::Uuid::new_v4().to_string())
.query(&[("session_id", self.extra.session_id)])
.query(&[
("application_name", self.extra.application_name),
("project", Some(self.creds.project().expect("impossible"))),
])
.build()?;

// TODO: use a proper logger
println!("cplane request: {url}");
println!("cplane request: {}", req.url());

let resp = reqwest::get(url.into_inner()).await?;
let resp = self.endpoint.execute(req).await?;
if !resp.status().is_success() {
return Err(TransportError::HttpStatus(resp.status()).into());
}
Expand Down
Loading

0 comments on commit f8422d2

Please sign in to comment.