Skip to content

Commit

Permalink
DH-5082/nl_generations implementation (#279)
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadrezaPourreza authored Dec 18, 2023
1 parent 5affa1d commit 11ea83b
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 8 deletions.
4 changes: 4 additions & 0 deletions dataherald/repositories/sql_generations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
DB_COLLECTION = "sql_generations"


class SQLGenerationNotFoundError(Exception):
pass


class SQLGenerationRepository:
def __init__(self, storage):
self.storage = storage
Expand Down
26 changes: 23 additions & 3 deletions dataherald/services/nl_generations.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,33 @@
from dataherald.api.types.requests import NLGenerationRequest
from dataherald.config import System
from dataherald.repositories.nl_generations import NLGenerationRepository
from dataherald.repositories.sql_generations import (
SQLGenerationNotFoundError,
SQLGenerationRepository,
)
from dataherald.sql_generator.generates_nl_answer import GeneratesNlAnswer
from dataherald.types import NLGeneration


class NLGenerationService:
def __init__(self, storage):
def __init__(self, system: System, storage):
self.system = system
self.storage = storage
self.nl_generation_repository = NLGenerationRepository(storage)

def create(
self, prompt_id: str, nl_generation_request: NLGenerationRequest
self, sql_generation_id: str, nl_generation_request: NLGenerationRequest
) -> NLGeneration:
pass
sql_generation_repository = SQLGenerationRepository(self.storage)
sql_generation = sql_generation_repository.find_by_id(sql_generation_id)
if not sql_generation:
raise SQLGenerationNotFoundError(
f"SQL Generation {sql_generation_id} not found"
)
nl_generator = GeneratesNlAnswer(self.system, self.storage)
nl_generation = nl_generator.execute(
sql_generation=sql_generation,
top_k=nl_generation_request.max_rows,
)
nl_generation.metadata = nl_generation_request.metadata
return self.nl_generation_repository.insert(nl_generation)
2 changes: 1 addition & 1 deletion dataherald/services/sql_generations.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def create(
completed_at=datetime.now(),
tokens_used=0,
created_at=datetime.now(),
metadata=sql_generation_request.metadata,
)
sql_generation = create_sql_query_status(
db=database, query=sql_generation.sql, sql_generation=sql_generation
Expand All @@ -72,4 +71,5 @@ def create(
)
sql_generation.evaluate = sql_generation_request.evaluate
sql_generation.confidence_score = confidence_score
sql_generation.metadata = sql_generation_request.metadata
return self.sql_generation_repository.insert(sql_generation)
8 changes: 4 additions & 4 deletions dataherald/sql_generator/generates_nl_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"""

HUMAN_TEMPLATE = """ Answer the question given the sql query and the sql query result.
Question: {question}
Question: {prompt}
SQL query: {sql_query}
SQL query result: {sql_query_result}
"""
Expand All @@ -39,11 +39,11 @@ def execute(
top_k: int = 100,
) -> NLGeneration:
prompt_repository = PromptRepository(self.storage)
question = prompt_repository.find_by_id(sql_generation.prompt_id)
prompt = prompt_repository.find_by_id(sql_generation.prompt_id)

db_connection_repository = DatabaseConnectionRepository(self.storage)
database_connection = db_connection_repository.find_by_id(
question.db_connection_id
prompt.db_connection_id
)
self.llm = self.model.get_model(
database_connection=database_connection,
Expand Down Expand Up @@ -95,7 +95,7 @@ def execute(
)
chain = LLMChain(llm=self.llm, prompt=chat_prompt)
nl_resp = chain.run(
question=question.question,
prompt=prompt.text,
sql_query=sql_generation.sql,
sql_query_result="\n".join([str(row) for row in rows]),
)
Expand Down

0 comments on commit 11ea83b

Please sign in to comment.