Skip to content

Commit

Permalink
Testing coherence with distilgpt2, but it doesn't work great
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Sep 17, 2024
1 parent cb9b6ef commit 57e80aa
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
10 changes: 0 additions & 10 deletions pdelfin/filter/coherency.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,6 @@

@lru_cache()
def load_coherency_model(model_name: str = "distilgpt2"):
"""
Loads the tokenizer and model, caching the result to avoid redundant loads.
Args:
model_name (str): The name of the pretrained model to load.
Returns:
tokenizer: The tokenizer associated with the model.
model: The pretrained causal language model.
"""
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.eval() # Set the model to evaluation mode
Expand Down
17 changes: 14 additions & 3 deletions tests/test_coherency.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,24 @@
import unittest

from pdelfin.filter.coherency import get_document_coherency
from pdelfin.extract_text import get_document_text
from pdelfin.extract_text import get_document_text, get_page_text


class TestCoherencyScores(unittest.TestCase):
def testBadOcr1(self):
good_text = get_document_text(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "instructions_and_schematics.pdf"))
bad_text = get_document_text(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "handwriting_bad_ocr.pdf"))
ocr1_text = get_document_text(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "handwriting_bad_ocr.pdf"))
ocr2_text = get_document_text(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "some_ocr1.pdf"))

print("Good", get_document_coherency(good_text))
print("Bad", get_document_coherency(bad_text))
print("Bad1", get_document_coherency(ocr1_text))
print("Bad2", get_document_coherency(ocr2_text))

def testTwoColumnMisparse(self):
pdftotext_text = get_page_text(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "pdftotext_two_column_issue.pdf"), page_num=2, pdf_engine="pdftotext")
pymupdf_text = get_page_text(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "pdftotext_two_column_issue.pdf"), page_num=2, pdf_engine="pymupdf")

print("pdftotext_text", get_document_coherency(pdftotext_text))
print("pymupdf_text", get_document_coherency(pymupdf_text))


0 comments on commit 57e80aa

Please sign in to comment.