Skip to content

Commit

Permalink
Making my parquets
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Feb 14, 2025
1 parent 51cfdbd commit 8297955
Showing 1 changed file with 19 additions and 14 deletions.
33 changes: 19 additions & 14 deletions olmocr/train/convertjsontoparquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import argparse
import glob
import json
import multiprocessing
import re
import sqlite3
import tempfile
Expand All @@ -18,6 +19,7 @@
from dataclasses import dataclass
from typing import Optional, List, Tuple, Dict, Set
import concurrent.futures
from urllib.parse import urlparse

import boto3
from tqdm import tqdm
Expand All @@ -32,11 +34,17 @@ def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
it will return "de80a57e6c57b45796d2e020173227f7eae44232".
"""
# Allow an optional "-<number>" at the end.
pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf(?:-\d+)?$"
match = re.match(pattern, pretty_pdf_path)
if match:
return match.group(1) + match.group(2)
return None
if pretty_pdf_path.startswith("s3://ai2-s2-pdfs/"):
pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf(?:-\d+)?$"
match = re.match(pattern, pretty_pdf_path)
if match:
return match.group(1) + match.group(2)
return None
elif pretty_pdf_path.startswith("s3://ai2-oe-data/reganh/iabooks/"):
return urlparse(pretty_pdf_path).path.split("/")[-1]
else:
raise NotImplementedError()


def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]:
"""
Expand Down Expand Up @@ -293,14 +301,17 @@ def main():
os.makedirs(pdfs_dir, exist_ok=True)

# Create a temporary directory for caching PDFs.
pdf_cache_dir = tempfile.mkdtemp(prefix="pdf_cache_")
pdf_cache_dir = "/tmp/pdf_cache"
os.makedirs(pdf_cache_dir, exist_ok=True)

print(f"Caching PDFs to temporary directory: {pdf_cache_dir}")

# ---------------------------------------------------------------------
# Step 1: Scan input files to collect all unique S3 URLs using a ProcessPoolExecutor.
unique_s3_urls: Set[str] = set()
print("Scanning input files to collect unique PDF URLs...")
with concurrent.futures.ProcessPoolExecutor() as executor:
num_cpus = multiprocessing.cpu_count()
with concurrent.futures.ProcessPoolExecutor(max_workers=num_cpus * 4) as executor:
results = list(tqdm(executor.map(scan_file_for_s3_urls, files), total=len(files), desc="Scanning files"))
for url_set in results:
unique_s3_urls |= url_set
Expand All @@ -311,7 +322,7 @@ def main():
# Step 2: Download all unique PDFs to the cache directory.
pdf_cache: Dict[str, str] = {}
print("Caching PDFs from S3...")
with concurrent.futures.ThreadPoolExecutor() as executor:
with concurrent.futures.ProcessPoolExecutor(max_workers=num_cpus * 8) as executor:
future_to_url = {
executor.submit(download_pdf_to_cache, s3_url, pdf_cache_dir): s3_url
for s3_url in unique_s3_urls
Expand Down Expand Up @@ -361,12 +372,6 @@ def main():
else:
print("No valid rows to write. Exiting.")

# Optionally clean up the PDF cache directory.
try:
shutil.rmtree(pdf_cache_dir)
print(f"Cleaned up PDF cache directory: {pdf_cache_dir}")
except Exception as e:
print(f"Error cleaning up PDF cache directory: {e}")

if __name__ == "__main__":
main()

0 comments on commit 8297955

Please sign in to comment.