Skip to content

Commit

Permalink
Fix LangChain compatibility with SQLDatabase (mlflow#9192)
Browse files Browse the repository at this point in the history
Signed-off-by: dbczumar <[email protected]>
Signed-off-by: santiagxf <[email protected]>
  • Loading branch information
dbczumar authored and santiagxf committed Aug 7, 2023
1 parent b9cee34 commit 1397993
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 33 deletions.
95 changes: 63 additions & 32 deletions mlflow/langchain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
import os
import shutil
import types
import functools
from packaging import version
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, NamedTuple

import pandas as pd
import cloudpickle
Expand All @@ -33,6 +34,7 @@
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
from mlflow.types.schema import ColSpec, DataType, Schema
from mlflow.utils.annotations import experimental
from mlflow.utils.class_utils import _get_class_from_string
from mlflow.utils.docstring_utils import LOG_MODEL_PARAM_DOCS, format_docstring
from mlflow.utils.environment import (
_CONDA_ENV_FILE_NAME,
Expand Down Expand Up @@ -68,6 +70,7 @@
_MODEL_TYPE_KEY = "model_type"
_LOADER_FN_FILE_NAME = "loader_fn.pkl"
_LOADER_FN_KEY = "loader_fn"
_LOADER_ARG_KEY = "loader_arg"
_PERSIST_DIR_NAME = "persist_dir_data"
_PERSIST_DIR_KEY = "persist_dir"
_UNSUPPORTED_MODEL_ERROR_MESSAGE = (
Expand Down Expand Up @@ -103,22 +106,54 @@ def get_default_conda_env():
return _mlflow_conda_env(additional_pip_deps=get_default_pip_requirements())


def _get_map_of_special_chain_class_name_to_kwargs_name():
from langchain.chains import (
APIChain,
HypotheticalDocumentEmbedder,
RetrievalQA,
SQLDatabaseChain,
)
class _SpecialChainInfo(NamedTuple):
loader_arg: str


def _get_special_chain_info_or_none(chain):
for special_chain_class, loader_arg in _get_map_of_special_chain_class_to_loader_arg().items():
if isinstance(chain, special_chain_class):
return _SpecialChainInfo(loader_arg=loader_arg)


@functools.lru_cache
def _get_map_of_special_chain_class_to_loader_arg():
import langchain
from mlflow.langchain.retriever_chain import _RetrieverChain

return {
RetrievalQA.__name__: "retriever",
APIChain.__name__: "requests_wrapper",
HypotheticalDocumentEmbedder.__name__: "embeddings",
SQLDatabaseChain.__name__: "database",
_RetrieverChain.__name__: "retriever",
class_name_to_loader_arg = {
"langchain.chains.RetrievalQA": "retriever",
"langchain.chains.APIChain": "requests_wrapper",
"langchain.chains.HypotheticalDocumentEmbedder": "embeddings",
}
# NB: SQLDatabaseChain was migrated to langchain_experimental beginning with version 0.0.247
if version.parse(langchain.__version__) <= version.parse("0.0.246"):
class_name_to_loader_arg["langchain.chains.SQLDatabaseChain"] = "database"
else:
try:
import langchain.experimental

class_name_to_loader_arg["langchain_experimental.sql.SQLDatabaseChain"] = "database"
except ImportError:
# Users may not have langchain_experimental installed, which is completely normal
pass

class_to_loader_arg = {
_RetrieverChain: "retriever",
}
for class_name, loader_arg in class_name_to_loader_arg.items():
try:
cls = _get_class_from_string(class_name)
class_to_loader_arg[cls] = loader_arg
except Exception:
logger.warning(
"Unexpected import failure for class '%s'. Please file an issue at"
" https://github.com/mlflow/mlflow/issues/.",
class_name,
exc_info=True,
)

return class_to_loader_arg


@experimental
Expand Down Expand Up @@ -297,8 +332,6 @@ def load_retriever(persist_directory):
def _validate_and_wrap_lc_model(lc_model, loader_fn):
import langchain

special_chains = _get_map_of_special_chain_class_name_to_kwargs_name()

if not isinstance(
lc_model,
(
Expand All @@ -312,25 +345,23 @@ def _validate_and_wrap_lc_model(lc_model, loader_fn):
)

_SUPPORTED_LLMS = {langchain.llms.openai.OpenAI, langchain.llms.huggingface_hub.HuggingFaceHub}
if (
isinstance(lc_model, langchain.chains.llm.LLMChain)
and type(lc_model.llm) not in _SUPPORTED_LLMS
if isinstance(lc_model, langchain.chains.llm.LLMChain) and not any(
isinstance(lc_model.llm, supported_llm) for supported_llm in _SUPPORTED_LLMS
):
logger.warning(
_UNSUPPORTED_LLM_WARNING_MESSAGE,
type(lc_model.llm).__name__,
)

if (
isinstance(lc_model, langchain.agents.agent.AgentExecutor)
and type(lc_model.agent.llm_chain.llm) not in _SUPPORTED_LLMS
if isinstance(lc_model, langchain.agents.agent.AgentExecutor) and not any(
isinstance(lc_model.agent.llm_chain.llm, supported_llm) for supported_llm in _SUPPORTED_LLMS
):
logger.warning(
_UNSUPPORTED_LLM_WARNING_MESSAGE,
type(lc_model.agent.llm_chain.llm).__name__,
)

if type(lc_model).__name__ in special_chains:
if special_chain_info := _get_special_chain_info_or_none(lc_model):
if isinstance(lc_model, langchain.chains.RetrievalQA) and version.parse(
langchain.__version__
) < version.parse("0.0.194"):
Expand All @@ -345,8 +376,8 @@ def _validate_and_wrap_lc_model(lc_model, loader_fn):
)
if not isinstance(loader_fn, types.FunctionType):
raise mlflow.MlflowException.invalid_parameter_value(
"The `loader_fn` must be a function that returns a {kwargs}.".format(
kwargs=special_chains[type(lc_model).__name__]
"The `loader_fn` must be a function that returns a {loader_arg}.".format(
loader_arg=special_chain_info.loader_arg
)
)

Expand Down Expand Up @@ -524,8 +555,6 @@ def _save_model(model, path, loader_fn, persist_dir):
model_data_path = os.path.join(path, _MODEL_DATA_FILE_NAME)
model_data_kwargs = {_MODEL_DATA_KEY: _MODEL_DATA_FILE_NAME}

special_chains = _get_map_of_special_chain_class_name_to_kwargs_name()

if isinstance(model, langchain.chains.llm.LLMChain):
model.save(model_data_path)
elif isinstance(model, langchain.agents.agent.AgentExecutor):
Expand Down Expand Up @@ -556,12 +585,13 @@ def _save_model(model, path, loader_fn, persist_dir):

model_data_kwargs[_AGENT_PRIMITIVES_DATA_KEY] = _AGENT_PRIMITIVES_FILE_NAME

elif type(model).__name__ in special_chains:
elif special_chain_info := _get_special_chain_info_or_none(model):
# Save loader_fn by pickling
loader_fn_path = os.path.join(path, _LOADER_FN_FILE_NAME)
with open(loader_fn_path, "wb") as f:
cloudpickle.dump(loader_fn, f)
model_data_kwargs[_LOADER_FN_KEY] = _LOADER_FN_FILE_NAME
model_data_kwargs[_LOADER_ARG_KEY] = special_chain_info.loader_arg

if persist_dir is not None:
if os.path.exists(persist_dir):
Expand Down Expand Up @@ -599,6 +629,7 @@ def _load_from_pickle(loader_fn_path, persist_dir):
def _load_model(
path,
model_type,
loader_arg=None,
agent_path=None,
tools_path=None,
agent_primitive_path=None,
Expand All @@ -608,15 +639,13 @@ def _load_model(
from langchain.chains.loading import load_chain
from mlflow.langchain.retriever_chain import _RetrieverChain

special_chains = _get_map_of_special_chain_class_name_to_kwargs_name()

model = None
if key := special_chains.get(model_type):
if loader_arg is not None:
if loader_fn_path is None:
raise mlflow.MlflowException.invalid_parameter_value(
"Missing file for loader_fn which is required to build the model."
)
kwargs = {key: _load_from_pickle(loader_fn_path, persist_dir)}
kwargs = {loader_arg: _load_from_pickle(loader_fn_path, persist_dir)}
if model_type == _RetrieverChain.__name__:
model = _RetrieverChain.load(path, **kwargs).retriever
else:
Expand Down Expand Up @@ -749,10 +778,12 @@ def _load_model_from_local_fs(local_model_path):
persist_dir = os.path.join(local_model_path, persist_dir_name)

model_type = flavor_conf.get(_MODEL_TYPE_KEY)
loader_arg = flavor_conf.get(_LOADER_ARG_KEY)

return _load_model(
lc_model_path,
model_type,
loader_arg,
agent_model_path,
tools_model_path,
agent_primitive_path,
Expand Down
1 change: 1 addition & 0 deletions mlflow/ml-package-versions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,7 @@ langchain:
"google-search-results",
"psutil",
"faiss-cpu",
"langchain-experimental",
]
run: |
pytest tests/langchain/test_langchain_model_export.py
Expand Down
23 changes: 22 additions & 1 deletion tests/langchain/test_langchain_model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
LLMChain,
RetrievalQA,
HypotheticalDocumentEmbedder,
SQLDatabaseChain,
)
from langchain.chains.api import open_meteo_docs
from langchain.chains.base import Chain
Expand All @@ -35,6 +34,7 @@
from langchain.requests import TextRequestsWrapper
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain_experimental.sql import SQLDatabaseChain
from pydantic import BaseModel
from pyspark.sql import SparkSession
from typing import Any, List, Mapping, Optional, Dict
Expand Down Expand Up @@ -511,6 +511,27 @@ def test_log_and_load_api_chain():
assert loaded_model == apichain


def test_log_and_load_subclass_of_specialized_chain():
class APIChainSubclass(APIChain):
pass

llm = OpenAI(temperature=0)
apichain_subclass = APIChainSubclass.from_llm_and_api_docs(
llm, open_meteo_docs.OPEN_METEO_DOCS, verbose=True
)

with mlflow.start_run():
logged_model = mlflow.langchain.log_model(
apichain_subclass,
"apichain_subclass",
loader_fn=load_requests_wrapper,
)

# Load the chain
loaded_model = mlflow.langchain.load_model(logged_model.model_uri)
assert loaded_model == apichain_subclass


def load_base_embeddings(_):
return FakeEmbeddings(size=32)

Expand Down

0 comments on commit 1397993

Please sign in to comment.