Skip to content

Commit 9aa905b

Browse files
committed
new: working on issue #23
1 parent 5756535 commit 9aa905b

File tree

6 files changed

+334
-4
lines changed

6 files changed

+334
-4
lines changed

Cargo.lock

+45
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

nerve-core/Cargo.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,15 @@ reqwest_cookie_store = "0.8.0"
4646
serde_json = "1.0.120"
4747
clap = { version = "4.5.6", features = ["derive"] }
4848
tera = { version = "1.20.0", default-features = false }
49+
clust = { version = "0.9.0", optional = true }
4950

5051
[features]
51-
default = ["ollama", "groq", "openai", "fireworks", "hf", "novita"]
52+
default = ["ollama", "groq", "openai", "fireworks", "hf", "novita", "anthropic"]
5253

5354
ollama = ["dep:ollama-rs"]
5455
groq = ["dep:groq-api-rs", "dep:duration-string"]
5556
openai = ["dep:openai_api_rust"]
5657
fireworks = ["dep:openai_api_rust"]
5758
hf = ["dep:openai_api_rust"]
5859
novita = ["dep:openai_api_rust"]
60+
anthropic = ["dep:clust"]
+262
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
use std::collections::HashMap;
2+
3+
use crate::agent::{state::SharedState, Invocation};
4+
use anyhow::Result;
5+
use async_trait::async_trait;
6+
use clust::messages::{
7+
ClaudeModel, MaxTokens, Message, MessagesRequestBody, Role, SystemPrompt, ToolDefinition,
8+
};
9+
use serde::{Deserialize, Serialize};
10+
11+
use super::{ChatOptions, Client};
12+
13+
#[derive(Debug, Clone, Serialize, Deserialize)]
14+
pub struct AnthropicToolFunctionParameterProperty {
15+
#[serde(rename(serialize = "type", deserialize = "type"))]
16+
pub the_type: String,
17+
pub description: String,
18+
}
19+
20+
pub struct AnthropicClient {
21+
model: ClaudeModel,
22+
client: clust::Client,
23+
}
24+
25+
impl AnthropicClient {
26+
async fn get_tools_if_supported(&self, state: &SharedState) -> Vec<ToolDefinition> {
27+
let mut tools = vec![];
28+
29+
// if native tool calls are supported (and XML was not forced)
30+
if state.lock().await.native_tools_support {
31+
// for every namespace available to the model
32+
for group in state.lock().await.get_namespaces() {
33+
// for every action of the namespace
34+
for action in &group.actions {
35+
let mut required = vec![];
36+
let mut properties = HashMap::new();
37+
38+
if let Some(example) = action.example_payload() {
39+
required.push("payload".to_string());
40+
properties.insert(
41+
"payload".to_string(),
42+
AnthropicToolFunctionParameterProperty {
43+
the_type: "string".to_string(),
44+
description: format!(
45+
"The main function argument, use this as a template: {}",
46+
example
47+
),
48+
},
49+
);
50+
}
51+
52+
if let Some(attrs) = action.example_attributes() {
53+
for name in attrs.keys() {
54+
required.push(name.to_string());
55+
properties.insert(
56+
name.to_string(),
57+
AnthropicToolFunctionParameterProperty {
58+
the_type: "string".to_string(),
59+
description: name.to_string(),
60+
},
61+
);
62+
}
63+
}
64+
65+
let input_schema = serde_json::json!({
66+
"properties": properties,
67+
"required": required,
68+
"type": "object",
69+
});
70+
71+
tools.push(ToolDefinition::new(
72+
action.name(),
73+
Some(action.description().to_string()),
74+
input_schema,
75+
));
76+
}
77+
}
78+
}
79+
80+
tools
81+
}
82+
}
83+
84+
#[async_trait]
85+
impl Client for AnthropicClient {
86+
fn new(_url: &str, _port: u16, model_name: &str, _context_window: u32) -> anyhow::Result<Self> {
87+
let model: ClaudeModel = if model_name.contains("opus") {
88+
ClaudeModel::Claude3Opus20240229
89+
} else if model_name.contains("sonnet") && !model_name.contains("5") {
90+
ClaudeModel::Claude3Sonnet20240229
91+
} else if model_name.contains("haiku") {
92+
ClaudeModel::Claude3Haiku20240307
93+
} else {
94+
ClaudeModel::Claude35Sonnet20240620
95+
};
96+
97+
let client = clust::Client::from_env()?;
98+
Ok(Self { model, client })
99+
}
100+
101+
async fn check_native_tools_support(&self) -> Result<bool> {
102+
let messages = vec![Message::user("Execute the test function.")];
103+
let max_tokens = MaxTokens::new(4096, self.model)?;
104+
105+
let request_body = MessagesRequestBody {
106+
model: self.model,
107+
system: Some(SystemPrompt::new("You are an helpful assistant.")),
108+
messages,
109+
max_tokens,
110+
tools: Some(vec![ToolDefinition::new(
111+
"test",
112+
Some("This is a test function.".to_string()),
113+
serde_json::json!({
114+
"properties": {},
115+
"required": [],
116+
"type": "object",
117+
}),
118+
)]),
119+
..Default::default()
120+
};
121+
122+
let response = self.client.create_a_message(request_body).await?;
123+
124+
log::debug!("response = {:?}", response);
125+
126+
if let Ok(tool_use) = response.content.flatten_into_tool_use() {
127+
Ok(tool_use.name == "test")
128+
} else {
129+
Ok(false)
130+
}
131+
}
132+
133+
async fn chat(
134+
&self,
135+
state: SharedState,
136+
options: &ChatOptions,
137+
) -> anyhow::Result<(String, Vec<Invocation>)> {
138+
let mut messages = vec![Message::user(options.prompt.trim().to_string())];
139+
let max_tokens = MaxTokens::new(4096, self.model)?;
140+
141+
for m in &options.history {
142+
// all messages must have non-empty content except for the optional final assistant messag
143+
match m {
144+
super::Message::Agent(data, _) => {
145+
let trimmed = data.trim();
146+
if !trimmed.is_empty() {
147+
messages.push(Message::assistant(data.trim()))
148+
} else {
149+
log::warn!("ignoring empty assistant message: {:?}", m);
150+
}
151+
}
152+
super::Message::Feedback(data, _) => {
153+
let trimmed = data.trim();
154+
if !trimmed.is_empty() {
155+
messages.push(Message::user(data.trim()))
156+
} else {
157+
log::warn!("ignoring empty user message: {:?}", m);
158+
}
159+
}
160+
}
161+
}
162+
163+
// if the last message is an assistant message, remove it
164+
if let Some(Message { role, content }) = messages.last() {
165+
// handles "Your API request included an `assistant` message in the final position, which would pre-fill the `assistant` response"
166+
if matches!(role, Role::Assistant) {
167+
messages.pop();
168+
}
169+
}
170+
171+
let tools = self.get_tools_if_supported(&state).await;
172+
173+
let request_body = MessagesRequestBody {
174+
model: self.model,
175+
system: Some(SystemPrompt::new(options.system_prompt.trim())),
176+
messages,
177+
max_tokens,
178+
tools: if tools.is_empty() { None } else { Some(tools) },
179+
..Default::default()
180+
};
181+
182+
log::debug!("request_body = {:?}", request_body);
183+
184+
let response = match self.client.create_a_message(request_body.clone()).await {
185+
Ok(r) => r,
186+
Err(e) => {
187+
log::error!("failed to send chat message: {e} - {:?}", request_body);
188+
return Err(anyhow::anyhow!("failed to send chat message: {e}"));
189+
}
190+
};
191+
192+
log::debug!("response = {:?}", response);
193+
194+
let (content, tool_use) = if let Ok(m) = response.content.flatten_into_tool_use() {
195+
(response.content.flatten_into_text()?, Some(m))
196+
} else {
197+
("", None)
198+
};
199+
200+
let mut invocations = vec![];
201+
202+
log::debug!("tool_use={:?}", &tool_use);
203+
204+
if let Some(tool_use) = tool_use {
205+
let mut attributes = HashMap::new();
206+
let mut payload = None;
207+
208+
let object = match tool_use.input.as_object() {
209+
Some(o) => o,
210+
None => {
211+
log::error!("tool_use.input is not an object: {:?}", tool_use.input);
212+
return Err(anyhow::anyhow!("tool_use.input is not an object"));
213+
}
214+
};
215+
216+
for (name, value) in object {
217+
log::debug!("tool_call.input[{}] = {:?}", name, value);
218+
219+
let mut content = value.to_string();
220+
if let serde_json::Value::String(escaped_json) = &value {
221+
content = escaped_json.to_string();
222+
}
223+
224+
let str_val = content.trim_matches('"').to_string();
225+
if name == "payload" {
226+
payload = Some(str_val);
227+
} else {
228+
attributes.insert(name.to_string(), str_val);
229+
}
230+
}
231+
232+
let inv = Invocation {
233+
action: tool_use.name.to_string(),
234+
attributes: if attributes.is_empty() {
235+
None
236+
} else {
237+
Some(attributes)
238+
},
239+
payload,
240+
};
241+
242+
invocations.push(inv);
243+
244+
log::debug!("tool_use={:?}", tool_use);
245+
log::debug!("invocations={:?}", &invocations);
246+
}
247+
248+
if invocations.is_empty() && content.is_empty() {
249+
log::warn!("response = {:?}", response);
250+
}
251+
252+
Ok((content.to_string(), invocations))
253+
}
254+
}
255+
256+
#[async_trait]
257+
impl mini_rag::Embedder for AnthropicClient {
258+
async fn embed(&self, _text: &str) -> Result<mini_rag::Embeddings> {
259+
// TODO: extend the rust client to do this
260+
todo!("anthropic embeddings generation not yet implemented")
261+
}
262+
}

nerve-core/src/agent/generator/mod.rs

+9
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ use serde::{Deserialize, Serialize};
99

1010
use super::{state::SharedState, Invocation};
1111

12+
#[cfg(feature = "anthropic")]
13+
mod anthropic;
1214
#[cfg(feature = "fireworks")]
1315
mod fireworks;
1416
#[cfg(feature = "groq")]
@@ -184,6 +186,13 @@ macro_rules! factory_body {
184186
$model_name,
185187
$context_window,
186188
)?)),
189+
#[cfg(feature = "anthropic")]
190+
"anthropic" | "claude" => Ok(Box::new(anthropic::AnthropicClient::new(
191+
$url,
192+
$port,
193+
$model_name,
194+
$context_window,
195+
)?)),
187196
"http" => Ok(Box::new(openai_compatible::OpenAiCompatibleClient::new(
188197
$url,
189198
$port,

nerve-core/src/agent/serialization/mod.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,12 @@ impl Strategy {
102102
"".to_string()
103103
} else {
104104
// model does not support tool calls, we need to provide the actions in its system prompt
105-
include_str!("actions.prompt").to_owned() + "\n" + &self.actions_for_state(state)?
105+
let mut raw = include_str!("actions.prompt").to_owned();
106+
107+
raw.push_str("\n");
108+
raw.push_str(&self.actions_for_state(state)?);
109+
110+
raw
106111
};
107112

108113
let iterations = if state.metrics.max_steps > 0 {

0 commit comments

Comments
 (0)