Skip to content

Commit

Permalink
dbmanager
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Oct 11, 2024
1 parent 2dccc4b commit f477a68
Showing 1 changed file with 63 additions and 41 deletions.
104 changes: 63 additions & 41 deletions pdelfin/assemblepipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,64 @@
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed

class DatabaseManager:
def __init__(self, db_path):
self.db_path = db_path
self.conn = sqlite3.connect(self.db_path)
self.cursor = self.conn.cursor()
self._initialize_tables()

def _initialize_tables(self):
self.cursor.execute("""
CREATE TABLE IF NOT EXISTS index_table (
custom_id TEXT,
s3_path TEXT,
start_index BIGINT,
end_index BIGINT
)
""")
self.cursor.execute("""
CREATE TABLE IF NOT EXISTS processed_files (
s3_path TEXT PRIMARY KEY,
etag TEXT
)
""")
self.conn.commit()

def is_file_processed(self, s3_path, etag):
self.cursor.execute("SELECT etag FROM processed_files WHERE s3_path = ?", (s3_path,))
result = self.cursor.fetchone()
return result is not None and result[0] == etag

def add_index_entries(self, index_entries):
if index_entries:
self.cursor.executemany("""
INSERT INTO index_table (custom_id, s3_path, start_index, end_index)
VALUES (?, ?, ?, ?)
""", index_entries)
self.conn.commit()

def update_processed_file(self, s3_path, etag):
self.cursor.execute("""
INSERT INTO processed_files (s3_path, etag)
VALUES (?, ?)
ON CONFLICT(s3_path) DO UPDATE SET etag=excluded.etag
""", (s3_path, etag))
self.conn.commit()

def close(self):
self.conn.close()

def build_index(s3_path):
# Hash the s3_path to get a cache key
cache_key = hashlib.sha256(s3_path.encode('utf-8')).hexdigest()
home_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', cache_key)
os.makedirs(home_cache_dir, exist_ok=True)
db_path = os.path.join(home_cache_dir, 'index.db')

# Connect to sqlite and create tables if not exist
# Initialize the database manager
print("Building page index at", db_path)
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS index_table (
custom_id TEXT,
s3_path TEXT,
start_index BIGINT,
end_index BIGINT
)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS processed_files (
s3_path TEXT PRIMARY KEY,
etag TEXT
)
""")
conn.commit()
db_manager = DatabaseManager(db_path)

s3 = boto3.client('s3')
bucket, prefix = parse_s3_path(s3_path)
Expand All @@ -42,42 +74,34 @@ def build_index(s3_path):

if not files:
print("No .json or .jsonl files found in the specified S3 path.")
db_manager.close()
return

# Prepare a list of files that need processing
files_to_process = []
for key, etag in files.items():
cursor.execute("SELECT etag FROM processed_files WHERE s3_path = ?", (key,))
db_result = cursor.fetchone()
if db_result and db_result[0] == etag:
# File has already been processed with the same ETag
pass # Skip
else:
files_to_process.append((key, etag))
files_to_process = [
(key, etag) for key, etag in files.items()
if not db_manager.is_file_processed(key, etag)
]

if not files_to_process:
print("All files are up to date. No processing needed.")
db_manager.close()
return

# Use ProcessPoolExecutor to process files with tqdm progress bar
with ProcessPoolExecutor() as executor:
futures = [executor.submit(process_file, bucket, key, etag) for key, etag in files_to_process]
futures = [
executor.submit(process_file, bucket, key, etag)
for key, etag in files_to_process
]
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing files"):
s3_path, key, etag, index_entries = future.result()
if index_entries:
cursor.executemany("""
INSERT INTO index_table (custom_id, s3_path, start_index, end_index)
VALUES (?, ?, ?, ?)
""", index_entries)
db_manager.add_index_entries(index_entries)
# Update the processed_files table
cursor.execute("""
INSERT INTO processed_files (s3_path, etag)
VALUES (?, ?)
ON CONFLICT(s3_path) DO UPDATE SET etag=excluded.etag
""", (key, etag))
conn.commit()
db_manager.update_processed_file(key, etag)

conn.close()
db_manager.close()

def parse_s3_path(s3_path):
if not s3_path.startswith('s3://'):
Expand Down Expand Up @@ -139,5 +163,3 @@ def process_jsonl_content(content, s3_path):

# Step one, build an index of all the pages that were processed
build_index(args.s3_path)


0 comments on commit f477a68

Please sign in to comment.