-
Notifications
You must be signed in to change notification settings - Fork 104
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: Update OpenAI module with new functionality and structure i… (
#76) * refactor: Update OpenAI module with new functionality and structure improvements * refactor(llm): changing with_option of llm trait to add_options (#78) changing with_option of llm trait to add_options for redability * Feature/question answer with options (#79) * refactor: Update OpenAI module with new functionality and structure improvements * feat: adding from Tool trait to open ai fucntions * feat(chain): add ChainCallOptions to load_stuff_qa functions This commit introduces an Optional ChainCallOptions parameter to the load_stuff_qa function in both src/chain/question_answering.rs and src/chain/stuff_documents/chain.rs files to allow setting options while loading question and answering. A new load_stuff_qa_with_options function was also added to the StuffDocument struct to handle the new option. * feat: streamline LLMChainBuilder in question_answering.rs Removed the unnecessary mutable declaration on and consolidated the assignment and building steps into one. All changes made in . * refactor: change order for redability * feat: add OpenAiToolAgent and related changes Implemented OpenAiToolAgent functionality allowing a tool to function as an AI agent, using a new OpenAiToolAgentBuilder. This involved various file changes and additions across several modules. A new function was added to the CallOptions struct, merge_options, to help manage setting options between the default set and those incoming from user configurations. The AgentAction struct now includes a "log" field, and a new struct, LogTools, was introduced. Modified several tool-related files (WebScrapper, SerpApi, Wolfram) to utilize the new interface and updated traits provided by the OpenAiToolAgent. This included changing the function signature for the function to and additionally revising the input type from string to serde_json::Value. In the messages module, added a new MessageType variant for ToolMessage and included a new function to create new Tool messages. Further, "tool_calls" was added as an optional field to the Message struct. New schemas, FunctionCallBehavior and FunctionDefinition, were added to replace those removed from the language_models options module. A corresponding tools_openai_like module was added under schemas to house these additions. * Feature/question answer with options (#79) * refactor: Update OpenAI module with new functionality and structure improvements * feat: adding from Tool trait to open ai fucntions * feat(chain): add ChainCallOptions to load_stuff_qa functions This commit introduces an Optional ChainCallOptions parameter to the load_stuff_qa function in both src/chain/question_answering.rs and src/chain/stuff_documents/chain.rs files to allow setting options while loading question and answering. A new load_stuff_qa_with_options function was also added to the StuffDocument struct to handle the new option. * feat: streamline LLMChainBuilder in question_answering.rs Removed the unnecessary mutable declaration on and consolidated the assignment and building steps into one. All changes made in . * refactor: delete println * refactor: delete println * feat: Refactor agent creation and add OpenAiToolAgent example This update changes the chain of responsibility in the agent creation process. Removed the OpenAiToolAgent from examples/agent.rs and created a separate OpenAiToolAgent in examples/open_ai_tools_agent.rs. It also adjusts the ability to add tools to an agent in src/agent/chat/builder.rs, changes the method to accept a slice instead of a vector. Tests were adjusted as well to match this new method signature in src/agent/chat/chat_agent.rs. * refactor: delete unused imports * chore: Update tool.rs with additional documentation and method implementations - Added a method to provide a description of the tool - Updated formatting and comments for better readability - Implemented methods for parsing input, running the tool, and calling the tool asynchronously This commit improves the documentation and functionality of the tool.rs file. * chore: Update code with additional documentation * fix: make open ai work if agent return multiple tools * chore: trim and replace whitespace with underscores in tool names * chore: adding open ai tool agent example with multiples tools * chore:fixing examples
- Loading branch information
1 parent
88cb964
commit 31c78c3
Showing
21 changed files
with
529 additions
and
97 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
use std::{error::Error, sync::Arc}; | ||
|
||
use async_trait::async_trait; | ||
use langchain_rust::{ | ||
agent::{AgentExecutor, OpenAiToolAgentBuilder}, | ||
chain::{options::ChainCallOptions, Chain}, | ||
llm::openai::OpenAI, | ||
memory::SimpleMemory, | ||
prompt_args, | ||
tools::{SerpApi, Tool}, | ||
}; | ||
|
||
use serde_json::Value; | ||
struct Date {} | ||
|
||
#[async_trait] | ||
impl Tool for Date { | ||
fn name(&self) -> String { | ||
"Date".to_string() | ||
} | ||
fn description(&self) -> String { | ||
"Useful when you need to get the date,input is a query".to_string() | ||
} | ||
async fn run(&self, _input: Value) -> Result<String, Box<dyn Error>> { | ||
Ok("25 of november of 2025".to_string()) | ||
} | ||
} | ||
|
||
#[tokio::main] | ||
async fn main() { | ||
let llm = OpenAI::default(); | ||
let memory = SimpleMemory::new(); | ||
let serpapi_tool = SerpApi::default(); | ||
let tool_calc = Date {}; | ||
let agent = OpenAiToolAgentBuilder::new() | ||
.tools(&[Arc::new(serpapi_tool), Arc::new(tool_calc)]) | ||
.options(ChainCallOptions::new().with_max_tokens(1000)) | ||
.build(llm) | ||
.unwrap(); | ||
|
||
let executor = AgentExecutor::from_agent(agent).with_memory(memory.into()); | ||
|
||
let input_variables = prompt_args! { | ||
"input" => "Who is the creator of vim, and Whats the current date?", | ||
}; | ||
|
||
match executor.invoke(input_variables).await { | ||
Ok(result) => { | ||
println!("Result: {:?}", result); | ||
} | ||
Err(e) => panic!("Error invoking LLMChain: {:?}", e), | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,11 @@ | ||
mod agent; | ||
mod chat; | ||
pub use agent::*; | ||
|
||
mod executor; | ||
pub use executor::*; | ||
|
||
pub use agent::*; | ||
mod chat; | ||
pub use chat::*; | ||
pub use executor::*; | ||
|
||
mod open_ai_tools; | ||
pub use open_ai_tools::*; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
use std::{error::Error, sync::Arc}; | ||
|
||
use async_trait::async_trait; | ||
use serde_json::json; | ||
|
||
use crate::{ | ||
agent::Agent, | ||
chain::Chain, | ||
fmt_message, fmt_placeholder, fmt_template, message_formatter, | ||
prompt::{HumanMessagePromptTemplate, MessageFormatterStruct, PromptArgs}, | ||
schemas::{ | ||
agent::{AgentAction, AgentEvent, AgentFinish, LogTools}, | ||
messages::Message, | ||
FunctionCallResponse, | ||
}, | ||
template_jinja2, | ||
tools::Tool, | ||
}; | ||
|
||
pub struct OpenAiToolAgent { | ||
pub(crate) chain: Box<dyn Chain>, | ||
pub(crate) tools: Vec<Arc<dyn Tool>>, | ||
} | ||
|
||
impl OpenAiToolAgent { | ||
pub fn create_prompt(prefix: &str) -> Result<MessageFormatterStruct, Box<dyn Error>> { | ||
let prompt = message_formatter![ | ||
fmt_message!(Message::new_system_message(prefix)), | ||
fmt_placeholder!("chat_history"), | ||
fmt_template!(HumanMessagePromptTemplate::new(template_jinja2!( | ||
"{{input}}", | ||
"input" | ||
))), | ||
fmt_placeholder!("agent_scratchpad") | ||
]; | ||
|
||
Ok(prompt) | ||
} | ||
|
||
fn construct_scratchpad( | ||
&self, | ||
intermediate_steps: &[(AgentAction, String)], | ||
) -> Result<Vec<Message>, Box<dyn Error>> { | ||
let mut thoughts: Vec<Message> = Vec::new(); | ||
|
||
for (action, observation) in intermediate_steps { | ||
// Deserialize directly and embed in method calls to streamline code. | ||
// Extract the tool ID and tool calls from the log. | ||
let LogTools { tool_id, tools } = serde_json::from_str(&action.log)?; | ||
let tools: Vec<FunctionCallResponse> = serde_json::from_str(&tools)?; | ||
|
||
// For the first action, add an AI message with all tools called in this session. | ||
if thoughts.is_empty() { | ||
thoughts.push(Message::new_ai_message("").with_tool_calls(json!(tools))); | ||
} | ||
|
||
// Add a tool message for each observation. | ||
thoughts.push(Message::new_tool_message(observation, tool_id)); | ||
} | ||
|
||
Ok(thoughts) | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl Agent for OpenAiToolAgent { | ||
async fn plan( | ||
&self, | ||
intermediate_steps: &[(AgentAction, String)], | ||
inputs: PromptArgs, | ||
) -> Result<AgentEvent, Box<dyn Error>> { | ||
let mut inputs = inputs.clone(); | ||
let scratchpad = self.construct_scratchpad(&intermediate_steps)?; | ||
inputs.insert("agent_scratchpad".to_string(), json!(scratchpad)); | ||
let output = self.chain.call(inputs).await?.generation; | ||
match serde_json::from_str::<Vec<FunctionCallResponse>>(&output) { | ||
Ok(tools) => { | ||
let mut actions: Vec<AgentAction> = Vec::new(); | ||
for tool in tools { | ||
//Log tools will be send as log | ||
let log: LogTools = LogTools { | ||
tool_id: tool.id.clone(), | ||
tools: output.clone(), //We send the complete tools ouput, we will need it in | ||
//the open ai call | ||
}; | ||
actions.push(AgentAction { | ||
tool: tool.function.name.clone(), | ||
tool_input: tool.function.arguments.clone(), | ||
log: serde_json::to_string(&log)?, //We send this as string to minimise changes | ||
}); | ||
} | ||
return Ok(AgentEvent::Action(actions)); | ||
} | ||
Err(_) => return Ok(AgentEvent::Finish(AgentFinish { output })), | ||
} | ||
} | ||
|
||
fn get_tools(&self) -> Vec<Arc<dyn Tool>> { | ||
self.tools.clone() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
use std::{error::Error, sync::Arc}; | ||
|
||
use crate::{ | ||
chain::{options::ChainCallOptions, LLMChainBuilder}, | ||
language_models::{llm::LLM, options::CallOptions}, | ||
schemas::FunctionDefinition, | ||
tools::Tool, | ||
}; | ||
|
||
use super::{prompt::PREFIX, OpenAiToolAgent}; | ||
|
||
pub struct OpenAiToolAgentBuilder { | ||
tools: Option<Vec<Arc<dyn Tool>>>, | ||
prefix: Option<String>, | ||
options: Option<ChainCallOptions>, | ||
} | ||
|
||
impl OpenAiToolAgentBuilder { | ||
pub fn new() -> Self { | ||
Self { | ||
tools: None, | ||
prefix: None, | ||
options: None, | ||
} | ||
} | ||
|
||
pub fn tools(mut self, tools: &[Arc<dyn Tool>]) -> Self { | ||
self.tools = Some(tools.to_vec()); | ||
self | ||
} | ||
|
||
pub fn prefix<S: Into<String>>(mut self, prefix: S) -> Self { | ||
self.prefix = Some(prefix.into()); | ||
self | ||
} | ||
|
||
pub fn options(mut self, options: ChainCallOptions) -> Self { | ||
self.options = Some(options); | ||
self | ||
} | ||
|
||
pub fn build<L: LLM + 'static>(self, llm: L) -> Result<OpenAiToolAgent, Box<dyn Error>> { | ||
let tools = self.tools.unwrap_or_else(Vec::new); | ||
let prefix = self.prefix.unwrap_or_else(|| PREFIX.to_string()); | ||
let mut llm = llm; | ||
|
||
let prompt = OpenAiToolAgent::create_prompt(&prefix)?; | ||
let default_options = ChainCallOptions::default().with_max_tokens(1000); | ||
let functions = tools | ||
.iter() | ||
.map(|tool| FunctionDefinition::from_langchain_tool(tool)) | ||
.collect::<Vec<FunctionDefinition>>(); | ||
llm.add_options(CallOptions::new().with_functions(functions)); | ||
let chain = Box::new( | ||
LLMChainBuilder::new() | ||
.prompt(prompt) | ||
.llm(llm) | ||
.options(self.options.unwrap_or(default_options)) | ||
.build()?, | ||
); | ||
|
||
Ok(OpenAiToolAgent { chain, tools }) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
mod builder; | ||
pub use builder::*; | ||
|
||
mod agent; | ||
pub use agent::*; | ||
|
||
mod prompt; |
Oops, something went wrong.