Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for mcp tools with rig agents + add anthropic prompt caching #213

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
832 changes: 531 additions & 301 deletions Cargo.lock

Large diffs are not rendered by default.

21 changes: 11 additions & 10 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
[workspace]
resolver = "2"
members = [
"rig-core",
"rig-lancedb",
"rig-mongodb",
"rig-neo4j",
"rig-postgres",
"rig-qdrant",
"rig-core/rig-core-derive",
"rig-sqlite",
"rig-eternalai", "rig-fastembed",
"rig-surrealdb",
"rig-core",
"rig-lancedb",
"rig-mongodb",
"rig-neo4j",
"rig-postgres",
"rig-qdrant",
"rig-core/rig-core-derive",
"rig-sqlite",
"rig-eternalai",
"rig-fastembed",
"rig-surrealdb",
]
3 changes: 2 additions & 1 deletion rig-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,16 @@ rig-derive = { version = "0.1.0", path = "./rig-core-derive", optional = true }
glob = "0.3.1"
lopdf = { version = "0.35.0", optional = true }
rayon = { version = "1.10.0", optional = true }
mcp_client_rs = { git = "https://github.com/edisontim/mcp_client_rust", branch = "ref/cleanup", default-features = false }
worker = { version = "0.5", optional = true }
bytes = "1.9.0"
async-stream = "0.3.6"

[dev-dependencies]
anyhow = "1.0.75"
assert_fs = "1.1.2"
tokio = { version = "1.34.0", features = ["full"] }
tracing-subscriber = "0.3.18"
tokio = { version = "1.34.0", features = ["full"] }
tokio-test = "0.4.4"
serde_path_to_error = "0.1.16"
base64 = "0.22.1"
Expand Down
109 changes: 98 additions & 11 deletions rig-core/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,21 +106,22 @@
//! let response = agent.prompt("What does \"glarb-glarb\" mean?").await
//! .expect("Failed to prompt the agent");
//! ```
use std::collections::HashMap;
use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc};

use futures::{stream, StreamExt, TryStreamExt};
use mcp_client_rs::{client::Client, CallToolResult, MessageContent};

use crate::{
completion::{
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
Message, Prompt, PromptError,
Message, Prompt, PromptError, ToolDefinition,
},
message::AssistantContent,
streaming::{
StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt,
StreamingResult,
},
tool::{Tool, ToolSet},
tool::{Tool, ToolDyn, ToolError, ToolSet},
vector_store::{VectorStoreError, VectorStoreIndexDyn},
};

Expand All @@ -147,8 +148,10 @@ use crate::{
pub struct Agent<M: CompletionModel> {
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
model: M,
/// Cached preamble
cached_preamble: Option<Vec<String>>,
/// System prompt
preamble: String,
preamble: Vec<String>,
/// Context documents always available to the agent
static_context: Vec<Document>,
/// Tools that are always available to the agent (identified by their name)
Expand Down Expand Up @@ -179,6 +182,7 @@ impl<M: CompletionModel> Completion<M> for Agent<M> {
let completion_request = self
.model
.completion_request(prompt)
.cached_preamble(self.cached_preamble.clone())
.preamble(self.preamble.clone())
.messages(chat_history)
.temperature_opt(self.temperature)
Expand Down Expand Up @@ -314,6 +318,7 @@ impl<M: CompletionModel> Chat for Agent<M> {
tool_call.function.arguments.to_string(),
)
.await?),
_ => unreachable!(),
}
}
}
Expand Down Expand Up @@ -342,8 +347,10 @@ impl<M: CompletionModel> Chat for Agent<M> {
pub struct AgentBuilder<M: CompletionModel> {
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
model: M,
/// Cached preamble
cached_preamble: Option<Vec<String>>,
/// System prompt
preamble: Option<String>,
preamble: Option<Vec<String>>,
/// Context documents always available to the agent
static_context: Vec<Document>,
/// Tools that are always available to the agent (by name)
Expand All @@ -366,6 +373,7 @@ impl<M: CompletionModel> AgentBuilder<M> {
pub fn new(model: M) -> Self {
Self {
model,
cached_preamble: None,
preamble: None,
static_context: vec![],
static_tools: vec![],
Expand All @@ -378,19 +386,25 @@ impl<M: CompletionModel> AgentBuilder<M> {
}
}

pub fn cached_preamble(mut self, cached_preamble: Vec<String>) -> Self {
self.cached_preamble = Some(cached_preamble);
self
}

/// Set the system prompt
pub fn preamble(mut self, preamble: &str) -> Self {
self.preamble = Some(preamble.into());
self.preamble = Some(vec![preamble.into()]);
self
}

/// Append to the preamble of the agent
pub fn append_preamble(mut self, doc: &str) -> Self {
self.preamble = Some(format!(
"{}\n{}",
self.preamble.unwrap_or_else(|| "".into()),
doc
));
self.preamble = if let Some(preamble) = self.preamble.as_mut() {
preamble.push(doc.into());
Some(preamble.to_vec())
} else {
Some(vec![doc.into()])
};
self
}

Expand Down Expand Up @@ -437,6 +451,13 @@ impl<M: CompletionModel> AgentBuilder<M> {
self
}

pub fn mcp_tool(mut self, tool: mcp_client_rs::Tool, client: Arc<Client>) -> Self {
let toolname = tool.name.clone();
self.tools.add_tool(MCPTool::from_mcp_server(tool, client));
self.static_tools.push(toolname);
self
}

/// Set the temperature of the model
pub fn temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
Expand All @@ -459,6 +480,7 @@ impl<M: CompletionModel> AgentBuilder<M> {
pub fn build(self) -> Agent<M> {
Agent {
model: self.model,
cached_preamble: self.cached_preamble,
preamble: self.preamble.unwrap_or_default(),
static_context: self.static_context,
static_tools: self.static_tools,
Expand Down Expand Up @@ -502,3 +524,68 @@ impl<M: StreamingCompletionModel> StreamingChat for Agent<M> {
.await
}
}
pub struct MCPTool {
client: Arc<Client>,
definition: mcp_client_rs::Tool,
}

impl MCPTool {
pub fn from_mcp_server(definition: mcp_client_rs::Tool, client: Arc<Client>) -> Self {
Self { client, definition }
}
}

#[derive(Debug, thiserror::Error)]
#[error("MCP tool error")]
pub struct MCPToolError(String);

impl ToolDyn for MCPTool {
fn name(&self) -> String {
self.definition.name.clone()
}

fn definition(
&self,
_prompt: String,
) -> Pin<Box<dyn Future<Output = ToolDefinition> + Send + Sync + '_>> {
Box::pin(async move {
ToolDefinition {
name: self.definition.name.clone(),
description: self.definition.description.clone(),
parameters: self.definition.input_schema.clone(),
}
})
}

fn call(
&self,
args: String,
) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + '_>> {
let client = self.client.clone();
let name = self.definition.name.clone();
let args_clone = args.clone();
let args: serde_json::Value = serde_json::from_str(&args_clone).unwrap_or_default();
Box::pin(async move {
let result: CallToolResult = client.call_tool(&name, args).await.map_err(|e| {
ToolError::ToolCallError(Box::new(MCPToolError(format!(
"Tool returned an error: {}",
e
))))
})?;
if result.is_error {
return Err(ToolError::ToolCallError(Box::new(MCPToolError(
"Tool returned an error".to_string(),
))));
}
Ok(result
.content
.into_iter()
.map(|c| match c {
MessageContent::Text { text } => text,
_ => "".to_string(),
})
.collect::<Vec<_>>()
.join(""))
})
}
}
47 changes: 47 additions & 0 deletions rig-core/src/completion/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,19 @@ pub enum UserContent {
pub enum AssistantContent {
Text(Text),
ToolCall(ToolCall),
Thinking(Thinking),
RedactedThinking(RedactedThinking),
}

#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct Thinking {
pub thinking: String,
pub signature: String,
}

#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct RedactedThinking {
pub data: String,
}

/// Tool result content containing information about a tool call and it's resulting content.
Expand Down Expand Up @@ -228,6 +241,15 @@ impl Message {
content: OneOrMany::one(AssistantContent::text(text)),
}
}

pub fn tool_result(id: impl Into<String>, content: OneOrMany<ToolResultContent>) -> Self {
Message::User {
content: OneOrMany::one(UserContent::ToolResult(ToolResult {
id: id.into(),
content,
})),
}
}
}

impl UserContent {
Expand Down Expand Up @@ -284,6 +306,10 @@ impl UserContent {
content,
})
}

pub fn is_text(&self) -> bool {
matches!(self, UserContent::Text(_))
}
}

impl AssistantContent {
Expand All @@ -306,6 +332,17 @@ impl AssistantContent {
},
})
}

pub fn thinking(thinking: impl Into<String>, signature: impl Into<String>) -> Self {
AssistantContent::Thinking(Thinking {
thinking: thinking.into(),
signature: signature.into(),
})
}

pub fn redacted_thinking(data: impl Into<String>) -> Self {
AssistantContent::RedactedThinking(RedactedThinking { data: data.into() })
}
}

impl ToolResultContent {
Expand Down Expand Up @@ -544,6 +581,16 @@ impl From<String> for UserContent {
}
}

// ================================================================
// From<Message> impls
// ================================================================

impl From<OneOrMany<AssistantContent>> for Message {
fn from(content: OneOrMany<AssistantContent>) -> Self {
Message::Assistant { content }
}
}

// ================================================================
// Error types
// ================================================================
Expand Down
20 changes: 16 additions & 4 deletions rig-core/src/completion/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,10 @@ pub trait CompletionModel: Clone + Send + Sync {
pub struct CompletionRequest {
/// The prompt to be sent to the completion model provider
pub prompt: Message,
/// The preamble to be sent to the completion model provider
pub preamble: Option<String>,
/// The preambles to be sent to the completion model provider
pub preamble: Option<Vec<String>>,
/// The cached preamble to be sent to the completion model provider
pub cached_preamble: Option<Vec<String>>,
/// The chat history to be sent to the completion model provider
pub chat_history: Vec<Message>,
/// The documents to be sent to the completion model provider
Expand Down Expand Up @@ -322,7 +324,8 @@ impl CompletionRequest {
pub struct CompletionRequestBuilder<M: CompletionModel> {
model: M,
prompt: Message,
preamble: Option<String>,
preamble: Option<Vec<String>>,
cached_preamble: Option<Vec<String>>,
chat_history: Vec<Message>,
documents: Vec<Document>,
tools: Vec<ToolDefinition>,
Expand All @@ -337,6 +340,7 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
model,
prompt: prompt.into(),
preamble: None,
cached_preamble: None,
chat_history: Vec::new(),
documents: Vec::new(),
tools: Vec::new(),
Expand All @@ -347,7 +351,7 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
}

/// Sets the preamble for the completion request.
pub fn preamble(mut self, preamble: String) -> Self {
pub fn preamble(mut self, preamble: Vec<String>) -> Self {
self.preamble = Some(preamble);
self
}
Expand Down Expand Up @@ -444,11 +448,18 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
self
}

/// Sets the cached preamble for the completion request.
pub fn cached_preamble(mut self, cached_preamble: Option<Vec<String>>) -> Self {
self.cached_preamble = cached_preamble;
self
}

/// Builds the completion request.
pub fn build(self) -> CompletionRequest {
CompletionRequest {
prompt: self.prompt,
preamble: self.preamble,
cached_preamble: self.cached_preamble,
chat_history: self.chat_history,
documents: self.documents,
tools: self.tools,
Expand Down Expand Up @@ -529,6 +540,7 @@ mod tests {
let request = CompletionRequest {
prompt: "What is the capital of France?".into(),
preamble: None,
cached_preamble: None,
chat_history: Vec::new(),
documents: vec![doc1, doc2],
tools: Vec::new(),
Expand Down
Loading