diff --git a/crates/blockless-drivers/src/http_driver/reqwest_driver.rs b/crates/blockless-drivers/src/http_driver/reqwest_driver.rs index 92dbcad..0117479 100644 --- a/crates/blockless-drivers/src/http_driver/reqwest_driver.rs +++ b/crates/blockless-drivers/src/http_driver/reqwest_driver.rs @@ -1,22 +1,21 @@ -use std::{collections::HashMap, sync::Once, time::Duration, pin::Pin}; +use std::{collections::HashMap, pin::Pin, sync::Once, time::Duration}; -use bytes::{Bytes, Buf}; +use bytes::{Buf, Bytes}; use futures_util::StreamExt; -use log::{error, debug}; +use log::{debug, error}; use reqwest::Response; use crate::HttpErrorKind; use futures_core; use futures_core::Stream; -type StreamInBox = Pin> + Send >>; +type StreamInBox = Pin> + Send>>; struct StreamState { stream: StreamInBox, buffer: Option, } - enum HttpCtx { Response(Response), StreamState(StreamState), @@ -26,14 +25,10 @@ enum HttpCtx { fn get_ctx() -> Option<&'static mut HashMap> { static mut CTX: Option> = None; static CTX_ONCE: Once = Once::new(); - CTX_ONCE.call_once(||{ - unsafe { - CTX = Some(HashMap::new()); - } + CTX_ONCE.call_once(|| unsafe { + CTX = Some(HashMap::new()); }); - unsafe { - CTX.as_mut() - } + unsafe { CTX.as_mut() } } fn increase_fd() -> Option { @@ -45,10 +40,7 @@ fn increase_fd() -> Option { } /// request the url and the return the fd handle. -pub(crate) async fn http_req( - url: &str, - opts: &str, -) -> Result<(u32, i32), HttpErrorKind> { +pub(crate) async fn http_req(url: &str, opts: &str) -> Result<(u32, i32), HttpErrorKind> { let json = match json::parse(opts) { Ok(o) => o, Err(_) => return Err(HttpErrorKind::RequestError), @@ -66,14 +58,12 @@ pub(crate) async fn http_req( let connect_timeout = json["connectTimeout"] .as_u64() .map(|s| Duration::from_secs(s)); - let read_timeout = json["readTimeout"] - .as_u64() - .map(|s| Duration::from_secs(s)); + let read_timeout = json["readTimeout"].as_u64().map(|s| Duration::from_secs(s)); - // build the headers from the options json + // build the headers from the options json let mut headers = reqwest::header::HeaderMap::new(); let header_value = &json["headers"]; - + // Check if header_value is a valid string let header_obj = match json::parse(header_value.as_str().unwrap_or_default()) { Ok(o) => o, @@ -89,10 +79,11 @@ pub(crate) async fn http_req( }; // Handle possible errors from from_str - let header_value = match reqwest::header::HeaderValue::from_str(value.as_str().unwrap_or_default()) { - Ok(value) => value, - Err(_) => return Err(HttpErrorKind::HeadersValidationError), - }; + let header_value = + match reqwest::header::HeaderValue::from_str(value.as_str().unwrap_or_default()) { + Ok(value) => value, + Err(_) => return Err(HttpErrorKind::HeadersValidationError), + }; headers.insert(header_name, header_value); } @@ -129,10 +120,7 @@ pub(crate) async fn http_req( } /// read from handle -pub(crate) fn http_read_head( - fd: u32, - head: &str, -) -> Result { +pub(crate) fn http_read_head(fd: u32, head: &str) -> Result { let ctx = get_ctx().unwrap(); let respone = match ctx.get_mut(&fd) { Some(HttpCtx::Response(ref h)) => h, @@ -141,13 +129,11 @@ pub(crate) fn http_read_head( }; let headers = respone.headers(); match headers.get(head) { - Some(h) => { - match h.to_str() { - Ok(s) => Ok(s.into()), - Err(_) => Err(HttpErrorKind::InvalidEncoding), - } - } - None => Err(HttpErrorKind::HeaderNotFound) + Some(h) => match h.to_str() { + Ok(s) => Ok(s.into()), + Err(_) => Err(HttpErrorKind::InvalidEncoding), + }, + None => Err(HttpErrorKind::HeaderNotFound), } } @@ -169,7 +155,7 @@ async fn stream_read(state: &mut StreamState, dest: &mut [u8]) -> usize { loop { match state.buffer { Some(ref mut buffer) => { - let n = read_call(buffer, &mut dest[readn..]); + let n = read_call(buffer, &mut dest[readn..]); if n + readn <= dest.len() { readn += n; } @@ -195,7 +181,7 @@ async fn stream_read(state: &mut StreamState, dest: &mut [u8]) -> usize { } if readn + n < dest.len() { readn += n; - } else if n + readn == dest.len() { + } else if n + readn == dest.len() { return readn + n; } else { unreachable!("can't be happend!"); @@ -205,10 +191,7 @@ async fn stream_read(state: &mut StreamState, dest: &mut [u8]) -> usize { } } -pub async fn http_read_body( - fd: u32, - buf: &mut [u8], -) -> Result { +pub async fn http_read_body(fd: u32, buf: &mut [u8]) -> Result { let ctx = get_ctx().unwrap(); match ctx.remove(&fd) { Some(HttpCtx::Response(resp)) => { @@ -242,23 +225,25 @@ pub(crate) fn http_close(fd: u32) -> Result<(), HttpErrorKind> { #[cfg(test)] mod test { use super::*; - use bytes::BytesMut; - use tokio::runtime::{Builder, Runtime}; - use std::task::Poll; use crate::error::HttpErrorKind; - use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; + use bytes::BytesMut; use json::JsonValue; + use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; + use std::task::Poll; + use tokio::runtime::{Builder, Runtime}; struct TestStream(Vec); - + impl Stream for TestStream { type Item = reqwest::Result; - fn poll_next(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll> { + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> Poll> { let s = self.get_mut().0.pop().map(|s| Ok(s)); Poll::Ready(s) } - } fn build_headers(json_str: &str) -> Result { @@ -266,19 +251,19 @@ mod test { Ok(json) => json, Err(_) => return Err(HttpErrorKind::HeadersValidationError), }; - + let headers_value = match &parsed_json["headers"] { JsonValue::Object(obj) => obj, _ => return Err(HttpErrorKind::HeadersValidationError), }; - + let mut headers = HeaderMap::new(); for (key, value) in headers_value.iter() { let header_name = match HeaderName::from_bytes(key.as_bytes()) { Ok(name) => name, Err(_) => return Err(HttpErrorKind::HeadersValidationError), }; - + let header_value = match value.as_str() { Some(val) => match HeaderValue::from_str(val) { Ok(value) => value, @@ -286,26 +271,22 @@ mod test { }, None => return Err(HttpErrorKind::HeadersValidationError), }; - + headers.insert(header_name, header_value); } - + Ok(headers) } fn get_runtime() -> Runtime { - let rt = Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); + let rt = Builder::new_current_thread().enable_all().build().unwrap(); return rt; } - - // Test for valid headers - #[test] - fn test_valid_headers() { - let json_str = r#" + // Test for valid headers + #[test] + fn test_valid_headers() { + let json_str = r#" { "headers": { "Content-Type": "application/json", @@ -314,12 +295,12 @@ mod test { } "#; - let result = build_headers(json_str); - assert!(result.is_ok()); - let headers = result.unwrap(); - assert_eq!(headers.get("Content-Type").unwrap(), "application/json"); - assert_eq!(headers.get("Authorization").unwrap(), "Bearer token"); - } + let result = build_headers(json_str); + assert!(result.is_ok()); + let headers = result.unwrap(); + assert_eq!(headers.get("Content-Type").unwrap(), "application/json"); + assert_eq!(headers.get("Authorization").unwrap(), "Bearer token"); + } // Test for invalid JSON headers #[test] @@ -388,7 +369,7 @@ mod test { fn test_stream_read_2step() { let rt = get_runtime(); rt.block_on(async move { - let data: &[u8] = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12]; + let data: &[u8] = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; let bytes = BytesMut::from(data); let mut state = StreamState { stream: Box::pin(TestStream(vec![bytes.freeze()])), @@ -410,8 +391,8 @@ mod test { fn test_stream_read_3step() { let rt = get_runtime(); rt.block_on(async move { - let data: &[u8] = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12]; - let data2: &[u8] = &[13, 14,15,16]; + let data: &[u8] = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; + let data2: &[u8] = &[13, 14, 15, 16]; let mut state = StreamState { stream: Box::pin(TestStream(vec![Bytes::from(data2), Bytes::from(data)])), buffer: None, @@ -430,4 +411,4 @@ mod test { assert!(src == dest); }); } -} \ No newline at end of file +}