Skip to content

Commit

Permalink
gpt cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Oct 11, 2024
1 parent a45f86e commit 4fd6066
Showing 1 changed file with 39 additions and 36 deletions.
75 changes: 39 additions & 36 deletions pdelfin/assemblepipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from dataclasses import dataclass
from pypdf import PdfReader
from tqdm import tqdm
from typing import Optional
from typing import Optional, List, Tuple, Dict
from urllib.parse import urlparse
from concurrent.futures import ProcessPoolExecutor, as_completed

Expand All @@ -36,12 +36,12 @@ def _initialize_tables(self):
page_num INTEGER,
start_index BIGINT,
length BIGINT,
finish_reason STRING
error STRING
finish_reason TEXT,
error TEXT
)
""")
self.cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_path ON index_table(s3_path)
CREATE INDEX IF NOT EXISTS idx_path ON page_results(s3_path)
""")
self.cursor.execute("""
CREATE TABLE IF NOT EXISTS pdfs (
Expand All @@ -66,26 +66,26 @@ def _initialize_tables(self):

self.conn.commit()

def get_metadata(self, key: str) -> str:
def get_metadata(self, key: str) -> Optional[str]:
self.cursor.execute("SELECT value FROM metadata WHERE key=?", (key,))
result = self.cursor.fetchone()
return result[0]
return result[0] if result else None

def get_current_round(self):
return int(self.get_metadata("round"))
round_value = self.get_metadata("round")
return int(round_value) if round_value else 0

def is_file_processed(self, s3_path, etag):
self.cursor.execute("SELECT etag FROM processed_files WHERE s3_path = ?", (s3_path,))
result = self.cursor.fetchone()
return result is not None and result[0] == etag

def add_index_entries(self, index_entries):
# TODO MAke it take batchInferenceLines
def add_index_entries(self, index_entries: List['BatchInferenceLine']):
if index_entries:
self.cursor.executemany("""
INSERT INTO index_table (custom_id, s3_path, start_index, end_index)
VALUES (?, ?, ?, ?)
""", index_entries)
INSERT INTO page_results (s3_path, page_num, start_index, length, finish_reason, error)
VALUES (?, ?, ?, ?, ?, ?)
""", [(entry.s3_path, entry.page_num, entry.start_index, entry.length, entry.finish_reason, entry.error) for entry in index_entries])
self.conn.commit()

def update_processed_file(self, s3_path, etag):
Expand Down Expand Up @@ -125,8 +125,7 @@ def parse_s3_path(s3_path):
bucket, _, prefix = path.partition('/')
return bucket, prefix


def expand_s3_glob(s3_glob: str) -> dict[str, str]:
def expand_s3_glob(s3_glob: str) -> Dict[str, str]:
parsed = urlparse(s3_glob)
bucket_name = parsed.netloc
prefix = os.path.dirname(parsed.path.lstrip('/')).rstrip('/') + "/"
Expand All @@ -147,19 +146,18 @@ def expand_s3_glob(s3_glob: str) -> dict[str, str]:
@dataclass(frozen=True)
class BatchInferenceLine:
s3_path: str
page_num: int # 1 indexed!
page_num: int # 1 indexed!
start_index: int
length: int
finish_reason: str
error: Optional[str]
def parse_custom_id(custom_id: str) -> tuple[str, int]:

def parse_custom_id(custom_id: str) -> Tuple[str, int]:
s3_path = custom_id[:custom_id.rindex("-")]
page_num = int(custom_id[custom_id.rindex("-") + 1:])

return s3_path, page_num

def process_jsonl_content(s3_path) -> list[BatchInferenceLine]:
def process_jsonl_content(s3_path) -> List[BatchInferenceLine]:
content = get_s3_bytes(s3_path).decode("utf-8")

start_index = 0
Expand All @@ -174,31 +172,39 @@ def process_jsonl_content(s3_path) -> list[BatchInferenceLine]:

assert "outputs" in data and len(data["outputs"]) > 0, "No outputs from model detected"

index_entries.append(BatchInferenceLine(s3_path, page_num, start_index, line_length,
finish_reason=data["outputs"][0]["finish_reason"], error=data.get("completion_error", None)))
index_entries.append(BatchInferenceLine(
s3_path=s3_path,
page_num=page_num,
start_index=start_index,
length=line_length,
finish_reason=data["outputs"][0]["finish_reason"],
error=data.get("completion_error", None)
))
except json.JSONDecodeError:
pass # Handle JSON decode errors if necessary
except Exception as e:
print(f"Error processing line: {e}")

start_index = start_index + line_length
start_index += line_length

return index_entries

def get_s3_bytes(s3_path: str, start_index: Optional[int] = None, end_index: Optional[int] = None) -> bytes:
bucket, key = parse_s3_path(s3_path)

# Build the range header if start_index and/or end_index are specified
range_header = None
if start_index is not None or end_index is not None:
range_value = f"bytes={start_index or 0}-"
if end_index is not None:
range_value += str(end_index)
range_header = {'Range': range_value}

if range_header:
obj = s3.get_object(Bucket=bucket, Key=key, Range=range_header['Range'])
else:
obj = s3.get_object(Bucket=bucket, Key=key)

return obj['Body'].read()

def get_pdf_num_pages(s3_path: str) -> Optional[int]:
Expand All @@ -211,9 +217,8 @@ def get_pdf_num_pages(s3_path: str) -> Optional[int]:
return reader.get_num_pages()
except Exception as ex:
print(f"Warning, could not add {s3_path} due to {ex}")

return None


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
Expand Down Expand Up @@ -244,12 +249,12 @@ def get_pdf_num_pages(s3_path: str) -> Optional[int]:
future_to_path = {executor.submit(get_pdf_num_pages, s3_path): s3_path for s3_path in all_pdfs}
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
s3_path = future_to_path[future]
if future.result() and not db.pdf_exists(s3_path):
db.add_pdf(s3_path, future.result(), "pending")
num_pages = future.result()
if num_pages and not db.pdf_exists(s3_path):
db.add_pdf(s3_path, num_pages, "pending")

print("\n")


# Now build an index of all the pages that were processed within the workspace so far
inference_output_paths = expand_s3_glob(f"{args.workspace}/inference_outputs/*.jsonl")

Expand All @@ -258,14 +263,12 @@ def get_pdf_num_pages(s3_path: str) -> Optional[int]:
if not db.is_file_processed(key, etag)
]

future_to_path = {executor.submit(process_jsonl_content, s3_path): s3_path for s3_path, etag in inference_output_paths}
# Adjust the future_to_path to include etag
future_to_path = {executor.submit(process_jsonl_content, s3_path): (s3_path, etag) for s3_path, etag in inference_output_paths}

for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
s3_path = future_to_path[future]

s3_path, etag = future_to_path[future]
inference_lines = future.result()

db.add_index_entries(inference_lines)

db.update_processed_file(s3_path, etag=TODO)

db.update_processed_file(s3_path, etag=etag)

0 comments on commit 4fd6066

Please sign in to comment.