diff --git a/actions/embedding.js b/actions/embedding.js new file mode 100644 index 0000000..67344e2 --- /dev/null +++ b/actions/embedding.js @@ -0,0 +1,53 @@ +// coding=utf-8 + +import { post } from "../tools/request.js"; + +// Copyright [2024] [SkywardAI] +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +export async function embeddings(req, res) { + if(!req.headers.authorization) { + res.status(401).send("Not Authorized!"); + return; + } + + const { input } = req.body; + if(!input) { + res.status(400).send("Input sentence not specified!"); + } + + const { embedding, http_error } = await post('embedding', {body: { + content: input + }}, { eng: "embedding" }); + + if(http_error) { + res.status(500).send("Embedding Engine Internal Server Error") + return; + } + + res.status(200).send({ + object: "list", + data: [ + { + object: "embedding", + embedding, + index: 0 + } + ], + model: "all-MiniLM-L6-v2", + usage: { + prompt_tokens: 0, + total_tokens: 0 + } + }) +} \ No newline at end of file diff --git a/actions/inference.js b/actions/inference.js index 32258b5..e66c6c2 100644 --- a/actions/inference.js +++ b/actions/inference.js @@ -16,6 +16,7 @@ import { formatOpenAIContext } from "../tools/formatContext.js"; import { generateFingerprint } from "../tools/generator.js"; import { post } from "../tools/request.js"; +import { loadDataset, searchByMessage } from "../database/rag-inference.js"; function generateResponseContent(id, object, model, system_fingerprint, stream, content, stopped) { const resp = { @@ -53,11 +54,42 @@ export async function chatCompletion(req, res) { return; } - let {messages, max_tokens, ...request_body} = req.body; + let { + messages, max_tokens, + system_passed_extra_properties, + ...request_body + } = req.body; + // apply default values or send error messages + if(!messages || !messages.length) { + res.status(400).send("Messages not given!"); + return; + } + if(!max_tokens) max_tokens = 128; + + let genResp = generateResponseContent; + if(system_passed_extra_properties) { + const { inference_type, extra_fields } = system_passed_extra_properties; + if(inference_type === "rag") { + const { has_background, question, answer, _distance:distance } = extra_fields; + if(has_background) { + messages.splice(-1, 0, { + role: 'system', + content: `Your next answer should based on this background: the question is "${question}" and the answer is "${answer}".` + }) + } + genResp = (...args) => { + const content = generateResponseContent(...args) + if(args[6] && has_background) { + return { content, rag_context: {question, answer, distance} } + } else return content; + } + } + } + // format requests to llamacpp format input request_body.prompt = formatOpenAIContext(messages); - if(max_tokens) request_body.n_predict = max_tokens; + request_body.n_predict = max_tokens; if(!request_body.stop) request_body.stop = [...default_stop_keywords]; // extra @@ -78,16 +110,39 @@ export async function chatCompletion(req, res) { const data = value.split("data: ").pop() const json_data = JSON.parse(data) const { content, stop } = json_data; - res.write(JSON.stringify(generateResponseContent(api_key, 'chat.completion.chunk', model, system_fingerprint, true, content, stop))+'\n\n'); + res.write(JSON.stringify(genResp(api_key, 'chat.completion.chunk', model, system_fingerprint, true, content, stop))+'\n\n'); } res.end(); } else { const eng_resp = await post('completion', { body: request_body }); const { model, content } = eng_resp; - const response_json = generateResponseContent( + const response_json = genResp( api_key, 'chat.completion', model, system_fingerprint, false, content, true ) res.send(response_json); } +} + +export async function ragChatCompletion(req, res) { + const { dataset_name, dataset_url } = req.body; + if(!dataset_name || !dataset_url) { + res.status(400).send("Dataset information not specified."); + } + + await loadDataset(dataset_name, dataset_url); + if(!req.body.messages || !req.body.messages.length) { + res.status(400).send("Messages not given!"); + return; + } + const latest_message = req.body.messages.slice(-1)[0].content; + const rag_result = await searchByMessage(dataset_name, latest_message); + req.body.system_passed_extra_properties = { + inference_type: "rag", + extra_fields: { + has_background: !!rag_result, + ...(rag_result || {}) + } + } + chatCompletion(req, res); } \ No newline at end of file diff --git a/database/index.js b/database/index.js index dc55811..70f249a 100644 --- a/database/index.js +++ b/database/index.js @@ -1,3 +1,18 @@ +// coding=utf-8 + +// Copyright [2024] [SkywardAI] +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + import { connect } from "@lancedb/lancedb"; import { Schema, Field, FixedSizeList, @@ -5,7 +20,7 @@ import { // eslint-disable-next-line Table } from "apache-arrow"; -import { DATASET_TABLE, SYSTEM_TABLE } from "./types"; +import { DATASET_TABLE, SYSTEM_TABLE } from "./types.js"; const uri = "/tmp/lancedb/"; const db = await connect(uri); @@ -26,8 +41,6 @@ export async function initDB(force = false) { ]), open_options) } -initDB(); - /** * Open a table with table name * @param {String} table_name table name to be opened diff --git a/database/rag-inference.js b/database/rag-inference.js index 3e0d350..76566af 100644 --- a/database/rag-inference.js +++ b/database/rag-inference.js @@ -1,3 +1,18 @@ +// coding=utf-8 + +// Copyright [2024] [SkywardAI] +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + import { get, post } from "../tools/request.js"; import { getTable } from "./index.js"; import { DATASET_TABLE, SYSTEM_TABLE } from "./types.js"; diff --git a/database/types.js b/database/types.js index 54a49be..2e04f93 100644 --- a/database/types.js +++ b/database/types.js @@ -1,2 +1,17 @@ +// coding=utf-8 + +// Copyright [2024] [SkywardAI] +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + export const SYSTEM_TABLE = 'system'; export const DATASET_TABLE = 'dataset'; \ No newline at end of file diff --git a/doc.html b/doc.html deleted file mode 100644 index a67abc9..0000000 --- a/doc.html +++ /dev/null @@ -1,193 +0,0 @@ - - - - - - VOYAGER APIs - - - - - -
- - \ No newline at end of file diff --git a/index.js b/index.js index 5bc6db5..0bdc67c 100644 --- a/index.js +++ b/index.js @@ -18,6 +18,7 @@ import cors from 'cors'; import bodyParser from 'body-parser'; import { configDotenv } from 'dotenv'; +import { initDB } from './database/index.js'; import buildRoutes from './routes/index.js' import swStats from 'swagger-stats'; @@ -25,6 +26,7 @@ import * as swaggerUi from 'swagger-ui-express' import swaggerSpec from "./swagger.json" with { type: "json" }; configDotenv() +await initDB() const app = express(); app.use(cors()); diff --git a/routes/embedding.js b/routes/embedding.js index 910e2eb..31e1d6a 100644 --- a/routes/embedding.js +++ b/routes/embedding.js @@ -14,8 +14,12 @@ // limitations under the License. import { Router } from "express"; +import { embeddings } from "../actions/embedding.js"; export default function embeddingRoute() { const router = Router(); + + router.post("/", embeddings); + return router; } \ No newline at end of file diff --git a/routes/index.js b/routes/index.js index 1553e2f..6e1fc0c 100644 --- a/routes/index.js +++ b/routes/index.js @@ -38,7 +38,7 @@ function generateAPIRouters() { api_router.use('/chat', inferenceRoute()); api_router.use('/token', tokenRoute()); api_router.use('/tracing', tracingRoute()); - api_router.use('/embedding', embeddingRoute()); + api_router.use('/embeddings', embeddingRoute()); api_router.use('/encoder', encoderRoute()); api_router.use('/decoder', decoderRoute()); diff --git a/routes/inference.js b/routes/inference.js index 18fa803..b39ee8e 100644 --- a/routes/inference.js +++ b/routes/inference.js @@ -14,12 +14,13 @@ // limitations under the License. import { Router } from "express"; -import { chatCompletion } from "../actions/inference.js"; +import { chatCompletion, ragChatCompletion } from "../actions/inference.js"; export default function inferenceRoute() { const router = Router(); router.post('/completions', chatCompletion); + router.post('/rag-completions', ragChatCompletion); return router; } \ No newline at end of file diff --git a/swagger.json b/swagger.json index 6dbbf9d..be22438 100644 --- a/swagger.json +++ b/swagger.json @@ -25,6 +25,10 @@ "name": "Chat", "description": "v1 Chat APIs" }, + { + "name": "Embedding", + "description": "v1 Token APIs" + }, { "name": "Token", "description": "v1 Token APIs" @@ -36,6 +40,7 @@ "tags": [ "Index" ], + "summary" : "Route of this page", "description": "Route to get this page" } }, @@ -66,6 +71,7 @@ "/stats": { "get": { "tags": ["Index"], + "summary" : "Route to check stats", "description": "Graphical server stats, Click [here](/stats) to get the page.", "responses": { "200": { @@ -91,6 +97,12 @@ "description": "Request body of AI chat completion", "schema": { "$ref": "#/components/schemas/CompletionRequest" + }, + "example": { + "messages": [ + { "role": "system", "content": "You are a helpful assistant who helps users solve their questions." }, + { "role": "user", "content": "Hello, tell me more about you!" } + ] } } } @@ -108,6 +120,198 @@ } } } + }, + "400": { + "description": "Some errors occured" + }, + "401": { + "description": "Not authorized" + } + }, + "security": [ + {"api_key": []} + ] + } + }, + "/v1/chat/rag-completions": { + "post": { + "tags": ["Chat"], + "summary": "AI chat completion with RAG dataset.", + "description": "Start a conversation with given messages and QAs from dataset as context", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CompletionRequest" + }, + "example": { + "messages": [ + { "role": "system", "content": "You are a helpful assistant who helps users solve their questions." }, + { "role": "user", "content": "tell me something interest about massachusetts" } + ], + "dataset_name": "aisuko/squad01", + "dataset_url": "https://datasets-server.huggingface.co/rows?dataset=aisuko%2Fsquad01&config=default&split=validation&offset=0&length=100" + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "properties": { + "content": { + "$ref": "#/components/schemas/CompletionResponseEntire" + }, + "rag_context": { + "type": "object", + "properties": { + "question": { + "type": "string", + "examples": ["In what year did Massachusetts first require children to be educated in schools?"] + }, + "answer": { + "type": "string", + "examples": ["1852"] + }, + "distance": { + "type": "number", + "examples": [0.27701786160469055] + } + } + } + } + } + } + } + }, + "400": { + "description": "Some errors occured" + }, + "401": { + "description": "Not authorized" + } + }, + "security": [ + {"api_key": []} + ] + } + }, + "/v1/embeddings": { + "post": { + "tags": ["Embedding"], + "summary": "Get embedding of input", + "description": "Get the embedding value of given input, in OpenAI format", + "requestBody": { + "content": { + "application/json": { + "schema": { + "properties": { + "input": { + "type": "string", + "examples": ["Hello, world!"] + }, + "model": { + "type": "string", + "description": "You can pass model, but this won't work at current stage", + "examples": ["all-MiniLM-L6-v2"] + }, + "encoding_format": { + "type": "string", + "description": "You can pass encoding_format, but this won't work as we currently only support float.", + "examples": ["float"] + } + }, + "required": [ + "input" + ], + "example": { + "input": "Hello, world!", + "model": "all-MiniLM-L6-v2" + } + } + } + } + }, + "responses": { + "200": { + "description": "Everything works normal, get embedding result in OpenAI format", + "content": { + "application/json": { + "schema": { + "properties": { + "object": { + "type": "string", + "examples": ["list"] + }, + "data": { + "type": "array", + "items": { + "type": "object", + "properties": { + "object": { + "type": "string", + "examples": ["embedding"] + }, + "embedding": { + "type": "array", + "items": { + "type": "number", + "examples": [-0.02184351161122322, 0.049017686396837234, 0.06728602200746536] + }, + "description": "The length should be exactly 384 items", + "example": [ + -0.02184351161122322, + 0.049017686396837234, + 0.06728602200746536, + 0.06581537425518036, + -0.05950690433382988, + -0.08613293617963791, + "*** totally 384 floats ***" + ] + }, + "index": { + "type": "integer", + "examples": [0] + } + } + } + }, + "model": { + "type": "string", + "examples": ["all-MiniLM-L6-v2"] + }, + "usage": { + "type": "array", + "items": { + "type": "object", + "properties": { + "prompt_tokens": { + "type": "integer", + "examples": [0] + }, + "total_tokens": { + "type": "integer", + "examples": [0] + } + } + } + } + } + } + } + } + }, + "400": { + "description": "Some errors occured" + }, + "401": { + "description": "Not authorized" + }, + "500": { + "description": "Internal server error" } }, "security": [ @@ -168,7 +372,7 @@ }, "max_tokens": { "type": "integer", - "examples": [ 32, 512 ] + "examples": [ 32, 128, 512 ] }, "end": { "type": "array", @@ -182,7 +386,68 @@ "type": "boolean", "description": "If set to `true`, AI will response streamed." } - } + }, + "required": [ + "messages" + ] + }, + "RAGCompletionRequest": { + "properties": { + "messages": { + "type": "array", + "items": { + "type": "object", + "properties":{ + "role": { + "type": "string", + "examples": [ + "system", "user", "assistant" + ] + }, + "content": { + "type": "string", + "examples": [ + "You are a helpful assistant who helps users solve their questions.", + "Hi, what can you do?" + ] + } + } + } + }, + "max_tokens": { + "type": "integer", + "examples": [ 32, 128, 512 ] + }, + "end": { + "type": "array", + "description": "When AI outputs the end pattern, end response.", + "items": { + "type": "string", + "examples": ["### user:"] + } + }, + "stream": { + "type": "boolean", + "description": "If set to `true`, AI will response streamed." + }, + "dataset_name": { + "type": "string", + "examples": [ + "aisuko/squad01" + ] + }, + "dataset_url": { + "type": "string", + "examples": [ + "https://datasets-server.huggingface.co/rows?dataset=aisuko%2Fsquad01&config=default&split=validation&offset=0&length=100" + ] + } + }, + "required": [ + "messages", + "dataset_name", + "dataset_url" + ] }, "CompletionResponseEntire": { "description": "Response of completion, where `stream=false`",