Skip to content

Commit

Permalink
Buildsilver script suppors reservoir sampling so it can sample 100M+ …
Browse files Browse the repository at this point in the history
…paths now efficiently
  • Loading branch information
jakep-allenai committed Sep 30, 2024
1 parent 8ec9e35 commit b4e9d6a
Showing 1 changed file with 49 additions and 14 deletions.
63 changes: 49 additions & 14 deletions pdelfin/silver_data/buildsilver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import argparse
import boto3
import json
from openai import OpenAI
from pypdf import PdfReader
from tqdm import tqdm
from typing import Generator
Expand All @@ -31,8 +30,6 @@ def _build_prompt(base_text: str) -> str:
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
)

# Initialize OpenAI client
openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
pdf_filter = PdfFilter()

def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int) -> dict:
Expand Down Expand Up @@ -145,30 +142,68 @@ def main():
parser.add_argument("--first_n_pages", type=int, default=0, help="Always sample the first N pages of each PDF.")
parser.add_argument("--max_sample_pages", type=int, default=15, help="Max number of pages to sample per PDF.")
parser.add_argument("--output", type=str, default="openai_batch_data", help="Output destination")
parser.add_argument("--reservoir_size", type=int, default=None,
help="Size of the reservoir for sampling paths. Defaults to 10x num_sample_docs.")
args = parser.parse_args()

# Load PDF paths from glob or path_list
# Set default reservoir_size if not provided
if args.reservoir_size is None:
args.reservoir_size = 10 * args.num_sample_docs

# Initialize reservoir sampling variables
pdf_paths = []
n = 0 # Total number of items seen

# Load PDF paths from glob or path_list using reservoir sampling
if args.glob_path:
if args.glob_path.startswith("s3://"):
# Handle S3 globbing using boto3
# Handle S3 globbing using boto3 with pagination
parsed = urlparse(args.glob_path)
s3 = boto3.client('s3')
bucket_name = parsed.netloc
prefix = os.path.dirname(parsed.path.lstrip('/')) + "/"
response = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
for obj in response.get('Contents', []):
if obj['Key'].endswith('.pdf'):
pdf_paths.append(f"s3://{bucket_name}/{obj['Key']}")
paginator = s3.get_paginator('list_objects_v2')
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)

for page in page_iterator:
for obj in page.get('Contents', []):
if obj['Key'].endswith('.pdf'):
n += 1
path = f"s3://{bucket_name}/{obj['Key']}"
if len(pdf_paths) < args.reservoir_size:
pdf_paths.append(path)
else:
s = random.randint(1, n)
if s <= args.reservoir_size:
pdf_paths[s - 1] = path
else:
# Handle local globbing
pdf_paths = glob.glob(args.glob_path)
# Handle local globbing using glob.iglob()
for path in glob.iglob(args.glob_path, recursive=True):
n += 1
if len(pdf_paths) < args.reservoir_size:
pdf_paths.append(path)
else:
s = random.randint(1, n)
if s <= args.reservoir_size:
pdf_paths[s - 1] = path
elif args.path_list:
with open(args.path_list, 'r') as f:
pdf_paths = [line.strip() for line in f]

for line in f:
n += 1
path = line.strip()
if len(pdf_paths) < args.reservoir_size:
pdf_paths.append(path)
else:
s = random.randint(1, n)
if s <= args.reservoir_size:
pdf_paths[s - 1] = path

# Shuffle the reservoir
random.shuffle(pdf_paths)

print(f"Loaded and shuffled {len(pdf_paths)} paths to use.")

# Rest of the code remains the same
cur_file_num = 0
output_dir = args.output
max_file_size = 99 * 1024 * 1024 # 99MB in bytes
Expand All @@ -184,7 +219,7 @@ def main():
# Counter to track PDFs that produce at least one output
pdfs_with_output = 0

# Using ThreadPoolExecutor to process files concurrently
# Using ThreadPoolExecutor to process files concurrently
with ThreadPoolExecutor(max_workers=60) as executor:
futures = []

Expand Down

0 comments on commit b4e9d6a

Please sign in to comment.