diff --git a/Makefile b/Makefile index 7b8f200..ba81c2e 100644 --- a/Makefile +++ b/Makefile @@ -7,6 +7,7 @@ ENV_FILE:=.env APP_PORT:=8000 DATABASE_BIND_PATH:=./lancedb REGION:=ap-southeast-2 +MODEL_ID:=anthropic.claude-3-sonnet-20240229-v1:0 # build and run this service only .PHONY: build @@ -22,6 +23,7 @@ env: @echo "APP_PORT=$(APP_PORT)"> $(ENV_FILE) @echo "DATABASE_BIND_PATH=$(DATABASE_BIND_PATH)">> $(ENV_FILE) @echo "REGION=$(REGION)">> $(ENV_FILE) + @echo "MODEL_ID=$(MODEL_ID)">> $(ENV_FILE) # normal build & up .PHONY: compose-build diff --git a/actions/bedrock.js b/actions/bedrock.js new file mode 100644 index 0000000..c0f35bc --- /dev/null +++ b/actions/bedrock.js @@ -0,0 +1,81 @@ +import { BedrockRuntimeClient, ConverseCommand, ConverseStreamCommand } from "@aws-sdk/client-bedrock-runtime"; + +/** + * @type { BedrockRuntimeClient? } + */ +let client = null; + +/** + * Initialize or Re-initialize the bedrock runtime client in given region. + */ +export function rebuildBedrockClient() { + client = new BedrockRuntimeClient({ + region: process.env.REGION || 'ap-southeast-2' + }) +} + +/** + * @callback InferenceCallback + * @param {String} text_piece Piece of text + * @param {Boolean} finished indicate whether response finished or not + */ + +/** + * @typedef MessageContent + * @property {String} text text of the content + */ + +/** + * @typedef Message + * @property {"user"|"assistant"} role + * @property {MessageContent[]} content + */ + +/** + * @typedef Settings + * @property {Boolean} stream Whether response in stream or not + * @property {Number} max_tokens The max tokens response can have + * @property {Number} top_p Top P of the request + * @property {Number} temperature Temperature of the request + */ + +/** + * Do inference with AWS Bedrock + * @param {Message[]} messages messages to inference + * @param {Settings} settings + * @param {InferenceCallback} cb + * @returns {Promise} the whole response text no matter stream or not + */ +export async function inference(messages, settings, cb = null) { + if(!client) rebuildBedrockClient(); + + const { top_p, temperature, max_tokens } = settings; + + const input = { + modelId: process.env.MODEL_ID || 'anthropic.claude-3-sonnet-20240229-v1:0', + messages, + inferenceConfig: { + maxTokens: max_tokens || 2048, + temperature: temperature || 0.7, + topP: top_p || 0.9 + } + } + + let command; + if(settings.stream) command = new ConverseStreamCommand(input); + else command = new ConverseCommand(input); + + const response = await client.send(command); + + let response_text; + for await (const resp of response.stream) { + if(resp.contentBlockDelta) { + text_piece = resp.contentBlockDelta.delta.text; + response_text += text_piece; + cb && cb(text_piece, false); + } + } + cb && cb('', true) + + return response_text; +} \ No newline at end of file diff --git a/actions/inference.js b/actions/inference.js index 5ce8405..651e321 100644 --- a/actions/inference.js +++ b/actions/inference.js @@ -13,12 +13,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -import { formatOpenAIContext } from "../tools/formatContext.js"; +// import { formatOpenAIContext } from "../tools/formatContext.js"; import { generateFingerprint } from "../tools/generator.js"; -import { post } from "../tools/request.js"; -import { searchByMessage } from "../database/rag-inference.js"; -import { userMessageHandler } from "../tools/plugin.js"; +// import { post } from "../tools/request.js"; +// import { searchByMessage } from "../database/rag-inference.js"; +// import { userMessageHandler } from "../tools/plugin.js"; import { extractAPIKeyFromHeader, validateAPIKey } from "../tools/apiKey.js"; +import { inference } from "./bedrock.js"; /** * Generates a response content object for chat completion. @@ -69,33 +70,33 @@ function generateResponseContent( return resp; } -/** - * Post to inference engine - * @param {Object} req_body The request body to be sent - * @param {Function} callback Callback function, takes one parameter contains parsed response json - * @param {Boolean} isStream To set the callback behaviour - */ -async function doInference(req_body, callback, isStream) { - if(isStream) { - const eng_resp = await post('completion', { body: req_body }, { getJSON: false }); - const reader = eng_resp.body.pipeThrough(new TextDecoderStream()).getReader(); - while(true) { - const { value, done } = await reader.read(); - if(done) break; - const data = value.split("data: ").pop() - try { - callback(JSON.parse(data)); - } catch(error) { - console.log(error) - callback({content: "", stop: true}) - } - } - } else { - const eng_resp = await post('completion', { body: req_body }); - if(eng_resp.http_error) return; - callback(eng_resp); - } -} +// /** +// * Post to inference engine +// * @param {Object} req_body The request body to be sent +// * @param {Function} callback Callback function, takes one parameter contains parsed response json +// * @param {Boolean} isStream To set the callback behaviour +// */ +// async function doInference(req_body, callback, isStream) { +// if(isStream) { +// const eng_resp = await post('completion', { body: req_body }, { getJSON: false }); +// const reader = eng_resp.body.pipeThrough(new TextDecoderStream()).getReader(); +// while(true) { +// const { value, done } = await reader.read(); +// if(done) break; +// const data = value.split("data: ").pop() +// try { +// callback(JSON.parse(data)); +// } catch(error) { +// console.log(error) +// callback({content: "", stop: true}) +// } +// } +// } else { +// const eng_resp = await post('completion', { body: req_body }); +// if(eng_resp.http_error) return; +// callback(eng_resp); +// } +// } function retrieveData(req_header, req_body) { // retrieve api key @@ -105,7 +106,7 @@ function retrieveData(req_header, req_body) { } // get attributes required special consideration - let { messages, max_tokens, ...request_body } = req_body; + let { messages, ...request_body } = req_body; // validate messages if(!messages || !messages.length) { @@ -118,11 +119,11 @@ function retrieveData(req_header, req_body) { }) // apply n_predict value - if(!max_tokens) max_tokens = 128; - request_body.n_predict = max_tokens; + // if(!max_tokens) max_tokens = 128; + // request_body.n_predict = max_tokens; // apply stop value - if(!req_body.stop) request_body.stop = [...default_stop_keywords]; + // if(!req_body.stop) request_body.stop = [...default_stop_keywords]; // generated fields const system_fingerprint = generateFingerprint(); @@ -133,7 +134,7 @@ function retrieveData(req_header, req_body) { } -const default_stop_keywords = ["<|endoftext|>", "<|end|>", "<|user|>", "<|assistant|>"] +// const default_stop_keywords = ["<|endoftext|>", "<|end|>", "<|user|>", "<|assistant|>"] /** * Handles a chat completion request, generating a response based on the input messages. @@ -191,73 +192,88 @@ export async function chatCompletion(req, res) { res.setHeader("X-Accel-Buffering", "no"); res.setHeader("Connection", "Keep-Alive"); } - doInference(request_body, (data) => { - const { content, stop } = data; + // doInference(request_body, (data) => { + // const { content, stop } = data; + // if(isStream) { + // res.write(JSON.stringify( + // generateResponseContent( + // api_key, 'chat.completion.chunk', model, system_fingerprint, isStream, content, stop + // ) + // )+'\n\n'); + // if(stop) res.end(); + // } else { + // res.send(generateResponseContent( + // api_key, 'chat.completion', model, system_fingerprint, + // isStream, content, true + // )) + // } + // }, isStream) + inference(messages, request_body, (text_piece, finished) => { if(isStream) { res.write(JSON.stringify( generateResponseContent( - api_key, 'chat.completion.chunk', model, system_fingerprint, isStream, content, stop + api_key, 'chat.completion.chunk', model, system_fingerprint, isStream, text_piece, finished ) )+'\n\n'); - if(stop) res.end(); + if(finished) res.end(); } else { res.send(generateResponseContent( api_key, 'chat.completion', model, system_fingerprint, - isStream, content, true + isStream, text_piece, true )) } - }, isStream) + }) } -/** - * Handles a RAG-based (Retrieval-Augmented Generation) chat completion request. - * - * @async - * @param {Request} req - The HTTP request object. - * @param {Response} res - The HTTP response object. - * @returns {Promise} A promise that resolves when the response is sent. - */ -export async function ragChatCompletion(req, res) { - const {error, body, status, message} = retrieveData(req.headers, req.body); - if(error) { - res.status(status).send(message); - return; - } - const { dataset_name, ...request_body } = body.request_body; - if(!dataset_name) { - res.status(422).send("Dataset name not specified."); - } - const { api_key, model, system_fingerprint, messages } = body +// /** +// * Handles a RAG-based (Retrieval-Augmented Generation) chat completion request. +// * +// * @async +// * @param {Request} req - The HTTP request object. +// * @param {Response} res - The HTTP response object. +// * @returns {Promise} A promise that resolves when the response is sent. +// */ +// export async function ragChatCompletion(req, res) { +// const {error, body, status, message} = retrieveData(req.headers, req.body); +// if(error) { +// res.status(status).send(message); +// return; +// } +// const { dataset_name, ...request_body } = body.request_body; +// if(!dataset_name) { +// res.status(422).send("Dataset name not specified."); +// } +// const { api_key, model, system_fingerprint, messages } = body - const latest_message = messages.slice(-1)[0].content; - const rag_result = await searchByMessage(dataset_name, latest_message); +// const latest_message = messages.slice(-1)[0].content; +// const rag_result = await searchByMessage(dataset_name, latest_message); - const context = [...messages]; - if(rag_result) context.push({ - role: "system", - content: `This background information is useful for your next answer: "${rag_result.context}"` - }) - request_body.prompt = formatOpenAIContext(context); +// const context = [...messages]; +// if(rag_result) context.push({ +// role: "system", +// content: `This background information is useful for your next answer: "${rag_result.context}"` +// }) +// request_body.prompt = formatOpenAIContext(context); - const isStream = !!request_body.stream; - if(isStream) { - res.setHeader("Content-Type", "text/event-stream"); - res.setHeader("Cache-Control", "no-cache"); - res.setHeader("X-Accel-Buffering", "no"); - res.setHeader("Connection", "Keep-Alive"); - } - doInference(request_body, (data) => { - const { content, stop } = data; - const openai_response = generateResponseContent( - api_key, 'chat.completion.chunk', model, system_fingerprint, true, content, stop - ) - const rag_response = stop ? { content: openai_response, rag_context: rag_result } : openai_response; +// const isStream = !!request_body.stream; +// if(isStream) { +// res.setHeader("Content-Type", "text/event-stream"); +// res.setHeader("Cache-Control", "no-cache"); +// res.setHeader("X-Accel-Buffering", "no"); +// res.setHeader("Connection", "Keep-Alive"); +// } +// doInference(request_body, (data) => { +// const { content, stop } = data; +// const openai_response = generateResponseContent( +// api_key, 'chat.completion.chunk', model, system_fingerprint, true, content, stop +// ) +// const rag_response = stop ? { content: openai_response, rag_context: rag_result } : openai_response; - if(isStream) { - res.write(JSON.stringify(rag_response)+'\n\n'); - if(stop) res.end(); - } else { - res.send(rag_response); - } - }, isStream) -} +// if(isStream) { +// res.write(JSON.stringify(rag_response)+'\n\n'); +// if(stop) res.end(); +// } else { +// res.send(rag_response); +// } +// }, isStream) +// }