diff --git a/llama-index-experimental/llama_index/experimental/retrievers/natrual_language/BUILD b/llama-index-experimental/llama_index/experimental/retrievers/natrual_language/BUILD new file mode 100644 index 0000000000000..db46e8d6c978c --- /dev/null +++ b/llama-index-experimental/llama_index/experimental/retrievers/natrual_language/BUILD @@ -0,0 +1 @@ +python_sources() diff --git a/llama-index-experimental/llama_index/experimental/retrievers/natrual_language/README.md b/llama-index-experimental/llama_index/experimental/retrievers/natrual_language/README.md new file mode 100644 index 0000000000000..82136ca286341 --- /dev/null +++ b/llama-index-experimental/llama_index/experimental/retrievers/natrual_language/README.md @@ -0,0 +1,11 @@ +# LlamaIndex Experimental Natual Language Retrievers + +This experimental feature is enabling using natural language to retrieve information from + +- Pandas dataframes +- CSV files +- JSON objects + +Compare to other approaches this is using [duckDb](https://duckdb.org/) to perform KQL queries instead of python code. This is important as it addresses security concerns when running arbitrary code. The duckDb session is an in memory one and the original data cannot be altered by the retriever. + +The schema is also used to generate a description of the set and what could be used for. The description and ontology are then used to calculate a ranking score against the query bundle. diff --git a/llama-index-experimental/llama_index/experimental/retrievers/natrual_language/__init__.py b/llama-index-experimental/llama_index/experimental/retrievers/natrual_language/__init__.py new file mode 100644 index 0000000000000..eedef066dc92a --- /dev/null +++ b/llama-index-experimental/llama_index/experimental/retrievers/natrual_language/__init__.py @@ -0,0 +1,7 @@ +from llama_index.experimental.retrievers.natrual_language import ( + nl_csv_retriever, + nl_data_frame_retierver, + nl_json_retriever, +) + +__all__ = ["nl_data_frame_retierver", "nl_csv_retriever", "nl_json_retriever"] diff --git a/llama-index-experimental/llama_index/experimental/retrievers/natrual_language/nl_csv_retriever.py b/llama-index-experimental/llama_index/experimental/retrievers/natrual_language/nl_csv_retriever.py new file mode 100644 index 0000000000000..e12aa6c23a890 --- /dev/null +++ b/llama-index-experimental/llama_index/experimental/retrievers/natrual_language/nl_csv_retriever.py @@ -0,0 +1,36 @@ +from llama_index.core.callbacks.base import CallbackManager +from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K +from llama_index.core.llms import llm +from llama_index.core.prompts import BasePromptTemplate +from typing import Optional + +import logging +import pandas as pd + +from llama_index.experimental.retrievers.natrual_language import NLDataframeRetriever + +logger = logging.getLogger(__name__) + + +class NLCSVRetriever(NLDataframeRetriever): + def __init__( + self, + csv_path: str, + llm: llm, + name: Optional[str] = None, + text_to_sql_prompt: Optional[BasePromptTemplate] = None, + similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K, + callback_manager: Optional[CallbackManager] = None, + verbose: bool = False, + ): + df = pd.read_csv(csv_path) + + super().__init__( + df=df, + llm=llm, + text_to_sql_prompt=text_to_sql_prompt, + similarity_top_k=similarity_top_k, + verbose=verbose, + name=name, + callback_manager=callback_manager, + ) diff --git a/llama-index-experimental/llama_index/experimental/retrievers/natrual_language/nl_data_frame_retierver.py b/llama-index-experimental/llama_index/experimental/retrievers/natrual_language/nl_data_frame_retierver.py new file mode 100644 index 0000000000000..995b0225bb145 --- /dev/null +++ b/llama-index-experimental/llama_index/experimental/retrievers/natrual_language/nl_data_frame_retierver.py @@ -0,0 +1,217 @@ +from llama_index.core.callbacks.base import CallbackManager +from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K +from llama_index.core.llms import llm +from llama_index.core.llms.utils import resolve_llm +from llama_index.core.utils import print_text +from llama_index.core.retrievers import BaseRetriever +from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode +from llama_index.core.prompts import BasePromptTemplate +from llama_index.core.prompts.base import PromptTemplate +from llama_index.core.prompts.prompt_type import PromptType +from llama_index.core.prompts.default_prompts import DEFAULT_TEXT_TO_SQL_PROMPT +from typing import List, Optional + +import asyncio +import logging +import pandas as pd +import duckdb +import re + +logger = logging.getLogger(__name__) + +DEFAULT_OWL_GENERATOR_PROMPT_TMPL = """\ +given the schema:{schema}\ +generate the owl representation of the schema.\ +make sure to take into account types as described in schema.org\ +Each row would be an instance of the schema and each column would be a property of the schema.\ +Some properties could be entities and others could be literals.\ +For entities, make sure to use or extend the schema.org vocabulary.\ +When are group of columns are related to each other, make sure to use the schema.org vocabulary to represent the relationship.\ +Soe column group might identifie other entities, some might be a description of the entity, some might be a location, etc.\ +Try to use the most specific schema.org vocabulary possible.\ +Do not introduce new labels when possible, for example, if a column is a date, use the schema.org vocabulary for date.\ +If a column is called first name, use the schema.org vocabulary for givenName.\ + +if the table name is table_data or it is too generic do not name the schema after the table name, instead name it after the domain of the table.\ + +Emit the owl representation of the schema only.\ +""" + +DEFAULT_OWL_GENERATOR_PROMPT = PromptTemplate( + DEFAULT_OWL_GENERATOR_PROMPT_TMPL, + prompt_type=PromptType.CUSTOM, +) + +DEFAULT_USE_DECTECTION_TMPL = """\ +given the schema:{schema}\ +describe what this retriever is useful for. What kind of information can the retriever provide and the type of data it can access.\ +""" + +DEFAULT_USE_DECTECTION_PROMPT = PromptTemplate( + DEFAULT_USE_DECTECTION_TMPL, + prompt_type=PromptType.CUSTOM, +) + +DEFAULT_RESULT_RANKING_TMPL = """\ +given the schema:{schema}\ +and the query: {query}\ +how relevant is the schema?\ +the relevance must be a number between 0 and 1 where 1 indicates that the schema is able to model the domain of the query and 0 indicates that the schema is not able to model the domain of the query.\ + +produce only the numeric value and nothing else.\ +relevance: +""" + +DEFAULT_RESULT_RANKING_PROMPTROMPT = PromptTemplate( + DEFAULT_RESULT_RANKING_TMPL, + prompt_type=PromptType.CUSTOM, +) + + +class NLDataframeRetriever(BaseRetriever): + def __init__( + self, + df: pd.DataFrame, + llm: llm, + name: Optional[str] = None, + text_to_sql_prompt: Optional[BasePromptTemplate] = None, + schema_to_owl_prompt: Optional[BasePromptTemplate] = None, + schema_use_detection_prompt: Optional[BasePromptTemplate] = None, + result_ranking_prompt: Optional[BasePromptTemplate] = None, + similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K, + callback_manager: Optional[CallbackManager] = None, + verbose: bool = False, + ): + self._llm = resolve_llm(llm) + self._similarity_top_k = similarity_top_k + self._text_to_sql_prompt = text_to_sql_prompt or DEFAULT_TEXT_TO_SQL_PROMPT + self._result_ranking_prompt = ( + result_ranking_prompt or DEFAULT_RESULT_RANKING_PROMPTROMPT + ) + self._schema_to_owl_prompt = ( + schema_to_owl_prompt or DEFAULT_OWL_GENERATOR_PROMPT + ) + self._schema_use_detection_prompt = ( + schema_use_detection_prompt or DEFAULT_USE_DECTECTION_PROMPT + ) + data_source = df.copy(deep=True) + data_source.rename(columns=lambda x: re.sub(r"\s+", "_", x), inplace=True) + table_name = name or "data_table" + self._connection = duckdb.connect() + self._connection.sql(f"CREATE TABLE {table_name} AS SELECT * FROM data_source") + self._connection.sql( + f"INSERT INTO {table_name} BY NAME SELECT * FROM data_source" + ) + self._schema = self._connection.table(f"{table_name}").describe() + self._description = None + self._owl = None + self._schema_str = self._create_schema() + + super().__init__(callback_manager=callback_manager, verbose=verbose) + + def get_owl(self) -> str: + if self._owl is None: + response = self._llm.complete( + self._schema_to_owl_prompt, schema=self._schema_str + ) + + logger.info(f"Schema Description: {response.text}") + + self._owl = response.text + return self._owl + + def get_description(self) -> str: + if self._description is None: + response = self._llm.complete( + self._schema_use_detection_prompt, schema=self._schema_str + ) + + logger.info(f"Schema Description: {response.text}") + + self._description = response.text + + return self._description + + def _create_schema(self) -> str: + table_desc_list = [] + table_desc_list.append(f"{self._schema.alias} (") + for idx, column in enumerate(self._schema.columns): + table_desc_list.append(f" {column} {self._schema.types[idx]},") + table_desc_list.append(f")") + + tables_desc_str = "\n\n".join(table_desc_list) + logger.info(f"Schema: {tables_desc_str}") + + return tables_desc_str + + async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + tables_desc_str = self._schema_str + + if self._verbose: + print_text( + f"Executing retrieval for {query_bundle.query_str} using Schema: {tables_desc_str}\n", + color="llama_pink", + ) + + response_str = self._llm.predict( + self._text_to_sql_prompt, + query_str=query_bundle.query_str, + schema=tables_desc_str, + dialect="PL/pgSQL", + ) + + rank = self._llm.complete( + self._result_ranking_prompt, + query=query_bundle.query_str, + schema=tables_desc_str, + ) + + score = 1.0 + try: + score = float(rank.text) + except ValueError as parsing_error: + logger.error( + f"Error in parsing the rank value: {rank.text} : {parsing_error}" + ) + + query = self._parse_response_to_sql(response_str) + if self._verbose: + print_text( + f"Executing query: {query}\n", + color="llama_turquoise", + ) + + logger.info(f"Executing query: {query}") + + results = [] + if self._similarity_top_k <= 0: + results = self._connection.execute(query).fetchall() + else: + cursor = self._connection.execute(query) + + for i in range(self._similarity_top_k): + result = cursor.fetchone() + if result is None: + break + results.append(result) + + cursor.fetchall() + + retrived_nodes: List[NodeWithScore] = [] + for result in results: + retrived_nodes.append( + NodeWithScore(node=TextNode(text=f"{result}"), score=score) + ) + + return retrived_nodes + + def _parse_response_to_sql(self, response: str) -> str: + """Parse response to SQL.""" + # Find and remove SQLResult part + sql_result_start = response.find("SQLResult:") + if sql_result_start != -1: + response = response[:sql_result_start] + return response.strip() + + def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + return asyncio.run(self._aretrieve(query_bundle)) diff --git a/llama-index-experimental/llama_index/experimental/retrievers/natrual_language/nl_json_retriever.py b/llama-index-experimental/llama_index/experimental/retrievers/natrual_language/nl_json_retriever.py new file mode 100644 index 0000000000000..162b35fd90cdd --- /dev/null +++ b/llama-index-experimental/llama_index/experimental/retrievers/natrual_language/nl_json_retriever.py @@ -0,0 +1,38 @@ +from llama_index.core.callbacks.base import CallbackManager +from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K +from llama_index.core.llms import llm +from llama_index.core.prompts import BasePromptTemplate +from typing import Optional + +import logging +import pandas as pd +import json + +from llama_index.experimental.retrievers.natrual_language import NLDataframeRetriever + +logger = logging.getLogger(__name__) + + +class NLJsonRetriever(NLDataframeRetriever): + def __init__( + self, + json_path: str, + llm: llm, + name: Optional[str] = None, + text_to_sql_prompt: Optional[BasePromptTemplate] = None, + similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K, + callback_manager: Optional[CallbackManager] = None, + verbose: bool = False, + ): + data = json.loads(open(json_path).read()) + df = pd.DataFrame.from_dict(data, orient="columns") + + super().__init__( + df=df, + llm=llm, + text_to_sql_prompt=text_to_sql_prompt, + similarity_top_k=similarity_top_k, + verbose=verbose, + name=name, + callback_manager=callback_manager, + ) diff --git a/llama-index-experimental/pyproject.toml b/llama-index-experimental/pyproject.toml index aa5271424d45d..f19c7fe1f40a4 100644 --- a/llama-index-experimental/pyproject.toml +++ b/llama-index-experimental/pyproject.toml @@ -25,13 +25,14 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-experimental" readme = "README.md" -version = "0.5.3" +version = "0.5.4" [tool.poetry.dependencies] python = ">=3.10,<4.0" llama-index-core = "^0.12.13" llama-index-finetuning = "^0.3.0" pandas = "*" +duckdb = "^1.0.0" [tool.poetry.group.dev.dependencies] ipython = "8.10.0"