Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core[major]: Upgrade langchain-core to pydantic 2 #25986

Merged
merged 14 commits into from
Sep 3, 2024
1 change: 0 additions & 1 deletion libs/core/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test

lint lint_diff lint_package lint_tests:
./scripts/check_pydantic.sh .
./scripts/lint_imports.sh
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
Expand Down
7 changes: 5 additions & 2 deletions libs/core/langchain_core/beta/runnables/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
Union,
)

from pydantic import ConfigDict

from langchain_core._api.beta_decorator import beta
from langchain_core.runnables.base import (
Runnable,
Expand Down Expand Up @@ -229,8 +231,9 @@ class ContextSet(RunnableSerializable):

keys: Mapping[str, Optional[Runnable]]

class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)

def __init__(
self,
Expand Down
3 changes: 2 additions & 1 deletion libs/core/langchain_core/chat_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
from abc import ABC, abstractmethod
from typing import List, Sequence, Union

from pydantic import BaseModel, Field

from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
get_buffer_string,
)
from langchain_core.pydantic_v1 import BaseModel, Field


class BaseChatMessageHistory(ABC):
Expand Down
22 changes: 15 additions & 7 deletions libs/core/langchain_core/documents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import mimetypes
from io import BufferedReader, BytesIO
from pathlib import PurePath
from typing import Any, Generator, List, Literal, Mapping, Optional, Union, cast
from typing import Any, Dict, Generator, List, Literal, Optional, Union, cast

from pydantic import ConfigDict, Field, model_validator

from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils.pydantic import v1_repr

PathLike = Union[str, PurePath]

Expand Down Expand Up @@ -110,9 +112,10 @@ class Blob(BaseMedia):
path: Optional[PathLike] = None
"""Location where the original content was found."""

class Config:
arbitrary_types_allowed = True
frozen = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
frozen=True,
)

@property
def source(self) -> Optional[str]:
Expand All @@ -127,8 +130,9 @@ def source(self) -> Optional[str]:
return cast(Optional[str], self.metadata["source"])
return str(self.path) if self.path else None

@root_validator(pre=True)
def check_blob_is_valid(cls, values: Mapping[str, Any]) -> Mapping[str, Any]:
@model_validator(mode="before")
@classmethod
def check_blob_is_valid(cls, values: Dict[str, Any]) -> Any:
"""Verify that either data or path is provided."""
if "data" not in values and "path" not in values:
raise ValueError("Either data or path must be provided")
Expand Down Expand Up @@ -293,3 +297,7 @@ def __str__(self) -> str:
return f"page_content='{self.page_content}' metadata={self.metadata}"
else:
return f"page_content='{self.page_content}'"

def __repr__(self) -> str:
# TODO(0.3): Remove this override after confirming unit tests!
return v1_repr(self)
3 changes: 2 additions & 1 deletion libs/core/langchain_core/documents/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from abc import ABC, abstractmethod
from typing import Optional, Sequence

from pydantic import BaseModel

from langchain_core.callbacks import Callbacks
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import run_in_executor


Expand Down
3 changes: 2 additions & 1 deletion libs/core/langchain_core/embeddings/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import hashlib
from typing import List

from pydantic import BaseModel

from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel


class FakeEmbeddings(Embeddings, BaseModel):
Expand Down
3 changes: 2 additions & 1 deletion libs/core/langchain_core/example_selectors/length_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import re
from typing import Callable, Dict, List

from pydantic import BaseModel, validator

from langchain_core.example_selectors.base import BaseExampleSelector
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, validator


def _get_length_based(text: str) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from abc import ABC
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type

from pydantic import BaseModel, ConfigDict

from langchain_core.documents import Document
from langchain_core.example_selectors.base import BaseExampleSelector
from langchain_core.pydantic_v1 import BaseModel, Extra
from langchain_core.vectorstores import VectorStore

if TYPE_CHECKING:
Expand Down Expand Up @@ -42,9 +43,10 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
vectorstore_kwargs: Optional[Dict[str, Any]] = None
"""Extra arguments passed to similarity_search function of the vectorstore."""

class Config:
arbitrary_types_allowed = True
extra = Extra.forbid
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)

@staticmethod
def _example_to_text(
Expand Down
3 changes: 2 additions & 1 deletion libs/core/langchain_core/graph_vectorstores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
Optional,
)

from pydantic import Field

from langchain_core._api import beta
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
Expand All @@ -20,7 +22,6 @@
from langchain_core.documents import Document
from langchain_core.graph_vectorstores.links import METADATA_LINKS_KEY, Link
from langchain_core.load import Serializable
from langchain_core.pydantic_v1 import Field
from langchain_core.runnables import run_in_executor
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever

Expand Down
8 changes: 5 additions & 3 deletions libs/core/langchain_core/indexing/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
cast,
)

from pydantic import model_validator

from langchain_core.document_loaders.base import BaseLoader
from langchain_core.documents import Document
from langchain_core.indexing.base import DocumentIndex, RecordManager
from langchain_core.pydantic_v1 import root_validator
from langchain_core.vectorstores import VectorStore

# Magic UUID to use as a namespace for hashing.
Expand Down Expand Up @@ -68,8 +69,9 @@ class _HashedDocument(Document):
def is_lc_serializable(cls) -> bool:
return False

@root_validator(pre=True)
def calculate_hashes(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@model_validator(mode="before")
@classmethod
def calculate_hashes(cls, values: Dict[str, Any]) -> Any:
"""Root validator to calculate content and metadata hash."""
content = values.get("page_content", "")
metadata = values.get("metadata", {})
Expand Down
3 changes: 2 additions & 1 deletion libs/core/langchain_core/indexing/in_memory.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import uuid
from typing import Any, Dict, List, Optional, Sequence, cast

from pydantic import Field

from langchain_core._api import beta
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.indexing import UpsertResponse
from langchain_core.indexing.base import DeleteResponse, DocumentIndex
from langchain_core.pydantic_v1 import Field


@beta(message="Introduced in version 0.2.29. Underlying abstraction subject to change.")
Expand Down
12 changes: 10 additions & 2 deletions libs/core/langchain_core/language_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Union,
)

from pydantic import BaseModel, ConfigDict, Field, validator
from typing_extensions import TypeAlias, TypedDict

from langchain_core._api import deprecated
Expand All @@ -28,7 +29,6 @@
get_buffer_string,
)
from langchain_core.prompt_values import PromptValue
from langchain_core.pydantic_v1 import BaseModel, Field, validator
from langchain_core.runnables import Runnable, RunnableSerializable
from langchain_core.utils import get_pydantic_field_names

Expand Down Expand Up @@ -113,7 +113,11 @@ class BaseLanguageModel(

Caching is not currently supported for streaming methods of models.
"""
verbose: bool = Field(default_factory=_get_verbosity)
# Repr = False is consistent with pydantic 1 if verbose = False
# We can relax this for pydantic 2?
# TODO(Team): decide what to do here.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update todo since we know we want to undo this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm not mistaken the issue was that this ends up affecting caching behavior with verbose=False and verbose=None being cached differently -- i'll double check, but i updated a TODO(0.3) for us to resolve

eyurtsev marked this conversation as resolved.
Show resolved Hide resolved
# Modified just to get unit tests to pass.
verbose: bool = Field(default_factory=_get_verbosity, exclude=True, repr=False)
"""Whether to print out response text."""
callbacks: Callbacks = Field(default=None, exclude=True)
"""Callbacks to add to the run trace."""
Expand All @@ -126,6 +130,10 @@ class BaseLanguageModel(
)
"""Optional encoder to use for counting tokens."""

model_config = ConfigDict(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this new?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's required when using pydantic 2 since cache is an attribute on the chat model and the cache is not a base model

arbitrary_types_allowed=True,
)

@validator("verbose", pre=True, always=True, allow_reuse=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
"""If verbose is None, set it.
Expand Down
44 changes: 27 additions & 17 deletions libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@
cast,
)

from pydantic import (
BaseModel,
ConfigDict,
Field,
model_validator,
)

from langchain_core._api import deprecated
from langchain_core.caches import BaseCache
from langchain_core.callbacks import (
Expand Down Expand Up @@ -57,11 +64,6 @@
RunInfo,
)
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
root_validator,
)
from langchain_core.rate_limiters import BaseRateLimiter
from langchain_core.runnables import RunnableMap, RunnablePassthrough
from langchain_core.runnables.config import ensure_config, run_in_executor
Expand Down Expand Up @@ -193,14 +195,20 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):

""" # noqa: E501

callback_manager: Optional[BaseCallbackManager] = deprecated(
name="callback_manager", since="0.1.7", removal="1.0", alternative="callbacks"
)(
Field(
default=None,
exclude=True,
description="Callback manager to add to the run trace.",
)
# TODO(0.3): Figure out how to re-apply deprecated decorator
# callback_manager: Optional[BaseCallbackManager] = deprecated(
# name="callback_manager", since="0.1.7", removal="1.0", alternative="callbacks"
# )(
# Field(
# default=None,
# exclude=True,
# description="Callback manager to add to the run trace.",
# )
# )
callback_manager: Optional[BaseCallbackManager] = Field(
default=None,
exclude=True,
description="Callback manager to add to the run trace.",
)

rate_limiter: Optional[BaseRateLimiter] = Field(default=None, exclude=True)
Expand All @@ -218,8 +226,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
- If False (default), will always use streaming case if available.
"""

@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
@model_validator(mode="before")
@classmethod
def raise_deprecation(cls, values: Dict) -> Any:
"""Raise deprecation warning if callback_manager is used.

Args:
Expand All @@ -240,8 +249,9 @@ def raise_deprecation(cls, values: Dict) -> Dict:
values["callbacks"] = values.pop("callback_manager", None)
return values

class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)

# --- Runnable methods ---

Expand Down
12 changes: 7 additions & 5 deletions libs/core/langchain_core/language_models/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)

import yaml
from pydantic import ConfigDict, Field, model_validator
from tenacity import (
RetryCallState,
before_sleep_log,
Expand Down Expand Up @@ -62,7 +63,6 @@
)
from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.runnables import RunnableConfig, ensure_config, get_config_list
from langchain_core.runnables.config import run_in_executor

Expand Down Expand Up @@ -300,11 +300,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""[DEPRECATED]"""

class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)

@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
@model_validator(mode="before")
@classmethod
def raise_deprecation(cls, values: Dict) -> Any:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:
warnings.warn(
Expand Down
Loading
Loading