Skip to content

Commit

Permalink
Ok, finally working nicely to build the page index
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Oct 10, 2024
1 parent 312ee8d commit 312847a
Showing 1 changed file with 53 additions and 49 deletions.
102 changes: 53 additions & 49 deletions pdelfin/assemblepipeline.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os
import hashlib
import boto3
import duckdb
import sqlite3
import json
import argparse
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from concurrent.futures import ProcessPoolExecutor, as_completed

def build_index(s3_path):
# Hash the s3_path to get a cache key
Expand All @@ -14,9 +14,9 @@ def build_index(s3_path):
os.makedirs(cache_dir, exist_ok=True)
db_path = os.path.join(cache_dir, 'index.db')

# Connect to duckdb and create tables if not exist
# Connect to sqlite and create tables if not exist
print("Building page index at", db_path)
conn = duckdb.connect(database=db_path)
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS index_table (
Expand All @@ -33,7 +33,6 @@ def build_index(s3_path):
)
""")
conn.commit()
conn.close()

s3 = boto3.client('s3')
bucket, prefix = parse_s3_path(s3_path)
Expand All @@ -45,11 +44,40 @@ def build_index(s3_path):
print("No .json or .jsonl files found in the specified S3 path.")
return

# Use ThreadPoolExecutor to process files with tqdm progress bar
with ThreadPoolExecutor() as executor:
futures = [executor.submit(process_file, s3, bucket, key, etag, db_path) for key, etag in files.items()]
for _ in tqdm(as_completed(futures), total=len(futures), desc="Processing files"):
pass
# Prepare a list of files that need processing
files_to_process = []
for key, etag in files.items():
cursor.execute("SELECT etag FROM processed_files WHERE s3_path = ?", (key,))
db_result = cursor.fetchone()
if db_result and db_result[0] == etag:
# File has already been processed with the same ETag
pass # Skip
else:
files_to_process.append((key, etag))

if not files_to_process:
print("All files are up to date. No processing needed.")
return

# Use ProcessPoolExecutor to process files with tqdm progress bar
with ProcessPoolExecutor() as executor:
futures = [executor.submit(process_file, bucket, key, etag) for key, etag in files_to_process]
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing files"):
s3_path, key, etag, index_entries = future.result()
if index_entries:
cursor.executemany("""
INSERT INTO index_table (custom_id, s3_path, start_index, end_index)
VALUES (?, ?, ?, ?)
""", index_entries)
# Update the processed_files table
cursor.execute("""
INSERT INTO processed_files (s3_path, etag)
VALUES (?, ?)
ON CONFLICT(s3_path) DO UPDATE SET etag=excluded.etag
""", (key, etag))
conn.commit()

conn.close()

def parse_s3_path(s3_path):
if not s3_path.startswith('s3://'):
Expand All @@ -71,47 +99,25 @@ def list_s3_files(s3, bucket, prefix):
files[key] = obj['ETag'].strip('"')
return files

def process_file(s3, bucket, key, etag, db_path):
def process_file(bucket, key, etag):
s3 = boto3.client('s3') # Initialize s3 client in the worker process
s3_path = f's3://{bucket}/{key}'
try:
# Connect to duckdb
conn = duckdb.connect(database=db_path)
cursor = conn.cursor()

# Check if file has already been processed with the same ETag
cursor.execute("SELECT etag FROM processed_files WHERE s3_path = ?", (key,))
result = cursor.fetchone()

if result and result[0] == etag:
# File has already been processed with the same ETag
# Optionally, log that the file was skipped
# print(f"Skipping already processed file: {s3_path}")
conn.close()
return
else:
# Get the object
obj = s3.get_object(Bucket=bucket, Key=key)

# Read the content as bytes
content = obj['Body'].read()

# Process the file as JSONL
process_jsonl_content(content, s3_path, cursor)

# Update the processed_files table
cursor.execute("""
INSERT INTO processed_files (s3_path, etag)
VALUES (?, ?)
ON CONFLICT (s3_path) DO UPDATE SET etag=excluded.etag
""", (key, etag))

conn.commit()
conn.close()
# Get the object
obj = s3.get_object(Bucket=bucket, Key=key)
# Read the content as bytes
content = obj['Body'].read()
# Process the file as JSONL
index_entries = process_jsonl_content(content, s3_path)
# Return the necessary data to the main process
return s3_path, key, etag, index_entries
except Exception as e:
print(f"Error processing file {s3_path}: {e}")
return s3_path, key, etag, []

def process_jsonl_content(content, s3_path, cursor):
def process_jsonl_content(content, s3_path):
start_index = 0
index_entries = []
lines = content.splitlines(keepends=True)
for line in lines:
line_length = len(line)
Expand All @@ -120,13 +126,11 @@ def process_jsonl_content(content, s3_path, cursor):
data = json.loads(line)
custom_id = data.get('custom_id')
if custom_id:
cursor.execute("""
INSERT INTO index_table (custom_id, s3_path, start_index, end_index)
VALUES (?, ?, ?, ?)
""", (custom_id, s3_path, start_index, end_index))
index_entries.append((custom_id, s3_path, start_index, end_index))
except json.JSONDecodeError:
pass # Handle JSON decode errors if necessary
start_index = end_index
return index_entries

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Build a local index of JSON files from S3.')
Expand Down

0 comments on commit 312847a

Please sign in to comment.