Skip to content

Commit

Permalink
Rag inference & embedding route (#31)
Browse files Browse the repository at this point in the history
* add boilerplate

Signed-off-by: cbh778899 <[email protected]>

* implement rag inference

Signed-off-by: cbh778899 <[email protected]>

* init db when app loads

Signed-off-by: cbh778899 <[email protected]>

* implement embeddings route

Signed-off-by: cbh778899 <[email protected]>

* update swagger documentation

Signed-off-by: cbh778899 <[email protected]>

* delete doc.html as we don't need it anymore

Signed-off-by: cbh778899 <[email protected]>

---------

Signed-off-by: cbh778899 <[email protected]>
  • Loading branch information
cbh778899 authored Aug 5, 2024
1 parent d49634c commit 624f9b0
Show file tree
Hide file tree
Showing 11 changed files with 434 additions and 204 deletions.
53 changes: 53 additions & 0 deletions actions/embedding.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// coding=utf-8

import { post } from "../tools/request.js";

// Copyright [2024] [SkywardAI]
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

export async function embeddings(req, res) {
if(!req.headers.authorization) {
res.status(401).send("Not Authorized!");
return;
}

const { input } = req.body;
if(!input) {
res.status(400).send("Input sentence not specified!");
}

const { embedding, http_error } = await post('embedding', {body: {
content: input
}}, { eng: "embedding" });

if(http_error) {
res.status(500).send("Embedding Engine Internal Server Error")
return;
}

res.status(200).send({
object: "list",
data: [
{
object: "embedding",
embedding,
index: 0
}
],
model: "all-MiniLM-L6-v2",
usage: {
prompt_tokens: 0,
total_tokens: 0
}
})
}
63 changes: 59 additions & 4 deletions actions/inference.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import { formatOpenAIContext } from "../tools/formatContext.js";
import { generateFingerprint } from "../tools/generator.js";
import { post } from "../tools/request.js";
import { loadDataset, searchByMessage } from "../database/rag-inference.js";

function generateResponseContent(id, object, model, system_fingerprint, stream, content, stopped) {
const resp = {
Expand Down Expand Up @@ -53,11 +54,42 @@ export async function chatCompletion(req, res) {
return;
}

let {messages, max_tokens, ...request_body} = req.body;
let {
messages, max_tokens,
system_passed_extra_properties,
...request_body
} = req.body;
// apply default values or send error messages
if(!messages || !messages.length) {
res.status(400).send("Messages not given!");
return;
}
if(!max_tokens) max_tokens = 128;

let genResp = generateResponseContent;
if(system_passed_extra_properties) {
const { inference_type, extra_fields } = system_passed_extra_properties;
if(inference_type === "rag") {
const { has_background, question, answer, _distance:distance } = extra_fields;
if(has_background) {
messages.splice(-1, 0, {
role: 'system',
content: `Your next answer should based on this background: the question is "${question}" and the answer is "${answer}".`
})
}
genResp = (...args) => {
const content = generateResponseContent(...args)
if(args[6] && has_background) {
return { content, rag_context: {question, answer, distance} }
} else return content;
}
}
}


// format requests to llamacpp format input
request_body.prompt = formatOpenAIContext(messages);
if(max_tokens) request_body.n_predict = max_tokens;
request_body.n_predict = max_tokens;
if(!request_body.stop) request_body.stop = [...default_stop_keywords];

// extra
Expand All @@ -78,16 +110,39 @@ export async function chatCompletion(req, res) {
const data = value.split("data: ").pop()
const json_data = JSON.parse(data)
const { content, stop } = json_data;
res.write(JSON.stringify(generateResponseContent(api_key, 'chat.completion.chunk', model, system_fingerprint, true, content, stop))+'\n\n');
res.write(JSON.stringify(genResp(api_key, 'chat.completion.chunk', model, system_fingerprint, true, content, stop))+'\n\n');
}
res.end();
} else {
const eng_resp = await post('completion', { body: request_body });
const { model, content } = eng_resp;
const response_json = generateResponseContent(
const response_json = genResp(
api_key, 'chat.completion', model, system_fingerprint,
false, content, true
)
res.send(response_json);
}
}

export async function ragChatCompletion(req, res) {
const { dataset_name, dataset_url } = req.body;
if(!dataset_name || !dataset_url) {
res.status(400).send("Dataset information not specified.");
}

await loadDataset(dataset_name, dataset_url);
if(!req.body.messages || !req.body.messages.length) {
res.status(400).send("Messages not given!");
return;
}
const latest_message = req.body.messages.slice(-1)[0].content;
const rag_result = await searchByMessage(dataset_name, latest_message);
req.body.system_passed_extra_properties = {
inference_type: "rag",
extra_fields: {
has_background: !!rag_result,
...(rag_result || {})
}
}
chatCompletion(req, res);
}
19 changes: 16 additions & 3 deletions database/index.js
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
// coding=utf-8

// Copyright [2024] [SkywardAI]
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import { connect } from "@lancedb/lancedb";
import {
Schema, Field, FixedSizeList,
Float32, Utf8,
// eslint-disable-next-line
Table
} from "apache-arrow";
import { DATASET_TABLE, SYSTEM_TABLE } from "./types";
import { DATASET_TABLE, SYSTEM_TABLE } from "./types.js";

const uri = "/tmp/lancedb/";
const db = await connect(uri);
Expand All @@ -26,8 +41,6 @@ export async function initDB(force = false) {
]), open_options)
}

initDB();

/**
* Open a table with table name
* @param {String} table_name table name to be opened
Expand Down
15 changes: 15 additions & 0 deletions database/rag-inference.js
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
// coding=utf-8

// Copyright [2024] [SkywardAI]
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import { get, post } from "../tools/request.js";
import { getTable } from "./index.js";
import { DATASET_TABLE, SYSTEM_TABLE } from "./types.js";
Expand Down
15 changes: 15 additions & 0 deletions database/types.js
Original file line number Diff line number Diff line change
@@ -1,2 +1,17 @@
// coding=utf-8

// Copyright [2024] [SkywardAI]
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

export const SYSTEM_TABLE = 'system';
export const DATASET_TABLE = 'dataset';
Loading

0 comments on commit 624f9b0

Please sign in to comment.