diff --git a/sdk/core/src/http_client.rs b/sdk/core/src/http_client.rs index 84eefdc61c..c502e48158 100644 --- a/sdk/core/src/http_client.rs +++ b/sdk/core/src/http_client.rs @@ -98,7 +98,7 @@ impl HttpClient for reqwest::Client { } async fn execute_request2(&self, request: &crate::Request) -> Result { - let url = url::Url::parse(&request.uri().to_string())?; + let url = request.url().clone(); let mut reqwest_request = self.request(request.method(), url); for (name, value) in request.headers().iter() { reqwest_request = reqwest_request.header(name, value); diff --git a/sdk/core/src/request.rs b/sdk/core/src/request.rs index 3a74b3833e..d850abd7ae 100644 --- a/sdk/core/src/request.rs +++ b/sdk/core/src/request.rs @@ -1,10 +1,9 @@ -use crate::error::{ErrorKind, Result, ResultExt}; use crate::headers::{AsHeaders, Headers}; use crate::SeekableStream; use bytes::Bytes; -use http::{Method, Uri}; +use http::Method; use std::fmt::Debug; -use std::str::FromStr; +use url::Url; /// An HTTP Body. #[derive(Debug, Clone)] @@ -36,7 +35,7 @@ impl From> for Body { /// body. Policies are expected to enrich the request by mutating it. #[derive(Debug, Clone)] pub struct Request { - pub(crate) uri: Uri, + pub(crate) url: Url, pub(crate) method: Method, pub(crate) headers: Headers, pub(crate) body: Body, @@ -44,21 +43,21 @@ pub struct Request { impl Request { /// Create a new request with an empty body and no headers - pub fn new(uri: Uri, method: Method) -> Self { + pub fn new(url: Url, method: Method) -> Self { Self { - uri, + url, method, headers: Headers::new(), body: Body::Bytes(bytes::Bytes::new()), } } - pub fn uri(&self) -> &Uri { - &self.uri + pub fn url(&self) -> &Url { + &self.url } - pub fn uri_mut(&mut self) -> &mut Uri { - &mut self.uri + pub fn url_mut(&mut self) -> &mut Url { + &mut self.url } pub fn method(&self) -> Method { @@ -86,11 +85,6 @@ impl Request { pub fn set_body(&mut self, body: impl Into) { self.body = body.into(); } - - /// Parse a `Uri` from a `str` - pub fn parse_uri(uri: &str) -> Result { - Uri::from_str(uri).map_kind(ErrorKind::DataConversion) - } } /// Temporary hack to convert preexisting requests into the new format. It @@ -99,7 +93,7 @@ impl From> for Request { fn from(request: http::Request) -> Self { let (parts, body) = request.into_parts(); Self { - uri: parts.uri, + url: Url::parse(&parts.uri.to_string()).unwrap(), method: parts.method, headers: parts.headers.into(), body: Body::Bytes(body), diff --git a/sdk/data_cosmos/src/authorization_policy.rs b/sdk/data_cosmos/src/authorization_policy.rs index 27f38ebd9f..4879d6da84 100644 --- a/sdk/data_cosmos/src/authorization_policy.rs +++ b/sdk/data_cosmos/src/authorization_policy.rs @@ -57,11 +57,15 @@ impl Policy for AuthorizationPolicy { let time_nonce = TimeNonce::new(); - let uri_path = &request.uri().path_and_query().unwrap().to_string()[1..]; + let mut uri_path = request.url().path().to_owned(); + if let Some(query) = request.url().query() { + uri_path.push('?'); + uri_path.push_str(query); + } trace!("uri_path used by AuthorizationPolicy == {:#?}", uri_path); let auth = { - let resource_link = generate_resource_link(uri_path); + let resource_link = generate_resource_link(&uri_path); trace!("resource_link == {}", resource_link); generate_authorization( &self.authorization_token, diff --git a/sdk/identity/src/client_credentials_flow/mod.rs b/sdk/identity/src/client_credentials_flow/mod.rs index 9611380df0..b3a63b7f25 100644 --- a/sdk/identity/src/client_credentials_flow/mod.rs +++ b/sdk/identity/src/client_credentials_flow/mod.rs @@ -47,7 +47,7 @@ use azure_core::{ use http::Method; use login_response::LoginResponse; use std::sync::Arc; -use url::form_urlencoded; +use url::{form_urlencoded, Url}; /// Perform the client credentials flow #[allow(clippy::manual_async_fn)] @@ -66,7 +66,7 @@ pub async fn perform( .append_pair("grant_type", "client_credentials") .finish(); - let url = Request::parse_uri(&format!( + let url = Url::parse(&format!( "https://login.microsoftonline.com/{}/oauth2/v2.0/token", tenant_id )) diff --git a/sdk/identity/src/device_code_flow/mod.rs b/sdk/identity/src/device_code_flow/mod.rs index 582d2e9cf4..8a33b79cbd 100644 --- a/sdk/identity/src/device_code_flow/mod.rs +++ b/sdk/identity/src/device_code_flow/mod.rs @@ -8,7 +8,7 @@ mod device_code_responses; use async_timer::timer::new_timer; use azure_core::{ content_type, - error::{Error, ErrorKind, Result}, + error::{Error, ErrorKind, Result, ResultExt}, headers, HttpClient, Request, Response, }; pub use device_code_responses::*; @@ -17,7 +17,7 @@ use http::Method; use oauth2::ClientId; use serde::Deserialize; use std::{borrow::Cow, sync::Arc, time::Duration}; -use url::form_urlencoded; +use url::{form_urlencoded, Url}; /// Start the device authorization grant flow. /// The user has only 15 minutes to sign in (the usual value for expires_in). @@ -171,7 +171,7 @@ async fn post_form( url: &str, form_body: String, ) -> Result { - let url = Request::parse_uri(url)?; + let url = Url::parse(url).map_kind(ErrorKind::DataConversion)?; let mut req = Request::new(url, Method::POST); req.headers_mut().insert( headers::CONTENT_TYPE, diff --git a/sdk/identity/src/refresh_token.rs b/sdk/identity/src/refresh_token.rs index 1144544567..ed124ccfb0 100644 --- a/sdk/identity/src/refresh_token.rs +++ b/sdk/identity/src/refresh_token.rs @@ -10,7 +10,7 @@ use oauth2::{AccessToken, ClientId, ClientSecret}; use serde::Deserialize; use std::fmt; use std::sync::Arc; -use url::form_urlencoded; +use url::{form_urlencoded, Url}; /// Exchange a refresh token for a new access token and refresh token #[allow(clippy::manual_async_fn)] @@ -34,7 +34,7 @@ pub async fn exchange( let encoded = encoded.append_pair("refresh_token", refresh_token.secret()); let encoded = encoded.finish(); - let url = Request::parse_uri(&format!( + let url = Url::parse(&format!( "https://login.microsoftonline.com/{}/oauth2/v2.0/token", tenant_id ))?; diff --git a/sdk/identity/src/token_credentials/imds_managed_identity_credentials.rs b/sdk/identity/src/token_credentials/imds_managed_identity_credentials.rs index a751a25c5d..11efe4c04e 100644 --- a/sdk/identity/src/token_credentials/imds_managed_identity_credentials.rs +++ b/sdk/identity/src/token_credentials/imds_managed_identity_credentials.rs @@ -112,7 +112,6 @@ impl TokenCredential for ImdsManagedIdentityCredential { "error parsing url for MSI endpoint", )?; - let url = Request::parse_uri(url.as_str())?; let mut req = Request::new(url, Method::GET); req.headers_mut().insert("Metadata", "true"); diff --git a/sdk/storage/src/account/operations/get_account_information.rs b/sdk/storage/src/account/operations/get_account_information.rs index 46f30edc99..ce3f412e01 100644 --- a/sdk/storage/src/account/operations/get_account_information.rs +++ b/sdk/storage/src/account/operations/get_account_information.rs @@ -34,8 +34,14 @@ impl GetAccountInformationBuilder { .blob_storage_request("", http::Method::GET); // TODO: add the query pairs - // request.uri_mut().query_pairs_mut().append_pair("restype", "account"); - // request.uri_mut().query_pairs_mut().append_pair("comp", "properties"); + request + .url_mut() + .query_pairs_mut() + .append_pair("restype", "account"); + request + .url_mut() + .query_pairs_mut() + .append_pair("comp", "properties"); let response = self .storage_client diff --git a/sdk/storage/src/authorization_policy.rs b/sdk/storage/src/authorization_policy.rs index a260f6c510..62972e7236 100644 --- a/sdk/storage/src/authorization_policy.rs +++ b/sdk/storage/src/authorization_policy.rs @@ -1,8 +1,10 @@ use azure_core::error::{ErrorKind, ResultExt}; use azure_core::{headers::*, Context, Policy, PolicyResult, Request}; use http::header::AUTHORIZATION; -use http::{Method, Uri}; +use http::Method; +use std::borrow::Cow; use std::sync::Arc; +use url::Url; use crate::clients::{ServiceType, StorageCredentials}; @@ -35,16 +37,10 @@ impl Policy for AuthorizationPolicy { ); let request = match &self.credentials { StorageCredentials::Key(account, key) => { - if !request - .uri() - .query() - .unwrap_or_default() - .split('&') - .any(|pair| matches!(pair.trim().split_once('='), Some(("sig", _)))) - { + if !request.url().query_pairs().any(|(k, _)| &*k == "sig") { let auth = generate_authorization( request.headers(), - request.uri(), + request.url(), &request.method(), account, key, @@ -56,23 +52,11 @@ impl Policy for AuthorizationPolicy { request } StorageCredentials::SASToken(query_pairs) => { - // TODO: switch to `url` crate. - // This is already very complex and we're not even url encoding - let query = request.uri().query(); - let new = query_pairs - .iter() - .map(|(k, v)| format!("{k}={v}")) - .collect::>() - .join("&"); - let new = match query { - Some(existing) => format!("{existing}&{new}"), - None => format!("?{new}"), - }; - let new = format!("{}{}", request.uri().path(), new); - let mut parts = request.uri().clone().into_parts(); - parts.path_and_query = - Some(http::uri::PathAndQuery::from_maybe_shared(new).unwrap()); - *request.uri_mut() = Uri::from_parts(parts).unwrap(); + request + .url_mut() + .query_pairs_mut() + .extend_pairs(query_pairs); + request } StorageCredentials::BearerToken(token) => { @@ -100,7 +84,7 @@ impl Policy for AuthorizationPolicy { fn generate_authorization( h: &Headers, - u: &Uri, + u: &Url, method: &Method, account: &str, key: &str, @@ -118,7 +102,7 @@ fn add_if_exists<'a>(h: &'a Headers, key: &'static str) -> &'a str { #[allow(unknown_lints)] fn string_to_sign( h: &Headers, - u: &Uri, + u: &Url, method: &Method, account: &str, service_type: &ServiceType, @@ -135,7 +119,7 @@ fn string_to_sign( ) } _ => { - // content lenght must only be specified if != 0 + // content length must only be specified if != 0 // this is valid from 2015-02-21 let content_length = h .get(CONTENT_LENGTH) @@ -160,33 +144,13 @@ fn string_to_sign( ) } } - - // expected - // GET\n /*HTTP Verb*/ - // \n /*Content-Encoding*/ - // \n /*Content-Language*/ - // \n /*Content-Length (include value when zero)*/ - // \n /*Content-MD5*/ - // \n /*Content-Type*/ - // \n /*Date*/ - // \n /*If-Modified-Since */ - // \n /*If-Match*/ - // \n /*If-None-Match*/ - // \n /*If-Unmodified-Since*/ - // \n /*Range*/ - // x-ms-date:Sun, 11 Oct 2009 21:49:13 GMT\nx-ms-version:2009-09-19\n - // /*CanonicalizedHeaders*/ - // /myaccount /mycontainer\ncomp:metadata\nrestype:container\ntimeout:20 - // /*CanonicalizedResource*/ - // - // } fn canonicalize_header(h: &Headers) -> String { let mut v_headers = h .iter() .filter(|(k, _)| k.as_str().starts_with("x-ms")) - .map(|(k, _)| k.as_str().to_owned()) + .map(|(k, _)| k) .collect::>(); v_headers.sort_unstable(); @@ -194,39 +158,34 @@ fn canonicalize_header(h: &Headers) -> String { for header_name in v_headers { let s = h.get(header_name.clone()).unwrap().as_str(); + let header_name = header_name.as_str(); can = format!("{can}{header_name}:{s}\n"); } can } -fn canonicalized_resource_table(account: &str, u: &Uri) -> String { +fn canonicalized_resource_table(account: &str, u: &Url) -> String { format!("/{}{}", account, u.path()) } -fn canonicalized_resource(account: &str, uri: &Uri) -> String { +fn canonicalized_resource(account: &str, uri: &Url) -> String { let mut can_res: String = String::new(); can_res += "/"; can_res += account; - let path = uri.path(); - - for p in path.split('/') { + for p in uri.path_segments().into_iter().flatten() { can_res.push('/'); - can_res.push_str(&*p); + can_res.push_str(p); } can_res += "\n"; // query parameters - let query_pairs = uri - .query() - .unwrap_or_default() - .split('&') - .filter_map(|p| p.split_once('=')); + let query_pairs = uri.query_pairs(); { - let mut qps = Vec::new(); - for (q, _p) in query_pairs.clone() { - if !(qps.iter().any(|x| x == q)) { - qps.push(q.to_owned()); + let mut qps: Vec = Vec::new(); + for (q, _) in query_pairs { + if !(qps.iter().any(|x| x == &*q)) { + qps.push(q.into_owned()); } } @@ -234,7 +193,7 @@ fn canonicalized_resource(account: &str, uri: &Uri) -> String { for qparam in qps { // find correct parameter - let ret = lexy_sort(query_pairs.clone(), &qparam); + let ret = lexy_sort(query_pairs, &qparam); can_res = can_res + &qparam.to_lowercase() + ":"; @@ -253,9 +212,9 @@ fn canonicalized_resource(account: &str, uri: &Uri) -> String { } fn lexy_sort<'a>( - vec: impl Iterator + 'a, + vec: impl Iterator, Cow<'a, str>)> + 'a, query_param: &str, -) -> Vec<&'a str> { +) -> Vec> { let mut values = vec .filter(|(k, _)| *k == query_param) .map(|(_, v)| v) diff --git a/sdk/storage_datalake/src/authorization_policies/shared_key.rs b/sdk/storage_datalake/src/authorization_policies/shared_key.rs index 50bd6f0008..953f0b6bed 100644 --- a/sdk/storage_datalake/src/authorization_policies/shared_key.rs +++ b/sdk/storage_datalake/src/authorization_policies/shared_key.rs @@ -40,7 +40,7 @@ impl Policy for SharedKeyAuthorizationPolicy { HeaderValue::from_str("2019-12-12")?, ); // TODO: Remove duplication with storage_account_client.rs - let url = url::Url::parse(&request.uri().to_string()).unwrap(); + let url = url::Url::parse(&request.url().to_string()).unwrap(); let auth = generate_authorization( request.headers(), &url,