Skip to content

Commit

Permalink
new: implemented evaluation logic
Browse files Browse the repository at this point in the history
  • Loading branch information
evilsocket committed Jan 29, 2025
1 parent efea051 commit 9a43df5
Show file tree
Hide file tree
Showing 12 changed files with 233 additions and 16 deletions.
34 changes: 34 additions & 0 deletions examples/eval_test/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import sys
import json

"""
An evaluator is a script that receives the current state of the agent via stdin
and performs some evaluation, at the end of which it can:
1. Exit with a 42 status code if the task is completed successfully.
2. Exit with any other status code if the task is not completed successfully.
3. Return via stdout anything, that'll go to the chat history itself.
"""

if __name__ == "__main__":
raw = sys.stdin.read()

# just check for the number 42 in the raw input
if "42" in raw:
exit(42)

state = json.loads(raw)

# uncomment this to validate the output of a tool in the history
"""
# in this case we're looping the chat history, we could just do substring matching really ...
for message in state["chat"]["history"]["conversation"]:
if message["type"] == "feedback":
invocation = message["data"][1]
if invocation is not None:
if invocation["action"] == "solution" and "42" in invocation["payload"]:
exit(42)
"""

# add a feedback message to the chat history
print("try thinking about a funny book reference to answer")
27 changes: 27 additions & 0 deletions examples/eval_test/task.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
using: []

system_prompt: >
You are an useful assistant that resolves problems and answers questions.
prompt: >
What is the meaning of life?
evaluator:
command:
- python3
- eval.py

# python: ...


# tools are not needed here, the evaluator will just check the chat history

# functions:
# - name: Solve
# description: You will use these actions to provide the answer to the problem.
# actions:
# - name: solution
# description: "To provide the answer to the problem:"
# example_payload: foobar
# # if no tool is provided, the input payload will be returned as the output
# # so that the evaluation can be done by inspecting the chat history
11 changes: 10 additions & 1 deletion src/agent/events/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::collections::HashMap;

use serde::{Deserialize, Serialize};

mod channel;
Expand All @@ -10,6 +12,13 @@ use super::{
Invocation,
};

#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub struct StateUpdate {
pub chat: ChatOptions,
pub globals: HashMap<String, String>,
pub variables: HashMap<String, String>,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum Event {
MetricsUpdate(Metrics),
Expand All @@ -20,7 +29,7 @@ pub enum Event {
prev: Option<String>,
new: Option<String>,
},
StateUpdate(ChatOptions),
StateUpdate(StateUpdate),
EmptyResponse,
InvalidResponse(String),
InvalidAction {
Expand Down
4 changes: 2 additions & 2 deletions src/agent/generator/history.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize};

use super::Message;

#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq)]
pub enum ConversationWindow {
/// Use the history as is.
Full,
Expand Down Expand Up @@ -42,7 +42,7 @@ impl std::fmt::Display for ConversationWindow {
}
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub struct ChatHistory {
// full list of messages as is
conversation: Vec<Message>,
Expand Down
3 changes: 2 additions & 1 deletion src/agent/generator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ lazy_static! {
static ref CONN_RESET_PARSER: Regex = Regex::new(r"(?m)^.+onnection reset by peer.*").unwrap();
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub struct ChatOptions {
pub system_prompt: Option<String>,
pub prompt: String,
Expand All @@ -58,6 +58,7 @@ impl ChatOptions {
}

#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", content = "data", rename_all = "lowercase")]
pub enum Message {
Agent(String, Option<Invocation>),
Feedback(String, Option<Invocation>),
Expand Down
58 changes: 52 additions & 6 deletions src/agent/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ use anyhow::Result;
use mini_rag::Embedder;
use serde::{Deserialize, Serialize};

use events::Event;
use events::{Event, StateUpdate};
use generator::{
history::{ChatHistory, ConversationWindow},
ChatOptions, ChatResponse, Client,
};
use namespaces::Action;
use state::{SharedState, State};
use task::Task;
use task::{eval::Evaluator, Task};

pub mod events;
pub mod generator;
Expand Down Expand Up @@ -104,7 +104,11 @@ pub struct Agent {
events_chan: events::Sender,
generator: Box<dyn Client>,
state: SharedState,

task_timeout: Option<Duration>,
task_evaluator: Option<Evaluator>,
task_working_directory: Option<String>,

conversation_window: ConversationWindow,

serializer: serialization::Strategy,
Expand Down Expand Up @@ -152,6 +156,9 @@ impl Agent {
};

let task_timeout = task.get_timeout();
let task_evaluation = task.get_evaluator();
let task_working_directory = task.get_working_directory();

let state = Arc::new(tokio::sync::Mutex::new(
State::new(
events_chan.clone(),
Expand All @@ -168,6 +175,8 @@ impl Agent {
generator,
state,
task_timeout,
task_evaluator: task_evaluation,
task_working_directory,
use_native_tools_format,
user_only,
serializer,
Expand Down Expand Up @@ -247,19 +256,48 @@ impl Agent {
}

async fn on_state_update(&self, options: &ChatOptions, refresh: bool) -> Result<()> {
let mut opts = options.clone();
let mut state_update = StateUpdate {
chat: options.clone(),
globals: task::variables::get_variables(),
variables: self.state.lock().await.get_variables().clone(),
};

if refresh {
opts.system_prompt = Some(
state_update.chat.system_prompt = Some(
self.serializer
.system_prompt_for_state(&*self.state.lock().await)?,
);

let messages = self.state.lock().await.to_chat_history(&self.serializer)?;

opts.history = ChatHistory::create(messages, self.conversation_window);
state_update.chat.history = ChatHistory::create(messages, self.conversation_window);
}

// if there was a state change
if refresh {
// if this task has an evaluation strategy
if let Some(task_evaluation) = &self.task_evaluator {
// run it
let evaluation = task_evaluation
.evaluate(&state_update, &self.task_working_directory)
.await;
if let Err(e) = evaluation {
log::error!("error evaluating task: {}", e);
} else {
let evaluation = evaluation.unwrap();
if evaluation.completed {
self.state
.lock()
.await
.on_complete(false, Some("evaluation success".to_string()))?;
} else if let Some(feedback) = evaluation.feedback {
self.state.lock().await.add_feedback_to_history(feedback);
}
}
}
}

self.on_event(events::Event::StateUpdate(opts))
self.on_event(events::Event::StateUpdate(state_update))
}

// TODO: move these feedback strings to a common place
Expand Down Expand Up @@ -440,6 +478,8 @@ impl Agent {
self.on_valid_response().await;
}

let mut any_state_updates = false;

// for each parsed invocation
for mut inv in invocations {
// lookup action
Expand Down Expand Up @@ -521,13 +561,19 @@ impl Agent {
}

self.on_state_update(&options, true).await?;
any_state_updates = true;

// break the loop if we're done
if self.state.lock().await.is_complete() {
break;
}
}

// trigger a final state update if there were no state changes
if !any_state_updates {
self.on_state_update(&options, true).await?;
}

Ok(())
}

Expand Down
9 changes: 9 additions & 0 deletions src/agent/state/history.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@ impl Execution {
}
}

pub fn with_feedback(message: String) -> Self {
Self {
invocation: None,
response: None,
result: Some(message),
error: None,
}
}

pub fn with_error(invocation: Invocation, error: String) -> Self {
Self {
invocation: Some(invocation),
Expand Down
4 changes: 4 additions & 0 deletions src/agent/state/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,10 @@ impl State {
.push(Execution::with_unparsed_response(response, error));
}

pub fn add_feedback_to_history(&mut self, feedback: String) {
self.history.push(Execution::with_feedback(feedback));
}

pub fn get_action(&self, name: &str) -> Option<Box<dyn namespaces::Action>> {
for group in &self.namespaces {
for action in &group.actions {
Expand Down
67 changes: 67 additions & 0 deletions src/agent/task/eval.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
use anyhow::Result;
use serde::Deserialize;

use crate::agent::events::StateUpdate;

const SUCCESS_CODE: i32 = 42;

pub struct Evaluation {
pub completed: bool,
pub feedback: Option<String>,
}

#[derive(Default, Deserialize, Debug, Clone)]
pub struct Evaluator {
command: Vec<String>,
}

impl Evaluator {
pub async fn evaluate(
&self,
state: &StateUpdate,
working_directory: &Option<String>,
) -> Result<Evaluation> {
log::info!("📊 running evaluation ...");

let mut eval = Evaluation {
completed: false,
feedback: None,
};

let json = serde_json::to_string(&state)?;

let mut cmd = tokio::process::Command::new(&self.command[0]);
if self.command.len() > 1 {
cmd.args(&self.command[1..]);
}

if let Some(working_directory) = working_directory {
cmd.current_dir(working_directory);
}

let mut child = cmd
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()?;

// write JSON to stdin
if let Some(mut stdin) = child.stdin.take() {
tokio::io::AsyncWriteExt::write_all(&mut stdin, json.as_bytes()).await?;
}

let output = child.wait_with_output().await?;
if !output.stdout.is_empty() {
eval.feedback = Some(String::from_utf8_lossy(&output.stdout).trim().to_string());
log::info!("📊 feedback: {}", eval.feedback.as_ref().unwrap());
}

if !output.stderr.is_empty() {
log::error!("📊 {}", String::from_utf8_lossy(&output.stderr));
}

eval.completed = output.status.code() == Some(SUCCESS_CODE);

Ok(eval)
}
}
12 changes: 10 additions & 2 deletions src/agent/task/mod.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
use std::time::Duration;

use anyhow::Result;
use eval::Evaluator;

use super::namespaces::Namespace;

pub mod eval;
pub mod robopages;
pub mod tasklet;
pub mod variables;

// TODO: comment the shit out of everything.

pub trait Task: std::fmt::Debug + Send + Sync {
fn to_system_prompt(&self) -> Result<String>;
fn to_prompt(&self) -> Result<String>;
fn get_functions(&self) -> Vec<Namespace>;

fn get_working_directory(&self) -> Option<String> {
None
}

fn get_evaluator(&self) -> Option<Evaluator> {
None
}

fn get_timeout(&self) -> Option<Duration> {
None
}
Expand Down
Loading

0 comments on commit 9a43df5

Please sign in to comment.