Skip to content

Commit

Permalink
Updated reduce.py
Browse files Browse the repository at this point in the history
  • Loading branch information
rrawatt committed Jan 24, 2025
1 parent a377c20 commit b656138
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions docetl/operations/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ def __init__(self, *args, **kwargs):
)
self.intermediates = {}
self.lineage_keys = self.config.get("output", {}).get("lineage", [])
self.scratchpad_cache = {}

def clear_scratchpad_cache(self):
self.scratchpad_cache.clear()

def syntax_check(self) -> None:
"""
Expand Down Expand Up @@ -774,7 +778,8 @@ def _increment_fold(
batch: List[Dict],
current_output: Optional[Dict],
scratchpad: Optional[str] = None,
) -> Tuple[Optional[Dict], str, float]:
) -> Tuple[Optional[Dict], str, float]:

"""
Perform an incremental fold operation on a batch of items.
Expand All @@ -789,6 +794,10 @@ def _increment_fold(
Tuple[Optional[Dict], str, float]: A tuple containing the folded output (or None if processing failed),
the prompt used, and the cost of the fold operation.
"""
cache_key = (key, tuple(sorted(batch, key=lambda x: str(x))), current_output, scratchpad)
if cache_key in self.scratchpad_cache:
return self.scratchpad_cache[cache_key]

if current_output is None:
return self._batch_reduce(key, batch, scratchpad)

Expand Down Expand Up @@ -833,6 +842,7 @@ def _increment_fold(

folded_output.update(dict(zip(self.config["reduce_key"], key)))
fold_cost = response.total_cost
self.scratchpad_cache[cache_key] = (folded_output, fold_prompt, fold_cost)

return folded_output, fold_prompt, fold_cost

Expand All @@ -854,6 +864,9 @@ def _merge_results(
Tuple[Optional[Dict], str, float]: A tuple containing the merged output (or None if processing failed),
the prompt used, and the cost of the merge operation.
"""
cache_key = (key, tuple(sorted(outputs, key=lambda x: str(x))))
if cache_key in self.scratchpad_cache:
return self.scratchpad_cache[cache_key]
start_time = time.time()
merge_prompt = strict_render(self.config["merge_prompt"], {
"outputs": outputs,
Expand Down Expand Up @@ -891,6 +904,7 @@ def _merge_results(
)[0]
merged_output.update(dict(zip(self.config["reduce_key"], key)))
merge_cost = response.total_cost
self.scratchpad_cache[cache_key] = (merged_output, merge_prompt, merge_cost)
return merged_output, merge_prompt, merge_cost

return None, merge_prompt, merge_cost
Expand Down Expand Up @@ -948,6 +962,9 @@ def _update_merge_time(self, time: float) -> None:
def _batch_reduce(
self, key: Tuple, group_list: List[Dict], scratchpad: Optional[str] = None
) -> Tuple[Optional[Dict], str, float]:
cache_key = (key, tuple(sorted(group_list, key=lambda x: str(x))), scratchpad)
if cache_key in self.scratchpad_cache:
return self.scratchpad_cache[cache_key]
"""
Perform a batch reduce operation on a group of items.
Expand Down Expand Up @@ -999,6 +1016,6 @@ def _batch_reduce(
manually_fix_errors=self.manually_fix_errors,
)[0]
output.update(dict(zip(self.config["reduce_key"], key)))

self.scratchpad_cache[cache_key] = (output, prompt, item_cost)
return output, prompt, item_cost
return None, prompt, item_cost

0 comments on commit b656138

Please sign in to comment.