diff --git a/dataherald/api/__init__.py b/dataherald/api/__init__.py index e81f0be1..afc5f1bc 100644 --- a/dataherald/api/__init__.py +++ b/dataherald/api/__init__.py @@ -164,6 +164,16 @@ def create_prompt_and_sql_generation( ) -> SQLGenerationResponse: pass + @abstractmethod + def get_sql_generations( + self, prompt_id: str | None = None + ) -> list[SQLGenerationResponse]: + pass + + @abstractmethod + def get_sql_generation(self, sql_generation_id: str) -> SQLGenerationResponse: + pass + @abstractmethod def create_nl_generation( self, sql_generation_id: str, nl_generation_request: NLGenerationRequest @@ -186,3 +196,13 @@ def create_prompt_sql_and_nl_generation( nl_generation: NLGenerationRequest, ) -> NLGenerationResponse: pass + + @abstractmethod + def get_nl_generations( + self, sql_generation_id: str | None = None + ) -> list[NLGenerationResponse]: + pass + + @abstractmethod + def get_nl_generation(self, nl_generation_id: str) -> NLGenerationResponse: + pass diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index 831f7f7b..d83761a5 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -335,7 +335,10 @@ def get_query_history(self, db_connection_id: str) -> list[QueryHistory]: def add_golden_sqls(self, golden_sqls: List[GoldenSQLRequest]) -> List[GoldenSQL]: """Takes in a list of NL <> SQL pairs and stores them to be used in prompts to the LLM""" context_store = self.system.instance(ContextStore) - return context_store.add_golden_sqls(golden_sqls) + golden_sqls = context_store.add_golden_sqls(golden_sqls) + for golden_sql in golden_sqls: + golden_sql.created_at = str(golden_sql.created_at) + return golden_sqls @override def execute_sql_query(self, query: Query) -> tuple[str, dict]: @@ -521,11 +524,9 @@ def create_sql_generation( except InvalidId as e: raise HTTPException(status_code=400, detail=str(e)) from e except PromptNotFoundError as e: - raise HTTPException(status_code=404, detail="Prompt not found") from e + raise HTTPException(status_code=404, detail=str(e)) from e except SQLGenerationError as e: - raise HTTPException( - status_code=400, detail="Error raised during SQL generation" - ) from e + raise HTTPException(status_code=400, detail=str(e)) from e sql_generation_dict = sql_generation.dict() sql_generation_dict["created_at"] = str(sql_generation.created_at) sql_generation_dict["completed_at"] = str(sql_generation.completed_at) @@ -540,10 +541,8 @@ def create_prompt_and_sql_generation( prompt = prompt_service.create(prompt) except InvalidId as e: raise HTTPException(status_code=400, detail=str(e)) from e - except DatabaseConnectionNotFoundError: - raise HTTPException( # noqa: B904 - status_code=404, detail="Database connection not found" - ) + except DatabaseConnectionNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e sql_generation_service = SQLGenerationService(self.system, self.storage) try: @@ -551,16 +550,50 @@ def create_prompt_and_sql_generation( except InvalidId as e: raise HTTPException(status_code=400, detail=str(e)) from e except PromptNotFoundError as e: - raise HTTPException(status_code=404, detail="Prompt not found") from e + raise HTTPException(status_code=404, detail=str(e)) from e except SQLGenerationError as e: - raise HTTPException( - status_code=400, detail="Error raised during SQL generation" - ) from e + raise HTTPException(status_code=400, detail=str(e)) from e sql_generation_dict = sql_generation.dict() sql_generation_dict["created_at"] = str(sql_generation.created_at) sql_generation_dict["completed_at"] = str(sql_generation.completed_at) return SQLGenerationResponse(**sql_generation_dict) + @override + def get_sql_generations( + self, prompt_id: str | None = None + ) -> list[SQLGenerationResponse]: + sql_generation_service = SQLGenerationService(self.system, self.storage) + query = {} + if prompt_id: + query["prompt_id"] = prompt_id + sql_generations = sql_generation_service.get(query) + result = [] + for sql_generation in sql_generations: + sql_generation_dict = sql_generation.dict() + sql_generation_dict["created_at"] = str(sql_generation.created_at) + sql_generation_dict["completed_at"] = str(sql_generation.completed_at) + result.append(SQLGenerationResponse(**sql_generation_dict)) + return result + + @override + def get_sql_generation(self, sql_generation_id: str) -> SQLGenerationResponse: + sql_generation_service = SQLGenerationService(self.system, self.storage) + try: + sql_generations = sql_generation_service.get( + {"_id": ObjectId(sql_generation_id)} + ) + except InvalidId as e: + raise HTTPException(status_code=400, detail=str(e)) from e + + if len(sql_generations) == 0: + raise HTTPException( + status_code=404, detail=f"SQL Generation {sql_generation_id} not found" + ) + sql_generation_dict = sql_generations[0].dict() + sql_generation_dict["created_at"] = str(sql_generations[0].created_at) + sql_generation_dict["completed_at"] = str(sql_generations[0].completed_at) + return SQLGenerationResponse(**sql_generation_dict) + @override def create_nl_generation( self, sql_generation_id: str, nl_generation_request: NLGenerationRequest @@ -573,13 +606,9 @@ def create_nl_generation( except InvalidId as e: raise HTTPException(status_code=400, detail=str(e)) from e except SQLGenerationNotFoundError as e: - raise HTTPException( - status_code=400, detail="SQL Generation not found" - ) from e + raise HTTPException(status_code=400, detail=str(e)) from e except NLGenerationError as e: - raise HTTPException( - status_code=400, detail="Error raised during NL generation" - ) from e + raise HTTPException(status_code=400, detail=str(e)) from e nl_generation_dict = nl_generation.dict() nl_generation_dict["created_at"] = str(nl_generation.created_at) return NLGenerationResponse(**nl_generation_dict) @@ -597,11 +626,9 @@ def create_sql_and_nl_generation( except InvalidId as e: raise HTTPException(status_code=400, detail=str(e)) from e except PromptNotFoundError as e: - raise HTTPException(status_code=404, detail="Prompt not found") from e + raise HTTPException(status_code=404, detail=str(e)) from e except SQLGenerationError as e: - raise HTTPException( - status_code=400, detail="Error raised during SQL generation" - ) from e + raise HTTPException(status_code=400, detail=str(e)) from e nl_generation_service = NLGenerationService(self.system, self.storage) try: @@ -611,13 +638,9 @@ def create_sql_and_nl_generation( except InvalidId as e: raise HTTPException(status_code=400, detail=str(e)) from e except SQLGenerationNotFoundError as e: - raise HTTPException( - status_code=400, detail="SQL Generation not found" - ) from e + raise HTTPException(status_code=400, detail=str(e)) from e except NLGenerationError as e: - raise HTTPException( - status_code=400, detail="Error raised during NL generation" - ) from e + raise HTTPException(status_code=400, detail=str(e)) from e nl_generation_dict = nl_generation.dict() nl_generation_dict["created_at"] = str(nl_generation.created_at) return NLGenerationResponse(**nl_generation_dict) @@ -634,10 +657,8 @@ def create_prompt_sql_and_nl_generation( prompt = prompt_service.create(prompt) except InvalidId as e: raise HTTPException(status_code=400, detail=str(e)) from e - except DatabaseConnectionNotFoundError: - raise HTTPException( # noqa: B904 - status_code=404, detail="Database connection not found" - ) + except DatabaseConnectionNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e # noqa: B904 sql_generation_service = SQLGenerationService(self.system, self.storage) try: @@ -645,11 +666,9 @@ def create_prompt_sql_and_nl_generation( except InvalidId as e: raise HTTPException(status_code=400, detail=str(e)) from e except PromptNotFoundError as e: - raise HTTPException(status_code=404, detail="Prompt not found") from e + raise HTTPException(status_code=404, detail=str(e)) from e except SQLGenerationError as e: - raise HTTPException( - status_code=400, detail="Error raised during SQL generation" - ) from e + raise HTTPException(status_code=400, detail=str(e)) from e nl_generation_service = NLGenerationService(self.system, self.storage) try: @@ -659,13 +678,44 @@ def create_prompt_sql_and_nl_generation( except InvalidId as e: raise HTTPException(status_code=400, detail=str(e)) from e except SQLGenerationNotFoundError as e: - raise HTTPException( - status_code=400, detail="SQL Generation not found" - ) from e + raise HTTPException(status_code=400, detail=str(e)) from e except NLGenerationError as e: - raise HTTPException( - status_code=400, detail="Error raised during NL generation" - ) from e + raise HTTPException(status_code=400, detail=str(e)) from e nl_generation_dict = nl_generation.dict() nl_generation_dict["created_at"] = str(nl_generation.created_at) return NLGenerationResponse(**nl_generation_dict) + + @override + def get_nl_generations( + self, sql_generation_id: str | None = None + ) -> list[NLGenerationResponse]: + nl_generation_service = NLGenerationService(self.system, self.storage) + query = {} + if sql_generation_id: + query["sql_generation_id"] = sql_generation_id + nl_generations = nl_generation_service.get(query) + result = [] + for nl_generation in nl_generations: + nl_generation_dict = nl_generation.dict() + nl_generation_dict["created_at"] = str(nl_generation.created_at) + result.append(NLGenerationResponse(**nl_generation_dict)) + return result + + @override + def get_nl_generation(self, nl_generation_id: str) -> NLGenerationResponse: + nl_generation_service = NLGenerationService(self.system, self.storage) + try: + nl_generations = nl_generation_service.get( + {"_id": ObjectId(nl_generation_id)} + ) + except InvalidId as e: + raise HTTPException(status_code=400, detail=str(e)) from e + + if len(nl_generations) == 0: + raise HTTPException( + status_code=404, + detail=f"NL Generation {nl_generation_id} not found", + ) + nl_generation_dict = nl_generations[0].dict() + nl_generation_dict["created_at"] = str(nl_generations[0].created_at) + return NLGenerationResponse(**nl_generation_dict) diff --git a/dataherald/api/types/requests.py b/dataherald/api/types/requests.py index 331c09b7..a2738fd5 100644 --- a/dataherald/api/types/requests.py +++ b/dataherald/api/types/requests.py @@ -8,12 +8,12 @@ class PromptRequest(BaseModel): class SQLGenerationRequest(BaseModel): - model: str | None + finetuning_id: str | None evaluate: bool = False sql: str | None metadata: dict | None class NLGenerationRequest(BaseModel): - max_rows: int + max_rows: int = 100 metadata: dict | None diff --git a/dataherald/api/types/responses.py b/dataherald/api/types/responses.py index 8fc46d6c..0039d654 100644 --- a/dataherald/api/types/responses.py +++ b/dataherald/api/types/responses.py @@ -12,7 +12,7 @@ class PromptResponse(BaseModel): class SQLGenerationResponse(BaseModel): id: str prompt_id: str - model: str | None + finetuning_id: str | None status: str completed_at: str sql: str | None diff --git a/dataherald/context_store/default.py b/dataherald/context_store/default.py index 66aa00aa..1dcf24a7 100644 --- a/dataherald/context_store/default.py +++ b/dataherald/context_store/default.py @@ -64,18 +64,18 @@ def add_golden_sqls(self, golden_sqls: List[GoldenSQLRequest]) -> List[GoldenSQL golden_sqls_repository = GoldenSQLRepository(self.db) retruned_golden_sqls = [] for record in golden_sqls: - tables = Parser(record.sql_query).tables - question = record.question + tables = Parser(record.sql).tables + prompt_text = record.prompt_text golden_sql = GoldenSQL( - question=question, - sql_query=record.sql_query, + prompt_text=prompt_text, + sql=record.sql, db_connection_id=record.db_connection_id, metadata=record.metadata, ) retruned_golden_sqls.append(golden_sql) golden_sql = golden_sqls_repository.insert(golden_sql) self.vector_store.add_record( - documents=question, + documents=prompt_text, db_connection_id=record.db_connection_id, collection=self.golden_sql_collection, metadata=[ diff --git a/dataherald/finetuning/openai_finetuning.py b/dataherald/finetuning/openai_finetuning.py index e36bfcc4..e7d5f52c 100644 --- a/dataherald/finetuning/openai_finetuning.py +++ b/dataherald/finetuning/openai_finetuning.py @@ -128,8 +128,8 @@ def create_fintuning_dataset(self): model = model_repository.find_by_id(self.fine_tuning_model.id) for golden_sql_id in self.fine_tuning_model.golden_sqls: golden_sql = golden_sqls_repository.find_by_id(golden_sql_id) - question = golden_sql.question - query = golden_sql.sql_query + question = golden_sql.prompt_text + query = golden_sql.sql system_prompt = FINETUNING_SYSTEM_INFORMATION + database_schema user_prompt = "User Question: " + question + "\n SQL: " assistant_prompt = query + "\n" diff --git a/dataherald/server/fastapi/__init__.py b/dataherald/server/fastapi/__init__.py index 056b92ea..164598e2 100644 --- a/dataherald/server/fastapi/__init__.py +++ b/dataherald/server/fastapi/__init__.py @@ -174,6 +174,20 @@ def __init__(self, settings: Settings): tags=["SQL Generation"], ) + self.router.add_api_route( + "/api/v1/sql-generations", + self.get_sql_generations, + methods=["GET"], + tags=["SQL Generation"], + ) + + self.router.add_api_route( + "/api/v1/sql-generations/{sql_generation_id}", + self.get_sql_generation, + methods=["GET"], + tags=["SQL Generation"], + ) + self.router.add_api_route( "/api/v1/sql-generations/{sql_generation_id}/nl-generations", self.create_nl_generation, @@ -198,6 +212,20 @@ def __init__(self, settings: Settings): tags=["NL Generation"], ) + self.router.add_api_route( + "/api/v1/nl-generations", + self.get_nl_generations, + methods=["GET"], + tags=["NL Generation"], + ) + + self.router.add_api_route( + "/api/v1/nl-generations/{nl_generation_id}", + self.get_nl_generation, + methods=["GET"], + tags=["NL Generation"], + ) + self.router.add_api_route( "/api/v1/sql-query-executions", self.execute_sql_query, @@ -291,6 +319,14 @@ def create_prompt_and_sql_generation( ) -> SQLGenerationResponse: return self._api.create_prompt_and_sql_generation(prompt, sql_generation) + def get_sql_generations( + self, prompt_id: str | None = None + ) -> list[SQLGenerationResponse]: + return self._api.get_sql_generations(prompt_id) + + def get_sql_generation(self, sql_generation_id: str) -> SQLGenerationResponse: + return self._api.get_sql_generation(sql_generation_id) + def create_nl_generation( self, sql_generation_id: str, nl_generation_request: NLGenerationRequest ) -> NLGenerationResponse: @@ -316,6 +352,14 @@ def create_prompt_sql_and_nl_generation( prompt, sql_generation, nl_generation ) + def get_nl_generations( + self, sql_generation_id: str | None = None + ) -> list[NLGenerationResponse]: + return self._api.get_nl_generations(sql_generation_id) + + def get_nl_generation(self, nl_generation_id: str) -> NLGenerationResponse: + return self._api.get_nl_generation(nl_generation_id) + def root(self) -> dict[str, int]: return {"nanosecond heartbeat": self._api.heartbeat()} diff --git a/dataherald/services/nl_generations.py b/dataherald/services/nl_generations.py index f42e6f10..0b92f215 100644 --- a/dataherald/services/nl_generations.py +++ b/dataherald/services/nl_generations.py @@ -38,3 +38,6 @@ def create( raise NLGenerationError(e) from e nl_generation.metadata = nl_generation_request.metadata return self.nl_generation_repository.insert(nl_generation) + + def get(self, query) -> list[NLGeneration]: + return self.nl_generation_repository.find_by(query) diff --git a/dataherald/services/sql_generations.py b/dataherald/services/sql_generations.py index d8d1fcbf..eb0975ed 100644 --- a/dataherald/services/sql_generations.py +++ b/dataherald/services/sql_generations.py @@ -48,13 +48,13 @@ def create( ) else: # noqa: PLR5501 if ( - sql_generation_request.model is None - or sql_generation_request.model == "" + sql_generation_request.finetuning_id is None + or sql_generation_request.finetuning_id == "" ): sql_generator = DataheraldSQLAgent(self.system) else: sql_generator = DataheraldFinetuningAgent(self.system) - sql_generator.finetuned_llm_id = sql_generation_request.model + sql_generator.finetuning_id = sql_generation_request.finetuning_id try: sql_generation = sql_generator.generate_response( user_prompt=prompt, database_connection=db_connection @@ -72,3 +72,6 @@ def create( sql_generation.confidence_score = confidence_score sql_generation.metadata = sql_generation_request.metadata return self.sql_generation_repository.insert(sql_generation) + + def get(self, query) -> list[SQLGeneration]: + return self.sql_generation_repository.find_by(query) diff --git a/dataherald/sql_generator/dataherald_finetuning_agent.py b/dataherald/sql_generator/dataherald_finetuning_agent.py index d45b002d..856b2ea5 100644 --- a/dataherald/sql_generator/dataherald_finetuning_agent.py +++ b/dataherald/sql_generator/dataherald_finetuning_agent.py @@ -18,6 +18,7 @@ from langchain.chains.llm import LLMChain from langchain.schema import AgentAction from langchain.tools.base import BaseTool +from openai import OpenAI from overrides import override from pydantic import BaseModel, Field from sqlalchemy.exc import SQLAlchemyError @@ -26,6 +27,8 @@ from dataherald.db import DB from dataherald.db_scanner.models.types import TableDescription, TableDescriptionStatus from dataherald.db_scanner.repository.base import TableDescriptionRepository +from dataherald.finetuning.openai_finetuning import OpenAIFineTuning +from dataherald.repositories.finetunings import FinetuningsRepository from dataherald.sql_database.base import SQLDatabase, SQLInjectionError from dataherald.sql_database.models.types import ( DatabaseConnection, @@ -188,82 +191,11 @@ class GenerateSQL(BaseSQLDatabaseTool, BaseTool): Use this tool to generate SQL queries. Pass the user question as the input to the tool. """ - finetuned_llm_id: str = Field(exclude=True) + finetuning_model_id: str = Field(exclude=True) args_schema: Type[BaseModel] = QuestionInput db_scan: List[TableDescription] - instructions: List[dict] | None = Field(exclude=True, default=None) api_key: str = Field(exclude=True) - def format_columns(self, table: TableDescription, top_k: int = 100) -> str: - """ - format_columns formats the columns. - - Args: - table: The table to format. - top_k: The number of categories to show. - - Returns: - The formatted columns in string format. - """ - columns_information = "" - for column in table.columns: - name = column.name - is_primary_key = column.is_primary_key - if is_primary_key: - primary_key_text = ( - f"this column is a primary key of the table {table.table_name}," - ) - else: - primary_key_text = "" - foreign_key = column.foreign_key - if foreign_key: - foreign_key_text = ( - f"this column has a foreign key to the table {foreign_key}," - ) - else: - foreign_key_text = "" - categories = column.categories - if categories: - if len(categories) <= top_k: - categories_text = f"Categories: {categories}," - else: - categories_text = "" - else: - categories_text = "" - if primary_key_text or foreign_key_text or categories_text: - columns_information += ( - f"{name}: {primary_key_text}{foreign_key_text}{categories_text}\n" - ) - return columns_information - - def format_database_schema( - self, db_scan: List[TableDescription], top_k: int = 100 - ) -> str: - """ - format_database_schema formats the database schema. - - Args: - db_scan: The database schema. - - Returns: - The formatted database schema in string format. - """ - schema_of_database = "" - for table in db_scan: - tables_schema = table.table_schema - schema_of_database += f"{tables_schema}\n" - schema_of_database += "# Categorical Columns:\n" - columns_information = self.format_columns(table, top_k) - schema_of_database += columns_information - sample_rows = table.examples - schema_of_database += "# Sample rows:\n" - for item in sample_rows: - for key, value in item.items(): - schema_of_database += f"{key}: {value}, " - schema_of_database += "\n" - schema_of_database += "\n\n" - return schema_of_database - @catch_exceptions() def _run( self, @@ -271,28 +203,20 @@ def _run( run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002 ) -> str: """Execute the query, return the results or an error message.""" - system_prompt = FINETUNING_SYSTEM_INFORMATION + self.format_database_schema( + system_prompt = FINETUNING_SYSTEM_INFORMATION + OpenAIFineTuning.format_dataset( self.db_scan ) - if self.instructions: - user_prompt = "Database administrator rules that should be followed:\n" - for index, instruction in enumerate(self.instructions): - user_prompt += f"{index+1}) {instruction['instruction']}\n" - user_prompt += "\n\n" - user_prompt += "User Question: " + question - else: - user_prompt = "User Question: " + question - response = openai.ChatCompletion.create( - model=self.finetuned_llm_id, - api_key=self.api_key, + user_prompt = "User Question: " + question + "\n SQL: " + client = OpenAI(api_key=self.api_key) + response = client.chat.completions.create( + model=self.finetuning_model_id, temperature=0.0, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], ) - - return response.choices[0]["message"]["content"] + return response.choices[0].message.content async def _arun( self, @@ -347,7 +271,7 @@ class SQLDatabaseToolkit(BaseToolkit): instructions: List[dict] | None = Field(exclude=True, default=None) db_scan: List[TableDescription] = Field(exclude=True) api_key: str = Field(exclude=True) - finetuned_llm_id: str = Field(exclude=True) + finetuning_model_id: str = Field(exclude=True) @property def dialect(self) -> str: @@ -368,9 +292,8 @@ def get_tools(self) -> List[BaseTool]: GenerateSQL( db=self.db, db_scan=self.db_scan, - instructions=self.instructions, api_key=self.api_key, - finetuned_llm_id=self.finetuned_llm_id, + finetuning_model_id=self.finetuning_model_id, ) ) tools.append(SchemaSQLDatabaseTool(db=self.db, db_scan=self.db_scan)) @@ -384,7 +307,7 @@ class DataheraldFinetuningAgent(SQLGenerator): """ llm: Any = None - finetuned_llm_id: str = Field(exclude=True) + finetuning_id: str = Field(exclude=True) def create_sql_agent( self, @@ -404,8 +327,9 @@ def create_sql_agent( ) -> AgentExecutor: tools = toolkit.get_tools() admin_instructions = "" - for index, instruction in enumerate(toolkit.instructions): - admin_instructions += f"{index+1}) {instruction['instruction']}\n" + if toolkit.instructions: + for index, instruction in enumerate(toolkit.instructions): + admin_instructions += f"{index+1}) {instruction['instruction']}\n" prefix = prefix.format( dialect=toolkit.dialect, admin_instructions=admin_instructions ) @@ -458,6 +382,7 @@ def generate_response( response = SQLGeneration( prompt_id=user_prompt.id, created_at=datetime.datetime.now(), + finetuning_id=self.finetuning_id, ) self.llm = self.model.get_model( database_connection=database_connection, @@ -476,6 +401,8 @@ def generate_response( _, instructions = context_store.retrieve_context_for_question( user_prompt, number_of_samples=1 ) + finetunings_repository = FinetuningsRepository(storage) + finetuning = finetunings_repository.find_by_id(self.finetuning_id) self.database = SQLDatabase.get_sql_engine(database_connection) toolkit = SQLDatabaseToolkit( @@ -483,7 +410,7 @@ def generate_response( instructions=instructions, db_scan=db_scan, api_key=database_connection.decrypt_api_key(), - finetuned_llm_id=self.finetuned_llm_id, + finetuning_model_id=finetuning.model_id, ) agent_executor = self.create_sql_agent( toolkit=toolkit, @@ -504,12 +431,13 @@ def generate_response( return SQLGeneration( prompt_id=user_prompt.id, tokens_used=cb.total_tokens, - model="FineTuning_Agent", + finetuning_id=self.finetuned_llm_id, completed_at=datetime.datetime.now(), sql="", status="INVALID", error=str(e), ) + sql_query = "" if "```sql" in result["output"]: sql_query = self.remove_markdown(result["output"]) else: diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index ced4dba5..224387c8 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -666,12 +666,12 @@ def generate_response( return SQLGeneration( prompt_id=user_prompt.id, tokens_used=cb.total_tokens, - model="RAG_AGENT", completed_at=datetime.datetime.now(), sql="", status="INVALID", error=str(e), ) + sql_query = "" if "```sql" in result["output"]: sql_query = self.remove_markdown(result["output"]) else: diff --git a/dataherald/types.py b/dataherald/types.py index 9c0daad8..1f1d495a 100644 --- a/dataherald/types.py +++ b/dataherald/types.py @@ -44,8 +44,8 @@ class Instruction(BaseModel): class GoldenSQLRequest(DBConnectionValidation): - question: str = Field(None, min_length=3) - sql_query: str = Field(None, min_length=3) + prompt_text: str = Field(None, min_length=3) + sql: str = Field(None, min_length=3) created_at: datetime = Field(default_factory=datetime.now) metadata: dict | None @@ -164,7 +164,7 @@ class Prompt(BaseModel): class SQLGeneration(BaseModel): id: str | None = None prompt_id: str - model: str | None + finetuning_id: str | None evaluate: bool = False sql: str | None status: str = "INVALID"