Skip to content

Commit

Permalink
[core, langchain] modelio code improvements (#15277)
Browse files Browse the repository at this point in the history
  • Loading branch information
hwchase17 authored Dec 28, 2023
1 parent 694bbb1 commit b868031
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 7 deletions.
2 changes: 1 addition & 1 deletion libs/core/langchain_core/example_selectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class BaseExampleSelector(ABC):

@abstractmethod
def add_example(self, example: Dict[str, str]) -> Any:
"""Add new example to store for a key."""
"""Add new example to store."""

@abstractmethod
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
Expand Down
3 changes: 2 additions & 1 deletion libs/core/langchain_core/output_parsers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
BaseLLMOutputParser,
BaseOutputParser,
)
from langchain_core.output_parsers.json import SimpleJsonOutputParser
from langchain_core.output_parsers.json import JsonOutputParser, SimpleJsonOutputParser
from langchain_core.output_parsers.list import (
CommaSeparatedListOutputParser,
ListOutputParser,
Expand All @@ -30,4 +30,5 @@
"BaseCumulativeTransformOutputParser",
"SimpleJsonOutputParser",
"XMLOutputParser",
"JsonOutputParser",
]
11 changes: 11 additions & 0 deletions libs/core/langchain_core/output_parsers/format_instructions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# flake8: noqa

JSON_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below.
As an example, for the schema {{"properties": {{"foo": {{"title": "Foo", "description": "a list of strings", "type": "array", "items": {{"type": "string"}}}}}}, "required": ["foo"]}}
the object {{"foo": ["bar", "baz"]}} is a well-formatted instance of the schema. The object {{"properties": {{"foo": ["bar", "baz"]}}}} is not well-formatted.
Here is the output schema:
```
{schema}
```"""
28 changes: 26 additions & 2 deletions libs/core/langchain_core/output_parsers/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import json
import re
from json import JSONDecodeError
from typing import Any, Callable, List, Optional
from typing import Any, Callable, List, Optional, Type

import jsonpatch # type: ignore[import]

from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
from langchain_core.pydantic_v1 import BaseModel


def _replace_new_line(match: re.Match[str]) -> str:
Expand Down Expand Up @@ -170,7 +172,7 @@ def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
return json_obj


class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
"""Parse the output of an LLM call to a JSON object.
When used in streaming mode, it will yield partial JSON objects containing
Expand All @@ -180,6 +182,8 @@ class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
describing the difference between the previous and the current object.
"""

pydantic_object: Optional[Type[BaseModel]] = None

def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch

Expand All @@ -190,6 +194,26 @@ def parse(self, text: str) -> Any:
except JSONDecodeError as e:
raise OutputParserException(f"Invalid json output: {text}") from e

def get_format_instructions(self) -> str:
if self.pydantic_object is None:
return "Return a JSON object."
else:
schema = self.pydantic_object.schema()

# Remove extraneous fields.
reduced_schema = schema
if "title" in reduced_schema:
del reduced_schema["title"]
if "type" in reduced_schema:
del reduced_schema["type"]
# Ensure json in context is well-formed with double quotes.
schema_str = json.dumps(reduced_schema)
return JSON_FORMAT_INSTRUCTIONS.format(schema=schema_str)

@property
def _type(self) -> str:
return "simple_json_output_parser"


# For backwards compatibility
SimpleJsonOutputParser = JsonOutputParser
6 changes: 5 additions & 1 deletion libs/core/langchain_core/output_parsers/xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ def get_format_instructions(self) -> str:
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)

def parse(self, text: str) -> Dict[str, List[Any]]:
text = text.strip("`").strip("xml")
# Try to find XML string within triple backticks
match = re.search(r"```(xml)?(.*)```", text, re.DOTALL)
if match is not None:
# If match found, use the content within the backticks
text = match.group(2)
encoding_match = self.encoding_matcher.search(text)
if encoding_match:
text = encoding_match.group(2)
Expand Down
1 change: 1 addition & 0 deletions libs/core/tests/unit_tests/output_parsers/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"BaseCumulativeTransformOutputParser",
"SimpleJsonOutputParser",
"XMLOutputParser",
"JsonOutputParser",
]


Expand Down
8 changes: 6 additions & 2 deletions libs/langchain/langchain/output_parsers/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,12 @@ class DatetimeOutputParser(BaseOutputParser[datetime]):

def get_format_instructions(self) -> str:
examples = comma_list(_generate_random_datetime_strings(self.format))
return f"""Write a datetime string that matches the
following pattern: "{self.format}". Examples: {examples}"""
return (
f"Write a datetime string that matches the "
f"following pattern: '{self.format}'.\n\n"
f"Examples: {examples}\n\n"
f"Return ONLY this string, no other words!"
)

def parse(self, response: str) -> datetime:
try:
Expand Down

0 comments on commit b868031

Please sign in to comment.