generated from allenai/python-package-template
-
Notifications
You must be signed in to change notification settings - Fork 397
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
01bc0b2
commit cb9b6ef
Showing
1 changed file
with
60 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,67 @@ | ||
# Uses a premade kenLM filter trained on good DCLM filtered web data to help identify pdfs where the | ||
# content has been very poorly parsed | ||
import kenlm | ||
|
||
from functools import lru_cache | ||
from cached_path import cached_path | ||
|
||
KENLM_S3_PATH = "s3://ai2-oe-data/jakep/kenlm-dclm/5gramtok.bin" | ||
from transformers import AutoTokenizer, AutoModelForCausalLM | ||
import torch | ||
|
||
@lru_cache() | ||
def load_kenlm(): | ||
local_path = cached_path(KENLM_S3_PATH) | ||
model = kenlm.Model(local_path) | ||
|
||
return model | ||
def load_coherency_model(model_name: str = "distilgpt2"): | ||
""" | ||
Loads the tokenizer and model, caching the result to avoid redundant loads. | ||
Args: | ||
model_name (str): The name of the pretrained model to load. | ||
Returns: | ||
tokenizer: The tokenizer associated with the model. | ||
model: The pretrained causal language model. | ||
""" | ||
tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
model = AutoModelForCausalLM.from_pretrained(model_name) | ||
model.eval() # Set the model to evaluation mode | ||
|
||
return tokenizer, model | ||
|
||
def get_document_coherency(text: str) -> float: | ||
model = load_kenlm() | ||
""" | ||
Calculates the coherency of a document based on the log likelihood of its tokens. | ||
Handles texts longer than the model's maximum token limit by splitting them into chunks. | ||
Args: | ||
text (str): The input text to evaluate. | ||
Returns: | ||
float: The average log likelihood per token as a measure of coherency. | ||
""" | ||
tokenizer, model = load_coherency_model() | ||
|
||
# Determine the model's maximum number of tokens | ||
max_length = tokenizer.model_max_length - 1 | ||
# Some tokenizers have a default value indicating no limit; use model config if so | ||
if max_length > 1_000_000: | ||
max_length = model.config.max_position_embeddings | ||
|
||
# Tokenize the entire text | ||
tokens = tokenizer.encode(text, return_tensors='pt').squeeze(0) | ||
|
||
total_log_likelihood = 0.0 | ||
total_tokens = 0 | ||
|
||
# Split tokens into chunks that fit within the model's max length | ||
for i in range(0, len(tokens), max_length): | ||
chunk = tokens[i:i + max_length] | ||
inputs = chunk.unsqueeze(0) # Add batch dimension | ||
|
||
# Move inputs to CPU (ensure compatibility) | ||
inputs = {k: v.cpu() for k, v in {'input_ids': inputs}.items()} | ||
|
||
with torch.no_grad(): | ||
outputs = model(**inputs, labels=inputs['input_ids']) | ||
# Compute log likelihood for the chunk | ||
log_likelihood = -outputs.loss.item() * chunk.size(0) | ||
total_log_likelihood += log_likelihood | ||
total_tokens += chunk.size(0) | ||
|
||
# Calculate the average log likelihood per token | ||
avg_log_likelihood = total_log_likelihood / total_tokens if total_tokens > 0 else 0.0 | ||
|
||
return avg_log_likelihood | ||
|
||
return model.score(text) |