Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve text splitter to avoid cutting words in chunks #242

Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
### Changed
- Updated LLM implementations to handle message history consistently across providers.
- The `id_prefix` parameter in the `LexicalGraphConfig` is deprecated.
- Changed the default behaviour of `FixedSizeSplitter` to avoid words cut-off in the chunks whenever it is possible.

### Fixed
- IDs for the Document and Chunk nodes in the lexical graph are now randomly generated and unique across multiple runs, fixing issues in the lexical graph where relationships were created between chunks that were created by different pipeline runs.
Expand Down
5 changes: 4 additions & 1 deletion docs/source/user_guide_kg_builder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -581,9 +581,12 @@ that can be processed within the LLM token limits:

from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import FixedSizeSplitter

splitter = FixedSizeSplitter(chunk_size=4000, chunk_overlap=200)
splitter = FixedSizeSplitter(chunk_size=4000, chunk_overlap=200, approximate=False)
splitter.run(text="Hello World. Life is beautiful.")

.. note::

`approximate` flag is by default set to True to ensure clean chunk start and end (i.e. avoid words cut in the middle) whenever it is possible.

Wrappers for LangChain and LlamaIndex text splitters are included in this package:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

async def main() -> TextChunks:
splitter = FixedSizeSplitter(
# optionally, configure chunk_size and chunk_overlap
# optionally, configure chunk_size, chunk_overlap, and approximate flag
# chunk_size=4000,
# chunk_overlap=200,
# approximate = False
)
chunks = await splitter.run(text="text to split")
return chunks
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import asyncio
import logging

import neo4j
from neo4j_graphrag.experimental.components.entity_relation_extractor import (
LLMEntityRelationExtractor,
OnError,
Expand All @@ -35,6 +34,8 @@
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
from neo4j_graphrag.llm import LLMInterface, OpenAILLM

import neo4j

logging.basicConfig(level=logging.INFO)


Expand Down Expand Up @@ -83,7 +84,8 @@ async def define_and_run_pipeline(
pipe = Pipeline()
pipe.add_component(PdfLoader(), "pdf_loader")
pipe.add_component(
FixedSizeSplitter(chunk_size=4000, chunk_overlap=200), "splitter"
FixedSizeSplitter(chunk_size=4000, chunk_overlap=200, approximate=False),
"splitter",
)
pipe.add_component(SchemaBuilder(), "schema")
pipe.add_component(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import asyncio

import neo4j
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
from neo4j_graphrag.experimental.components.entity_relation_extractor import (
Expand All @@ -37,6 +36,8 @@
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
from neo4j_graphrag.llm import LLMInterface, OpenAILLM

import neo4j


async def define_and_run_pipeline(
neo4j_driver: neo4j.Driver, llm: LLMInterface
Expand All @@ -58,7 +59,7 @@ async def define_and_run_pipeline(
# define the components
pipe.add_component(
# chunk_size=50 for the sake of this demo
FixedSizeSplitter(chunk_size=4000, chunk_overlap=200),
FixedSizeSplitter(chunk_size=4000, chunk_overlap=200, approximate=False),
"splitter",
)
pipe.add_component(TextChunkEmbedder(embedder=OpenAIEmbeddings()), "chunk_embedder")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import asyncio

import neo4j
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
from neo4j_graphrag.experimental.components.kg_writer import Neo4jWriter
Expand All @@ -14,6 +13,8 @@
from neo4j_graphrag.experimental.pipeline import Pipeline
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult

import neo4j


async def main(neo4j_driver: neo4j.Driver) -> PipelineResult:
"""This is where we define and run the Lexical Graph builder pipeline, instantiating
Expand All @@ -27,7 +28,7 @@ async def main(neo4j_driver: neo4j.Driver) -> PipelineResult:
pipe = Pipeline()
# define the components
pipe.add_component(
FixedSizeSplitter(chunk_size=20, chunk_overlap=1),
FixedSizeSplitter(chunk_size=20, chunk_overlap=1, approximate=False),
"splitter",
)
pipe.add_component(TextChunkEmbedder(embedder=OpenAIEmbeddings()), "chunk_embedder")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import asyncio

import neo4j
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
from neo4j_graphrag.experimental.components.entity_relation_extractor import (
Expand All @@ -29,6 +28,8 @@
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
from neo4j_graphrag.llm import LLMInterface, OpenAILLM

import neo4j


async def define_and_run_pipeline(
neo4j_driver: neo4j.Driver,
Expand Down Expand Up @@ -56,7 +57,7 @@ async def define_and_run_pipeline(
pipe = Pipeline()
# define the components
pipe.add_component(
FixedSizeSplitter(chunk_size=200, chunk_overlap=50),
FixedSizeSplitter(chunk_size=200, chunk_overlap=50, approximate=False),
"splitter",
)
pipe.add_component(TextChunkEmbedder(embedder=OpenAIEmbeddings()), "chunk_embedder")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import asyncio

import neo4j
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
from neo4j_graphrag.experimental.components.entity_relation_extractor import (
Expand All @@ -31,6 +30,8 @@
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
from neo4j_graphrag.llm import LLMInterface, OpenAILLM

import neo4j


async def build_lexical_graph(
neo4j_driver: neo4j.Driver,
Expand All @@ -47,7 +48,7 @@ async def build_lexical_graph(
pipe = Pipeline()
# define the components
pipe.add_component(
FixedSizeSplitter(chunk_size=200, chunk_overlap=50),
FixedSizeSplitter(chunk_size=200, chunk_overlap=50, approximate=False),
"splitter",
)
pipe.add_component(TextChunkEmbedder(embedder=OpenAIEmbeddings()), "chunk_embedder")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,66 @@
from neo4j_graphrag.experimental.components.types import TextChunk, TextChunks


def _adjust_chunk_start(text: str, approximate_start: int) -> int:
"""
Shift the starting index backward if it lands in the middle of a word.
If no whitespace is found, use the proposed start.

Args:
text (str): The text being split.
approximate_start (int): The initial starting index of the chunk.

Returns:
int: The adjusted starting index, ensuring the chunk does not begin in the
middle of a word if possible.
"""
start = approximate_start
if start > 0 and not text[start].isspace() and not text[start - 1].isspace():
while start > 0 and not text[start - 1].isspace():
stellasia marked this conversation as resolved.
Show resolved Hide resolved
start -= 1

# fallback if no whitespace is found
if start == 0 and not text[0].isspace():
start = approximate_start
return start


def _adjust_chunk_end(text: str, start: int, approximate_end: int) -> int:
stellasia marked this conversation as resolved.
Show resolved Hide resolved
"""
Shift the ending index backward if it lands in the middle of a word.
If no whitespace is found, use 'approximate_end'.

Args:
text (str): The full text being split.
start (int): The adjusted starting index for this chunk.
approximate_end (int): The initial end index.

Returns:
int: The adjusted ending index, ensuring the chunk does not end in the middle of
a word if possible.
"""
end = approximate_end
if end < len(text):
while end > start and not text[end].isspace() and not text[end - 1].isspace():
end -= 1

# fallback if no whitespace is found
if end == start:
end = approximate_end
return end


class FixedSizeSplitter(TextSplitter):
"""Text splitter which splits the input text into fixed size chunks with optional overlap.
"""Text splitter which splits the input text into fixed or approximate fixed size
chunks with optional overlap.

Args:
chunk_size (int): The number of characters in each chunk.
chunk_overlap (int): The number of characters from the previous chunk to overlap with each chunk. Must be less than `chunk_size`.
chunk_overlap (int): The number of characters from the previous chunk to overlap
with each chunk. Must be less than `chunk_size`.
approximate (bool): If True, avoids splitting words in the middle at chunk
boundaries. Defaults to True.


Example:

Expand All @@ -33,16 +87,21 @@ class FixedSizeSplitter(TextSplitter):
from neo4j_graphrag.experimental.pipeline import Pipeline

pipeline = Pipeline()
text_splitter = FixedSizeSplitter(chunk_size=4000, chunk_overlap=200)
text_splitter = FixedSizeSplitter(chunk_size=4000, chunk_overlap=200, approximate=True)
pipeline.add_component(text_splitter, "text_splitter")
"""

@validate_call
def __init__(self, chunk_size: int = 4000, chunk_overlap: int = 200) -> None:
def __init__(
self, chunk_size: int = 4000, chunk_overlap: int = 200, approximate: bool = True
) -> None:
if chunk_size <= 0:
raise ValueError("chunk_size must be strictly greater than 0")
if chunk_overlap >= chunk_size:
raise ValueError("chunk_overlap must be strictly less than chunk_size")
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.approximate = approximate

@validate_call
async def run(self, text: str) -> TextChunks:
Expand All @@ -56,10 +115,35 @@ async def run(self, text: str) -> TextChunks:
"""
chunks = []
index = 0
for i in range(0, len(text), self.chunk_size - self.chunk_overlap):
start = i
end = min(start + self.chunk_size, len(text))
step = self.chunk_size - self.chunk_overlap
text_length = len(text)
approximate_start = 0
skip_adjust_chunk_start = False
end = 0

while end < text_length:
if self.approximate:
start = (
approximate_start
if skip_adjust_chunk_start
else _adjust_chunk_start(text, approximate_start)
)
# adjust start and end to avoid cutting words in the middle
approximate_end = min(start + self.chunk_size, text_length)
end = _adjust_chunk_end(text, start, approximate_end)
# when avoiding splitting words in the middle is not possible, revert to
# initial chunk end and skip adjusting next chunk start
skip_adjust_chunk_start = end == approximate_end
else:
# apply fixed size splitting with possibly words cut in half at chunk
# boundaries
start = approximate_start
end = min(start + self.chunk_size, text_length)

chunk_text = text[start:end]
chunks.append(TextChunk(text=chunk_text, index=index))
index += 1

approximate_start = start + step

return TextChunks(chunks=chunks)
Loading
Loading