Skip to content

Commit

Permalink
Anchor is fixed to sample text elements better
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Oct 8, 2024
1 parent c8a4d14 commit 97291b3
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 19 deletions.
112 changes: 96 additions & 16 deletions pdelfin/prompts/anchor.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ def _transform_point(x, y, m):
y_new = m[1]*x + m[3]*y + m[5]
return x_new, y_new

@dataclass
@dataclass(frozen=True)
class Element:
pass

@dataclass
@dataclass(frozen=True)
class BoundingBox:
x0: float
y0: float
Expand All @@ -102,23 +102,24 @@ class BoundingBox:
def from_rectangle(rect: RectangleObject) -> "BoundingBox":
return BoundingBox(rect[0], rect[1], rect[2], rect[3])

@dataclass
@dataclass(frozen=True)
class TextElement(Element):
text: str
x: float
y: float

@dataclass
@dataclass(frozen=True)
class ImageElement(Element):
name: str
bbox: BoundingBox

@dataclass
@dataclass(frozen=True)
class PageReport:
mediabox: BoundingBox
text_elements: List[TextElement]
image_elements: List[ImageElement]


def _pdf_report(local_pdf_path: str, page: int) -> PageReport:
reader = PdfReader(local_pdf_path)
page = reader.pages[page - 1]
Expand Down Expand Up @@ -219,27 +220,106 @@ def bboxes_overlap(b1: BoundingBox, b2: BoundingBox, tolerance: float) -> bool:
return merged_images


def _linearize_pdf_report(report: PageReport) -> str:
def _linearize_pdf_report(report: PageReport, max_length: int = 4000) -> str:
result = ""

result += f"Page dimensions: {report.mediabox.x1:.1f}x{report.mediabox.y1:.1f}\n"

#images = report.image_elements
images = _merge_image_elements(report.image_elements)

for index, element in enumerate(images):
result += f"[Image {element.bbox.x0:.0f}x{element.bbox.y0:.0f} to {element.bbox.x1:.0f}x{element.bbox.y1:.0f}]"
# Process image elements
image_strings = []
for element in images:
image_str = f"[Image {element.bbox.x0:.0f}x{element.bbox.y0:.0f} to {element.bbox.x1:.0f}x{element.bbox.y1:.0f}]"
# Use element's unique identifier (e.g., id or position) for comparison
image_strings.append((element, image_str))

for index, element in enumerate(report.text_elements):
# Process text elements
text_strings = []
for element in report.text_elements:
if len(element.text.strip()) == 0:
continue

element_text = ftfy.fix_text(element.text)
# Replace square brackets with something else not to throw off the syntax
element_text = element_text.replace("[", "\[").replace("]", "\[")

# Need to use ftfy to fix text, because occasionally there are invalid surrogate pairs and other UTF issues that cause
# pyarrow to fail to load the json later
result += f"[{element.x:.0f}x{element.y:.0f}]{element_text}"
# Replace square brackets with escaped brackets
element_text = element_text.replace("[", "\\[").replace("]", "\\]")

text_str = f"[{element.x:.0f}x{element.y:.0f}]{element_text}"
text_strings.append((element, text_str))

# Combine all elements with their positions for sorting
all_elements = []
for elem, s in image_strings:
position = (elem.bbox.x0, elem.bbox.y0)
all_elements.append(('image', elem, s, position))
for elem, s in text_strings:
position = (elem.x, elem.y)
all_elements.append(('text', elem, s, position))

# Calculate total length
total_length = len(result) + sum(len(s) for _, _, s, _ in all_elements)

if total_length <= max_length:
# Include all elements
for _, _, s, _ in all_elements:
result += s
return result

# Identify elements with min/max coordinates
edge_elements = set()

if images:
min_x0_image = min(images, key=lambda e: e.bbox.x0)
max_x1_image = max(images, key=lambda e: e.bbox.x1)
min_y0_image = min(images, key=lambda e: e.bbox.y0)
max_y1_image = max(images, key=lambda e: e.bbox.y1)
edge_elements.update([min_x0_image, max_x1_image, min_y0_image, max_y1_image])

if report.text_elements:
text_elements = [e for e in report.text_elements if len(e.text.strip()) > 0]
min_x_text = min(text_elements, key=lambda e: e.x)
max_x_text = max(text_elements, key=lambda e: e.x)
min_y_text = min(text_elements, key=lambda e: e.y)
max_y_text = max(text_elements, key=lambda e: e.y)
edge_elements.update([min_x_text, max_x_text, min_y_text, max_y_text])

# Keep track of element IDs to prevent duplication
selected_element_ids = set()
selected_elements = []

# Include edge elements first
for elem_type, elem, s, position in all_elements:
if elem in edge_elements and id(elem) not in selected_element_ids:
selected_elements.append((elem_type, elem, s, position))
selected_element_ids.add(id(elem))

# Calculate remaining length
current_length = len(result) + sum(len(s) for _, _, s, _ in selected_elements)
remaining_length = max_length - current_length

# Exclude edge elements from the pool
remaining_elements = [
(elem_type, elem, s, position) for elem_type, elem, s, position in all_elements
if id(elem) not in selected_element_ids
]

# Sort remaining elements by their positions (e.g., x-coordinate and then y-coordinate)
remaining_elements.sort(key=lambda x: (x[3][0], x[3][1]))

# Add elements until reaching max_length
for elem_type, elem, s, position in remaining_elements:
if current_length + len(s) > max_length:
break
selected_elements.append((elem_type, elem, s, position))
selected_element_ids.add(id(elem))
current_length += len(s)

# Sort selected elements by their positions to maintain logical order
selected_elements.sort(key=lambda x: (x[3][0], x[3][1]))

# Build the final result
for _, _, s, _ in selected_elements:
result += s

return result

6 changes: 3 additions & 3 deletions pdelfin/silver_data/convertsilver_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,10 @@ def main():
description="Transform JSONL files by extracting and renaming specific fields."
)
parser.add_argument(
'--rewrite_finetuning_prompt',
'--rewrite_prompt',
action='store_true',
default=False,
help="Rewrites the input prompt from standard OPENAI instruction format into our finetuned format"
help="Rewrites the input prompt by reloading the pdf from source"
)
parser.add_argument(
'input_dir',
Expand Down Expand Up @@ -233,7 +233,7 @@ def main():
# Process files in parallel
with ProcessPoolExecutor(max_workers=max_jobs) as executor:
future_to_file = {
executor.submit(process_file, input_file, output_file, args.rewrite_finetuning_prompt): input_file
executor.submit(process_file, input_file, output_file, args.rewrite_prompt): input_file
for input_file, output_file in tasks
}

Expand Down
3 changes: 3 additions & 0 deletions tests/test_anchor.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def testLargePromptHint1(self):

print(anchor_text)
print(len(anchor_text))
self.assertLess(len(anchor_text), 1000)

def testLargePromptHint2(self):
local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "large_prompt_hint2.pdf")
Expand All @@ -81,6 +82,7 @@ def testLargePromptHint2(self):

print(anchor_text)
print(len(anchor_text))
self.assertLess(len(anchor_text), 4000)

def testNewsPaperPromptHint(self):
local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "newspaper.pdf")
Expand All @@ -89,6 +91,7 @@ def testNewsPaperPromptHint(self):

print(anchor_text)
print(len(anchor_text))
self.assertLess(len(anchor_text), 4000)



Expand Down

0 comments on commit 97291b3

Please sign in to comment.