Skip to content

Commit d8fcb45

Browse files
committed
new: added cookies persistency for http namespace
1 parent 35c63c7 commit d8fcb45

File tree

5 files changed

+131
-21
lines changed

5 files changed

+131
-21
lines changed

Cargo.lock

+72-3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ regex = "1.10.5"
3535
serde = { version = "1.0.203", features = ["derive", "serde_derive"] }
3636
serde_trim = "1.1.0"
3737
serde_yaml = "0.9.34"
38-
simple-home-dir = "0.3.5"
38+
simple-home-dir = "0.4.0"
3939
tokio = "1.38.0"
4040
xml-rs = "0.8.20"
4141
duration-string = { version = "0.4.0", optional = true }
@@ -53,6 +53,7 @@ reqwest = { version = "0.12.5", default-features = false, features = [
5353
"rustls-tls",
5454
] }
5555
url = "2.5.2"
56+
reqwest_cookie_store = "0.8.0"
5657

5758
[features]
5859
default = ["ollama", "groq", "openai", "fireworks"]

src/agent/mod.rs

+16-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::{collections::HashMap, sync::Arc, time::Duration};
1+
use std::{collections::HashMap, path::PathBuf, sync::Arc, time::Duration};
22

33
use anyhow::Result;
44
use mini_rag::Embedder;
@@ -18,6 +18,21 @@ pub mod serialization;
1818
pub mod state;
1919
pub mod task;
2020

21+
pub fn data_path(path: &str) -> Result<PathBuf> {
22+
let user_home = match simple_home_dir::home_dir() {
23+
Some(path) => path,
24+
None => return Err(anyhow!("can't get user home folder")),
25+
};
26+
27+
let inner_path = user_home.join(".nerve").join(path);
28+
if !inner_path.exists() {
29+
log::info!("creating {} ...", inner_path.display());
30+
std::fs::create_dir_all(&inner_path)?;
31+
}
32+
33+
Ok(inner_path)
34+
}
35+
2136
#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)]
2237
pub struct Invocation {
2338
pub action: String,

src/agent/namespaces/http/mod.rs

+33-4
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
use std::{
22
collections::HashMap,
3+
fs::OpenOptions,
34
str::FromStr,
5+
sync::Arc,
46
time::{Duration, Instant},
57
};
68

79
use anyhow::Result;
810
use async_trait::async_trait;
911
use colored::Colorize;
12+
use lazy_static::lazy_static;
13+
use reqwest_cookie_store::CookieStoreMutex;
1014
use url::Url;
1115

1216
use crate::agent::state::SharedState;
@@ -15,6 +19,28 @@ use super::{Action, Namespace, StorageDescriptor};
1519

1620
const DEFAULT_HTTP_SCHEMA: &str = "https";
1721

22+
lazy_static! {
23+
static ref COOKIE_STORE: Arc<CookieStoreMutex> = {
24+
let cookies_file = crate::agent::data_path("http")
25+
.unwrap()
26+
.join("cookies.json");
27+
28+
let file = OpenOptions::new()
29+
.read(true)
30+
.write(true)
31+
.create(true)
32+
.truncate(true)
33+
.open(&cookies_file)
34+
.map(std::io::BufReader::new)
35+
.unwrap_or_else(|_| panic!("can't open {}", cookies_file.display()));
36+
37+
let cookie_store = reqwest_cookie_store::CookieStore::load_json(file)
38+
.unwrap_or_else(|_| panic!("can't load {}", cookies_file.display()));
39+
40+
Arc::new(reqwest_cookie_store::CookieStoreMutex::new(cookie_store))
41+
};
42+
}
43+
1844
#[derive(Debug, Default, Clone)]
1945
struct ClearHeaders {}
2046

@@ -153,9 +179,14 @@ impl Request {
153179

154180
fn create_request(method: &str, target_url: Url) -> Result<reqwest::RequestBuilder> {
155181
let method = reqwest::Method::from_str(method)?;
156-
let mut request = reqwest::Client::new().request(method.clone(), target_url.clone());
157-
let query_str = target_url.query().unwrap_or("").to_string();
158182

183+
let mut request = reqwest::Client::builder()
184+
.cookie_provider(COOKIE_STORE.clone())
185+
.build()?
186+
.request(method.clone(), target_url.clone());
187+
188+
// get query string if any
189+
let query_str = target_url.query().unwrap_or("").to_string();
159190
// if there're parameters and we're not in GET, set them as the body
160191
if !query_str.is_empty() && !matches!(method, reqwest::Method::GET) {
161192
request = request.header(
@@ -212,8 +243,6 @@ impl Action for Request {
212243
let target_url_str = target_url.to_string();
213244
let mut request = Self::create_request(method, target_url)?;
214245

215-
// TODO: handle cookie/session persistency
216-
217246
// add defined headers
218247
for (key, value) in state.lock().await.get_storage("http-headers")?.iter() {
219248
request = request.header(key, &value.data);

src/agent/task/tasklet.rs

+8-12
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ use colored::Colorize;
99
use duration_string::DurationString;
1010
use serde::Deserialize;
1111
use serde_trim::*;
12-
use simple_home_dir::home_dir;
1312

1413
use super::{variables::interpolate_variables, Task};
1514
use crate::{
@@ -251,25 +250,22 @@ pub struct Tasklet {
251250
}
252251

253252
impl Tasklet {
254-
pub fn from_path(path: &str, defines: &Vec<String>) -> Result<Self> {
253+
pub fn from_path(tasklet_path: &str, defines: &Vec<String>) -> Result<Self> {
255254
parse_pre_defined_values(defines)?;
256255

257-
let mut ppath = PathBuf::from_str(path)?;
258-
256+
let mut tasklet_path = PathBuf::from_str(tasklet_path)?;
259257
// try to look it up in ~/.nerve/tasklets
260-
if !ppath.exists() {
261-
let in_home = home_dir()
262-
.unwrap()
263-
.join(PathBuf::from_str(".nerve/tasklets")?.join(&ppath));
258+
if !tasklet_path.exists() {
259+
let in_home = crate::agent::data_path("tasklets")?.join(&tasklet_path);
264260
if in_home.exists() {
265-
ppath = in_home;
261+
tasklet_path = in_home;
266262
}
267263
}
268264

269-
if ppath.is_dir() {
270-
Self::from_folder(ppath.to_str().unwrap())
265+
if tasklet_path.is_dir() {
266+
Self::from_folder(tasklet_path.to_str().unwrap())
271267
} else {
272-
Self::from_yaml_file(ppath.to_str().unwrap())
268+
Self::from_yaml_file(tasklet_path.to_str().unwrap())
273269
}
274270
}
275271

0 commit comments

Comments
 (0)