From 99e6a3d759582cf7df1d19e4c41b99ef15120604 Mon Sep 17 00:00:00 2001 From: David Kircos Date: Mon, 10 Feb 2025 20:03:34 -0700 Subject: [PATCH 01/13] begin billing section build out --- ...eams.$uuid.billing.checkout.session.GET.ts | 29 ++- quadratic-api/src/stripe/stripe.ts | 2 +- .../src/routes/teams.$teamUuid.settings.tsx | 213 +++++++++++++----- 3 files changed, 189 insertions(+), 55 deletions(-) diff --git a/quadratic-api/src/routes/v0/teams.$uuid.billing.checkout.session.GET.ts b/quadratic-api/src/routes/v0/teams.$uuid.billing.checkout.session.GET.ts index 03f69502da..e7083c61d1 100644 --- a/quadratic-api/src/routes/v0/teams.$uuid.billing.checkout.session.GET.ts +++ b/quadratic-api/src/routes/v0/teams.$uuid.billing.checkout.session.GET.ts @@ -1,11 +1,14 @@ +import { SubscriptionStatus } from '@prisma/client'; import type { Request, Response } from 'express'; import type { ApiTypes } from 'quadratic-shared/typesAndSchemas'; import z from 'zod'; +import { getUsersFromAuth0 } from '../../auth/auth0'; +import dbClient from '../../dbClient'; import { getTeam } from '../../middleware/getTeam'; import { userMiddleware } from '../../middleware/user'; import { validateAccessToken } from '../../middleware/validateAccessToken'; import { validateRequestSchema } from '../../middleware/validateRequestSchema'; -import { createCheckoutSession, getMonthlyPriceId } from '../../stripe/stripe'; +import { createCheckoutSession, createCustomer, getMonthlyPriceId } from '../../stripe/stripe'; import type { RequestWithUser } from '../../types/Request'; export default [ @@ -24,9 +27,9 @@ export default [ async function handler(req: Request, res: Response) { const { params: { uuid }, - user: { id: userId }, + user: { id: userId, auth0Id }, } = req as RequestWithUser; - const { userMakingRequest } = await getTeam({ uuid, userId }); + const { userMakingRequest, team } = await getTeam({ uuid, userId }); // Can the user even edit this team? if (!userMakingRequest.permissions.includes('TEAM_MANAGE')) { @@ -35,6 +38,26 @@ async function handler(req: Request, res: Response) { .json({ error: { message: 'User does not have permission to access billing for this team.' } }); } + if (team?.stripeSubscriptionStatus === SubscriptionStatus.ACTIVE) { + return res.status(400).json({ error: { message: 'Team already has an active subscription.' } }); + } + + // create a stripe customer if one doesn't exist + if (!team?.stripeCustomerId) { + // Get user email from Auth0 + const auth0Record = await getUsersFromAuth0([{ id: userId, auth0Id }]); + const auth0User = auth0Record[userId]; + + // create Stripe customer + const stripeCustomer = await createCustomer(team.name, auth0User.email); + await dbClient.team.update({ + where: { uuid }, + data: { stripeCustomerId: stripeCustomer.id }, + }); + + team.stripeCustomerId = stripeCustomer.id; + } + const monthlyPriceId = await getMonthlyPriceId(); const session = await createCheckoutSession(uuid, monthlyPriceId, req.headers.origin || 'http://localhost:3000'); diff --git a/quadratic-api/src/stripe/stripe.ts b/quadratic-api/src/stripe/stripe.ts index 7bbead7c72..bdea7d73cf 100644 --- a/quadratic-api/src/stripe/stripe.ts +++ b/quadratic-api/src/stripe/stripe.ts @@ -112,7 +112,7 @@ export const getMonthlyPriceId = async () => { active: true, }); - const data = prices.data.filter((price) => price.lookup_key === 'team_monthly'); + const data = prices.data.filter((price) => price.lookup_key === 'team_monthly_ai'); if (data.length === 0) { throw new Error('No monthly price found'); } diff --git a/quadratic-client/src/routes/teams.$teamUuid.settings.tsx b/quadratic-client/src/routes/teams.$teamUuid.settings.tsx index de832fabda..c20f82b280 100644 --- a/quadratic-client/src/routes/teams.$teamUuid.settings.tsx +++ b/quadratic-client/src/routes/teams.$teamUuid.settings.tsx @@ -2,6 +2,7 @@ import { DashboardHeader } from '@/dashboard/components/DashboardHeader'; import { SettingControl } from '@/dashboard/components/SettingControl'; import { useDashboardRouteLoaderData } from '@/routes/_dashboard'; import { getActionUpdateTeam, type TeamAction } from '@/routes/teams.$teamUuid'; +import { apiClient } from '@/shared/api/apiClient'; import { useGlobalSnackbar } from '@/shared/components/GlobalSnackbarProvider'; import { CheckIcon } from '@/shared/components/Icons'; import { Type } from '@/shared/components/Type'; @@ -21,6 +22,7 @@ export const Component = () => { activeTeam: { team, userMakingRequest: { teamPermissions }, + billing, }, } = useDashboardRouteLoaderData(); @@ -68,21 +70,6 @@ export const Component = () => { }); }; - // One day, when we have billing, we can add something akin to this - // - // {teamPermissions.includes('TEAM_MANAGE') && ( - // { - // // Get the billing session URL - // apiClient.teams.billing.getPortalSessionUrl(team.uuid).then((data) => { - // window.location.href = data.url; - // }); - // }} - // > - // Update billing - // - // )} - // If for some reason it failed, display an error useEffect(() => { if (fetcher.data && fetcher.data.ok === false) { @@ -90,7 +77,7 @@ export const Component = () => { } }, [fetcher.data, addGlobalSnackbar]); - // If you don’t have permission, you can't see this view + // If you don't have permission, you can't see this view if (!teamPermissions.includes('TEAM_EDIT')) { return ; } @@ -112,41 +99,165 @@ export const Component = () => { {teamPermissions.includes('TEAM_MANAGE') && ( - - - Privacy - - -
- - Help improve AI results by allowing Quadratic to store and analyze user prompts.{' '} - - Learn more - - . - - } - onCheckedChange={(checked) => { - handleUpdatePreference('analyticsAi', checked); - }} - checked={optimisticSettings.analyticsAi} - className="rounded border border-border px-3 py-2 shadow-sm" - /> -

- When using AI features your data is sent to our AI providers: -

-
    - {['OpenAI', 'Anthropic', 'AWS Bedrock'].map((item, i) => ( -
  • - {item}: zero-day data retention -
  • - ))} -
-
-
+ <> + + + Billing + +
+ {/* Plan Comparison */} +
+ {/* Free Plan */} +
+

Free Plan

+
+
+ AI Messages Monthly + 50 +
+
+ Connection Runs Monthly + +
+
+ Team Members + +
+
+ Files + +
+
+ {billing.status === undefined && ( + + )} +
+ + {/* Team AI Plan */} +
+

Team AI Plan

+
+
+ AI Messages Month / User + +
+
+ Connection Runs Monthly + +
+
+ Team Members + +
+
+ Files + +
+
+ {billing.status === undefined && ( + + )} +
+
+ + {/* Current Usage */} +
+

Current Usage

+
+
+ AI Messages Month / User +
+ 0 + + / {billing.status === undefined ? '50' : '∞'} + +
+
+
+ Connection Runs Monthly +
+ 0 + / ∞ +
+
+
+ Team Members +
+ 0 + / ∞ +
+
+
+ Files +
+ 0 + / ∞ +
+
+
+
+ + {billing.status !== undefined && ( + + )} +
+
+ + + Privacy + + +
+ + Help improve AI results by allowing Quadratic to store and analyze user prompts.{' '} + + Learn more + + . + + } + onCheckedChange={(checked) => { + handleUpdatePreference('analyticsAi', checked); + }} + checked={optimisticSettings.analyticsAi} + className="rounded border border-border px-3 py-2 shadow-sm" + /> +

+ When using AI features your data is sent to our AI providers: +

+
    + {['OpenAI', 'Anthropic', 'AWS Bedrock'].map((item, i) => ( +
  • + {item}: zero-day data retention +
  • + ))} +
+
+
+ )} From 0fab67097087514c537b80da0475f033cc09c0ca Mon Sep 17 00:00:00 2001 From: David Kircos Date: Mon, 10 Feb 2025 20:08:11 -0700 Subject: [PATCH 02/13] on team rename, update Stripe --- quadratic-api/src/routes/v0/teams.$uuid.PATCH.ts | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/quadratic-api/src/routes/v0/teams.$uuid.PATCH.ts b/quadratic-api/src/routes/v0/teams.$uuid.PATCH.ts index 6c9c142920..9bbd941d3a 100644 --- a/quadratic-api/src/routes/v0/teams.$uuid.PATCH.ts +++ b/quadratic-api/src/routes/v0/teams.$uuid.PATCH.ts @@ -7,6 +7,7 @@ import { getTeam } from '../../middleware/getTeam'; import { userMiddleware } from '../../middleware/user'; import { validateAccessToken } from '../../middleware/validateAccessToken'; import { parseRequest } from '../../middleware/validateRequestSchema'; +import { updateCustomer } from '../../stripe/stripe'; import type { RequestWithUser } from '../../types/Request'; import { ApiError } from '../../utils/ApiError'; @@ -29,7 +30,7 @@ async function handler(req: RequestWithUser, res: Response Date: Mon, 10 Feb 2025 20:41:19 -0700 Subject: [PATCH 03/13] always save an AnalyticsAIChatMessage, don't always save message content --- .../migration.sql | 2 ++ quadratic-api/prisma/schema.prisma | 2 +- quadratic-api/src/routes/v0/ai.chat.POST.ts | 16 +++++++++++----- .../src/routes/teams.$teamUuid.settings.tsx | 12 ++++++------ 4 files changed, 20 insertions(+), 12 deletions(-) create mode 100644 quadratic-api/prisma/migrations/20250211034001_optional_analytics/migration.sql diff --git a/quadratic-api/prisma/migrations/20250211034001_optional_analytics/migration.sql b/quadratic-api/prisma/migrations/20250211034001_optional_analytics/migration.sql new file mode 100644 index 0000000000..b925ed76b2 --- /dev/null +++ b/quadratic-api/prisma/migrations/20250211034001_optional_analytics/migration.sql @@ -0,0 +1,2 @@ +-- AlterTable +ALTER TABLE "AnalyticsAIChatMessage" ALTER COLUMN "s3_key" DROP NOT NULL; diff --git a/quadratic-api/prisma/schema.prisma b/quadratic-api/prisma/schema.prisma index 901729173a..80c79daa15 100644 --- a/quadratic-api/prisma/schema.prisma +++ b/quadratic-api/prisma/schema.prisma @@ -291,7 +291,7 @@ model AnalyticsAIChatMessage { chat AnalyticsAIChat @relation(fields: [chatId], references: [id]) model String messageIndex Int @map("message_index") - s3Key String @map("s3_key") + s3Key String? @map("s3_key") like Boolean? undo Boolean? codeRunError String? @map("code_run_error") diff --git a/quadratic-api/src/routes/v0/ai.chat.POST.ts b/quadratic-api/src/routes/v0/ai.chat.POST.ts index 42070c52c7..fc672a0165 100644 --- a/quadratic-api/src/routes/v0/ai.chat.POST.ts +++ b/quadratic-api/src/routes/v0/ai.chat.POST.ts @@ -68,7 +68,9 @@ async function handler(req: RequestWithUser, res: Response {

Free Plan

- AI Messages Monthly + AI Messages / User / Month 50
- Connection Runs Monthly + Connection Runs / Month
@@ -140,11 +140,11 @@ export const Component = () => {

Team AI Plan

- AI Messages Month / User + AI Messages / User / Month
- Connection Runs Monthly + Connection Runs / Month
@@ -176,7 +176,7 @@ export const Component = () => {

Current Usage

- AI Messages Month / User + AI Messages / User / Month
0 @@ -185,7 +185,7 @@ export const Component = () => {
- Connection Runs Monthly + Connection Runs / Month
0 / ∞ From 8716f9ca6f5adb8db3400eccff63a27611b274ee Mon Sep 17 00:00:00 2001 From: David Kircos Date: Mon, 10 Feb 2025 23:30:31 -0700 Subject: [PATCH 04/13] show usage on dashboard --- quadratic-api/src/ai/usage.ts | 28 ++++++ quadratic-api/src/routes/v0/ai.chat.POST.ts | 91 ++++++++++--------- .../src/routes/v0/teams.$uuid.GET.ts | 5 +- .../src/routes/teams.$teamUuid.settings.tsx | 28 +++++- quadratic-shared/typesAndSchemas.ts | 6 ++ 5 files changed, 112 insertions(+), 46 deletions(-) create mode 100644 quadratic-api/src/ai/usage.ts diff --git a/quadratic-api/src/ai/usage.ts b/quadratic-api/src/ai/usage.ts new file mode 100644 index 0000000000..c201a2b389 --- /dev/null +++ b/quadratic-api/src/ai/usage.ts @@ -0,0 +1,28 @@ +import dbClient from '../dbClient'; + +// Get AI message usage aggregated by month for last 6 months +export const getAIMessageUsageForUser = async (userId: number) => { + return await dbClient.$queryRaw<{ month: string; ai_messages: number }[]>` + WITH RECURSIVE months AS ( + SELECT + DATE_TRUNC('month', CURRENT_DATE) as month + UNION ALL + SELECT + DATE_TRUNC('month', month - INTERVAL '1 month') + FROM months + WHERE month > DATE_TRUNC('month', CURRENT_DATE - INTERVAL '5 months') + ) + SELECT + TO_CHAR(m.month, 'YYYY-MM') as month, + COALESCE(COUNT(acm.id)::integer, 0) as ai_messages + FROM months m + LEFT JOIN "AnalyticsAIChat" ac ON + DATE_TRUNC('month', ac.created_date) = m.month + AND ac.user_id = ${userId} + AND ac.source IN ('ai_assistant', 'ai_analyst', 'ai_researcher') + LEFT JOIN "AnalyticsAIChatMessage" acm ON + acm.chat_id = ac.id + GROUP BY m.month + ORDER BY m.month ASC; +`; +}; diff --git a/quadratic-api/src/routes/v0/ai.chat.POST.ts b/quadratic-api/src/routes/v0/ai.chat.POST.ts index fc672a0165..bd313d73ed 100644 --- a/quadratic-api/src/routes/v0/ai.chat.POST.ts +++ b/quadratic-api/src/routes/v0/ai.chat.POST.ts @@ -68,57 +68,62 @@ async function handler(req: RequestWithUser, res: Response-__.json - const messageIndex = getLastUserPromptMessageIndex(args.messages); - const key = `${fileUuid}-${source}_${chatId.replace(/-/g, '_')}_${messageIndex}.json`; - - let s3Key; if (ownerTeam.settingAnalyticsAi) { + const key = `${fileUuid}-${source}_${chatId.replace(/-/g, '_')}_${messageIndex}.json`; + + // If we aren't using s3 or the analytics bucket name is not set, don't save the data + // This path is also used for self-hosted users, so we don't want to save the data in that case + if (STORAGE_TYPE !== 's3' || !getBucketName(S3Bucket.ANALYTICS)) { + return; + } + + const jwt = req.header('Authorization'); + if (!jwt) { + return; + } + const contents = Buffer.from(JSON.stringify(args)).toString('base64'); const response = await uploadFile(key, contents, jwt, S3Bucket.ANALYTICS); - s3Key = response.key; - } + const s3Key = response.key; - await dbClient.analyticsAIChat.upsert({ - where: { - chatId, - }, - create: { - userId, - fileId, - chatId, - source, - messages: { - create: { - model, - messageIndex, - s3Key, - }, + await dbClient.analyticsAIChatMessage.update({ + where: { + chatId_messageIndex: { chatId: chat.id, messageIndex }, }, - }, - update: { - messages: { - create: { - model, - messageIndex, - s3Key, - }, - }, - updatedDate: new Date(), - }, - }); + data: { s3Key }, + }); + } } catch (e) { console.error(e); } diff --git a/quadratic-api/src/routes/v0/teams.$uuid.GET.ts b/quadratic-api/src/routes/v0/teams.$uuid.GET.ts index 9567eb284e..8903db9d64 100644 --- a/quadratic-api/src/routes/v0/teams.$uuid.GET.ts +++ b/quadratic-api/src/routes/v0/teams.$uuid.GET.ts @@ -1,6 +1,7 @@ import type { Request, Response } from 'express'; import type { ApiTypes } from 'quadratic-shared/typesAndSchemas'; import { z } from 'zod'; +import { getAIMessageUsageForUser } from '../../ai/usage'; import { getUsers } from '../../auth/auth'; import dbClient from '../../dbClient'; import { licenseClient } from '../../licenseClient'; @@ -13,7 +14,6 @@ import type { RequestWithUser } from '../../types/Request'; import type { ResponseError } from '../../types/Response'; import { ApiError } from '../../utils/ApiError'; import { getFilePermissions } from '../../utils/permissions'; - export default [validateAccessToken, userMiddleware, handler]; const schema = z.object({ @@ -124,6 +124,8 @@ async function handler(req: Request, res: Response { return ; } + const latestUsage = billing.usage[billing.usage.length - 1] || { ai_messages: 0 }; + return ( <> @@ -176,9 +180,29 @@ export const Component = () => {

Current Usage

- AI Messages / User / Month +
+ AI Messages / User / Month + + + + + + + Usage History + +
+ {billing.usage.map((usage) => ( +
+ {usage.month} + {usage.ai_messages} +
+ ))} +
+
+
+
- 0 + {latestUsage.ai_messages} / {billing.status === undefined ? '50' : '∞'} diff --git a/quadratic-shared/typesAndSchemas.ts b/quadratic-shared/typesAndSchemas.ts index f537330789..78de975eb7 100644 --- a/quadratic-shared/typesAndSchemas.ts +++ b/quadratic-shared/typesAndSchemas.ts @@ -324,6 +324,12 @@ export const ApiSchemas = { billing: z.object({ status: TeamSubscriptionStatusSchema.optional(), currentPeriodEnd: z.string().optional(), + usage: z.array( + z.object({ + month: z.string(), + ai_messages: z.number(), + }) + ), }), files: z.array( z.object({ From 7cd53efa17b7abef417f63cda56e07132adf0113 Mon Sep 17 00:00:00 2001 From: David Kircos Date: Mon, 10 Feb 2025 23:37:04 -0700 Subject: [PATCH 05/13] fix background --- .../src/routes/teams.$teamUuid.settings.tsx | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/quadratic-client/src/routes/teams.$teamUuid.settings.tsx b/quadratic-client/src/routes/teams.$teamUuid.settings.tsx index 4b635a0b59..78b0b26efb 100644 --- a/quadratic-client/src/routes/teams.$teamUuid.settings.tsx +++ b/quadratic-client/src/routes/teams.$teamUuid.settings.tsx @@ -25,6 +25,9 @@ export const Component = () => { team, userMakingRequest: { teamPermissions }, billing, + users, + files, + filesPrivate, }, } = useDashboardRouteLoaderData(); @@ -140,8 +143,8 @@ export const Component = () => {
{/* Team AI Plan */} -
-

Team AI Plan

+
+

Team Plan

AI Messages / User / Month @@ -169,7 +172,7 @@ export const Component = () => { }} className="mt-4 w-full" > - Upgrade to Team AI + Upgrade Team )}
@@ -211,21 +214,21 @@ export const Component = () => {
Connection Runs / Month
- 0 + - / ∞
Team Members
- 0 + {users.length} / ∞
Files
- 0 + {files.length + filesPrivate.length} / ∞
From 2fecf192ac100594a20c0db1bbfe794d2459cf60 Mon Sep 17 00:00:00 2001 From: AyushAgrawal-A2 Date: Tue, 11 Feb 2025 18:55:22 +0530 Subject: [PATCH 06/13] chore: track ai usage --- quadratic-api/src/ai/handler/anthropic.ts | 12 +- quadratic-api/src/ai/handler/bedrock.ts | 20 +-- quadratic-api/src/ai/handler/openai.ts | 15 ++- .../src/ai/helpers/anthropic.helper.ts | 114 ++++++++++++------ .../src/ai/helpers/bedrock.helper.ts | 99 +++++++++------ quadratic-api/src/ai/helpers/openai.helper.ts | 33 ++++- quadratic-api/src/ai/helpers/usage.helper.ts | 38 ++++++ quadratic-api/src/routes/v0/ai.chat.POST.ts | 17 +-- .../ui/menus/AIAnalyst/AIAnalystMessages.tsx | 7 -- .../hooks/useSubmitAIAnalystPrompt.tsx | 5 +- quadratic-shared/ai/models/AI_MODELS.ts | 63 ++++++++-- quadratic-shared/typesAndSchemasAI.ts | 45 ++++++- 12 files changed, 336 insertions(+), 132 deletions(-) create mode 100644 quadratic-api/src/ai/helpers/usage.helper.ts diff --git a/quadratic-api/src/ai/handler/anthropic.ts b/quadratic-api/src/ai/handler/anthropic.ts index 8fa2bae1c4..c109361879 100644 --- a/quadratic-api/src/ai/handler/anthropic.ts +++ b/quadratic-api/src/ai/handler/anthropic.ts @@ -1,7 +1,7 @@ import Anthropic from '@anthropic-ai/sdk'; import type { Response } from 'express'; import { getModelOptions } from 'quadratic-shared/ai/helpers/model.helper'; -import type { AIMessagePrompt, AIRequestHelperArgs, AnthropicModel } from 'quadratic-shared/typesAndSchemasAI'; +import type { AIRequestHelperArgs, AnthropicModel, ParsedAIResponse } from 'quadratic-shared/typesAndSchemasAI'; import { ANTHROPIC_API_KEY } from '../../env-vars'; import { getAnthropicApiArgs, parseAnthropicResponse, parseAnthropicStream } from '../helpers/anthropic.helper'; @@ -13,7 +13,7 @@ export const handleAnthropicRequest = async ( model: AnthropicModel, args: AIRequestHelperArgs, response: Response -): Promise => { +): Promise => { const { system, messages, tools, tool_choice } = getAnthropicApiArgs(args); const { stream, temperature, max_tokens } = getModelOptions(model, args); @@ -34,8 +34,8 @@ export const handleAnthropicRequest = async ( response.setHeader('Cache-Control', 'no-cache'); response.setHeader('Connection', 'keep-alive'); - const responseMessage = await parseAnthropicStream(chunks, response, model); - return responseMessage; + const parsedResponse = await parseAnthropicStream(chunks, response, model); + return parsedResponse; } catch (error: any) { if (!response.headersSent) { if (error instanceof Anthropic.APIError) { @@ -62,8 +62,8 @@ export const handleAnthropicRequest = async ( tool_choice, }); - const responseMessage = parseAnthropicResponse(result, response, model); - return responseMessage; + const parsedResponse = parseAnthropicResponse(result, response, model); + return parsedResponse; } catch (error: any) { if (error instanceof Anthropic.APIError) { response.status(error.status ?? 400).json(error.message); diff --git a/quadratic-api/src/ai/handler/bedrock.ts b/quadratic-api/src/ai/handler/bedrock.ts index 3583e4e8ae..09d92f83bd 100644 --- a/quadratic-api/src/ai/handler/bedrock.ts +++ b/quadratic-api/src/ai/handler/bedrock.ts @@ -3,7 +3,7 @@ import Anthropic from '@anthropic-ai/sdk'; import { BedrockRuntimeClient, ConverseCommand, ConverseStreamCommand } from '@aws-sdk/client-bedrock-runtime'; import { type Response } from 'express'; import { getModelOptions, isBedrockAnthropicModel } from 'quadratic-shared/ai/helpers/model.helper'; -import type { AIMessagePrompt, AIRequestBody, BedrockModel } from 'quadratic-shared/typesAndSchemasAI'; +import type { AIRequestBody, BedrockModel, ParsedAIResponse } from 'quadratic-shared/typesAndSchemasAI'; import { AWS_S3_ACCESS_KEY_ID, AWS_S3_REGION, AWS_S3_SECRET_ACCESS_KEY } from '../../env-vars'; import { getAnthropicApiArgs, parseAnthropicResponse, parseAnthropicStream } from '../helpers/anthropic.helper'; import { getBedrockApiArgs, parseBedrockResponse, parseBedrockStream } from '../helpers/bedrock.helper'; @@ -25,7 +25,7 @@ export const handleBedrockRequest = async ( model: BedrockModel, args: Omit, response: Response -): Promise => { +): Promise => { const { stream, temperature, max_tokens } = getModelOptions(model, args); if (isBedrockAnthropicModel(model)) { @@ -47,8 +47,8 @@ export const handleBedrockRequest = async ( response.setHeader('Cache-Control', 'no-cache'); response.setHeader('Connection', 'keep-alive'); - const responseMessage = await parseAnthropicStream(chunks, response, model); - return responseMessage; + const parsedResponse = await parseAnthropicStream(chunks, response, model); + return parsedResponse; } catch (error: any) { if (!response.headersSent) { if (error instanceof Anthropic.APIError) { @@ -74,8 +74,8 @@ export const handleBedrockRequest = async ( tool_choice, }); - const responseMessage = parseAnthropicResponse(result, response, model); - return responseMessage; + const parsedResponse = parseAnthropicResponse(result, response, model); + return parsedResponse; } catch (error: any) { if (error instanceof Anthropic.APIError) { response.status(error.status ?? 400).json(error.message); @@ -108,8 +108,8 @@ export const handleBedrockRequest = async ( response.setHeader('Cache-Control', 'no-cache'); response.setHeader('Connection', 'keep-alive'); - const responseMessage = await parseBedrockStream(chunks, response, model); - return responseMessage; + const parsedResponse = await parseBedrockStream(chunks, response, model); + return parsedResponse; } catch (error: any) { if (!response.headersSent) { if (error.response) { @@ -138,8 +138,8 @@ export const handleBedrockRequest = async ( }); const result = await bedrock.send(command); - const responseMessage = parseBedrockResponse(result.output, response, model); - return responseMessage; + const parsedResponse = parseBedrockResponse(result, response, model); + return parsedResponse; } catch (error: any) { if (error.response) { response.status(error.response.status).json(error.response.data); diff --git a/quadratic-api/src/ai/handler/openai.ts b/quadratic-api/src/ai/handler/openai.ts index 8d8120e542..a51b9b9eb7 100644 --- a/quadratic-api/src/ai/handler/openai.ts +++ b/quadratic-api/src/ai/handler/openai.ts @@ -1,7 +1,7 @@ import { type Response } from 'express'; import OpenAI from 'openai'; import { getModelOptions } from 'quadratic-shared/ai/helpers/model.helper'; -import type { AIMessagePrompt, AIRequestHelperArgs, OpenAIModel } from 'quadratic-shared/typesAndSchemasAI'; +import type { AIRequestHelperArgs, OpenAIModel, ParsedAIResponse } from 'quadratic-shared/typesAndSchemasAI'; import { OPENAI_API_KEY } from '../../env-vars'; import { getOpenAIApiArgs, parseOpenAIResponse, parseOpenAIStream } from '../helpers/openai.helper'; @@ -13,7 +13,7 @@ export const handleOpenAIRequest = async ( model: OpenAIModel, args: AIRequestHelperArgs, response: Response -): Promise => { +): Promise => { const { messages, tools, tool_choice } = getOpenAIApiArgs(args); const { stream, temperature } = getModelOptions(model, args); @@ -26,14 +26,17 @@ export const handleOpenAIRequest = async ( stream: true, tools, tool_choice, + stream_options: { + include_usage: true, + }, }); response.setHeader('Content-Type', 'text/event-stream'); response.setHeader('Cache-Control', 'no-cache'); response.setHeader('Connection', 'keep-alive'); - const responseMessage = await parseOpenAIStream(completion, response, model); - return responseMessage; + const parsedResponse = await parseOpenAIStream(completion, response, model); + return parsedResponse; } catch (error: any) { if (!response.headersSent) { if (error instanceof OpenAI.APIError) { @@ -57,8 +60,8 @@ export const handleOpenAIRequest = async ( tool_choice, }); - const responseMessage = parseOpenAIResponse(result, response, model); - return responseMessage; + const parsedResponse = parseOpenAIResponse(result, response, model); + return parsedResponse; } catch (error: any) { if (error instanceof OpenAI.APIError) { response.status(error.status ?? 400).json(error.message); diff --git a/quadratic-api/src/ai/helpers/anthropic.helper.ts b/quadratic-api/src/ai/helpers/anthropic.helper.ts index ae02e3ba5f..ed6ebc410f 100644 --- a/quadratic-api/src/ai/helpers/anthropic.helper.ts +++ b/quadratic-api/src/ai/helpers/anthropic.helper.ts @@ -10,7 +10,9 @@ import type { AIRequestBody, AnthropicModel, BedrockAnthropicModel, + ParsedAIResponse, } from 'quadratic-shared/typesAndSchemasAI'; +import { calculateUsage } from './usage.helper'; export function getAnthropicApiArgs(args: Omit): { system: string | TextBlockParam[] | undefined; @@ -120,7 +122,7 @@ export async function parseAnthropicStream( chunks: Stream, response: Response, model: AnthropicModel | BedrockAnthropicModel -) { +): Promise { const responseMessage: AIMessagePrompt = { role: 'assistant', content: '', @@ -129,46 +131,70 @@ export async function parseAnthropicStream( model, }; + let input_tokens = 0; + let output_tokens = 0; + let cache_read_tokens = 0; + let cache_write_tokens = 0; + for await (const chunk of chunks) { if (!response.writableEnded) { - if (chunk.type === 'content_block_start') { - if (chunk.content_block.type === 'text') { - responseMessage.content += chunk.content_block.text; - } else if (chunk.content_block.type === 'tool_use') { - const toolCalls = [...responseMessage.toolCalls]; - const toolCall = { - id: chunk.content_block.id, - name: chunk.content_block.name, - arguments: '', - loading: true, - }; - toolCalls.push(toolCall); - responseMessage.toolCalls = toolCalls; - } - } else if (chunk.type === 'content_block_delta') { - if (chunk.delta.type === 'text_delta') { - responseMessage.content += chunk.delta.text; - } else if (chunk.delta.type === 'input_json_delta') { - const toolCalls = [...responseMessage.toolCalls]; - const toolCall = { - ...(toolCalls.pop() ?? { - id: '', - name: '', + switch (chunk.type) { + case 'content_block_start': + if (chunk.content_block.type === 'text') { + responseMessage.content += chunk.content_block.text; + } else if (chunk.content_block.type === 'tool_use') { + const toolCalls = [...responseMessage.toolCalls]; + const toolCall = { + id: chunk.content_block.id, + name: chunk.content_block.name, arguments: '', loading: true, - }), - }; - toolCall.arguments += chunk.delta.partial_json; - toolCalls.push(toolCall); - responseMessage.toolCalls = toolCalls; - } - } else if (chunk.type === 'content_block_stop') { - const toolCalls = [...responseMessage.toolCalls]; - const toolCall = toolCalls.pop(); - if (toolCall) { - toolCalls.push({ ...toolCall, loading: false }); - responseMessage.toolCalls = toolCalls; - } + }; + toolCalls.push(toolCall); + responseMessage.toolCalls = toolCalls; + } + break; + case 'content_block_delta': + if (chunk.delta.type === 'text_delta') { + responseMessage.content += chunk.delta.text; + } else if (chunk.delta.type === 'input_json_delta') { + const toolCalls = [...responseMessage.toolCalls]; + const toolCall = { + ...(toolCalls.pop() ?? { + id: '', + name: '', + arguments: '', + loading: true, + }), + }; + toolCall.arguments += chunk.delta.partial_json; + toolCalls.push(toolCall); + responseMessage.toolCalls = toolCalls; + } + break; + case 'content_block_stop': + { + const toolCalls = [...responseMessage.toolCalls]; + const toolCall = toolCalls.pop(); + if (toolCall) { + toolCalls.push({ ...toolCall, loading: false }); + responseMessage.toolCalls = toolCalls; + } + } + break; + case 'message_start': + if (chunk.message.usage) { + input_tokens = Math.max(input_tokens, chunk.message.usage.input_tokens); + output_tokens = Math.max(output_tokens, chunk.message.usage.output_tokens); + cache_read_tokens = Math.max(cache_read_tokens, chunk.message.usage.cache_read_input_tokens ?? 0); + cache_write_tokens = Math.max(cache_write_tokens, chunk.message.usage.cache_creation_input_tokens ?? 0); + } + break; + case 'message_delta': + if (chunk.usage) { + output_tokens = Math.max(output_tokens, chunk.usage.output_tokens); + } + break; } response.write(`data: ${JSON.stringify(responseMessage)}\n\n`); @@ -187,14 +213,16 @@ export async function parseAnthropicStream( response.end(); } - return responseMessage; + const usage = calculateUsage({ model, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens }); + + return { responseMessage, usage }; } export function parseAnthropicResponse( result: Anthropic.Messages.Message, response: Response, model: AnthropicModel | BedrockAnthropicModel -): AIMessagePrompt { +): ParsedAIResponse { const responseMessage: AIMessagePrompt = { role: 'assistant', content: '', @@ -233,5 +261,11 @@ export function parseAnthropicResponse( response.json(responseMessage); - return responseMessage; + const input_tokens = result.usage.input_tokens; + const output_tokens = result.usage.output_tokens; + const cache_read_tokens = result.usage.cache_read_input_tokens ?? 0; + const cache_write_tokens = result.usage.cache_creation_input_tokens ?? 0; + const usage = calculateUsage({ model, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens }); + + return { responseMessage, usage }; } diff --git a/quadratic-api/src/ai/helpers/bedrock.helper.ts b/quadratic-api/src/ai/helpers/bedrock.helper.ts index 436b01e4af..4da7467e2b 100644 --- a/quadratic-api/src/ai/helpers/bedrock.helper.ts +++ b/quadratic-api/src/ai/helpers/bedrock.helper.ts @@ -1,5 +1,5 @@ import { - type ConverseOutput, + type ConverseResponse, type ConverseStreamOutput, type Message, type SystemContentBlock, @@ -10,7 +10,13 @@ import type { Response } from 'express'; import { getSystemPromptMessages } from 'quadratic-shared/ai/helpers/message.helper'; import type { AITool } from 'quadratic-shared/ai/specs/aiToolsSpec'; import { aiToolsSpec } from 'quadratic-shared/ai/specs/aiToolsSpec'; -import type { AIMessagePrompt, AIRequestHelperArgs, BedrockModel } from 'quadratic-shared/typesAndSchemasAI'; +import type { + AIMessagePrompt, + AIRequestHelperArgs, + BedrockModel, + ParsedAIResponse, +} from 'quadratic-shared/typesAndSchemasAI'; +import { calculateUsage } from './usage.helper'; export function getBedrockApiArgs(args: AIRequestHelperArgs): { system: SystemContentBlock[] | undefined; @@ -121,7 +127,7 @@ export async function parseBedrockStream( chunks: AsyncIterable | never[], response: Response, model: BedrockModel -) { +): Promise { const responseMessage: AIMessagePrompt = { role: 'assistant', content: '', @@ -130,18 +136,29 @@ export async function parseBedrockStream( model, }; + let input_tokens = 0; + let output_tokens = 0; + for await (const chunk of chunks) { + if (chunk.metadata) { + input_tokens = Math.max(input_tokens, chunk.metadata.usage?.inputTokens ?? 0); + output_tokens = Math.max(output_tokens, chunk.metadata.usage?.outputTokens ?? 0); + } + if (!response.writableEnded) { - if (chunk.contentBlockStart && chunk.contentBlockStart.start && chunk.contentBlockStart.start.toolUse) { - const toolCalls = [...responseMessage.toolCalls]; - const toolCall = { - id: chunk.contentBlockStart.start.toolUse.toolUseId ?? '', - name: chunk.contentBlockStart.start.toolUse.name ?? '', - arguments: '', - loading: true, - }; - toolCalls.push(toolCall); - responseMessage.toolCalls = toolCalls; + if (chunk.contentBlockStart) { + // tool use start + if (chunk.contentBlockStart.start && chunk.contentBlockStart.start.toolUse) { + const toolCalls = [...responseMessage.toolCalls]; + const toolCall = { + id: chunk.contentBlockStart.start.toolUse.toolUseId ?? '', + name: chunk.contentBlockStart.start.toolUse.name ?? '', + arguments: '', + loading: true, + }; + toolCalls.push(toolCall); + responseMessage.toolCalls = toolCalls; + } } // tool use stop else if (chunk.contentBlockStop) { @@ -152,25 +169,27 @@ export async function parseBedrockStream( toolCalls.push(toolCall); responseMessage.toolCalls = toolCalls; } - } else if (chunk.contentBlockDelta && chunk.contentBlockDelta.delta) { - // text delta - if ('text' in chunk.contentBlockDelta.delta) { - responseMessage.content += chunk.contentBlockDelta.delta.text; - } - // tool use delta - if ('toolUse' in chunk.contentBlockDelta.delta) { - const toolCalls = [...responseMessage.toolCalls]; - const toolCall = { - ...(toolCalls.pop() ?? { - id: '', - name: '', - arguments: '', - loading: true, - }), - }; - toolCall.arguments += chunk.contentBlockDelta.delta.toolUse?.input ?? ''; - toolCalls.push(toolCall); - responseMessage.toolCalls = toolCalls; + } else if (chunk.contentBlockDelta) { + if (chunk.contentBlockDelta.delta) { + // text delta + if ('text' in chunk.contentBlockDelta.delta) { + responseMessage.content += chunk.contentBlockDelta.delta.text; + } + // tool use delta + if ('toolUse' in chunk.contentBlockDelta.delta) { + const toolCalls = [...responseMessage.toolCalls]; + const toolCall = { + ...(toolCalls.pop() ?? { + id: '', + name: '', + arguments: '', + loading: true, + }), + }; + toolCall.arguments += chunk.contentBlockDelta.delta.toolUse?.input ?? ''; + toolCalls.push(toolCall); + responseMessage.toolCalls = toolCalls; + } } } @@ -190,14 +209,16 @@ export async function parseBedrockStream( response.end(); } - return responseMessage; + const usage = calculateUsage({ model, input_tokens, output_tokens, cache_read_tokens: 0, cache_write_tokens: 0 }); + + return { responseMessage, usage }; } export function parseBedrockResponse( - result: ConverseOutput | undefined, + result: ConverseResponse, response: Response, model: BedrockModel -): AIMessagePrompt { +): ParsedAIResponse { const responseMessage: AIMessagePrompt = { role: 'assistant', content: '', @@ -206,7 +227,7 @@ export function parseBedrockResponse( model, }; - result?.message?.content?.forEach((contentBlock) => { + result.output?.message?.content?.forEach((contentBlock) => { if ('text' in contentBlock) { responseMessage.content += contentBlock.text; } @@ -235,5 +256,9 @@ export function parseBedrockResponse( response.json(responseMessage); - return responseMessage; + const input_tokens = result.usage?.inputTokens ?? 0; + const output_tokens = result.usage?.outputTokens ?? 0; + const usage = calculateUsage({ model, input_tokens, output_tokens, cache_read_tokens: 0, cache_write_tokens: 0 }); + + return { responseMessage, usage }; } diff --git a/quadratic-api/src/ai/helpers/openai.helper.ts b/quadratic-api/src/ai/helpers/openai.helper.ts index 8ceae60206..873e22eaf8 100644 --- a/quadratic-api/src/ai/helpers/openai.helper.ts +++ b/quadratic-api/src/ai/helpers/openai.helper.ts @@ -5,7 +5,13 @@ import type { Stream } from 'openai/streaming'; import { getSystemPromptMessages } from 'quadratic-shared/ai/helpers/message.helper'; import type { AITool } from 'quadratic-shared/ai/specs/aiToolsSpec'; import { aiToolsSpec } from 'quadratic-shared/ai/specs/aiToolsSpec'; -import type { AIMessagePrompt, AIRequestHelperArgs, OpenAIModel } from 'quadratic-shared/typesAndSchemasAI'; +import type { + AIMessagePrompt, + AIRequestHelperArgs, + OpenAIModel, + ParsedAIResponse, +} from 'quadratic-shared/typesAndSchemasAI'; +import { calculateUsage } from './usage.helper'; export function getOpenAIApiArgs(args: AIRequestHelperArgs): { messages: ChatCompletionMessageParam[]; @@ -107,7 +113,7 @@ export async function parseOpenAIStream( chunks: Stream, response: Response, model: OpenAIModel -) { +): Promise { const responseMessage: AIMessagePrompt = { role: 'assistant', content: '', @@ -116,7 +122,17 @@ export async function parseOpenAIStream( model, }; + let input_tokens = 0; + let output_tokens = 0; + let cache_read_tokens = 0; + for await (const chunk of chunks) { + if (chunk.usage) { + input_tokens = Math.max(input_tokens, chunk.usage.prompt_tokens); + output_tokens = Math.max(output_tokens, chunk.usage.completion_tokens); + cache_read_tokens = Math.max(cache_read_tokens, chunk.usage.prompt_tokens_details?.cached_tokens ?? 0); + } + if (!response.writableEnded) { if (chunk.choices && chunk.choices[0] && chunk.choices[0].delta) { // text delta @@ -187,14 +203,16 @@ export async function parseOpenAIStream( response.end(); } - return responseMessage; + const usage = calculateUsage({ model, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens: 0 }); + + return { responseMessage, usage }; } export function parseOpenAIResponse( result: OpenAI.Chat.Completions.ChatCompletion, response: Response, model: OpenAIModel -): AIMessagePrompt { +): ParsedAIResponse { const responseMessage: AIMessagePrompt = { role: 'assistant', content: '', @@ -240,5 +258,10 @@ export function parseOpenAIResponse( response.json(responseMessage); - return responseMessage; + const input_tokens = result.usage?.prompt_tokens ?? 0; + const output_tokens = result.usage?.completion_tokens ?? 0; + const cache_read_tokens = result.usage?.prompt_tokens_details?.cached_tokens ?? 0; + const usage = calculateUsage({ model, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens: 0 }); + + return { responseMessage, usage }; } diff --git a/quadratic-api/src/ai/helpers/usage.helper.ts b/quadratic-api/src/ai/helpers/usage.helper.ts new file mode 100644 index 0000000000..dee375fe29 --- /dev/null +++ b/quadratic-api/src/ai/helpers/usage.helper.ts @@ -0,0 +1,38 @@ +import { MODEL_OPTIONS } from 'quadratic-shared/ai/models/AI_MODELS'; +import type { AIModel, AIUsage } from 'quadratic-shared/typesAndSchemasAI'; + +export function calculateUsage({ + model, + input_tokens, + output_tokens, + cache_read_tokens, + cache_write_tokens, +}: { + model: AIModel; +} & Pick): AIUsage { + const rate_per_million_input_tokens = MODEL_OPTIONS[model].rate_per_million_input_tokens; + const rate_per_million_output_tokens = MODEL_OPTIONS[model].rate_per_million_output_tokens; + const rate_per_million_cache_read_tokens = MODEL_OPTIONS[model].rate_per_million_cache_read_tokens; + const rate_per_million_cache_write_tokens = MODEL_OPTIONS[model].rate_per_million_cache_write_tokens; + const net_cost = + (cache_read_tokens * rate_per_million_cache_read_tokens + + cache_write_tokens * rate_per_million_cache_write_tokens + + (input_tokens - cache_read_tokens) * rate_per_million_input_tokens + + output_tokens * rate_per_million_output_tokens) / + 1000000; + + const usage: AIUsage = { + model, + rate_per_million_input_tokens, + rate_per_million_output_tokens, + rate_per_million_cache_read_tokens, + rate_per_million_cache_write_tokens, + input_tokens, + output_tokens, + cache_read_tokens, + cache_write_tokens, + net_cost, + }; + + return usage; +} diff --git a/quadratic-api/src/routes/v0/ai.chat.POST.ts b/quadratic-api/src/routes/v0/ai.chat.POST.ts index bd313d73ed..7c3d113379 100644 --- a/quadratic-api/src/routes/v0/ai.chat.POST.ts +++ b/quadratic-api/src/routes/v0/ai.chat.POST.ts @@ -8,7 +8,7 @@ import { } from 'quadratic-shared/ai/helpers/model.helper'; import type { ApiTypes } from 'quadratic-shared/typesAndSchemas'; import { ApiSchemas } from 'quadratic-shared/typesAndSchemas'; -import { type AIMessagePrompt } from 'quadratic-shared/typesAndSchemasAI'; +import { type ParsedAIResponse } from 'quadratic-shared/typesAndSchemasAI'; import { z } from 'zod'; import { handleAnthropicRequest } from '../../ai/handler/anthropic'; import { handleBedrockRequest } from '../../ai/handler/bedrock'; @@ -49,21 +49,24 @@ async function handler(req: RequestWithUser, res: Response(null); - const settings = useRecoilValue(editorInteractionStateSettingsAtom); const logFeedback = useRecoilCallback( ({ snapshot }) => @@ -172,11 +170,6 @@ function FeedbackButtons() { // Log it to mixpanel mixpanel.track('[AIAnalyst].feedback', { like: newLike }); - // If they have AI analytics turned off, don't do anything else - if (!settings.analyticsAi) { - return; - } - // Otherwise, log it to our DB const messages = snapshot.getLoadable(aiAnalystCurrentChatMessagesAtom).getValue(); const messageIndex = getLastUserPromptMessageIndex(messages); diff --git a/quadratic-client/src/app/ui/menus/AIAnalyst/hooks/useSubmitAIAnalystPrompt.tsx b/quadratic-client/src/app/ui/menus/AIAnalyst/hooks/useSubmitAIAnalystPrompt.tsx index 6207c981cd..b55ad8b76a 100644 --- a/quadratic-client/src/app/ui/menus/AIAnalyst/hooks/useSubmitAIAnalystPrompt.tsx +++ b/quadratic-client/src/app/ui/menus/AIAnalyst/hooks/useSubmitAIAnalystPrompt.tsx @@ -27,6 +27,7 @@ import type { import { useRecoilCallback } from 'recoil'; import { v4 } from 'uuid'; +const USE_STREAM = true; const MAX_TOOL_CALL_ITERATIONS = 25; export type SubmitAIAnalystPromptArgs = { @@ -144,7 +145,7 @@ export function useSubmitAIAnalystPrompt() { source: 'AIAnalyst', model, messages: updatedMessages, - useStream: true, + useStream: USE_STREAM, useTools: true, useToolsPrompt: true, language: undefined, @@ -194,7 +195,7 @@ export function useSubmitAIAnalystPrompt() { source: 'AIAnalyst', model, messages: updatedMessages, - useStream: true, + useStream: USE_STREAM, useTools: true, useToolsPrompt: true, language: undefined, diff --git a/quadratic-shared/ai/models/AI_MODELS.ts b/quadratic-shared/ai/models/AI_MODELS.ts index 2a093f054d..1f8ef67e76 100644 --- a/quadratic-shared/ai/models/AI_MODELS.ts +++ b/quadratic-shared/ai/models/AI_MODELS.ts @@ -1,21 +1,13 @@ -import type { AIModel, AIProviders } from 'quadratic-shared/typesAndSchemasAI'; +import type { AIModel, AIModelOptions } from 'quadratic-shared/typesAndSchemasAI'; export const DEFAULT_MODEL: AIModel = 'anthropic.claude-3-5-sonnet-20241022-v2:0'; -export const DEFAULT_GET_CHAT_NAME_MODEL: AIModel = 'anthropic.claude-3-5-haiku-20241022-v1:0'; +export const DEFAULT_GET_CHAT_NAME_MODEL: AIModel = 'anthropic.claude-3-haiku-20240307-v1:0'; // updating this will force the model to be reset to the default model in local storage export const DEFAULT_MODEL_VERSION = 2; export const MODEL_OPTIONS: { - [key in AIModel]: { - displayName: string; - temperature: number; - max_tokens: number; - canStream: boolean; - canStreamWithToolCalls: boolean; - enabled: boolean; - provider: AIProviders; - }; + [key in AIModel]: AIModelOptions; } = { 'gpt-4o-2024-11-20': { displayName: 'OpenAI: GPT-4o', @@ -25,6 +17,10 @@ export const MODEL_OPTIONS: { canStreamWithToolCalls: true, enabled: true, provider: 'openai', + rate_per_million_input_tokens: 2.5, + rate_per_million_output_tokens: 10, + rate_per_million_cache_read_tokens: 1.25, + rate_per_million_cache_write_tokens: 0, }, 'o1-2024-12-17': { displayName: 'OpenAI: o1', @@ -34,6 +30,10 @@ export const MODEL_OPTIONS: { canStreamWithToolCalls: false, enabled: false, provider: 'openai', + rate_per_million_input_tokens: 15, + rate_per_million_output_tokens: 60, + rate_per_million_cache_read_tokens: 7.5, + rate_per_million_cache_write_tokens: 0, }, 'o3-mini-2025-01-31': { displayName: 'OpenAI: o3-mini', @@ -43,6 +43,10 @@ export const MODEL_OPTIONS: { canStreamWithToolCalls: true, enabled: true, provider: 'openai', + rate_per_million_input_tokens: 1.1, + rate_per_million_output_tokens: 4.4, + rate_per_million_cache_read_tokens: 0.55, + rate_per_million_cache_write_tokens: 0, }, 'claude-3-5-sonnet-20241022': { displayName: 'Anthropic: Claude 3.5 Sonnet', @@ -52,6 +56,10 @@ export const MODEL_OPTIONS: { canStreamWithToolCalls: true, enabled: false, provider: 'anthropic', + rate_per_million_input_tokens: 3, + rate_per_million_output_tokens: 15, + rate_per_million_cache_read_tokens: 0.3, + rate_per_million_cache_write_tokens: 3.75, }, 'claude-3-5-haiku-20241022': { displayName: 'Anthropic: Claude 3.5 Haiku', @@ -61,6 +69,10 @@ export const MODEL_OPTIONS: { canStreamWithToolCalls: true, enabled: false, provider: 'anthropic', + rate_per_million_input_tokens: 0.8, + rate_per_million_output_tokens: 4, + rate_per_million_cache_read_tokens: 0.08, + rate_per_million_cache_write_tokens: 1, }, 'anthropic.claude-3-5-sonnet-20241022-v2:0': { displayName: `Bedrock: Claude 3.5 Sonnet`, @@ -70,6 +82,10 @@ export const MODEL_OPTIONS: { canStreamWithToolCalls: true, enabled: true, provider: 'bedrock-anthropic', + rate_per_million_input_tokens: 3, + rate_per_million_output_tokens: 15, + rate_per_million_cache_read_tokens: 0.3, + rate_per_million_cache_write_tokens: 3.75, }, 'anthropic.claude-3-5-haiku-20241022-v1:0': { displayName: 'Bedrock: Claude 3.5 Haiku', @@ -79,6 +95,23 @@ export const MODEL_OPTIONS: { canStreamWithToolCalls: true, enabled: false, provider: 'bedrock-anthropic', + rate_per_million_input_tokens: 0.8, + rate_per_million_output_tokens: 4, + rate_per_million_cache_read_tokens: 0.08, + rate_per_million_cache_write_tokens: 1, + }, + 'anthropic.claude-3-haiku-20240307-v1:0': { + displayName: 'Bedrock: Claude 3 Haiku', + temperature: 0, + max_tokens: 4096, + canStream: true, + canStreamWithToolCalls: true, + enabled: false, + provider: 'bedrock-anthropic', + rate_per_million_input_tokens: 0.25, + rate_per_million_output_tokens: 1.25, + rate_per_million_cache_read_tokens: 0, + rate_per_million_cache_write_tokens: 0, }, 'us.meta.llama3-2-90b-instruct-v1:0': { displayName: 'Bedrock: Llama 3.2 90B Instruct', @@ -88,6 +121,10 @@ export const MODEL_OPTIONS: { canStreamWithToolCalls: false, enabled: false, provider: 'bedrock', + rate_per_million_input_tokens: 0.72, + rate_per_million_output_tokens: 0.72, + rate_per_million_cache_read_tokens: 0, + rate_per_million_cache_write_tokens: 0, }, 'mistral.mistral-large-2407-v1:0': { displayName: 'Bedrock: Mistral Large 2 (24.07)', @@ -97,5 +134,9 @@ export const MODEL_OPTIONS: { canStreamWithToolCalls: false, enabled: false, provider: 'bedrock', + rate_per_million_input_tokens: 2, + rate_per_million_output_tokens: 6, + rate_per_million_cache_read_tokens: 0, + rate_per_million_cache_write_tokens: 0, }, } as const; diff --git a/quadratic-shared/typesAndSchemasAI.ts b/quadratic-shared/typesAndSchemasAI.ts index 588fef566d..eb1d85d66a 100644 --- a/quadratic-shared/typesAndSchemasAI.ts +++ b/quadratic-shared/typesAndSchemasAI.ts @@ -8,6 +8,7 @@ const BedrockModelSchema = z .enum([ 'anthropic.claude-3-5-sonnet-20241022-v2:0', 'anthropic.claude-3-5-haiku-20241022-v1:0', + 'anthropic.claude-3-haiku-20240307-v1:0', 'us.meta.llama3-2-90b-instruct-v1:0', 'mistral.mistral-large-2407-v1:0', ]) @@ -15,7 +16,11 @@ const BedrockModelSchema = z export type BedrockModel = z.infer; const BedrockAnthropicModelSchema = z - .enum(['anthropic.claude-3-5-sonnet-20241022-v2:0', 'anthropic.claude-3-5-haiku-20241022-v1:0']) + .enum([ + 'anthropic.claude-3-5-sonnet-20241022-v2:0', + 'anthropic.claude-3-5-haiku-20241022-v1:0', + 'anthropic.claude-3-haiku-20240307-v1:0', + ]) .default('anthropic.claude-3-5-sonnet-20241022-v2:0'); export type BedrockAnthropicModel = z.infer; @@ -37,6 +42,26 @@ const AIModelSchema = z.union([ ]); export type AIModel = z.infer; +const AIRatesSchema = z.object({ + rate_per_million_input_tokens: z.number(), + rate_per_million_output_tokens: z.number(), + rate_per_million_cache_read_tokens: z.number(), + rate_per_million_cache_write_tokens: z.number(), +}); + +const AIModelOptionsSchema = z + .object({ + displayName: z.string(), + temperature: z.number(), + max_tokens: z.number(), + canStream: z.boolean(), + canStreamWithToolCalls: z.boolean(), + enabled: z.boolean(), + provider: AIProvidersSchema, + }) + .extend(AIRatesSchema.shape); +export type AIModelOptions = z.infer; + const InternalContextTypeSchema = z.enum([ 'quadraticDocs', 'currentFile', @@ -178,3 +203,21 @@ export const AIRequestBodySchema = z.object({ }); export type AIRequestBody = z.infer; export type AIRequestHelperArgs = Omit; + +const AIUsageSchema = z + .object({ + model: AIModelSchema, + input_tokens: z.number(), + output_tokens: z.number(), + cache_read_tokens: z.number(), + cache_write_tokens: z.number(), + net_cost: z.number(), + }) + .extend(AIRatesSchema.shape); +export type AIUsage = z.infer; + +const parsedAIResponseSchema = z.object({ + responseMessage: AIMessagePromptSchema, + usage: AIUsageSchema, +}); +export type ParsedAIResponse = z.infer; From 1f3fb9d84c313c8abb15118aba2440bb8365a8c8 Mon Sep 17 00:00:00 2001 From: AyushAgrawal-A2 Date: Tue, 11 Feb 2025 19:55:29 +0530 Subject: [PATCH 07/13] fix cache calculation --- quadratic-api/src/ai/helpers/openai.helper.ts | 5 +++-- quadratic-api/src/ai/helpers/usage.helper.ts | 8 ++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/quadratic-api/src/ai/helpers/openai.helper.ts b/quadratic-api/src/ai/helpers/openai.helper.ts index 873e22eaf8..ca5a6a7bc3 100644 --- a/quadratic-api/src/ai/helpers/openai.helper.ts +++ b/quadratic-api/src/ai/helpers/openai.helper.ts @@ -203,6 +203,7 @@ export async function parseOpenAIStream( response.end(); } + input_tokens -= cache_read_tokens; const usage = calculateUsage({ model, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens: 0 }); return { responseMessage, usage }; @@ -258,9 +259,9 @@ export function parseOpenAIResponse( response.json(responseMessage); - const input_tokens = result.usage?.prompt_tokens ?? 0; - const output_tokens = result.usage?.completion_tokens ?? 0; const cache_read_tokens = result.usage?.prompt_tokens_details?.cached_tokens ?? 0; + const input_tokens = (result.usage?.prompt_tokens ?? 0) - cache_read_tokens; + const output_tokens = result.usage?.completion_tokens ?? 0; const usage = calculateUsage({ model, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens: 0 }); return { responseMessage, usage }; diff --git a/quadratic-api/src/ai/helpers/usage.helper.ts b/quadratic-api/src/ai/helpers/usage.helper.ts index dee375fe29..b5bc8fe794 100644 --- a/quadratic-api/src/ai/helpers/usage.helper.ts +++ b/quadratic-api/src/ai/helpers/usage.helper.ts @@ -15,10 +15,10 @@ export function calculateUsage({ const rate_per_million_cache_read_tokens = MODEL_OPTIONS[model].rate_per_million_cache_read_tokens; const rate_per_million_cache_write_tokens = MODEL_OPTIONS[model].rate_per_million_cache_write_tokens; const net_cost = - (cache_read_tokens * rate_per_million_cache_read_tokens + - cache_write_tokens * rate_per_million_cache_write_tokens + - (input_tokens - cache_read_tokens) * rate_per_million_input_tokens + - output_tokens * rate_per_million_output_tokens) / + (input_tokens * rate_per_million_input_tokens + + output_tokens * rate_per_million_output_tokens + + cache_read_tokens * rate_per_million_cache_read_tokens + + cache_write_tokens * rate_per_million_cache_write_tokens) / 1000000; const usage: AIUsage = { From 27c21baa8341b11e234595b6504b281678102ccb Mon Sep 17 00:00:00 2001 From: AyushAgrawal-A2 Date: Tue, 11 Feb 2025 20:47:00 +0530 Subject: [PATCH 08/13] anthropic prompt caching --- quadratic-api/src/ai/helpers/anthropic.helper.ts | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/quadratic-api/src/ai/helpers/anthropic.helper.ts b/quadratic-api/src/ai/helpers/anthropic.helper.ts index ed6ebc410f..0003a65f20 100644 --- a/quadratic-api/src/ai/helpers/anthropic.helper.ts +++ b/quadratic-api/src/ai/helpers/anthropic.helper.ts @@ -23,7 +23,17 @@ export function getAnthropicApiArgs(args: Omit ({ + // type: 'text' as const, + // text: message, + // ...(index < 4 ? { cache_control: { type: 'ephemeral' } } : {}), + // })); + const messages: MessageParam[] = promptMessages.reduce((acc, message) => { if (message.role === 'assistant' && message.contextType === 'userPrompt' && message.toolCalls.length > 0) { const anthropicMessages: MessageParam[] = [ From c197434d2fe0ee255466ac41f80088edf575e5b3 Mon Sep 17 00:00:00 2001 From: David Kircos Date: Tue, 11 Feb 2025 10:19:04 -0700 Subject: [PATCH 09/13] add stripe key for previews --- infra/aws-cloudformation/quadratic-preview.yml | 3 +++ quadratic-client/src/routes/teams.$teamUuid.settings.tsx | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/infra/aws-cloudformation/quadratic-preview.yml b/infra/aws-cloudformation/quadratic-preview.yml index 9d98f8c120..8c7a8b0386 100644 --- a/infra/aws-cloudformation/quadratic-preview.yml +++ b/infra/aws-cloudformation/quadratic-preview.yml @@ -139,6 +139,9 @@ Resources: ANTHROPIC_API_KEY=$(aws ssm get-parameter --name "/quadratic-development/ANTHROPIC_API_KEY" --with-decryption --query "Parameter.Value" --output text) EXA_API_KEY=$(aws ssm get-parameter --name "/quadratic-development/EXA_API_KEY" --with-decryption --query "Parameter.Value" --output text) + # stripe + STRIPE_SECRET_KEY=$(aws ssm get-parameter --name "/quadratic-development/STRIPE_SECRET_KEY" --with-decryption --query "Parameter.Value" --output text) + # aws ecr ECR_URL=${AWS::AccountId}.dkr.ecr.${AWS::Region}.amazonaws.com IMAGE_TAG=${ImageTag} diff --git a/quadratic-client/src/routes/teams.$teamUuid.settings.tsx b/quadratic-client/src/routes/teams.$teamUuid.settings.tsx index 78b0b26efb..0ae5e6155a 100644 --- a/quadratic-client/src/routes/teams.$teamUuid.settings.tsx +++ b/quadratic-client/src/routes/teams.$teamUuid.settings.tsx @@ -193,6 +193,10 @@ export const Component = () => { Usage History +

+ Users receive 50 free AI messages per month across all teams. If a user belongs to a paid + team, they'll use that team's unlimited messages instead. +

{billing.usage.map((usage) => (
From 277f2b2c1e02fd6bcda3f001b86e9c620b55b376 Mon Sep 17 00:00:00 2001 From: AyushAgrawal-A2 Date: Wed, 12 Feb 2025 01:16:10 +0530 Subject: [PATCH 10/13] save to db, remove api rates --- .../migration.sql | 2 - .../migration.sql | 10 +++++ quadratic-api/prisma/schema.prisma | 34 +++++++++----- .../src/ai/helpers/anthropic.helper.ts | 45 +++++++++++-------- .../src/ai/helpers/bedrock.helper.ts | 25 ++++++----- quadratic-api/src/ai/helpers/openai.helper.ts | 32 +++++++------ quadratic-api/src/ai/helpers/usage.helper.ts | 38 ---------------- quadratic-api/src/routes/v0/ai.chat.POST.ts | 16 +++++-- quadratic-shared/ai/helpers/message.helper.ts | 21 +++++++-- quadratic-shared/ai/models/AI_MODELS.ts | 40 ----------------- quadratic-shared/typesAndSchemasAI.ts | 43 +++++++----------- 11 files changed, 137 insertions(+), 169 deletions(-) delete mode 100644 quadratic-api/prisma/migrations/20250211034001_optional_analytics/migration.sql create mode 100644 quadratic-api/prisma/migrations/20250211191943_optional_analytics_and_token_usage_and_message_type/migration.sql delete mode 100644 quadratic-api/src/ai/helpers/usage.helper.ts diff --git a/quadratic-api/prisma/migrations/20250211034001_optional_analytics/migration.sql b/quadratic-api/prisma/migrations/20250211034001_optional_analytics/migration.sql deleted file mode 100644 index b925ed76b2..0000000000 --- a/quadratic-api/prisma/migrations/20250211034001_optional_analytics/migration.sql +++ /dev/null @@ -1,2 +0,0 @@ --- AlterTable -ALTER TABLE "AnalyticsAIChatMessage" ALTER COLUMN "s3_key" DROP NOT NULL; diff --git a/quadratic-api/prisma/migrations/20250211191943_optional_analytics_and_token_usage_and_message_type/migration.sql b/quadratic-api/prisma/migrations/20250211191943_optional_analytics_and_token_usage_and_message_type/migration.sql new file mode 100644 index 0000000000..d329cd7782 --- /dev/null +++ b/quadratic-api/prisma/migrations/20250211191943_optional_analytics_and_token_usage_and_message_type/migration.sql @@ -0,0 +1,10 @@ +-- CreateEnum +CREATE TYPE "AIChatMessageType" AS ENUM ('user_prompt', 'tool_result'); + +-- AlterTable +ALTER TABLE "AnalyticsAIChatMessage" ADD COLUMN "cache_read_tokens" INTEGER, +ADD COLUMN "cache_write_tokens" INTEGER, +ADD COLUMN "input_tokens" INTEGER, +ADD COLUMN "message_type" "AIChatMessageType", +ADD COLUMN "output_tokens" INTEGER, +ALTER COLUMN "s3_key" DROP NOT NULL; diff --git a/quadratic-api/prisma/schema.prisma b/quadratic-api/prisma/schema.prisma index 80c79daa15..cb0f93e980 100644 --- a/quadratic-api/prisma/schema.prisma +++ b/quadratic-api/prisma/schema.prisma @@ -285,19 +285,29 @@ enum AIChatSource { GetFileName @map("get_file_name") } +enum AIChatMessageType { + userPrompt @map("user_prompt") + toolResult @map("tool_result") +} + model AnalyticsAIChatMessage { - id Int @id @default(autoincrement()) - chatId Int @map("chat_id") - chat AnalyticsAIChat @relation(fields: [chatId], references: [id]) - model String - messageIndex Int @map("message_index") - s3Key String? @map("s3_key") - like Boolean? - undo Boolean? - codeRunError String? @map("code_run_error") - responseError String? @map("response_error") - createdDate DateTime @default(now()) @map("created_date") - updatedDate DateTime @default(now()) @map("updated_date") + id Int @id @default(autoincrement()) + chatId Int @map("chat_id") + chat AnalyticsAIChat @relation(fields: [chatId], references: [id]) + model String + messageIndex Int @map("message_index") + messageType AIChatMessageType? @map("message_type") + s3Key String? @map("s3_key") + like Boolean? + undo Boolean? + codeRunError String? @map("code_run_error") + responseError String? @map("response_error") + inputTokens Int? @map("input_tokens") + outputTokens Int? @map("output_tokens") + cacheReadTokens Int? @map("cache_read_tokens") + cacheWriteTokens Int? @map("cache_write_tokens") + createdDate DateTime @default(now()) @map("created_date") + updatedDate DateTime @default(now()) @map("updated_date") @@unique([chatId, messageIndex], name: "chatId_messageIndex") @@index([chatId, messageIndex]) diff --git a/quadratic-api/src/ai/helpers/anthropic.helper.ts b/quadratic-api/src/ai/helpers/anthropic.helper.ts index 0003a65f20..ea1846566b 100644 --- a/quadratic-api/src/ai/helpers/anthropic.helper.ts +++ b/quadratic-api/src/ai/helpers/anthropic.helper.ts @@ -8,14 +8,14 @@ import { aiToolsSpec } from 'quadratic-shared/ai/specs/aiToolsSpec'; import type { AIMessagePrompt, AIRequestBody, + AIUsage, AnthropicModel, BedrockAnthropicModel, ParsedAIResponse, } from 'quadratic-shared/typesAndSchemasAI'; -import { calculateUsage } from './usage.helper'; export function getAnthropicApiArgs(args: Omit): { - system: string | TextBlockParam[] | undefined; + system: TextBlockParam[] | undefined; messages: MessageParam[]; tools: Tool[] | undefined; tool_choice: ToolChoice | undefined; @@ -25,7 +25,10 @@ export function getAnthropicApiArgs(args: Omit ({ + type: 'text' as const, + text: message, + })); // with prompt caching of system messages // const system: TextBlockParam[] = systemMessages.map((message, index) => ({ @@ -141,10 +144,12 @@ export async function parseAnthropicStream( model, }; - let input_tokens = 0; - let output_tokens = 0; - let cache_read_tokens = 0; - let cache_write_tokens = 0; + const usage: AIUsage = { + inputTokens: 0, + outputTokens: 0, + cacheReadTokens: 0, + cacheWriteTokens: 0, + }; for await (const chunk of chunks) { if (!response.writableEnded) { @@ -194,15 +199,18 @@ export async function parseAnthropicStream( break; case 'message_start': if (chunk.message.usage) { - input_tokens = Math.max(input_tokens, chunk.message.usage.input_tokens); - output_tokens = Math.max(output_tokens, chunk.message.usage.output_tokens); - cache_read_tokens = Math.max(cache_read_tokens, chunk.message.usage.cache_read_input_tokens ?? 0); - cache_write_tokens = Math.max(cache_write_tokens, chunk.message.usage.cache_creation_input_tokens ?? 0); + usage.inputTokens = Math.max(usage.inputTokens, chunk.message.usage.input_tokens); + usage.outputTokens = Math.max(usage.outputTokens, chunk.message.usage.output_tokens); + usage.cacheReadTokens = Math.max(usage.cacheReadTokens, chunk.message.usage.cache_read_input_tokens ?? 0); + usage.cacheWriteTokens = Math.max( + usage.cacheWriteTokens, + chunk.message.usage.cache_creation_input_tokens ?? 0 + ); } break; case 'message_delta': if (chunk.usage) { - output_tokens = Math.max(output_tokens, chunk.usage.output_tokens); + usage.outputTokens = Math.max(usage.outputTokens, chunk.usage.output_tokens); } break; } @@ -223,8 +231,6 @@ export async function parseAnthropicStream( response.end(); } - const usage = calculateUsage({ model, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens }); - return { responseMessage, usage }; } @@ -271,11 +277,12 @@ export function parseAnthropicResponse( response.json(responseMessage); - const input_tokens = result.usage.input_tokens; - const output_tokens = result.usage.output_tokens; - const cache_read_tokens = result.usage.cache_read_input_tokens ?? 0; - const cache_write_tokens = result.usage.cache_creation_input_tokens ?? 0; - const usage = calculateUsage({ model, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens }); + const usage: AIUsage = { + inputTokens: result.usage.input_tokens, + outputTokens: result.usage.output_tokens, + cacheReadTokens: result.usage.cache_read_input_tokens ?? 0, + cacheWriteTokens: result.usage.cache_creation_input_tokens ?? 0, + }; return { responseMessage, usage }; } diff --git a/quadratic-api/src/ai/helpers/bedrock.helper.ts b/quadratic-api/src/ai/helpers/bedrock.helper.ts index 4da7467e2b..b58c7d14cb 100644 --- a/quadratic-api/src/ai/helpers/bedrock.helper.ts +++ b/quadratic-api/src/ai/helpers/bedrock.helper.ts @@ -13,10 +13,10 @@ import { aiToolsSpec } from 'quadratic-shared/ai/specs/aiToolsSpec'; import type { AIMessagePrompt, AIRequestHelperArgs, + AIUsage, BedrockModel, ParsedAIResponse, } from 'quadratic-shared/typesAndSchemasAI'; -import { calculateUsage } from './usage.helper'; export function getBedrockApiArgs(args: AIRequestHelperArgs): { system: SystemContentBlock[] | undefined; @@ -136,13 +136,17 @@ export async function parseBedrockStream( model, }; - let input_tokens = 0; - let output_tokens = 0; + const usage: AIUsage = { + inputTokens: 0, + outputTokens: 0, + cacheReadTokens: 0, + cacheWriteTokens: 0, + }; for await (const chunk of chunks) { if (chunk.metadata) { - input_tokens = Math.max(input_tokens, chunk.metadata.usage?.inputTokens ?? 0); - output_tokens = Math.max(output_tokens, chunk.metadata.usage?.outputTokens ?? 0); + usage.inputTokens = Math.max(usage.inputTokens, chunk.metadata.usage?.inputTokens ?? 0); + usage.outputTokens = Math.max(usage.outputTokens, chunk.metadata.usage?.outputTokens ?? 0); } if (!response.writableEnded) { @@ -209,8 +213,6 @@ export async function parseBedrockStream( response.end(); } - const usage = calculateUsage({ model, input_tokens, output_tokens, cache_read_tokens: 0, cache_write_tokens: 0 }); - return { responseMessage, usage }; } @@ -256,9 +258,12 @@ export function parseBedrockResponse( response.json(responseMessage); - const input_tokens = result.usage?.inputTokens ?? 0; - const output_tokens = result.usage?.outputTokens ?? 0; - const usage = calculateUsage({ model, input_tokens, output_tokens, cache_read_tokens: 0, cache_write_tokens: 0 }); + const usage: AIUsage = { + inputTokens: result.usage?.inputTokens ?? 0, + outputTokens: result.usage?.outputTokens ?? 0, + cacheReadTokens: 0, + cacheWriteTokens: 0, + }; return { responseMessage, usage }; } diff --git a/quadratic-api/src/ai/helpers/openai.helper.ts b/quadratic-api/src/ai/helpers/openai.helper.ts index ca5a6a7bc3..d7d4354831 100644 --- a/quadratic-api/src/ai/helpers/openai.helper.ts +++ b/quadratic-api/src/ai/helpers/openai.helper.ts @@ -8,10 +8,10 @@ import { aiToolsSpec } from 'quadratic-shared/ai/specs/aiToolsSpec'; import type { AIMessagePrompt, AIRequestHelperArgs, + AIUsage, OpenAIModel, ParsedAIResponse, } from 'quadratic-shared/typesAndSchemasAI'; -import { calculateUsage } from './usage.helper'; export function getOpenAIApiArgs(args: AIRequestHelperArgs): { messages: ChatCompletionMessageParam[]; @@ -122,15 +122,19 @@ export async function parseOpenAIStream( model, }; - let input_tokens = 0; - let output_tokens = 0; - let cache_read_tokens = 0; + const usage: AIUsage = { + inputTokens: 0, + outputTokens: 0, + cacheReadTokens: 0, + cacheWriteTokens: 0, + }; for await (const chunk of chunks) { if (chunk.usage) { - input_tokens = Math.max(input_tokens, chunk.usage.prompt_tokens); - output_tokens = Math.max(output_tokens, chunk.usage.completion_tokens); - cache_read_tokens = Math.max(cache_read_tokens, chunk.usage.prompt_tokens_details?.cached_tokens ?? 0); + usage.inputTokens = Math.max(usage.inputTokens, chunk.usage.prompt_tokens); + usage.outputTokens = Math.max(usage.outputTokens, chunk.usage.completion_tokens); + usage.cacheReadTokens = Math.max(usage.cacheReadTokens, chunk.usage.prompt_tokens_details?.cached_tokens ?? 0); + usage.inputTokens -= usage.cacheReadTokens; } if (!response.writableEnded) { @@ -203,9 +207,6 @@ export async function parseOpenAIStream( response.end(); } - input_tokens -= cache_read_tokens; - const usage = calculateUsage({ model, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens: 0 }); - return { responseMessage, usage }; } @@ -259,10 +260,13 @@ export function parseOpenAIResponse( response.json(responseMessage); - const cache_read_tokens = result.usage?.prompt_tokens_details?.cached_tokens ?? 0; - const input_tokens = (result.usage?.prompt_tokens ?? 0) - cache_read_tokens; - const output_tokens = result.usage?.completion_tokens ?? 0; - const usage = calculateUsage({ model, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens: 0 }); + const cacheReadTokens = result.usage?.prompt_tokens_details?.cached_tokens ?? 0; + const usage: AIUsage = { + inputTokens: (result.usage?.prompt_tokens ?? 0) - cacheReadTokens, + outputTokens: result.usage?.completion_tokens ?? 0, + cacheReadTokens, + cacheWriteTokens: 0, + }; return { responseMessage, usage }; } diff --git a/quadratic-api/src/ai/helpers/usage.helper.ts b/quadratic-api/src/ai/helpers/usage.helper.ts deleted file mode 100644 index b5bc8fe794..0000000000 --- a/quadratic-api/src/ai/helpers/usage.helper.ts +++ /dev/null @@ -1,38 +0,0 @@ -import { MODEL_OPTIONS } from 'quadratic-shared/ai/models/AI_MODELS'; -import type { AIModel, AIUsage } from 'quadratic-shared/typesAndSchemasAI'; - -export function calculateUsage({ - model, - input_tokens, - output_tokens, - cache_read_tokens, - cache_write_tokens, -}: { - model: AIModel; -} & Pick): AIUsage { - const rate_per_million_input_tokens = MODEL_OPTIONS[model].rate_per_million_input_tokens; - const rate_per_million_output_tokens = MODEL_OPTIONS[model].rate_per_million_output_tokens; - const rate_per_million_cache_read_tokens = MODEL_OPTIONS[model].rate_per_million_cache_read_tokens; - const rate_per_million_cache_write_tokens = MODEL_OPTIONS[model].rate_per_million_cache_write_tokens; - const net_cost = - (input_tokens * rate_per_million_input_tokens + - output_tokens * rate_per_million_output_tokens + - cache_read_tokens * rate_per_million_cache_read_tokens + - cache_write_tokens * rate_per_million_cache_write_tokens) / - 1000000; - - const usage: AIUsage = { - model, - rate_per_million_input_tokens, - rate_per_million_output_tokens, - rate_per_million_cache_read_tokens, - rate_per_million_cache_write_tokens, - input_tokens, - output_tokens, - cache_read_tokens, - cache_write_tokens, - net_cost, - }; - - return usage; -} diff --git a/quadratic-api/src/routes/v0/ai.chat.POST.ts b/quadratic-api/src/routes/v0/ai.chat.POST.ts index 7c3d113379..443fedcd02 100644 --- a/quadratic-api/src/routes/v0/ai.chat.POST.ts +++ b/quadratic-api/src/routes/v0/ai.chat.POST.ts @@ -1,5 +1,5 @@ import type { Response } from 'express'; -import { getLastUserPromptMessageIndex } from 'quadratic-shared/ai/helpers/message.helper'; +import { getLastPromptMessageType, getLastUserPromptMessageIndex } from 'quadratic-shared/ai/helpers/message.helper'; import { isAnthropicModel, isBedrockAnthropicModel, @@ -64,13 +64,11 @@ async function handler(req: RequestWithUser, res: Response { const systemMessages: SystemMessage[] = messages.filter( @@ -8,14 +15,22 @@ export const getSystemMessages = (messages: ChatMessage[]): string[] => { return systemMessages.map((message) => message.content); }; -export const getPromptMessages = (messages: ChatMessage[]): ChatMessage[] => { - return messages.filter((message) => message.contextType === 'userPrompt' || message.contextType === 'toolResult'); +export const getPromptMessages = (messages: ChatMessage[]): (UserMessagePrompt | ToolResultMessage)[] => { + return messages.filter( + (message): message is UserMessagePrompt | ToolResultMessage => + message.contextType === 'userPrompt' || message.contextType === 'toolResult' + ); }; export const getUserPromptMessages = (messages: ChatMessage[]): UserMessagePrompt[] => { return getPromptMessages(messages).filter((message): message is UserMessagePrompt => message.role === 'user'); }; +export const getLastPromptMessageType = (messages: ChatMessage[]): UserPromptContextType | ToolResultContextType => { + const userPromptMessage = getUserPromptMessages(messages); + return userPromptMessage[userPromptMessage.length - 1].contextType; +}; + export const getLastUserPromptMessageIndex = (messages: ChatMessage[]): number => { return getUserPromptMessages(messages).length - 1; }; diff --git a/quadratic-shared/ai/models/AI_MODELS.ts b/quadratic-shared/ai/models/AI_MODELS.ts index 1f8ef67e76..f44c8ec042 100644 --- a/quadratic-shared/ai/models/AI_MODELS.ts +++ b/quadratic-shared/ai/models/AI_MODELS.ts @@ -17,10 +17,6 @@ export const MODEL_OPTIONS: { canStreamWithToolCalls: true, enabled: true, provider: 'openai', - rate_per_million_input_tokens: 2.5, - rate_per_million_output_tokens: 10, - rate_per_million_cache_read_tokens: 1.25, - rate_per_million_cache_write_tokens: 0, }, 'o1-2024-12-17': { displayName: 'OpenAI: o1', @@ -30,10 +26,6 @@ export const MODEL_OPTIONS: { canStreamWithToolCalls: false, enabled: false, provider: 'openai', - rate_per_million_input_tokens: 15, - rate_per_million_output_tokens: 60, - rate_per_million_cache_read_tokens: 7.5, - rate_per_million_cache_write_tokens: 0, }, 'o3-mini-2025-01-31': { displayName: 'OpenAI: o3-mini', @@ -43,10 +35,6 @@ export const MODEL_OPTIONS: { canStreamWithToolCalls: true, enabled: true, provider: 'openai', - rate_per_million_input_tokens: 1.1, - rate_per_million_output_tokens: 4.4, - rate_per_million_cache_read_tokens: 0.55, - rate_per_million_cache_write_tokens: 0, }, 'claude-3-5-sonnet-20241022': { displayName: 'Anthropic: Claude 3.5 Sonnet', @@ -56,10 +44,6 @@ export const MODEL_OPTIONS: { canStreamWithToolCalls: true, enabled: false, provider: 'anthropic', - rate_per_million_input_tokens: 3, - rate_per_million_output_tokens: 15, - rate_per_million_cache_read_tokens: 0.3, - rate_per_million_cache_write_tokens: 3.75, }, 'claude-3-5-haiku-20241022': { displayName: 'Anthropic: Claude 3.5 Haiku', @@ -69,10 +53,6 @@ export const MODEL_OPTIONS: { canStreamWithToolCalls: true, enabled: false, provider: 'anthropic', - rate_per_million_input_tokens: 0.8, - rate_per_million_output_tokens: 4, - rate_per_million_cache_read_tokens: 0.08, - rate_per_million_cache_write_tokens: 1, }, 'anthropic.claude-3-5-sonnet-20241022-v2:0': { displayName: `Bedrock: Claude 3.5 Sonnet`, @@ -82,10 +62,6 @@ export const MODEL_OPTIONS: { canStreamWithToolCalls: true, enabled: true, provider: 'bedrock-anthropic', - rate_per_million_input_tokens: 3, - rate_per_million_output_tokens: 15, - rate_per_million_cache_read_tokens: 0.3, - rate_per_million_cache_write_tokens: 3.75, }, 'anthropic.claude-3-5-haiku-20241022-v1:0': { displayName: 'Bedrock: Claude 3.5 Haiku', @@ -95,10 +71,6 @@ export const MODEL_OPTIONS: { canStreamWithToolCalls: true, enabled: false, provider: 'bedrock-anthropic', - rate_per_million_input_tokens: 0.8, - rate_per_million_output_tokens: 4, - rate_per_million_cache_read_tokens: 0.08, - rate_per_million_cache_write_tokens: 1, }, 'anthropic.claude-3-haiku-20240307-v1:0': { displayName: 'Bedrock: Claude 3 Haiku', @@ -108,10 +80,6 @@ export const MODEL_OPTIONS: { canStreamWithToolCalls: true, enabled: false, provider: 'bedrock-anthropic', - rate_per_million_input_tokens: 0.25, - rate_per_million_output_tokens: 1.25, - rate_per_million_cache_read_tokens: 0, - rate_per_million_cache_write_tokens: 0, }, 'us.meta.llama3-2-90b-instruct-v1:0': { displayName: 'Bedrock: Llama 3.2 90B Instruct', @@ -121,10 +89,6 @@ export const MODEL_OPTIONS: { canStreamWithToolCalls: false, enabled: false, provider: 'bedrock', - rate_per_million_input_tokens: 0.72, - rate_per_million_output_tokens: 0.72, - rate_per_million_cache_read_tokens: 0, - rate_per_million_cache_write_tokens: 0, }, 'mistral.mistral-large-2407-v1:0': { displayName: 'Bedrock: Mistral Large 2 (24.07)', @@ -134,9 +98,5 @@ export const MODEL_OPTIONS: { canStreamWithToolCalls: false, enabled: false, provider: 'bedrock', - rate_per_million_input_tokens: 2, - rate_per_million_output_tokens: 6, - rate_per_million_cache_read_tokens: 0, - rate_per_million_cache_write_tokens: 0, }, } as const; diff --git a/quadratic-shared/typesAndSchemasAI.ts b/quadratic-shared/typesAndSchemasAI.ts index eb1d85d66a..d7554f87a6 100644 --- a/quadratic-shared/typesAndSchemasAI.ts +++ b/quadratic-shared/typesAndSchemasAI.ts @@ -42,24 +42,15 @@ const AIModelSchema = z.union([ ]); export type AIModel = z.infer; -const AIRatesSchema = z.object({ - rate_per_million_input_tokens: z.number(), - rate_per_million_output_tokens: z.number(), - rate_per_million_cache_read_tokens: z.number(), - rate_per_million_cache_write_tokens: z.number(), +const AIModelOptionsSchema = z.object({ + displayName: z.string(), + temperature: z.number(), + max_tokens: z.number(), + canStream: z.boolean(), + canStreamWithToolCalls: z.boolean(), + enabled: z.boolean(), + provider: AIProvidersSchema, }); - -const AIModelOptionsSchema = z - .object({ - displayName: z.string(), - temperature: z.number(), - max_tokens: z.number(), - canStream: z.boolean(), - canStreamWithToolCalls: z.boolean(), - enabled: z.boolean(), - provider: AIProvidersSchema, - }) - .extend(AIRatesSchema.shape); export type AIModelOptions = z.infer; const InternalContextTypeSchema = z.enum([ @@ -75,7 +66,9 @@ const InternalContextTypeSchema = z.enum([ 'tables', ]); const ToolResultContextTypeSchema = z.literal('toolResult'); +export type ToolResultContextType = z.infer; const UserPromptContextTypeSchema = z.literal('userPrompt'); +export type UserPromptContextType = z.infer; const ContextTypeSchema = z.union([ InternalContextTypeSchema, ToolResultContextTypeSchema, @@ -204,16 +197,12 @@ export const AIRequestBodySchema = z.object({ export type AIRequestBody = z.infer; export type AIRequestHelperArgs = Omit; -const AIUsageSchema = z - .object({ - model: AIModelSchema, - input_tokens: z.number(), - output_tokens: z.number(), - cache_read_tokens: z.number(), - cache_write_tokens: z.number(), - net_cost: z.number(), - }) - .extend(AIRatesSchema.shape); +const AIUsageSchema = z.object({ + inputTokens: z.number(), + outputTokens: z.number(), + cacheReadTokens: z.number(), + cacheWriteTokens: z.number(), +}); export type AIUsage = z.infer; const parsedAIResponseSchema = z.object({ From 6bbb039877955ee96014c95dc121a790849b3051 Mon Sep 17 00:00:00 2001 From: David Kircos Date: Tue, 11 Feb 2025 14:12:46 -0700 Subject: [PATCH 11/13] billing lifecycle --- quadratic-api/src/ai/usage.ts | 1 + .../src/routes/v0/teams.$uuid.GET.ts | 4 ++ .../src/routes/teams.$teamUuid.settings.tsx | 50 +++++++++++-------- 3 files changed, 33 insertions(+), 22 deletions(-) diff --git a/quadratic-api/src/ai/usage.ts b/quadratic-api/src/ai/usage.ts index c201a2b389..1e30e416b5 100644 --- a/quadratic-api/src/ai/usage.ts +++ b/quadratic-api/src/ai/usage.ts @@ -22,6 +22,7 @@ export const getAIMessageUsageForUser = async (userId: number) => { AND ac.source IN ('ai_assistant', 'ai_analyst', 'ai_researcher') LEFT JOIN "AnalyticsAIChatMessage" acm ON acm.chat_id = ac.id + AND acm.message_type = 'user_prompt' GROUP BY m.month ORDER BY m.month ASC; `; diff --git a/quadratic-api/src/routes/v0/teams.$uuid.GET.ts b/quadratic-api/src/routes/v0/teams.$uuid.GET.ts index 8903db9d64..da450dfe8a 100644 --- a/quadratic-api/src/routes/v0/teams.$uuid.GET.ts +++ b/quadratic-api/src/routes/v0/teams.$uuid.GET.ts @@ -10,6 +10,7 @@ import { userMiddleware } from '../../middleware/user'; import { validateAccessToken } from '../../middleware/validateAccessToken'; import { parseRequest } from '../../middleware/validateRequestSchema'; import { getPresignedFileUrl } from '../../storage/storage'; +import { updateBillingIfNecessary } from '../../stripe/stripe'; import type { RequestWithUser } from '../../types/Request'; import type { ResponseError } from '../../types/Response'; import { ApiError } from '../../utils/ApiError'; @@ -31,6 +32,9 @@ async function handler(req: Request, res: Response {
{/* Free Plan */}
-

Free Plan

+
+

Free Plan

+ {billing.status === undefined && ( + Current Plan + )} +
AI Messages / User / Month @@ -135,16 +140,16 @@ export const Component = () => {
- {billing.status === undefined && ( - - )}
{/* Team AI Plan */} -
-

Team Plan

+
+
+

Team Plan

+ {billing.status === 'ACTIVE' && ( + Current Plan + )} +
AI Messages / User / Month @@ -163,7 +168,7 @@ export const Component = () => {
- {billing.status === undefined && ( + {billing.status === undefined ? ( + ) : ( + billing.status === 'ACTIVE' && ( + + ) )}
@@ -238,19 +257,6 @@ export const Component = () => {
- - {billing.status !== undefined && ( - - )}
From 434ce6c8e8c5d435f4aaa4cd9781288958adf314 Mon Sep 17 00:00:00 2001 From: David Kircos Date: Tue, 11 Feb 2025 22:22:47 -0700 Subject: [PATCH 12/13] basic usage limit error --- quadratic-api/src/ai/usage.ts | 6 +++++- quadratic-api/src/routes/v0/ai.chat.POST.ts | 10 +++++++++- .../src/app/ai/hooks/useAIRequestToAPI.tsx | 14 ++++++++++---- .../src/routes/teams.$teamUuid.settings.tsx | 2 +- 4 files changed, 25 insertions(+), 7 deletions(-) diff --git a/quadratic-api/src/ai/usage.ts b/quadratic-api/src/ai/usage.ts index 1e30e416b5..a724ab8c5c 100644 --- a/quadratic-api/src/ai/usage.ts +++ b/quadratic-api/src/ai/usage.ts @@ -24,6 +24,10 @@ export const getAIMessageUsageForUser = async (userId: number) => { acm.chat_id = ac.id AND acm.message_type = 'user_prompt' GROUP BY m.month - ORDER BY m.month ASC; + ORDER BY m.month DESC; `; }; + +export const userExceededUsageLimit = async (monthlyUsage: Awaited>) => { + return monthlyUsage[0]?.ai_messages > 5; +}; diff --git a/quadratic-api/src/routes/v0/ai.chat.POST.ts b/quadratic-api/src/routes/v0/ai.chat.POST.ts index 443fedcd02..a73f02d7b6 100644 --- a/quadratic-api/src/routes/v0/ai.chat.POST.ts +++ b/quadratic-api/src/routes/v0/ai.chat.POST.ts @@ -15,6 +15,7 @@ import { handleBedrockRequest } from '../../ai/handler/bedrock'; import { handleOpenAIRequest } from '../../ai/handler/openai'; import { getQuadraticContext, getToolUseContext } from '../../ai/helpers/context.helper'; import { ai_rate_limiter } from '../../ai/middleware/aiRateLimiter'; +import { getAIMessageUsageForUser, userExceededUsageLimit } from '../../ai/usage'; import dbClient from '../../dbClient'; import { STORAGE_TYPE } from '../../env-vars'; import { getFile } from '../../middleware/getFile'; @@ -24,7 +25,6 @@ import { parseRequest } from '../../middleware/validateRequestSchema'; import { getBucketName, S3Bucket } from '../../storage/s3'; import { uploadFile } from '../../storage/storage'; import type { RequestWithUser } from '../../types/Request'; - export default [validateAccessToken, ai_rate_limiter, userMiddleware, handler]; const schema = z.object({ @@ -36,6 +36,14 @@ async function handler(req: RequestWithUser, res: Response { + switch (response.status) { + case 429: + return 'You have exceeded the maximum number of requests. Please try again later.'; + case 402: + return 'You have exceeded your AI message limit. Please upgrade your plan to continue.'; + default: + return `Looks like there was a problem. Error: ${data.error}`; + } + })(); setMessages?.((prev) => [ ...prev.slice(0, -1), { role: 'assistant', content: error, contextType: 'userPrompt', model, toolCalls: [] }, diff --git a/quadratic-client/src/routes/teams.$teamUuid.settings.tsx b/quadratic-client/src/routes/teams.$teamUuid.settings.tsx index 1627020572..3020a3ff0a 100644 --- a/quadratic-client/src/routes/teams.$teamUuid.settings.tsx +++ b/quadratic-client/src/routes/teams.$teamUuid.settings.tsx @@ -87,7 +87,7 @@ export const Component = () => { return ; } - const latestUsage = billing.usage[billing.usage.length - 1] || { ai_messages: 0 }; + const latestUsage = billing.usage[0] || { ai_messages: 0 }; return ( <> From e6de36103b2aed3a77787c6913ae75df416e5685 Mon Sep 17 00:00:00 2001 From: David Kircos Date: Wed, 12 Feb 2025 11:32:47 -0700 Subject: [PATCH 13/13] update plans display --- .../src/app/ai/hooks/useAIRequestToAPI.tsx | 7 +++ .../src/routes/teams.$teamUuid.settings.tsx | 59 +++++-------------- 2 files changed, 22 insertions(+), 44 deletions(-) diff --git a/quadratic-client/src/app/ai/hooks/useAIRequestToAPI.tsx b/quadratic-client/src/app/ai/hooks/useAIRequestToAPI.tsx index d6da12857c..e43c88125a 100644 --- a/quadratic-client/src/app/ai/hooks/useAIRequestToAPI.tsx +++ b/quadratic-client/src/app/ai/hooks/useAIRequestToAPI.tsx @@ -1,6 +1,8 @@ import { editorInteractionStateFileUuidAtom } from '@/app/atoms/editorInteractionStateAtom'; import { authClient } from '@/auth/auth'; import { apiClient } from '@/shared/api/apiClient'; +import { ROUTES } from '@/shared/constants/routes'; +import { useFileRouteLoaderData } from '@/shared/hooks/useFileRouteLoaderData'; import { getModelOptions } from 'quadratic-shared/ai/helpers/model.helper'; import { AIMessagePromptSchema, @@ -17,6 +19,11 @@ type HandleAIPromptProps = Omit & { }; export function useAIRequestToAPI() { + const { team } = useFileRouteLoaderData(); + console.log('team', team.uuid); + const url = ROUTES.TEAM_SETTINGS(team.uuid); + console.log('url', url); + const handleAIRequestToAPI = useRecoilCallback( ({ snapshot }) => async ({ diff --git a/quadratic-client/src/routes/teams.$teamUuid.settings.tsx b/quadratic-client/src/routes/teams.$teamUuid.settings.tsx index 3020a3ff0a..730f6e42fd 100644 --- a/quadratic-client/src/routes/teams.$teamUuid.settings.tsx +++ b/quadratic-client/src/routes/teams.$teamUuid.settings.tsx @@ -26,8 +26,6 @@ export const Component = () => { userMakingRequest: { teamPermissions }, billing, users, - files, - filesPrivate, }, } = useDashboardRouteLoaderData(); @@ -93,7 +91,7 @@ export const Component = () => { <>
- + Name @@ -103,11 +101,11 @@ export const Component = () => { Save - + {teamPermissions.includes('TEAM_MANAGE') && ( <> - + Billing @@ -127,25 +125,17 @@ export const Component = () => { AI Messages / User / Month 50
-
- Connection Runs / Month - -
Team Members
-
- Files - -
{/* Team AI Plan */}
-

Team Plan

+

Pro Plan

{billing.status === 'ACTIVE' && ( Current Plan )} @@ -155,17 +145,12 @@ export const Component = () => { AI Messages / User / Month
-
- Connection Runs / Month - -
-
+
Team Members - -
-
- Files - + + $20

+ / User / Month +
{billing.status === undefined ? ( @@ -177,7 +162,7 @@ export const Component = () => { }} className="mt-4 w-full" > - Upgrade Team + Upgrade to Pro ) : ( billing.status === 'ACTIVE' && ( @@ -234,13 +219,6 @@ export const Component = () => {
-
- Connection Runs / Month -
- - - / ∞ -
-
Team Members
@@ -248,18 +226,11 @@ export const Component = () => { / ∞
-
- Files -
- {files.length + filesPrivate.length} - / ∞ -
-
- - + + Privacy @@ -293,7 +264,7 @@ export const Component = () => { ))}
- + )}
@@ -301,13 +272,13 @@ export const Component = () => { ); }; -function Row(props: { children: ReactNode[]; className?: string }) { +function SettingsRow(props: { children: ReactNode[]; className?: string }) { if (props.children.length !== 2) { throw new Error('Row must have exactly two children'); } return ( -
+
{props.children[0]}
{props.children[1]}