Skip to content

Commit

Permalink
Switch to url crate
Browse files Browse the repository at this point in the history
  • Loading branch information
rylev committed Jun 16, 2022
1 parent 2478841 commit 81c7b46
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 98 deletions.
2 changes: 1 addition & 1 deletion sdk/core/src/http_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ impl HttpClient for reqwest::Client {
}

async fn execute_request2(&self, request: &crate::Request) -> Result<crate::Response> {
let url = url::Url::parse(&request.uri().to_string())?;
let url = request.url().clone();
let mut reqwest_request = self.request(request.method(), url);
for (name, value) in request.headers().iter() {
reqwest_request = reqwest_request.header(name, value);
Expand Down
26 changes: 10 additions & 16 deletions sdk/core/src/request.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use crate::error::{ErrorKind, Result, ResultExt};
use crate::headers::{AsHeaders, Headers};
use crate::SeekableStream;
use bytes::Bytes;
use http::{Method, Uri};
use http::Method;
use std::fmt::Debug;
use std::str::FromStr;
use url::Url;

/// An HTTP Body.
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -36,29 +35,29 @@ impl From<Box<dyn SeekableStream>> for Body {
/// body. Policies are expected to enrich the request by mutating it.
#[derive(Debug, Clone)]
pub struct Request {
pub(crate) uri: Uri,
pub(crate) url: Url,
pub(crate) method: Method,
pub(crate) headers: Headers,
pub(crate) body: Body,
}

impl Request {
/// Create a new request with an empty body and no headers
pub fn new(uri: Uri, method: Method) -> Self {
pub fn new(url: Url, method: Method) -> Self {
Self {
uri,
url,
method,
headers: Headers::new(),
body: Body::Bytes(bytes::Bytes::new()),
}
}

pub fn uri(&self) -> &Uri {
&self.uri
pub fn url(&self) -> &Url {
&self.url
}

pub fn uri_mut(&mut self) -> &mut Uri {
&mut self.uri
pub fn url_mut(&mut self) -> &mut Url {
&mut self.url
}

pub fn method(&self) -> Method {
Expand Down Expand Up @@ -86,11 +85,6 @@ impl Request {
pub fn set_body(&mut self, body: impl Into<Body>) {
self.body = body.into();
}

/// Parse a `Uri` from a `str`
pub fn parse_uri(uri: &str) -> Result<Uri> {
Uri::from_str(uri).map_kind(ErrorKind::DataConversion)
}
}

/// Temporary hack to convert preexisting requests into the new format. It
Expand All @@ -99,7 +93,7 @@ impl From<http::Request<bytes::Bytes>> for Request {
fn from(request: http::Request<bytes::Bytes>) -> Self {
let (parts, body) = request.into_parts();
Self {
uri: parts.uri,
url: Url::parse(&parts.uri.to_string()).unwrap(),
method: parts.method,
headers: parts.headers.into(),
body: Body::Bytes(body),
Expand Down
8 changes: 6 additions & 2 deletions sdk/data_cosmos/src/authorization_policy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,15 @@ impl Policy for AuthorizationPolicy {

let time_nonce = TimeNonce::new();

let uri_path = &request.uri().path_and_query().unwrap().to_string()[1..];
let mut uri_path = request.url().path().to_owned();
if let Some(query) = request.url().query() {
uri_path.push('?');
uri_path.push_str(query);
}
trace!("uri_path used by AuthorizationPolicy == {:#?}", uri_path);

let auth = {
let resource_link = generate_resource_link(uri_path);
let resource_link = generate_resource_link(&uri_path);
trace!("resource_link == {}", resource_link);
generate_authorization(
&self.authorization_token,
Expand Down
4 changes: 2 additions & 2 deletions sdk/identity/src/client_credentials_flow/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ use azure_core::{
use http::Method;
use login_response::LoginResponse;
use std::sync::Arc;
use url::form_urlencoded;
use url::{form_urlencoded, Url};

/// Perform the client credentials flow
#[allow(clippy::manual_async_fn)]
Expand All @@ -66,7 +66,7 @@ pub async fn perform(
.append_pair("grant_type", "client_credentials")
.finish();

let url = Request::parse_uri(&format!(
let url = Url::parse(&format!(
"https://login.microsoftonline.com/{}/oauth2/v2.0/token",
tenant_id
))
Expand Down
6 changes: 3 additions & 3 deletions sdk/identity/src/device_code_flow/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mod device_code_responses;
use async_timer::timer::new_timer;
use azure_core::{
content_type,
error::{Error, ErrorKind, Result},
error::{Error, ErrorKind, Result, ResultExt},
headers, HttpClient, Request, Response,
};
pub use device_code_responses::*;
Expand All @@ -17,7 +17,7 @@ use http::Method;
use oauth2::ClientId;
use serde::Deserialize;
use std::{borrow::Cow, sync::Arc, time::Duration};
use url::form_urlencoded;
use url::{form_urlencoded, Url};

/// Start the device authorization grant flow.
/// The user has only 15 minutes to sign in (the usual value for expires_in).
Expand Down Expand Up @@ -171,7 +171,7 @@ async fn post_form(
url: &str,
form_body: String,
) -> Result<Response> {
let url = Request::parse_uri(url)?;
let url = Url::parse(url).map_kind(ErrorKind::DataConversion)?;
let mut req = Request::new(url, Method::POST);
req.headers_mut().insert(
headers::CONTENT_TYPE,
Expand Down
4 changes: 2 additions & 2 deletions sdk/identity/src/refresh_token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use oauth2::{AccessToken, ClientId, ClientSecret};
use serde::Deserialize;
use std::fmt;
use std::sync::Arc;
use url::form_urlencoded;
use url::{form_urlencoded, Url};

/// Exchange a refresh token for a new access token and refresh token
#[allow(clippy::manual_async_fn)]
Expand All @@ -34,7 +34,7 @@ pub async fn exchange(
let encoded = encoded.append_pair("refresh_token", refresh_token.secret());
let encoded = encoded.finish();

let url = Request::parse_uri(&format!(
let url = Url::parse(&format!(
"https://login.microsoftonline.com/{}/oauth2/v2.0/token",
tenant_id
))?;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ impl TokenCredential for ImdsManagedIdentityCredential {
"error parsing url for MSI endpoint",
)?;

let url = Request::parse_uri(url.as_str())?;
let mut req = Request::new(url, Method::GET);

req.headers_mut().insert("Metadata", "true");
Expand Down
10 changes: 8 additions & 2 deletions sdk/storage/src/account/operations/get_account_information.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,14 @@ impl GetAccountInformationBuilder {
.blob_storage_request("", http::Method::GET);

// TODO: add the query pairs
// request.uri_mut().query_pairs_mut().append_pair("restype", "account");
// request.uri_mut().query_pairs_mut().append_pair("comp", "properties");
request
.url_mut()
.query_pairs_mut()
.append_pair("restype", "account");
request
.url_mut()
.query_pairs_mut()
.append_pair("comp", "properties");

let response = self
.storage_client
Expand Down
95 changes: 27 additions & 68 deletions sdk/storage/src/authorization_policy.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use azure_core::error::{ErrorKind, ResultExt};
use azure_core::{headers::*, Context, Policy, PolicyResult, Request};
use http::header::AUTHORIZATION;
use http::{Method, Uri};
use http::Method;
use std::borrow::Cow;
use std::sync::Arc;
use url::Url;

use crate::clients::{ServiceType, StorageCredentials};

Expand Down Expand Up @@ -35,16 +37,10 @@ impl Policy for AuthorizationPolicy {
);
let request = match &self.credentials {
StorageCredentials::Key(account, key) => {
if !request
.uri()
.query()
.unwrap_or_default()
.split('&')
.any(|pair| matches!(pair.trim().split_once('='), Some(("sig", _))))
{
if !request.url().query_pairs().any(|(k, _)| &*k == "sig") {
let auth = generate_authorization(
request.headers(),
request.uri(),
request.url(),
&request.method(),
account,
key,
Expand All @@ -56,23 +52,11 @@ impl Policy for AuthorizationPolicy {
request
}
StorageCredentials::SASToken(query_pairs) => {
// TODO: switch to `url` crate.
// This is already very complex and we're not even url encoding
let query = request.uri().query();
let new = query_pairs
.iter()
.map(|(k, v)| format!("{k}={v}"))
.collect::<Vec<String>>()
.join("&");
let new = match query {
Some(existing) => format!("{existing}&{new}"),
None => format!("?{new}"),
};
let new = format!("{}{}", request.uri().path(), new);
let mut parts = request.uri().clone().into_parts();
parts.path_and_query =
Some(http::uri::PathAndQuery::from_maybe_shared(new).unwrap());
*request.uri_mut() = Uri::from_parts(parts).unwrap();
request
.url_mut()
.query_pairs_mut()
.extend_pairs(query_pairs);

request
}
StorageCredentials::BearerToken(token) => {
Expand Down Expand Up @@ -100,7 +84,7 @@ impl Policy for AuthorizationPolicy {

fn generate_authorization(
h: &Headers,
u: &Uri,
u: &Url,
method: &Method,
account: &str,
key: &str,
Expand All @@ -118,7 +102,7 @@ fn add_if_exists<'a>(h: &'a Headers, key: &'static str) -> &'a str {
#[allow(unknown_lints)]
fn string_to_sign(
h: &Headers,
u: &Uri,
u: &Url,
method: &Method,
account: &str,
service_type: &ServiceType,
Expand All @@ -135,7 +119,7 @@ fn string_to_sign(
)
}
_ => {
// content lenght must only be specified if != 0
// content length must only be specified if != 0
// this is valid from 2015-02-21
let content_length = h
.get(CONTENT_LENGTH)
Expand All @@ -160,81 +144,56 @@ fn string_to_sign(
)
}
}

// expected
// GET\n /*HTTP Verb*/
// \n /*Content-Encoding*/
// \n /*Content-Language*/
// \n /*Content-Length (include value when zero)*/
// \n /*Content-MD5*/
// \n /*Content-Type*/
// \n /*Date*/
// \n /*If-Modified-Since */
// \n /*If-Match*/
// \n /*If-None-Match*/
// \n /*If-Unmodified-Since*/
// \n /*Range*/
// x-ms-date:Sun, 11 Oct 2009 21:49:13 GMT\nx-ms-version:2009-09-19\n
// /*CanonicalizedHeaders*/
// /myaccount /mycontainer\ncomp:metadata\nrestype:container\ntimeout:20
// /*CanonicalizedResource*/
//
//
}

fn canonicalize_header(h: &Headers) -> String {
let mut v_headers = h
.iter()
.filter(|(k, _)| k.as_str().starts_with("x-ms"))
.map(|(k, _)| k.as_str().to_owned())
.map(|(k, _)| k)
.collect::<Vec<_>>();
v_headers.sort_unstable();

let mut can = String::new();

for header_name in v_headers {
let s = h.get(header_name.clone()).unwrap().as_str();
let header_name = header_name.as_str();
can = format!("{can}{header_name}:{s}\n");
}
can
}

fn canonicalized_resource_table(account: &str, u: &Uri) -> String {
fn canonicalized_resource_table(account: &str, u: &Url) -> String {
format!("/{}{}", account, u.path())
}

fn canonicalized_resource(account: &str, uri: &Uri) -> String {
fn canonicalized_resource(account: &str, uri: &Url) -> String {
let mut can_res: String = String::new();
can_res += "/";
can_res += account;

let path = uri.path();

for p in path.split('/') {
for p in uri.path_segments().into_iter().flatten() {
can_res.push('/');
can_res.push_str(&*p);
can_res.push_str(p);
}
can_res += "\n";

// query parameters
let query_pairs = uri
.query()
.unwrap_or_default()
.split('&')
.filter_map(|p| p.split_once('='));
let query_pairs = uri.query_pairs();
{
let mut qps = Vec::new();
for (q, _p) in query_pairs.clone() {
if !(qps.iter().any(|x| x == q)) {
qps.push(q.to_owned());
let mut qps: Vec<String> = Vec::new();
for (q, _) in query_pairs {
if !(qps.iter().any(|x| x == &*q)) {
qps.push(q.into_owned());
}
}

qps.sort();

for qparam in qps {
// find correct parameter
let ret = lexy_sort(query_pairs.clone(), &qparam);
let ret = lexy_sort(query_pairs, &qparam);

can_res = can_res + &qparam.to_lowercase() + ":";

Expand All @@ -253,9 +212,9 @@ fn canonicalized_resource(account: &str, uri: &Uri) -> String {
}

fn lexy_sort<'a>(
vec: impl Iterator<Item = (&'a str, &'a str)> + 'a,
vec: impl Iterator<Item = (Cow<'a, str>, Cow<'a, str>)> + 'a,
query_param: &str,
) -> Vec<&'a str> {
) -> Vec<Cow<'a, str>> {
let mut values = vec
.filter(|(k, _)| *k == query_param)
.map(|(_, v)| v)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ impl Policy for SharedKeyAuthorizationPolicy {
HeaderValue::from_str("2019-12-12")?,
); // TODO: Remove duplication with storage_account_client.rs

let url = url::Url::parse(&request.uri().to_string()).unwrap();
let url = url::Url::parse(&request.url().to_string()).unwrap();
let auth = generate_authorization(
request.headers(),
&url,
Expand Down

0 comments on commit 81c7b46

Please sign in to comment.