Skip to content

Commit 35c63c7

Browse files
committed
new: added unit tests for http namespace
1 parent 02121cf commit 35c63c7

File tree

2 files changed

+221
-15
lines changed

2 files changed

+221
-15
lines changed

src/agent/namespaces/http/mod.rs

+216-15
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ use crate::agent::state::SharedState;
1313

1414
use super::{Action, Namespace, StorageDescriptor};
1515

16+
const DEFAULT_HTTP_SCHEMA: &str = "https";
17+
1618
#[derive(Debug, Default, Clone)]
1719
struct ClearHeaders {}
1820

@@ -97,7 +99,7 @@ impl Request {
9799

98100
// add schema if not present
99101
if !http_target.contains("://") {
100-
http_target = format!("http://{http_target}");
102+
http_target = format!("{DEFAULT_HTTP_SCHEMA}://{http_target}");
101103
}
102104

103105
Url::parse(&http_target)
@@ -148,6 +150,23 @@ impl Request {
148150

149151
Ok((reason.to_string(), resp))
150152
}
153+
154+
fn create_request(method: &str, target_url: Url) -> Result<reqwest::RequestBuilder> {
155+
let method = reqwest::Method::from_str(method)?;
156+
let mut request = reqwest::Client::new().request(method.clone(), target_url.clone());
157+
let query_str = target_url.query().unwrap_or("").to_string();
158+
159+
// if there're parameters and we're not in GET, set them as the body
160+
if !query_str.is_empty() && !matches!(method, reqwest::Method::GET) {
161+
request = request.header(
162+
reqwest::header::CONTENT_TYPE,
163+
"application/x-www-form-urlencoded",
164+
);
165+
request = request.body(query_str);
166+
}
167+
168+
Ok(request)
169+
}
151170
}
152171

153172
#[async_trait]
@@ -188,33 +207,23 @@ impl Action for Request {
188207
) -> Result<Option<String>> {
189208
// create a parsed Url from the attributes, payload and HTTP_TARGET variable
190209
let attrs = attrs.unwrap();
191-
let method = reqwest::Method::from_str(attrs.get("method").unwrap())?;
210+
let method = attrs.get("method").unwrap();
192211
let target_url = Self::create_target_url_from(&state, payload.clone()).await?;
193-
let query_str = target_url.query().unwrap_or("").to_string();
212+
let target_url_str = target_url.to_string();
213+
let mut request = Self::create_request(method, target_url)?;
194214

195215
// TODO: handle cookie/session persistency
196216

197-
let mut request = reqwest::Client::new().request(method.clone(), target_url.clone());
198-
199217
// add defined headers
200218
for (key, value) in state.lock().await.get_storage("http-headers")?.iter() {
201219
request = request.header(key, &value.data);
202220
}
203221

204-
// if there're parameters and we're not in GET, set them as the body
205-
if !query_str.is_empty() && !matches!(method, reqwest::Method::GET) {
206-
request = request.header(
207-
reqwest::header::CONTENT_TYPE,
208-
"application/x-www-form-urlencoded",
209-
);
210-
request = request.body(query_str);
211-
}
212-
213222
log::info!(
214223
"{}.{} {} ...",
215224
"http".bold(),
216225
method.to_string().yellow(),
217-
target_url.to_string(),
226+
target_url_str,
218227
);
219228

220229
// perform the request
@@ -262,3 +271,195 @@ pub(crate) fn get_namespace() -> Namespace {
262271
]),
263272
)
264273
}
274+
275+
#[cfg(test)]
276+
mod tests {
277+
use std::sync::Arc;
278+
279+
use crate::agent::state::State;
280+
281+
use super::*;
282+
283+
#[derive(Debug)]
284+
struct TestTask {}
285+
286+
impl crate::agent::task::Task for TestTask {
287+
fn to_system_prompt(&self) -> Result<String> {
288+
Ok("test".to_string())
289+
}
290+
291+
fn to_prompt(&self) -> Result<String> {
292+
Ok("test".to_string())
293+
}
294+
295+
fn get_functions(&self) -> Vec<Namespace> {
296+
vec![]
297+
}
298+
}
299+
300+
struct TestEmbedder {}
301+
302+
#[async_trait]
303+
impl mini_rag::Embedder for TestEmbedder {
304+
async fn embed(&self, _text: &str) -> Result<mini_rag::Embeddings> {
305+
todo!()
306+
}
307+
}
308+
309+
#[allow(unused_variables)]
310+
async fn create_test_state(vars: Vec<(String, String)>) -> Result<SharedState> {
311+
let (tx, _rx) = crate::agent::events::create_channel();
312+
313+
let task = Box::new(TestTask {});
314+
let embedder = Box::new(TestEmbedder {});
315+
316+
let mut state = State::new(tx, task, embedder, 10).await?;
317+
318+
for (name, value) in vars {
319+
state.set_variable(name, value);
320+
}
321+
322+
Ok(Arc::new(tokio::sync::Mutex::new(state)))
323+
}
324+
325+
#[tokio::test]
326+
async fn test_parse_no_target() {
327+
let state = create_test_state(vec![]).await.unwrap();
328+
let payload = Some("/".to_string());
329+
let target_url = Request::create_target_url_from(&state, payload.clone()).await;
330+
331+
assert!(target_url.is_err());
332+
}
333+
334+
#[tokio::test]
335+
async fn test_parse_simple_get_without_schema() {
336+
let state = create_test_state(vec![(
337+
"HTTP_TARGET".to_string(),
338+
"www.example.com".to_string(),
339+
)])
340+
.await
341+
.unwrap();
342+
343+
let payload = Some("/".to_string());
344+
let target_url = Request::create_target_url_from(&state, payload.clone())
345+
.await
346+
.unwrap();
347+
348+
assert_eq!(
349+
target_url.to_string(),
350+
format!("{DEFAULT_HTTP_SCHEMA}://www.example.com/")
351+
);
352+
}
353+
354+
#[tokio::test]
355+
async fn test_parse_simple_get_with_schema() {
356+
let state = create_test_state(vec![(
357+
"HTTP_TARGET".to_string(),
358+
"ftp://www.example.com".to_string(),
359+
)])
360+
.await
361+
.unwrap();
362+
363+
let payload = Some("/".to_string());
364+
let target_url = Request::create_target_url_from(&state, payload.clone())
365+
.await
366+
.unwrap();
367+
368+
assert_eq!(target_url.to_string(), format!("ftp://www.example.com/"));
369+
}
370+
371+
#[tokio::test]
372+
async fn test_parse_simple_get_with_schema_and_port() {
373+
let state = create_test_state(vec![(
374+
"HTTP_TARGET".to_string(),
375+
"ftp://www.example.com:1012".to_string(),
376+
)])
377+
.await
378+
.unwrap();
379+
380+
let payload = Some("/".to_string());
381+
let target_url = Request::create_target_url_from(&state, payload.clone())
382+
.await
383+
.unwrap();
384+
385+
assert_eq!(
386+
target_url.to_string(),
387+
format!("ftp://www.example.com:1012/")
388+
);
389+
}
390+
391+
#[tokio::test]
392+
async fn test_parse_query_string() {
393+
let state = create_test_state(vec![(
394+
"HTTP_TARGET".to_string(),
395+
"www.example.com".to_string(),
396+
)])
397+
.await
398+
.unwrap();
399+
400+
let payload = Some("/index.php?id=1&name=foo".to_string());
401+
let target_url = Request::create_target_url_from(&state, payload.clone())
402+
.await
403+
.unwrap();
404+
405+
assert_eq!(
406+
target_url.to_string(),
407+
format!("{DEFAULT_HTTP_SCHEMA}://www.example.com/index.php?id=1&name=foo")
408+
);
409+
}
410+
411+
#[tokio::test]
412+
async fn test_parse_query_string_is_escaped() {
413+
let state = create_test_state(vec![(
414+
"HTTP_TARGET".to_string(),
415+
"www.example.com".to_string(),
416+
)])
417+
.await
418+
.unwrap();
419+
420+
let payload = Some("/index.php?id=1&name=foo' or ''='".to_string());
421+
let target_url = Request::create_target_url_from(&state, payload.clone())
422+
.await
423+
.unwrap();
424+
425+
assert_eq!(
426+
target_url.to_string(),
427+
format!("{DEFAULT_HTTP_SCHEMA}://www.example.com/index.php?id=1&name=foo%27%20or%20%27%27=%27")
428+
);
429+
}
430+
#[tokio::test]
431+
async fn test_parse_body_post() {
432+
let state = create_test_state(vec![(
433+
"HTTP_TARGET".to_string(),
434+
"www.example.com".to_string(),
435+
)])
436+
.await
437+
.unwrap();
438+
439+
let method = "POST";
440+
let payload = Some("/login.php?user=admin&pass=' OR ''='".to_string());
441+
let target_url = Request::create_target_url_from(&state, payload.clone())
442+
.await
443+
.unwrap();
444+
let expected_body_string = "user=admin&pass=%27%20OR%20%27%27=%27".to_string();
445+
let expected_target_url_string = format!(
446+
"{DEFAULT_HTTP_SCHEMA}://www.example.com/login.php?{}",
447+
expected_body_string
448+
);
449+
450+
assert_eq!(target_url.to_string(), expected_target_url_string);
451+
452+
let request = Request::create_request(method, target_url)
453+
.unwrap()
454+
.build()
455+
.unwrap();
456+
457+
assert_eq!(request.method().to_string(), method.to_string());
458+
assert_eq!(request.url().to_string(), expected_target_url_string);
459+
assert!(request.body().is_some());
460+
assert_eq!(
461+
request.body().unwrap().as_bytes(),
462+
Some(expected_body_string.as_bytes())
463+
);
464+
}
465+
}

src/agent/state/mod.rs

+5
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,11 @@ impl State {
202202
self.variables.get(name)
203203
}
204204

205+
#[allow(dead_code)]
206+
pub fn set_variable(&mut self, name: String, value: String) {
207+
self.variables.insert(name, value);
208+
}
209+
205210
pub fn get_storages(&self) -> Vec<&Storage> {
206211
self.storages.values().collect()
207212
}

0 commit comments

Comments
 (0)