From 6ec809df2af52809c31622cad8c1ff27e553f88c Mon Sep 17 00:00:00 2001 From: Atris Date: Mon, 20 Jan 2025 23:55:33 +0100 Subject: [PATCH] feat: handle multiple tool calls --- Cargo.lock | 113 --------------- rig-core/Cargo.toml | 10 +- rig-core/examples/local_agent_with_tools.rs | 146 ++++++++++++++++++++ rig-core/src/agent.rs | 13 ++ rig-core/src/completion.rs | 2 + rig-core/src/providers/openai.rs | 27 ++-- rig-core/src/providers/xai/completion.rs | 2 +- 7 files changed, 185 insertions(+), 128 deletions(-) create mode 100644 rig-core/examples/local_agent_with_tools.rs diff --git a/Cargo.lock b/Cargo.lock index 72cab45e..3704d621 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -61,12 +61,6 @@ dependencies = [ "libc", ] -[[package]] -name = "anstyle" -version = "1.0.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" - [[package]] name = "anyhow" version = "1.0.93" @@ -319,21 +313,6 @@ dependencies = [ "serde_json", ] -[[package]] -name = "assert_fs" -version = "1.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7efdb1fdb47602827a342857666feb372712cbc64b414172bd6b167a02927674" -dependencies = [ - "anstyle", - "doc-comment", - "globwalk", - "predicates", - "predicates-core", - "predicates-tree", - "tempfile", -] - [[package]] name = "async-attributes" version = "1.1.2" @@ -1193,16 +1172,6 @@ dependencies = [ "uuid", ] -[[package]] -name = "bstr" -version = "1.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a68f1f47cdf0ec8ee4b941b2eee2a80cb796db73118c0dd09ac63fbe405be22" -dependencies = [ - "memchr", - "serde", -] - [[package]] name = "bumpalo" version = "3.16.0" @@ -1997,12 +1966,6 @@ dependencies = [ "syn 2.0.89", ] -[[package]] -name = "difflib" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6184e33543162437515c2e2b48714794e37845ec9851711914eec9d308f6ebe8" - [[package]] name = "digest" version = "0.10.7" @@ -2490,30 +2453,6 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" -[[package]] -name = "globset" -version = "0.4.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15f1ce686646e7f1e19bf7d5533fe443a45dbfb990e00629110797578b42fb19" -dependencies = [ - "aho-corasick", - "bstr", - "log", - "regex-automata 0.4.9", - "regex-syntax 0.8.5", -] - -[[package]] -name = "globwalk" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bf760ebf69878d9fd8f110c89703d90ce35095324d1f1edcb595c63945ee757" -dependencies = [ - "bitflags 2.6.0", - "ignore", - "walkdir", -] - [[package]] name = "gloo-timers" version = "0.3.0" @@ -3165,22 +3104,6 @@ dependencies = [ "icu_properties", ] -[[package]] -name = "ignore" -version = "0.4.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d89fd380afde86567dfba715db065673989d6253f42b88179abd3eae47bda4b" -dependencies = [ - "crossbeam-deque", - "globset", - "log", - "memchr", - "regex-automata 0.4.9", - "same-file", - "walkdir", - "winapi-util", -] - [[package]] name = "indexmap" version = "1.9.3" @@ -4738,33 +4661,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" -[[package]] -name = "predicates" -version = "3.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e9086cc7640c29a356d1a29fd134380bee9d8f79a17410aa76e7ad295f42c97" -dependencies = [ - "anstyle", - "difflib", - "predicates-core", -] - -[[package]] -name = "predicates-core" -version = "1.0.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae8177bee8e75d6846599c6b9ff679ed51e882816914eec639944d7c9aa11931" - -[[package]] -name = "predicates-tree" -version = "1.0.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41b740d195ed3166cd147c8047ec98db0e22ec019eb8eeb76d343b795304fb13" -dependencies = [ - "predicates-core", - "termtree", -] - [[package]] name = "prettyplease" version = "0.2.25" @@ -5257,8 +5153,6 @@ checksum = "4389f1d5789befaf6029ebd9f7dac4af7f7e3d61b69d4f30e2ac02b57e7712b0" name = "rig-core" version = "0.6.1" dependencies = [ - "anyhow", - "assert_fs", "futures", "glob", "lopdf", @@ -5270,7 +5164,6 @@ dependencies = [ "serde", "serde_json", "thiserror 1.0.69", - "tokio", "tokio-test", "tracing", "tracing-subscriber", @@ -6415,12 +6308,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "termtree" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" - [[package]] name = "testcontainers" version = "0.23.1" diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index 1afd7c05..d8888d2e 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -26,14 +26,11 @@ thiserror = "1.0.61" rig-derive = { version = "0.1.0", path = "./rig-core-derive", optional = true } glob = "0.3.1" lopdf = { version = "0.34.0", optional = true } -rayon = { version = "1.10.0", optional = true} +rayon = { version = "1.10.0", optional = true } worker = { version = "0.5", optional = true } [dev-dependencies] -anyhow = "1.0.75" -assert_fs = "1.1.2" -tokio = { version = "1.34.0", features = ["full"] } -tracing-subscriber = "0.3.18" +tracing-subscriber = { version = "0.3.18" } tokio-test = "0.4.4" [features] @@ -66,3 +63,6 @@ required-features = ["derive"] [[example]] name = "xai_embeddings" required-features = ["derive"] +anyhow = "1.0.75" +assert_fs = "1.1.2" +tokio = { version = "1.34.0", features = ["full"] } diff --git a/rig-core/examples/local_agent_with_tools.rs b/rig-core/examples/local_agent_with_tools.rs new file mode 100644 index 00000000..5e08ded1 --- /dev/null +++ b/rig-core/examples/local_agent_with_tools.rs @@ -0,0 +1,146 @@ +use anyhow::Result; +use rig::{ + completion::{Chat, Message, Prompt, ToolDefinition}, + providers, + tool::Tool, +}; +use serde::{Deserialize, Serialize}; +use serde_json::json; + +#[derive(Deserialize)] +struct OperationArgs { + x: i32, + y: i32, +} + +#[derive(Debug, thiserror::Error)] +#[error("Math error")] +struct MathError; + +#[derive(Deserialize, Serialize)] +struct Adder; +impl Tool for Adder { + const NAME: &'static str = "add"; + + type Error = MathError; + type Args = OperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + ToolDefinition { + name: "add".to_string(), + description: "Add x and y together".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The first number to add" + }, + "y": { + "type": "number", + "description": "The second number to add" + } + } + }), + } + } + + async fn call(&self, args: Self::Args) -> Result { + tracing::info!("Adding {} and {}", args.x, args.y); + let result = args.x + args.y; + Ok(result) + } +} + +#[derive(Deserialize, Serialize)] +struct Subtract; +impl Tool for Subtract { + const NAME: &'static str = "subtract"; + + type Error = MathError; + type Args = OperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + serde_json::from_value(json!({ + "name": "subtract", + "description": "Subtract y from x (i.e.: x - y)", + "parameters": { + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The number to substract from" + }, + "y": { + "type": "number", + "description": "The number to substract" + } + } + } + })) + .expect("Tool Definition") + } + + async fn call(&self, args: Self::Args) -> Result { + tracing::info!("Subtracting {} from {}", args.y, args.x); + let result = args.x - args.y; + Ok(result) + } +} + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + // Create local client + let local = providers::openai::Client::from_url("", "http://192.168.0.10:11434/v1"); + + let span = info_span!("calculator_agent"); + + // Create agent with a single context prompt and two tools + let calculator_agent = local + .agent("c4ai-command-r7b-12-2024-abliterated") + .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.") + .tool(Adder) + .tool(Subtract) + .max_tokens(1024) + .build(); + + // Initialize chat history + let mut chat_history = Vec::new(); + println!("Calculator Agent: Ready to help with calculations! (Type 'quit' to exit)"); + + loop { + print!("\nYou: "); + let mut input = String::new(); + std::io::stdin().read_line(&mut input)?; + let input = input.trim(); + + if input.to_lowercase() == "quit" { + break; + } + + // Add user message to history + chat_history.push(Message { + role: "user".into(), + content: input.into(), + }); + + // Get response from agent + let response = calculator_agent + .chat(input, chat_history.clone()) + .instrument(span.clone()) + .await?; + + // Add assistant's response to history + chat_history.push(Message { + role: "assistant".into(), + content: response.clone(), + }); + + println!("Calculator Agent: {}", response); + } + + println!("\nGoodbye!"); + Ok(()) +} diff --git a/rig-core/src/agent.rs b/rig-core/src/agent.rs index 7b51e5a4..1ad1aeb0 100644 --- a/rig-core/src/agent.rs +++ b/rig-core/src/agent.rs @@ -268,6 +268,19 @@ impl Chat for Agent { choice: ModelChoice::ToolCall(toolname, _, args), .. } => Ok(self.tools.call(&toolname, args.to_string()).await?), + CompletionResponse { + choice: ModelChoice::MultipleToolCalls(tool_calls), + .. + } => { + let mut results = Vec::new(); + for tool_call in tool_calls { + if let ModelChoice::ToolCall(toolname, _, args) = tool_call { + let result = self.tools.call(&toolname, args.to_string()).await?; + results.push(result); + } + } + Ok(results.join("\n")) + } } } } diff --git a/rig-core/src/completion.rs b/rig-core/src/completion.rs index a08f9bd2..6cbd292d 100644 --- a/rig-core/src/completion.rs +++ b/rig-core/src/completion.rs @@ -223,6 +223,8 @@ pub enum ModelChoice { /// Represents a completion response as a tool call of the form /// `ToolCall(function_name, id, function_params)`. ToolCall(String, String, serde_json::Value), + /// Represents a completion response with multiple tool calls + MultipleToolCalls(Vec), } /// Trait defining a completion model that can be used to generate completion responses. diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index 1cc88d08..92049e97 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -385,17 +385,26 @@ impl TryFrom for completion::CompletionResponse - { - let call = calls.first().unwrap(); + }, ..] => { + if calls.is_empty() { + return Err(CompletionError::ResponseError( + "Tool selection is empty".into(), + )); + } + + let tool_calls = calls + .iter() + .map(|call| { + Ok(completion::ModelChoice::ToolCall( + call.function.name.clone(), + "".to_owned(), + serde_json::from_str(&call.function.arguments)?, + )) + }) + .collect::, CompletionError>>()?; Ok(completion::CompletionResponse { - choice: completion::ModelChoice::ToolCall( - call.function.name.clone(), - "".to_owned(), - serde_json::from_str(&call.function.arguments)?, - ), + choice: completion::ModelChoice::MultipleToolCalls(tool_calls), raw_response: value, }) } diff --git a/rig-core/src/providers/xai/completion.rs b/rig-core/src/providers/xai/completion.rs index 1c789196..c7b48615 100644 --- a/rig-core/src/providers/xai/completion.rs +++ b/rig-core/src/providers/xai/completion.rs @@ -136,7 +136,7 @@ pub mod xai_api_types { Ok(completion::CompletionResponse { choice: completion::ModelChoice::ToolCall( call.function.name.clone(), - "".to_owned(), + call.id.clone(), serde_json::from_str(&call.function.arguments)?, ), raw_response: value,