Skip to content

Commit

Permalink
Improve text splitter to avoid cutting words in chunks (#242)
Browse files Browse the repository at this point in the history
* Improve text splitter to avoid cutting words in chunks
  • Loading branch information
NathalieCharbel authored Jan 21, 2025
1 parent 68cba61 commit 191b1b1
Show file tree
Hide file tree
Showing 10 changed files with 254 additions and 22 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
### Changed
- Updated LLM implementations to handle message history consistently across providers.
- The `id_prefix` parameter in the `LexicalGraphConfig` is deprecated.
- Changed the default behaviour of `FixedSizeSplitter` to avoid words cut-off in the chunks whenever it is possible.

### Fixed
- IDs for the Document and Chunk nodes in the lexical graph are now randomly generated and unique across multiple runs, fixing issues in the lexical graph where relationships were created between chunks that were created by different pipeline runs.
Expand Down
5 changes: 4 additions & 1 deletion docs/source/user_guide_kg_builder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -581,9 +581,12 @@ that can be processed within the LLM token limits:
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import FixedSizeSplitter
splitter = FixedSizeSplitter(chunk_size=4000, chunk_overlap=200)
splitter = FixedSizeSplitter(chunk_size=4000, chunk_overlap=200, approximate=False)
splitter.run(text="Hello World. Life is beautiful.")
.. note::

`approximate` flag is by default set to True to ensure clean chunk start and end (i.e. avoid words cut in the middle) whenever it is possible.

Wrappers for LangChain and LlamaIndex text splitters are included in this package:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

async def main() -> TextChunks:
splitter = FixedSizeSplitter(
# optionally, configure chunk_size and chunk_overlap
# optionally, configure chunk_size, chunk_overlap, and approximate flag
# chunk_size=4000,
# chunk_overlap=200,
# approximate = False
)
chunks = await splitter.run(text="text to split")
return chunks
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import asyncio
import logging

import neo4j
from neo4j_graphrag.experimental.components.entity_relation_extractor import (
LLMEntityRelationExtractor,
OnError,
Expand All @@ -35,6 +34,8 @@
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
from neo4j_graphrag.llm import LLMInterface, OpenAILLM

import neo4j

logging.basicConfig(level=logging.INFO)


Expand Down Expand Up @@ -83,7 +84,8 @@ async def define_and_run_pipeline(
pipe = Pipeline()
pipe.add_component(PdfLoader(), "pdf_loader")
pipe.add_component(
FixedSizeSplitter(chunk_size=4000, chunk_overlap=200), "splitter"
FixedSizeSplitter(chunk_size=4000, chunk_overlap=200, approximate=False),
"splitter",
)
pipe.add_component(SchemaBuilder(), "schema")
pipe.add_component(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import asyncio

import neo4j
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
from neo4j_graphrag.experimental.components.entity_relation_extractor import (
Expand All @@ -37,6 +36,8 @@
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
from neo4j_graphrag.llm import LLMInterface, OpenAILLM

import neo4j


async def define_and_run_pipeline(
neo4j_driver: neo4j.Driver, llm: LLMInterface
Expand All @@ -58,7 +59,7 @@ async def define_and_run_pipeline(
# define the components
pipe.add_component(
# chunk_size=50 for the sake of this demo
FixedSizeSplitter(chunk_size=4000, chunk_overlap=200),
FixedSizeSplitter(chunk_size=4000, chunk_overlap=200, approximate=False),
"splitter",
)
pipe.add_component(TextChunkEmbedder(embedder=OpenAIEmbeddings()), "chunk_embedder")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import asyncio

import neo4j
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
from neo4j_graphrag.experimental.components.kg_writer import Neo4jWriter
Expand All @@ -14,6 +13,8 @@
from neo4j_graphrag.experimental.pipeline import Pipeline
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult

import neo4j


async def main(neo4j_driver: neo4j.Driver) -> PipelineResult:
"""This is where we define and run the Lexical Graph builder pipeline, instantiating
Expand All @@ -27,7 +28,7 @@ async def main(neo4j_driver: neo4j.Driver) -> PipelineResult:
pipe = Pipeline()
# define the components
pipe.add_component(
FixedSizeSplitter(chunk_size=20, chunk_overlap=1),
FixedSizeSplitter(chunk_size=20, chunk_overlap=1, approximate=False),
"splitter",
)
pipe.add_component(TextChunkEmbedder(embedder=OpenAIEmbeddings()), "chunk_embedder")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import asyncio

import neo4j
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
from neo4j_graphrag.experimental.components.entity_relation_extractor import (
Expand All @@ -29,6 +28,8 @@
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
from neo4j_graphrag.llm import LLMInterface, OpenAILLM

import neo4j


async def define_and_run_pipeline(
neo4j_driver: neo4j.Driver,
Expand Down Expand Up @@ -56,7 +57,7 @@ async def define_and_run_pipeline(
pipe = Pipeline()
# define the components
pipe.add_component(
FixedSizeSplitter(chunk_size=200, chunk_overlap=50),
FixedSizeSplitter(chunk_size=200, chunk_overlap=50, approximate=False),
"splitter",
)
pipe.add_component(TextChunkEmbedder(embedder=OpenAIEmbeddings()), "chunk_embedder")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import asyncio

import neo4j
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
from neo4j_graphrag.experimental.components.entity_relation_extractor import (
Expand All @@ -31,6 +30,8 @@
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
from neo4j_graphrag.llm import LLMInterface, OpenAILLM

import neo4j


async def build_lexical_graph(
neo4j_driver: neo4j.Driver,
Expand All @@ -47,7 +48,7 @@ async def build_lexical_graph(
pipe = Pipeline()
# define the components
pipe.add_component(
FixedSizeSplitter(chunk_size=200, chunk_overlap=50),
FixedSizeSplitter(chunk_size=200, chunk_overlap=50, approximate=False),
"splitter",
)
pipe.add_component(TextChunkEmbedder(embedder=OpenAIEmbeddings()), "chunk_embedder")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,66 @@
from neo4j_graphrag.experimental.components.types import TextChunk, TextChunks


def _adjust_chunk_start(text: str, approximate_start: int) -> int:
"""
Shift the starting index backward if it lands in the middle of a word.
If no whitespace is found, use the proposed start.
Args:
text (str): The text being split.
approximate_start (int): The initial starting index of the chunk.
Returns:
int: The adjusted starting index, ensuring the chunk does not begin in the
middle of a word if possible.
"""
start = approximate_start
if start > 0 and not text[start].isspace() and not text[start - 1].isspace():
while start > 0 and not text[start - 1].isspace():
start -= 1

# fallback if no whitespace is found
if start == 0 and not text[0].isspace():
start = approximate_start
return start


def _adjust_chunk_end(text: str, start: int, approximate_end: int) -> int:
"""
Shift the ending index backward if it lands in the middle of a word.
If no whitespace is found, use 'approximate_end'.
Args:
text (str): The full text being split.
start (int): The adjusted starting index for this chunk.
approximate_end (int): The initial end index.
Returns:
int: The adjusted ending index, ensuring the chunk does not end in the middle of
a word if possible.
"""
end = approximate_end
if end < len(text):
while end > start and not text[end].isspace() and not text[end - 1].isspace():
end -= 1

# fallback if no whitespace is found
if end == start:
end = approximate_end
return end


class FixedSizeSplitter(TextSplitter):
"""Text splitter which splits the input text into fixed size chunks with optional overlap.
"""Text splitter which splits the input text into fixed or approximate fixed size
chunks with optional overlap.
Args:
chunk_size (int): The number of characters in each chunk.
chunk_overlap (int): The number of characters from the previous chunk to overlap with each chunk. Must be less than `chunk_size`.
chunk_overlap (int): The number of characters from the previous chunk to overlap
with each chunk. Must be less than `chunk_size`.
approximate (bool): If True, avoids splitting words in the middle at chunk
boundaries. Defaults to True.
Example:
Expand All @@ -33,16 +87,21 @@ class FixedSizeSplitter(TextSplitter):
from neo4j_graphrag.experimental.pipeline import Pipeline
pipeline = Pipeline()
text_splitter = FixedSizeSplitter(chunk_size=4000, chunk_overlap=200)
text_splitter = FixedSizeSplitter(chunk_size=4000, chunk_overlap=200, approximate=True)
pipeline.add_component(text_splitter, "text_splitter")
"""

@validate_call
def __init__(self, chunk_size: int = 4000, chunk_overlap: int = 200) -> None:
def __init__(
self, chunk_size: int = 4000, chunk_overlap: int = 200, approximate: bool = True
) -> None:
if chunk_size <= 0:
raise ValueError("chunk_size must be strictly greater than 0")
if chunk_overlap >= chunk_size:
raise ValueError("chunk_overlap must be strictly less than chunk_size")
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.approximate = approximate

@validate_call
async def run(self, text: str) -> TextChunks:
Expand All @@ -56,10 +115,35 @@ async def run(self, text: str) -> TextChunks:
"""
chunks = []
index = 0
for i in range(0, len(text), self.chunk_size - self.chunk_overlap):
start = i
end = min(start + self.chunk_size, len(text))
step = self.chunk_size - self.chunk_overlap
text_length = len(text)
approximate_start = 0
skip_adjust_chunk_start = False
end = 0

while end < text_length:
if self.approximate:
start = (
approximate_start
if skip_adjust_chunk_start
else _adjust_chunk_start(text, approximate_start)
)
# adjust start and end to avoid cutting words in the middle
approximate_end = min(start + self.chunk_size, text_length)
end = _adjust_chunk_end(text, start, approximate_end)
# when avoiding splitting words in the middle is not possible, revert to
# initial chunk end and skip adjusting next chunk start
skip_adjust_chunk_start = end == approximate_end
else:
# apply fixed size splitting with possibly words cut in half at chunk
# boundaries
start = approximate_start
end = min(start + self.chunk_size, text_length)

chunk_text = text[start:end]
chunks.append(TextChunk(text=chunk_text, index=index))
index += 1

approximate_start = start + step

return TextChunks(chunks=chunks)
Loading

0 comments on commit 191b1b1

Please sign in to comment.