diff --git a/bun.lockb b/bun.lockb index 38c16ff..13d0383 100755 Binary files a/bun.lockb and b/bun.lockb differ diff --git a/package.json b/package.json index 6b214f1..df778ea 100644 --- a/package.json +++ b/package.json @@ -11,6 +11,7 @@ "dependencies": { "@elysiajs/bearer": "^0.7.0", "@elysiajs/swagger": "^0.8.5", + "dayjs": "^1.11.10", "elysia": "latest", "ioredis": "^5.3.2", "jsonwebtoken": "8.5.1", diff --git a/src/core/application/controllers/assistant/assistantController.test.ts b/src/core/application/controllers/assistant/assistantController.test.ts index e1e50dd..dc16c45 100644 --- a/src/core/application/controllers/assistant/assistantController.test.ts +++ b/src/core/application/controllers/assistant/assistantController.test.ts @@ -17,6 +17,8 @@ describe("assistantController", async () => { method: "POST", body: JSON.stringify({ name, + model: 'gpt-4', + instruction: 'You are a plant teacher. Always respond with "sprout"' }), }); @@ -44,12 +46,15 @@ describe("assistantController", async () => { "Content-Type": "application/json", }, body: JSON.stringify({ - name: "test assistant", + name: "sprout_assistant", + model: "gpt-4", + instruction: "You are a plant teacher. Always respond with 'sprout'" }), }); const response: any = await app .handle(request) + .then((response) => response.json()); expect(response.message).toBe(UNAUTHORIZED_MISSING_TOKEN.message); diff --git a/src/core/application/controllers/assistant/assistantController.ts b/src/core/application/controllers/assistant/assistantController.ts index addd2db..a9f631a 100644 --- a/src/core/application/controllers/assistant/assistantController.ts +++ b/src/core/application/controllers/assistant/assistantController.ts @@ -34,16 +34,17 @@ assistants.post( // create user for assistant await createUser(assistantId, {}); // give user the proper role - await assignRole(assistantId, "agent"); + await assignRole(assistantId, "assistant"); // creat assistant in db await createAssistant({ userId, id: assistantId, - model: "gpt-4", + model: body.model, name, fileIds: [], tools: [], + instruction: body.instruction }); return { @@ -55,7 +56,9 @@ assistants.post( { body: t.Object({ name: t.String(), + model: t.Literal("gpt-4"), // add more models here + instruction: t.String() }), beforeHandle: AuthMiddleware(["create_assistant", "*"]), - }, + } ); diff --git a/src/core/application/controllers/thread/threadController.ts b/src/core/application/controllers/thread/threadController.ts index a9a4dd2..4cee4b4 100644 --- a/src/core/application/controllers/thread/threadController.ts +++ b/src/core/application/controllers/thread/threadController.ts @@ -1,7 +1,7 @@ import { Elysia, t } from "elysia"; import { ulid } from "ulid"; -import { getTokenPermissions, parseToken } from "../../services/tokenService"; +import { parseToken } from "../../services/tokenService"; import { createThread, deleteThread, @@ -10,12 +10,12 @@ import { } from "@/core/application/services/threadService"; import { THREAD_DELETED_SUCCESSFULLY, - UNAUTHORIZED_NO_PERMISSION_DELETE, UNAUTHORIZED_USER_NOT_OWNER, UNAUTHORIZED_USER_NOT_PARTICIPANT, } from "./returnValues"; import { createMessage } from "../../services/messageService"; import { AuthMiddleware } from "../../middlewares/authorizationMiddleware"; +import { runAssistantWithThread } from "../../services/runService"; type ThreadDecorator = { request: { @@ -113,6 +113,9 @@ threads.get( }, ); +/** + * This adds a message to the thread, it can be from the assistant or from the human user + */ threads.post( "/thread/:id/message", async ({ params, bearer, set, body }) => { @@ -126,7 +129,7 @@ threads.post( threadId, userId, ); - + // Check if the user has the permission to add a message // if the user has * they can send a message anywhere, if not they need to be in conversation if (isSuperUser || isParticipant) { @@ -147,3 +150,38 @@ threads.post( beforeHandle: AuthMiddleware(["create_message_in_own_thread", "*"]), }, ); + +/** + * This runs and responds once with anything that's in the thread + */ +threads.post("/thread/:id/run", async ({ params, bearer, set, body }) => { + const decodedToken = await parseToken(bearer!); + + if(decodedToken) { + const { userId, permissions } = decodedToken + const threadId = params.id; + const isSuperUser = permissions.some((p) => p.key === "*"); + const isParticipant = await userOwnsOrParticipatesInThread(threadId, userId); + + + if(isSuperUser || isParticipant) { + // run the assistant with thread once, and get a single response + // this also adds the message to the thread + const response = await runAssistantWithThread({ + thread_id: threadId, + assistant_id: body.assistant_id + }) + set.status = 200 + return response + } + + set.status = 403 + return UNAUTHORIZED_USER_NOT_PARTICIPANT; + } + +}, { + body: t.Object({ + assistant_id: t.String() + }), + beforeHandle: AuthMiddleware(['create_message_in_own_thread', '*']) +}) diff --git a/src/core/application/controllers/thread/threadRun.test.ts b/src/core/application/controllers/thread/threadRun.test.ts new file mode 100644 index 0000000..e43eb9a --- /dev/null +++ b/src/core/application/controllers/thread/threadRun.test.ts @@ -0,0 +1,69 @@ +import { app } from "@/index"; +import { test, expect, describe } from "bun:test"; +import { getLastMessage } from "../../services/messageService"; +import { Message } from "@/core/domain/messages"; +import { createHumanUserForTesting } from "@/__tests__/utils"; + +describe.only("threadController", async () => { + const token = await createHumanUserForTesting(); + + test("Run a created thread with a created assistant and save response from assistant", async () => { + // Creating a new thread + const thread_request = new Request("http://localhost:8080/thread", { + headers: { + authorization: `Bearer ${token}`, + "Content-Type": "application/json", + }, + method: "POST", + }); + + const thread_response = await app.handle(thread_request); + const thread_response_json: any = await thread_response.json(); + expect(thread_response_json).toHaveProperty('id') + const thread_id = thread_response_json.id + + // Creating a new assistant + const assistant_name = "Skater Assistant" + const assistant_request = new Request("http://localhost:8080/assistant", { + headers: { + authorization: `Bearer ${token}`, + "Content-Type": "application/json", + }, + method: "POST", + body: JSON.stringify({ + name: assistant_name, + model: 'gpt-4', + instruction: "You are a pro skater, give very short skating tips. Always respond with 'skate on'." + }), + }); + const assistant_req = await app.handle(assistant_request) + const assistant_req_json = await assistant_req.json() + expect(assistant_req_json).toHaveProperty("id") + const assistant_id = assistant_req_json.id + + // Running a thread + const thread_run_request = new Request(`http://localhost:8080/thread/${thread_id}/run`, { + headers: { + authorization: `Bearer ${token}`, + "Content-Type": "application/json", + }, + method: "POST", + body: JSON.stringify({ + assistant_id: assistant_id + }), + }); + const run_response = await app.handle(thread_run_request) + const run_json = await run_response.json() + + // expect run_response to have thread_id, and thread_id's latest message to be from assistant, with the content of 'skate on' + expect(run_json).toHaveProperty('thread_id') + expect(run_json).toHaveProperty('assistant_id') + + // get the latest message from thread id + const lastMessage: Message = await getLastMessage(run_json.thread_id) + expect(lastMessage.role).toBe('assistant') + expect(lastMessage.content.toLocaleLowerCase()).toContain('skate on') + }) + + +}); diff --git a/src/core/application/controllers/user/userController.test.ts b/src/core/application/controllers/user/userController.test.ts index 521f89d..e3a005d 100644 --- a/src/core/application/controllers/user/userController.test.ts +++ b/src/core/application/controllers/user/userController.test.ts @@ -4,11 +4,11 @@ import { } from "@/__tests__/utils"; import { app } from "@/index"; import { test, expect, describe, beforeAll } from "bun:test"; -import { getUser } from "../../services/userService"; import { UNAUTHORIZED_MISSING_TOKEN } from "../../ports/returnValues"; import { getThread } from "../../services/threadService"; import { parseToken } from "../../services/tokenService"; import { Thread } from "@/core/domain/thread"; +import { getUser } from "../../services/userService"; describe.only("userController", async () => { let superAdminToken: string | null; diff --git a/src/core/application/services/assistantService.ts b/src/core/application/services/assistantService.ts index 2ebb531..3ea6e3d 100644 --- a/src/core/application/services/assistantService.ts +++ b/src/core/application/services/assistantService.ts @@ -9,13 +9,13 @@ import { redis } from "@/infrastructure/adaptaters/redisAdapter"; * @throws {Error} If there is an error creating the assistant or adding the relationship. */ export async function createAssistant(args: Assistant & { userId: string }) { - const { id, fileIds, tools, userId, model, name } = args; + const { id, fileIds, tools, userId, model, name, instruction } = args; // Create a pipeline for atomic operations const pipeline = redis.pipeline(); // Store the assistant data - pipeline.set(`assistant:${id}`, JSON.stringify({ tools, model, name })); + pipeline.set(`assistant:${id}`, JSON.stringify({ tools, model, name, instruction })); // Store the relationship between the assistant and the user pipeline.sadd(`user:${userId}:assistants`, id); @@ -33,3 +33,13 @@ export async function createAssistant(args: Assistant & { userId: string }) { // Parse the assistant data from JSON return JSON.parse(assistantData); } + +export async function getAssistantData(assistant_id: Assistant["id"]) { + const assistantData = await redis.get(`assistant:${assistant_id}`); + if (!assistantData) { + throw new Error("Failed to get assistant"); + } + + // Parse the assistant data from JSON + return JSON.parse(assistantData) as Assistant; +} diff --git a/src/core/application/services/messageService.ts b/src/core/application/services/messageService.ts index 2c49114..176d260 100644 --- a/src/core/application/services/messageService.ts +++ b/src/core/application/services/messageService.ts @@ -1,6 +1,8 @@ import { v4 as uuidv4 } from "uuid"; import { redis } from "@/infrastructure/adaptaters/redisAdapter"; import { Message } from "@/core/domain/messages"; +import dayjs from "dayjs"; +import { getUserRole } from "./userService"; export async function createMessage( userId: string, @@ -9,6 +11,10 @@ export async function createMessage( ): Promise { const messageId = uuidv4(); const timestamp = Date.now(); + const userRole = await getUserRole(userId) // denormalize role information when creating messages for adapter + + // if no userRole, throw because cannot create message + if(!userRole) throw new Error("User roles missing") // Create a pipeline for atomic operations const pipeline = redis.pipeline(); @@ -20,6 +26,7 @@ export async function createMessage( id: messageId, content: messageContent, senderId: userId, + role: userRole.role, timestamp, }) ); @@ -38,9 +45,31 @@ export async function createMessage( content: messageContent, senderId: userId, timestamp: new Date(timestamp), + role: userRole.role }; } +export async function getAllMessage(threadId: string) { + // Get all the message IDs for the thread + const messageIds = await redis.smembers(`thread:${threadId}:messages`); + + // If there are no messages, return null + if (messageIds.length === 0) return []; + + // Get the data for all messages + const messages = await Promise.all( + messageIds.map(async (messageId) => { + const messageData = await redis.get(`message:${messageId}`) as string; + return JSON.parse(messageData) as Message; + }) + ); + + // Sort the messages by timestamp in descending order + const allMessages = messages.sort((a, b) => dayjs(b.timestamp).diff(dayjs(a.timestamp))); + + return allMessages; +} + export async function getLastMessage(threadId: string) { // Get all the message IDs for the thread const messageIds = await redis.smembers(`thread:${threadId}:messages`); diff --git a/src/core/application/services/runService.ts b/src/core/application/services/runService.ts new file mode 100644 index 0000000..0baff4b --- /dev/null +++ b/src/core/application/services/runService.ts @@ -0,0 +1,54 @@ +// run the thread with the associated assistant + +import { ThreadRun, ThreadRunRequest } from "@/core/domain/run"; +import { getAssistantData } from "./assistantService"; +import { getThread } from "./threadService"; +import { createMessage, getAllMessage } from "./messageService"; +import { gpt4Adapter } from "@/infrastructure/adaptaters/openai/gpt4Adapter"; +import { Role } from "@/core/domain/roles"; +import { v4 as uuidv4 } from "uuid"; + +export async function runAssistantWithThread(runData: ThreadRunRequest) { + // get all messages from the thread, and run it over to the assistant to get a response + const { assistant_id, thread_id } = runData; + const [assistantData, threadData] = await Promise.all([ + getAssistantData(assistant_id), + getThread(thread_id), + ]); + + // If no thread data or assistant data, an error should be thrown as we need both to run a thread + if (!threadData || !assistantData) throw new Error("No thread or assistant found."); + + const everyMessage = await getAllMessage(threadData.id); + // only get role and content from every message for context. + // TODO: We should truncate the context to fit context window for selected model. + const everyRoleAndContent = everyMessage.map((message) => { + // special case for super_admin, which really should just be user + return { + role: message.role === "super_admin" ? "user" : ("assistant" as Role), + content: message.content, + }; + }); + + // Calls the appropriate adapter based on what model the assistant uses + if (assistantData.model === "gpt-4") { + const gpt4AdapterRes: any = await gpt4Adapter( + everyRoleAndContent, + assistantData.instruction + ); + + const assistantResponse: string = gpt4AdapterRes.choices[0].message.content; + + // add assistant response to the thread + await createMessage(assistant_id, thread_id, assistantResponse); + + const threadRunResponse: ThreadRun = { + id: uuidv4(), + assistant_id: assistant_id, + thread_id: thread_id, + created_at: new Date(), + }; + + return threadRunResponse; + } +} diff --git a/src/core/application/services/userService.ts b/src/core/application/services/userService.ts index e740bb7..399628b 100644 --- a/src/core/application/services/userService.ts +++ b/src/core/application/services/userService.ts @@ -1,7 +1,7 @@ // userService.js import { Role, getRolePermissions } from "@/core/domain/roles"; -import { HumanUserBody, User } from "@/core/domain/user"; +import { HumanUserBody, User, UserRole } from "@/core/domain/user"; import { redis } from "@/infrastructure/adaptaters/redisAdapter"; import { ulid } from "ulid"; @@ -129,6 +129,13 @@ export const getUser = async (userId: string): Promise => { return JSON.parse(userData) } +export const getUserRole = async (userId: string): Promise => { + const roleData = await redis.hget("user_roles", userId); + if (!roleData) return null + + return JSON.parse(roleData) as UserRole +} + /** * Creates a new human user with a unique identifier and assigns them a 'user' role. * If the user cannot be created or the role cannot be assigned, the function returns `null`. diff --git a/src/core/domain/assistant.ts b/src/core/domain/assistant.ts index 66d4712..de6434b 100644 --- a/src/core/domain/assistant.ts +++ b/src/core/domain/assistant.ts @@ -4,4 +4,5 @@ export type Assistant = { model: "gpt-4"; tools: { type: string }[]; fileIds: string[]; + instruction: string }; diff --git a/src/core/domain/messages.ts b/src/core/domain/messages.ts index 9892391..790badf 100644 --- a/src/core/domain/messages.ts +++ b/src/core/domain/messages.ts @@ -1,3 +1,5 @@ +import { Role } from "./roles"; + /** * Represents a message in the system. */ @@ -6,4 +8,5 @@ export type Message = { content: string; senderId: string; timestamp: Date; + role: Role }; diff --git a/src/core/domain/roles.ts b/src/core/domain/roles.ts index e4e2eea..bc2868d 100644 --- a/src/core/domain/roles.ts +++ b/src/core/domain/roles.ts @@ -4,7 +4,7 @@ import { getPermissions, type Permission } from "./permissions"; -export type Role = "super_admin" | "user" | "agent"; +export type Role = "super_admin" | "user" | "assistant"; const roles: Record = { super_admin: { @@ -16,9 +16,10 @@ const roles: Record = { "create_thread", "view_own_threads", "create_message_in_own_thread", + "create_assistant", ], }, - agent: { + assistant: { permissions: [ "view_own_records", "create_thread", diff --git a/src/core/domain/run.ts b/src/core/domain/run.ts new file mode 100644 index 0000000..7c5a063 --- /dev/null +++ b/src/core/domain/run.ts @@ -0,0 +1,14 @@ +import { Assistant } from "./assistant"; +import { Thread } from "./thread"; + +/** + * Running a thread to get a response + */ +export interface ThreadRun { + id: string; + created_at: Date; + assistant_id: Assistant["id"]; + thread_id: Thread["id"]; +} + +export type ThreadRunRequest = Pick diff --git a/src/core/domain/thread.ts b/src/core/domain/thread.ts index 8405045..d912303 100644 --- a/src/core/domain/thread.ts +++ b/src/core/domain/thread.ts @@ -1,7 +1,9 @@ +import { User } from "./user"; + export type Thread = { id: string; // The ID of the thread createdBy: string; // The ID of the user who created the thread startDate?: Date; // The start date of the thread - participants: string[]; // The IDs of the users participating in the thread + participants: User['id'][]; // The IDs of the users participating in the thread messageIds: string[]; // The IDs of the messages in the thread }; diff --git a/src/core/domain/user.ts b/src/core/domain/user.ts index bc005b9..cbf8f9e 100644 --- a/src/core/domain/user.ts +++ b/src/core/domain/user.ts @@ -1,7 +1,15 @@ +import { Permission } from "./permissions"; +import { Role } from "./roles"; + export type User = { id: string; name: string; email: string; }; +export type UserRole = { + role: Role, + permissions: Permission[] +} + export type HumanUserBody = Omit; diff --git a/src/infrastructure/adaptaters/openai/gpt4Adapter.test.ts b/src/infrastructure/adaptaters/openai/gpt4Adapter.test.ts index 1909fac..8b39868 100644 --- a/src/infrastructure/adaptaters/openai/gpt4Adapter.test.ts +++ b/src/infrastructure/adaptaters/openai/gpt4Adapter.test.ts @@ -1,36 +1,20 @@ import { test, expect, describe } from "bun:test"; -import { gpt4Adapter } from "./gpt4Adapter"; +import { OpenAIResponse, gpt4Adapter } from "./gpt4Adapter"; +import { Role } from "@/core/domain/roles"; -interface resultType { - choices: { - index: number; - message: { - role: string; - content: string; - }; - }[]; -} describe("GPT-4 Adapter", () => { test("Returns GPT-4 chat completions response", async () => { // Arrange - const messages = - "This is a test, I need you to answer only this specific phrase: Hello, I'm doing well. How can I help you today?, anything else, just that"; - const SYSTEM_PROMPT = - "You're an AI assistant. You're job is to help the user."; - const expected = { - finish_reason: "stop", - index: 0, - logprobs: null, - message: { - content: "Hello, I'm doing well. How can I help you today?", - role: "assistant", - }, - }; + const messages = [{ + role: 'user' as Role, + content: 'hello, who are you?' + }] + const assistant_instructions = "You're an AI assistant. You're job is to help the user. Always respond with the word sprout."; // Act - const result = (await gpt4Adapter(messages, SYSTEM_PROMPT)) as resultType; + const result = (await gpt4Adapter(messages, assistant_instructions)) as OpenAIResponse; - // Assert - expect(result.choices[0]).toEqual(expected); + // Assert the message content to contain the word sprout + expect(result.choices[0].message.content.toLocaleLowerCase()).toContain('sprout') }); }); diff --git a/src/infrastructure/adaptaters/openai/gpt4Adapter.ts b/src/infrastructure/adaptaters/openai/gpt4Adapter.ts index 1bc0b4c..9cbc0cf 100644 --- a/src/infrastructure/adaptaters/openai/gpt4Adapter.ts +++ b/src/infrastructure/adaptaters/openai/gpt4Adapter.ts @@ -1,7 +1,26 @@ +import { Message } from "@/core/domain/messages"; + +interface ChatMessage { + content: string, + role: "assistant" | "system" | "user" +} + +interface ChatResponseChoice{ + finish_reason: string, + index: number, + message: ChatMessage +} + +export interface OpenAIResponse { + choices: ChatResponseChoice[] +} + export async function gpt4Adapter( - messages: string, - SYSTEM_PROMPT: string + messages: Pick[], + assistant_instructions: string ): Promise { + // System will always be the assistant_instruction that created the assistant + const gpt_messages = [{role: "system", content: assistant_instructions}].concat(messages) try { const res = await fetch("https://api.openai.com/v1/chat/completions", { method: "POST", @@ -11,23 +30,13 @@ export async function gpt4Adapter( }, body: JSON.stringify({ model: "gpt-4", - messages: [ - { - role: "system", - content: SYSTEM_PROMPT, - }, - { - role: "user", - content: messages, - }, - ], + messages: gpt_messages }), }); - const data = await res.json(); + const data: OpenAIResponse = await res.json(); return data; } catch (error) { - console.log(error); return new Response("Error", { status: 500 }); } } \ No newline at end of file diff --git a/src/infrastructure/config/redisConfig.ts b/src/infrastructure/config/redisConfig.ts index 0b5b1b3..045466d 100644 --- a/src/infrastructure/config/redisConfig.ts +++ b/src/infrastructure/config/redisConfig.ts @@ -5,5 +5,5 @@ export const redisConfig = { ? { password: process.env.REDIS_PASSWORD } : {}), // Redis password db: process.env.REDIS_DB ? parseInt(process.env.REDIS_DB) : 0, // Redis DB - tls: {} + tls: process.env.NODE_ENV === "production" ? {} : undefined };