Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move to using Arc<dyn HttpClient> in azure_identity #799

Merged
merged 3 commits into from
Jun 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/identity/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ openssl = { version = "0.10", optional=true }
base64 = "0.13.0"
uuid = { version = "1.0", features = ["v4"] }
http = "0.2"
# work around https://github.com/rust-lang/rust/issues/63033
fix-hidden-lifetime-bug = "0.2"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😱


[dev-dependencies]
reqwest = { version = "0.11", features = ["json"], default-features = false }
Expand Down
2 changes: 1 addition & 1 deletion sdk/identity/examples/client_credentials_flow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
let http_client = azure_core::new_http_client();
// This will give you the final token to use in authorization.
let token = client_credentials_flow::perform(
http_client.as_ref(),
http_client.clone(),
&client_id,
&client_secret,
&["https://management.azure.com/"],
Expand Down
2 changes: 1 addition & 1 deletion sdk/identity/examples/client_credentials_flow_blob.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
let http_client = azure_core::new_http_client();

let token = client_credentials_flow::perform(
http_client.as_ref(),
http_client.clone(),
&client_id,
&client_secret,
&[&format!(
Expand Down
7 changes: 5 additions & 2 deletions sdk/identity/src/client_credentials_flow/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
//! let http_client = azure_core::new_http_client();
//! // This will give you the final token to use in authorization.
//! let token = client_credentials_flow::perform(
//! http_client.as_ref(),
//! http_client.clone(),
//! &client_id,
//! &client_secret,
//! &["https://management.azure.com/"],
Expand All @@ -46,11 +46,14 @@ use azure_core::{
};
use http::Method;
use login_response::LoginResponse;
use std::sync::Arc;
use url::form_urlencoded;

/// Perform the client credentials flow
#[allow(clippy::manual_async_fn)]
#[fix_hidden_lifetime_bug::fix_hidden_lifetime_bug]
pub async fn perform(
http_client: &dyn HttpClient,
http_client: Arc<dyn HttpClient>,
client_id: &oauth2::ClientId,
client_secret: &oauth2::ClientSecret,
scopes: &[&str],
Expand Down
32 changes: 18 additions & 14 deletions sdk/identity/src/device_code_flow/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,24 @@
//! You can learn more about this authorization flow [here](https://docs.microsoft.com/azure/active-directory/develop/v2-oauth2-device-code).
mod device_code_responses;

use azure_core::error::{Error, ErrorKind, Result};
use azure_core::{content_type, headers, HttpClient, Request, Response};
pub use device_code_responses::*;
use http::Method;

use async_timer::timer::new_timer;
use azure_core::{
content_type,
error::{Error, ErrorKind, Result},
headers, HttpClient, Request, Response,
};
pub use device_code_responses::*;
use futures::stream::unfold;
use http::Method;
use oauth2::ClientId;
use serde::Deserialize;
use std::{borrow::Cow, sync::Arc, time::Duration};
use url::form_urlencoded;

use std::borrow::Cow;
use std::time::Duration;

/// Start the device authorization grant flow.
/// The user has only 15 minutes to sign in (the usual value for expires_in).
pub async fn start<'a, 'b, T>(
http_client: &'a dyn HttpClient,
http_client: Arc<dyn HttpClient>,
tenant_id: T,
client_id: &'a ClientId,
scopes: &'b [&'b str],
Expand All @@ -41,7 +41,7 @@ where
let encoded = encoded.append_pair("scope", &scopes.join(" "));
let encoded = encoded.finish();

let rsp = post_form(http_client, url, encoded).await?;
let rsp = post_form(http_client.clone(), url, encoded).await?;
let rsp_status = rsp.status();
let rsp_body = rsp.into_body().await;
if !rsp_status.is_success() {
Expand Down Expand Up @@ -78,7 +78,7 @@ pub struct DeviceCodePhaseOneResponse<'a> {
// The skipped fields below do not come from the Azure answer.
// They will be added manually after deserialization
#[serde(skip)]
http_client: Option<&'a dyn HttpClient>,
http_client: Option<Arc<dyn HttpClient>>,
#[serde(skip)]
tenant_id: Cow<'a, str>,
// We store the ClientId as string instead of the original type, because it
Expand Down Expand Up @@ -122,9 +122,9 @@ impl<'a> DeviceCodePhaseOneResponse<'a> {
let encoded = encoded.append_pair("device_code", &self.device_code);
let encoded = encoded.finish();

let http_client = self.http_client.unwrap();
let http_client = self.http_client.clone().unwrap();

match post_form(http_client, url, encoded).await {
match post_form(http_client.clone(), url, encoded).await {
Ok(rsp) => {
let rsp_status = rsp.status();
let rsp_body = rsp.into_body().await;
Expand Down Expand Up @@ -166,7 +166,11 @@ impl<'a> DeviceCodePhaseOneResponse<'a> {
}
}

async fn post_form(http_client: &dyn HttpClient, url: &str, form_body: String) -> Result<Response> {
async fn post_form(
http_client: Arc<dyn HttpClient>,
url: &str,
form_body: String,
) -> Result<Response> {
let url = Request::parse_uri(url)?;
let mut req = Request::new(url, Method::POST);
req.headers_mut().insert(
Expand Down
5 changes: 4 additions & 1 deletion sdk/identity/src/refresh_token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@ use http::Method;
use oauth2::{AccessToken, ClientId, ClientSecret};
use serde::Deserialize;
use std::fmt;
use std::sync::Arc;
use url::form_urlencoded;

/// Exchange a refresh token for a new access token and refresh token
#[allow(clippy::manual_async_fn)]
#[fix_hidden_lifetime_bug::fix_hidden_lifetime_bug]
pub async fn exchange(
http_client: &dyn HttpClient,
http_client: Arc<dyn HttpClient>,
tenant_id: &str,
client_id: &ClientId,
client_secret: Option<&ClientSecret>,
Expand Down
8 changes: 4 additions & 4 deletions sdk/storage_blobs/examples/device_code_flow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
.nth(1)
.expect("please specify the storage account name as first command line parameter");

let client = reqwest::Client::new();
let http_client = azure_core::new_http_client();

// the process requires two steps. The first is to ask for
// the code to show to the user. This is done with the following
Expand All @@ -26,7 +26,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
// receive the refresh token as well.
// We are requesting access to the storage account passed as parameter.
let device_code_flow = device_code_flow::start(
&client,
http_client.clone(),
&tenant_id,
&client_id,
&[
Expand Down Expand Up @@ -73,7 +73,6 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
// this example we are creating an Azure Storage client
// using the access token.

let http_client = azure_core::new_http_client();
let storage_account_client = StorageAccountClient::new_bearer_token(
http_client.clone(),
&storage_account_name,
Expand All @@ -89,7 +88,8 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
// now let's refresh the token, if available
if let Some(refresh_token) = authorization.refresh_token() {
let refreshed_token =
refresh_token::exchange(&client, &tenant_id, &client_id, None, refresh_token).await?;
refresh_token::exchange(http_client, &tenant_id, &client_id, None, refresh_token)
.await?;
println!("refreshed token == {:#?}", refreshed_token);
}

Expand Down