Skip to content

Commit

Permalink
Use json-repair package to fix LLM generated json (#226)
Browse files Browse the repository at this point in the history
* Use json-repair package to fix LLM generated json

* Removed redundant repaired_json.strip()

* Rename test to test_extractor_llm_unfixable_json

* Readded llama-index to dependencies

* Removed print

* Renamed JSONRepairError to InvalidJSONError

* Use cast for repaired_json instead of isinstance

* Add InvalidJSONError hyperlink to API docs
  • Loading branch information
willtai authored Dec 10, 2024
1 parent c166afc commit 0ac06b7
Show file tree
Hide file tree
Showing 9 changed files with 1,233 additions and 1,243 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@

## Next

## Added
- Integrated json-repair package to handle and repair invalid JSON generated by LLMs.
- Introduced InvalidJSONError exception for handling cases where JSON repair fails.

## Changed
- Updated LLM prompts to include stricter instructions for generating valid JSON.

### Fixed
- Added schema functions to the documentation.

Expand Down
9 changes: 9 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,8 @@ Errors

* :class:`neo4j_graphrag.experimental.pipeline.exceptions.PipelineStatusUpdateError`

* :class:`neo4j_graphrag.experimental.pipeline.exceptions.InvalidJSONError`


Neo4jGraphRagError
==================
Expand Down Expand Up @@ -509,3 +511,10 @@ PipelineStatusUpdateError

.. autoclass:: neo4j_graphrag.experimental.pipeline.exceptions.PipelineStatusUpdateError
:show-inheritance:


InvalidJSONError
================

.. autoclass:: neo4j_graphrag.experimental.pipeline.exceptions.InvalidJSONError
:show-inheritance:
2,340 changes: 1,152 additions & 1,188 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ llama-index = {version = "^0.10.55", optional = true }
openai = {version = "^1.51.1", optional = true }
anthropic = { version = "^0.36.0", optional = true}
sentence-transformers = {version = "^3.0.0", optional = true }
json-repair = "^0.30.2"

[tool.poetry.group.dev.dependencies]
urllib3 = "<2"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
import enum
import json
import logging
import re
from datetime import datetime
from typing import Any, List, Optional, Union
from typing import Any, List, Optional, Union, cast

import json_repair

from pydantic import ValidationError, validate_call

Expand All @@ -36,6 +37,7 @@
TextChunks,
)
from neo4j_graphrag.experimental.pipeline.component import Component
from neo4j_graphrag.experimental.pipeline.exceptions import InvalidJSONError
from neo4j_graphrag.generation.prompts import ERExtractionTemplate, PromptTemplate
from neo4j_graphrag.llm import LLMInterface

Expand Down Expand Up @@ -100,28 +102,15 @@ def balance_curly_braces(json_string: str) -> str:
return "".join(fixed_json)


def fix_invalid_json(invalid_json_string: str) -> str:
# Fix missing quotes around field names
invalid_json_string = re.sub(
r"([{,]\s*)(\w+)(\s*:)", r'\1"\2"\3', invalid_json_string
)

# Fix missing quotes around string values, correctly ignoring null, true, false, and numeric values
invalid_json_string = re.sub(
r"(?<=:\s)(?!(null|true|false|\d+\.?\d*))([a-zA-Z_][a-zA-Z0-9_]*)\s*(?=[,}])",
r'"\2"',
invalid_json_string,
)

# Correct the specific issue: remove trailing commas within arrays or objects before closing braces or brackets
invalid_json_string = re.sub(r",\s*(?=[}\]])", "", invalid_json_string)
def fix_invalid_json(raw_json: str) -> str:
repaired_json = json_repair.repair_json(raw_json)
repaired_json = cast(str, repaired_json).strip()

# Normalize excessive curly braces
invalid_json_string = re.sub(r"{{+", "{", invalid_json_string)
invalid_json_string = re.sub(r"}}+", "}", invalid_json_string)

# Balance curly braces
return balance_curly_braces(invalid_json_string)
if repaired_json == '""':
raise InvalidJSONError("JSON repair resulted in an empty or invalid JSON.")
if not repaired_json:
raise InvalidJSONError("JSON repair resulted in an empty string.")
return repaired_json


class EntityRelationExtractor(Component, abc.ABC):
Expand Down Expand Up @@ -223,24 +212,18 @@ async def extract_for_chunk(
)
llm_result = await self.llm.ainvoke(prompt)
try:
result = json.loads(llm_result.content)
except json.JSONDecodeError:
logger.info(
f"LLM response is not valid JSON {llm_result.content} for chunk_index={chunk.index}. Trying to fix it."
)
fixed_content = fix_invalid_json(llm_result.content)
try:
result = json.loads(fixed_content)
except json.JSONDecodeError as e:
if self.on_error == OnError.RAISE:
raise LLMGenerationError(
f"LLM response is not valid JSON {fixed_content}: {e}"
)
else:
logger.error(
f"LLM response is not valid JSON {llm_result.content} for chunk_index={chunk.index}"
)
result = {"nodes": [], "relationships": []}
llm_generated_json = fix_invalid_json(llm_result.content)
result = json.loads(llm_generated_json)
except (json.JSONDecodeError, InvalidJSONError) as e:
if self.on_error == OnError.RAISE:
raise LLMGenerationError(
f"LLM response is not valid JSON {llm_result.content}: {e}"
)
else:
logger.error(
f"LLM response is not valid JSON {llm_result.content} for chunk_index={chunk.index}"
)
result = {"nodes": [], "relationships": []}
try:
chunk_graph = Neo4jGraph(**result)
except ValidationError as e:
Expand Down
6 changes: 6 additions & 0 deletions src/neo4j_graphrag/experimental/pipeline/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,9 @@ class PipelineStatusUpdateError(Neo4jGraphRagError):
"""Raises when trying an invalid change of state (e.g. DONE => DOING)"""

pass


class InvalidJSONError(Neo4jGraphRagError):
"""Raised when JSON repair fails to produce valid JSON."""

pass
6 changes: 5 additions & 1 deletion src/neo4j_graphrag/generation/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,11 @@ class ERExtractionTemplate(PromptTemplate):
Do respect the source and target node types for relationship and
the relationship direction.
Do not return any additional information other than the JSON in it.
Make sure you adhere to the following rules to produce valid JSON objects:
- Do not return any additional information other than the JSON in it.
- Omit any backticks around the JSON - simply output the JSON on its own.
- The JSON object must not wrapped into a list - it is its own JSON object.
- Property names must be enclosed in double quotes
Examples:
{examples}
Expand Down
18 changes: 10 additions & 8 deletions src/neo4j_graphrag/llm/mistralai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,11 @@ def invoke(self, input: str) -> LLMResponse:
messages=self.get_messages(input),
**self.model_params,
)
if response is None or response.choices is None or not response.choices:
content = ""
else:
content = response.choices[0].message.content or ""
content: str = ""
if response and response.choices:
possible_content = response.choices[0].message.content
if isinstance(possible_content, str):
content = possible_content
return LLMResponse(content=content)
except SDKError as e:
raise LLMGenerationError(e)
Expand All @@ -111,10 +112,11 @@ async def ainvoke(self, input: str) -> LLMResponse:
messages=self.get_messages(input),
**self.model_params,
)
if response is None or response.choices is None or not response.choices:
content = ""
else:
content = response.choices[0].message.content or ""
content: str = ""
if response and response.choices:
possible_content = response.choices[0].message.content
if isinstance(possible_content, str):
content = possible_content
return LLMResponse(content=content)
except SDKError as e:
raise LLMGenerationError(e)
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

import json
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import pytest
from neo4j_graphrag.exceptions import LLMGenerationError
Expand All @@ -31,6 +31,7 @@
TextChunk,
TextChunks,
)
from neo4j_graphrag.experimental.pipeline.exceptions import InvalidJSONError
from neo4j_graphrag.llm import LLMInterface, LLMResponse


Expand Down Expand Up @@ -144,16 +145,17 @@ async def test_extractor_llm_ainvoke_failed() -> None:


@pytest.mark.asyncio
async def test_extractor_llm_badly_formatted_json() -> None:
async def test_extractor_llm_unfixable_json() -> None:
llm = MagicMock(spec=LLMInterface)
llm.ainvoke.return_value = LLMResponse(
content='{"nodes": [{"id": "0", "label": "Person", "properties": {}}], "relationships": [}'
content='{"nodes": [{"id": "0", "label": "Person", "properties": {}}], "relationships": }'
)

extractor = LLMEntityRelationExtractor(
llm=llm,
)
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])

with pytest.raises(LLMGenerationError):
await extractor.run(chunks=chunks)

Expand All @@ -177,7 +179,7 @@ async def test_extractor_llm_invalid_json() -> None:


@pytest.mark.asyncio
async def test_extractor_llm_badly_formatted_json_do_not_raise() -> None:
async def test_extractor_llm_badly_formatted_json_gets_fixed() -> None:
llm = MagicMock(spec=LLMInterface)
llm.ainvoke.return_value = LLMResponse(
content='{"nodes": [{"id": "0", "label": "Person", "properties": {}}], "relationships": [}'
Expand All @@ -190,7 +192,11 @@ async def test_extractor_llm_badly_formatted_json_do_not_raise() -> None:
)
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])
res = await extractor.run(chunks=chunks)
assert res.nodes == []

assert len(res.nodes) == 1
assert res.nodes[0].label == "Person"
assert res.nodes[0].properties == {"chunk_index": 0}
assert res.nodes[0].embedding_properties is None
assert res.relationships == []


Expand All @@ -205,6 +211,14 @@ async def test_extractor_custom_prompt() -> None:
llm.ainvoke.assert_called_once_with("this is my prompt")


def test_fix_invalid_json_empty_result() -> None:
json_string = "invalid json"

with patch("json_repair.repair_json", return_value=""):
with pytest.raises(InvalidJSONError):
fix_invalid_json(json_string)


def test_fix_unquoted_keys() -> None:
json_string = '{name: "John", age: "30"}'
expected_result = '{"name": "John", "age": "30"}'
Expand Down

0 comments on commit 0ac06b7

Please sign in to comment.