Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

examples: RAG Question-answering Web Service #3098

Merged
merged 12 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions examples/rag-chat-api/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
## RAG Question-answering Web Service

This script allows for the quick deployment of a local service which can answer questions based on
the context provided from a configured vector database.
To run it, you need to first install the required dependencies:
```
pip install -r requirements.txt
```

Then, you need to start the web service:
```
uvicorn private_ai_api:app --reload
```

You can now query the web service through the following command:
```
curl http://127.0.0.1:8000/question/ -H "Content-Type: application/json" -d '{"question": "INPUT-QUESTION-HERE"}
```

Alternatively, you can use the Python script for easier access + string formatting:
```
python question.py "INPUT-QUESTION-HERE"
```
15 changes: 15 additions & 0 deletions examples/rag-chat-api/api_config.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[db]
postgres_dbname=
postgres_dsn=
postgres_host=
postgres_password=
postgres_user=

[llm]
auth_token=
llm_host=
llm_model=

[tables]
embeddings_table=
data_table=
101 changes: 101 additions & 0 deletions examples/rag-chat-api/chat_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2021-2024 VMware, Inc.
# SPDX-License-Identifier: Apache-2.0
import configparser

import psycopg2
from clean_embed_text import get_question_embedding
from clean_embed_text import setup_nltk
from fastapi import FastAPI
from openai import OpenAI
from pgvector.psycopg2 import register_vector
from pydantic import BaseModel

# TODO: figure out how to make the parts configurable, i.e. embedding model could be configured here
# but it would also need to be the same at the document ingestion step so that the similarity search
# can work


class QuestionModel(BaseModel):
question: str


app = FastAPI()


@app.post("/question/")
async def answer_question(question: QuestionModel):
setup_nltk(".")

config = configparser.ConfigParser()
config.read("api_config.ini")

embedding = get_question_embedding(question.question)

cur = get_db_cursor(config)

docs = get_similar_documents(embedding, cur, config, 3)
docs = truncate(docs, 2000)

prompt = build_prompt(docs, question.question)

client = OpenAI(
api_key=config["llm"]["auth_token"], base_url=config["llm"]["llm_host"]
)

completion = client.completions.create(
model=config["llm"]["llm_model"],
prompt=prompt,
max_tokens=512,
temperature=0,
stream=True,
)

model_output = ""
for c in completion:
model_output += c.choices[0].text

return model_output


def truncate(s, wordcount):
return " ".join(s.split()[:wordcount])


def get_db_cursor(config):
db_conn = psycopg2.connect(
dsn=config["db"]["postgres_dsn"],
dbname=config["db"]["postgres_dbname"],
user=config["db"]["postgres_user"],
password=config["db"]["postgres_password"],
host=config["db"]["postgres_host"],
)
register_vector(db_conn)
cur = db_conn.cursor()

return cur


def build_prompt(context, question):
prompt = f"""Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. Use three sentences maximum. Keep the answer as concise as possible. Always say "thanks for asking!" at the end of the answer.
Context: {context}
Question: {question}
Helpful Answer:"""

# Standard formatting for LLaMa 2
return f"<s>[INST] <<SYS>>\nBelow is an instruction that describes a task. Write a response that appropriately completes the request.\n<</SYS>>\n\n{prompt} [/INST] "


def get_similar_documents(question_embedding, db_cursor, config, doc_count):
db_cursor.execute(
f"""
SELECT {config["tables"]["metadata_table"]}.data
FROM {config["tables"]["metadata_table"]}
JOIN {config["tables"]["embeddings_table"]}
ON {config["tables"]["metadata_table"]}.id = {config["tables"]["embeddings_table"]}.id
ORDER BY {config["tables"]["embeddings_table"]}.embedding <-> %s LIMIT {doc_count}
""",
(question_embedding,),
)
res = db_cursor.fetchall()

return "\n".join([doc[0] for doc in res])
52 changes: 52 additions & 0 deletions examples/rag-chat-api/clean_embed_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2021-2024 VMware, Inc.
# SPDX-License-Identifier: Apache-2.0
import os
import re

import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from sentence_transformers import SentenceTransformer

# This file contains all the reused code from vdk example jobs


def clean_text(text):
"""
TODO: Copied from the embed-ingest-job-example. Needs to be replaced by a more robust approach, something
off the shelf ideally.
"""
text = text.lower()
# remove punctuation and special characters
text = re.sub(r"[^\w\s]", "", text)
# remove stopwords and lemmatize
stop_words = set(stopwords.words("english"))
lemmatizer = WordNetLemmatizer()
text = " ".join(
[lemmatizer.lemmatize(word) for word in text.split() if word not in stop_words]
)
return text


def setup_nltk(temp_dir):
"""
Set up NLTK by creating a temporary directory for NLTK data and downloading required resources.
"""
from pathlib import Path

nltk_data_path = Path(temp_dir) / "nltk_data"

nltk_data_path.mkdir(exist_ok=True)
nltk.data.path.append(str(nltk_data_path))
if os.path.isdir(nltk_data_path):
return

nltk.download("stopwords", download_dir=str(nltk_data_path))
nltk.download("wordnet", download_dir=str(nltk_data_path))


def get_question_embedding(question):
model = SentenceTransformer("all-mpnet-base-v2")
embedding = model.encode(clean_text(question))

return embedding
20 changes: 20 additions & 0 deletions examples/rag-chat-api/question.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2021-2024 VMware, Inc.
# SPDX-License-Identifier: Apache-2.0
import sys

import requests


def question():
if len(sys.argv) != 2:
print("Wrap your question in quotation marks")

headers = {"Content-Type": "application/json"}
data = {"question": sys.argv[1]}
res = requests.post("http://127.0.0.1:8000/question/", headers=headers, json=data)

print(res.text.replace("\\n", "\n").replace("\\t", "\t"))


if __name__ == "__main__":
question()
7 changes: 7 additions & 0 deletions examples/rag-chat-api/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"uvicorn[standard]"
fastapi
langchain
nltk
openai
pgvector
sentence-transformers
Loading