From ec575b1bf4ff9de57181ba26d7d66081f2fdd59a Mon Sep 17 00:00:00 2001 From: Bohan Cheng <47214785+cbh778899@users.noreply.github.com> Date: Tue, 6 Aug 2024 13:07:59 +1000 Subject: [PATCH] add max_distance to limit the result provided by vector db (#35) Signed-off-by: cbh778899 --- database/rag-inference.js | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/database/rag-inference.js b/database/rag-inference.js index 76566af..c0ce4ca 100644 --- a/database/rag-inference.js +++ b/database/rag-inference.js @@ -62,9 +62,11 @@ export async function loadDataset(dataset_name, dataset_url, force = false) { * Search in given dataset using provided embedding value to get Q/A pair * @param {String} dataset_name The dataset name to be query from * @param {Array} vector The embedding result to be searched + * @param {Number} max_distance If the calculated distance is over given max_distance, then the result will be excluded. + * Default to `1`. * @returns {Promise} If there's no result, returns null, otherwise returns the result */ -export async function searchByEmbedding(dataset_name, vector) { +export async function searchByEmbedding(dataset_name, vector, max_distance = 1) { const embedding_result = (await ( await getTable(DATASET_TABLE) ).search(vector).where(`dataset_name = "${dataset_name}"`) @@ -72,6 +74,7 @@ export async function searchByEmbedding(dataset_name, vector) { if(embedding_result) { const { question, answer, _distance } = embedding_result; + if(_distance >= max_distance) return null; return { question, answer, _distance } } return null; @@ -82,12 +85,14 @@ export async function searchByEmbedding(dataset_name, vector) { * This will firstly embedding the message and query use {@link searchByEmbedding} * @param {String} dataset_name The dataset name to be query from * @param {String} message The message to be searched + * @param {Number} max_distance If the calculated distance is over given max_distance, then the result will be excluded. + * Default to `1`. * @returns {Promise} If there's no result, returns null, otherwise returns the result */ -export async function searchByMessage(dataset_name, message) { +export async function searchByMessage(dataset_name, message, max_distance = 1) { const { embedding, http_error } = await post('embedding', {body: { content: message }}, { eng: "embedding" }); - return http_error ? null : await searchByEmbedding(dataset_name, embedding); + return http_error ? null : await searchByEmbedding(dataset_name, embedding, max_distance); } \ No newline at end of file