Skip to content

Commit

Permalink
Merge pull request #375 from SkywardAI/rag
Browse files Browse the repository at this point in the history
Feat new api for vectordb search
  • Loading branch information
Micost authored Aug 26, 2024
2 parents 2c3665d + 929067e commit 2ece301
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 15 deletions.
74 changes: 74 additions & 0 deletions backend/src/api/routes/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
ChatsWithTime,
ChatInMessage,
ChatInResponse,
SearchInMessage,
SearchResponse,
SessionUpdate,
SessionResponse,
ChatUUIDResponse,
Expand Down Expand Up @@ -176,6 +178,78 @@ async def chat_uuid(
return ChatUUIDResponse(sessionUuid=session_uuid)


@router.post(
"/search",
name="chat:chatbot",
response_model=SearchResponse,
status_code=fastapi.status.HTTP_200_OK,
)
async def search(
search_in_msg: SearchInMessage,
token: str = fastapi.Depends(oauth2_scheme),
session_repo: SessionCRUDRepository = fastapi.Depends(get_repository(repo_type=SessionCRUDRepository)),
rag_chat_repo: RAGChatModelRepository = fastapi.Depends(get_rag_repository(repo_type=RAGChatModelRepository)),
account_repo: AccountCRUDRepository = fastapi.Depends(get_repository(repo_type=AccountCRUDRepository)),
jwt_payload: dict = fastapi.Depends(jwt_required),
) -> SearchResponse:
"""
Search rag result with give messages and session.
**Note:**
You need to sing up and sign in before calling this API. If you are using
the Swagger UI. You can get the token automatically by login in through `api/auth/verify` API.
**Anonymous users**, we will create anonymous usef infor in the database. So, you can login through Authorize button in Swagger UI.
- **username**: anonymous
- **password**: Marlboro@2211
**Example of the request body:**
```bash
curl -X 'POST'
'http://localhost:8000/api/chat'
-H 'accept: application/json'
-H 'Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VybmFtZSI6ImFub255bW91cyIsImVtYWlsIjoiYW5vbnltb3VzQGFub255LmNvbSIsImV4cCI6MTcyMTA3MTI0MCwic3ViIjoiWU9VUi1KV1QtU1VCSkVDVCJ9.hip3zPA2yN-MOwKHFOm_KhZuvaC4soY4MgwegyYJu2s'
-H 'Content-Type: application/json'
-d '{
"sessionUuid": "string",
"message": "do you know RMIT?",
}'
```
**Return StreamingResponse:**
data: {"context":" Yes","score":100}
"""

##############################################################################################################################
# Note: await keyword will cause issue. See https://github.com/sqlalchemy/sqlalchemy/discussions/9757
#

current_user = account_repo.read_account_by_username(username=jwt_payload.username)

# TODO: Only read session here @Micost
session = session_repo.read_create_sessions_by_uuid(
session_uuid=search_in_msg.sessionUuid, account_id=current_user.id, name=search_in_msg.message[:20]
)
match session.session_type:
case "rag":
# Verify dataset_name exist
if session.dataset_name is None:
return SearchResponse(context= None, score=0)
else:
context = rag_chat_repo.search_context(
input_msg=search_in_msg.message,
collection_name=session.dataset_name
)
case _:
return SearchResponse(context= None, score=0)
return SearchResponse(
context= context, score=100
)


@router.post(
"",
name="chat:chatbot",
Expand Down
9 changes: 6 additions & 3 deletions backend/src/api/routes/rag_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ async def get_dataset_list(
)
async def load_dataset(
rag_ds_create: RagDatasetCreate,
background_tasks: fastapi.BackgroundTasks,
token: str = fastapi.Depends(oauth2_scheme),
session_repo: SessionCRUDRepository = fastapi.Depends(get_repository(repo_type=SessionCRUDRepository)),
ds_repo: DataSetCRUDRepository = fastapi.Depends(get_repository(repo_type=DataSetCRUDRepository)),
Expand Down Expand Up @@ -125,16 +126,17 @@ async def load_dataset(
# Here we don't use async because load_dataset is a sync function in HF ds
# status: bool = True if DatasetEng.load_dataset(rag_ds_create.dataset_name).get("insert_count") > 0 else False
session = session_repo.read_create_sessions_by_uuid(
session_uuid=rag_ds_create.sessionUuid, account_id=current_user.id, name="new session"
session_uuid=rag_ds_create.sessionUuid, account_id=current_user.id, name="new session", session_type="rag"
)
try:
# Here we use async because we need to update the session db
DatasetEng.load_dataset(rag_ds_create.dataset_name)
dataset_list = DatasetEng.validate_dataset(rag_ds_create.dataset_name)
status: bool =True
except Exception:
status: bool = False


async def load_dataset_task(dataset_name: str, ds_list: list):
DatasetEng.load_dataset(dataset_name, ds_list)
match status:
case True:
table_name = DatasetFormatter.format_dataset_by_name(
Expand All @@ -150,6 +152,7 @@ async def load_dataset(
table_name=table_name,
des=""
))
background_tasks.add_task(load_dataset_task, rag_ds_create.dataset_name, dataset_list)
case False:
return LoadRAGDSResponse(dataset_name=rag_ds_create.dataset_name, session_uuid=session.session_uuid, status=status)

Expand Down
20 changes: 20 additions & 0 deletions backend/src/models/schemas/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,26 @@ class ChatInMessage(BaseSchemaModel):
collection_name: Optional[str] | None = Field(default=None, title="Collection Name", description="Collection Name")


class SearchInMessage(BaseSchemaModel):
"""
Object for the request body of the chatbot endpoint.
Attributes:
-----------
sessionUuid: Optional[str] | None
Session UUID
message: str
Message
"""

sessionUuid: Optional[str] | None = Field(..., title="Session UUID", description="Session UUID")
message: str = Field(..., title="Message", description="Message")

class SearchResponse(BaseSchemaModel):
context: str | None = Field(..., title="Context", description="Context")
score: float = Field(..., title="Score", description="Score of the search")


class ChatInResponse(BaseSchemaModel):
sessionUuid: str = Field(..., title="Session UUID", description="Session UUID")
message: str = Field(..., title="Message", description="Message")
Expand Down
4 changes: 2 additions & 2 deletions backend/src/repository/crud/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def append_ds_name_to_session(self, session_uuid: str, account_id: int, ds_name:
loguru.logger.info(f"Update session {session_uuid}")
return Session.from_dict(update_session)

def read_create_sessions_by_uuid(self, session_uuid: str, account_id: int, name: str) -> Session:
def read_create_sessions_by_uuid(self, session_uuid: str, account_id: int, name: str, session_type: str = "chat") -> Session:
try:
session=self.tbl.search().where(f"session_uuid = '{session_uuid}' , account_id = {account_id}", prefilter=True).limit(1).to_list()[0]
except Exception:
Expand All @@ -127,7 +127,7 @@ def read_create_sessions_by_uuid(self, session_uuid: str, account_id: int, name:
"session_uuid": uuid_id,
"account_id": account_id,
"name": name,
"session_type": "chat",
"session_type": session_type,
"dataset_name": "",
"created_at": current_time
}])
Expand Down
30 changes: 30 additions & 0 deletions backend/src/repository/rag/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,38 @@ async def inference(
except httpx.HTTPStatusError as e:
loguru.logger.error(f"Error response {e.response.status_code} while requesting {e.request.url!r}.")

def search_context(self, input_msg: str, collection_name: str) -> str:
"""
Search the context from the vector database
Args:
input_msg (str): input message
collection_name (str): collection name
Returns:
str: context
"""
try:
res = httpx_kit.sync_client.post(
InferenceHelper.instruct_embedding_url(),
headers={"Content-Type": "application/json"},
json={"content": input_msg},
timeout=httpx.Timeout(timeout=None),
)
res.raise_for_status()
embedd_input = res.json().get("embedding")
except Exception as e:
loguru.logger.error(e)
# collection name for testing
context = vector_db.search(
list(embedd_input), 1, table_name=DatasetFormatter.format_dataset_by_name(collection_name)
)
loguru.logger.info(f"Context: {context}")
return context

async def inference_with_rag(
self,
session_uuid: str,
input_msg: str,
collection_name: str,
temperature: float = 0.2,
Expand Down
19 changes: 11 additions & 8 deletions backend/src/repository/rag_datasets_eng.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,26 @@ def get_dataset_by_name(self, name: str):
pass

@classmethod
def load_dataset(cls, name: str) -> dict:
def validate_dataset(cls, name: str) -> list:
"""
Load dataset from the given name, must connect to the internet
No need to consider the memory usage
Validate the dataset by the given name
"""

ds = load_dataset(name)
# TODO: validation isn't make sense, it should be removed
ds_list = ds.get("validation").to_list()
for item in ds_list:
for key, value in list(item.items()):
if isinstance(value, list):
item['vector'] = item.pop(key)
name = DatasetFormatter.format_dataset_by_name(name) if name else None

# vector_db.create_collection(collection_name=name)
vector_db.create_table(table_name=name, data=[ds_list[0]])
return ds_list

@classmethod
def load_dataset(cls, name: str, ds_list: list) -> dict:
"""
Load dataset from the given name, must connect to the internet
No need to consider the memory usage
"""
name = DatasetFormatter.format_dataset_by_name(name) if name else None
return vector_db.insert_list(table_name=name, data_list=ds_list)
4 changes: 2 additions & 2 deletions backend/src/repository/vector_database.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import loguru

from src.config.settings.const import DEFAULT_COLLECTION, DATASET_LANCEDB

import lancedb
Expand All @@ -22,8 +21,9 @@ def insert_list(self, table_name: str = DEFAULT_COLLECTION, data_list: list = []
try:
tbl = self.db.open_table(table_name)
tbl.add(data_list)
loguru.logger.info(f"Vector Databse --- Inserted {len(data_list)} records")
except Exception as e:
loguru.logger.info(f"Vector Databse --- Error: {e}")
loguru.logger.error(f"Vector Databse --- Error: {e}")

def search(self, data, n_results, table_name=DEFAULT_COLLECTION):
print(table_name)
Expand Down

0 comments on commit 2ece301

Please sign in to comment.