Skip to content

Commit

Permalink
feat: add ernie:ernie-bot-8k qianwen:qwen-max (#252)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden authored Nov 27, 2023
1 parent 2508d56 commit 18f16c6
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 15 deletions.
5 changes: 3 additions & 2 deletions src/client/ernie.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ use std::env;
const API_BASE: &str = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1";
const ACCESS_TOKEN_URL: &str = "https://aip.baidubce.com/oauth/2.0/token";

const MODELS: [(&str, &str); 3] = [
("eb-instant", "/wenxinworkshop/chat/eb-instant"),
const MODELS: [(&str, &str); 4] = [
("ernie-bot-turbo", "/wenxinworkshop/chat/eb-instant"),
("ernie-bot", "/wenxinworkshop/chat/completions"),
("ernie-bot-8k", "/wenxinworkshop/chat/ernie_bot_8k"),
("ernie-bot-4", "/wenxinworkshop/chat/completions_pro"),
];

Expand Down
26 changes: 13 additions & 13 deletions src/client/qianwen.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
use super::{QianwenClient, Client, ExtraConfig, PromptType, SendData, Model};
use super::{Client, ExtraConfig, Model, PromptType, QianwenClient, SendData};

use crate::{
config::GlobalConfig,
render::ReplyHandler,
utils::PromptKind,
};
use crate::{config::GlobalConfig, render::ReplyHandler, utils::PromptKind};

use anyhow::{anyhow, bail, Result};
use async_trait::async_trait;
Expand All @@ -17,7 +13,11 @@ use serde_json::{json, Value};
const API_URL: &str =
"https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation";

const MODELS: [(&str, usize); 2] = [("qwen-turbo", 6144), ("qwen-plus", 6144)];
const MODELS: [(&str, usize); 3] = [
("qwen-turbo", 6144),
("qwen-plus", 6144),
("qwen-max", 6144),
];

#[derive(Debug, Clone, Deserialize, Default)]
pub struct QianwenConfig {
Expand Down Expand Up @@ -58,7 +58,9 @@ impl QianwenClient {
let client_name = Self::name(local_config);
MODELS
.into_iter()
.map(|(name, max_tokens)| Model::new(client_name, name).set_max_tokens(Some(max_tokens)))
.map(|(name, max_tokens)| {
Model::new(client_name, name).set_max_tokens(Some(max_tokens))
})
.collect()
}

Expand All @@ -83,16 +85,14 @@ async fn send_message(builder: RequestBuilder) -> Result<String> {
let data: Value = builder.send().await?.json().await?;
check_error(&data)?;

let output = data["output"]["text"].as_str()
let output = data["output"]["text"]
.as_str()
.ok_or_else(|| anyhow!("Unexpected response {data}"))?;

Ok(output.to_string())
}

async fn send_message_streaming(
builder: RequestBuilder,
handler: &mut ReplyHandler,
) -> Result<()> {
async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHandler) -> Result<()> {
let mut es = builder.eventsource()?;

while let Some(event) = es.next().await {
Expand Down

0 comments on commit 18f16c6

Please sign in to comment.