Skip to content

Commit

Permalink
DH-5099/save sql_generation on initial then update (#288)
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadrezaPourreza authored Dec 19, 2023
1 parent d27b39d commit 576dfa1
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 7 deletions.
10 changes: 10 additions & 0 deletions dataherald/repositories/sql_generations.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ def insert(self, sql_generation: SQLGeneration) -> SQLGeneration:
)
return sql_generation

def update(self, sql_generation: SQLGeneration) -> SQLGeneration:
sql_generation_dict = sql_generation.dict(exclude={"id"})
sql_generation_dict["prompt_id"] = str(sql_generation.prompt_id)
self.storage.update_or_create(
DB_COLLECTION,
{"_id": ObjectId(sql_generation.id)},
sql_generation_dict,
)
return sql_generation

def find_one(self, query: dict) -> SQLGeneration | None:
row = self.storage.find_one(DB_COLLECTION, query)
if not row:
Expand Down
23 changes: 16 additions & 7 deletions dataherald/services/sql_generations.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ def __init__(self, system: System, storage):
def create(
self, prompt_id: str, sql_generation_request: SQLGenerationRequest
) -> SQLGeneration:
initial_sql_generation = SQLGeneration(
prompt_id=prompt_id,
created_at=datetime.now(),
)
self.sql_generation_repository.insert(initial_sql_generation)
prompt_repository = PromptRepository(self.storage)
prompt = prompt_repository.find_by_id(prompt_id)
if not prompt:
Expand All @@ -43,11 +48,8 @@ def create(
database = SQLDatabase.get_sql_engine(db_connection)
if sql_generation_request.sql is not None:
sql_generation = SQLGeneration(
prompt_id=prompt_id,
sql=sql_generation_request.sql,
tokens_used=0,
created_at=datetime.now(),
completed_at=datetime.now(),
)
sql_generation = create_sql_query_status(
db=database, query=sql_generation.sql, sql_generation=sql_generation
Expand All @@ -61,6 +63,9 @@ def create(
else:
sql_generator = DataheraldFinetuningAgent(self.system)
sql_generator.finetuning_id = sql_generation_request.finetuning_id
initial_sql_generation.finetuning_id = (
sql_generation_request.finetuning_id
)
try:
sql_generation = sql_generator.generate_response(
user_prompt=prompt, database_connection=db_connection
Expand All @@ -74,10 +79,14 @@ def create(
sql_generation=sql_generation,
database_connection=db_connection,
)
sql_generation.evaluate = sql_generation_request.evaluate
sql_generation.confidence_score = confidence_score
sql_generation.metadata = sql_generation_request.metadata
return self.sql_generation_repository.insert(sql_generation)
initial_sql_generation.evaluate = sql_generation_request.evaluate
initial_sql_generation.confidence_score = confidence_score
initial_sql_generation.sql = sql_generation.sql
initial_sql_generation.tokens_used = sql_generation.tokens_used
initial_sql_generation.completed_at = datetime.now()
initial_sql_generation.metadata = initial_sql_generation.metadata
initial_sql_generation.status = sql_generation.status
return self.sql_generation_repository.update(initial_sql_generation)

def get(self, query) -> list[SQLGeneration]:
return self.sql_generation_repository.find_by(query)
Expand Down

0 comments on commit 576dfa1

Please sign in to comment.