-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
2,295 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
DEMO_CONNECTOR_API_KEY="YOUR_DEMO_CONNECTOR_API_KEY" | ||
COHERE_API_KEY="YOUR_COHERE_API_KEY" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
openapi: 3.0.3 | ||
info: | ||
title: Search Connector API | ||
version: 0.0.1 | ||
paths: | ||
/search: | ||
post: | ||
description: >- | ||
<p>Searches the connected data source for documents related to the query and returns a set of key-value pairs representing the found documents.</p> | ||
operationId: search | ||
summary: Perform a search | ||
security: | ||
- api_key: [] | ||
requestBody: | ||
required: true | ||
content: | ||
application/json: | ||
schema: | ||
type: object | ||
required: | ||
- query | ||
properties: | ||
query: | ||
description: >- | ||
A plain-text query string to be used to search for relevant documents. | ||
type: string | ||
minLength: 1 | ||
example: | ||
query: embeddings | ||
responses: | ||
"200": | ||
description: Successful response | ||
content: | ||
application/json: | ||
schema: | ||
type: object | ||
properties: | ||
results: | ||
type: array | ||
items: | ||
type: object | ||
additionalProperties: | ||
type: string | ||
"400": | ||
description: Bad request | ||
"401": | ||
description: Unauthorized | ||
default: | ||
description: Error response | ||
|
||
/process: | ||
post: | ||
description: >- | ||
<p>Processes the documents from the provided sources.</p> | ||
operationId: process | ||
summary: Process documents | ||
security: | ||
- api_key: [] | ||
requestBody: | ||
required: true | ||
content: | ||
application/json: | ||
schema: | ||
type: object | ||
required: | ||
- sources | ||
properties: | ||
sources: | ||
description: >- | ||
A list of dictionaries representing the sources of the documents. Each dictionary should have 'title' and 'url' keys. | ||
type: array | ||
items: | ||
type: object | ||
properties: | ||
title: | ||
type: string | ||
url: | ||
type: string | ||
responses: | ||
"200": | ||
description: Successful response | ||
content: | ||
application/json: | ||
schema: | ||
type: object | ||
properties: | ||
message: | ||
type: string | ||
"400": | ||
description: Bad request | ||
"401": | ||
description: Unauthorized | ||
default: | ||
description: Error response | ||
|
||
components: | ||
securitySchemes: | ||
api_key: | ||
type: http | ||
scheme: bearer | ||
x-bearerInfoFunc: provider.app.apikey_auth |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Template Quick Start Connector | ||
|
||
This is a _template_ for a simple quick start connector that return static data. This can serve as starting point for creating a brand new connector. | ||
|
||
## Configuration | ||
|
||
This connector is very simple and only needs a `TEMPLATE_CONNECTOR_API_KEY` environment variable to be set. This value will be used for bearer token authentication to protect this connector from abuse. | ||
|
||
A `.env-template` file is provided with all the environment variables that are used by this connector. | ||
|
||
## Development | ||
|
||
Create a virtual environment and install dependencies with poetry. We recommend using in-project virtual environments: | ||
|
||
```bash | ||
$ poetry config virtualenvs.in-project true | ||
$ poetry install --no-root | ||
``` | ||
|
||
Then start the server | ||
|
||
```bash | ||
$ poetry run flask --app provider --debug run --port 5000 | ||
``` | ||
|
||
and check with curl to see that everything is working | ||
|
||
```bash | ||
$ curl --request POST \ | ||
--url http://localhost:5000/search \ | ||
--header 'Content-Type: application/json' \ | ||
--data '{ | ||
"query": "which species of penguin is the tallest?" | ||
}' | ||
``` | ||
|
||
Alternatively, load up the Swagger UI and try out the API from a browser: http://localhost:5000/ui/ |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import logging | ||
import os | ||
|
||
import connexion # type: ignore | ||
from dotenv import load_dotenv | ||
|
||
load_dotenv() | ||
|
||
API_VERSION = "api.yaml" | ||
|
||
|
||
class UpstreamProviderError(Exception): | ||
def __init__(self, message) -> None: | ||
self.message = message | ||
|
||
def __str__(self) -> str: | ||
return self.message | ||
|
||
|
||
def create_app() -> connexion.FlaskApp: | ||
# use connexion to create a flask app with the endpoints defined in api.yaml spec | ||
app = connexion.FlaskApp(__name__, specification_dir="../.openapi") | ||
app.add_api( | ||
API_VERSION, resolver=connexion.resolver.RelativeResolver("provider.app") | ||
) | ||
logging.basicConfig(level=logging.INFO) | ||
flask_app = app.app | ||
# load environment variables prefixed with the name of the current directory | ||
config_prefix = os.path.split(os.getcwd())[1].upper().replace("_", "") | ||
flask_app.config.from_prefixed_env(config_prefix) | ||
flask_app.config["APP_ID"] = config_prefix | ||
return flask_app |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import logging | ||
|
||
from connexion.exceptions import Unauthorized | ||
from flask import current_app as app | ||
from provider.documents import Documents | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
demo_store = {} | ||
|
||
|
||
# This function is run for the /search endpoint | ||
# the results that are returned here will be passed to Cohere's model for RAG | ||
def search(body): | ||
logger.debug(f'Search request: {body["query"]}') | ||
|
||
try: | ||
docs = demo_store["docs"] | ||
data = docs.retrieve(body["query"]) | ||
except KeyError: | ||
return {"error": "No documents processed yet"}, 404 | ||
|
||
return {"results": data}, 200, {"X-Connector-Id": app.config.get("APP_ID")} | ||
|
||
|
||
def process(body): | ||
demo_store["docs"] = Documents(body["sources"]) | ||
|
||
return {"message": "Documents processed successfully"}, 200 | ||
|
||
|
||
# This function is run for all endpoints to ensure requests are using a valid API key | ||
def apikey_auth(token): | ||
if token != app.config.get("CONNECTOR_API_KEY"): | ||
raise Unauthorized() | ||
# successfully authenticated | ||
return {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
import os | ||
from typing import Dict, List | ||
|
||
import cohere | ||
import hnswlib | ||
from unstructured.chunking.title import chunk_by_title | ||
from unstructured.partition.html import partition_html | ||
|
||
co = cohere.Client(os.environ["COHERE_API_KEY"]) | ||
|
||
class Documents: | ||
""" | ||
A class representing a collection of documents. | ||
Parameters: | ||
sources (list): A list of dictionaries representing the sources of the documents. Each dictionary should have 'title' and 'url' keys. | ||
Attributes: | ||
sources (list): A list of dictionaries representing the sources of the documents. | ||
docs (list): A list of dictionaries representing the documents, with 'title', 'content', and 'url' keys. | ||
docs_embs (list): A list of the associated embeddings for the documents. | ||
retrieve_top_k (int): The number of documents to retrieve during search. | ||
rerank_top_k (int): The number of documents to rerank after retrieval. | ||
docs_len (int): The number of documents in the collection. | ||
index (hnswlib.Index): The index used for document retrieval. | ||
Methods: | ||
load(): Loads the data from the sources and partitions the HTML content into chunks. | ||
embed(): Embeds the documents using the Cohere API. | ||
index(): Indexes the documents for efficient retrieval. | ||
retrieve(query): Retrieves documents based on the given query. | ||
""" | ||
|
||
def __init__(self, sources: List[Dict[str, str]]): | ||
self.sources = sources | ||
self.docs = [] | ||
self.docs_embs = [] | ||
self.retrieve_top_k = 10 | ||
self.rerank_top_k = 3 | ||
self.load() | ||
self.embed() | ||
self.index() | ||
|
||
def load(self) -> None: | ||
""" | ||
Loads the documents from the sources and chunks the HTML content. | ||
""" | ||
print("Loading documents...") | ||
|
||
for source in self.sources: | ||
elements = partition_html(url=source["url"]) | ||
chunks = chunk_by_title(elements) | ||
for chunk in chunks: | ||
self.docs.append( | ||
{ | ||
"title": source["title"], | ||
"text": str(chunk), | ||
"url": source["url"], | ||
} | ||
) | ||
|
||
def embed(self) -> None: | ||
""" | ||
Embeds the documents using the Cohere API. | ||
""" | ||
print("Embedding documents...") | ||
|
||
batch_size = 90 | ||
self.docs_len = len(self.docs) | ||
|
||
for i in range(0, self.docs_len, batch_size): | ||
batch = self.docs[i : min(i + batch_size, self.docs_len)] | ||
texts = [item["text"] for item in batch] | ||
docs_embs_batch = co.embed( | ||
texts=texts, model="embed-english-v3.0", input_type="search_document" | ||
).embeddings | ||
self.docs_embs.extend(docs_embs_batch) | ||
|
||
def index(self) -> None: | ||
""" | ||
Indexes the documents for efficient retrieval. | ||
""" | ||
print("Indexing documents...") | ||
|
||
self.idx = hnswlib.Index(space="ip", dim=1024) | ||
self.idx.init_index(max_elements=self.docs_len, ef_construction=512, M=64) | ||
self.idx.add_items(self.docs_embs, list(range(len(self.docs_embs)))) | ||
|
||
print(f"Indexing complete with {self.idx.get_current_count()} documents.") | ||
|
||
def retrieve(self, query: str) -> List[Dict[str, str]]: | ||
""" | ||
Retrieves documents based on the given query. | ||
Parameters: | ||
query (str): The query to retrieve documents for. | ||
Returns: | ||
List[Dict[str, str]]: A list of dictionaries representing the retrieved documents, with 'title', 'text', and 'url' keys. | ||
""" | ||
docs_retrieved = [] | ||
query_emb = co.embed( | ||
texts=[query], model="embed-english-v3.0", input_type="search_query" | ||
).embeddings | ||
|
||
doc_ids = self.idx.knn_query(query_emb, k=self.retrieve_top_k)[0][0] | ||
|
||
docs_to_rerank = [] | ||
for doc_id in doc_ids: | ||
docs_to_rerank.append(self.docs[doc_id]["text"]) | ||
|
||
rerank_results = co.rerank( | ||
query=query, | ||
documents=docs_to_rerank, | ||
top_n=self.rerank_top_k, | ||
model="rerank-english-v2.0", | ||
) | ||
|
||
doc_ids_reranked = [] | ||
for result in rerank_results: | ||
doc_ids_reranked.append(doc_ids[result.index]) | ||
|
||
for doc_id in doc_ids_reranked: | ||
docs_retrieved.append( | ||
{ | ||
"title": self.docs[doc_id]["title"], | ||
"text": self.docs[doc_id]["text"], | ||
"url": self.docs[doc_id]["url"], | ||
} | ||
) | ||
|
||
return docs_retrieved | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
[tool.poetry] | ||
name = "demo" | ||
version = "0.1.0" | ||
description = "" | ||
authors = ["Walter B <[email protected]>"] | ||
readme = "README.md" | ||
|
||
[tool.poetry.dependencies] | ||
python = "^3.11" | ||
flask = "2.2.5" | ||
connexion = { extras = ["swagger-ui"], version = "^2.14.2" } | ||
python-dotenv = "^1.0.0" | ||
gunicorn = "^21.2.0" | ||
unstructured = "^0.11.2" | ||
hnswlib = "^0.8.0" | ||
cohere = "^4.37" | ||
|
||
|
||
[build-system] | ||
requires = ["poetry-core"] | ||
build-backend = "poetry.core.masonry.api" |