Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add error handling and use EventSource for AI assistant #1265

Merged
merged 4 commits into from
Jun 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion querybook/server/datasources/ai_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from lib.ai_assistant import ai_assistant


@register("/ai/query_title/", methods=["POST"], custom_response=True)
@register("/ai/query_title/", custom_response=True)
czgu marked this conversation as resolved.
Show resolved Hide resolved
def generate_query_title(query):
title_stream = ai_assistant.generate_title_from_query(
query=query, user_id=current_user.id
)

return Response(title_stream, mimetype="text/event-stream")
8 changes: 4 additions & 4 deletions querybook/server/lib/ai_assistant/ai_assistant.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import threading

from .all_ai_assistants import get_ai_assistant_class
from .base_ai_assistant import ThreadedGenerator, ChainStreamHandler
from .base_ai_assistant import ChainStreamHandler, EventStream


class AIAssistant:
Expand All @@ -10,12 +10,12 @@ def __init__(self, provider: str, config: dict = {}):
self._assisant.set_config(config)

def _get_streaming_result(self, fn, kwargs):
g = ThreadedGenerator()
callback_handler = ChainStreamHandler(g)
event_stream = EventStream()
callback_handler = ChainStreamHandler(event_stream)
kwargs["callback_handler"] = callback_handler
thread = threading.Thread(target=fn, kwargs=kwargs)
thread.start()
return g
return event_stream

def generate_title_from_query(self, query, user_id=None):
return self._get_streaming_result(
Expand Down
11 changes: 9 additions & 2 deletions querybook/server/lib/ai_assistant/assistants/openai_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
import openai

LOG = get_logger(__file__)

Expand All @@ -21,6 +22,12 @@ class OpenAIAssistant(BaseAIAssistant):
def name(self) -> str:
return "openai"

def _get_error_msg(self, error) -> str:
if isinstance(error, openai.error.AuthenticationError):
return "Invalid OpenAI API key"

return super()._get_error_msg(error)

@property
def title_generation_prompt_template(self) -> str:
system_template = "You are a helpful assistant that can summerize SQL queries."
Expand All @@ -29,7 +36,7 @@ def title_generation_prompt_template(self) -> str:
)
human_template = (
"Generate a concise summary with no more than 8 words for the query below. "
"Only respond the title without any explanation or leading words.\n"
"Only respond the title without any explanation or final period.\n"
"```\n{query}\n```\nTitle:"
)
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
Expand All @@ -40,7 +47,7 @@ def title_generation_prompt_template(self) -> str:
def generate_sql_query(self):
pass

def generate_title_from_query(
def _generate_title_from_query(
self, query, stream=True, callback_handler=None, user_id=None
):
"""Generate title from SQL query using OpenAI's chat model."""
Expand Down
78 changes: 64 additions & 14 deletions querybook/server/lib/ai_assistant/base_ai_assistant.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from abc import ABC, abstractmethod

import functools
import queue

from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from pydantic.error_wrappers import ValidationError

from lib.logger import get_logger

LOG = get_logger(__file__)


class ThreadedGenerator:
"""Generator to facilitate streaming result from Langchain."""
class EventStream:
jczhong84 marked this conversation as resolved.
Show resolved Hide resolved
"""Generator to facilitate streaming result from Langchain.
The stream format is based on Server-Sent Events (SSE)."""

def __init__(self):
self.queue = queue.Queue()
Expand All @@ -20,26 +26,32 @@ def __next__(self):
raise item
return item

def send(self, data):
self.queue.put(data)
def send(self, data: str):
self.queue.put("data: " + data + "\n\n")

def close(self):
# the empty data is to make sure the client receives the close event
self.queue.put("event: close\ndata: \n\n")
self.queue.put(StopIteration)

def send_error(self, error: str):
self.queue.put("event: error\n")
self.queue.put(f"data: {error}\n\n")
self.close()


class ChainStreamHandler(StreamingStdOutCallbackHandler):
"""Callback handlder to stream the result to a generator."""

def __init__(self, gen: ThreadedGenerator):
def __init__(self, stream: EventStream):
super().__init__()
self.gen = gen
self.stream = stream

def on_llm_new_token(self, token: str, **kwargs):
self.gen.send(token)
self.stream.send(token)

def on_llm_end(self, response, **kwargs):
self.gen.send(StopIteration)
self.gen.close()
self.stream.close()


class BaseAIAssistant(ABC):
Expand All @@ -50,25 +62,63 @@ def name(self) -> str:
def set_config(self, config: dict):
self._config = config

def catch_error(func):
jczhong84 marked this conversation as resolved.
Show resolved Hide resolved
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except Exception as e:
LOG.error(e, exc_info=True)
err_msg = self._get_error_msg(e)
callback_handler = kwargs.get("callback_handler")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

callback_handler is required, can you make it more explicit (not just in except)
it would be useful to note that callback_handler must be passed as kwargs not args

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually callback_handler is not always required, only when it's a streaming function. and catch_error will also handle non streaming functions as well. will update

if callback_handler:
callback_handler.stream.send_error(err_msg)
else:
raise Exception(err_msg) from e

return wrapper

def _get_error_msg(self, error) -> str:
"""Override this method to return specific error messages for your own assistant."""
if isinstance(error, ValidationError):
return error.errors()[0].get("msg")

return str(error.args[0])

@abstractmethod
def generate_sql_query(
self, metastore_id: int, query_engine_id: int, question: str, tables: list[str]
):
raise NotImplementedError()

@abstractmethod
@catch_error
def generate_title_from_query(
self,
query,
stream=False,
stream=True,
callback_handler: ChainStreamHandler = None,
user_id=None,
):
"""Generate title from SQL query.

Args:
query (str): SQL query
stream (bool, optional): Whether to stream the result. Defaults to False.
callback_handler (CallbackHandler, optional): Callback handler to handle the straming result. Defaults to None.
stream (bool, optional): Whether to stream the result. Defaults to True.
callback_handler (CallbackHandler, optional): Callback handler to handle the straming result. Required if stream is True.
"""
return self._generate_title_from_query(
query=query,
stream=stream,
callback_handler=callback_handler,
user_id=user_id,
)

@abstractmethod
def _generate_title_from_query(
self,
query,
stream,
callback_handler,
user_id=None,
):
raise NotImplementedError()
53 changes: 24 additions & 29 deletions querybook/webapp/lib/datasource.ts
Original file line number Diff line number Diff line change
Expand Up @@ -157,44 +157,39 @@ export function uploadDatasource<T = null>(
}

/**
* Stream data from a datasource.
* Stream data from a datasource using EventSource
*
* @param url The url to stream from
* @param data The data to send to the url
* @param params The data to send to the url
* @param onStraming Callback when data is received. The data is the accumulated data.
* @param onStramingEnd Callback when the stream ends
*/
async function streamDatasource(
function streamDatasource(
url: string,
data?: Record<string, unknown>,
onStraming?: (data: string) => void,
onStramingEnd?: () => void
params?: Record<string, unknown>,
onStreaming?: (data: string) => void,
onStreamingEnd?: () => void
) {
const resp = await fetch(url, {
method: 'POST',
headers: {
'Content-Type': 'application/json; charset=utf-8',
},
body: JSON.stringify(data),
});
if (resp.status !== 200) {
console.error(resp);
return;
}
const decoder = new TextDecoder();
const reader = resp.body.getReader();
const eventSource = new EventSource(
`${url}?params=${JSON.stringify(params)}`
czgu marked this conversation as resolved.
Show resolved Hide resolved
);
let dataStream = '';
while (true) {
const { done, value } = await reader.read();
if (done) {
onStramingEnd?.();
break;
eventSource.addEventListener('message', (e) => {
dataStream += e.data;
onStreaming?.(dataStream);
});
eventSource.addEventListener('error', (e) => {
console.error(e);
eventSource.close();
onStreamingEnd?.();
if (e instanceof MessageEvent) {
toast.error(e.data);
}
dataStream += decoder.decode(value);
onStraming?.(dataStream);
}

return dataStream;
});
eventSource.addEventListener('close', (e) => {
eventSource.close();
onStreamingEnd?.();
});
}

export default {
Expand Down