Skip to content

Commit

Permalink
fmt the code
Browse files Browse the repository at this point in the history
  • Loading branch information
Joinhack committed Jun 16, 2024
1 parent 9ea1c2a commit ad0a1e2
Showing 1 changed file with 54 additions and 73 deletions.
127 changes: 54 additions & 73 deletions crates/blockless-drivers/src/http_driver/reqwest_driver.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
use std::{collections::HashMap, sync::Once, time::Duration, pin::Pin};
use std::{collections::HashMap, pin::Pin, sync::Once, time::Duration};

use bytes::{Bytes, Buf};
use bytes::{Buf, Bytes};
use futures_util::StreamExt;
use log::{error, debug};
use log::{debug, error};
use reqwest::Response;

use crate::HttpErrorKind;
use futures_core;
use futures_core::Stream;

type StreamInBox = Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>> + Send >>;
type StreamInBox = Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>> + Send>>;

struct StreamState {
stream: StreamInBox,
buffer: Option<Bytes>,
}


enum HttpCtx {
Response(Response),
StreamState(StreamState),
Expand All @@ -26,14 +25,10 @@ enum HttpCtx {
fn get_ctx() -> Option<&'static mut HashMap<u32, HttpCtx>> {
static mut CTX: Option<HashMap<u32, HttpCtx>> = None;
static CTX_ONCE: Once = Once::new();
CTX_ONCE.call_once(||{
unsafe {
CTX = Some(HashMap::new());
}
CTX_ONCE.call_once(|| unsafe {
CTX = Some(HashMap::new());
});
unsafe {
CTX.as_mut()
}
unsafe { CTX.as_mut() }
}

fn increase_fd() -> Option<u32> {
Expand All @@ -45,10 +40,7 @@ fn increase_fd() -> Option<u32> {
}

/// request the url and the return the fd handle.
pub(crate) async fn http_req(
url: &str,
opts: &str,
) -> Result<(u32, i32), HttpErrorKind> {
pub(crate) async fn http_req(url: &str, opts: &str) -> Result<(u32, i32), HttpErrorKind> {
let json = match json::parse(opts) {
Ok(o) => o,
Err(_) => return Err(HttpErrorKind::RequestError),
Expand All @@ -66,14 +58,12 @@ pub(crate) async fn http_req(
let connect_timeout = json["connectTimeout"]
.as_u64()
.map(|s| Duration::from_secs(s));
let read_timeout = json["readTimeout"]
.as_u64()
.map(|s| Duration::from_secs(s));
let read_timeout = json["readTimeout"].as_u64().map(|s| Duration::from_secs(s));

// build the headers from the options json
// build the headers from the options json
let mut headers = reqwest::header::HeaderMap::new();
let header_value = &json["headers"];

// Check if header_value is a valid string
let header_obj = match json::parse(header_value.as_str().unwrap_or_default()) {
Ok(o) => o,
Expand All @@ -89,10 +79,11 @@ pub(crate) async fn http_req(
};

// Handle possible errors from from_str
let header_value = match reqwest::header::HeaderValue::from_str(value.as_str().unwrap_or_default()) {
Ok(value) => value,
Err(_) => return Err(HttpErrorKind::HeadersValidationError),
};
let header_value =
match reqwest::header::HeaderValue::from_str(value.as_str().unwrap_or_default()) {
Ok(value) => value,
Err(_) => return Err(HttpErrorKind::HeadersValidationError),
};

headers.insert(header_name, header_value);
}
Expand Down Expand Up @@ -129,10 +120,7 @@ pub(crate) async fn http_req(
}

/// read from handle
pub(crate) fn http_read_head(
fd: u32,
head: &str,
) -> Result<String, HttpErrorKind> {
pub(crate) fn http_read_head(fd: u32, head: &str) -> Result<String, HttpErrorKind> {
let ctx = get_ctx().unwrap();
let respone = match ctx.get_mut(&fd) {
Some(HttpCtx::Response(ref h)) => h,
Expand All @@ -141,13 +129,11 @@ pub(crate) fn http_read_head(
};
let headers = respone.headers();
match headers.get(head) {
Some(h) => {
match h.to_str() {
Ok(s) => Ok(s.into()),
Err(_) => Err(HttpErrorKind::InvalidEncoding),
}
}
None => Err(HttpErrorKind::HeaderNotFound)
Some(h) => match h.to_str() {
Ok(s) => Ok(s.into()),
Err(_) => Err(HttpErrorKind::InvalidEncoding),
},
None => Err(HttpErrorKind::HeaderNotFound),
}
}

Expand All @@ -169,7 +155,7 @@ async fn stream_read(state: &mut StreamState, dest: &mut [u8]) -> usize {
loop {
match state.buffer {
Some(ref mut buffer) => {
let n = read_call(buffer, &mut dest[readn..]);
let n = read_call(buffer, &mut dest[readn..]);
if n + readn <= dest.len() {
readn += n;
}
Expand All @@ -195,7 +181,7 @@ async fn stream_read(state: &mut StreamState, dest: &mut [u8]) -> usize {
}
if readn + n < dest.len() {
readn += n;
} else if n + readn == dest.len() {
} else if n + readn == dest.len() {
return readn + n;
} else {
unreachable!("can't be happend!");
Expand All @@ -205,10 +191,7 @@ async fn stream_read(state: &mut StreamState, dest: &mut [u8]) -> usize {
}
}

pub async fn http_read_body(
fd: u32,
buf: &mut [u8],
) -> Result<u32, HttpErrorKind> {
pub async fn http_read_body(fd: u32, buf: &mut [u8]) -> Result<u32, HttpErrorKind> {
let ctx = get_ctx().unwrap();
match ctx.remove(&fd) {
Some(HttpCtx::Response(resp)) => {
Expand Down Expand Up @@ -242,70 +225,68 @@ pub(crate) fn http_close(fd: u32) -> Result<(), HttpErrorKind> {
#[cfg(test)]
mod test {
use super::*;
use bytes::BytesMut;
use tokio::runtime::{Builder, Runtime};
use std::task::Poll;
use crate::error::HttpErrorKind;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use bytes::BytesMut;
use json::JsonValue;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use std::task::Poll;
use tokio::runtime::{Builder, Runtime};

struct TestStream(Vec<Bytes>);

impl Stream for TestStream {
type Item = reqwest::Result<Bytes>;

fn poll_next(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
fn poll_next(
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let s = self.get_mut().0.pop().map(|s| Ok(s));
Poll::Ready(s)
}

}

fn build_headers(json_str: &str) -> Result<HeaderMap, HttpErrorKind> {
let parsed_json = match json::parse(json_str) {
Ok(json) => json,
Err(_) => return Err(HttpErrorKind::HeadersValidationError),
};

let headers_value = match &parsed_json["headers"] {
JsonValue::Object(obj) => obj,
_ => return Err(HttpErrorKind::HeadersValidationError),
};

let mut headers = HeaderMap::new();
for (key, value) in headers_value.iter() {
let header_name = match HeaderName::from_bytes(key.as_bytes()) {
Ok(name) => name,
Err(_) => return Err(HttpErrorKind::HeadersValidationError),
};

let header_value = match value.as_str() {
Some(val) => match HeaderValue::from_str(val) {
Ok(value) => value,
Err(_) => return Err(HttpErrorKind::HeadersValidationError),
},
None => return Err(HttpErrorKind::HeadersValidationError),
};

headers.insert(header_name, header_value);
}

Ok(headers)
}

fn get_runtime() -> Runtime {
let rt = Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let rt = Builder::new_current_thread().enable_all().build().unwrap();
return rt;
}


// Test for valid headers
#[test]
fn test_valid_headers() {
let json_str = r#"
// Test for valid headers
#[test]
fn test_valid_headers() {
let json_str = r#"
{
"headers": {
"Content-Type": "application/json",
Expand All @@ -314,12 +295,12 @@ mod test {
}
"#;

let result = build_headers(json_str);
assert!(result.is_ok());
let headers = result.unwrap();
assert_eq!(headers.get("Content-Type").unwrap(), "application/json");
assert_eq!(headers.get("Authorization").unwrap(), "Bearer token");
}
let result = build_headers(json_str);
assert!(result.is_ok());
let headers = result.unwrap();
assert_eq!(headers.get("Content-Type").unwrap(), "application/json");
assert_eq!(headers.get("Authorization").unwrap(), "Bearer token");
}

// Test for invalid JSON headers
#[test]
Expand Down Expand Up @@ -388,7 +369,7 @@ mod test {
fn test_stream_read_2step() {
let rt = get_runtime();
rt.block_on(async move {
let data: &[u8] = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12];
let data: &[u8] = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
let bytes = BytesMut::from(data);
let mut state = StreamState {
stream: Box::pin(TestStream(vec![bytes.freeze()])),
Expand All @@ -410,8 +391,8 @@ mod test {
fn test_stream_read_3step() {
let rt = get_runtime();
rt.block_on(async move {
let data: &[u8] = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12];
let data2: &[u8] = &[13, 14,15,16];
let data: &[u8] = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
let data2: &[u8] = &[13, 14, 15, 16];
let mut state = StreamState {
stream: Box::pin(TestStream(vec![Bytes::from(data2), Bytes::from(data)])),
buffer: None,
Expand All @@ -430,4 +411,4 @@ mod test {
assert!(src == dest);
});
}
}
}

0 comments on commit ad0a1e2

Please sign in to comment.