Skip to content

Commit

Permalink
streaming refactor complete
Browse files Browse the repository at this point in the history
  • Loading branch information
AyushAgrawal-A2 committed Jan 7, 2025
1 parent fb772cf commit c69a3ae
Show file tree
Hide file tree
Showing 23 changed files with 1,550 additions and 1,504 deletions.
1,359 changes: 721 additions & 638 deletions package-lock.json

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions quadratic-api/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@
},
"author": "David Kircos",
"dependencies": {
"@anthropic-ai/bedrock-sdk": "^0.11.2",
"@anthropic-ai/sdk": "^0.32.1",
"@aws-sdk/client-bedrock-runtime": "^3.682.0",
"@anthropic-ai/bedrock-sdk": "^0.12.0",
"@anthropic-ai/sdk": "^0.33.1",
"@aws-sdk/client-bedrock-runtime": "^3.723.0",
"@aws-sdk/client-s3": "^3.427.0",
"@aws-sdk/client-secrets-manager": "^3.441.0",
"@aws-sdk/s3-request-presigner": "^3.427.0",
Expand All @@ -59,7 +59,7 @@
"multer": "^1.4.5-lts.1",
"multer-s3": "^3.0.1",
"newrelic": "^11.17.0",
"openai": "^4.58.1",
"openai": "^4.77.3",
"pg": "^8.11.3",
"stripe": "^14.16.0",
"supertest": "^6.3.3",
Expand Down
2 changes: 1 addition & 1 deletion quadratic-api/src/app.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ app.get('/', (req, res) => {

// App routes
// TODO: eventually move all of these into the `v0` directory and register them dynamically
app.use('/v0/ai', ai_router);
app.use('/ai', ai_router);
// Internal routes
app.use('/v0/internal', internal_router);

Expand Down
33 changes: 23 additions & 10 deletions quadratic-api/src/routes/ai/ai.ts
Original file line number Diff line number Diff line change
@@ -1,44 +1,57 @@
import express, { type Response } from 'express';
import {
isAnthropicBedrockModel,
isAnthropicModel,
isBedrockAnthropicModel,
isBedrockModel,
isOpenAIModel,
} from 'quadratic-shared/ai/helpers/model.helper';
import { getQuadraticContext } from 'quadratic-shared/ai/helpers/quadraticContext.helper';
import { getToolUseContext } from 'quadratic-shared/ai/helpers/toolUseContext.helper';
import { AIAutoCompleteRequestBodySchema } from 'quadratic-shared/typesAndSchemasAI';
import { AIAutoCompleteRequestBodySchema, type AIMessagePrompt } from 'quadratic-shared/typesAndSchemasAI';
import { validateAccessToken } from '../../middleware/validateAccessToken';
import type { Request } from '../../types/Request';
import { ai_rate_limiter } from './aiRateLimiter';
import { handleAnthropicRequest } from './anthropic';
import { handleBedrockRequest } from './bedrock';
import { getQuadraticContext, getToolUseContext } from './helpers/context.helper';
import { handleOpenAIRequest } from './openai';

const ai_router = express.Router();

ai_router.post('/', validateAccessToken, ai_rate_limiter, async (request: Request, response: Response) => {
try {
const { model, ...args } = AIAutoCompleteRequestBodySchema.parse(request.body);

if (args.useToolUsePrompt) {
const toolUseContext = getToolUseContext();
args.messages = [...toolUseContext, ...args.messages];
args.messages.unshift(...toolUseContext);
}

if (args.useQuadraticContext) {
const quadraticContext = getQuadraticContext(args.language);
args.messages = [...quadraticContext, ...args.messages];
args.messages.unshift(...quadraticContext);
}

let responseMessage: AIMessagePrompt | undefined;

switch (true) {
case isAnthropicBedrockModel(model) || isBedrockModel(model):
return handleBedrockRequest(model, args, response);
case isBedrockAnthropicModel(model) || isBedrockModel(model):
responseMessage = await handleBedrockRequest(model, args, response);
break;
case isAnthropicModel(model):
return handleAnthropicRequest(model, args, response);
responseMessage = await handleAnthropicRequest(model, args, response);
break;
case isOpenAIModel(model):
return handleOpenAIRequest(model, args, response);
responseMessage = await handleOpenAIRequest(model, args, response);
break;
default:
throw new Error('Model not supported');
}

if (responseMessage) {
args.messages.push(responseMessage);
}

// todo(ayush): check for permission and then log chats
console.log(args.messages);
} catch (error: any) {
response.status(400).json(error);
console.log(error);
Expand Down
24 changes: 8 additions & 16 deletions quadratic-api/src/routes/ai/anthropic.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import Anthropic from '@anthropic-ai/sdk';
import { type Response } from 'express';
import { getAnthropicApiArgs } from 'quadratic-shared/ai/helpers/anthropic.helper';
import type { Response } from 'express';
import { getModelOptions } from 'quadratic-shared/ai/helpers/model.helper';
import type { AIAutoCompleteRequestBody, AnthropicModel } from 'quadratic-shared/typesAndSchemasAI';
import type { AIAutoCompleteRequestBody, AIMessagePrompt, AnthropicModel } from 'quadratic-shared/typesAndSchemasAI';
import { ANTHROPIC_API_KEY } from '../../env-vars';
import { getAnthropicApiArgs, parseAnthropicResponse, parseAnthropicStream } from './helpers/anthropic.helper';

const anthropic = new Anthropic({
apiKey: ANTHROPIC_API_KEY,
Expand All @@ -13,7 +13,7 @@ export const handleAnthropicRequest = async (
model: AnthropicModel,
args: Omit<AIAutoCompleteRequestBody, 'model'>,
response: Response
) => {
): Promise<AIMessagePrompt | undefined> => {
const { system, messages, tools, tool_choice } = getAnthropicApiArgs(args);
const { stream, temperature, max_tokens } = getModelOptions(model, args);

Expand All @@ -33,17 +33,9 @@ export const handleAnthropicRequest = async (
response.setHeader('Content-Type', 'text/event-stream');
response.setHeader('Cache-Control', 'no-cache');
response.setHeader('Connection', 'keep-alive');
for await (const chunk of chunks) {
if (!response.writableEnded) {
response.write(`data: ${JSON.stringify(chunk)}\n\n`);
} else {
break;
}
}

if (!response.writableEnded) {
response.end();
}
const responseMessage = await parseAnthropicStream(chunks, response);
return responseMessage;
} catch (error: any) {
if (!response.headersSent) {
if (error instanceof Anthropic.APIError) {
Expand All @@ -69,8 +61,8 @@ export const handleAnthropicRequest = async (
tools,
tool_choice,
});

response.json(result.content);
const responseMessage = parseAnthropicResponse(result, response);
return responseMessage;
} catch (error: any) {
if (error instanceof Anthropic.APIError) {
response.status(error.status ?? 400).json(error.message);
Expand Down
46 changes: 18 additions & 28 deletions quadratic-api/src/routes/ai/bedrock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ import { AnthropicBedrock } from '@anthropic-ai/bedrock-sdk';
import Anthropic from '@anthropic-ai/sdk';
import { BedrockRuntimeClient, ConverseCommand, ConverseStreamCommand } from '@aws-sdk/client-bedrock-runtime';
import { type Response } from 'express';
import { getAnthropicApiArgs } from 'quadratic-shared/ai/helpers/anthropic.helper';
import { getBedrockApiArgs } from 'quadratic-shared/ai/helpers/bedrock.helper';
import { getModelOptions, isAnthropicBedrockModel } from 'quadratic-shared/ai/helpers/model.helper';
import { type AIAutoCompleteRequestBody, type BedrockModel } from 'quadratic-shared/typesAndSchemasAI';
import { getModelOptions, isBedrockAnthropicModel } from 'quadratic-shared/ai/helpers/model.helper';
import {
type AIAutoCompleteRequestBody,
type AIMessagePrompt,
type BedrockModel,
} from 'quadratic-shared/typesAndSchemasAI';
import { AWS_S3_ACCESS_KEY_ID, AWS_S3_REGION, AWS_S3_SECRET_ACCESS_KEY } from '../../env-vars';
import { getAnthropicApiArgs, parseAnthropicResponse, parseAnthropicStream } from './helpers/anthropic.helper';
import { getBedrockApiArgs, parseBedrockResponse, parseBedrockStream } from './helpers/bedrock.helper';

// aws-sdk for bedrock, generic for all models
const bedrock = new BedrockRuntimeClient({
Expand All @@ -25,10 +29,10 @@ export const handleBedrockRequest = async (
model: BedrockModel,
args: Omit<AIAutoCompleteRequestBody, 'model'>,
response: Response
) => {
): Promise<AIMessagePrompt | undefined> => {
const { stream, temperature, max_tokens } = getModelOptions(model, args);

if (isAnthropicBedrockModel(model)) {
if (isBedrockAnthropicModel(model)) {
const { system, messages, tools, tool_choice } = getAnthropicApiArgs(args);
if (stream) {
try {
Expand All @@ -46,17 +50,9 @@ export const handleBedrockRequest = async (
response.setHeader('Content-Type', 'text/event-stream');
response.setHeader('Cache-Control', 'no-cache');
response.setHeader('Connection', 'keep-alive');
for await (const chunk of chunks) {
if (!response.writableEnded) {
response.write(`data: ${JSON.stringify(chunk)}\n\n`);
} else {
break;
}
}

if (!response.writableEnded) {
response.end();
}
const responseMessage = await parseAnthropicStream(chunks, response);
return responseMessage;
} catch (error: any) {
if (!response.headersSent) {
if (error instanceof Anthropic.APIError) {
Expand All @@ -81,7 +77,8 @@ export const handleBedrockRequest = async (
tools,
tool_choice,
});
response.json(result.content);
const responseMessage = parseAnthropicResponse(result, response);
return responseMessage;
} catch (error: any) {
if (error instanceof Anthropic.APIError) {
response.status(error.status ?? 400).json(error.message);
Expand Down Expand Up @@ -113,17 +110,9 @@ export const handleBedrockRequest = async (
response.setHeader('Content-Type', 'text/event-stream');
response.setHeader('Cache-Control', 'no-cache');
response.setHeader('Connection', 'keep-alive');
for await (const chunk of chunks) {
if (!response.writableEnded) {
response.write(`data: ${JSON.stringify(chunk)}\n\n`);
} else {
break;
}
}

if (!response.writableEnded) {
response.end();
}
const responseMessage = await parseBedrockStream(chunks, response);
return responseMessage;
} catch (error: any) {
if (!response.headersSent) {
if (error.response) {
Expand Down Expand Up @@ -152,7 +141,8 @@ export const handleBedrockRequest = async (
});

const result = await bedrock.send(command);
response.json(result.output);
const responseMessage = parseBedrockResponse(result.output, response);
return responseMessage;
} catch (error: any) {
if (error.response) {
response.status(error.response.status).json(error.response.data);
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit c69a3ae

Please sign in to comment.