Skip to content

Commit

Permalink
fix(chunker): correctly determine chunk midpoint when empty chunks ar…
Browse files Browse the repository at this point in the history
…e present (#1800)

Previously ["foo", '', "bar", 'baz'] would be token counted as
'foobarbaz' rather than 'foo  bar baz' when getting the midpoint index
  • Loading branch information
collindutter committed Mar 4, 2025
1 parent 4cf9fcb commit 0c1935d
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 19 deletions.
6 changes: 3 additions & 3 deletions griptape/chunkers/base_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _chunk_recursively(self, chunk: str, current_separator: Optional[ChunkSepara

if len(non_empty_subchunks) > 1:
# Find what combination of subchunks results in the most balanced split of the chunk.
midpoint_index = self.__find_midpoint_index(subchunks, half_token_count)
midpoint_index = self.__find_midpoint_index(separator, subchunks, half_token_count)

# Create the two subchunks based on the best separator.
first_subchunk, second_subchunk = self.__get_subchunks(separator, subchunks, midpoint_index)
Expand Down Expand Up @@ -98,12 +98,12 @@ def __get_subchunks(self, separator: ChunkSeparator, subchunks: list[str], balan

return first_subchunk, second_subchunk

def __find_midpoint_index(self, subchunks: list[str], half_token_count: int) -> int:
def __find_midpoint_index(self, separator: ChunkSeparator, subchunks: list[str], half_token_count: int) -> int:
midpoint_index = -1
best_midpoint_distance = float("inf")

for index, _ in enumerate(subchunks):
subchunk_tokens_count = self.tokenizer.count_tokens("".join(subchunks[: index + 1]))
subchunk_tokens_count = self.tokenizer.count_tokens(separator.value.join(subchunks[: index + 1]))

midpoint_distance = abs(subchunk_tokens_count - half_token_count)
if midpoint_distance < best_midpoint_distance:
Expand Down
14 changes: 8 additions & 6 deletions tests/unit/chunkers/test_markdown_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_chunk(self, chunker):
]
chunks = chunker.chunk("".join(text))

assert len(chunks) == 6
assert len(chunks) == 7

for chunk in chunks:
assert chunker.tokenizer.count_tokens(chunk.value) <= MAX_TOKENS
Expand All @@ -33,12 +33,14 @@ def test_chunk(self, chunker):
assert chunks[1].value.startswith("## Header 2\nfoo-0")
assert chunks[2].value.startswith("foo-0.")
assert chunks[3].value.startswith("## Header 3\nfoo-0")
assert chunks[4].value.startswith("foo-10.")
assert chunks[5].value.startswith("foo-16.")
assert chunks[4].value.startswith("foo-5.")
assert chunks[5].value.startswith("foo-12.")
assert chunks[6].value.startswith("foo-19.")

assert chunks[0].value.endswith(". foo-5.")
assert chunks[1].value.endswith(". foo-5.")
assert chunks[2].value.endswith(". foo-5.")
assert chunks[3].value.endswith(". foo-9.")
assert chunks[4].value.endswith(". foo-15.")
assert chunks[5].value.endswith(". foo-24.")
assert chunks[3].value.endswith(". foo-4.")
assert chunks[4].value.endswith(". foo-11.")
assert chunks[5].value.endswith(". foo-18.")
assert chunks[6].value.endswith(". foo-24.")
32 changes: 22 additions & 10 deletions tests/unit/chunkers/test_text_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ def test_large_chunks(self, chunker):
assert chunker.tokenizer.count_tokens(chunk.value) <= MAX_TOKENS

assert chunks[0].value.startswith("foo-0!")
assert chunks[1].value.startswith("foo-11!")
assert chunks[2].value.startswith("foo-17!")
assert chunks[1].value.startswith("foo-7!")
assert chunks[2].value.startswith("foo-13!")
assert chunks[3].value.startswith("foo-0.")

assert chunks[0].value.endswith("! foo-10!")
assert chunks[1].value.endswith("! foo-16!")
assert chunks[0].value.endswith("! foo-6!")
assert chunks[1].value.endswith("! foo-12!")
assert chunks[2].value.endswith("! foo-24!")
assert chunks[3].value.endswith(". foo-11.")

Expand Down Expand Up @@ -92,19 +92,19 @@ def test_separators(self, chunker):
assert chunker.tokenizer.count_tokens(chunk.value) <= MAX_TOKENS

assert chunks[0].value.startswith("foo-0!")
assert chunks[1].value.startswith("foo-11!")
assert chunks[2].value.startswith("foo-17!")
assert chunks[1].value.startswith("foo-7!")
assert chunks[2].value.startswith("foo-13!")
assert chunks[3].value.startswith("foo-0.")
assert chunks[4].value.startswith("foo-0?")
assert chunks[5].value.startswith("foo-9?")
assert chunks[5].value.startswith("foo-7?")
assert chunks[6].value.startswith("foo-0")
assert chunks[7].value.startswith("foo-8")

assert chunks[0].value.endswith("! foo-10!")
assert chunks[1].value.endswith("! foo-16!")
assert chunks[0].value.endswith("! foo-6!")
assert chunks[1].value.endswith("! foo-12!")
assert chunks[2].value.endswith("! foo-24!")
assert chunks[3].value.endswith(". foo-11.")
assert chunks[4].value.endswith("? foo-8?")
assert chunks[4].value.endswith("? foo-6?")
assert chunks[5].value.endswith("? foo-12?")
assert chunks[6].value.endswith(" foo-7")
assert chunks[7].value.endswith(" foo-16")
Expand Down Expand Up @@ -138,3 +138,15 @@ def test_artifact_reference(self, chunker):

for chunk in chunks:
assert chunk.reference is None

def test_midpoint_index_empty_subchunks(self, chunker):
# This tests that a midpoint index is correctly found when there are some empty subchunks
# Previously ["foo", '', "bar", 'baz'] would be token counted as 'foobarbaz' rather than 'foo bar baz'
# when calculating the midpoint index.
# https://github.com/griptape-ai/griptape/issues/1796
chunker.max_tokens = 3

assert len(chunker.chunk("foo bar baz")) == 1
assert len(chunker.chunk("foo bar baz ")) == 2

assert len(chunker.chunk("foo bar baz")) == 2

0 comments on commit 0c1935d

Please sign in to comment.