Skip to content

Commit

Permalink
feat: add error handling and use EventSource for AI assistant (pinter…
Browse files Browse the repository at this point in the history
…est#1265)

* feat: add error handling and use EventSource

* fix linter

* address comments

* raise from
  • Loading branch information
jczhong84 authored and aidenprice committed Jan 3, 2024
1 parent 6219e2e commit e086b76
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 50 deletions.
3 changes: 2 additions & 1 deletion querybook/server/datasources/ai_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from lib.ai_assistant import ai_assistant


@register("/ai/query_title/", methods=["POST"], custom_response=True)
@register("/ai/query_title/", custom_response=True)
def generate_query_title(query):
title_stream = ai_assistant.generate_title_from_query(
query=query, user_id=current_user.id
)

return Response(title_stream, mimetype="text/event-stream")
8 changes: 4 additions & 4 deletions querybook/server/lib/ai_assistant/ai_assistant.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import threading

from .all_ai_assistants import get_ai_assistant_class
from .base_ai_assistant import ThreadedGenerator, ChainStreamHandler
from .base_ai_assistant import ChainStreamHandler, EventStream


class AIAssistant:
Expand All @@ -10,12 +10,12 @@ def __init__(self, provider: str, config: dict = {}):
self._assisant.set_config(config)

def _get_streaming_result(self, fn, kwargs):
g = ThreadedGenerator()
callback_handler = ChainStreamHandler(g)
event_stream = EventStream()
callback_handler = ChainStreamHandler(event_stream)
kwargs["callback_handler"] = callback_handler
thread = threading.Thread(target=fn, kwargs=kwargs)
thread.start()
return g
return event_stream

def generate_title_from_query(self, query, user_id=None):
return self._get_streaming_result(
Expand Down
11 changes: 9 additions & 2 deletions querybook/server/lib/ai_assistant/assistants/openai_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
import openai

LOG = get_logger(__file__)

Expand All @@ -21,6 +22,12 @@ class OpenAIAssistant(BaseAIAssistant):
def name(self) -> str:
return "openai"

def _get_error_msg(self, error) -> str:
if isinstance(error, openai.error.AuthenticationError):
return "Invalid OpenAI API key"

return super()._get_error_msg(error)

@property
def title_generation_prompt_template(self) -> str:
system_template = "You are a helpful assistant that can summerize SQL queries."
Expand All @@ -29,7 +36,7 @@ def title_generation_prompt_template(self) -> str:
)
human_template = (
"Generate a concise summary with no more than 8 words for the query below. "
"Only respond the title without any explanation or leading words.\n"
"Only respond the title without any explanation or final period.\n"
"```\n{query}\n```\nTitle:"
)
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
Expand All @@ -40,7 +47,7 @@ def title_generation_prompt_template(self) -> str:
def generate_sql_query(self):
pass

def generate_title_from_query(
def _generate_title_from_query(
self, query, stream=True, callback_handler=None, user_id=None
):
"""Generate title from SQL query using OpenAI's chat model."""
Expand Down
78 changes: 64 additions & 14 deletions querybook/server/lib/ai_assistant/base_ai_assistant.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from abc import ABC, abstractmethod

import functools
import queue

from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from pydantic.error_wrappers import ValidationError

from lib.logger import get_logger

LOG = get_logger(__file__)


class ThreadedGenerator:
"""Generator to facilitate streaming result from Langchain."""
class EventStream:
"""Generator to facilitate streaming result from Langchain.
The stream format is based on Server-Sent Events (SSE)."""

def __init__(self):
self.queue = queue.Queue()
Expand All @@ -20,26 +26,32 @@ def __next__(self):
raise item
return item

def send(self, data):
self.queue.put(data)
def send(self, data: str):
self.queue.put("data: " + data + "\n\n")

def close(self):
# the empty data is to make sure the client receives the close event
self.queue.put("event: close\ndata: \n\n")
self.queue.put(StopIteration)

def send_error(self, error: str):
self.queue.put("event: error\n")
self.queue.put(f"data: {error}\n\n")
self.close()


class ChainStreamHandler(StreamingStdOutCallbackHandler):
"""Callback handlder to stream the result to a generator."""

def __init__(self, gen: ThreadedGenerator):
def __init__(self, stream: EventStream):
super().__init__()
self.gen = gen
self.stream = stream

def on_llm_new_token(self, token: str, **kwargs):
self.gen.send(token)
self.stream.send(token)

def on_llm_end(self, response, **kwargs):
self.gen.send(StopIteration)
self.gen.close()
self.stream.close()


class BaseAIAssistant(ABC):
Expand All @@ -50,25 +62,63 @@ def name(self) -> str:
def set_config(self, config: dict):
self._config = config

def catch_error(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except Exception as e:
LOG.error(e, exc_info=True)
err_msg = self._get_error_msg(e)
callback_handler = kwargs.get("callback_handler")
if callback_handler:
callback_handler.stream.send_error(err_msg)
else:
raise Exception(err_msg) from e

return wrapper

def _get_error_msg(self, error) -> str:
"""Override this method to return specific error messages for your own assistant."""
if isinstance(error, ValidationError):
return error.errors()[0].get("msg")

return str(error.args[0])

@abstractmethod
def generate_sql_query(
self, metastore_id: int, query_engine_id: int, question: str, tables: list[str]
):
raise NotImplementedError()

@abstractmethod
@catch_error
def generate_title_from_query(
self,
query,
stream=False,
stream=True,
callback_handler: ChainStreamHandler = None,
user_id=None,
):
"""Generate title from SQL query.
Args:
query (str): SQL query
stream (bool, optional): Whether to stream the result. Defaults to False.
callback_handler (CallbackHandler, optional): Callback handler to handle the straming result. Defaults to None.
stream (bool, optional): Whether to stream the result. Defaults to True.
callback_handler (CallbackHandler, optional): Callback handler to handle the straming result. Required if stream is True.
"""
return self._generate_title_from_query(
query=query,
stream=stream,
callback_handler=callback_handler,
user_id=user_id,
)

@abstractmethod
def _generate_title_from_query(
self,
query,
stream,
callback_handler,
user_id=None,
):
raise NotImplementedError()
53 changes: 24 additions & 29 deletions querybook/webapp/lib/datasource.ts
Original file line number Diff line number Diff line change
Expand Up @@ -157,44 +157,39 @@ export function uploadDatasource<T = null>(
}

/**
* Stream data from a datasource.
* Stream data from a datasource using EventSource
*
* @param url The url to stream from
* @param data The data to send to the url
* @param params The data to send to the url
* @param onStraming Callback when data is received. The data is the accumulated data.
* @param onStramingEnd Callback when the stream ends
*/
async function streamDatasource(
function streamDatasource(
url: string,
data?: Record<string, unknown>,
onStraming?: (data: string) => void,
onStramingEnd?: () => void
params?: Record<string, unknown>,
onStreaming?: (data: string) => void,
onStreamingEnd?: () => void
) {
const resp = await fetch(url, {
method: 'POST',
headers: {
'Content-Type': 'application/json; charset=utf-8',
},
body: JSON.stringify(data),
});
if (resp.status !== 200) {
console.error(resp);
return;
}
const decoder = new TextDecoder();
const reader = resp.body.getReader();
const eventSource = new EventSource(
`${url}?params=${JSON.stringify(params)}`
);
let dataStream = '';
while (true) {
const { done, value } = await reader.read();
if (done) {
onStramingEnd?.();
break;
eventSource.addEventListener('message', (e) => {
dataStream += e.data;
onStreaming?.(dataStream);
});
eventSource.addEventListener('error', (e) => {
console.error(e);
eventSource.close();
onStreamingEnd?.();
if (e instanceof MessageEvent) {
toast.error(e.data);
}
dataStream += decoder.decode(value);
onStraming?.(dataStream);
}

return dataStream;
});
eventSource.addEventListener('close', (e) => {
eventSource.close();
onStreamingEnd?.();
});
}

export default {
Expand Down

0 comments on commit e086b76

Please sign in to comment.