-
Notifications
You must be signed in to change notification settings - Fork 59
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
examples: RAG Question-answering Web Service (#3098)
This PR introduces a script which allows for the quick deployment of a local service which can answer questions based on the context provided from a configured vector database. I've left the table names hardcoded, looking for suggestions as to what the best approach for making them configurable is - putting them in the config.ini file did not seem appropriate. Testing done: local --------- Signed-off-by: Gabriel Georgiev <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
7f0ecd1
commit 33adbb4
Showing
6 changed files
with
218 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
## RAG Question-answering Web Service | ||
|
||
This script allows for the quick deployment of a local service which can answer questions based on | ||
the context provided from a configured vector database. | ||
To run it, you need to first install the required dependencies: | ||
``` | ||
pip install -r requirements.txt | ||
``` | ||
|
||
Then, you need to start the web service: | ||
``` | ||
uvicorn private_ai_api:app --reload | ||
``` | ||
|
||
You can now query the web service through the following command: | ||
``` | ||
curl http://127.0.0.1:8000/question/ -H "Content-Type: application/json" -d '{"question": "INPUT-QUESTION-HERE"} | ||
``` | ||
|
||
Alternatively, you can use the Python script for easier access + string formatting: | ||
``` | ||
python question.py "INPUT-QUESTION-HERE" | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
[db] | ||
postgres_dbname= | ||
postgres_dsn= | ||
postgres_host= | ||
postgres_password= | ||
postgres_user= | ||
|
||
[llm] | ||
auth_token= | ||
llm_host= | ||
llm_model= | ||
|
||
[tables] | ||
embeddings_table= | ||
data_table= |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# Copyright 2021-2024 VMware, Inc. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
import configparser | ||
|
||
import psycopg2 | ||
from clean_embed_text import get_question_embedding | ||
from clean_embed_text import setup_nltk | ||
from fastapi import FastAPI | ||
from openai import OpenAI | ||
from pgvector.psycopg2 import register_vector | ||
from pydantic import BaseModel | ||
|
||
# TODO: figure out how to make the parts configurable, i.e. embedding model could be configured here | ||
# but it would also need to be the same at the document ingestion step so that the similarity search | ||
# can work | ||
|
||
|
||
class QuestionModel(BaseModel): | ||
question: str | ||
|
||
|
||
app = FastAPI() | ||
|
||
|
||
@app.post("/question/") | ||
async def answer_question(question: QuestionModel): | ||
setup_nltk(".") | ||
|
||
config = configparser.ConfigParser() | ||
config.read("api_config.ini") | ||
|
||
embedding = get_question_embedding(question.question) | ||
|
||
cur = get_db_cursor(config) | ||
|
||
docs = get_similar_documents(embedding, cur, config, 3) | ||
docs = truncate(docs, 2000) | ||
|
||
prompt = build_prompt(docs, question.question) | ||
|
||
client = OpenAI( | ||
api_key=config["llm"]["auth_token"], base_url=config["llm"]["llm_host"] | ||
) | ||
|
||
completion = client.completions.create( | ||
model=config["llm"]["llm_model"], | ||
prompt=prompt, | ||
max_tokens=512, | ||
temperature=0, | ||
stream=True, | ||
) | ||
|
||
model_output = "" | ||
for c in completion: | ||
model_output += c.choices[0].text | ||
|
||
return model_output | ||
|
||
|
||
def truncate(s, wordcount): | ||
return " ".join(s.split()[:wordcount]) | ||
|
||
|
||
def get_db_cursor(config): | ||
db_conn = psycopg2.connect( | ||
dsn=config["db"]["postgres_dsn"], | ||
dbname=config["db"]["postgres_dbname"], | ||
user=config["db"]["postgres_user"], | ||
password=config["db"]["postgres_password"], | ||
host=config["db"]["postgres_host"], | ||
) | ||
register_vector(db_conn) | ||
cur = db_conn.cursor() | ||
|
||
return cur | ||
|
||
|
||
def build_prompt(context, question): | ||
prompt = f"""Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. Use three sentences maximum. Keep the answer as concise as possible. Always say "thanks for asking!" at the end of the answer. | ||
Context: {context} | ||
Question: {question} | ||
Helpful Answer:""" | ||
|
||
# Standard formatting for LLaMa 2 | ||
return f"<s>[INST] <<SYS>>\nBelow is an instruction that describes a task. Write a response that appropriately completes the request.\n<</SYS>>\n\n{prompt} [/INST] " | ||
|
||
|
||
def get_similar_documents(question_embedding, db_cursor, config, doc_count): | ||
db_cursor.execute( | ||
f""" | ||
SELECT {config["tables"]["metadata_table"]}.data | ||
FROM {config["tables"]["metadata_table"]} | ||
JOIN {config["tables"]["embeddings_table"]} | ||
ON {config["tables"]["metadata_table"]}.id = {config["tables"]["embeddings_table"]}.id | ||
ORDER BY {config["tables"]["embeddings_table"]}.embedding <-> %s LIMIT {doc_count} | ||
""", | ||
(question_embedding,), | ||
) | ||
res = db_cursor.fetchall() | ||
|
||
return "\n".join([doc[0] for doc in res]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Copyright 2021-2024 VMware, Inc. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
import os | ||
import re | ||
|
||
import nltk | ||
from nltk.corpus import stopwords | ||
from nltk.stem import WordNetLemmatizer | ||
from sentence_transformers import SentenceTransformer | ||
|
||
# This file contains all the reused code from vdk example jobs | ||
|
||
|
||
def clean_text(text): | ||
""" | ||
TODO: Copied from the embed-ingest-job-example. Needs to be replaced by a more robust approach, something | ||
off the shelf ideally. | ||
""" | ||
text = text.lower() | ||
# remove punctuation and special characters | ||
text = re.sub(r"[^\w\s]", "", text) | ||
# remove stopwords and lemmatize | ||
stop_words = set(stopwords.words("english")) | ||
lemmatizer = WordNetLemmatizer() | ||
text = " ".join( | ||
[lemmatizer.lemmatize(word) for word in text.split() if word not in stop_words] | ||
) | ||
return text | ||
|
||
|
||
def setup_nltk(temp_dir): | ||
""" | ||
Set up NLTK by creating a temporary directory for NLTK data and downloading required resources. | ||
""" | ||
from pathlib import Path | ||
|
||
nltk_data_path = Path(temp_dir) / "nltk_data" | ||
|
||
nltk_data_path.mkdir(exist_ok=True) | ||
nltk.data.path.append(str(nltk_data_path)) | ||
if os.path.isdir(nltk_data_path): | ||
return | ||
|
||
nltk.download("stopwords", download_dir=str(nltk_data_path)) | ||
nltk.download("wordnet", download_dir=str(nltk_data_path)) | ||
|
||
|
||
def get_question_embedding(question): | ||
model = SentenceTransformer("all-mpnet-base-v2") | ||
embedding = model.encode(clean_text(question)) | ||
|
||
return embedding |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Copyright 2021-2024 VMware, Inc. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
import sys | ||
|
||
import requests | ||
|
||
|
||
def question(): | ||
if len(sys.argv) != 2: | ||
print("Wrap your question in quotation marks") | ||
|
||
headers = {"Content-Type": "application/json"} | ||
data = {"question": sys.argv[1]} | ||
res = requests.post("http://127.0.0.1:8000/question/", headers=headers, json=data) | ||
|
||
print(res.text.replace("\\n", "\n").replace("\\t", "\t")) | ||
|
||
|
||
if __name__ == "__main__": | ||
question() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
"uvicorn[standard]" | ||
fastapi | ||
langchain | ||
nltk | ||
openai | ||
pgvector | ||
sentence-transformers |