diff --git a/README.md b/README.md index ed75625..2ef1b2b 100644 --- a/README.md +++ b/README.md @@ -31,9 +31,9 @@ Nerve features integrations for any model accessible via the following providers |----------|----------------------------|------------------| | **Ollama** | - | `ollama://llama3@localhost:11434` | | **Groq** | `GROQ_API_KEY` | `groq://llama3-70b-8192` | -| **OpenAI** | `OPENAI_API_KEY` | `openai://gpt-4` | +| **OpenAI**¹ | `OPENAI_API_KEY` | `openai://gpt-4` | | **Fireworks** | `LLM_FIREWORKS_KEY` | `fireworks://llama-v3-70b-instruct` | -| **Huggingface**¹ | `HF_API_TOKEN` | `hf://tgi@your-custom-endpoint.aws.endpoints.huggingface.cloud` | +| **Huggingface**² | `HF_API_TOKEN` | `hf://tgi@your-custom-endpoint.aws.endpoints.huggingface.cloud` | | **Anthropic** | `ANTHROPIC_API_KEY` | `anthropic://claude` | | **Nvidia NIM** | `NIM_API_KEY` | `nim://nvidia/nemotron-4-340b-instruct` | | **DeepSeek** | `DEEPSEEK_API_KEY` | `deepseek://deepseek-chat` | @@ -41,7 +41,9 @@ Nerve features integrations for any model accessible via the following providers | **Mistral.ai** | `MISTRAL_API_KEY` | `mistral://mistral-large-latest` | | **Novita** | `NOVITA_API_KEY` | `novita://meta-llama/llama-3.1-70b-instruct` | -¹ Refer to [this document](https://huggingface.co/blog/tgi-messages-api#using-inference-endpoints-with-openai-client-libraries) for how to configure a custom Huggingface endpoint. +¹ **o1-preview and o1 models do not support function calling directly** and do not support a system prompt. It is possible to workaround this by adding the `--user-only` flag to the command line. + +² Refer to [this document](https://huggingface.co/blog/tgi-messages-api#using-inference-endpoints-with-openai-client-libraries) for how to configure a custom Huggingface endpoint. ## Installing with Cargo diff --git a/src/agent/generator/anthropic.rs b/src/agent/generator/anthropic.rs index 70d6198..2a748e1 100644 --- a/src/agent/generator/anthropic.rs +++ b/src/agent/generator/anthropic.rs @@ -183,7 +183,10 @@ impl Client for AnthropicClient { let request_body = MessagesRequestBody { model: self.model, - system: Some(SystemPrompt::new(options.system_prompt.trim())), + system: match &options.system_prompt { + Some(sp) => Some(SystemPrompt::new(sp.trim())), + None => None, + }, messages, max_tokens, tools: if tools.is_empty() { None } else { Some(tools) }, diff --git a/src/agent/generator/groq.rs b/src/agent/generator/groq.rs index 2efa79e..c3cfe39 100644 --- a/src/agent/generator/groq.rs +++ b/src/agent/generator/groq.rs @@ -114,20 +114,30 @@ impl Client for GroqClient { state: SharedState, options: &ChatOptions, ) -> anyhow::Result { - let mut chat_history = vec![ - crate::api::groq::completion::message::Message::SystemMessage { - role: Some("system".to_string()), - content: Some(options.system_prompt.trim().to_string()), - name: None, - tool_call_id: None, - }, - crate::api::groq::completion::message::Message::UserMessage { - role: Some("user".to_string()), - content: Some(options.prompt.trim().to_string()), - name: None, - tool_call_id: None, - }, - ]; + let mut chat_history = match &options.system_prompt { + Some(sp) => vec![ + crate::api::groq::completion::message::Message::SystemMessage { + role: Some("system".to_string()), + content: Some(sp.trim().to_string()), + name: None, + tool_call_id: None, + }, + crate::api::groq::completion::message::Message::UserMessage { + role: Some("user".to_string()), + content: Some(options.prompt.trim().to_string()), + name: None, + tool_call_id: None, + }, + ], + None => vec![ + crate::api::groq::completion::message::Message::UserMessage { + role: Some("user".to_string()), + content: Some(options.prompt.trim().to_string()), + name: None, + tool_call_id: None, + }, + ], + }; let mut call_idx = 0; diff --git a/src/agent/generator/mod.rs b/src/agent/generator/mod.rs index 1796ad1..f2e9255 100644 --- a/src/agent/generator/mod.rs +++ b/src/agent/generator/mod.rs @@ -15,13 +15,13 @@ mod deepseek; mod fireworks; mod groq; mod huggingface; +mod mistral; mod nim; mod novita; mod ollama; mod openai; mod openai_compatible; mod xai; -mod mistral; pub(crate) mod history; mod options; @@ -36,14 +36,14 @@ lazy_static! { #[derive(Clone, Debug, Serialize, Deserialize)] pub struct ChatOptions { - pub system_prompt: String, + pub system_prompt: Option, pub prompt: String, pub history: ChatHistory, } impl ChatOptions { pub fn new( - system_prompt: String, + system_prompt: Option, prompt: String, conversation: Vec, history_strategy: ConversationWindow, @@ -97,8 +97,8 @@ pub trait Client: mini_rag::Embedder + Send + Sync { async fn chat(&self, state: SharedState, options: &ChatOptions) -> Result; - async fn check_native_tools_support(&self) -> Result { - Ok(false) + async fn supports_system_prompt(&self) -> Result { + Ok(true) } async fn check_rate_limit(&self, error: &str) -> bool { diff --git a/src/agent/generator/ollama.rs b/src/agent/generator/ollama.rs index 24c81fd..479b53d 100644 --- a/src/agent/generator/ollama.rs +++ b/src/agent/generator/ollama.rs @@ -102,10 +102,13 @@ impl Client for OllamaClient { // - msg 1 // - ... // - msg n - let mut chat_history = vec![ - ChatMessage::system(options.system_prompt.trim().to_string()), - ChatMessage::user(options.prompt.to_string()), - ]; + let mut chat_history = match &options.system_prompt { + Some(sp) => vec![ + ChatMessage::system(sp.trim().to_string()), + ChatMessage::user(options.prompt.to_string()), + ], + None => vec![ChatMessage::user(options.prompt.to_string())], + }; for m in options.history.iter() { chat_history.push(match m { @@ -229,9 +232,9 @@ impl Client for OllamaClient { content, invocations, usage: res.final_data.map(|final_data| super::Usage { - input_tokens: final_data.prompt_eval_count as u32, - output_tokens: final_data.eval_count as u32, - }), + input_tokens: final_data.prompt_eval_count as u32, + output_tokens: final_data.eval_count as u32, + }), }) } else { log::warn!("model returned an empty message."); @@ -239,9 +242,9 @@ impl Client for OllamaClient { content: "".to_string(), invocations: vec![], usage: res.final_data.map(|final_data| super::Usage { - input_tokens: final_data.prompt_eval_count as u32, - output_tokens: final_data.eval_count as u32, - }), + input_tokens: final_data.prompt_eval_count as u32, + output_tokens: final_data.eval_count as u32, + }), }) } } diff --git a/src/agent/generator/openai.rs b/src/agent/generator/openai.rs index 08f5bb7..453e614 100644 --- a/src/agent/generator/openai.rs +++ b/src/agent/generator/openai.rs @@ -191,18 +191,25 @@ impl Client for OpenAIClient { state: SharedState, options: &ChatOptions, ) -> anyhow::Result { - let mut chat_history = vec![ - crate::api::openai::Message { - role: Role::System, - content: Some(options.system_prompt.trim().to_string()), - tool_calls: None, - }, - crate::api::openai::Message { + let mut chat_history = match &options.system_prompt { + Some(sp) => vec![ + crate::api::openai::Message { + role: Role::System, + content: Some(sp.trim().to_string()), + tool_calls: None, + }, + crate::api::openai::Message { + role: Role::User, + content: Some(options.prompt.trim().to_string()), + tool_calls: None, + }, + ], + None => vec![crate::api::openai::Message { role: Role::User, content: Some(options.prompt.trim().to_string()), tool_calls: None, - }, - ]; + }], + }; for m in options.history.iter() { chat_history.push(match m { diff --git a/src/agent/mod.rs b/src/agent/mod.rs index ae34b41..f7aca43 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -108,6 +108,7 @@ pub struct Agent { serializer: serialization::Strategy, use_native_tools_format: bool, + user_only: bool, } impl Agent { @@ -119,6 +120,7 @@ impl Agent { serializer: serialization::Strategy, conversation_window: ConversationWindow, force_strategy: bool, + user_only: bool, max_iterations: usize, ) -> Result { let use_native_tools_format = if force_strategy { @@ -156,6 +158,7 @@ impl Agent { state, task_timeout, use_native_tools_format, + user_only, serializer, conversation_window, }) @@ -235,9 +238,10 @@ impl Agent { async fn on_state_update(&self, options: &ChatOptions, refresh: bool) -> Result<()> { let mut opts = options.clone(); if refresh { - opts.system_prompt = self - .serializer - .system_prompt_for_state(&*self.state.lock().await)?; + opts.system_prompt = Some( + self.serializer + .system_prompt_for_state(&*self.state.lock().await)?, + ); let messages = self.state.lock().await.to_chat_history(&self.serializer)?; @@ -368,6 +372,14 @@ impl Agent { let system_prompt = self.serializer.system_prompt_for_state(&mut_state)?; let prompt = mut_state.to_prompt()?; + + let (system_prompt, prompt) = if self.user_only { + // combine with user prompt for models like the openai/o1 family + (None, format!("{system_prompt}\n\n{prompt}")) + } else { + (Some(system_prompt), prompt) + }; + let history = mut_state.to_chat_history(&self.serializer)?; let options = ChatOptions::new(system_prompt, prompt, history, self.conversation_window); diff --git a/src/cli/cli.rs b/src/cli/cli.rs index 0af3126..b208ba9 100644 --- a/src/cli/cli.rs +++ b/src/cli/cli.rs @@ -14,6 +14,9 @@ pub struct Args { /// Run in judge mode. #[arg(long)] pub judge_mode: bool, + /// Only rely on user prompt. Use for models like openai/o1 family that don't allow a system prompt. + #[arg(long)] + pub user_only: bool, /// Embedder string as ://@: #[arg( short = 'E', diff --git a/src/cli/setup.rs b/src/cli/setup.rs index dc8334b..9309bc4 100644 --- a/src/cli/setup.rs +++ b/src/cli/setup.rs @@ -101,6 +101,7 @@ pub async fn setup_agent(args: &cli::Args) -> Result<(Agent, events::Receiver)> args.serialization.clone(), conversation_window, args.force_format, + args.user_only, args.max_iterations, ) .await?; diff --git a/src/cli/ui/text.rs b/src/cli/ui/text.rs index 5f6f6f4..ab74011 100644 --- a/src/cli/ui/text.rs +++ b/src/cli/ui/text.rs @@ -73,7 +73,7 @@ pub async fn consume_events(args: cli::Args, mut events_rx: Receiver) { if let Some(prompt_path) = &args.save_to { let data = format!( "[SYSTEM PROMPT]\n\n{}\n\n[PROMPT]\n\n{}\n\n[CHAT]\n\n{}", - &opts.system_prompt, + &opts.system_prompt.unwrap_or_default(), &opts.prompt, opts.history .iter()