Skip to content

Commit

Permalink
new: implemented new --user-only argument for openai/o1 models
Browse files Browse the repository at this point in the history
  • Loading branch information
evilsocket committed Dec 10, 2024
1 parent 15a99b3 commit 5ebae31
Show file tree
Hide file tree
Showing 10 changed files with 87 additions and 46 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,19 @@ Nerve features integrations for any model accessible via the following providers
|----------|----------------------------|------------------|
| **Ollama** | - | `ollama://llama3@localhost:11434` |
| **Groq** | `GROQ_API_KEY` | `groq://llama3-70b-8192` |
| **OpenAI** | `OPENAI_API_KEY` | `openai://gpt-4` |
| **OpenAI**¹ | `OPENAI_API_KEY` | `openai://gpt-4` |
| **Fireworks** | `LLM_FIREWORKS_KEY` | `fireworks://llama-v3-70b-instruct` |
| **Huggingface**¹ | `HF_API_TOKEN` | `hf://[email protected]` |
| **Huggingface**² | `HF_API_TOKEN` | `hf://[email protected]` |
| **Anthropic** | `ANTHROPIC_API_KEY` | `anthropic://claude` |
| **Nvidia NIM** | `NIM_API_KEY` | `nim://nvidia/nemotron-4-340b-instruct` |
| **DeepSeek** | `DEEPSEEK_API_KEY` | `deepseek://deepseek-chat` |
| **xAI** | `XAI_API_KEY` | `xai://grok-beta` |
| **Mistral.ai** | `MISTRAL_API_KEY` | `mistral://mistral-large-latest` |
| **Novita** | `NOVITA_API_KEY` | `novita://meta-llama/llama-3.1-70b-instruct` |

¹ Refer to [this document](https://huggingface.co/blog/tgi-messages-api#using-inference-endpoints-with-openai-client-libraries) for how to configure a custom Huggingface endpoint.
¹ **o1-preview and o1 models do not support function calling directly** and do not support a system prompt. It is possible to workaround this by adding the `--user-only` flag to the command line.

² Refer to [this document](https://huggingface.co/blog/tgi-messages-api#using-inference-endpoints-with-openai-client-libraries) for how to configure a custom Huggingface endpoint.

## Installing with Cargo

Expand Down
5 changes: 4 additions & 1 deletion src/agent/generator/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,10 @@ impl Client for AnthropicClient {

let request_body = MessagesRequestBody {
model: self.model,
system: Some(SystemPrompt::new(options.system_prompt.trim())),
system: match &options.system_prompt {
Some(sp) => Some(SystemPrompt::new(sp.trim())),
None => None,
},
messages,
max_tokens,
tools: if tools.is_empty() { None } else { Some(tools) },
Expand Down
38 changes: 24 additions & 14 deletions src/agent/generator/groq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,20 +114,30 @@ impl Client for GroqClient {
state: SharedState,
options: &ChatOptions,
) -> anyhow::Result<ChatResponse> {
let mut chat_history = vec![
crate::api::groq::completion::message::Message::SystemMessage {
role: Some("system".to_string()),
content: Some(options.system_prompt.trim().to_string()),
name: None,
tool_call_id: None,
},
crate::api::groq::completion::message::Message::UserMessage {
role: Some("user".to_string()),
content: Some(options.prompt.trim().to_string()),
name: None,
tool_call_id: None,
},
];
let mut chat_history = match &options.system_prompt {
Some(sp) => vec![
crate::api::groq::completion::message::Message::SystemMessage {
role: Some("system".to_string()),
content: Some(sp.trim().to_string()),
name: None,
tool_call_id: None,
},
crate::api::groq::completion::message::Message::UserMessage {
role: Some("user".to_string()),
content: Some(options.prompt.trim().to_string()),
name: None,
tool_call_id: None,
},
],
None => vec![
crate::api::groq::completion::message::Message::UserMessage {
role: Some("user".to_string()),
content: Some(options.prompt.trim().to_string()),
name: None,
tool_call_id: None,
},
],
};

let mut call_idx = 0;

Expand Down
10 changes: 5 additions & 5 deletions src/agent/generator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ mod deepseek;
mod fireworks;
mod groq;
mod huggingface;
mod mistral;
mod nim;
mod novita;
mod ollama;
mod openai;
mod openai_compatible;
mod xai;
mod mistral;

pub(crate) mod history;
mod options;
Expand All @@ -36,14 +36,14 @@ lazy_static! {

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ChatOptions {
pub system_prompt: String,
pub system_prompt: Option<String>,
pub prompt: String,
pub history: ChatHistory,
}

impl ChatOptions {
pub fn new(
system_prompt: String,
system_prompt: Option<String>,
prompt: String,
conversation: Vec<Message>,
history_strategy: ConversationWindow,
Expand Down Expand Up @@ -97,8 +97,8 @@ pub trait Client: mini_rag::Embedder + Send + Sync {

async fn chat(&self, state: SharedState, options: &ChatOptions) -> Result<ChatResponse>;

async fn check_native_tools_support(&self) -> Result<bool> {
Ok(false)
async fn supports_system_prompt(&self) -> Result<bool> {
Ok(true)
}

async fn check_rate_limit(&self, error: &str) -> bool {
Expand Down
23 changes: 13 additions & 10 deletions src/agent/generator/ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,13 @@ impl Client for OllamaClient {
// - msg 1
// - ...
// - msg n
let mut chat_history = vec![
ChatMessage::system(options.system_prompt.trim().to_string()),
ChatMessage::user(options.prompt.to_string()),
];
let mut chat_history = match &options.system_prompt {
Some(sp) => vec![
ChatMessage::system(sp.trim().to_string()),
ChatMessage::user(options.prompt.to_string()),
],
None => vec![ChatMessage::user(options.prompt.to_string())],
};

for m in options.history.iter() {
chat_history.push(match m {
Expand Down Expand Up @@ -229,19 +232,19 @@ impl Client for OllamaClient {
content,
invocations,
usage: res.final_data.map(|final_data| super::Usage {
input_tokens: final_data.prompt_eval_count as u32,
output_tokens: final_data.eval_count as u32,
}),
input_tokens: final_data.prompt_eval_count as u32,
output_tokens: final_data.eval_count as u32,
}),
})
} else {
log::warn!("model returned an empty message.");
Ok(ChatResponse {
content: "".to_string(),
invocations: vec![],
usage: res.final_data.map(|final_data| super::Usage {
input_tokens: final_data.prompt_eval_count as u32,
output_tokens: final_data.eval_count as u32,
}),
input_tokens: final_data.prompt_eval_count as u32,
output_tokens: final_data.eval_count as u32,
}),
})
}
}
Expand Down
25 changes: 16 additions & 9 deletions src/agent/generator/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,18 +191,25 @@ impl Client for OpenAIClient {
state: SharedState,
options: &ChatOptions,
) -> anyhow::Result<ChatResponse> {
let mut chat_history = vec![
crate::api::openai::Message {
role: Role::System,
content: Some(options.system_prompt.trim().to_string()),
tool_calls: None,
},
crate::api::openai::Message {
let mut chat_history = match &options.system_prompt {
Some(sp) => vec![
crate::api::openai::Message {
role: Role::System,
content: Some(sp.trim().to_string()),
tool_calls: None,
},
crate::api::openai::Message {
role: Role::User,
content: Some(options.prompt.trim().to_string()),
tool_calls: None,
},
],
None => vec![crate::api::openai::Message {
role: Role::User,
content: Some(options.prompt.trim().to_string()),
tool_calls: None,
},
];
}],
};

for m in options.history.iter() {
chat_history.push(match m {
Expand Down
18 changes: 15 additions & 3 deletions src/agent/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ pub struct Agent {

serializer: serialization::Strategy,
use_native_tools_format: bool,
user_only: bool,
}

impl Agent {
Expand All @@ -119,6 +120,7 @@ impl Agent {
serializer: serialization::Strategy,
conversation_window: ConversationWindow,
force_strategy: bool,
user_only: bool,
max_iterations: usize,
) -> Result<Self> {
let use_native_tools_format = if force_strategy {
Expand Down Expand Up @@ -156,6 +158,7 @@ impl Agent {
state,
task_timeout,
use_native_tools_format,
user_only,
serializer,
conversation_window,
})
Expand Down Expand Up @@ -235,9 +238,10 @@ impl Agent {
async fn on_state_update(&self, options: &ChatOptions, refresh: bool) -> Result<()> {
let mut opts = options.clone();
if refresh {
opts.system_prompt = self
.serializer
.system_prompt_for_state(&*self.state.lock().await)?;
opts.system_prompt = Some(
self.serializer
.system_prompt_for_state(&*self.state.lock().await)?,
);

let messages = self.state.lock().await.to_chat_history(&self.serializer)?;

Expand Down Expand Up @@ -368,6 +372,14 @@ impl Agent {

let system_prompt = self.serializer.system_prompt_for_state(&mut_state)?;
let prompt = mut_state.to_prompt()?;

let (system_prompt, prompt) = if self.user_only {
// combine with user prompt for models like the openai/o1 family
(None, format!("{system_prompt}\n\n{prompt}"))
} else {
(Some(system_prompt), prompt)
};

let history = mut_state.to_chat_history(&self.serializer)?;
let options = ChatOptions::new(system_prompt, prompt, history, self.conversation_window);

Expand Down
3 changes: 3 additions & 0 deletions src/cli/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ pub struct Args {
/// Run in judge mode.
#[arg(long)]
pub judge_mode: bool,
/// Only rely on user prompt. Use for models like openai/o1 family that don't allow a system prompt.
#[arg(long)]
pub user_only: bool,
/// Embedder string as <type>://<model name>@<host>:<port>
#[arg(
short = 'E',
Expand Down
1 change: 1 addition & 0 deletions src/cli/setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ pub async fn setup_agent(args: &cli::Args) -> Result<(Agent, events::Receiver)>
args.serialization.clone(),
conversation_window,
args.force_format,
args.user_only,
args.max_iterations,
)
.await?;
Expand Down
2 changes: 1 addition & 1 deletion src/cli/ui/text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ pub async fn consume_events(args: cli::Args, mut events_rx: Receiver) {
if let Some(prompt_path) = &args.save_to {
let data = format!(
"[SYSTEM PROMPT]\n\n{}\n\n[PROMPT]\n\n{}\n\n[CHAT]\n\n{}",
&opts.system_prompt,
&opts.system_prompt.unwrap_or_default(),
&opts.prompt,
opts.history
.iter()
Expand Down

0 comments on commit 5ebae31

Please sign in to comment.