Skip to content

Commit

Permalink
Trying distilgpt2 instead of kenlm
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Sep 17, 2024
1 parent 01bc0b2 commit cb9b6ef
Showing 1 changed file with 60 additions and 14 deletions.
74 changes: 60 additions & 14 deletions pdelfin/filter/coherency.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,67 @@
# Uses a premade kenLM filter trained on good DCLM filtered web data to help identify pdfs where the
# content has been very poorly parsed
import kenlm

from functools import lru_cache
from cached_path import cached_path

KENLM_S3_PATH = "s3://ai2-oe-data/jakep/kenlm-dclm/5gramtok.bin"
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

@lru_cache()
def load_kenlm():
local_path = cached_path(KENLM_S3_PATH)
model = kenlm.Model(local_path)

return model
def load_coherency_model(model_name: str = "distilgpt2"):
"""
Loads the tokenizer and model, caching the result to avoid redundant loads.
Args:
model_name (str): The name of the pretrained model to load.
Returns:
tokenizer: The tokenizer associated with the model.
model: The pretrained causal language model.
"""
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.eval() # Set the model to evaluation mode

return tokenizer, model

def get_document_coherency(text: str) -> float:
model = load_kenlm()
"""
Calculates the coherency of a document based on the log likelihood of its tokens.
Handles texts longer than the model's maximum token limit by splitting them into chunks.
Args:
text (str): The input text to evaluate.
Returns:
float: The average log likelihood per token as a measure of coherency.
"""
tokenizer, model = load_coherency_model()

# Determine the model's maximum number of tokens
max_length = tokenizer.model_max_length - 1
# Some tokenizers have a default value indicating no limit; use model config if so
if max_length > 1_000_000:
max_length = model.config.max_position_embeddings

# Tokenize the entire text
tokens = tokenizer.encode(text, return_tensors='pt').squeeze(0)

total_log_likelihood = 0.0
total_tokens = 0

# Split tokens into chunks that fit within the model's max length
for i in range(0, len(tokens), max_length):
chunk = tokens[i:i + max_length]
inputs = chunk.unsqueeze(0) # Add batch dimension

# Move inputs to CPU (ensure compatibility)
inputs = {k: v.cpu() for k, v in {'input_ids': inputs}.items()}

with torch.no_grad():
outputs = model(**inputs, labels=inputs['input_ids'])
# Compute log likelihood for the chunk
log_likelihood = -outputs.loss.item() * chunk.size(0)
total_log_likelihood += log_likelihood
total_tokens += chunk.size(0)

# Calculate the average log likelihood per token
avg_log_likelihood = total_log_likelihood / total_tokens if total_tokens > 0 else 0.0

return avg_log_likelihood

return model.score(text)

0 comments on commit cb9b6ef

Please sign in to comment.