Skip to content

Commit

Permalink
run pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Oct 9, 2024
1 parent 954b19a commit c2909f3
Showing 1 changed file with 86 additions and 36 deletions.
122 changes: 86 additions & 36 deletions pdelfin/data/runpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from pdelfin.prompts.anchor import get_anchor_text
from pdelfin.filter import PdfFilter


pdf_filter = PdfFilter()

def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int) -> dict:
Expand All @@ -43,7 +42,7 @@ def fetch_s3_file(s3_url: str, local_path: str) -> str:
parsed = urlparse(s3_url)
bucket_name = parsed.netloc
key = parsed.path.lstrip('/')

s3 = boto3.client('s3')
s3.download_file(bucket_name, key, local_path)
return local_path
Expand All @@ -58,12 +57,12 @@ def process_pdf(pdf_path: str, no_filter: bool) -> Generator[dict, None, None]:
if (not no_filter) and pdf_filter.filter_out_pdf(local_pdf_path):
print(f"Skipping {local_pdf_path} due to common filter")
return []

pretty_pdf_path = pdf_path

pdf = PdfReader(local_pdf_path)
num_pages = len(pdf.pages)

sample_pages = list(range(1, num_pages + 1))
result = []
for page in sample_pages:
Expand All @@ -75,44 +74,92 @@ def process_pdf(pdf_path: str, no_filter: bool) -> Generator[dict, None, None]:

return result

def is_glob_pattern(path: str) -> bool:
return any(char in path for char in ['*', '?', '[', ']'])

def expand_s3_glob(s3_glob: str) -> list:
parsed = urlparse(s3_glob)
bucket_name = parsed.netloc
prefix = os.path.dirname(parsed.path.lstrip('/')).rstrip('/') + "/"
pattern = os.path.basename(parsed.path)

s3 = boto3.client('s3')
paginator = s3.get_paginator('list_objects_v2')
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)

matched_files = []
for page in page_iterator:
for obj in page.get('Contents', []):
key = obj['Key']
if key.endswith('.pdf') and glob.fnmatch.fnmatch(key, prefix + pattern):
matched_files.append(f"s3://{bucket_name}/{key}")

return matched_files

def main():
parser = argparse.ArgumentParser(description="Given a bunch of PDFs, prepares a mise/birr workflow to run them through a conversion mechanism")
parser.add_argument("--glob_path", type=str, help="Local or S3 path glob (e.g., *.pdf or s3://bucket/pdfs/*.pdf).")
parser.add_argument("--path_list", type=str, help="Path to a file containing paths to PDFs, one per line.")
parser.add_argument("--max_size_mb", type=int, default=250, help="Max number of mb's of entries to put in each birr workitem")
parser.add_argument("--no_filter", action="store_true", help="Disables the basic spam/language filtering so that ALL pdfs listed are used")
parser.add_argument("--output", type=str, default="mise_batch_data", help="Output destination")
parser = argparse.ArgumentParser(
description="Given a bunch of PDFs, prepares a mise/birr workflow to run them through a conversion mechanism"
)
parser.add_argument(
"pdf_paths",
nargs='*',
help=(
"List of PDF paths to process. If a single argument contains glob patterns (e.g., *.pdf or s3://bucket/pdfs/*.pdf), "
"it will be expanded accordingly."
)
)
parser.add_argument(
"--path_list",
type=str,
help="Path to a file containing paths to PDFs, one per line."
)
parser.add_argument(
"--max_size_mb",
type=int,
default=250,
help="Max number of MBs of entries to put in each birr workitem"
)
parser.add_argument(
"--no_filter",
action="store_true",
help="Disables the basic spam/language filtering so that ALL pdfs listed are used"
)
parser.add_argument(
"--output",
type=str,
default="mise_batch_data",
help="Output destination"
)
args = parser.parse_args()

pdf_paths = []

# Load PDF paths from glob or path_list using reservoir sampling
if args.glob_path:
if args.glob_path.startswith("s3://"):
# Handle S3 globbing using boto3 with pagination
parsed = urlparse(args.glob_path)
s3 = boto3.client('s3')
bucket_name = parsed.netloc
prefix = os.path.dirname(parsed.path.lstrip('/')) + "/"
paginator = s3.get_paginator('list_objects_v2')
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)

for page in page_iterator:
for obj in page.get('Contents', []):
if obj['Key'].endswith('.pdf'):
pdf_paths.append(f"s3://{bucket_name}/{obj['Key']}")

# Load PDF paths from positional arguments or path_list
if args.pdf_paths:
if len(args.pdf_paths) == 1 and is_glob_pattern(args.pdf_paths[0]):
glob_path = args.pdf_paths[0]
if glob_path.startswith("s3://"):
# Handle S3 globbing
expanded_paths = expand_s3_glob(glob_path)
pdf_paths.extend(expanded_paths)
else:
# Handle local filesystem globbing
expanded_paths = glob.glob(glob_path, recursive=True)
pdf_paths.extend(expanded_paths)
else:
# Handle local globbing using glob.iglob()
for path in glob.iglob(args.glob_path, recursive=True):
pdf_paths.append(path)
elif args.path_list:
# Treat positional arguments as list of PDF paths
pdf_paths.extend(args.pdf_paths)

if args.path_list:
with open(args.path_list, 'r') as f:
for line in f:
n += 1
path = line.strip()
pdf_paths.append(path)
if path:
pdf_paths.append(path)

# Remove duplicates and shuffle
pdf_paths = list(set(pdf_paths))
random.shuffle(pdf_paths)

print(f"Loaded and shuffled {len(pdf_paths)} paths to use.")

Expand All @@ -132,7 +179,7 @@ def main():
# Counter to track PDFs that produce at least one output
pdfs_with_output = 0

# Using ThreadPoolExecutor to process files concurrently
# Using ProcessPoolExecutor to process files concurrently
with ProcessPoolExecutor() as executor:
futures = []

Expand All @@ -142,7 +189,10 @@ def main():

for future in as_completed(futures):
try:
request_results = future.result() # Get the result from the thread
request_results = future.result() # Get the result from the process

if request_results:
pdfs_with_output += 1 # Increment if there's at least one result

for request_obj in request_results:
request_json = json.dumps(request_obj)
Expand All @@ -165,12 +215,12 @@ def main():
pb.update(1)

except Exception as e:
print(f"Error processing {pdf_path}: {str(e)}")
print(f"Error processing a PDF: {str(e)}")

# Close the last open file
cur_file.close()

# Print or log the number of PDFs that resulted in at least one output
# Print the number of PDFs that resulted in at least one output
print(f"Number of sampled PDFs that produced at least one output: {pdfs_with_output}")

if __name__ == "__main__":
Expand Down

0 comments on commit c2909f3

Please sign in to comment.