From e3bca00986423c1f05fb1bbdfd7ca325542a772c Mon Sep 17 00:00:00 2001 From: Charlie Bacon Date: Tue, 11 Jun 2024 23:51:02 +0200 Subject: [PATCH] added hashmap to split conversations into their chats --- src/bot.rs | 19 +++++++++++++++---- src/gpt.rs | 19 +++++++++++++------ src/lib.rs | 9 ++++++--- src/main.rs | 6 ++---- 4 files changed, 36 insertions(+), 17 deletions(-) diff --git a/src/bot.rs b/src/bot.rs index f636eb1..e56d142 100644 --- a/src/bot.rs +++ b/src/bot.rs @@ -1,3 +1,5 @@ +use std::collections::VecDeque; + use crate::{gpt, types}; use teloxide::{prelude::*, utils::command::BotCommands}; @@ -41,7 +43,7 @@ pub async fn handle_commands( bot.send_message(msg.chat.id, answer).await?; } Command::Mediate => { - let answer = match gpt::mediate(messages).await { + let answer = match gpt::mediate(messages, msg.chat.id).await { Ok(response) => response, Err(err) => format!("Error getting an answer from OpenAI: {err}"), }; @@ -55,10 +57,19 @@ pub async fn handle_commands( pub async fn handle_messages(messages: types::Messages, msg: Message) -> ResponseResult<()> { let mut messages_lock = messages.write().await; - if messages_lock.len() == types::STORE_CAPACITY { - messages_lock.pop_front(); + match messages_lock.get_mut(&msg.chat.id) { + Some(buffer) => { + if buffer.len() == types::STORE_CAPACITY { + buffer.pop_front(); + } + buffer.push_back(msg.clone()); + } + None => { + let mut buffer = VecDeque::new(); + buffer.push_back(msg.clone()); + messages_lock.insert(msg.chat.id, buffer); + } } - messages_lock.push_back(msg.clone()); Ok(()) } diff --git a/src/gpt.rs b/src/gpt.rs index 244c317..4d0c4ac 100644 --- a/src/gpt.rs +++ b/src/gpt.rs @@ -1,4 +1,4 @@ -use crate::types::Messages; +use crate::types; use async_openai::{ config::OpenAIConfig, types::{ @@ -9,6 +9,7 @@ use async_openai::{ }; use std::{env, error::Error}; use string_builder::Builder; +use teloxide::types::ChatId; pub async fn ask(question: String) -> Result> { let client = init_gpt_client()?; @@ -18,7 +19,7 @@ pub async fn ask(question: String) -> Result> { .max_tokens(512u16) .messages(vec![ ChatCompletionRequestSystemMessageArgs::default() - .content(crate::types::PERSONALITY) + .content(types::PERSONALITY) .build()? .into(), ChatCompletionRequestUserMessageArgs::default() @@ -37,11 +38,17 @@ pub async fn ask(question: String) -> Result> { } // TODO: Better user and message handling -pub async fn mediate(messages: Messages) -> Result> { +pub async fn mediate(messages: types::Messages, chat_id: ChatId) -> Result> { let messages_lock = messages.read().await; let mut conversation = Builder::default(); - for message in messages_lock.iter() { + + let buffer = match messages_lock.get(&chat_id) { + Some(b) => b, + None => return Err("Buffer with selected ChatId does not exist".into()), + }; + + for message in buffer.iter() { conversation.append(message.from().unwrap().full_name()); conversation.append(": "); conversation.append(message.text().unwrap()); @@ -56,11 +63,11 @@ pub async fn mediate(messages: Messages) -> Result> { .max_tokens(4096u16) .messages(vec![ ChatCompletionRequestSystemMessageArgs::default() - .content(crate::types::PERSONALITY) + .content(types::PERSONALITY) .build()? .into(), ChatCompletionRequestSystemMessageArgs::default() - .content(crate::types::MEDIATE_QUERY) + .content(types::MEDIATE_QUERY) .build()? .into(), ChatCompletionRequestSystemMessageArgs::default() diff --git a/src/lib.rs b/src/lib.rs index dd5e6d5..69b9633 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,13 +2,16 @@ pub mod bot; pub mod gpt; pub mod types { - use std::{collections::VecDeque, sync::Arc}; - use teloxide::types::Message; + use std::{ + collections::{HashMap, VecDeque}, + sync::Arc, + }; + use teloxide::types::{ChatId, Message}; use tokio::sync::RwLock; pub const PERSONALITY: &str= "Eres un asistente andaluz con jerga informal y algo irónica. Ayudas a todo aquel que te necesite, no sin antes quejarte un poco, ya que eres algo vago."; pub const MEDIATE_QUERY: &str= "A partir de los siguientes mensajes, analiza una posible discusión y da la razón a alguno de los implicados, con una pequeña argumentación."; pub const STORE_CAPACITY: usize = 200; - pub type Messages = Arc>>; + pub type Messages = Arc>>>; } diff --git a/src/main.rs b/src/main.rs index 2848de9..47df712 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -use std::{collections::VecDeque, sync::Arc}; +use std::{collections::HashMap, sync::Arc}; use telitairos_bot::{bot, types}; use teloxide::prelude::*; use tokio::sync::RwLock; @@ -10,9 +10,7 @@ async fn main() { let bot = Bot::from_env(); - let messages_store: types::Messages = Arc::new(RwLock::new(VecDeque::with_capacity( - crate::types::STORE_CAPACITY, - ))); + let messages_store: types::Messages = Arc::new(RwLock::new(HashMap::new())); let handler = dptree::entry() .branch(