Skip to content

Commit

Permalink
Handle cases where the LLM produces a valid JSON array
Browse files Browse the repository at this point in the history
  • Loading branch information
stellasia committed Jan 31, 2025
1 parent d7d6674 commit bedf73d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ async def extract_for_chunk(
logger.debug(f"Invalid JSON: {llm_result.content}")
result = {"nodes": [], "relationships": []}
try:
chunk_graph = Neo4jGraph(**result)
chunk_graph = Neo4jGraph.model_validate(result)
except ValidationError as e:
if self.on_error == OnError.RAISE:
raise LLMGenerationError("LLM response has improper format") from e
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,24 @@ async def test_extractor_llm_invalid_json() -> None:
await extractor.run(chunks=chunks)


@pytest.mark.asyncio
async def test_extractor_llm_invalid_json_is_a_list() -> None:
"""Test what happens when the returned JSON is a valid JSON list,
but it does not match the expected Pydantic model"""
llm = MagicMock(spec=LLMInterface)
llm.ainvoke.return_value = LLMResponse(
# missing "label" for entity
content='[{"nodes": [{"id": 0, "entity_type": "Person", "properties": {}}], "relationships": []}]'
)

extractor = LLMEntityRelationExtractor(
llm=llm,
)
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])
with pytest.raises(LLMGenerationError):
await extractor.run(chunks=chunks)


@pytest.mark.asyncio
async def test_extractor_llm_badly_formatted_json_gets_fixed() -> None:
llm = MagicMock(spec=LLMInterface)
Expand Down

0 comments on commit bedf73d

Please sign in to comment.