Skip to content

Commit

Permalink
update demo
Browse files Browse the repository at this point in the history
  • Loading branch information
mrmer1 committed Dec 13, 2023
1 parent 0be8a8d commit 31e4dc7
Show file tree
Hide file tree
Showing 8 changed files with 2,295 additions and 0 deletions.
2 changes: 2 additions & 0 deletions chat_rag_connector/demo/.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
DEMO_CONNECTOR_API_KEY="YOUR_DEMO_CONNECTOR_API_KEY"
COHERE_API_KEY="YOUR_COHERE_API_KEY"
101 changes: 101 additions & 0 deletions chat_rag_connector/demo/.openapi/api.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
openapi: 3.0.3
info:
title: Search Connector API
version: 0.0.1
paths:
/search:
post:
description: >-
<p>Searches the connected data source for documents related to the query and returns a set of key-value pairs representing the found documents.</p>
operationId: search
summary: Perform a search
security:
- api_key: []
requestBody:
required: true
content:
application/json:
schema:
type: object
required:
- query
properties:
query:
description: >-
A plain-text query string to be used to search for relevant documents.
type: string
minLength: 1
example:
query: embeddings
responses:
"200":
description: Successful response
content:
application/json:
schema:
type: object
properties:
results:
type: array
items:
type: object
additionalProperties:
type: string
"400":
description: Bad request
"401":
description: Unauthorized
default:
description: Error response

/process:
post:
description: >-
<p>Processes the documents from the provided sources.</p>
operationId: process
summary: Process documents
security:
- api_key: []
requestBody:
required: true
content:
application/json:
schema:
type: object
required:
- sources
properties:
sources:
description: >-
A list of dictionaries representing the sources of the documents. Each dictionary should have 'title' and 'url' keys.
type: array
items:
type: object
properties:
title:
type: string
url:
type: string
responses:
"200":
description: Successful response
content:
application/json:
schema:
type: object
properties:
message:
type: string
"400":
description: Bad request
"401":
description: Unauthorized
default:
description: Error response

components:
securitySchemes:
api_key:
type: http
scheme: bearer
x-bearerInfoFunc: provider.app.apikey_auth
37 changes: 37 additions & 0 deletions chat_rag_connector/demo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Template Quick Start Connector

This is a _template_ for a simple quick start connector that return static data. This can serve as starting point for creating a brand new connector.

## Configuration

This connector is very simple and only needs a `TEMPLATE_CONNECTOR_API_KEY` environment variable to be set. This value will be used for bearer token authentication to protect this connector from abuse.

A `.env-template` file is provided with all the environment variables that are used by this connector.

## Development

Create a virtual environment and install dependencies with poetry. We recommend using in-project virtual environments:

```bash
$ poetry config virtualenvs.in-project true
$ poetry install --no-root
```

Then start the server

```bash
$ poetry run flask --app provider --debug run --port 5000
```

and check with curl to see that everything is working

```bash
$ curl --request POST \
--url http://localhost:5000/search \
--header 'Content-Type: application/json' \
--data '{
"query": "which species of penguin is the tallest?"
}'
```

Alternatively, load up the Swagger UI and try out the API from a browser: http://localhost:5000/ui/
1,931 changes: 1,931 additions & 0 deletions chat_rag_connector/demo/poetry.lock

Large diffs are not rendered by default.

32 changes: 32 additions & 0 deletions chat_rag_connector/demo/provider/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import logging
import os

import connexion # type: ignore
from dotenv import load_dotenv

load_dotenv()

API_VERSION = "api.yaml"


class UpstreamProviderError(Exception):
def __init__(self, message) -> None:
self.message = message

def __str__(self) -> str:
return self.message


def create_app() -> connexion.FlaskApp:
# use connexion to create a flask app with the endpoints defined in api.yaml spec
app = connexion.FlaskApp(__name__, specification_dir="../.openapi")
app.add_api(
API_VERSION, resolver=connexion.resolver.RelativeResolver("provider.app")
)
logging.basicConfig(level=logging.INFO)
flask_app = app.app
# load environment variables prefixed with the name of the current directory
config_prefix = os.path.split(os.getcwd())[1].upper().replace("_", "")
flask_app.config.from_prefixed_env(config_prefix)
flask_app.config["APP_ID"] = config_prefix
return flask_app
37 changes: 37 additions & 0 deletions chat_rag_connector/demo/provider/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import logging

from connexion.exceptions import Unauthorized
from flask import current_app as app
from provider.documents import Documents

logger = logging.getLogger(__name__)

demo_store = {}


# This function is run for the /search endpoint
# the results that are returned here will be passed to Cohere's model for RAG
def search(body):
logger.debug(f'Search request: {body["query"]}')

try:
docs = demo_store["docs"]
data = docs.retrieve(body["query"])
except KeyError:
return {"error": "No documents processed yet"}, 404

return {"results": data}, 200, {"X-Connector-Id": app.config.get("APP_ID")}


def process(body):
demo_store["docs"] = Documents(body["sources"])

return {"message": "Documents processed successfully"}, 200


# This function is run for all endpoints to ensure requests are using a valid API key
def apikey_auth(token):
if token != app.config.get("CONNECTOR_API_KEY"):
raise Unauthorized()
# successfully authenticated
return {}
134 changes: 134 additions & 0 deletions chat_rag_connector/demo/provider/documents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import os
from typing import Dict, List

import cohere
import hnswlib
from unstructured.chunking.title import chunk_by_title
from unstructured.partition.html import partition_html

co = cohere.Client(os.environ["COHERE_API_KEY"])

class Documents:
"""
A class representing a collection of documents.
Parameters:
sources (list): A list of dictionaries representing the sources of the documents. Each dictionary should have 'title' and 'url' keys.
Attributes:
sources (list): A list of dictionaries representing the sources of the documents.
docs (list): A list of dictionaries representing the documents, with 'title', 'content', and 'url' keys.
docs_embs (list): A list of the associated embeddings for the documents.
retrieve_top_k (int): The number of documents to retrieve during search.
rerank_top_k (int): The number of documents to rerank after retrieval.
docs_len (int): The number of documents in the collection.
index (hnswlib.Index): The index used for document retrieval.
Methods:
load(): Loads the data from the sources and partitions the HTML content into chunks.
embed(): Embeds the documents using the Cohere API.
index(): Indexes the documents for efficient retrieval.
retrieve(query): Retrieves documents based on the given query.
"""

def __init__(self, sources: List[Dict[str, str]]):
self.sources = sources
self.docs = []
self.docs_embs = []
self.retrieve_top_k = 10
self.rerank_top_k = 3
self.load()
self.embed()
self.index()

def load(self) -> None:
"""
Loads the documents from the sources and chunks the HTML content.
"""
print("Loading documents...")

for source in self.sources:
elements = partition_html(url=source["url"])
chunks = chunk_by_title(elements)
for chunk in chunks:
self.docs.append(
{
"title": source["title"],
"text": str(chunk),
"url": source["url"],
}
)

def embed(self) -> None:
"""
Embeds the documents using the Cohere API.
"""
print("Embedding documents...")

batch_size = 90
self.docs_len = len(self.docs)

for i in range(0, self.docs_len, batch_size):
batch = self.docs[i : min(i + batch_size, self.docs_len)]
texts = [item["text"] for item in batch]
docs_embs_batch = co.embed(
texts=texts, model="embed-english-v3.0", input_type="search_document"
).embeddings
self.docs_embs.extend(docs_embs_batch)

def index(self) -> None:
"""
Indexes the documents for efficient retrieval.
"""
print("Indexing documents...")

self.idx = hnswlib.Index(space="ip", dim=1024)
self.idx.init_index(max_elements=self.docs_len, ef_construction=512, M=64)
self.idx.add_items(self.docs_embs, list(range(len(self.docs_embs))))

print(f"Indexing complete with {self.idx.get_current_count()} documents.")

def retrieve(self, query: str) -> List[Dict[str, str]]:
"""
Retrieves documents based on the given query.
Parameters:
query (str): The query to retrieve documents for.
Returns:
List[Dict[str, str]]: A list of dictionaries representing the retrieved documents, with 'title', 'text', and 'url' keys.
"""
docs_retrieved = []
query_emb = co.embed(
texts=[query], model="embed-english-v3.0", input_type="search_query"
).embeddings

doc_ids = self.idx.knn_query(query_emb, k=self.retrieve_top_k)[0][0]

docs_to_rerank = []
for doc_id in doc_ids:
docs_to_rerank.append(self.docs[doc_id]["text"])

rerank_results = co.rerank(
query=query,
documents=docs_to_rerank,
top_n=self.rerank_top_k,
model="rerank-english-v2.0",
)

doc_ids_reranked = []
for result in rerank_results:
doc_ids_reranked.append(doc_ids[result.index])

for doc_id in doc_ids_reranked:
docs_retrieved.append(
{
"title": self.docs[doc_id]["title"],
"text": self.docs[doc_id]["text"],
"url": self.docs[doc_id]["url"],
}
)

return docs_retrieved

21 changes: 21 additions & 0 deletions chat_rag_connector/demo/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
[tool.poetry]
name = "demo"
version = "0.1.0"
description = ""
authors = ["Walter B <[email protected]>"]
readme = "README.md"

[tool.poetry.dependencies]
python = "^3.11"
flask = "2.2.5"
connexion = { extras = ["swagger-ui"], version = "^2.14.2" }
python-dotenv = "^1.0.0"
gunicorn = "^21.2.0"
unstructured = "^0.11.2"
hnswlib = "^0.8.0"
cohere = "^4.37"


[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

0 comments on commit 31e4dc7

Please sign in to comment.