Skip to content

Commit

Permalink
Local pdf support
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Jan 28, 2025
1 parent dbe5487 commit 7882944
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
19 changes: 9 additions & 10 deletions olmocr/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,31 +915,30 @@ async def main():
# Expand s3 paths
if args.pdfs.startswith("s3://"):
logger.info(f"Expanding s3 glob at {args.pdfs}")
s3_work_paths = expand_s3_glob(pdf_s3, args.pdfs)
pdf_work_paths = expand_s3_glob(pdf_s3, args.pdfs)
elif any(char in args.pdfs for char in {"*", "?", "[", "]"}):
logger.info(f"Expanding local glob at {args.pdfs}")
s3_work_paths = glob.glob(args.pdfs)
pdf_work_paths = glob.glob(args.pdfs)
elif os.path.exists(args.pdfs):
logger.info(f"Loading file at {args.pdfs}")
with open(args.pdfs, "r") as f:
s3_work_paths = list(filter(None, (line.strip() for line in f)))
pdf_work_paths = list(filter(None, (line.strip() for line in f)))
else:
raise ValueError("pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)")

s3_work_paths = set(s3_work_paths)
logger.info(f"Found {len(s3_work_paths):,} total pdf paths to add")
pdf_work_paths = set(pdf_work_paths)
logger.info(f"Found {len(pdf_work_paths):,} total pdf paths to add")

# Estimate average pages per pdf
sample_size = min(100, len(s3_work_paths))
sampled_pdfs = random.sample(list(s3_work_paths), sample_size)
sample_size = min(100, len(pdf_work_paths))
sampled_pdfs = random.sample(list(pdf_work_paths), sample_size)
page_counts = []

for pdf in tqdm(sampled_pdfs, desc="Sampling PDFs to calculate optimal length"):
try:
# Download the PDF to a temp file
with tempfile.NamedTemporaryFile(suffix=".pdf") as tmp_file:
s3_bucket, s3_key = parse_s3_path(pdf)
pdf_s3.download_fileobj(s3_bucket, s3_key, tmp_file)
tmp_file.write(get_s3_bytes(pdf_s3, pdf))
tmp_file.flush()
reader = PdfReader(tmp_file.name)
page_counts.append(len(reader.pages))
Expand All @@ -956,7 +955,7 @@ async def main():
logger.info(f"Calculated items_per_group: {items_per_group} based on average pages per PDF: {avg_pages_per_pdf:.2f}")

# Now call populate_queue
await work_queue.populate_queue(s3_work_paths, items_per_group)
await work_queue.populate_queue(pdf_work_paths, items_per_group)

if args.stats:
print_stats(args)
Expand Down
6 changes: 6 additions & 0 deletions olmocr/s3_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ def expand_s3_glob(s3_client, s3_glob: str) -> dict[str, str]:


def get_s3_bytes(s3_client, s3_path: str, start_index: Optional[int] = None, end_index: Optional[int] = None) -> bytes:
# Fall back for local files
if os.path.exists(s3_path):
assert start_index is None and end_index is None, "Range query not supported yet"
with open(s3_path, "rb") as f:
return f.read()

bucket, key = parse_s3_path(s3_path)

# Build the range header if start_index and/or end_index are specified
Expand Down

0 comments on commit 7882944

Please sign in to comment.