From ee97b7ebc5a09e8b26ea3558b7bf80b7b27f84c0 Mon Sep 17 00:00:00 2001 From: AyushAgrawal-A2 Date: Mon, 13 Jan 2025 10:51:34 +0530 Subject: [PATCH] chat logging tests --- .../src/routes/v0/ai.chat.POST.test.ts | 174 +++++++++++++++--- ...POST.test.ts => ai.feedback.PATCH.test.ts} | 21 +-- ....feedback.POST.ts => ai.feedback.PATCH.ts} | 4 +- ...nal.user.$auth0Id.connections.$uuid.GET.ts | 2 +- .../src/routes/v0/teams.$uuid.PATCH.ts | 8 +- .../src/routes/v0/teams.$uuid.invites.POST.ts | 7 +- .../v0/teams.$uuid.users.$userId.PATCH.ts | 7 +- quadratic-client/src/shared/api/apiClient.ts | 6 +- quadratic-shared/typesAndSchemas.ts | 4 +- 9 files changed, 183 insertions(+), 50 deletions(-) rename quadratic-api/src/routes/v0/{ai.feedback.POST.test.ts => ai.feedback.PATCH.test.ts} (91%) rename quadratic-api/src/routes/v0/{ai.feedback.POST.ts => ai.feedback.PATCH.ts} (92%) diff --git a/quadratic-api/src/routes/v0/ai.chat.POST.test.ts b/quadratic-api/src/routes/v0/ai.chat.POST.test.ts index f74098a0ea..44e9dca878 100644 --- a/quadratic-api/src/routes/v0/ai.chat.POST.test.ts +++ b/quadratic-api/src/routes/v0/ai.chat.POST.test.ts @@ -1,7 +1,11 @@ import request from 'supertest'; import { app } from '../../app'; +import dbClient from '../../dbClient'; +import { getFile } from '../../middleware/getFile'; import { clearDb, createFile, createTeam, createUser } from '../../tests/testDataGenerator'; +const auth0Id = 'user'; + const payload = { chatId: '00000000-0000-0000-0000-000000000000', fileUuid: '11111111-1111-1111-1111-111111111111', @@ -16,21 +20,6 @@ const payload = { useQuadraticContext: false, }; -beforeAll(async () => { - const user = await createUser({ auth0Id: 'user' }); - const team = await createTeam({ users: [{ userId: user.id, role: 'OWNER' }] }); - await createFile({ - data: { - uuid: payload.fileUuid, - name: 'Untitled', - ownerTeamId: team.id, - creatorUserId: user.id, - }, - }); -}); - -afterAll(clearDb); - jest.mock('@anthropic-ai/bedrock-sdk', () => ({ AnthropicBedrock: jest.fn().mockImplementation(() => ({ messages: { @@ -53,18 +42,93 @@ jest.mock('@anthropic-ai/bedrock-sdk', () => ({ })), })); +beforeAll(async () => { + const user = await createUser({ auth0Id }); + const team = await createTeam({ users: [{ userId: user.id, role: 'OWNER' }] }); + await createFile({ + data: { + uuid: payload.fileUuid, + name: 'Untitled', + ownerTeamId: team.id, + creatorUserId: user.id, + }, + }); +}); + +afterAll(clearDb); + describe('POST /v0/ai/chat', () => { - describe('an unauthorized user', () => { - it('responds with a 401', async () => { - await request(app).post('/v0/ai/chat').send(payload).set('Authorization', `Bearer InvalidToken user`).expect(401); + describe('authentication', () => { + it('responds with a 401 when the token is invalid', async () => { + await request(app) + .post('/v0/ai/chat') + .send({ ...payload, chatId: '00000000-0000-0000-0000-000000000001' }) + .set('Authorization', `Bearer InvalidToken user`) + .expect(401); + }); + + it('responds with model response when the token is valid', async () => { + await request(app) + .post('/v0/ai/chat') + .send({ ...payload, chatId: '00000000-0000-0000-0000-000000000002' }) + .set('Authorization', `Bearer ValidToken user`) + .expect(200) + .expect(({ body }) => { + expect(body).toEqual({ + role: 'assistant', + content: 'This is a mocked response from Claude', + contextType: 'userPrompt', + toolCalls: [ + { + id: 'tool_123', + name: 'example_tool', + arguments: JSON.stringify({ param1: 'value1' }), + loading: false, + }, + ], + model: payload.model, + }); + }); + + // wait for the chat to be saved + await new Promise((resolve) => setTimeout(resolve, 250)); }); }); - describe('an authorized user', () => { - it('responds with a 200', async () => { + describe('Analytics AI Chat', () => { + beforeEach(async () => { + await dbClient.$transaction([ + dbClient.analyticsAIChatMessage.deleteMany(), + dbClient.analyticsAIChat.deleteMany(), + ]); + }); + + it('saves the chat in storage when analyticsAi is enabled', async () => { + const analyticsAIChatsBefore = await dbClient.analyticsAIChat.findMany(); + expect(analyticsAIChatsBefore.length).toBe(0); + + const user = await dbClient.user.findUnique({ + where: { + auth0Id, + }, + }); + expect(user).not.toBeNull(); + if (!user) { + throw new Error('User not found'); + } + + const { + file: { ownerTeam }, + } = await getFile({ uuid: payload.fileUuid, userId: user.id }); + expect(ownerTeam).not.toBeNull(); + if (!ownerTeam) { + throw new Error('Owner team not found'); + } + expect(ownerTeam.settingAnalyticsAi).toBe(true); + await request(app) .post('/v0/ai/chat') - .send(payload) + .send({ ...payload, chatId: '00000000-0000-0000-0000-000000000003' }) .set('Authorization', `Bearer ValidToken user`) .expect(200) .expect(({ body }) => { @@ -83,6 +147,74 @@ describe('POST /v0/ai/chat', () => { model: payload.model, }); }); + + // wait for the chat to be saved + await new Promise((resolve) => setTimeout(resolve, 250)); + + const analyticsAIChatsAfter = await dbClient.analyticsAIChat.findMany(); + expect(analyticsAIChatsAfter.length).toBe(1); + }); + + it('does not save the chat in storage when analyticsAi is disabled', async () => { + const analyticsAIChatsBefore = await dbClient.analyticsAIChat.findMany(); + expect(analyticsAIChatsBefore.length).toBe(0); + + const user = await dbClient.user.findUnique({ + where: { + auth0Id, + }, + }); + expect(user).not.toBeNull(); + if (!user) { + throw new Error('User not found'); + } + + const { + file: { ownerTeam }, + } = await getFile({ uuid: payload.fileUuid, userId: user.id }); + expect(ownerTeam).not.toBeNull(); + if (!ownerTeam) { + throw new Error('Owner team not found'); + } + expect(ownerTeam.settingAnalyticsAi).toBe(true); + + await request(app) + .patch(`/v0/teams/${ownerTeam.uuid}`) + .set('Authorization', `Bearer ValidToken user`) + .send({ settings: { analyticsAi: false } }) + + .expect(200) + .expect((res) => { + expect(res.body.settings.analyticsAi).toBe(false); + }); + + await request(app) + .post('/v0/ai/chat') + .set('Authorization', `Bearer ValidToken user`) + .send({ ...payload, chatId: '00000000-0000-0000-0000-000000000004' }) + .expect(200) + .expect(({ body }) => { + expect(body).toEqual({ + role: 'assistant', + content: 'This is a mocked response from Claude', + contextType: 'userPrompt', + toolCalls: [ + { + id: 'tool_123', + name: 'example_tool', + arguments: JSON.stringify({ param1: 'value1' }), + loading: false, + }, + ], + model: payload.model, + }); + }); + + // wait for the chat to be saved + await new Promise((resolve) => setTimeout(resolve, 250)); + + const analyticsAIChatsAfter = await dbClient.analyticsAIChat.findMany(); + expect(analyticsAIChatsAfter.length).toBe(0); }); }); }); diff --git a/quadratic-api/src/routes/v0/ai.feedback.POST.test.ts b/quadratic-api/src/routes/v0/ai.feedback.PATCH.test.ts similarity index 91% rename from quadratic-api/src/routes/v0/ai.feedback.POST.test.ts rename to quadratic-api/src/routes/v0/ai.feedback.PATCH.test.ts index d5c6829524..76c568fec4 100644 --- a/quadratic-api/src/routes/v0/ai.feedback.POST.test.ts +++ b/quadratic-api/src/routes/v0/ai.feedback.PATCH.test.ts @@ -23,15 +23,15 @@ beforeAll(async () => { afterAll(clearDb); describe('POST /v0/ai/feedback', () => { - describe('an unauthorized user', () => { - it('responds with a 401', async () => { + describe('authentication', () => { + it('responds with a 401 when the token is invalid', async () => { await request(app) - .post('/v0/ai/feedback') + .patch('/v0/ai/feedback') + .set('Authorization', `Bearer InvalidToken user`) .send({ ...payload, like: true, }) - .set('Authorization', `Bearer InvalidToken user`) .expect(401); }); }); @@ -51,13 +51,12 @@ describe('POST /v0/ai/feedback', () => { // create a like await request(app) - .post('/v0/ai/feedback') + .patch('/v0/ai/feedback') + .set('Authorization', `Bearer ValidToken user`) .send({ ...payload, like: true, }) - .set('Accept', 'application/json') - .set('Authorization', `Bearer ValidToken user`) .expect(200) .expect(({ body }) => { expect(body.message).toBe('Feedback received'); @@ -75,12 +74,12 @@ describe('POST /v0/ai/feedback', () => { // set to dislike await request(app) - .post('/v0/ai/feedback') + .patch('/v0/ai/feedback') + .set('Authorization', `Bearer ValidToken user`) .send({ ...payload, like: false, }) - .set('Authorization', `Bearer ValidToken user`) .expect(200) .expect(({ body }) => { expect(body.message).toBe('Feedback received'); @@ -98,12 +97,12 @@ describe('POST /v0/ai/feedback', () => { // unset like await request(app) - .post('/v0/ai/feedback') + .patch('/v0/ai/feedback') + .set('Authorization', `Bearer ValidToken user`) .send({ ...payload, like: null, }) - .set('Authorization', `Bearer ValidToken user`) .expect(200) .expect(({ body }) => { expect(body.message).toBe('Feedback received'); diff --git a/quadratic-api/src/routes/v0/ai.feedback.POST.ts b/quadratic-api/src/routes/v0/ai.feedback.PATCH.ts similarity index 92% rename from quadratic-api/src/routes/v0/ai.feedback.POST.ts rename to quadratic-api/src/routes/v0/ai.feedback.PATCH.ts index 76be448de4..7e0db7c55f 100644 --- a/quadratic-api/src/routes/v0/ai.feedback.POST.ts +++ b/quadratic-api/src/routes/v0/ai.feedback.PATCH.ts @@ -10,10 +10,10 @@ import type { RequestWithUser } from '../../types/Request'; export default [validateAccessToken, handler]; const schema = z.object({ - body: ApiSchemas['/v0/ai/feedback.POST.request'], + body: ApiSchemas['/v0/ai/feedback.PATCH.request'], }); -async function handler(req: RequestWithUser, res: Response) { +async function handler(req: RequestWithUser, res: Response) { const { body: { chatId, messageIndex, like }, } = parseRequest(req, schema); diff --git a/quadratic-api/src/routes/v0/internal.user.$auth0Id.connections.$uuid.GET.ts b/quadratic-api/src/routes/v0/internal.user.$auth0Id.connections.$uuid.GET.ts index 755e891014..dd32fa4d0b 100644 --- a/quadratic-api/src/routes/v0/internal.user.$auth0Id.connections.$uuid.GET.ts +++ b/quadratic-api/src/routes/v0/internal.user.$auth0Id.connections.$uuid.GET.ts @@ -1,4 +1,4 @@ -import { Request, Response } from 'express'; +import type { Request, Response } from 'express'; import z from 'zod'; import dbClient from '../../dbClient'; import { validateM2MAuth } from '../../internal/validateM2MAuth'; diff --git a/quadratic-api/src/routes/v0/teams.$uuid.PATCH.ts b/quadratic-api/src/routes/v0/teams.$uuid.PATCH.ts index 6d7aee1e09..6c9c142920 100644 --- a/quadratic-api/src/routes/v0/teams.$uuid.PATCH.ts +++ b/quadratic-api/src/routes/v0/teams.$uuid.PATCH.ts @@ -29,7 +29,7 @@ async function handler(req: RequestWithUser, res: Response