diff --git a/actions/embedding.js b/actions/embedding.js
new file mode 100644
index 0000000..67344e2
--- /dev/null
+++ b/actions/embedding.js
@@ -0,0 +1,53 @@
+// coding=utf-8
+
+import { post } from "../tools/request.js";
+
+// Copyright [2024] [SkywardAI]
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+
+// http://www.apache.org/licenses/LICENSE-2.0
+
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+export async function embeddings(req, res) {
+ if(!req.headers.authorization) {
+ res.status(401).send("Not Authorized!");
+ return;
+ }
+
+ const { input } = req.body;
+ if(!input) {
+ res.status(400).send("Input sentence not specified!");
+ }
+
+ const { embedding, http_error } = await post('embedding', {body: {
+ content: input
+ }}, { eng: "embedding" });
+
+ if(http_error) {
+ res.status(500).send("Embedding Engine Internal Server Error")
+ return;
+ }
+
+ res.status(200).send({
+ object: "list",
+ data: [
+ {
+ object: "embedding",
+ embedding,
+ index: 0
+ }
+ ],
+ model: "all-MiniLM-L6-v2",
+ usage: {
+ prompt_tokens: 0,
+ total_tokens: 0
+ }
+ })
+}
\ No newline at end of file
diff --git a/actions/inference.js b/actions/inference.js
index 32258b5..e66c6c2 100644
--- a/actions/inference.js
+++ b/actions/inference.js
@@ -16,6 +16,7 @@
import { formatOpenAIContext } from "../tools/formatContext.js";
import { generateFingerprint } from "../tools/generator.js";
import { post } from "../tools/request.js";
+import { loadDataset, searchByMessage } from "../database/rag-inference.js";
function generateResponseContent(id, object, model, system_fingerprint, stream, content, stopped) {
const resp = {
@@ -53,11 +54,42 @@ export async function chatCompletion(req, res) {
return;
}
- let {messages, max_tokens, ...request_body} = req.body;
+ let {
+ messages, max_tokens,
+ system_passed_extra_properties,
+ ...request_body
+ } = req.body;
+ // apply default values or send error messages
+ if(!messages || !messages.length) {
+ res.status(400).send("Messages not given!");
+ return;
+ }
+ if(!max_tokens) max_tokens = 128;
+
+ let genResp = generateResponseContent;
+ if(system_passed_extra_properties) {
+ const { inference_type, extra_fields } = system_passed_extra_properties;
+ if(inference_type === "rag") {
+ const { has_background, question, answer, _distance:distance } = extra_fields;
+ if(has_background) {
+ messages.splice(-1, 0, {
+ role: 'system',
+ content: `Your next answer should based on this background: the question is "${question}" and the answer is "${answer}".`
+ })
+ }
+ genResp = (...args) => {
+ const content = generateResponseContent(...args)
+ if(args[6] && has_background) {
+ return { content, rag_context: {question, answer, distance} }
+ } else return content;
+ }
+ }
+ }
+
// format requests to llamacpp format input
request_body.prompt = formatOpenAIContext(messages);
- if(max_tokens) request_body.n_predict = max_tokens;
+ request_body.n_predict = max_tokens;
if(!request_body.stop) request_body.stop = [...default_stop_keywords];
// extra
@@ -78,16 +110,39 @@ export async function chatCompletion(req, res) {
const data = value.split("data: ").pop()
const json_data = JSON.parse(data)
const { content, stop } = json_data;
- res.write(JSON.stringify(generateResponseContent(api_key, 'chat.completion.chunk', model, system_fingerprint, true, content, stop))+'\n\n');
+ res.write(JSON.stringify(genResp(api_key, 'chat.completion.chunk', model, system_fingerprint, true, content, stop))+'\n\n');
}
res.end();
} else {
const eng_resp = await post('completion', { body: request_body });
const { model, content } = eng_resp;
- const response_json = generateResponseContent(
+ const response_json = genResp(
api_key, 'chat.completion', model, system_fingerprint,
false, content, true
)
res.send(response_json);
}
+}
+
+export async function ragChatCompletion(req, res) {
+ const { dataset_name, dataset_url } = req.body;
+ if(!dataset_name || !dataset_url) {
+ res.status(400).send("Dataset information not specified.");
+ }
+
+ await loadDataset(dataset_name, dataset_url);
+ if(!req.body.messages || !req.body.messages.length) {
+ res.status(400).send("Messages not given!");
+ return;
+ }
+ const latest_message = req.body.messages.slice(-1)[0].content;
+ const rag_result = await searchByMessage(dataset_name, latest_message);
+ req.body.system_passed_extra_properties = {
+ inference_type: "rag",
+ extra_fields: {
+ has_background: !!rag_result,
+ ...(rag_result || {})
+ }
+ }
+ chatCompletion(req, res);
}
\ No newline at end of file
diff --git a/database/index.js b/database/index.js
index dc55811..70f249a 100644
--- a/database/index.js
+++ b/database/index.js
@@ -1,3 +1,18 @@
+// coding=utf-8
+
+// Copyright [2024] [SkywardAI]
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+
+// http://www.apache.org/licenses/LICENSE-2.0
+
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
import { connect } from "@lancedb/lancedb";
import {
Schema, Field, FixedSizeList,
@@ -5,7 +20,7 @@ import {
// eslint-disable-next-line
Table
} from "apache-arrow";
-import { DATASET_TABLE, SYSTEM_TABLE } from "./types";
+import { DATASET_TABLE, SYSTEM_TABLE } from "./types.js";
const uri = "/tmp/lancedb/";
const db = await connect(uri);
@@ -26,8 +41,6 @@ export async function initDB(force = false) {
]), open_options)
}
-initDB();
-
/**
* Open a table with table name
* @param {String} table_name table name to be opened
diff --git a/database/rag-inference.js b/database/rag-inference.js
index 3e0d350..76566af 100644
--- a/database/rag-inference.js
+++ b/database/rag-inference.js
@@ -1,3 +1,18 @@
+// coding=utf-8
+
+// Copyright [2024] [SkywardAI]
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+
+// http://www.apache.org/licenses/LICENSE-2.0
+
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
import { get, post } from "../tools/request.js";
import { getTable } from "./index.js";
import { DATASET_TABLE, SYSTEM_TABLE } from "./types.js";
diff --git a/database/types.js b/database/types.js
index 54a49be..2e04f93 100644
--- a/database/types.js
+++ b/database/types.js
@@ -1,2 +1,17 @@
+// coding=utf-8
+
+// Copyright [2024] [SkywardAI]
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+
+// http://www.apache.org/licenses/LICENSE-2.0
+
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
export const SYSTEM_TABLE = 'system';
export const DATASET_TABLE = 'dataset';
\ No newline at end of file
diff --git a/doc.html b/doc.html
deleted file mode 100644
index a67abc9..0000000
--- a/doc.html
+++ /dev/null
@@ -1,193 +0,0 @@
-
-
-
-
-
- VOYAGER APIs
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/index.js b/index.js
index 5bc6db5..0bdc67c 100644
--- a/index.js
+++ b/index.js
@@ -18,6 +18,7 @@ import cors from 'cors';
import bodyParser from 'body-parser';
import { configDotenv } from 'dotenv';
+import { initDB } from './database/index.js';
import buildRoutes from './routes/index.js'
import swStats from 'swagger-stats';
@@ -25,6 +26,7 @@ import * as swaggerUi from 'swagger-ui-express'
import swaggerSpec from "./swagger.json" with { type: "json" };
configDotenv()
+await initDB()
const app = express();
app.use(cors());
diff --git a/routes/embedding.js b/routes/embedding.js
index 910e2eb..31e1d6a 100644
--- a/routes/embedding.js
+++ b/routes/embedding.js
@@ -14,8 +14,12 @@
// limitations under the License.
import { Router } from "express";
+import { embeddings } from "../actions/embedding.js";
export default function embeddingRoute() {
const router = Router();
+
+ router.post("/", embeddings);
+
return router;
}
\ No newline at end of file
diff --git a/routes/index.js b/routes/index.js
index 1553e2f..6e1fc0c 100644
--- a/routes/index.js
+++ b/routes/index.js
@@ -38,7 +38,7 @@ function generateAPIRouters() {
api_router.use('/chat', inferenceRoute());
api_router.use('/token', tokenRoute());
api_router.use('/tracing', tracingRoute());
- api_router.use('/embedding', embeddingRoute());
+ api_router.use('/embeddings', embeddingRoute());
api_router.use('/encoder', encoderRoute());
api_router.use('/decoder', decoderRoute());
diff --git a/routes/inference.js b/routes/inference.js
index 18fa803..b39ee8e 100644
--- a/routes/inference.js
+++ b/routes/inference.js
@@ -14,12 +14,13 @@
// limitations under the License.
import { Router } from "express";
-import { chatCompletion } from "../actions/inference.js";
+import { chatCompletion, ragChatCompletion } from "../actions/inference.js";
export default function inferenceRoute() {
const router = Router();
router.post('/completions', chatCompletion);
+ router.post('/rag-completions', ragChatCompletion);
return router;
}
\ No newline at end of file
diff --git a/swagger.json b/swagger.json
index 6dbbf9d..be22438 100644
--- a/swagger.json
+++ b/swagger.json
@@ -25,6 +25,10 @@
"name": "Chat",
"description": "v1 Chat APIs"
},
+ {
+ "name": "Embedding",
+ "description": "v1 Token APIs"
+ },
{
"name": "Token",
"description": "v1 Token APIs"
@@ -36,6 +40,7 @@
"tags": [
"Index"
],
+ "summary" : "Route of this page",
"description": "Route to get this page"
}
},
@@ -66,6 +71,7 @@
"/stats": {
"get": {
"tags": ["Index"],
+ "summary" : "Route to check stats",
"description": "Graphical server stats, Click [here](/stats) to get the page.",
"responses": {
"200": {
@@ -91,6 +97,12 @@
"description": "Request body of AI chat completion",
"schema": {
"$ref": "#/components/schemas/CompletionRequest"
+ },
+ "example": {
+ "messages": [
+ { "role": "system", "content": "You are a helpful assistant who helps users solve their questions." },
+ { "role": "user", "content": "Hello, tell me more about you!" }
+ ]
}
}
}
@@ -108,6 +120,198 @@
}
}
}
+ },
+ "400": {
+ "description": "Some errors occured"
+ },
+ "401": {
+ "description": "Not authorized"
+ }
+ },
+ "security": [
+ {"api_key": []}
+ ]
+ }
+ },
+ "/v1/chat/rag-completions": {
+ "post": {
+ "tags": ["Chat"],
+ "summary": "AI chat completion with RAG dataset.",
+ "description": "Start a conversation with given messages and QAs from dataset as context",
+ "requestBody": {
+ "required": true,
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/CompletionRequest"
+ },
+ "example": {
+ "messages": [
+ { "role": "system", "content": "You are a helpful assistant who helps users solve their questions." },
+ { "role": "user", "content": "tell me something interest about massachusetts" }
+ ],
+ "dataset_name": "aisuko/squad01",
+ "dataset_url": "https://datasets-server.huggingface.co/rows?dataset=aisuko%2Fsquad01&config=default&split=validation&offset=0&length=100"
+ }
+ }
+ }
+ },
+ "responses": {
+ "200": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "properties": {
+ "content": {
+ "$ref": "#/components/schemas/CompletionResponseEntire"
+ },
+ "rag_context": {
+ "type": "object",
+ "properties": {
+ "question": {
+ "type": "string",
+ "examples": ["In what year did Massachusetts first require children to be educated in schools?"]
+ },
+ "answer": {
+ "type": "string",
+ "examples": ["1852"]
+ },
+ "distance": {
+ "type": "number",
+ "examples": [0.27701786160469055]
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ },
+ "400": {
+ "description": "Some errors occured"
+ },
+ "401": {
+ "description": "Not authorized"
+ }
+ },
+ "security": [
+ {"api_key": []}
+ ]
+ }
+ },
+ "/v1/embeddings": {
+ "post": {
+ "tags": ["Embedding"],
+ "summary": "Get embedding of input",
+ "description": "Get the embedding value of given input, in OpenAI format",
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "properties": {
+ "input": {
+ "type": "string",
+ "examples": ["Hello, world!"]
+ },
+ "model": {
+ "type": "string",
+ "description": "You can pass model, but this won't work at current stage",
+ "examples": ["all-MiniLM-L6-v2"]
+ },
+ "encoding_format": {
+ "type": "string",
+ "description": "You can pass encoding_format, but this won't work as we currently only support float.",
+ "examples": ["float"]
+ }
+ },
+ "required": [
+ "input"
+ ],
+ "example": {
+ "input": "Hello, world!",
+ "model": "all-MiniLM-L6-v2"
+ }
+ }
+ }
+ }
+ },
+ "responses": {
+ "200": {
+ "description": "Everything works normal, get embedding result in OpenAI format",
+ "content": {
+ "application/json": {
+ "schema": {
+ "properties": {
+ "object": {
+ "type": "string",
+ "examples": ["list"]
+ },
+ "data": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "object": {
+ "type": "string",
+ "examples": ["embedding"]
+ },
+ "embedding": {
+ "type": "array",
+ "items": {
+ "type": "number",
+ "examples": [-0.02184351161122322, 0.049017686396837234, 0.06728602200746536]
+ },
+ "description": "The length should be exactly 384 items",
+ "example": [
+ -0.02184351161122322,
+ 0.049017686396837234,
+ 0.06728602200746536,
+ 0.06581537425518036,
+ -0.05950690433382988,
+ -0.08613293617963791,
+ "*** totally 384 floats ***"
+ ]
+ },
+ "index": {
+ "type": "integer",
+ "examples": [0]
+ }
+ }
+ }
+ },
+ "model": {
+ "type": "string",
+ "examples": ["all-MiniLM-L6-v2"]
+ },
+ "usage": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "prompt_tokens": {
+ "type": "integer",
+ "examples": [0]
+ },
+ "total_tokens": {
+ "type": "integer",
+ "examples": [0]
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ },
+ "400": {
+ "description": "Some errors occured"
+ },
+ "401": {
+ "description": "Not authorized"
+ },
+ "500": {
+ "description": "Internal server error"
}
},
"security": [
@@ -168,7 +372,7 @@
},
"max_tokens": {
"type": "integer",
- "examples": [ 32, 512 ]
+ "examples": [ 32, 128, 512 ]
},
"end": {
"type": "array",
@@ -182,7 +386,68 @@
"type": "boolean",
"description": "If set to `true`, AI will response streamed."
}
- }
+ },
+ "required": [
+ "messages"
+ ]
+ },
+ "RAGCompletionRequest": {
+ "properties": {
+ "messages": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties":{
+ "role": {
+ "type": "string",
+ "examples": [
+ "system", "user", "assistant"
+ ]
+ },
+ "content": {
+ "type": "string",
+ "examples": [
+ "You are a helpful assistant who helps users solve their questions.",
+ "Hi, what can you do?"
+ ]
+ }
+ }
+ }
+ },
+ "max_tokens": {
+ "type": "integer",
+ "examples": [ 32, 128, 512 ]
+ },
+ "end": {
+ "type": "array",
+ "description": "When AI outputs the end pattern, end response.",
+ "items": {
+ "type": "string",
+ "examples": ["### user:"]
+ }
+ },
+ "stream": {
+ "type": "boolean",
+ "description": "If set to `true`, AI will response streamed."
+ },
+ "dataset_name": {
+ "type": "string",
+ "examples": [
+ "aisuko/squad01"
+ ]
+ },
+ "dataset_url": {
+ "type": "string",
+ "examples": [
+ "https://datasets-server.huggingface.co/rows?dataset=aisuko%2Fsquad01&config=default&split=validation&offset=0&length=100"
+ ]
+ }
+ },
+ "required": [
+ "messages",
+ "dataset_name",
+ "dataset_url"
+ ]
},
"CompletionResponseEntire": {
"description": "Response of completion, where `stream=false`",