Skip to content

Commit

Permalink
Refactored to have a more efficient batchwriter, and also not allow t…
Browse files Browse the repository at this point in the history
…oo many running futures
  • Loading branch information
jakep-allenai committed Oct 23, 2024
1 parent d99096e commit 38dc5a2
Showing 1 changed file with 144 additions and 66 deletions.
210 changes: 144 additions & 66 deletions pdelfin/birrpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sqlite3
import json
import argparse
import glob
import uuid
import tempfile
import datetime
import posixpath
Expand All @@ -19,6 +19,7 @@
from functools import partial
from typing import Optional, List, Tuple, Dict, Callable, Any
from urllib.parse import urlparse
import concurrent.futures
from concurrent.futures import ProcessPoolExecutor, as_completed

from pdelfin.data.renderpdf import render_pdf_to_base64png
Expand Down Expand Up @@ -241,85 +242,119 @@ def close(self):
self.conn.close()


# Writes batches of lines out to a set of files, keeping each file below some maximum size
class BatchWriter:
def __init__(self, output_prefix: str, max_size_mb: int = 250, after_flush: Optional[Callable[[List[str]], Any]] = None):
def __init__(
self,
output_prefix: str,
max_size_mb: int = 250,
after_flush: Optional[Callable[[List[Any]], Any]] = None,
):
self.output_prefix = output_prefix
self.max_size = max_size_mb * 1024 * 1024 # Convert MB to bytes
self.batch = []
self.batch_objects = []
self.batch_size = 0
self.after_flush = after_flush
self.threads = []
self.temp_file = None # The temporary file object
self.temp_file_path = None # Path to the temporary file

parsed = urlparse(output_prefix)
self.is_s3 = parsed.scheme in ('s3', 's3a', 's3n')
self.is_s3 = parsed.scheme in ("s3", "s3a", "s3n")

if not self.is_s3:
os.makedirs(output_prefix, exist_ok=True)

def _compute_hash(self, content: str) -> str:
"""Compute a 20-character SHA1 hash of the given content."""
sha1 = hashlib.sha1()
sha1.update(content.encode('utf-8'))
return sha1.hexdigest()[:20]

def _get_output_path(self, hash_str: str) -> str:
"""Generate the full output path with hash in the filename."""
parsed = urlparse(self.output_prefix)
if self.is_s3:
bucket = parsed.netloc
key = parsed.path.lstrip('/')
if key and not key.endswith('/'):
key += '/'
full_key = posixpath.join(key, f"output_{hash_str}.jsonl")
return f"s3://{bucket}/{full_key}"
else:
filename = f"output_{hash_str}.jsonl"
return os.path.join(self.output_prefix, filename)

def write_line(self, line: Optional[str]):
if line is None or not line.strip():
def write_line(self, obj: Optional[Any]):
if obj is None:
return

line_size = len(line.encode('utf-8')) + 1 # +1 for newline
line_bytes = json.dumps(obj, ensure_ascii=False).encode("utf-8")
line_size = len(line_bytes) + 1 # +1 for newline

if self.batch_size + line_size > self.max_size:
self._write_batch()

self.batch.append(line)
if self.batch_size == 0:
# Open a new temporary file
self.temp_file = tempfile.NamedTemporaryFile(mode="wb+", delete=False)
self.temp_file_path = self.temp_file.name

self.temp_file.write(line_bytes + b"\n")
self.batch_objects.append(obj)
self.batch_size += line_size

def _write_batch(self):
if not self.batch:
if self.batch_size == 0:
return

batch_lines = self.batch.copy()
batch_content = "\n".join(batch_lines) + "\n"
hash_str = self._compute_hash(batch_content)
output_path = self._get_output_path(hash_str)
# Close the temp file
self.temp_file.flush()
self.temp_file.close()

# Start a new thread to write the batch
# Start a new thread to upload the temp file
thread = threading.Thread(
target=self._write_batch_to_file,
args=(batch_content, output_path, batch_lines)
target=self._write_batch_to_file, args=(self.temp_file_path, self.batch_objects)
)
thread.start()
self.threads.append(thread)

# Clear the batch and batch_size
self.batch = []
# Reset batch_objects and batch_size
self.batch_objects = []
self.batch_size = 0
self.temp_file = None
self.temp_file_path = None

def _write_batch_to_file(self, temp_file_path: str, batch_objects: List[Any]):
# Compute hash based on file content
hash_str = self._compute_hash(temp_file_path)
output_path = self._get_output_path(hash_str)

def _write_batch_to_file(self, batch_content: str, output_path: str, batch_lines: List[str]):
if self.is_s3:
put_s3_bytes(workspace_s3, output_path, batch_content.encode("utf-8"))
# Use s3 upload_file
parsed = urlparse(output_path)
bucket = parsed.netloc
key = parsed.path.lstrip("/")

# Use the s3 client directly
try:
workspace_s3.upload_file(temp_file_path, bucket, key)
except Exception as e:
print(f"Failed to upload {temp_file_path} to {output_path}: {e}")
else:
with open(output_path, 'w', encoding='utf-8') as f_out:
f_out.write(batch_content)
# Move the temp file to the output path
os.rename(temp_file_path, output_path)

# After writing, call the after_flush callback if it is set
if self.after_flush:
self.after_flush(batch_lines)
self.after_flush(batch_objects)

# Delete the temporary file
os.remove(temp_file_path)

def _compute_hash(self, temp_file_path: str) -> str:
"""Compute a 20-character SHA1 hash of the file content."""
sha1 = hashlib.sha1()
with open(temp_file_path, "rb") as f:
while True:
data = f.read(1024*1024)
if not data:
break
sha1.update(data)
return sha1.hexdigest()[:20]

def _get_output_path(self, hash_str: str) -> str:
"""Generate the full output path with hash in the filename."""
parsed = urlparse(self.output_prefix)
if self.is_s3:
bucket = parsed.netloc
key = parsed.path.lstrip("/")
if key and not key.endswith("/"):
key += "/"
full_key = posixpath.join(key, f"output_{hash_str}.jsonl")
return f"s3://{bucket}/{full_key}"
else:
filename = f"output_{hash_str}.jsonl"
return os.path.join(self.output_prefix, filename)

def close(self):
self._write_batch()
Expand Down Expand Up @@ -520,11 +555,11 @@ def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> Option

return dolma_doc

def mark_pdfs_done(s3_workspace: str, dolma_doc_lines: list[str]):
def mark_pdfs_done(s3_workspace: str, dolma_docs: list[dict]):
db = DatabaseManager(s3_workspace)

for line in dolma_doc_lines:
db.update_pdf_status(json.loads(line)["metadata"]["Source-File"], "completed")
for doc in dolma_docs:
db.update_pdf_status(doc["metadata"]["Source-File"], "completed")

def get_current_round(s3_workspace: str) -> int:
path = s3_workspace[5:]
Expand Down Expand Up @@ -610,13 +645,13 @@ def get_current_round(s3_workspace: str) -> int:
print("Indexing all batch inference sent to this workspace")
inference_output_paths = expand_s3_glob(workspace_s3, f"{args.workspace}/inference_outputs/*.jsonl")

inference_output_paths = [
(s3_path, etag) for s3_path, etag in inference_output_paths.items()
inference_output_paths = {
s3_path: etag for s3_path, etag in inference_output_paths.items()
if not db.is_file_processed(s3_path, etag)
]
}

print(f"Found {len(inference_output_paths):,} new batch inference results to index")
future_to_path = {executor.submit(process_jsonl_content, s3_path): (s3_path, etag) for s3_path, etag in inference_output_paths}
future_to_path = {executor.submit(process_jsonl_content, s3_path): (s3_path, etag) for s3_path, etag in inference_output_paths.items()}

for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
s3_path, etag = future_to_path[future]
Expand All @@ -638,29 +673,72 @@ def get_current_round(s3_workspace: str) -> int:
potentially_done_pdfs = db.get_pdfs_by_status("pending")
else:
print(f"\nCreating batch inference files for new PDFs")
future_to_path = {executor.submit(build_pdf_queries, args.workspace, pdf, current_round, args.target_longest_image_dim, args.target_anchor_text_len): pdf for pdf in db.get_pdfs_by_status("pending")}
pdf_list = list(db.get_pdfs_by_status("pending"))
pdf_iter = iter(pdf_list)
pending_futures = {}
potentially_done_pdfs = []
lines_written = 0
new_inference_writer = BatchWriter(f"{args.workspace}/inference_inputs/round_{current_round}", args.max_size_mb)

for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
pdf = future_to_path[future]
inference_lines = future.result()

if len(inference_lines) == 0:
potentially_done_pdfs.append(pdf)

for line in inference_lines:
lines_written += 1

if line is not None:
new_inference_writer.write_line(json.dumps(line))
total_pdfs = len(pdf_list)
max_pending = 5000

with tqdm(total=total_pdfs) as pbar:
# Submit initial batch of futures
for _ in range(min(max_pending, total_pdfs)):
pdf = next(pdf_iter)
future = executor.submit(
build_pdf_queries,
args.workspace,
pdf,
current_round,
args.target_longest_image_dim,
args.target_anchor_text_len,
)
pending_futures[future] = pdf

while pending_futures:
# Wait for the next future to complete
done, _ = concurrent.futures.wait(
pending_futures.keys(),
return_when=concurrent.futures.FIRST_COMPLETED,
)

for future in done:
pdf = pending_futures.pop(future)
inference_lines = future.result()

if len(inference_lines) == 0:
potentially_done_pdfs.append(pdf)

for line in inference_lines:
lines_written += 1

if line is not None:
new_inference_writer.write_line(line)

pbar.update(1)

# Submit a new future if there are more PDFs
try:
pdf = next(pdf_iter)
new_future = executor.submit(
build_pdf_queries,
args.workspace,
pdf,
current_round,
args.target_longest_image_dim,
args.target_anchor_text_len,
)
pending_futures[new_future] = pdf
except StopIteration:
pass # No more PDFs to process

new_inference_writer.close()

if lines_written > 0:
print(f"Added {lines_written:,} new batch inference requests")


# Now, finally, assemble any potentially done docs into dolma documents
print(f"\nAssembling potentially finished PDFs into Dolma documents at {args.workspace}/output")
future_to_path = {executor.submit(build_dolma_doc, args.workspace, pdf): pdf for pdf in potentially_done_pdfs}
Expand All @@ -671,7 +749,7 @@ def get_current_round(s3_workspace: str) -> int:
dolma_doc = future.result()

if dolma_doc is not None:
new_output_writer.write_line(json.dumps(dolma_doc))
new_output_writer.write_line(dolma_doc)

new_output_writer.close()

Expand Down

0 comments on commit 38dc5a2

Please sign in to comment.