Skip to content

Commit

Permalink
Merge pull request #396 from kazk/conditional-auth-layer
Browse files Browse the repository at this point in the history
Use AuthLayer conditionally
  • Loading branch information
clux authored Feb 7, 2021
2 parents 6653051 + 6340c9c commit cff35fb
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 59 deletions.
68 changes: 36 additions & 32 deletions kube/src/config/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,34 @@ pub(crate) enum Authentication {
None,
Basic(String),
Token(String),
RefreshableToken(Arc<Mutex<(String, DateTime<Utc>, AuthInfo)>>),
RefreshableToken(RefreshableToken),
}

#[derive(Debug, Clone)]
pub(crate) struct RefreshableToken(pub(crate) Arc<Mutex<(String, DateTime<Utc>, AuthInfo)>>);

impl RefreshableToken {
pub(crate) async fn to_header(&self) -> Result<header::HeaderValue> {
let data = &self.0;
let mut locked_data = data.lock().await;
// Add some wiggle room onto the current timestamp so we don't get any race
// conditions where the token expires while we are refreshing
if Utc::now() + Duration::seconds(60) >= locked_data.1 {
if let Authentication::RefreshableToken(d) =
Authentication::from_auth_info(&locked_data.2).await?
{
let (new_token, new_expire, new_info) = Arc::try_unwrap(d.0)
.expect("Unable to unwrap Arc, this is likely a programming error")
.into_inner();
locked_data.0 = new_token;
locked_data.1 = new_expire;
locked_data.2 = new_info;
} else {
return Err(ConfigError::UnrefreshableTokenResponse).map_err(Error::from);
}
}
Ok(header::HeaderValue::from_str(&locked_data.0).map_err(ConfigError::InvalidBearerToken)?)
}
}

impl Authentication {
Expand All @@ -43,28 +70,7 @@ impl Authentication {
Self::Token(value) => Ok(Some(
header::HeaderValue::from_str(value).map_err(ConfigError::InvalidBearerToken)?,
)),
Self::RefreshableToken(data) => {
let mut locked_data = data.lock().await;
// Add some wiggle room onto the current timestamp so we don't get any race
// conditions where the token expires while we are refreshing
if Utc::now() + Duration::seconds(60) >= locked_data.1 {
if let Authentication::RefreshableToken(d) =
Authentication::from_auth_info(&locked_data.2).await?
{
let (new_token, new_expire, new_info) = Arc::try_unwrap(d)
.expect("Unable to unwrap Arc, this is likely a programming error")
.into_inner();
locked_data.0 = new_token;
locked_data.1 = new_expire;
locked_data.2 = new_info;
} else {
return Err(ConfigError::UnrefreshableTokenResponse).map_err(Error::from);
}
}
Ok(Some(
header::HeaderValue::from_str(&locked_data.0).map_err(ConfigError::InvalidBearerToken)?,
))
}
Self::RefreshableToken(refreshable) => Ok(Some(refreshable.to_header().await?)),
}
}

Expand All @@ -80,11 +86,11 @@ impl Authentication {
provider.config.insert("access-token".into(), token.clone());
provider.config.insert("expiry".into(), expiry.to_rfc3339());
info.auth_provider = Some(provider);
return Ok(Self::RefreshableToken(Arc::new(Mutex::new((
return Ok(Self::RefreshableToken(RefreshableToken(Arc::new(Mutex::new((
format!("Bearer {}", token),
expiry,
info,
)))));
))))));
}

ProviderToken::GCP(token, None) => {
Expand Down Expand Up @@ -119,11 +125,9 @@ impl Authentication {
expiration,
) {
(Ok(token), _, None) => Ok(Authentication::Token(format!("Bearer {}", token))),
(Ok(token), _, Some(expire)) => Ok(Authentication::RefreshableToken(Arc::new(Mutex::new((
format!("Bearer {}", token),
expire,
auth_info.clone(),
))))),
(Ok(token), _, Some(expire)) => Ok(Authentication::RefreshableToken(RefreshableToken(Arc::new(
Mutex::new((format!("Bearer {}", token), expire, auth_info.clone())),
)))),
(_, (Some(u), Some(p)), _) => {
let encoded = base64::encode(&format!("{}:{}", u, p));
Ok(Authentication::Basic(format!("Basic {}", encoded)))
Expand Down Expand Up @@ -413,8 +417,8 @@ mod test {
let mut config: Kubeconfig = serde_yaml::from_str(&test_file).map_err(ConfigError::ParseYaml)?;
let auth_info = &mut config.auth_infos[0].auth_info;
match Authentication::from_auth_info(&auth_info).await {
Ok(Authentication::RefreshableToken(data)) => {
let (token, _expire, info) = Arc::try_unwrap(data).unwrap().into_inner();
Ok(Authentication::RefreshableToken(refreshable)) => {
let (token, _expire, info) = Arc::try_unwrap(refreshable.0).unwrap().into_inner();
assert_eq!(token, "Bearer my_token".to_owned());
let config = info.auth_provider.unwrap().config;
assert_eq!(config.get("access-token"), Some(&"my_token".to_owned()));
Expand Down
2 changes: 1 addition & 1 deletion kube/src/config/file_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ pub struct NamedAuthInfo {
}

/// AuthInfo stores information to tell cluster who you are.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
pub struct AuthInfo {
pub username: Option<String>,
pub password: Option<String>,
Expand Down
2 changes: 1 addition & 1 deletion kube/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ mod incluster_config;
mod utils;

use crate::{error::ConfigError, Result};
pub(crate) use auth::Authentication;
pub(crate) use auth::{Authentication, RefreshableToken};
use file_loader::ConfigLoader;
pub use file_loader::KubeConfigOptions;

Expand Down
34 changes: 21 additions & 13 deletions kube/src/service/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ use hyper::Body;
use pin_project::pin_project;
use tower::{layer::Layer, BoxError, Service};

use crate::{config::Authentication, Result};
use crate::{config::RefreshableToken, Result};

/// `Layer` to decorate the request with `Authorization` header.
pub struct AuthLayer {
auth: Authentication,
auth: RefreshableToken,
}

impl AuthLayer {
pub(crate) fn new(auth: Authentication) -> Self {
pub(crate) fn new(auth: RefreshableToken) -> Self {
Self { auth }
}
}
Expand All @@ -41,7 +41,7 @@ pub struct AuthService<S>
where
S: Service<Request<Body>>,
{
auth: Authentication,
auth: RefreshableToken,
service: S,
}

Expand Down Expand Up @@ -69,11 +69,8 @@ where

let auth = self.auth.clone();
let request = async move {
// If using authorization header, attach the updated value.
auth.to_header().await.map_err(BoxError::from).map(|opt| {
if let Some(value) = opt {
req.headers_mut().insert(AUTHORIZATION, value);
}
auth.to_header().await.map_err(BoxError::from).map(|value| {
req.headers_mut().insert(AUTHORIZATION, value);
req
})
};
Expand Down Expand Up @@ -136,20 +133,22 @@ where
mod tests {
use super::*;

use std::matches;
use std::{matches, sync::Arc};

use chrono::{Duration, Utc};
use futures::pin_mut;
use http::{HeaderValue, Request, Response};
use hyper::Body;
use tokio::sync::Mutex;
use tokio_test::assert_ready_ok;
use tower_test::mock;

use crate::{error::ConfigError, Error};
use crate::{config::AuthInfo, error::ConfigError, Error};

#[tokio::test(flavor = "current_thread")]
async fn valid_token() {
const TOKEN: &str = "Bearer test";
let auth = Authentication::Token(TOKEN.into());
let auth = test_token(TOKEN.into());
let (mut service, handle) = mock::spawn_layer(AuthLayer::new(auth));

let spawned = tokio::spawn(async move {
Expand All @@ -174,7 +173,7 @@ mod tests {
#[tokio::test(flavor = "current_thread")]
async fn invalid_token() {
const TOKEN: &str = "\n";
let auth = Authentication::Token(TOKEN.into());
let auth = test_token(TOKEN.into());
let (mut service, _handle) =
mock::spawn_layer::<Request<Body>, Response<Body>, _>(AuthLayer::new(auth));
let err = service
Expand All @@ -188,4 +187,13 @@ mod tests {
Error::Kubeconfig(ConfigError::InvalidBearerToken(_))
));
}

fn test_token(token: String) -> RefreshableToken {
let expiry = Utc::now() + Duration::seconds(60 * 60);
let info = AuthInfo {
token: Some(token.clone()),
..Default::default()
};
RefreshableToken(Arc::new(Mutex::new((token, expiry, info))))
}
}
50 changes: 38 additions & 12 deletions kube/src/service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ use tls::HttpsConnector;

use std::convert::{TryFrom, TryInto};

use http::{Request, Response};
use http::{HeaderValue, Request, Response};
use hyper::{Body, Client as HyperClient};
use hyper_timeout::TimeoutConnector;
use tower::{buffer::Buffer, util::BoxService, BoxError, ServiceBuilder};

use crate::{Config, Error, Result};
use crate::{config::Authentication, error::ConfigError, Config, Error, Result};

// - `Buffer` for cheap clone
// - `BoxService` to avoid type parameters in `Client`
Expand Down Expand Up @@ -65,10 +65,30 @@ impl TryFrom<Config> for Service {
/// Convert [`Config`] into a [`Service`]
fn try_from(config: Config) -> Result<Self> {
let cluster_url = config.cluster_url.clone();
let default_headers = config.headers.clone();
let mut default_headers = config.headers.clone();
let timeout = config.timeout;
let auth = config.auth_header.clone();

// AuthLayer is not necessary unless `RefreshableToken`
if let Authentication::Basic(value) = &auth {
default_headers.insert(
http::header::AUTHORIZATION,
HeaderValue::from_str(value).map_err(ConfigError::InvalidBasicAuth)?,
);
} else if let Authentication::Token(value) = &auth {
default_headers.insert(
http::header::AUTHORIZATION,
HeaderValue::from_str(value).map_err(ConfigError::InvalidBearerToken)?,
);
}

let common = ServiceBuilder::new()
.map_request(move |r| set_cluster_url(r, &cluster_url))
.map_request(move |r| set_default_headers(r, default_headers.clone()))
.map_request(accept_compressed)
.map_response(maybe_decompress)
.into_inner();

let https: HttpsConnector<_> = config.try_into()?;
let mut connector = TimeoutConnector::new(https);
if let Some(timeout) = timeout {
Expand All @@ -81,14 +101,20 @@ impl TryFrom<Config> for Service {
}
let client: HyperClient<_, Body> = HyperClient::builder().build(connector);

let inner = ServiceBuilder::new()
.map_request(move |r| set_cluster_url(r, &cluster_url))
.map_request(move |r| set_default_headers(r, default_headers.clone()))
.map_request(accept_compressed)
.map_response(maybe_decompress)
.layer(AuthLayer::new(auth))
.layer(tower::layer::layer_fn(LogRequest::new))
.service(client);
Ok(Self::new(inner))
if let Authentication::RefreshableToken(refreshable) = auth {
let inner = ServiceBuilder::new()
.layer(common)
.layer(AuthLayer::new(refreshable))
.layer(tower::layer::layer_fn(LogRequest::new))
.service(client);
Ok(Self::new(inner))
} else {
let inner = ServiceBuilder::new()
.layer(common)
.map_err(BoxError::from)
.layer(tower::layer::layer_fn(LogRequest::new))
.service(client);
Ok(Self::new(inner))
}
}
}

0 comments on commit cff35fb

Please sign in to comment.