Skip to content

Commit

Permalink
Merge pull request #2 from Haoming-jpg/Haoming-tool.py
Browse files Browse the repository at this point in the history
Re-write the MongoDB query checker
  • Loading branch information
Haoming-jpg authored Nov 24, 2023
2 parents 079d20c + 53b1d82 commit 8fc76a1
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 14 deletions.
18 changes: 9 additions & 9 deletions libs/langchain/langchain/tools/mongo_database/prompt.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# flake8: noqa
QUERY_CHECKER = """
{query}
Double check the {client} query above for common mistakes, including:
- Improper use of $nin operator with null values
- Using $merge instead of $concat for combining arrays
- Incorrect use of $not or $ne for exclusive ranges
- Data type mismatch in query conditions
- Properly referencing field names in queries
- Using the correct syntax for aggregation functions
- Casting to the correct BSON data type
- Using the proper fields for $lookup in aggregations
Double check the MongoDB query above for common mistakes, including:
- Correct syntax for query operators (e.g., $match, $group, $project)
- Properly matching nested fields in the documents
- Using the appropriate array operators (e.g., $elemMatch)
- Utilizing indexes for performance optimization
- Handling data type mismatch in queries
- Ensuring proper field names and key names in queries
- Using the correct projection operators for desired output
- Properly structuring aggregation pipelines if applicable
If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.
Expand Down
8 changes: 3 additions & 5 deletions libs/langchain/langchain/tools/mongo_database/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,13 @@ def _init_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["llm_chain"] = LLMChain(
llm=values.get("llm"),
prompt=PromptTemplate(
template=QUERY_CHECKER, input_variables=["client", "query"]
template=QUERY_CHECKER, input_variables=["query"]
),
)

if values["llm_chain"].prompt.input_variables != ["client", "query"]:
if values["llm_chain"].prompt.input_variables != ["query"]:
raise ValueError(
"LLM chain for QueryCheckerTool must have input variables ['query', 'client']"
"LLM chain for QueryCheckerTool must have input variables ['query']"
)

return values
Expand All @@ -116,7 +116,6 @@ def _run(
"""Use the LLM to check the query."""
return self.llm_chain.predict(
query=query,
client=self.db.client,
callbacks=run_manager.get_child() if run_manager else None,
)

Expand All @@ -127,6 +126,5 @@ async def _arun(
) -> str:
return await self.llm_chain.apredict(
query=query,
client=self.db.client,
callbacks=run_manager.get_child() if run_manager else None,
)

0 comments on commit 8fc76a1

Please sign in to comment.