Skip to content

Commit

Permalink
Fix langchain tests (mlflow#10495)
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Lok <[email protected]>
  • Loading branch information
daniellok-db authored Nov 24, 2023
1 parent ddfeea6 commit bf5eafb
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
3 changes: 2 additions & 1 deletion mlflow/langchain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,7 @@ def predict(
:return: Model predictions.
"""
import langchain
from langchain.schema.retriever import BaseRetriever

from mlflow.openai.utils import TEST_CONTENT, TEST_INTERMEDIATE_STEPS, TEST_SOURCE_DOCUMENTS

Expand All @@ -748,7 +749,7 @@ def predict(
(
langchain.chains.llm.LLMChain,
langchain.chains.RetrievalQA,
langchain.schema.retriever.BaseRetriever,
BaseRetriever,
),
):
mockContent = TEST_CONTENT
Expand Down
4 changes: 3 additions & 1 deletion tests/langchain/test_langchain_model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,9 @@ def embed_query(self, text: str) -> List[float]:


def assert_equal_retrievers(retriever, expected_retreiver):
assert isinstance(retriever, langchain.schema.retriever.BaseRetriever)
from langchain.schema.retriever import BaseRetriever

assert isinstance(retriever, BaseRetriever)
assert isinstance(retriever, type(expected_retreiver))
assert isinstance(retriever.vectorstore, type(expected_retreiver.vectorstore))
assert retriever.tags == expected_retreiver.tags
Expand Down

0 comments on commit bf5eafb

Please sign in to comment.