Skip to content

Commit

Permalink
Issue 245: Async token refresh (#251)
Browse files Browse the repository at this point in the history
* async token refresh

Signed-off-by: Wenqi Mou <[email protected]>
  • Loading branch information
Wenqi Mou authored May 6, 2021
1 parent 766502e commit 94a4351
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 57 deletions.
75 changes: 41 additions & 34 deletions config/src/credentials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@ use std::fs::File;
use std::io::{BufReader, Read};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::sync::Mutex;
use std::time::SystemTime;
use tokio::runtime::Runtime;
use tokio::sync::Mutex;

pub const URL_TOKEN: &str = "/realms/{realm-name}/protocol/openid-connect/token";
pub const BASIC: &str = "Basic";
Expand Down Expand Up @@ -70,8 +69,12 @@ impl Credentials {
}
}

pub fn get_request_metadata(&self) -> String {
self.inner.get_request_metadata()
pub async fn get_request_metadata(&self) -> String {
self.inner.get_request_metadata().await
}

pub fn is_expired(&self) -> bool {
self.inner.is_expired()
}
}

Expand All @@ -85,7 +88,8 @@ impl Clone for Credentials {

#[async_trait]
trait Cred: Debug + CredClone + Send + Sync {
fn get_request_metadata(&self) -> String;
async fn get_request_metadata(&self) -> String;
fn is_expired(&self) -> bool;
}

trait CredClone {
Expand All @@ -109,9 +113,13 @@ struct Basic {

#[async_trait]
impl Cred for Basic {
fn get_request_metadata(&self) -> String {
async fn get_request_metadata(&self) -> String {
format!("{} {}", self.method, self.token)
}

fn is_expired(&self) -> bool {
false
}
}

#[derive(Debug, Clone)]
Expand All @@ -124,17 +132,23 @@ struct KeyCloak {

#[async_trait]
impl Cred for KeyCloak {
fn get_request_metadata(&self) -> String {
async fn get_request_metadata(&self) -> String {
if self.is_expired() {
self.refresh_rpt_token();
self.refresh_rpt_token().await;
}
format!("{} {}", self.method, *self.token.lock().expect("lock token"))
format!("{} {}", self.method, *self.token.lock().await)
}

fn is_expired(&self) -> bool {
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.expect("get unix time");
now.as_secs() + REFRESH_THRESHOLD_SECONDS >= self.expires_at.load(Ordering::Relaxed)
}
}

impl KeyCloak {
fn refresh_rpt_token(&self) {
let rt = Runtime::new().expect("create tokio runtime to get rpt token");
async fn refresh_rpt_token(&self) {
// read keycloak json
let file = File::open(self.path.to_string()).expect("open keycloak.json");
let mut buf_reader = BufReader::new(file);
Expand All @@ -145,39 +159,32 @@ impl KeyCloak {
let key_cloak_json: KeyCloakJson = serde_json::from_slice(&buffer).expect("decode slice to struct");

// first POST request for access token
let access_token = rt
.block_on(obtain_access_token(
&key_cloak_json.auth_server_url,
&key_cloak_json.realm,
&key_cloak_json.resource,
&key_cloak_json.credentials.secret,
))
.expect("obtain access token");
let access_token = obtain_access_token(
&key_cloak_json.auth_server_url,
&key_cloak_json.realm,
&key_cloak_json.resource,
&key_cloak_json.credentials.secret,
)
.await
.expect("obtain access token");

// second POST request for rpt
let rpt = rt
.block_on(authorize(
&key_cloak_json.auth_server_url,
&key_cloak_json.realm,
&access_token,
))
.expect("get rpt");
let rpt = authorize(
&key_cloak_json.auth_server_url,
&key_cloak_json.realm,
&access_token,
)
.await
.expect("get rpt");

let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.expect("get unix time");
let expires_at = now.as_secs() + rpt.expires_in;

*self.token.lock().expect("lock token") = rpt.access_token;
*self.token.lock().await = rpt.access_token;
self.expires_at.store(expires_at, Ordering::Relaxed);
}

fn is_expired(&self) -> bool {
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.expect("get unix time");
now.as_secs() + REFRESH_THRESHOLD_SECONDS >= self.expires_at.load(Ordering::Relaxed)
}
}

#[derive(Serialize, Deserialize)]
Expand Down
10 changes: 5 additions & 5 deletions config/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ mod tests {
assert_eq!(config.retry_policy(), RetryWithBackoff::default());
}

#[test]
fn test_extract_credentials() {
#[tokio::test]
async fn test_extract_credentials() {
// test empty env
let config = ClientConfigBuilder::default()
.controller_uri("127.0.0.2:9091".to_string())
Expand All @@ -193,7 +193,7 @@ mod tests {
let token = encode(":");

assert_eq!(
config.credentials.get_request_metadata(),
config.credentials.get_request_metadata().await,
format!("{} {}", "Basic", token)
);

Expand All @@ -209,7 +209,7 @@ mod tests {

let token = encode("hello:12345");
assert_eq!(
config.credentials.get_request_metadata(),
config.credentials.get_request_metadata().await,
format!("{} {}", "Basic", token)
);

Expand All @@ -220,7 +220,7 @@ mod tests {
.build()
.unwrap();
assert_eq!(
config.credentials.get_request_metadata(),
config.credentials.get_request_metadata().await,
format!("{} {}", "Basic", "ABCDE")
);
env::remove_var("pravega_client_auth_method");
Expand Down
Loading

0 comments on commit 94a4351

Please sign in to comment.