Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new: finished implementing anthropic support (closes #23) #25

Merged
merged 2 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion nerve-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@ reqwest_cookie_store = "0.8.0"
serde_json = "1.0.120"
clap = { version = "4.5.6", features = ["derive"] }
tera = { version = "1.20.0", default-features = false }
clust = { version = "0.9.0", optional = true }

[features]
default = ["ollama", "groq", "openai", "fireworks", "hf", "novita"]
default = ["ollama", "groq", "openai", "fireworks", "hf", "novita", "anthropic"]

ollama = ["dep:ollama-rs"]
groq = ["dep:groq-api-rs", "dep:duration-string"]
openai = ["dep:openai_api_rust"]
fireworks = ["dep:openai_api_rust"]
hf = ["dep:openai_api_rust"]
novita = ["dep:openai_api_rust"]
anthropic = ["dep:clust"]
262 changes: 262 additions & 0 deletions nerve-core/src/agent/generator/anthropic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
use std::collections::HashMap;

use crate::agent::{state::SharedState, Invocation};
use anyhow::Result;
use async_trait::async_trait;
use clust::messages::{
ClaudeModel, MaxTokens, Message, MessagesRequestBody, Role, SystemPrompt, ToolDefinition,
};
use serde::{Deserialize, Serialize};

use super::{ChatOptions, Client};

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnthropicToolFunctionParameterProperty {
#[serde(rename(serialize = "type", deserialize = "type"))]
pub the_type: String,
pub description: String,
}

pub struct AnthropicClient {
model: ClaudeModel,
client: clust::Client,
}

impl AnthropicClient {
async fn get_tools_if_supported(&self, state: &SharedState) -> Vec<ToolDefinition> {
let mut tools = vec![];

// if native tool calls are supported (and XML was not forced)
if state.lock().await.native_tools_support {
// for every namespace available to the model
for group in state.lock().await.get_namespaces() {
// for every action of the namespace
for action in &group.actions {
let mut required = vec![];
let mut properties = HashMap::new();

if let Some(example) = action.example_payload() {
required.push("payload".to_string());
properties.insert(
"payload".to_string(),
AnthropicToolFunctionParameterProperty {
the_type: "string".to_string(),
description: format!(
"The main function argument, use this as a template: {}",
example
),
},
);
}

if let Some(attrs) = action.example_attributes() {
for name in attrs.keys() {
required.push(name.to_string());
properties.insert(
name.to_string(),
AnthropicToolFunctionParameterProperty {
the_type: "string".to_string(),
description: name.to_string(),
},
);
}
}

let input_schema = serde_json::json!({
"properties": properties,
"required": required,
"type": "object",
});

tools.push(ToolDefinition::new(
action.name(),
Some(action.description().to_string()),
input_schema,
));
}
}
}

tools
}
}

#[async_trait]
impl Client for AnthropicClient {
fn new(_url: &str, _port: u16, model_name: &str, _context_window: u32) -> anyhow::Result<Self> {
let model: ClaudeModel = if model_name.contains("opus") {
ClaudeModel::Claude3Opus20240229
} else if model_name.contains("sonnet") && !model_name.contains("5") {
ClaudeModel::Claude3Sonnet20240229
} else if model_name.contains("haiku") {
ClaudeModel::Claude3Haiku20240307
} else {
ClaudeModel::Claude35Sonnet20240620
};

let client = clust::Client::from_env()?;
Ok(Self { model, client })
}

async fn check_native_tools_support(&self) -> Result<bool> {
let messages = vec![Message::user("Execute the test function.")];
let max_tokens = MaxTokens::new(4096, self.model)?;

let request_body = MessagesRequestBody {
model: self.model,
system: Some(SystemPrompt::new("You are an helpful assistant.")),
messages,
max_tokens,
tools: Some(vec![ToolDefinition::new(
"test",
Some("This is a test function.".to_string()),
serde_json::json!({
"properties": {},
"required": [],
"type": "object",
}),
)]),
..Default::default()
};

let response = self.client.create_a_message(request_body).await?;

log::debug!("response = {:?}", response);

if let Ok(tool_use) = response.content.flatten_into_tool_use() {
Ok(tool_use.name == "test")
} else {
Ok(false)
}
}

async fn chat(
&self,
state: SharedState,
options: &ChatOptions,
) -> anyhow::Result<(String, Vec<Invocation>)> {
let mut messages = vec![Message::user(options.prompt.trim().to_string())];
let max_tokens = MaxTokens::new(4096, self.model)?;

for m in &options.history {
// all messages must have non-empty content except for the optional final assistant messag
match m {
super::Message::Agent(data, _) => {
let trimmed = data.trim();
if !trimmed.is_empty() {
messages.push(Message::assistant(data.trim()))
} else {
log::warn!("ignoring empty assistant message: {:?}", m);
}
}
super::Message::Feedback(data, _) => {
let trimmed = data.trim();
if !trimmed.is_empty() {
messages.push(Message::user(data.trim()))
} else {
log::warn!("ignoring empty user message: {:?}", m);
}
}
}
}

// if the last message is an assistant message, remove it
if let Some(Message { role, content }) = messages.last() {
// handles "Your API request included an `assistant` message in the final position, which would pre-fill the `assistant` response"
if matches!(role, Role::Assistant) {
messages.pop();
}
}

let tools = self.get_tools_if_supported(&state).await;

let request_body = MessagesRequestBody {
model: self.model,
system: Some(SystemPrompt::new(options.system_prompt.trim())),
messages,
max_tokens,
tools: if tools.is_empty() { None } else { Some(tools) },
..Default::default()
};

log::debug!("request_body = {:?}", request_body);

let response = match self.client.create_a_message(request_body.clone()).await {
Ok(r) => r,
Err(e) => {
log::error!("failed to send chat message: {e} - {:?}", request_body);
return Err(anyhow::anyhow!("failed to send chat message: {e}"));
}
};

log::debug!("response = {:?}", response);

let content = response.content.flatten_into_text().unwrap_or_default();
let tool_use = match response.content.flatten_into_tool_use() {
Ok(m) => Some(m),
Err(_) => None,
};

let mut invocations = vec![];

log::debug!("tool_use={:?}", &tool_use);

if let Some(tool_use) = tool_use {
let mut attributes = HashMap::new();
let mut payload = None;

let object = match tool_use.input.as_object() {
Some(o) => o,
None => {
log::error!("tool_use.input is not an object: {:?}", tool_use.input);
return Err(anyhow::anyhow!("tool_use.input is not an object"));
}
};

for (name, value) in object {
log::debug!("tool_call.input[{}] = {:?}", name, value);

let mut value_content = value.to_string();
if let serde_json::Value::String(escaped_json) = &value {
value_content = escaped_json.to_string();
}

let str_val = value_content.trim_matches('"').to_string();
if name == "payload" {
payload = Some(str_val);
} else {
attributes.insert(name.to_string(), str_val);
}
}

let inv = Invocation {
action: tool_use.name.to_string(),
attributes: if attributes.is_empty() {
None
} else {
Some(attributes)
},
payload,
};

invocations.push(inv);

log::debug!("tool_use={:?}", tool_use);
log::debug!("invocations={:?}", &invocations);
}

if invocations.is_empty() && content.is_empty() {
log::warn!("empty tool calls and content in response: {:?}", response);
}

Ok((content.to_string(), invocations))
}
}

#[async_trait]
impl mini_rag::Embedder for AnthropicClient {
async fn embed(&self, _text: &str) -> Result<mini_rag::Embeddings> {
// TODO: extend the rust client to do this
todo!("anthropic embeddings generation not yet implemented")
}
}
9 changes: 9 additions & 0 deletions nerve-core/src/agent/generator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use serde::{Deserialize, Serialize};

use super::{state::SharedState, Invocation};

#[cfg(feature = "anthropic")]
mod anthropic;
#[cfg(feature = "fireworks")]
mod fireworks;
#[cfg(feature = "groq")]
Expand Down Expand Up @@ -184,6 +186,13 @@ macro_rules! factory_body {
$model_name,
$context_window,
)?)),
#[cfg(feature = "anthropic")]
"anthropic" | "claude" => Ok(Box::new(anthropic::AnthropicClient::new(
$url,
$port,
$model_name,
$context_window,
)?)),
"http" => Ok(Box::new(openai_compatible::OpenAiCompatibleClient::new(
$url,
$port,
Expand Down
7 changes: 6 additions & 1 deletion nerve-core/src/agent/serialization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,12 @@ impl Strategy {
"".to_string()
} else {
// model does not support tool calls, we need to provide the actions in its system prompt
include_str!("actions.prompt").to_owned() + "\n" + &self.actions_for_state(state)?
let mut raw = include_str!("actions.prompt").to_owned();

raw.push_str("\n");
raw.push_str(&self.actions_for_state(state)?);

raw
};

let iterations = if state.metrics.max_steps > 0 {
Expand Down
Loading
Loading