Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Фикс функций и распределение кода по папкам #233

Merged
merged 3 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 30 additions & 18 deletions libs/langchain_gigachat/langchain_gigachat/chat_models/gigachat.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@
ToolCallChunk,
ToolMessage,
)
from langchain_core.output_parsers import (
JsonOutputKeyToolsParser,
JsonOutputParser,
PydanticOutputParser,
PydanticToolsParser,
)
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
Expand All @@ -60,9 +66,7 @@
from pydantic import BaseModel

from langchain_gigachat.chat_models.base_gigachat import _BaseGigaChat
from langchain_gigachat.tools.gigachat_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
from langchain_gigachat.utils.function_calling import (
convert_to_gigachat_function,
convert_to_gigachat_tool,
)
Expand All @@ -76,7 +80,7 @@
r'<img\ssrc="(?P<UUID>.+?)"\sfuse=".+?"/>(?P<postfix>.+)?'
)
VIDEO_SEARCH_REGEX = re.compile(
r'<video\scover="(?P<cover_UUID>.+?)"\ssrc="(?P<UUID>.+?)"\sfuse="true"/>(?P<postfix>.+)?'
r'<video\scover="(?P<cover_UUID>.+?)"\ssrc="(?P<UUID>.+?)"\sfuse="true"/>(?P<postfix>.+)?' # noqa
)


Expand Down Expand Up @@ -588,22 +592,30 @@ def with_structured_output(
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = _is_pydantic_class(schema)
if schema is None:
raise ValueError(
"schema must be specified when method is 'function_calling'. "
"Received None."
)
key_name = convert_to_gigachat_tool(schema)["function"]["name"]
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], # type: ignore
first_tool_only=True,
)
if method == "function_calling":
if schema is None:
raise ValueError(
"schema must be specified when method is 'function_calling'. "
"Received None."
)
key_name = convert_to_gigachat_tool(schema)["function"]["name"]
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], # type: ignore
first_tool_only=True,
)
else:
output_parser = JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
)
llm = self.bind_tools([schema], tool_choice=key_name)
else:
output_parser = JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
llm = self
output_parser = (
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
if is_pydantic_schema
else JsonOutputParser()
)
llm = self.bind_tools([schema], tool_choice=key_name)

if include_raw:
parser_assign = RunnablePassthrough.assign(
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import copy
from types import GenericAlias
from typing import Any, Dict, List, Type, Union

from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import BaseGenerationOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from pydantic import BaseModel, model_validator


class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
"""Parse an output that is one of sets of values."""

args_only: bool = True
"""Whether to only return the arguments to the function call."""

def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
generation = result[0]
if not isinstance(generation, ChatGeneration):
raise OutputParserException(
"This output parser can only be used with a chat generation."
)
message = generation.message
try:
func_call = copy.deepcopy(message.additional_kwargs["function_call"])
except KeyError as exc:
raise OutputParserException(
f"Could not parse function call: {exc}"
) from exc

if self.args_only:
return func_call["arguments"]
return func_call


class PydanticOutputFunctionsParser(OutputFunctionsParser):
"""Parse an output as a pydantic object."""

pydantic_schema: Union[Type[BaseModel], Dict[str, Type[BaseModel]]]
"""The pydantic schema to parse the output with.

If multiple schemas are provided, then the function name will be used to
determine which schema to use.
"""

@model_validator(mode="before")
@classmethod
def validate_schema(cls, values: dict) -> Any:
"""Validate the pydantic schema.

Args:
values: The values to validate.

Returns:
The validated values.

Raises:
ValueError: If the schema is not a pydantic schema.
"""
schema = values["pydantic_schema"]
if "args_only" not in values:
values["args_only"] = (
isinstance(schema, type)
and not isinstance(schema, GenericAlias)
and issubclass(schema, BaseModel)
)
elif values["args_only"] and isinstance(schema, dict):
msg = (
"If multiple pydantic schemas are provided then args_only should be"
" False."
)
raise ValueError(msg)
return values

def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object.

Args:
result: The result of the LLM call.
partial: Whether to parse partial JSON objects. Default is False.

Returns:
The parsed JSON object.
"""
_result = super().parse_result(result)
if self.args_only:
if hasattr(self.pydantic_schema, "model_validate"):
pydantic_args = self.pydantic_schema.model_validate(_result)
else:
pydantic_args = self.pydantic_schema.parse_obj(_result) # type: ignore
else:
fn_name = _result["name"]
_args = _result["arguments"]
if isinstance(self.pydantic_schema, dict):
pydantic_schema = self.pydantic_schema[fn_name]
else:
pydantic_schema = self.pydantic_schema
if hasattr(pydantic_schema, "model_validate"):
pydantic_args = pydantic_schema.model_validate(_args) # type: ignore
else:
pydantic_args = pydantic_schema.parse_obj(_args) # type: ignore
return pydantic_args


class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
"""Parse an output as an attribute of a pydantic object."""

attr_name: str
"""The name of the attribute to return."""

def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
result = super().parse_result(result)
return getattr(result, self.attr_name)
Loading
Loading