Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rag inference & embedding route #31

Merged
merged 6 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions actions/embedding.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// coding=utf-8

import { post } from "../tools/request.js";

// Copyright [2024] [SkywardAI]
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

export async function embeddings(req, res) {
if(!req.headers.authorization) {
res.status(401).send("Not Authorized!");
return;
}

const { input } = req.body;
if(!input) {
res.status(400).send("Input sentence not specified!");
}

const { embedding, http_error } = await post('embedding', {body: {
content: input
}}, { eng: "embedding" });

if(http_error) {
res.status(500).send("Embedding Engine Internal Server Error")
return;
}

res.status(200).send({
object: "list",
data: [
{
object: "embedding",
embedding,
index: 0
}
],
model: "all-MiniLM-L6-v2",
usage: {
prompt_tokens: 0,
total_tokens: 0
}
})
}
63 changes: 59 additions & 4 deletions actions/inference.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import { formatOpenAIContext } from "../tools/formatContext.js";
import { generateFingerprint } from "../tools/generator.js";
import { post } from "../tools/request.js";
import { loadDataset, searchByMessage } from "../database/rag-inference.js";

function generateResponseContent(id, object, model, system_fingerprint, stream, content, stopped) {
const resp = {
Expand Down Expand Up @@ -53,11 +54,42 @@ export async function chatCompletion(req, res) {
return;
}

let {messages, max_tokens, ...request_body} = req.body;
let {
messages, max_tokens,
system_passed_extra_properties,
...request_body
} = req.body;
// apply default values or send error messages
if(!messages || !messages.length) {
res.status(400).send("Messages not given!");
return;
}
if(!max_tokens) max_tokens = 128;

let genResp = generateResponseContent;
if(system_passed_extra_properties) {
const { inference_type, extra_fields } = system_passed_extra_properties;
if(inference_type === "rag") {
const { has_background, question, answer, _distance:distance } = extra_fields;
if(has_background) {
messages.splice(-1, 0, {
role: 'system',
content: `Your next answer should based on this background: the question is "${question}" and the answer is "${answer}".`
})
}
genResp = (...args) => {
const content = generateResponseContent(...args)
if(args[6] && has_background) {
return { content, rag_context: {question, answer, distance} }
} else return content;
}
}
}


// format requests to llamacpp format input
request_body.prompt = formatOpenAIContext(messages);
if(max_tokens) request_body.n_predict = max_tokens;
request_body.n_predict = max_tokens;
if(!request_body.stop) request_body.stop = [...default_stop_keywords];

// extra
Expand All @@ -78,16 +110,39 @@ export async function chatCompletion(req, res) {
const data = value.split("data: ").pop()
const json_data = JSON.parse(data)
const { content, stop } = json_data;
res.write(JSON.stringify(generateResponseContent(api_key, 'chat.completion.chunk', model, system_fingerprint, true, content, stop))+'\n\n');
res.write(JSON.stringify(genResp(api_key, 'chat.completion.chunk', model, system_fingerprint, true, content, stop))+'\n\n');
}
res.end();
} else {
const eng_resp = await post('completion', { body: request_body });
const { model, content } = eng_resp;
const response_json = generateResponseContent(
const response_json = genResp(
api_key, 'chat.completion', model, system_fingerprint,
false, content, true
)
res.send(response_json);
}
}

export async function ragChatCompletion(req, res) {
const { dataset_name, dataset_url } = req.body;
if(!dataset_name || !dataset_url) {
res.status(400).send("Dataset information not specified.");
}

await loadDataset(dataset_name, dataset_url);
if(!req.body.messages || !req.body.messages.length) {
res.status(400).send("Messages not given!");
return;
}
const latest_message = req.body.messages.slice(-1)[0].content;
const rag_result = await searchByMessage(dataset_name, latest_message);
req.body.system_passed_extra_properties = {
inference_type: "rag",
extra_fields: {
has_background: !!rag_result,
...(rag_result || {})
}
}
chatCompletion(req, res);
}
19 changes: 16 additions & 3 deletions database/index.js
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
// coding=utf-8

// Copyright [2024] [SkywardAI]
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import { connect } from "@lancedb/lancedb";
import {
Schema, Field, FixedSizeList,
Float32, Utf8,
// eslint-disable-next-line
Table
} from "apache-arrow";
import { DATASET_TABLE, SYSTEM_TABLE } from "./types";
import { DATASET_TABLE, SYSTEM_TABLE } from "./types.js";

const uri = "/tmp/lancedb/";
const db = await connect(uri);
Expand All @@ -26,8 +41,6 @@ export async function initDB(force = false) {
]), open_options)
}

initDB();

/**
* Open a table with table name
* @param {String} table_name table name to be opened
Expand Down
15 changes: 15 additions & 0 deletions database/rag-inference.js
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
// coding=utf-8

// Copyright [2024] [SkywardAI]
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import { get, post } from "../tools/request.js";
import { getTable } from "./index.js";
import { DATASET_TABLE, SYSTEM_TABLE } from "./types.js";
Expand Down
15 changes: 15 additions & 0 deletions database/types.js
Original file line number Diff line number Diff line change
@@ -1,2 +1,17 @@
// coding=utf-8

// Copyright [2024] [SkywardAI]
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

export const SYSTEM_TABLE = 'system';
export const DATASET_TABLE = 'dataset';
Loading
Loading