diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a4b5e89..3bf39dd6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ ### Changed - Updated LLM implementations to handle message history consistently across providers. - The `id_prefix` parameter in the `LexicalGraphConfig` is deprecated. +- Changed the default behaviour of `FixedSizeSplitter` to avoid words cut-off in the chunks whenever it is possible. ### Fixed - IDs for the Document and Chunk nodes in the lexical graph are now randomly generated and unique across multiple runs, fixing issues in the lexical graph where relationships were created between chunks that were created by different pipeline runs. diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index c4559072..ea91ec2c 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -581,9 +581,12 @@ that can be processed within the LLM token limits: from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import FixedSizeSplitter - splitter = FixedSizeSplitter(chunk_size=4000, chunk_overlap=200) + splitter = FixedSizeSplitter(chunk_size=4000, chunk_overlap=200, approximate=False) splitter.run(text="Hello World. Life is beautiful.") +.. note:: + + `approximate` flag is by default set to True to ensure clean chunk start and end (i.e. avoid words cut in the middle) whenever it is possible. Wrappers for LangChain and LlamaIndex text splitters are included in this package: diff --git a/examples/customize/build_graph/components/splitters/fixed_size_splitter.py b/examples/customize/build_graph/components/splitters/fixed_size_splitter.py index 8b2f2cc1..0b97f393 100644 --- a/examples/customize/build_graph/components/splitters/fixed_size_splitter.py +++ b/examples/customize/build_graph/components/splitters/fixed_size_splitter.py @@ -6,9 +6,10 @@ async def main() -> TextChunks: splitter = FixedSizeSplitter( - # optionally, configure chunk_size and chunk_overlap + # optionally, configure chunk_size, chunk_overlap, and approximate flag # chunk_size=4000, # chunk_overlap=200, + # approximate = False ) chunks = await splitter.run(text="text to split") return chunks diff --git a/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py b/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py index e81d482c..ab11206d 100644 --- a/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py +++ b/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py @@ -17,7 +17,6 @@ import asyncio import logging -import neo4j from neo4j_graphrag.experimental.components.entity_relation_extractor import ( LLMEntityRelationExtractor, OnError, @@ -35,6 +34,8 @@ from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult from neo4j_graphrag.llm import LLMInterface, OpenAILLM +import neo4j + logging.basicConfig(level=logging.INFO) @@ -83,7 +84,8 @@ async def define_and_run_pipeline( pipe = Pipeline() pipe.add_component(PdfLoader(), "pdf_loader") pipe.add_component( - FixedSizeSplitter(chunk_size=4000, chunk_overlap=200), "splitter" + FixedSizeSplitter(chunk_size=4000, chunk_overlap=200, approximate=False), + "splitter", ) pipe.add_component(SchemaBuilder(), "schema") pipe.add_component( diff --git a/examples/customize/build_graph/pipeline/kg_builder_from_text.py b/examples/customize/build_graph/pipeline/kg_builder_from_text.py index 2beed124..907a0282 100644 --- a/examples/customize/build_graph/pipeline/kg_builder_from_text.py +++ b/examples/customize/build_graph/pipeline/kg_builder_from_text.py @@ -16,7 +16,6 @@ import asyncio -import neo4j from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder from neo4j_graphrag.experimental.components.entity_relation_extractor import ( @@ -37,6 +36,8 @@ from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult from neo4j_graphrag.llm import LLMInterface, OpenAILLM +import neo4j + async def define_and_run_pipeline( neo4j_driver: neo4j.Driver, llm: LLMInterface @@ -58,7 +59,7 @@ async def define_and_run_pipeline( # define the components pipe.add_component( # chunk_size=50 for the sake of this demo - FixedSizeSplitter(chunk_size=4000, chunk_overlap=200), + FixedSizeSplitter(chunk_size=4000, chunk_overlap=200, approximate=False), "splitter", ) pipe.add_component(TextChunkEmbedder(embedder=OpenAIEmbeddings()), "chunk_embedder") diff --git a/examples/customize/build_graph/pipeline/lexical_graph_builder_from_text.py b/examples/customize/build_graph/pipeline/lexical_graph_builder_from_text.py index 4c332802..c2fbdec4 100644 --- a/examples/customize/build_graph/pipeline/lexical_graph_builder_from_text.py +++ b/examples/customize/build_graph/pipeline/lexical_graph_builder_from_text.py @@ -2,7 +2,6 @@ import asyncio -import neo4j from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder from neo4j_graphrag.experimental.components.kg_writer import Neo4jWriter @@ -14,6 +13,8 @@ from neo4j_graphrag.experimental.pipeline import Pipeline from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult +import neo4j + async def main(neo4j_driver: neo4j.Driver) -> PipelineResult: """This is where we define and run the Lexical Graph builder pipeline, instantiating @@ -27,7 +28,7 @@ async def main(neo4j_driver: neo4j.Driver) -> PipelineResult: pipe = Pipeline() # define the components pipe.add_component( - FixedSizeSplitter(chunk_size=20, chunk_overlap=1), + FixedSizeSplitter(chunk_size=20, chunk_overlap=1, approximate=False), "splitter", ) pipe.add_component(TextChunkEmbedder(embedder=OpenAIEmbeddings()), "chunk_embedder") diff --git a/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py b/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py index b7ba60e1..6867d906 100644 --- a/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py +++ b/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py @@ -7,7 +7,6 @@ import asyncio -import neo4j from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder from neo4j_graphrag.experimental.components.entity_relation_extractor import ( @@ -29,6 +28,8 @@ from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult from neo4j_graphrag.llm import LLMInterface, OpenAILLM +import neo4j + async def define_and_run_pipeline( neo4j_driver: neo4j.Driver, @@ -56,7 +57,7 @@ async def define_and_run_pipeline( pipe = Pipeline() # define the components pipe.add_component( - FixedSizeSplitter(chunk_size=200, chunk_overlap=50), + FixedSizeSplitter(chunk_size=200, chunk_overlap=50, approximate=False), "splitter", ) pipe.add_component(TextChunkEmbedder(embedder=OpenAIEmbeddings()), "chunk_embedder") diff --git a/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py b/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py index 7cfa4cbb..0fd354db 100644 --- a/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py +++ b/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py @@ -8,7 +8,6 @@ import asyncio -import neo4j from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder from neo4j_graphrag.experimental.components.entity_relation_extractor import ( @@ -31,6 +30,8 @@ from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult from neo4j_graphrag.llm import LLMInterface, OpenAILLM +import neo4j + async def build_lexical_graph( neo4j_driver: neo4j.Driver, @@ -47,7 +48,7 @@ async def build_lexical_graph( pipe = Pipeline() # define the components pipe.add_component( - FixedSizeSplitter(chunk_size=200, chunk_overlap=50), + FixedSizeSplitter(chunk_size=200, chunk_overlap=50, approximate=False), "splitter", ) pipe.add_component(TextChunkEmbedder(embedder=OpenAIEmbeddings()), "chunk_embedder") diff --git a/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py b/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py index 6add30d8..515387f4 100644 --- a/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py +++ b/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py @@ -18,12 +18,66 @@ from neo4j_graphrag.experimental.components.types import TextChunk, TextChunks +def _adjust_chunk_start(text: str, approximate_start: int) -> int: + """ + Shift the starting index backward if it lands in the middle of a word. + If no whitespace is found, use the proposed start. + + Args: + text (str): The text being split. + approximate_start (int): The initial starting index of the chunk. + + Returns: + int: The adjusted starting index, ensuring the chunk does not begin in the + middle of a word if possible. + """ + start = approximate_start + if start > 0 and not text[start].isspace() and not text[start - 1].isspace(): + while start > 0 and not text[start - 1].isspace(): + start -= 1 + + # fallback if no whitespace is found + if start == 0 and not text[0].isspace(): + start = approximate_start + return start + + +def _adjust_chunk_end(text: str, start: int, approximate_end: int) -> int: + """ + Shift the ending index backward if it lands in the middle of a word. + If no whitespace is found, use 'approximate_end'. + + Args: + text (str): The full text being split. + start (int): The adjusted starting index for this chunk. + approximate_end (int): The initial end index. + + Returns: + int: The adjusted ending index, ensuring the chunk does not end in the middle of + a word if possible. + """ + end = approximate_end + if end < len(text): + while end > start and not text[end].isspace() and not text[end - 1].isspace(): + end -= 1 + + # fallback if no whitespace is found + if end == start: + end = approximate_end + return end + + class FixedSizeSplitter(TextSplitter): - """Text splitter which splits the input text into fixed size chunks with optional overlap. + """Text splitter which splits the input text into fixed or approximate fixed size + chunks with optional overlap. Args: chunk_size (int): The number of characters in each chunk. - chunk_overlap (int): The number of characters from the previous chunk to overlap with each chunk. Must be less than `chunk_size`. + chunk_overlap (int): The number of characters from the previous chunk to overlap + with each chunk. Must be less than `chunk_size`. + approximate (bool): If True, avoids splitting words in the middle at chunk + boundaries. Defaults to True. + Example: @@ -33,16 +87,21 @@ class FixedSizeSplitter(TextSplitter): from neo4j_graphrag.experimental.pipeline import Pipeline pipeline = Pipeline() - text_splitter = FixedSizeSplitter(chunk_size=4000, chunk_overlap=200) + text_splitter = FixedSizeSplitter(chunk_size=4000, chunk_overlap=200, approximate=True) pipeline.add_component(text_splitter, "text_splitter") """ @validate_call - def __init__(self, chunk_size: int = 4000, chunk_overlap: int = 200) -> None: + def __init__( + self, chunk_size: int = 4000, chunk_overlap: int = 200, approximate: bool = True + ) -> None: + if chunk_size <= 0: + raise ValueError("chunk_size must be strictly greater than 0") if chunk_overlap >= chunk_size: raise ValueError("chunk_overlap must be strictly less than chunk_size") self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap + self.approximate = approximate @validate_call async def run(self, text: str) -> TextChunks: @@ -56,10 +115,35 @@ async def run(self, text: str) -> TextChunks: """ chunks = [] index = 0 - for i in range(0, len(text), self.chunk_size - self.chunk_overlap): - start = i - end = min(start + self.chunk_size, len(text)) + step = self.chunk_size - self.chunk_overlap + text_length = len(text) + approximate_start = 0 + skip_adjust_chunk_start = False + end = 0 + + while end < text_length: + if self.approximate: + start = ( + approximate_start + if skip_adjust_chunk_start + else _adjust_chunk_start(text, approximate_start) + ) + # adjust start and end to avoid cutting words in the middle + approximate_end = min(start + self.chunk_size, text_length) + end = _adjust_chunk_end(text, start, approximate_end) + # when avoiding splitting words in the middle is not possible, revert to + # initial chunk end and skip adjusting next chunk start + skip_adjust_chunk_start = end == approximate_end + else: + # apply fixed size splitting with possibly words cut in half at chunk + # boundaries + start = approximate_start + end = min(start + self.chunk_size, text_length) + chunk_text = text[start:end] chunks.append(TextChunk(text=chunk_text, index=index)) index += 1 + + approximate_start = start + step + return TextChunks(chunks=chunks) diff --git a/tests/unit/experimental/components/text_splitters/test_fixed_size_splitter.py b/tests/unit/experimental/components/text_splitters/test_fixed_size_splitter.py index 0467201f..d03006dd 100644 --- a/tests/unit/experimental/components/text_splitters/test_fixed_size_splitter.py +++ b/tests/unit/experimental/components/text_splitters/test_fixed_size_splitter.py @@ -17,6 +17,8 @@ import pytest from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( FixedSizeSplitter, + _adjust_chunk_end, + _adjust_chunk_start, ) from neo4j_graphrag.experimental.components.types import TextChunk @@ -26,7 +28,8 @@ async def test_split_text_no_overlap() -> None: text = "may thy knife chip and shatter" chunk_size = 5 chunk_overlap = 0 - splitter = FixedSizeSplitter(chunk_size, chunk_overlap) + approximate = False + splitter = FixedSizeSplitter(chunk_size, chunk_overlap, approximate) chunks = await splitter.run(text) expected_chunks = [ TextChunk(text="may t", index=0), @@ -47,7 +50,8 @@ async def test_split_text_with_overlap() -> None: text = "may thy knife chip and shatter" chunk_size = 10 chunk_overlap = 2 - splitter = FixedSizeSplitter(chunk_size, chunk_overlap) + approximate = False + splitter = FixedSizeSplitter(chunk_size, chunk_overlap, approximate) chunks = await splitter.run(text) expected_chunks = [ TextChunk(text="may thy kn", index=0), @@ -66,7 +70,8 @@ async def test_split_text_empty_string() -> None: text = "" chunk_size = 5 chunk_overlap = 1 - splitter = FixedSizeSplitter(chunk_size, chunk_overlap) + approximate = False + splitter = FixedSizeSplitter(chunk_size, chunk_overlap, approximate) chunks = await splitter.run(text) assert chunks.chunks == [] @@ -75,3 +80,135 @@ def test_invalid_chunk_overlap() -> None: with pytest.raises(ValueError) as excinfo: FixedSizeSplitter(5, 5) assert "chunk_overlap must be strictly less than chunk_size" in str(excinfo) + + +def test_invalid_chunk_size() -> None: + with pytest.raises(ValueError) as excinfo: + FixedSizeSplitter(0, 0) + assert "chunk_size must be strictly greater than 0" in str(excinfo) + + +@pytest.mark.parametrize( + "text, approximate_start, expected_start", + [ + # Case: approximate_start is at word boundary already + ("Hello World", 6, 6), + # Case: approximate_start is at a whitespace already + ("Hello World", 5, 5), + # Case: approximate_start is at the middle of word and no whitespace is found + ("Hello World", 2, 2), + # Case: approximate_start is at the middle of a word + ("Hello World", 8, 6), + # Case: approximate_start = 0 + ("Hello World", 0, 0), + ], +) +def test_adjust_chunk_start( + text: str, approximate_start: int, expected_start: int +) -> None: + """ + Test that the _adjust_chunk_start function correctly shifts + the start index to avoid breaking words, unless no whitespace is found. + """ + result = _adjust_chunk_start(text, approximate_start) + assert result == expected_start + + +@pytest.mark.parametrize( + "text, start, approximate_end, expected_end", + [ + # Case: approximate_end is at word boundary already + ("Hello World", 0, 5, 5), + # Case: approximate_end is at the middle of a word + ("Hello World", 0, 8, 6), + # Case: approximate_end is at the middle of word and no whitespace is found + ("Hello World", 0, 3, 3), + # Case: adjusted_end == start => fallback to approximate_end + ("Hello World", 6, 7, 7), + # Case: end>=len(text) + ("Hello World", 6, 15, 15), + ], +) +def test_adjust_chunk_end( + text: str, start: int, approximate_end: int, expected_end: int +) -> None: + """ + Test that the _adjust_chunk_end function correctly shifts + the end index to avoid breaking words, unless no whitespace is found. + """ + result = _adjust_chunk_end(text, start, approximate_end) + assert result == expected_end + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "text, chunk_size, chunk_overlap, approximate, expected_chunks", + [ + # Case: approximate fixed size splitting + ( + "Hello World, this is a test message.", + 10, + 2, + True, + ["Hello ", "World, ", "this is a ", "a test ", "message."], + ), + # Case: fixed size splitting + ( + "Hello World, this is a test message.", + 10, + 2, + False, + ["Hello Worl", "rld, this ", "s is a tes", "est messag", "age."], + ), + # Case: short text => only one chunk + ( + "Short text", + 20, + 5, + True, + ["Short text"], + ), + # Case: short text => only one chunk + ( + "Short text", + 12, + 4, + True, + ["Short text"], + ), + # Case: text with no spaces + ( + "1234567890", + 5, + 1, + True, + ["12345", "56789", "90"], + ), + ], +) +async def test_fixed_size_splitter_run( + text: str, + chunk_size: int, + chunk_overlap: int, + approximate: bool, + expected_chunks: list[str], +) -> None: + """ + Test that 'FixedSizeSplitter.run' returns the expected chunks + for different configurations. + """ + splitter = FixedSizeSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + approximate=approximate, + ) + text_chunks = await splitter.run(text) + + # Verify number of chunks + assert len(text_chunks.chunks) == len(expected_chunks) + + # Verify content of each chunk + for i, expected_text in enumerate(expected_chunks): + assert text_chunks.chunks[i].text == expected_text + assert isinstance(text_chunks.chunks[i], TextChunk) + assert text_chunks.chunks[i].index == i