Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DH-5089/fixing_the_new_agent #286

Merged
merged 4 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions dataherald/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,16 @@ def create_prompt_and_sql_generation(
) -> SQLGenerationResponse:
pass

@abstractmethod
def get_sql_generations(
self, prompt_id: str | None = None
) -> list[SQLGenerationResponse]:
pass

@abstractmethod
def get_sql_generation(self, sql_generation_id: str) -> SQLGenerationResponse:
pass

@abstractmethod
def create_nl_generation(
self, sql_generation_id: str, nl_generation_request: NLGenerationRequest
Expand All @@ -186,3 +196,13 @@ def create_prompt_sql_and_nl_generation(
nl_generation: NLGenerationRequest,
) -> NLGenerationResponse:
pass

@abstractmethod
def get_nl_generations(
self, sql_generation_id: str | None = None
) -> list[NLGenerationResponse]:
pass

@abstractmethod
def get_nl_generation(self, nl_generation_id: str) -> NLGenerationResponse:
pass
136 changes: 93 additions & 43 deletions dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,10 @@ def get_query_history(self, db_connection_id: str) -> list[QueryHistory]:
def add_golden_sqls(self, golden_sqls: List[GoldenSQLRequest]) -> List[GoldenSQL]:
"""Takes in a list of NL <> SQL pairs and stores them to be used in prompts to the LLM"""
context_store = self.system.instance(ContextStore)
return context_store.add_golden_sqls(golden_sqls)
golden_sqls = context_store.add_golden_sqls(golden_sqls)
for golden_sql in golden_sqls:
golden_sql.created_at = str(golden_sql.created_at)
return golden_sqls

@override
def execute_sql_query(self, query: Query) -> tuple[str, dict]:
Expand Down Expand Up @@ -521,11 +524,9 @@ def create_sql_generation(
except InvalidId as e:
raise HTTPException(status_code=400, detail=str(e)) from e
except PromptNotFoundError as e:
raise HTTPException(status_code=404, detail="Prompt not found") from e
raise HTTPException(status_code=404, detail=str(e)) from e
except SQLGenerationError as e:
raise HTTPException(
status_code=400, detail="Error raised during SQL generation"
) from e
raise HTTPException(status_code=400, detail=str(e)) from e
sql_generation_dict = sql_generation.dict()
sql_generation_dict["created_at"] = str(sql_generation.created_at)
sql_generation_dict["completed_at"] = str(sql_generation.completed_at)
Expand All @@ -540,27 +541,59 @@ def create_prompt_and_sql_generation(
prompt = prompt_service.create(prompt)
except InvalidId as e:
raise HTTPException(status_code=400, detail=str(e)) from e
except DatabaseConnectionNotFoundError:
raise HTTPException( # noqa: B904
status_code=404, detail="Database connection not found"
)
except DatabaseConnectionNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e)) from e

sql_generation_service = SQLGenerationService(self.system, self.storage)
try:
sql_generation = sql_generation_service.create(prompt.id, sql_generation)
except InvalidId as e:
raise HTTPException(status_code=400, detail=str(e)) from e
except PromptNotFoundError as e:
raise HTTPException(status_code=404, detail="Prompt not found") from e
raise HTTPException(status_code=404, detail=str(e)) from e
except SQLGenerationError as e:
raise HTTPException(
status_code=400, detail="Error raised during SQL generation"
) from e
raise HTTPException(status_code=400, detail=str(e)) from e
sql_generation_dict = sql_generation.dict()
sql_generation_dict["created_at"] = str(sql_generation.created_at)
sql_generation_dict["completed_at"] = str(sql_generation.completed_at)
return SQLGenerationResponse(**sql_generation_dict)

@override
def get_sql_generations(
self, prompt_id: str | None = None
) -> list[SQLGenerationResponse]:
sql_generation_service = SQLGenerationService(self.system, self.storage)
query = {}
if prompt_id:
query["prompt_id"] = prompt_id
sql_generations = sql_generation_service.get(query)
result = []
for sql_generation in sql_generations:
sql_generation_dict = sql_generation.dict()
sql_generation_dict["created_at"] = str(sql_generation.created_at)
sql_generation_dict["completed_at"] = str(sql_generation.completed_at)
result.append(SQLGenerationResponse(**sql_generation_dict))
return result

@override
def get_sql_generation(self, sql_generation_id: str) -> SQLGenerationResponse:
sql_generation_service = SQLGenerationService(self.system, self.storage)
try:
sql_generations = sql_generation_service.get(
{"_id": ObjectId(sql_generation_id)}
)
except InvalidId as e:
raise HTTPException(status_code=400, detail=str(e)) from e

if len(sql_generations) == 0:
raise HTTPException(
status_code=404, detail=f"SQL Generation {sql_generation_id} not found"
)
sql_generation_dict = sql_generations[0].dict()
sql_generation_dict["created_at"] = str(sql_generations[0].created_at)
sql_generation_dict["completed_at"] = str(sql_generations[0].completed_at)
return SQLGenerationResponse(**sql_generation_dict)

@override
def create_nl_generation(
self, sql_generation_id: str, nl_generation_request: NLGenerationRequest
Expand All @@ -573,13 +606,9 @@ def create_nl_generation(
except InvalidId as e:
raise HTTPException(status_code=400, detail=str(e)) from e
except SQLGenerationNotFoundError as e:
raise HTTPException(
status_code=400, detail="SQL Generation not found"
) from e
raise HTTPException(status_code=400, detail=str(e)) from e
except NLGenerationError as e:
raise HTTPException(
status_code=400, detail="Error raised during NL generation"
) from e
raise HTTPException(status_code=400, detail=str(e)) from e
nl_generation_dict = nl_generation.dict()
nl_generation_dict["created_at"] = str(nl_generation.created_at)
return NLGenerationResponse(**nl_generation_dict)
Expand All @@ -597,11 +626,9 @@ def create_sql_and_nl_generation(
except InvalidId as e:
raise HTTPException(status_code=400, detail=str(e)) from e
except PromptNotFoundError as e:
raise HTTPException(status_code=404, detail="Prompt not found") from e
raise HTTPException(status_code=404, detail=str(e)) from e
except SQLGenerationError as e:
raise HTTPException(
status_code=400, detail="Error raised during SQL generation"
) from e
raise HTTPException(status_code=400, detail=str(e)) from e

nl_generation_service = NLGenerationService(self.system, self.storage)
try:
Expand All @@ -611,13 +638,9 @@ def create_sql_and_nl_generation(
except InvalidId as e:
raise HTTPException(status_code=400, detail=str(e)) from e
except SQLGenerationNotFoundError as e:
raise HTTPException(
status_code=400, detail="SQL Generation not found"
) from e
raise HTTPException(status_code=400, detail=str(e)) from e
except NLGenerationError as e:
raise HTTPException(
status_code=400, detail="Error raised during NL generation"
) from e
raise HTTPException(status_code=400, detail=str(e)) from e
nl_generation_dict = nl_generation.dict()
nl_generation_dict["created_at"] = str(nl_generation.created_at)
return NLGenerationResponse(**nl_generation_dict)
Expand All @@ -634,22 +657,18 @@ def create_prompt_sql_and_nl_generation(
prompt = prompt_service.create(prompt)
except InvalidId as e:
raise HTTPException(status_code=400, detail=str(e)) from e
except DatabaseConnectionNotFoundError:
raise HTTPException( # noqa: B904
status_code=404, detail="Database connection not found"
)
except DatabaseConnectionNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e)) from e # noqa: B904

sql_generation_service = SQLGenerationService(self.system, self.storage)
try:
sql_generation = sql_generation_service.create(prompt.id, sql_generation)
except InvalidId as e:
raise HTTPException(status_code=400, detail=str(e)) from e
except PromptNotFoundError as e:
raise HTTPException(status_code=404, detail="Prompt not found") from e
raise HTTPException(status_code=404, detail=str(e)) from e
except SQLGenerationError as e:
raise HTTPException(
status_code=400, detail="Error raised during SQL generation"
) from e
raise HTTPException(status_code=400, detail=str(e)) from e

nl_generation_service = NLGenerationService(self.system, self.storage)
try:
Expand All @@ -659,13 +678,44 @@ def create_prompt_sql_and_nl_generation(
except InvalidId as e:
raise HTTPException(status_code=400, detail=str(e)) from e
except SQLGenerationNotFoundError as e:
raise HTTPException(
status_code=400, detail="SQL Generation not found"
) from e
raise HTTPException(status_code=400, detail=str(e)) from e
except NLGenerationError as e:
raise HTTPException(
status_code=400, detail="Error raised during NL generation"
) from e
raise HTTPException(status_code=400, detail=str(e)) from e
nl_generation_dict = nl_generation.dict()
nl_generation_dict["created_at"] = str(nl_generation.created_at)
return NLGenerationResponse(**nl_generation_dict)

@override
def get_nl_generations(
self, sql_generation_id: str | None = None
) -> list[NLGenerationResponse]:
nl_generation_service = NLGenerationService(self.system, self.storage)
query = {}
if sql_generation_id:
query["sql_generation_id"] = sql_generation_id
nl_generations = nl_generation_service.get(query)
result = []
for nl_generation in nl_generations:
nl_generation_dict = nl_generation.dict()
nl_generation_dict["created_at"] = str(nl_generation.created_at)
result.append(NLGenerationResponse(**nl_generation_dict))
return result

@override
def get_nl_generation(self, nl_generation_id: str) -> NLGenerationResponse:
nl_generation_service = NLGenerationService(self.system, self.storage)
try:
nl_generations = nl_generation_service.get(
{"_id": ObjectId(nl_generation_id)}
)
except InvalidId as e:
raise HTTPException(status_code=400, detail=str(e)) from e

if len(nl_generations) == 0:
raise HTTPException(
status_code=404,
detail=f"NL Generation {nl_generation_id} not found",
)
nl_generation_dict = nl_generations[0].dict()
nl_generation_dict["created_at"] = str(nl_generations[0].created_at)
return NLGenerationResponse(**nl_generation_dict)
4 changes: 2 additions & 2 deletions dataherald/api/types/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ class PromptRequest(BaseModel):


class SQLGenerationRequest(BaseModel):
model: str | None
finetuning_id: str | None
evaluate: bool = False
sql: str | None
metadata: dict | None


class NLGenerationRequest(BaseModel):
max_rows: int
max_rows: int = 100
metadata: dict | None
2 changes: 1 addition & 1 deletion dataherald/api/types/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class PromptResponse(BaseModel):
class SQLGenerationResponse(BaseModel):
id: str
prompt_id: str
model: str | None
finetuning_id: str | None
status: str
completed_at: str
sql: str | None
Expand Down
10 changes: 5 additions & 5 deletions dataherald/context_store/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,18 @@ def add_golden_sqls(self, golden_sqls: List[GoldenSQLRequest]) -> List[GoldenSQL
golden_sqls_repository = GoldenSQLRepository(self.db)
retruned_golden_sqls = []
for record in golden_sqls:
tables = Parser(record.sql_query).tables
question = record.question
tables = Parser(record.sql).tables
prompt_text = record.prompt_text
golden_sql = GoldenSQL(
question=question,
sql_query=record.sql_query,
prompt_text=prompt_text,
sql=record.sql,
db_connection_id=record.db_connection_id,
metadata=record.metadata,
)
retruned_golden_sqls.append(golden_sql)
golden_sql = golden_sqls_repository.insert(golden_sql)
self.vector_store.add_record(
documents=question,
documents=prompt_text,
db_connection_id=record.db_connection_id,
collection=self.golden_sql_collection,
metadata=[
Expand Down
4 changes: 2 additions & 2 deletions dataherald/finetuning/openai_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ def create_fintuning_dataset(self):
model = model_repository.find_by_id(self.fine_tuning_model.id)
for golden_sql_id in self.fine_tuning_model.golden_sqls:
golden_sql = golden_sqls_repository.find_by_id(golden_sql_id)
question = golden_sql.question
query = golden_sql.sql_query
question = golden_sql.prompt_text
query = golden_sql.sql
system_prompt = FINETUNING_SYSTEM_INFORMATION + database_schema
user_prompt = "User Question: " + question + "\n SQL: "
assistant_prompt = query + "\n"
Expand Down
44 changes: 44 additions & 0 deletions dataherald/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,20 @@ def __init__(self, settings: Settings):
tags=["SQL Generation"],
)

self.router.add_api_route(
"/api/v1/sql-generations",
self.get_sql_generations,
methods=["GET"],
tags=["SQL Generation"],
)

self.router.add_api_route(
"/api/v1/sql-generations/{sql_generation_id}",
self.get_sql_generation,
methods=["GET"],
tags=["SQL Generation"],
)

self.router.add_api_route(
"/api/v1/sql-generations/{sql_generation_id}/nl-generations",
self.create_nl_generation,
Expand All @@ -198,6 +212,20 @@ def __init__(self, settings: Settings):
tags=["NL Generation"],
)

self.router.add_api_route(
"/api/v1/nl-generations",
self.get_nl_generations,
methods=["GET"],
tags=["NL Generation"],
)

self.router.add_api_route(
"/api/v1/nl-generations/{nl_generation_id}",
self.get_nl_generation,
methods=["GET"],
tags=["NL Generation"],
)

self.router.add_api_route(
"/api/v1/sql-query-executions",
self.execute_sql_query,
Expand Down Expand Up @@ -291,6 +319,14 @@ def create_prompt_and_sql_generation(
) -> SQLGenerationResponse:
return self._api.create_prompt_and_sql_generation(prompt, sql_generation)

def get_sql_generations(
self, prompt_id: str | None = None
) -> list[SQLGenerationResponse]:
return self._api.get_sql_generations(prompt_id)

def get_sql_generation(self, sql_generation_id: str) -> SQLGenerationResponse:
return self._api.get_sql_generation(sql_generation_id)

def create_nl_generation(
self, sql_generation_id: str, nl_generation_request: NLGenerationRequest
) -> NLGenerationResponse:
Expand All @@ -316,6 +352,14 @@ def create_prompt_sql_and_nl_generation(
prompt, sql_generation, nl_generation
)

def get_nl_generations(
self, sql_generation_id: str | None = None
) -> list[NLGenerationResponse]:
return self._api.get_nl_generations(sql_generation_id)

def get_nl_generation(self, nl_generation_id: str) -> NLGenerationResponse:
return self._api.get_nl_generation(nl_generation_id)

def root(self) -> dict[str, int]:
return {"nanosecond heartbeat": self._api.heartbeat()}

Expand Down
3 changes: 3 additions & 0 deletions dataherald/services/nl_generations.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,6 @@ def create(
raise NLGenerationError(e) from e
nl_generation.metadata = nl_generation_request.metadata
return self.nl_generation_repository.insert(nl_generation)

def get(self, query) -> list[NLGeneration]:
return self.nl_generation_repository.find_by(query)
Loading