Skip to content

Commit

Permalink
added hashmap to split conversations into their chats
Browse files Browse the repository at this point in the history
  • Loading branch information
Charlie Bacon committed Jun 11, 2024
1 parent 9957d4c commit e3bca00
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 17 deletions.
19 changes: 15 additions & 4 deletions src/bot.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::collections::VecDeque;

use crate::{gpt, types};
use teloxide::{prelude::*, utils::command::BotCommands};

Expand Down Expand Up @@ -41,7 +43,7 @@ pub async fn handle_commands(
bot.send_message(msg.chat.id, answer).await?;
}
Command::Mediate => {
let answer = match gpt::mediate(messages).await {
let answer = match gpt::mediate(messages, msg.chat.id).await {
Ok(response) => response,
Err(err) => format!("Error getting an answer from OpenAI: {err}"),
};
Expand All @@ -55,10 +57,19 @@ pub async fn handle_commands(

pub async fn handle_messages(messages: types::Messages, msg: Message) -> ResponseResult<()> {
let mut messages_lock = messages.write().await;
if messages_lock.len() == types::STORE_CAPACITY {
messages_lock.pop_front();
match messages_lock.get_mut(&msg.chat.id) {
Some(buffer) => {
if buffer.len() == types::STORE_CAPACITY {
buffer.pop_front();
}
buffer.push_back(msg.clone());
}
None => {
let mut buffer = VecDeque::new();
buffer.push_back(msg.clone());
messages_lock.insert(msg.chat.id, buffer);
}
}
messages_lock.push_back(msg.clone());

Ok(())
}
19 changes: 13 additions & 6 deletions src/gpt.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::types::Messages;
use crate::types;
use async_openai::{
config::OpenAIConfig,
types::{
Expand All @@ -9,6 +9,7 @@ use async_openai::{
};
use std::{env, error::Error};
use string_builder::Builder;
use teloxide::types::ChatId;

pub async fn ask(question: String) -> Result<String, Box<dyn Error>> {
let client = init_gpt_client()?;
Expand All @@ -18,7 +19,7 @@ pub async fn ask(question: String) -> Result<String, Box<dyn Error>> {
.max_tokens(512u16)
.messages(vec![
ChatCompletionRequestSystemMessageArgs::default()
.content(crate::types::PERSONALITY)
.content(types::PERSONALITY)
.build()?
.into(),
ChatCompletionRequestUserMessageArgs::default()
Expand All @@ -37,11 +38,17 @@ pub async fn ask(question: String) -> Result<String, Box<dyn Error>> {
}

// TODO: Better user and message handling
pub async fn mediate(messages: Messages) -> Result<String, Box<dyn Error>> {
pub async fn mediate(messages: types::Messages, chat_id: ChatId) -> Result<String, Box<dyn Error>> {
let messages_lock = messages.read().await;

let mut conversation = Builder::default();
for message in messages_lock.iter() {

let buffer = match messages_lock.get(&chat_id) {
Some(b) => b,
None => return Err("Buffer with selected ChatId does not exist".into()),
};

for message in buffer.iter() {
conversation.append(message.from().unwrap().full_name());
conversation.append(": ");
conversation.append(message.text().unwrap());
Expand All @@ -56,11 +63,11 @@ pub async fn mediate(messages: Messages) -> Result<String, Box<dyn Error>> {
.max_tokens(4096u16)
.messages(vec![
ChatCompletionRequestSystemMessageArgs::default()
.content(crate::types::PERSONALITY)
.content(types::PERSONALITY)
.build()?
.into(),
ChatCompletionRequestSystemMessageArgs::default()
.content(crate::types::MEDIATE_QUERY)
.content(types::MEDIATE_QUERY)
.build()?
.into(),
ChatCompletionRequestSystemMessageArgs::default()
Expand Down
9 changes: 6 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@ pub mod bot;
pub mod gpt;

pub mod types {
use std::{collections::VecDeque, sync::Arc};
use teloxide::types::Message;
use std::{
collections::{HashMap, VecDeque},
sync::Arc,
};
use teloxide::types::{ChatId, Message};
use tokio::sync::RwLock;

pub const PERSONALITY: &str= "Eres un asistente andaluz con jerga informal y algo irónica. Ayudas a todo aquel que te necesite, no sin antes quejarte un poco, ya que eres algo vago.";
pub const MEDIATE_QUERY: &str= "A partir de los siguientes mensajes, analiza una posible discusión y da la razón a alguno de los implicados, con una pequeña argumentación.";
pub const STORE_CAPACITY: usize = 200;

pub type Messages = Arc<RwLock<VecDeque<Message>>>;
pub type Messages = Arc<RwLock<HashMap<ChatId, VecDeque<Message>>>>;
}
6 changes: 2 additions & 4 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{collections::VecDeque, sync::Arc};
use std::{collections::HashMap, sync::Arc};
use telitairos_bot::{bot, types};
use teloxide::prelude::*;
use tokio::sync::RwLock;
Expand All @@ -10,9 +10,7 @@ async fn main() {

let bot = Bot::from_env();

let messages_store: types::Messages = Arc::new(RwLock::new(VecDeque::with_capacity(
crate::types::STORE_CAPACITY,
)));
let messages_store: types::Messages = Arc::new(RwLock::new(HashMap::new()));

let handler = dptree::entry()
.branch(
Expand Down

0 comments on commit e3bca00

Please sign in to comment.