Skip to content

Commit a95e546

Browse files
committed
new: implemented huggingface message api support (closes #21)
1 parent c1640b7 commit a95e546

File tree

5 files changed

+112
-5
lines changed

5 files changed

+112
-5
lines changed

Cargo.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,13 @@ reqwest_cookie_store = "0.8.0"
5555
serde_json = "1.0.120"
5656

5757
[features]
58-
default = ["ollama", "groq", "openai", "fireworks"]
58+
default = ["ollama", "groq", "openai", "fireworks", "hf"]
5959

6060
ollama = ["dep:ollama-rs"]
6161
groq = ["dep:groq-api-rs", "dep:duration-string"]
6262
openai = ["dep:openai_api_rust"]
6363
fireworks = ["dep:openai_api_rust"]
64+
hf = ["dep:openai_api_rust"]
6465

6566
[profile.release]
6667
lto = true # Enable link-time optimization

README.md

+9-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ While Nerve was inspired by other projects such as Autogen and Rigging, its main
3434

3535
## LLM Support
3636

37-
Nerve features integrations for any model accessible via the [ollama](https://github.com/ollama/ollama), [groq](https://groq.com), [OpenAI](https://openai.com/index/openai-api/) and [Fireworks](https://fireworks.ai/) APIs.
37+
Nerve features integrations for any model accessible via the [ollama](https://github.com/ollama/ollama), [groq](https://groq.com), [OpenAI](https://openai.com/index/openai-api/), [Fireworks](https://fireworks.ai/) and [Huggingface](https://huggingface.co/blog/tgi-messages-api#using-inference-endpoints-with-openai-client-libraries) APIs.
3838

3939
**The tool will automatically detect if the selected model natively supports function calling. If not, it will provide a compatibility layer that empowers older models to perform function calling anyway.**
4040

@@ -64,6 +64,14 @@ For **Fireworks**:
6464
LLM_FIREWORKS_KEY=you-api-key nerve -G "fireworks://llama-v3-70b-instruct" ...
6565
```
6666

67+
For **Huggingface**:
68+
69+
Refer to [this document](https://huggingface.co/blog/tgi-messages-api#using-inference-endpoints-with-openai-client-libraries) for how to configure a custom Huggingface endpoint.
70+
71+
```sh
72+
HF_API_TOKEN=you-api-key nerve -G "hf://[email protected]" ...
73+
```
74+
6775
## Example
6876

6977
Let's take a look at the `examples/ssh_agent` example tasklet (a "tasklet" is a YAML file describing a task and the instructions):

src/agent/generator/huggingface.rs

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
use anyhow::Result;
2+
use async_trait::async_trait;
3+
4+
use crate::agent::{state::SharedState, Invocation};
5+
6+
use super::{openai::OpenAIClient, Client, Options};
7+
8+
pub struct HuggingfaceMessageClient {
9+
client: OpenAIClient,
10+
}
11+
12+
#[async_trait]
13+
impl Client for HuggingfaceMessageClient {
14+
fn new(url: &str, _: u16, model_name: &str, _: u32) -> anyhow::Result<Self>
15+
where
16+
Self: Sized,
17+
{
18+
let message_api = format!("https://{}/v1/", url);
19+
let client = OpenAIClient::custom(model_name, "HF_API_TOKEN", &message_api)?;
20+
21+
log::debug!("using huggingface message api @ {}", message_api);
22+
23+
Ok(Self { client })
24+
}
25+
26+
async fn chat(
27+
&self,
28+
state: SharedState,
29+
options: &Options,
30+
) -> anyhow::Result<(String, Vec<Invocation>)> {
31+
self.client.chat(state, options).await
32+
}
33+
}
34+
35+
#[async_trait]
36+
impl mini_rag::Embedder for HuggingfaceMessageClient {
37+
async fn embed(&self, text: &str) -> Result<mini_rag::Embeddings> {
38+
self.client.embed(text).await
39+
}
40+
}

src/agent/generator/mod.rs

+8
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ use super::{state::SharedState, Invocation};
1313
mod fireworks;
1414
#[cfg(feature = "groq")]
1515
mod groq;
16+
#[cfg(feature = "hf")]
17+
mod huggingface;
1618
#[cfg(feature = "ollama")]
1719
mod ollama;
1820
#[cfg(feature = "openai")]
@@ -153,6 +155,12 @@ macro_rules! factory_body {
153155
$model_name,
154156
$context_window,
155157
)?)),
158+
"hf" => Ok(Box::new(huggingface::HuggingfaceMessageClient::new(
159+
$url,
160+
$port,
161+
$model_name,
162+
$context_window,
163+
)?)),
156164
#[cfg(feature = "groq")]
157165
"groq" => Ok(Box::new(groq::GroqClient::new(
158166
$url,

src/cli.rs

+53-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use regex::Regex;
88
lazy_static! {
99
pub static ref PUBLIC_GENERATOR_PARSER: Regex = Regex::new(r"(?m)^(.+)://(.+)$").unwrap();
1010
pub static ref LOCAL_GENERATOR_PARSER: Regex =
11-
Regex::new(r"(?m)^(.+)://(.+)@(.+):(\d+)$").unwrap();
11+
Regex::new(r"(?m)^(.+)://(.+)@([^:]+):?(\d+)?$").unwrap();
1212
}
1313

1414
#[derive(Default)]
@@ -21,7 +21,7 @@ pub(crate) struct GeneratorOptions {
2121
}
2222

2323
/// Get things done with LLMs.
24-
#[derive(Parser, Debug)]
24+
#[derive(Parser, Debug, Default)]
2525
#[command(version, about, long_about = None)]
2626
pub(crate) struct Args {
2727
/// Generator string as <type>://<model name>@<host>:<port>
@@ -96,7 +96,11 @@ impl Args {
9696
.unwrap()
9797
.as_str()
9898
.clone_into(&mut generator.host);
99-
generator.port = caps.get(4).unwrap().as_str().parse::<u16>().unwrap();
99+
generator.port = if let Some(port) = caps.get(4) {
100+
port.as_str().parse::<u16>().unwrap()
101+
} else {
102+
0
103+
};
100104
} else {
101105
let caps = if let Some(caps) = PUBLIC_GENERATOR_PARSER.captures_iter(raw).next() {
102106
caps
@@ -149,3 +153,49 @@ pub(crate) fn get_user_input(prompt: &str) -> String {
149153
println!();
150154
input.trim().to_string()
151155
}
156+
157+
#[cfg(test)]
158+
mod tests {
159+
use super::Args;
160+
161+
#[test]
162+
fn test_wont_parse_invalid_generator() {
163+
let mut args = Args::default();
164+
args.generator = "not a valid generator".to_string();
165+
let ret = args.to_generator_options();
166+
assert!(ret.is_err());
167+
}
168+
169+
#[test]
170+
fn test_parse_local_generator_full() {
171+
let mut args = Args::default();
172+
args.generator = "ollama://llama3@localhost:11434".to_string();
173+
let ret = args.to_generator_options().unwrap();
174+
assert_eq!(ret.type_name, "ollama");
175+
assert_eq!(ret.model_name, "llama3");
176+
assert_eq!(ret.host, "localhost");
177+
assert_eq!(ret.port, 11434);
178+
}
179+
180+
#[test]
181+
fn test_parse_local_generator_without_port() {
182+
let mut args = Args::default();
183+
args.generator = "ollama://llama3@localhost".to_string();
184+
let ret = args.to_generator_options().unwrap();
185+
assert_eq!(ret.type_name, "ollama");
186+
assert_eq!(ret.model_name, "llama3");
187+
assert_eq!(ret.host, "localhost");
188+
assert_eq!(ret.port, 0);
189+
}
190+
191+
#[test]
192+
fn test_parse_public_generator() {
193+
let mut args = Args::default();
194+
args.generator = "groq://llama3".to_string();
195+
let ret = args.to_generator_options().unwrap();
196+
assert_eq!(ret.type_name, "groq");
197+
assert_eq!(ret.model_name, "llama3");
198+
assert_eq!(ret.host, "");
199+
assert_eq!(ret.port, 0);
200+
}
201+
}

0 commit comments

Comments
 (0)