Skip to content

Commit

Permalink
refactor: Update OpenAI module with new functionality and structure i… (
Browse files Browse the repository at this point in the history
#76)

* refactor: Update OpenAI module with new functionality and structure improvements

* refactor(llm): changing with_option of llm trait to add_options (#78)

changing with_option of llm trait to add_options for redability

* Feature/question answer with options (#79)

* refactor: Update OpenAI module with new functionality and structure improvements

* feat: adding from Tool trait to open ai fucntions

* feat(chain): add ChainCallOptions to load_stuff_qa functions

This commit introduces an Optional ChainCallOptions parameter to the load_stuff_qa function in both src/chain/question_answering.rs and src/chain/stuff_documents/chain.rs files to allow setting options while loading question and answering. A new load_stuff_qa_with_options function was also added to the StuffDocument struct to handle the new option.

* feat: streamline LLMChainBuilder in question_answering.rs

Removed the unnecessary mutable declaration on  and consolidated the  assignment and building steps into one. All changes made in .

* refactor: change order for redability

* feat: add OpenAiToolAgent and related changes

Implemented OpenAiToolAgent functionality allowing a tool to function as an AI agent, using a new OpenAiToolAgentBuilder. This involved various file changes and additions across several modules.

A new function was added to the CallOptions struct, merge_options, to help manage setting options between the default set and those incoming from user configurations.

The AgentAction struct now includes a "log" field, and a new struct, LogTools, was introduced.

Modified several tool-related files (WebScrapper, SerpApi, Wolfram) to utilize the new interface and updated traits provided by the OpenAiToolAgent. This included changing the function signature for the  function to  and additionally revising the input type from string to serde_json::Value.

In the messages module, added a new MessageType variant for ToolMessage and included a new function to create new Tool messages. Further, "tool_calls" was added as an optional field to the Message struct.

New schemas, FunctionCallBehavior and FunctionDefinition, were added to replace those removed from the language_models options module. A corresponding tools_openai_like module was added under schemas to house these additions.

* Feature/question answer with options (#79)

* refactor: Update OpenAI module with new functionality and structure improvements

* feat: adding from Tool trait to open ai fucntions

* feat(chain): add ChainCallOptions to load_stuff_qa functions

This commit introduces an Optional ChainCallOptions parameter to the load_stuff_qa function in both src/chain/question_answering.rs and src/chain/stuff_documents/chain.rs files to allow setting options while loading question and answering. A new load_stuff_qa_with_options function was also added to the StuffDocument struct to handle the new option.

* feat: streamline LLMChainBuilder in question_answering.rs

Removed the unnecessary mutable declaration on  and consolidated the  assignment and building steps into one. All changes made in .

* refactor: delete println

* refactor: delete println

* feat: Refactor agent creation and add OpenAiToolAgent example

This update changes the chain of responsibility in the agent creation process. Removed the OpenAiToolAgent from examples/agent.rs and created a separate OpenAiToolAgent in examples/open_ai_tools_agent.rs. It also adjusts the ability to add tools to an agent in src/agent/chat/builder.rs, changes the  method to accept a slice instead of a vector. Tests were adjusted as well to match this new method signature in src/agent/chat/chat_agent.rs.

* refactor: delete unused imports

* chore: Update tool.rs with additional documentation and method implementations

- Added a method to provide a description of the tool
- Updated formatting and comments for better readability
- Implemented methods for parsing input, running the tool, and calling the tool asynchronously

This commit improves the documentation and functionality of the tool.rs file.

* chore: Update code with additional documentation

* fix: make open ai work if agent return multiple tools

* chore: trim and replace whitespace with underscores in tool names

* chore: adding open ai tool agent example with multiples tools

* chore:fixing examples
  • Loading branch information
Abraxas-365 authored Mar 22, 2024
1 parent 88cb964 commit 31c78c3
Show file tree
Hide file tree
Showing 21 changed files with 529 additions and 97 deletions.
4 changes: 2 additions & 2 deletions examples/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ async fn main() {
let memory = SimpleMemory::new();
let serpapi_tool = SerpApi::default();
let agent = ConversationalAgentBuilder::new()
.tools(vec![Arc::new(serpapi_tool)])
.tools(&[Arc::new(serpapi_tool)])
.output_parser(ChatOutputParser::new().into())
.options(ChainCallOptions::new().with_max_tokens(1000))
.build(llm)
Expand All @@ -24,7 +24,7 @@ async fn main() {
let executor = AgentExecutor::from_agent(agent).with_memory(memory.into());

let input_variables = prompt_args! {
"input" => "Who is Leonardo DiCaprio's girlfriend, and how old is she?",
"input" => "Who is the creator of vim, and how old is vim",
};

match executor.invoke(input_variables).await {
Expand Down
53 changes: 53 additions & 0 deletions examples/open_ai_tools_agent.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use std::{error::Error, sync::Arc};

use async_trait::async_trait;
use langchain_rust::{
agent::{AgentExecutor, OpenAiToolAgentBuilder},
chain::{options::ChainCallOptions, Chain},
llm::openai::OpenAI,
memory::SimpleMemory,
prompt_args,
tools::{SerpApi, Tool},
};

use serde_json::Value;
struct Date {}

#[async_trait]
impl Tool for Date {
fn name(&self) -> String {
"Date".to_string()
}
fn description(&self) -> String {
"Useful when you need to get the date,input is a query".to_string()
}
async fn run(&self, _input: Value) -> Result<String, Box<dyn Error>> {
Ok("25 of november of 2025".to_string())
}
}

#[tokio::main]
async fn main() {
let llm = OpenAI::default();
let memory = SimpleMemory::new();
let serpapi_tool = SerpApi::default();
let tool_calc = Date {};
let agent = OpenAiToolAgentBuilder::new()
.tools(&[Arc::new(serpapi_tool), Arc::new(tool_calc)])
.options(ChainCallOptions::new().with_max_tokens(1000))
.build(llm)
.unwrap();

let executor = AgentExecutor::from_agent(agent).with_memory(memory.into());

let input_variables = prompt_args! {
"input" => "Who is the creator of vim, and Whats the current date?",
};

match executor.invoke(input_variables).await {
Ok(result) => {
println!("Result: {:?}", result);
}
Err(e) => panic!("Error invoking LLMChain: {:?}", e),
}
}
12 changes: 6 additions & 6 deletions src/agent/chat/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ impl ConversationalAgentBuilder {
}
}

pub fn tools(mut self, tools: Vec<Arc<dyn Tool>>) -> Self {
self.tools = Some(tools);
pub fn tools(mut self, tools: &[Arc<dyn Tool>]) -> Self {
self.tools = Some(tools.to_vec());
self
}

Expand All @@ -41,13 +41,13 @@ impl ConversationalAgentBuilder {
self
}

pub fn prefix(mut self, prefix: String) -> Self {
self.prefix = Some(prefix);
pub fn prefix<S: Into<String>>(mut self, prefix: S) -> Self {
self.prefix = Some(prefix.into());
self
}

pub fn suffix(mut self, suffix: String) -> Self {
self.suffix = Some(suffix);
pub fn suffix<S: Into<String>>(mut self, suffix: S) -> Self {
self.suffix = Some(suffix.into());
self
}

Expand Down
5 changes: 3 additions & 2 deletions src/agent/chat/chat_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ mod tests {
use std::{error::Error, sync::Arc};

use async_trait::async_trait;
use serde_json::Value;

use crate::{
agent::{
Expand All @@ -135,7 +136,7 @@ mod tests {
fn description(&self) -> String {
"Usefull to make calculations".to_string()
}
async fn call(&self, _input: &str) -> Result<String, Box<dyn Error>> {
async fn run(&self, _input: Value) -> Result<String, Box<dyn Error>> {
Ok("25".to_string())
}
}
Expand All @@ -147,7 +148,7 @@ mod tests {
let memory = SimpleMemory::new();
let tool_calc = Calc {};
let agent = ConversationalAgentBuilder::new()
.tools(vec![Arc::new(tool_calc)])
.tools(&[Arc::new(tool_calc)])
.output_parser(ChatOutputParser::new().into())
.build(llm)
.unwrap();
Expand Down
4 changes: 2 additions & 2 deletions src/agent/chat/output_parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ impl AgentOutputParser for ChatOutputParser {
output: agent_output.action_input,
}))
} else {
Ok(AgentEvent::Action(AgentAction {
Ok(AgentEvent::Action(vec![AgentAction {
tool: agent_output.action,
tool_input: agent_output.action_input,
log: text.to_string(),
}))
}]))
}
}
None => {
Expand Down
48 changes: 28 additions & 20 deletions src/agent/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ where
let mut name_to_tool = HashMap::new();
for tool in self.agent.get_tools().iter() {
log::debug!("Loading Tool:{}", tool.name());
name_to_tool.insert(tool.name(), tool.clone());
name_to_tool.insert(tool.name().trim().replace(" ", "_"), tool.clone());
}
name_to_tool
}
Expand Down Expand Up @@ -89,27 +89,35 @@ where
loop {
let agent_event = self.agent.plan(&steps, input_variables.clone()).await?;
match agent_event {
AgentEvent::Action(action) => {
log::debug!("Action: {:?}", action.tool_input);
let tool = name_to_tools.get(&action.tool).ok_or("Tool not found")?; //TODO:Check
//what to do with the error

let observation_result = tool.call(&action.tool_input).await;

let observation = match observation_result {
Ok(result) => result,
Err(err) => {
log::info!("The tool return the following error: {}", err.to_string());
if self.break_if_error {
return Err(err); // return the error immediately
} else {
format!("The tool return the following error: {}", err.to_string())
// convert the error to a string and continue
AgentEvent::Action(actions) => {
for action in actions {
log::debug!("Action: {:?}", action.tool_input);
let tool = name_to_tools.get(&action.tool).ok_or("Tool not found")?; //TODO:Check
//what to do with the error

let observation_result = tool.call(&action.tool_input).await;

let observation = match observation_result {
Ok(result) => result,
Err(err) => {
log::info!(
"The tool return the following error: {}",
err.to_string()
);
if self.break_if_error {
return Err(err); // return the error immediately
} else {
format!(
"The tool return the following error: {}",
err.to_string()
)
// convert the error to a string and continue
}
}
}
};
};

steps.push((action, observation));
steps.push((action, observation));
}
}
AgentEvent::Finish(finish) => {
if let Some(memory) = &self.memory {
Expand Down
10 changes: 7 additions & 3 deletions src/agent/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
mod agent;
mod chat;
pub use agent::*;

mod executor;
pub use executor::*;

pub use agent::*;
mod chat;
pub use chat::*;
pub use executor::*;

mod open_ai_tools;
pub use open_ai_tools::*;
101 changes: 101 additions & 0 deletions src/agent/open_ai_tools/agent.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
use std::{error::Error, sync::Arc};

use async_trait::async_trait;
use serde_json::json;

use crate::{
agent::Agent,
chain::Chain,
fmt_message, fmt_placeholder, fmt_template, message_formatter,
prompt::{HumanMessagePromptTemplate, MessageFormatterStruct, PromptArgs},
schemas::{
agent::{AgentAction, AgentEvent, AgentFinish, LogTools},
messages::Message,
FunctionCallResponse,
},
template_jinja2,
tools::Tool,
};

pub struct OpenAiToolAgent {
pub(crate) chain: Box<dyn Chain>,
pub(crate) tools: Vec<Arc<dyn Tool>>,
}

impl OpenAiToolAgent {
pub fn create_prompt(prefix: &str) -> Result<MessageFormatterStruct, Box<dyn Error>> {
let prompt = message_formatter![
fmt_message!(Message::new_system_message(prefix)),
fmt_placeholder!("chat_history"),
fmt_template!(HumanMessagePromptTemplate::new(template_jinja2!(
"{{input}}",
"input"
))),
fmt_placeholder!("agent_scratchpad")
];

Ok(prompt)
}

fn construct_scratchpad(
&self,
intermediate_steps: &[(AgentAction, String)],
) -> Result<Vec<Message>, Box<dyn Error>> {
let mut thoughts: Vec<Message> = Vec::new();

for (action, observation) in intermediate_steps {
// Deserialize directly and embed in method calls to streamline code.
// Extract the tool ID and tool calls from the log.
let LogTools { tool_id, tools } = serde_json::from_str(&action.log)?;
let tools: Vec<FunctionCallResponse> = serde_json::from_str(&tools)?;

// For the first action, add an AI message with all tools called in this session.
if thoughts.is_empty() {
thoughts.push(Message::new_ai_message("").with_tool_calls(json!(tools)));
}

// Add a tool message for each observation.
thoughts.push(Message::new_tool_message(observation, tool_id));
}

Ok(thoughts)
}
}

#[async_trait]
impl Agent for OpenAiToolAgent {
async fn plan(
&self,
intermediate_steps: &[(AgentAction, String)],
inputs: PromptArgs,
) -> Result<AgentEvent, Box<dyn Error>> {
let mut inputs = inputs.clone();
let scratchpad = self.construct_scratchpad(&intermediate_steps)?;
inputs.insert("agent_scratchpad".to_string(), json!(scratchpad));
let output = self.chain.call(inputs).await?.generation;
match serde_json::from_str::<Vec<FunctionCallResponse>>(&output) {
Ok(tools) => {
let mut actions: Vec<AgentAction> = Vec::new();
for tool in tools {
//Log tools will be send as log
let log: LogTools = LogTools {
tool_id: tool.id.clone(),
tools: output.clone(), //We send the complete tools ouput, we will need it in
//the open ai call
};
actions.push(AgentAction {
tool: tool.function.name.clone(),
tool_input: tool.function.arguments.clone(),
log: serde_json::to_string(&log)?, //We send this as string to minimise changes
});
}
return Ok(AgentEvent::Action(actions));
}
Err(_) => return Ok(AgentEvent::Finish(AgentFinish { output })),
}
}

fn get_tools(&self) -> Vec<Arc<dyn Tool>> {
self.tools.clone()
}
}
64 changes: 64 additions & 0 deletions src/agent/open_ai_tools/builder.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
use std::{error::Error, sync::Arc};

use crate::{
chain::{options::ChainCallOptions, LLMChainBuilder},
language_models::{llm::LLM, options::CallOptions},
schemas::FunctionDefinition,
tools::Tool,
};

use super::{prompt::PREFIX, OpenAiToolAgent};

pub struct OpenAiToolAgentBuilder {
tools: Option<Vec<Arc<dyn Tool>>>,
prefix: Option<String>,
options: Option<ChainCallOptions>,
}

impl OpenAiToolAgentBuilder {
pub fn new() -> Self {
Self {
tools: None,
prefix: None,
options: None,
}
}

pub fn tools(mut self, tools: &[Arc<dyn Tool>]) -> Self {
self.tools = Some(tools.to_vec());
self
}

pub fn prefix<S: Into<String>>(mut self, prefix: S) -> Self {
self.prefix = Some(prefix.into());
self
}

pub fn options(mut self, options: ChainCallOptions) -> Self {
self.options = Some(options);
self
}

pub fn build<L: LLM + 'static>(self, llm: L) -> Result<OpenAiToolAgent, Box<dyn Error>> {
let tools = self.tools.unwrap_or_else(Vec::new);
let prefix = self.prefix.unwrap_or_else(|| PREFIX.to_string());
let mut llm = llm;

let prompt = OpenAiToolAgent::create_prompt(&prefix)?;
let default_options = ChainCallOptions::default().with_max_tokens(1000);
let functions = tools
.iter()
.map(|tool| FunctionDefinition::from_langchain_tool(tool))
.collect::<Vec<FunctionDefinition>>();
llm.add_options(CallOptions::new().with_functions(functions));
let chain = Box::new(
LLMChainBuilder::new()
.prompt(prompt)
.llm(llm)
.options(self.options.unwrap_or(default_options))
.build()?,
);

Ok(OpenAiToolAgent { chain, tools })
}
}
7 changes: 7 additions & 0 deletions src/agent/open_ai_tools/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
mod builder;
pub use builder::*;

mod agent;
pub use agent::*;

mod prompt;
Loading

0 comments on commit 31c78c3

Please sign in to comment.