Skip to content

Commit 5ebae31

Browse files
committed
new: implemented new --user-only argument for openai/o1 models
1 parent 15a99b3 commit 5ebae31

File tree

10 files changed

+87
-46
lines changed

10 files changed

+87
-46
lines changed

README.md

+5-3
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,19 @@ Nerve features integrations for any model accessible via the following providers
3131
|----------|----------------------------|------------------|
3232
| **Ollama** | - | `ollama://llama3@localhost:11434` |
3333
| **Groq** | `GROQ_API_KEY` | `groq://llama3-70b-8192` |
34-
| **OpenAI** | `OPENAI_API_KEY` | `openai://gpt-4` |
34+
| **OpenAI**¹ | `OPENAI_API_KEY` | `openai://gpt-4` |
3535
| **Fireworks** | `LLM_FIREWORKS_KEY` | `fireworks://llama-v3-70b-instruct` |
36-
| **Huggingface**¹ | `HF_API_TOKEN` | `hf://[email protected]` |
36+
| **Huggingface**² | `HF_API_TOKEN` | `hf://[email protected]` |
3737
| **Anthropic** | `ANTHROPIC_API_KEY` | `anthropic://claude` |
3838
| **Nvidia NIM** | `NIM_API_KEY` | `nim://nvidia/nemotron-4-340b-instruct` |
3939
| **DeepSeek** | `DEEPSEEK_API_KEY` | `deepseek://deepseek-chat` |
4040
| **xAI** | `XAI_API_KEY` | `xai://grok-beta` |
4141
| **Mistral.ai** | `MISTRAL_API_KEY` | `mistral://mistral-large-latest` |
4242
| **Novita** | `NOVITA_API_KEY` | `novita://meta-llama/llama-3.1-70b-instruct` |
4343

44-
¹ Refer to [this document](https://huggingface.co/blog/tgi-messages-api#using-inference-endpoints-with-openai-client-libraries) for how to configure a custom Huggingface endpoint.
44+
¹ **o1-preview and o1 models do not support function calling directly** and do not support a system prompt. It is possible to workaround this by adding the `--user-only` flag to the command line.
45+
46+
² Refer to [this document](https://huggingface.co/blog/tgi-messages-api#using-inference-endpoints-with-openai-client-libraries) for how to configure a custom Huggingface endpoint.
4547

4648
## Installing with Cargo
4749

src/agent/generator/anthropic.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,10 @@ impl Client for AnthropicClient {
183183

184184
let request_body = MessagesRequestBody {
185185
model: self.model,
186-
system: Some(SystemPrompt::new(options.system_prompt.trim())),
186+
system: match &options.system_prompt {
187+
Some(sp) => Some(SystemPrompt::new(sp.trim())),
188+
None => None,
189+
},
187190
messages,
188191
max_tokens,
189192
tools: if tools.is_empty() { None } else { Some(tools) },

src/agent/generator/groq.rs

+24-14
Original file line numberDiff line numberDiff line change
@@ -114,20 +114,30 @@ impl Client for GroqClient {
114114
state: SharedState,
115115
options: &ChatOptions,
116116
) -> anyhow::Result<ChatResponse> {
117-
let mut chat_history = vec![
118-
crate::api::groq::completion::message::Message::SystemMessage {
119-
role: Some("system".to_string()),
120-
content: Some(options.system_prompt.trim().to_string()),
121-
name: None,
122-
tool_call_id: None,
123-
},
124-
crate::api::groq::completion::message::Message::UserMessage {
125-
role: Some("user".to_string()),
126-
content: Some(options.prompt.trim().to_string()),
127-
name: None,
128-
tool_call_id: None,
129-
},
130-
];
117+
let mut chat_history = match &options.system_prompt {
118+
Some(sp) => vec![
119+
crate::api::groq::completion::message::Message::SystemMessage {
120+
role: Some("system".to_string()),
121+
content: Some(sp.trim().to_string()),
122+
name: None,
123+
tool_call_id: None,
124+
},
125+
crate::api::groq::completion::message::Message::UserMessage {
126+
role: Some("user".to_string()),
127+
content: Some(options.prompt.trim().to_string()),
128+
name: None,
129+
tool_call_id: None,
130+
},
131+
],
132+
None => vec![
133+
crate::api::groq::completion::message::Message::UserMessage {
134+
role: Some("user".to_string()),
135+
content: Some(options.prompt.trim().to_string()),
136+
name: None,
137+
tool_call_id: None,
138+
},
139+
],
140+
};
131141

132142
let mut call_idx = 0;
133143

src/agent/generator/mod.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ mod deepseek;
1515
mod fireworks;
1616
mod groq;
1717
mod huggingface;
18+
mod mistral;
1819
mod nim;
1920
mod novita;
2021
mod ollama;
2122
mod openai;
2223
mod openai_compatible;
2324
mod xai;
24-
mod mistral;
2525

2626
pub(crate) mod history;
2727
mod options;
@@ -36,14 +36,14 @@ lazy_static! {
3636

3737
#[derive(Clone, Debug, Serialize, Deserialize)]
3838
pub struct ChatOptions {
39-
pub system_prompt: String,
39+
pub system_prompt: Option<String>,
4040
pub prompt: String,
4141
pub history: ChatHistory,
4242
}
4343

4444
impl ChatOptions {
4545
pub fn new(
46-
system_prompt: String,
46+
system_prompt: Option<String>,
4747
prompt: String,
4848
conversation: Vec<Message>,
4949
history_strategy: ConversationWindow,
@@ -97,8 +97,8 @@ pub trait Client: mini_rag::Embedder + Send + Sync {
9797

9898
async fn chat(&self, state: SharedState, options: &ChatOptions) -> Result<ChatResponse>;
9999

100-
async fn check_native_tools_support(&self) -> Result<bool> {
101-
Ok(false)
100+
async fn supports_system_prompt(&self) -> Result<bool> {
101+
Ok(true)
102102
}
103103

104104
async fn check_rate_limit(&self, error: &str) -> bool {

src/agent/generator/ollama.rs

+13-10
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,13 @@ impl Client for OllamaClient {
102102
// - msg 1
103103
// - ...
104104
// - msg n
105-
let mut chat_history = vec![
106-
ChatMessage::system(options.system_prompt.trim().to_string()),
107-
ChatMessage::user(options.prompt.to_string()),
108-
];
105+
let mut chat_history = match &options.system_prompt {
106+
Some(sp) => vec![
107+
ChatMessage::system(sp.trim().to_string()),
108+
ChatMessage::user(options.prompt.to_string()),
109+
],
110+
None => vec![ChatMessage::user(options.prompt.to_string())],
111+
};
109112

110113
for m in options.history.iter() {
111114
chat_history.push(match m {
@@ -229,19 +232,19 @@ impl Client for OllamaClient {
229232
content,
230233
invocations,
231234
usage: res.final_data.map(|final_data| super::Usage {
232-
input_tokens: final_data.prompt_eval_count as u32,
233-
output_tokens: final_data.eval_count as u32,
234-
}),
235+
input_tokens: final_data.prompt_eval_count as u32,
236+
output_tokens: final_data.eval_count as u32,
237+
}),
235238
})
236239
} else {
237240
log::warn!("model returned an empty message.");
238241
Ok(ChatResponse {
239242
content: "".to_string(),
240243
invocations: vec![],
241244
usage: res.final_data.map(|final_data| super::Usage {
242-
input_tokens: final_data.prompt_eval_count as u32,
243-
output_tokens: final_data.eval_count as u32,
244-
}),
245+
input_tokens: final_data.prompt_eval_count as u32,
246+
output_tokens: final_data.eval_count as u32,
247+
}),
245248
})
246249
}
247250
}

src/agent/generator/openai.rs

+16-9
Original file line numberDiff line numberDiff line change
@@ -191,18 +191,25 @@ impl Client for OpenAIClient {
191191
state: SharedState,
192192
options: &ChatOptions,
193193
) -> anyhow::Result<ChatResponse> {
194-
let mut chat_history = vec![
195-
crate::api::openai::Message {
196-
role: Role::System,
197-
content: Some(options.system_prompt.trim().to_string()),
198-
tool_calls: None,
199-
},
200-
crate::api::openai::Message {
194+
let mut chat_history = match &options.system_prompt {
195+
Some(sp) => vec![
196+
crate::api::openai::Message {
197+
role: Role::System,
198+
content: Some(sp.trim().to_string()),
199+
tool_calls: None,
200+
},
201+
crate::api::openai::Message {
202+
role: Role::User,
203+
content: Some(options.prompt.trim().to_string()),
204+
tool_calls: None,
205+
},
206+
],
207+
None => vec![crate::api::openai::Message {
201208
role: Role::User,
202209
content: Some(options.prompt.trim().to_string()),
203210
tool_calls: None,
204-
},
205-
];
211+
}],
212+
};
206213

207214
for m in options.history.iter() {
208215
chat_history.push(match m {

src/agent/mod.rs

+15-3
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ pub struct Agent {
108108

109109
serializer: serialization::Strategy,
110110
use_native_tools_format: bool,
111+
user_only: bool,
111112
}
112113

113114
impl Agent {
@@ -119,6 +120,7 @@ impl Agent {
119120
serializer: serialization::Strategy,
120121
conversation_window: ConversationWindow,
121122
force_strategy: bool,
123+
user_only: bool,
122124
max_iterations: usize,
123125
) -> Result<Self> {
124126
let use_native_tools_format = if force_strategy {
@@ -156,6 +158,7 @@ impl Agent {
156158
state,
157159
task_timeout,
158160
use_native_tools_format,
161+
user_only,
159162
serializer,
160163
conversation_window,
161164
})
@@ -235,9 +238,10 @@ impl Agent {
235238
async fn on_state_update(&self, options: &ChatOptions, refresh: bool) -> Result<()> {
236239
let mut opts = options.clone();
237240
if refresh {
238-
opts.system_prompt = self
239-
.serializer
240-
.system_prompt_for_state(&*self.state.lock().await)?;
241+
opts.system_prompt = Some(
242+
self.serializer
243+
.system_prompt_for_state(&*self.state.lock().await)?,
244+
);
241245

242246
let messages = self.state.lock().await.to_chat_history(&self.serializer)?;
243247

@@ -368,6 +372,14 @@ impl Agent {
368372

369373
let system_prompt = self.serializer.system_prompt_for_state(&mut_state)?;
370374
let prompt = mut_state.to_prompt()?;
375+
376+
let (system_prompt, prompt) = if self.user_only {
377+
// combine with user prompt for models like the openai/o1 family
378+
(None, format!("{system_prompt}\n\n{prompt}"))
379+
} else {
380+
(Some(system_prompt), prompt)
381+
};
382+
371383
let history = mut_state.to_chat_history(&self.serializer)?;
372384
let options = ChatOptions::new(system_prompt, prompt, history, self.conversation_window);
373385

src/cli/cli.rs

+3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ pub struct Args {
1414
/// Run in judge mode.
1515
#[arg(long)]
1616
pub judge_mode: bool,
17+
/// Only rely on user prompt. Use for models like openai/o1 family that don't allow a system prompt.
18+
#[arg(long)]
19+
pub user_only: bool,
1720
/// Embedder string as <type>://<model name>@<host>:<port>
1821
#[arg(
1922
short = 'E',

src/cli/setup.rs

+1
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ pub async fn setup_agent(args: &cli::Args) -> Result<(Agent, events::Receiver)>
101101
args.serialization.clone(),
102102
conversation_window,
103103
args.force_format,
104+
args.user_only,
104105
args.max_iterations,
105106
)
106107
.await?;

src/cli/ui/text.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ pub async fn consume_events(args: cli::Args, mut events_rx: Receiver) {
7373
if let Some(prompt_path) = &args.save_to {
7474
let data = format!(
7575
"[SYSTEM PROMPT]\n\n{}\n\n[PROMPT]\n\n{}\n\n[CHAT]\n\n{}",
76-
&opts.system_prompt,
76+
&opts.system_prompt.unwrap_or_default(),
7777
&opts.prompt,
7878
opts.history
7979
.iter()

0 commit comments

Comments
 (0)