Skip to content

Commit

Permalink
First stab at document assembly
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Oct 9, 2024
1 parent 847064f commit c6bdf69
Showing 1 changed file with 105 additions and 53 deletions.
158 changes: 105 additions & 53 deletions pdelfin/assemblepipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,93 @@ def parse_s3_path(s3_path):
bucket_name, _, key = s3_path.partition('/')
return bucket_name, key

def process_document(s3_path, entries, output_dir):
"""
Processes a single document:
- Downloads the PDF
- Validates and assembles text
- Writes the output JSON if successful
- Returns processing results for aggregation
"""
try:
# Download the PDF locally
pdf_local_path = cached_path(s3_path, quiet=True)
pdf = PdfReader(pdf_local_path)
total_pages_in_pdf = len(pdf.pages)
except Exception as e:
logging.error(f"Error downloading or reading PDF {s3_path}: {e}")
return {
'processed': 1,
'successful_documents': 0,
'successful_pages': 0,
'total_pages': 0
}

# Build mapping from pagenum to entry
entry_by_pagenum = {entry.pagenum: entry for entry in entries}

valid_entries = []
missing_pages = []
errors = []

# Iterate from 1 to total_pages_in_pdf inclusive
for page_num in range(1, total_pages_in_pdf + 1):
entry = entry_by_pagenum.get(page_num)
if entry is None:
missing_pages.append(page_num)
elif entry.error is not None or entry.finish_reason != 'stop':
errors.append(entry)
else:
valid_entries.append(entry)

if not missing_pages and not errors:
# Assemble text
valid_entries_sorted = sorted(valid_entries, key=lambda x: x.pagenum)
text = '\n'.join(entry.text for entry in valid_entries_sorted if entry.text)

# Generate a filename based on the s3_path
doc_hash = hashlib.md5(s3_path.encode('utf-8')).hexdigest()
output_filename = os.path.join(output_dir, f'{doc_hash}.json')

output_data = {
'source': s3_path,
'total_pages': total_pages_in_pdf,
'text': text
}

try:
with open(output_filename, 'w') as f_out:
json.dump(output_data, f_out)
return {
'processed': 1,
'successful_documents': 1,
'successful_pages': len(valid_entries),
'total_pages': total_pages_in_pdf
}
except Exception as e:
logging.error(f"Error writing output file {output_filename}: {e}")
return {
'processed': 1,
'successful_documents': 0,
'successful_pages': 0,
'total_pages': total_pages_in_pdf
}
else:
missing = [page for page in missing_pages]
error_pages = [e.pagenum for e in errors]
logging.info(f'Document {s3_path} has missing pages: {missing} or errors in pages: {error_pages}')
return {
'processed': 1,
'successful_documents': 0,
'successful_pages': len(valid_entries),
'total_pages': total_pages_in_pdf
}

def main():
parser = argparse.ArgumentParser(description='Process finished birr inference outputs into dolma docs')
parser.add_argument('s3_path', help='S3 path to the directory containing JSON or JSONL files')
parser.add_argument('--output_dir', default='output', help='Directory to save the output files')
parser.add_argument('--max_workers', type=int, default=8, help='Maximum number of worker threads')
args = parser.parse_args()

# Set up logging
Expand Down Expand Up @@ -135,59 +218,28 @@ def main():
total_pages = 0
successful_pages = 0

print("Processing documents...")
for s3_path, entries in tqdm(documents.items()):
try:
# Download the PDF locally
pdf_local_path = cached_path(s3_path, quiet=True)

pdf = PdfReader(pdf_local_path)
total_pages_in_pdf = len(pdf.pages)
except Exception as e:
logging.error(f"Error downloading or reading PDF {s3_path}: {e}")
continue

total_pages += total_pages_in_pdf

# Build mapping from pagenum to entry
entry_by_pagenum = {entry.pagenum: entry for entry in entries}

valid_entries = []
missing_pages = []
errors = []

for page_num in range(total_pages_in_pdf):
entry = entry_by_pagenum.get(page_num)
if entry is None:
missing_pages.append(page_num)
elif entry.error is not None or entry.finish_reason != 'stop':
errors.append(entry)
else:
valid_entries.append(entry)

successful_pages += len(valid_entries)

if not missing_pages and not errors:
# Assemble text
valid_entries_sorted = sorted(valid_entries, key=lambda x: x.pagenum)
text = '\n'.join(entry.text for entry in valid_entries_sorted)

# Generate a filename based on the s3_path
doc_hash = hashlib.md5(s3_path.encode('utf-8')).hexdigest()
output_filename = os.path.join(args.output_dir, f'{doc_hash}.json')

output_data = {
'source': s3_path,
'total_pages': total_pages_in_pdf,
'text': text
}

with open(output_filename, 'w') as f_out:
json.dump(output_data, f_out)

successful_documents += 1
else:
logging.info(f'Document {s3_path} has missing pages: {missing_pages} or errors in pages: {[e.pagenum for e in errors]}')
print("Processing documents with ThreadPoolExecutor...")
with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
# Prepare futures
future_to_s3 = {
executor.submit(
process_document,
s3_path,
entries,
args.output_dir
): s3_path for s3_path, entries in documents.items()
}

# Use tqdm to display progress
for future in tqdm(as_completed(future_to_s3), total=len(future_to_s3)):
try:
result = future.result()
successful_documents += result.get('successful_documents', 0)
successful_pages += result.get('successful_pages', 0)
total_pages += result.get('total_pages', 0)
except Exception as e:
s3_path = future_to_s3[future]
logging.error(f"Error processing document {s3_path}: {e}")

print(f'Total documents: {total_documents}')
print(f'Successful documents: {successful_documents}')
Expand Down

0 comments on commit c6bdf69

Please sign in to comment.