From 8f74b05a3cec52f0358285c10b25c8a753da308b Mon Sep 17 00:00:00 2001 From: Kelvie Wong Date: Tue, 16 Jan 2024 00:05:07 -0800 Subject: [PATCH 1/2] Add an "extra_fields" config to localai models Because there are so many local AIs out there with a bunch of custom parameters you can set, this allows users to send in extra parameters to a local LLM runner, such as, e.g. `instruction_template: Alpaca`, so that Mixtral can take a system prompt. --- src/client/localai.rs | 15 ++++++++++++++- src/client/model.rs | 8 ++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/client/localai.rs b/src/client/localai.rs index 3bc06704..17820f9a 100644 --- a/src/client/localai.rs +++ b/src/client/localai.rs @@ -45,6 +45,7 @@ impl LocalAIClient { Model::new(client_name, &v.name) .set_capabilities(v.capabilities) .set_max_tokens(v.max_tokens) + .set_extra_fields(v.extra_fields.clone()) .set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS) }) .collect() @@ -53,7 +54,19 @@ impl LocalAIClient { fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { let api_key = self.get_api_key().ok(); - let body = openai_build_body(data, self.model.name.clone()); + let mut body = openai_build_body(data, self.model.name.clone()); + + // merge fields from extra_fields into body + if let Some(extra_fields) = &self.model.extra_fields { + let extra_fields = extra_fields + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect::>(); + + body.as_object_mut() + .unwrap() + .extend(extra_fields.into_iter()); + } let chat_endpoint = self .config diff --git a/src/client/model.rs b/src/client/model.rs index 88d32942..4e912918 100644 --- a/src/client/model.rs +++ b/src/client/model.rs @@ -11,6 +11,7 @@ pub type TokensCountFactors = (usize, usize); // (per-messages, bias) pub struct Model { pub client_name: String, pub name: String, + pub extra_fields: Option>, pub max_tokens: Option, pub tokens_count_factors: TokensCountFactors, pub capabilities: ModelCapabilities, @@ -27,6 +28,7 @@ impl Model { Self { client_name: client_name.into(), name: name.into(), + extra_fields: None, max_tokens: None, tokens_count_factors: Default::default(), capabilities: ModelCapabilities::Text, @@ -73,6 +75,11 @@ impl Model { self } + pub fn set_extra_fields(mut self, extra_fields: Option>) -> Self { + self.extra_fields = extra_fields; + self + } + pub fn set_max_tokens(mut self, max_tokens: Option) -> Self { match max_tokens { None | Some(0) => self.max_tokens = None, @@ -127,6 +134,7 @@ impl Model { #[derive(Debug, Clone, Deserialize)] pub struct ModelConfig { pub name: String, + pub extra_fields: Option>, pub max_tokens: Option, #[serde(deserialize_with = "deserialize_capabilities")] #[serde(default = "default_capabilities")] From 62718c1c1d58b019745cd012f7c62c44c6dfcfed Mon Sep 17 00:00:00 2001 From: sigoden Date: Tue, 30 Jan 2024 11:36:08 +0000 Subject: [PATCH 2/2] support ollama --- config.example.yaml | 2 ++ src/client/localai.rs | 13 +------------ src/client/model.rs | 19 ++++++++++++++++--- src/client/ollama.rs | 5 ++++- 4 files changed, 23 insertions(+), 16 deletions(-) diff --git a/config.example.yaml b/config.example.yaml index ce12cce2..22b71cc3 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -38,6 +38,8 @@ clients: models: - name: mistral max_tokens: 8192 + extra_fields: # Optional field, set custom parameters + key: value - name: llava max_tokens: 8192 capabilities: text,vision # Optional field, possible values: text, vision diff --git a/src/client/localai.rs b/src/client/localai.rs index 17820f9a..77950391 100644 --- a/src/client/localai.rs +++ b/src/client/localai.rs @@ -55,18 +55,7 @@ impl LocalAIClient { let api_key = self.get_api_key().ok(); let mut body = openai_build_body(data, self.model.name.clone()); - - // merge fields from extra_fields into body - if let Some(extra_fields) = &self.model.extra_fields { - let extra_fields = extra_fields - .iter() - .map(|(k, v)| (k.clone(), v.clone())) - .collect::>(); - - body.as_object_mut() - .unwrap() - .extend(extra_fields.into_iter()); - } + self.model.merge_extra_fields(&mut body); let chat_endpoint = self .config diff --git a/src/client/model.rs b/src/client/model.rs index 4e912918..b29166ea 100644 --- a/src/client/model.rs +++ b/src/client/model.rs @@ -11,8 +11,8 @@ pub type TokensCountFactors = (usize, usize); // (per-messages, bias) pub struct Model { pub client_name: String, pub name: String, - pub extra_fields: Option>, pub max_tokens: Option, + pub extra_fields: Option>, pub tokens_count_factors: TokensCountFactors, pub capabilities: ModelCapabilities, } @@ -75,7 +75,10 @@ impl Model { self } - pub fn set_extra_fields(mut self, extra_fields: Option>) -> Self { + pub fn set_extra_fields( + mut self, + extra_fields: Option>, + ) -> Self { self.extra_fields = extra_fields; self } @@ -129,13 +132,23 @@ impl Model { } Ok(()) } + + pub fn merge_extra_fields(&self, body: &mut serde_json::Value) { + if let (Some(body), Some(extra_fields)) = (body.as_object_mut(), &self.extra_fields) { + for (k, v) in extra_fields { + if !body.contains_key(k) { + body.insert(k.clone(), v.clone()); + } + } + } + } } #[derive(Debug, Clone, Deserialize)] pub struct ModelConfig { pub name: String, - pub extra_fields: Option>, pub max_tokens: Option, + pub extra_fields: Option>, #[serde(deserialize_with = "deserialize_capabilities")] #[serde(default = "default_capabilities")] pub capabilities: ModelCapabilities, diff --git a/src/client/ollama.rs b/src/client/ollama.rs index 0705f2e9..a3e23c90 100644 --- a/src/client/ollama.rs +++ b/src/client/ollama.rs @@ -69,6 +69,7 @@ impl OllamaClient { Model::new(client_name, &v.name) .set_capabilities(v.capabilities) .set_max_tokens(v.max_tokens) + .set_extra_fields(v.extra_fields.clone()) .set_tokens_count_factors(TOKENS_COUNT_FACTORS) }) .collect() @@ -77,7 +78,9 @@ impl OllamaClient { fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { let api_key = self.get_api_key().ok(); - let body = build_body(data, self.model.name.clone())?; + let mut body = build_body(data, self.model.name.clone())?; + + self.model.merge_extra_fields(&mut body); let chat_endpoint = self.config.chat_endpoint.as_deref().unwrap_or("/api/chat");