Skip to content

Commit

Permalink
runpipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Oct 9, 2024
1 parent a90feda commit 8e5809d
Showing 1 changed file with 65 additions and 32 deletions.
97 changes: 65 additions & 32 deletions pdelfin/data/runpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import argparse
import boto3
import json
import hashlib
from pypdf import PdfReader
from tqdm import tqdm
from typing import Generator
from typing import Generator, List
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
from urllib.parse import urlparse

Expand All @@ -18,6 +19,8 @@
from pdelfin.filter import PdfFilter

import logging
import smart_open
import posixpath # Import posixpath for S3 path handling

logging.getLogger("pypdf").setLevel(logging.ERROR)

Expand Down Expand Up @@ -51,7 +54,7 @@ def fetch_s3_file(s3_url: str, local_path: str) -> str:
s3.download_file(bucket_name, key, local_path)
return local_path

def process_pdf(pdf_path: str, no_filter: bool) -> Generator[dict, None, None]:
def process_pdf(pdf_path: str, no_filter: bool) -> List[dict]:
if pdf_path.startswith("s3://"):
local_pdf_path = os.path.join("/tmp", os.path.basename(pdf_path))
fetch_s3_file(pdf_path, local_pdf_path)
Expand Down Expand Up @@ -95,11 +98,34 @@ def expand_s3_glob(s3_glob: str) -> list:
for page in page_iterator:
for obj in page.get('Contents', []):
key = obj['Key']
if key.endswith('.pdf') and glob.fnmatch.fnmatch(key, prefix + pattern):
if key.endswith('.pdf') and glob.fnmatch.fnmatch(key, posixpath.join(prefix, pattern)):
matched_files.append(f"s3://{bucket_name}/{key}")

return matched_files

def compute_hash(content: str) -> str:
"""Compute a 20-character SHA1 hash of the given content."""
sha1 = hashlib.sha1()
sha1.update(content.encode('utf-8'))
return sha1.hexdigest()[:20]

def get_smart_open_write_path(output_path: str, hash_str: str) -> str:
"""Generate the full output path with hash in the filename."""
parsed = urlparse(output_path)
if parsed.scheme in ('s3', 's3a', 's3n'):
bucket = parsed.netloc
key = parsed.path.lstrip('/')
# Ensure the key is treated as a directory by appending a slash if not present
if key and not key.endswith('/'):
key += '/'
# Use posixpath to correctly join S3 paths
full_key = posixpath.join(key, f"output_{hash_str}.jsonl")
return f"s3://{bucket}/{full_key}"
else:
dir_path = output_path
filename = f"output_{hash_str}.jsonl"
return os.path.join(dir_path, filename)

def main():
parser = argparse.ArgumentParser(
description="Given a bunch of PDFs, prepares a mise/birr workflow to run them through a conversion mechanism"
Expand Down Expand Up @@ -132,7 +158,7 @@ def main():
"--output",
type=str,
default="mise_batch_data",
help="Output destination"
help="Output destination (can be a local path or an S3 URI)"
)
args = parser.parse_args()

Expand Down Expand Up @@ -167,22 +193,31 @@ def main():

print(f"Loaded and shuffled {len(pdf_paths)} paths to use.")

# Rest of the code remains the same
cur_file_num = 0
# Prepare for output
output_dir = args.output
max_file_size = args.max_size_mb * 1024 * 1024
cur_file_size = 0
cur_file_path = os.path.join(output_dir, f"output_{cur_file_num}.jsonl")

# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)
max_file_size = args.max_size_mb * 1024 * 1024 # Convert MB to bytes

# Open the first file for writing
cur_file = open(cur_file_path, 'w')
# Determine if output is S3
parsed_output = urlparse(output_dir)
is_s3 = parsed_output.scheme in ('s3', 's3a', 's3n')

# Counter to track PDFs that produce at least one output
# Initialize variables for batching
batch = []
batch_size = 0
pdfs_with_output = 0

# Function to write a batch
def write_batch(batch: List[dict]):
nonlocal output_dir
if not batch:
return
batch_content = "\n".join(json.dumps(entry) for entry in batch) + "\n"
hash_str = compute_hash(batch_content)
output_path_with_hash = get_smart_open_write_path(output_dir, hash_str)
with smart_open.open(output_path_with_hash, 'w') as f_out:
f_out.write(batch_content)
print(f"Wrote batch to {output_path_with_hash}")

# Using ProcessPoolExecutor to process files concurrently
with ProcessPoolExecutor() as executor:
futures = []
Expand All @@ -200,28 +235,26 @@ def main():

for request_obj in request_results:
request_json = json.dumps(request_obj)
request_size = len(request_json.encode('utf-8')) # Calculate size in bytes

# Check if the current request can fit in the current file
if cur_file_size + request_size > max_file_size:
# Close the current file and create a new one
cur_file.close()
cur_file_num += 1
cur_file_path = os.path.join(output_dir, f"output_{cur_file_num}.jsonl")
cur_file = open(cur_file_path, 'w')
cur_file_size = 0 # Reset file size

# Write the JSON entry to the file
cur_file.write(request_json)
cur_file.write("\n")
cur_file_size += request_size
request_size = len(request_json.encode('utf-8')) + 1 # +1 for newline

# Check if adding this entry would exceed the max size
if batch_size + request_size > max_file_size:
# Write the current batch
write_batch(batch)
# Reset the batch
batch = []
batch_size = 0

# Add the entry to the batch
batch.append(request_obj)
batch_size += request_size

pb.update(1)
except Exception as e:
print(f"Error processing a PDF: {str(e)}")

# Close the last open file
cur_file.close()
# Write any remaining batch
write_batch(batch)

# Print the number of PDFs that resulted in at least one output
print(f"Number of sampled PDFs that produced at least one output: {pdfs_with_output}")
Expand Down

0 comments on commit 8e5809d

Please sign in to comment.