Skip to content

Commit

Permalink
Merge pull request #3 from mannjaro/feature/embedding_model
Browse files Browse the repository at this point in the history
add: schema
  • Loading branch information
mannjaro authored Sep 16, 2024
2 parents 12957e6 + 80897aa commit 703bffd
Show file tree
Hide file tree
Showing 13 changed files with 248 additions and 41 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/deploy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ jobs:
platforms: linux/arm64
# リポジトリをクローンする
- name: Checkout repository
uses: actions/checkout@v4

uses: actions/checkout@v4
# AWS の認証
- name: configure aws credentials
uses: aws-actions/configure-aws-credentials@v4
Expand Down
31 changes: 31 additions & 0 deletions .github/workflows/format.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: Code quality

on:
push:
pull_request:


jobs:
quality:
permissions:
pull-requests: write
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Setup Biome
uses: biomejs/setup-biome@v2
with:
version: latest
- name: Run biome check
run: biome check **/*.ts --write
- name: Commit and push
continue-on-error: true
run: |
git config user.email 41898282+github-actions[bot]@users.noreply.github.com
git config user.name github-actions[bot]
git add .
git commit -m "format by bot"
git push
- name: Run Biome
run: biome ci **/*.ts
29 changes: 21 additions & 8 deletions biome.json
Original file line number Diff line number Diff line change
@@ -1,19 +1,32 @@
{
"$schema": "https://biomejs.dev/schemas/1.8.3/schema.json",
"$schema": "https://biomejs.dev/schemas/1.9.1/schema.json",
"vcs": {
"enabled": true,
"clientKind": "git",
"useIgnoreFile": true,
"defaultBranch": "main"
},
"files": {
"ignoreUnknown": false,
"ignore": []
},
"formatter": {
"enabled": true,
"indentStyle": "space",
"indentWidth": 2
},
"organizeImports": {
"enabled": true
},
"linter": {
"enabled": true,
"rules": {
"recommended": true,
"style": {
"useImportType": "info"
}
"recommended": true
}
},
"formatter": {
"indentStyle": "space",
"indentWidth": 2
"javascript": {
"formatter": {
"quoteStyle": "double"
}
}
}
6 changes: 3 additions & 3 deletions lambda/bedrock-proxy/src/api/chat.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import { Hono } from "hono";
import { BedrockModel } from "../model/bedrock";
import { zValidator } from "@hono/zod-validator";
import { Hono } from "hono";
import { BedrockModel } from "../model/bedrock/chat";

import { ChatRequestSchema } from "../schema/request/chat";
import { streamText } from "hono/streaming";
import { ChatRequestSchema } from "../schema/request/chat";

const chat = new Hono();

Expand Down
25 changes: 25 additions & 0 deletions lambda/bedrock-proxy/src/api/embedding.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import { zValidator } from "@hono/zod-validator";
import { Hono } from "hono";
import { CohereEmbeddingModel } from "../model/bedrock/embedding";

import { EmbeddingRequestSchema } from "../schema/request/embedding";

const embedding = new Hono();

embedding.post(
"",
zValidator("json", EmbeddingRequestSchema, async (result, c) => {
if (!result.success) {
console.log(await c.req.text());
return c.json({ message: "Validation failed" });
}
}),
async (c) => {
const model = new CohereEmbeddingModel();
const embeddingRequest = await c.req.valid("json");
const response = await model.embed(embeddingRequest);
return c.json(response);
},
);

export { embedding };
10 changes: 8 additions & 2 deletions lambda/bedrock-proxy/src/app.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import { Hono } from "hono";

import type { LambdaEvent } from "hono/aws-lambda";

import { chat } from "./api/chat";
import { embedding } from "./api/embedding";
type Bindings = {
event: LambdaEvent;
};

const app = new Hono();
const app = new Hono<{ Bindings: Bindings }>();

app.get("/", (c) => {
return c.text("ok");
});

app.route("/chat", chat);

app.route("/embeddings", embedding);
export default app;
Original file line number Diff line number Diff line change
@@ -1,37 +1,36 @@
import type {
Message,
Tool,
ToolResultBlock,
ToolUseBlock,
ConverseCommandOutput,
ConverseStreamCommandOutput,
ContentBlock,
ConverseCommandInput,
ConverseCommandOutput,
ConverseStreamCommandInput,
ContentBlock,
ImageFormat,
ConverseStreamCommandOutput,
ConverseStreamOutput,
ImageFormat,
Message,
StopReason,
Tool,
ToolResultBlock,
ToolUseBlock,
} from "@aws-sdk/client-bedrock-runtime";
import {
BedrockRuntime,
BedrockRuntimeServiceException,
} from "@aws-sdk/client-bedrock-runtime";

import { BaseModel } from "./base";
import { BaseModel } from "../base";

import type { StreamingApi } from "hono/utils/stream";
import type {
ChatRequest,
FunctionInput,
UserMessage,
} from "../schema/request/chat";

} from "../../schema/request/chat";
import type {
ChatResponse,
ChatStreamResponse,
ChatResponseMessage,
ChatStreamResponse,
ToolCall,
} from "../schema/response/chat";
import { StreamingApi } from "hono/utils/stream";
} from "../../schema/response/chat";

export class BedrockModel extends BaseModel {
async _invokeBedrock(
Expand Down
98 changes: 98 additions & 0 deletions lambda/bedrock-proxy/src/model/bedrock/embedding.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import {
BedrockRuntime,
type InvokeModelCommandInput,
ValidationException,
} from "@aws-sdk/client-bedrock-runtime";
import type { EmbeddingRequest } from "../../schema/request/embedding";
import type {
Embedding,
EmbeddingResponse,
} from "../../schema/response/embedding";

interface InvokeModelArgs {
texts: Array<string>;
input_type: "search_document";
truncate: "END";
}

class BedrockEmbeddingModel {
_invokeModel(args: InvokeModelArgs, modelId: string) {
const bedrockRuntime = new BedrockRuntime({ region: "us-east-1" });
const input: InvokeModelCommandInput = {
modelId: modelId,
body: JSON.stringify(args),
accept: "application/json",
contentType: "application/json",
};
try {
const response = bedrockRuntime.invokeModel(input);
return response;
} catch (error) {
if (error instanceof ValidationException) {
throw new Error(`HTTPException: ${error.message}`);
}
throw error;
}
}
_createEmbeddingResponse(
embeddings: Array<number>,
model: string,
inputTokens: number,
outputTokens: number,
encodingFormat: "float" | "base64",
): EmbeddingResponse {
const data: Array<Embedding> = [];
data.push({
object: "embedding",
index: 0,
embedding: embeddings,
});
return {
object: "list",
data: data,
model: model,
usage: {
prompt_tokens: inputTokens,
completion_tokens: outputTokens,
total_tokens: inputTokens + outputTokens,
},
};
}
}

class CohereEmbeddingModel extends BedrockEmbeddingModel {
_parseRequest(embeddingRequest: EmbeddingRequest): InvokeModelArgs {
const input = embeddingRequest.input;
const texts = (): Array<string> => {
if (typeof input === "string") {
return [input];
}
if (Array.isArray(input)) {
return input;
}
throw new Error("Invalid input type");
};
return {
texts: texts(),
input_type: "search_document",
truncate: "END",
};
}
async embed(embeddingRequest: EmbeddingRequest): Promise<EmbeddingResponse> {
const response = await this._invokeModel(
this._parseRequest(embeddingRequest),
embeddingRequest.model,
);
const body = JSON.parse(new TextDecoder().decode(response.body));
console.log(body);
return this._createEmbeddingResponse(
body.embeddings,
embeddingRequest.model,
0,
0,
"float",
);
}
}

export { CohereEmbeddingModel };
11 changes: 11 additions & 0 deletions lambda/bedrock-proxy/src/schema/request/embedding.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import { z } from "zod";

export const EmbeddingRequestSchema = z.object({
input: z.union([z.string(), z.array(z.string())]),
model: z.string(),
encoding_format: z.enum(["float", "base64"]).default("float"),
dimensions: z.number().optional(),
user: z.string().optional(),
});

export type EmbeddingRequest = z.infer<typeof EmbeddingRequestSchema>;
26 changes: 26 additions & 0 deletions lambda/bedrock-proxy/src/schema/response/embedding.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import { z } from "zod";

const EmbeddingSchema = z.object({
object: z.literal("embedding"),
embedding: z.array(z.number()),
index: z.number(),
});
export const EmbeddingResponseSchema = z.object({
object: z.literal("list"),
data: z.array(
z.object({
object: z.literal("embedding"),
index: z.number(),
embedding: z.array(z.number()),
}),
),
model: z.string(),
usage: z.object({
prompt_tokens: z.number(),
completion_tokens: z.number(),
total_tokens: z.number(),
}),
});

export type Embedding = z.infer<typeof EmbeddingSchema>;
export type EmbeddingResponse = z.infer<typeof EmbeddingResponseSchema>;
2 changes: 1 addition & 1 deletion lib/construct/lambda.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import * as cdk from "aws-cdk-lib";
import * as lambda from "aws-cdk-lib/aws-lambda";
import * as iam from "aws-cdk-lib/aws-iam";
import * as lambda from "aws-cdk-lib/aws-lambda";
import { NodejsFunction } from "aws-cdk-lib/aws-lambda-nodejs";
import { Construct } from "constructs";

Expand Down
2 changes: 1 addition & 1 deletion parameter.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Environment } from "aws-cdk-lib";
import type { Environment } from "aws-cdk-lib";

export interface AppParameter {
env?: Environment;
Expand Down
19 changes: 9 additions & 10 deletions test/cloudfront-lambda.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@

// example test. To run these tests, uncomment this file along with the
// example resource in lib/cloudfront-lambda-stack.ts
test('SQS Queue Created', () => {
// const app = new cdk.App();
// // WHEN
// const stack = new CloudfrontLambda.CloudfrontLambdaStack(app, 'MyTestStack');
// // THEN
// const template = Template.fromStack(stack);

// template.hasResourceProperties('AWS::SQS::Queue', {
// VisibilityTimeout: 300
// });
test("SQS Queue Created", () => {
// const app = new cdk.App();
// // WHEN
// const stack = new CloudfrontLambda.CloudfrontLambdaStack(app, 'MyTestStack');
// // THEN
// const template = Template.fromStack(stack);
// template.hasResourceProperties('AWS::SQS::Queue', {
// VisibilityTimeout: 300
// });
});

0 comments on commit 703bffd

Please sign in to comment.