@@ -13,6 +13,8 @@ use crate::agent::state::SharedState;
13
13
14
14
use super :: { Action , Namespace , StorageDescriptor } ;
15
15
16
+ const DEFAULT_HTTP_SCHEMA : & str = "https" ;
17
+
16
18
#[ derive( Debug , Default , Clone ) ]
17
19
struct ClearHeaders { }
18
20
@@ -97,7 +99,7 @@ impl Request {
97
99
98
100
// add schema if not present
99
101
if !http_target. contains ( "://" ) {
100
- http_target = format ! ( "http ://{http_target}" ) ;
102
+ http_target = format ! ( "{DEFAULT_HTTP_SCHEMA} ://{http_target}" ) ;
101
103
}
102
104
103
105
Url :: parse ( & http_target)
@@ -148,6 +150,23 @@ impl Request {
148
150
149
151
Ok ( ( reason. to_string ( ) , resp) )
150
152
}
153
+
154
+ fn create_request ( method : & str , target_url : Url ) -> Result < reqwest:: RequestBuilder > {
155
+ let method = reqwest:: Method :: from_str ( method) ?;
156
+ let mut request = reqwest:: Client :: new ( ) . request ( method. clone ( ) , target_url. clone ( ) ) ;
157
+ let query_str = target_url. query ( ) . unwrap_or ( "" ) . to_string ( ) ;
158
+
159
+ // if there're parameters and we're not in GET, set them as the body
160
+ if !query_str. is_empty ( ) && !matches ! ( method, reqwest:: Method :: GET ) {
161
+ request = request. header (
162
+ reqwest:: header:: CONTENT_TYPE ,
163
+ "application/x-www-form-urlencoded" ,
164
+ ) ;
165
+ request = request. body ( query_str) ;
166
+ }
167
+
168
+ Ok ( request)
169
+ }
151
170
}
152
171
153
172
#[ async_trait]
@@ -188,33 +207,23 @@ impl Action for Request {
188
207
) -> Result < Option < String > > {
189
208
// create a parsed Url from the attributes, payload and HTTP_TARGET variable
190
209
let attrs = attrs. unwrap ( ) ;
191
- let method = reqwest :: Method :: from_str ( attrs. get ( "method" ) . unwrap ( ) ) ? ;
210
+ let method = attrs. get ( "method" ) . unwrap ( ) ;
192
211
let target_url = Self :: create_target_url_from ( & state, payload. clone ( ) ) . await ?;
193
- let query_str = target_url. query ( ) . unwrap_or ( "" ) . to_string ( ) ;
212
+ let target_url_str = target_url. to_string ( ) ;
213
+ let mut request = Self :: create_request ( method, target_url) ?;
194
214
195
215
// TODO: handle cookie/session persistency
196
216
197
- let mut request = reqwest:: Client :: new ( ) . request ( method. clone ( ) , target_url. clone ( ) ) ;
198
-
199
217
// add defined headers
200
218
for ( key, value) in state. lock ( ) . await . get_storage ( "http-headers" ) ?. iter ( ) {
201
219
request = request. header ( key, & value. data ) ;
202
220
}
203
221
204
- // if there're parameters and we're not in GET, set them as the body
205
- if !query_str. is_empty ( ) && !matches ! ( method, reqwest:: Method :: GET ) {
206
- request = request. header (
207
- reqwest:: header:: CONTENT_TYPE ,
208
- "application/x-www-form-urlencoded" ,
209
- ) ;
210
- request = request. body ( query_str) ;
211
- }
212
-
213
222
log:: info!(
214
223
"{}.{} {} ..." ,
215
224
"http" . bold( ) ,
216
225
method. to_string( ) . yellow( ) ,
217
- target_url . to_string ( ) ,
226
+ target_url_str ,
218
227
) ;
219
228
220
229
// perform the request
@@ -262,3 +271,195 @@ pub(crate) fn get_namespace() -> Namespace {
262
271
] ) ,
263
272
)
264
273
}
274
+
275
+ #[ cfg( test) ]
276
+ mod tests {
277
+ use std:: sync:: Arc ;
278
+
279
+ use crate :: agent:: state:: State ;
280
+
281
+ use super :: * ;
282
+
283
+ #[ derive( Debug ) ]
284
+ struct TestTask { }
285
+
286
+ impl crate :: agent:: task:: Task for TestTask {
287
+ fn to_system_prompt ( & self ) -> Result < String > {
288
+ Ok ( "test" . to_string ( ) )
289
+ }
290
+
291
+ fn to_prompt ( & self ) -> Result < String > {
292
+ Ok ( "test" . to_string ( ) )
293
+ }
294
+
295
+ fn get_functions ( & self ) -> Vec < Namespace > {
296
+ vec ! [ ]
297
+ }
298
+ }
299
+
300
+ struct TestEmbedder { }
301
+
302
+ #[ async_trait]
303
+ impl mini_rag:: Embedder for TestEmbedder {
304
+ async fn embed ( & self , _text : & str ) -> Result < mini_rag:: Embeddings > {
305
+ todo ! ( )
306
+ }
307
+ }
308
+
309
+ #[ allow( unused_variables) ]
310
+ async fn create_test_state ( vars : Vec < ( String , String ) > ) -> Result < SharedState > {
311
+ let ( tx, _rx) = crate :: agent:: events:: create_channel ( ) ;
312
+
313
+ let task = Box :: new ( TestTask { } ) ;
314
+ let embedder = Box :: new ( TestEmbedder { } ) ;
315
+
316
+ let mut state = State :: new ( tx, task, embedder, 10 ) . await ?;
317
+
318
+ for ( name, value) in vars {
319
+ state. set_variable ( name, value) ;
320
+ }
321
+
322
+ Ok ( Arc :: new ( tokio:: sync:: Mutex :: new ( state) ) )
323
+ }
324
+
325
+ #[ tokio:: test]
326
+ async fn test_parse_no_target ( ) {
327
+ let state = create_test_state ( vec ! [ ] ) . await . unwrap ( ) ;
328
+ let payload = Some ( "/" . to_string ( ) ) ;
329
+ let target_url = Request :: create_target_url_from ( & state, payload. clone ( ) ) . await ;
330
+
331
+ assert ! ( target_url. is_err( ) ) ;
332
+ }
333
+
334
+ #[ tokio:: test]
335
+ async fn test_parse_simple_get_without_schema ( ) {
336
+ let state = create_test_state ( vec ! [ (
337
+ "HTTP_TARGET" . to_string( ) ,
338
+ "www.example.com" . to_string( ) ,
339
+ ) ] )
340
+ . await
341
+ . unwrap ( ) ;
342
+
343
+ let payload = Some ( "/" . to_string ( ) ) ;
344
+ let target_url = Request :: create_target_url_from ( & state, payload. clone ( ) )
345
+ . await
346
+ . unwrap ( ) ;
347
+
348
+ assert_eq ! (
349
+ target_url. to_string( ) ,
350
+ format!( "{DEFAULT_HTTP_SCHEMA}://www.example.com/" )
351
+ ) ;
352
+ }
353
+
354
+ #[ tokio:: test]
355
+ async fn test_parse_simple_get_with_schema ( ) {
356
+ let state = create_test_state ( vec ! [ (
357
+ "HTTP_TARGET" . to_string( ) ,
358
+ "ftp://www.example.com" . to_string( ) ,
359
+ ) ] )
360
+ . await
361
+ . unwrap ( ) ;
362
+
363
+ let payload = Some ( "/" . to_string ( ) ) ;
364
+ let target_url = Request :: create_target_url_from ( & state, payload. clone ( ) )
365
+ . await
366
+ . unwrap ( ) ;
367
+
368
+ assert_eq ! ( target_url. to_string( ) , format!( "ftp://www.example.com/" ) ) ;
369
+ }
370
+
371
+ #[ tokio:: test]
372
+ async fn test_parse_simple_get_with_schema_and_port ( ) {
373
+ let state = create_test_state ( vec ! [ (
374
+ "HTTP_TARGET" . to_string( ) ,
375
+ "ftp://www.example.com:1012" . to_string( ) ,
376
+ ) ] )
377
+ . await
378
+ . unwrap ( ) ;
379
+
380
+ let payload = Some ( "/" . to_string ( ) ) ;
381
+ let target_url = Request :: create_target_url_from ( & state, payload. clone ( ) )
382
+ . await
383
+ . unwrap ( ) ;
384
+
385
+ assert_eq ! (
386
+ target_url. to_string( ) ,
387
+ format!( "ftp://www.example.com:1012/" )
388
+ ) ;
389
+ }
390
+
391
+ #[ tokio:: test]
392
+ async fn test_parse_query_string ( ) {
393
+ let state = create_test_state ( vec ! [ (
394
+ "HTTP_TARGET" . to_string( ) ,
395
+ "www.example.com" . to_string( ) ,
396
+ ) ] )
397
+ . await
398
+ . unwrap ( ) ;
399
+
400
+ let payload = Some ( "/index.php?id=1&name=foo" . to_string ( ) ) ;
401
+ let target_url = Request :: create_target_url_from ( & state, payload. clone ( ) )
402
+ . await
403
+ . unwrap ( ) ;
404
+
405
+ assert_eq ! (
406
+ target_url. to_string( ) ,
407
+ format!( "{DEFAULT_HTTP_SCHEMA}://www.example.com/index.php?id=1&name=foo" )
408
+ ) ;
409
+ }
410
+
411
+ #[ tokio:: test]
412
+ async fn test_parse_query_string_is_escaped ( ) {
413
+ let state = create_test_state ( vec ! [ (
414
+ "HTTP_TARGET" . to_string( ) ,
415
+ "www.example.com" . to_string( ) ,
416
+ ) ] )
417
+ . await
418
+ . unwrap ( ) ;
419
+
420
+ let payload = Some ( "/index.php?id=1&name=foo' or ''='" . to_string ( ) ) ;
421
+ let target_url = Request :: create_target_url_from ( & state, payload. clone ( ) )
422
+ . await
423
+ . unwrap ( ) ;
424
+
425
+ assert_eq ! (
426
+ target_url. to_string( ) ,
427
+ format!( "{DEFAULT_HTTP_SCHEMA}://www.example.com/index.php?id=1&name=foo%27%20or%20%27%27=%27" )
428
+ ) ;
429
+ }
430
+ #[ tokio:: test]
431
+ async fn test_parse_body_post ( ) {
432
+ let state = create_test_state ( vec ! [ (
433
+ "HTTP_TARGET" . to_string( ) ,
434
+ "www.example.com" . to_string( ) ,
435
+ ) ] )
436
+ . await
437
+ . unwrap ( ) ;
438
+
439
+ let method = "POST" ;
440
+ let payload = Some ( "/login.php?user=admin&pass=' OR ''='" . to_string ( ) ) ;
441
+ let target_url = Request :: create_target_url_from ( & state, payload. clone ( ) )
442
+ . await
443
+ . unwrap ( ) ;
444
+ let expected_body_string = "user=admin&pass=%27%20OR%20%27%27=%27" . to_string ( ) ;
445
+ let expected_target_url_string = format ! (
446
+ "{DEFAULT_HTTP_SCHEMA}://www.example.com/login.php?{}" ,
447
+ expected_body_string
448
+ ) ;
449
+
450
+ assert_eq ! ( target_url. to_string( ) , expected_target_url_string) ;
451
+
452
+ let request = Request :: create_request ( method, target_url)
453
+ . unwrap ( )
454
+ . build ( )
455
+ . unwrap ( ) ;
456
+
457
+ assert_eq ! ( request. method( ) . to_string( ) , method. to_string( ) ) ;
458
+ assert_eq ! ( request. url( ) . to_string( ) , expected_target_url_string) ;
459
+ assert ! ( request. body( ) . is_some( ) ) ;
460
+ assert_eq ! (
461
+ request. body( ) . unwrap( ) . as_bytes( ) ,
462
+ Some ( expected_body_string. as_bytes( ) )
463
+ ) ;
464
+ }
465
+ }
0 commit comments