From 16fba3688b168f5d3eeade1b3d8c12b2fa5d7298 Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Thu, 2 Jan 2025 14:13:23 +0000 Subject: [PATCH] improve block weighting --- test.py | 11 +++++++++++ wtpsplit/__init__.py | 21 +++++++++++++++++++++ wtpsplit/evaluation/__init__.py | 3 +++ wtpsplit/extract.py | 16 +++++++++++++--- wtpsplit/train/evaluate.py | 3 +++ 5 files changed, 51 insertions(+), 3 deletions(-) diff --git a/test.py b/test.py index 70de4a25..20f662e1 100644 --- a/test.py +++ b/test.py @@ -1,6 +1,17 @@ # noqa: E501 from wtpsplit import WtP, SaT +def test_weighting(): + sat = SaT("sat-3l-sm", ort_providers=["CPUExecutionProvider"]) + + text = "This is a test sentence This is another test sentence." + splits_default = sat.split(text, threshold=0.25) + splits_uniform = sat.split(text, threshold=0.25, weighting="uniform") + splits_hat = sat.split(text, threshold=0.25, weighting="hat") + expected_splits = ["This is a test sentence ", "This is another test sentence."] + assert splits_default == splits_uniform == splits_hat == expected_splits + assert "".join(splits_default) == text + def test_split_ort(): sat = SaT("sat-3l-sm", ort_providers=["CPUExecutionProvider"]) diff --git a/wtpsplit/__init__.py b/wtpsplit/__init__.py index 33dd199b..fe9d4c97 100644 --- a/wtpsplit/__init__.py +++ b/wtpsplit/__init__.py @@ -3,6 +3,7 @@ import os import warnings from pathlib import Path +from typing import Literal # avoid the "None of PyTorch, TensorFlow, etc. have been found" warning. with contextlib.redirect_stderr(open(os.devnull, "w")): @@ -141,6 +142,7 @@ def predict_proba( block_size: int = 512, batch_size=32, pad_last_batch: bool = False, + weighting: Literal["uniform", "hat"] = "uniform", remove_whitespace_before_inference: bool = False, outer_batch_size=1000, return_paragraph_probabilities=False, @@ -156,6 +158,7 @@ def predict_proba( block_size=block_size, batch_size=batch_size, pad_last_batch=pad_last_batch, + weighting=weighting, remove_whitespace_before_inference=remove_whitespace_before_inference, outer_batch_size=outer_batch_size, return_paragraph_probabilities=return_paragraph_probabilities, @@ -171,6 +174,7 @@ def predict_proba( block_size=block_size, batch_size=batch_size, pad_last_batch=pad_last_batch, + weighting=weighting, remove_whitespace_before_inference=remove_whitespace_before_inference, outer_batch_size=outer_batch_size, return_paragraph_probabilities=return_paragraph_probabilities, @@ -186,6 +190,7 @@ def _predict_proba( block_size: int, batch_size: int, pad_last_batch: bool, + weighting: Literal["uniform", "hat"], remove_whitespace_before_inference: bool, outer_batch_size: int, return_paragraph_probabilities: bool, @@ -246,6 +251,7 @@ def _predict_proba( max_block_size=block_size, batch_size=batch_size, pad_last_batch=pad_last_batch, + weighting=weighting, verbose=verbose, )[0] else: @@ -290,6 +296,7 @@ def split( block_size: int = 512, batch_size=32, pad_last_batch: bool = False, + weighting: Literal["uniform", "hat"] = "uniform", remove_whitespace_before_inference: bool = False, outer_batch_size=1000, paragraph_threshold: float = 0.5, @@ -308,6 +315,7 @@ def split( block_size=block_size, batch_size=batch_size, pad_last_batch=pad_last_batch, + weighting=weighting, remove_whitespace_before_inference=remove_whitespace_before_inference, outer_batch_size=outer_batch_size, paragraph_threshold=paragraph_threshold, @@ -326,6 +334,7 @@ def split( block_size=block_size, batch_size=batch_size, pad_last_batch=pad_last_batch, + weighting=weighting, remove_whitespace_before_inference=remove_whitespace_before_inference, outer_batch_size=outer_batch_size, paragraph_threshold=paragraph_threshold, @@ -355,6 +364,7 @@ def _split( block_size: int, batch_size: int, pad_last_batch: bool, + weighting: Literal["uniform", "hat"], remove_whitespace_before_inference: bool, outer_batch_size: int, paragraph_threshold: float, @@ -391,6 +401,7 @@ def _split( block_size=block_size, batch_size=batch_size, pad_last_batch=pad_last_batch, + weighting=weighting, remove_whitespace_before_inference=remove_whitespace_before_inference, outer_batch_size=outer_batch_size, return_paragraph_probabilities=do_paragraph_segmentation, @@ -573,6 +584,7 @@ def predict_proba( block_size: int = 512, batch_size=32, pad_last_batch: bool = False, + weighting: Literal["uniform", "hat"] = "uniform", remove_whitespace_before_inference: bool = False, outer_batch_size=1000, return_paragraph_probabilities=False, @@ -586,6 +598,7 @@ def predict_proba( block_size=block_size, batch_size=batch_size, pad_last_batch=pad_last_batch, + weighting=weighting, remove_whitespace_before_inference=remove_whitespace_before_inference, outer_batch_size=outer_batch_size, return_paragraph_probabilities=return_paragraph_probabilities, @@ -599,6 +612,7 @@ def predict_proba( block_size=block_size, batch_size=batch_size, pad_last_batch=pad_last_batch, + weighting=weighting, remove_whitespace_before_inference=remove_whitespace_before_inference, outer_batch_size=outer_batch_size, return_paragraph_probabilities=return_paragraph_probabilities, @@ -612,6 +626,7 @@ def _predict_proba( block_size: int, batch_size: int, pad_last_batch: bool, + weighting: Literal["uniform", "hat"], remove_whitespace_before_inference: bool, outer_batch_size: int, return_paragraph_probabilities: bool, @@ -657,6 +672,7 @@ def newline_probability_fn(logits): max_block_size=block_size, batch_size=batch_size, pad_last_batch=pad_last_batch, + weighting=weighting, verbose=verbose, tokenizer=self.tokenizer, ) @@ -705,6 +721,7 @@ def split( block_size: int = 512, batch_size=32, pad_last_batch: bool = False, + weighting: Literal["uniform", "hat"] = "uniform", remove_whitespace_before_inference: bool = False, outer_batch_size=1000, paragraph_threshold: float = 0.5, @@ -722,6 +739,7 @@ def split( block_size=block_size, batch_size=batch_size, pad_last_batch=pad_last_batch, + weighting=weighting, remove_whitespace_before_inference=remove_whitespace_before_inference, outer_batch_size=outer_batch_size, paragraph_threshold=paragraph_threshold, @@ -739,6 +757,7 @@ def split( block_size=block_size, batch_size=batch_size, pad_last_batch=pad_last_batch, + weighting=weighting, remove_whitespace_before_inference=remove_whitespace_before_inference, outer_batch_size=outer_batch_size, paragraph_threshold=paragraph_threshold, @@ -756,6 +775,7 @@ def _split( block_size: int, batch_size: int, pad_last_batch: bool, + weighting: Literal["uniform", "hat"], paragraph_threshold: float, remove_whitespace_before_inference: bool, outer_batch_size: int, @@ -784,6 +804,7 @@ def get_default_threshold(model_str: str): block_size=block_size, batch_size=batch_size, pad_last_batch=pad_last_batch, + weighting=weighting, remove_whitespace_before_inference=remove_whitespace_before_inference, outer_batch_size=outer_batch_size, return_paragraph_probabilities=do_paragraph_segmentation, diff --git a/wtpsplit/evaluation/__init__.py b/wtpsplit/evaluation/__init__.py index 0147ca90..0c3c0d70 100644 --- a/wtpsplit/evaluation/__init__.py +++ b/wtpsplit/evaluation/__init__.py @@ -1,6 +1,7 @@ import subprocess import unicodedata import os +from typing import Literal import numpy as np import regex as re @@ -240,6 +241,7 @@ def our_sentencize( block_size=512, stride=64, batch_size=32, + weighting: Literal["uniform", "hat"] = "uniform", ): logits = extract( [text], @@ -249,6 +251,7 @@ def our_sentencize( max_block_size=block_size, batch_size=batch_size, pad_last_batch=False, + weighting=weighting, use_hidden_states=False, verbose=False, )[0] diff --git a/wtpsplit/extract.py b/wtpsplit/extract.py index 7ecd2c2f..9d0c319d 100644 --- a/wtpsplit/extract.py +++ b/wtpsplit/extract.py @@ -1,6 +1,7 @@ import math import sys import logging +from typing import Literal import numpy as np from tqdm.auto import tqdm @@ -93,6 +94,7 @@ def extract( batch_size, lang_code=None, pad_last_batch=False, + weighting: Literal["uniform", "hat"] = "uniform", verbose=False, tokenizer=None, ): @@ -202,7 +204,7 @@ def extract( for length in text_lengths ] # container for the number of chunks that any character was part of (to average chunk predictions) - all_counts = [np.zeros(length, dtype=np.int16) for length in text_lengths] + all_counts = [np.zeros(length, dtype=np.float16) for length in text_lengths] uses_lang_adapters = getattr(model.config, "language_adapter", "off") == "on" if uses_lang_adapters: @@ -218,6 +220,13 @@ def extract( ) else: language_ids = None + + # compute weights for the given weighting scheme + if weighting == "uniform": + weights = np.ones(block_size, dtype=np.float16) + elif weighting == "hat": + x = np.linspace(-(1 - 1 / block_size), 1 - 1 / block_size, block_size, dtype=np.float16) + weights = 1 - np.abs(x) # forward passes through all chunks for batch_idx in tqdm(range(n_batches), disable=not verbose): @@ -255,8 +264,9 @@ def extract( for i in range(start, end): original_idx, start_char_idx, end_char_idx = locs[i] - all_logits[original_idx][start_char_idx:end_char_idx] += logits[i - start, : end_char_idx - start_char_idx] - all_counts[original_idx][start_char_idx:end_char_idx] += 1 + n = end_char_idx - start_char_idx + all_logits[original_idx][start_char_idx:end_char_idx] += weights[:n, np.newaxis] * logits[i - start, :n] + all_counts[original_idx][start_char_idx:end_char_idx] += weights[:n] # so far, logits are summed, so we average them here all_logits = [(logits / counts[:, None]).astype(np.float16) for logits, counts in zip(all_logits, all_counts)] diff --git a/wtpsplit/train/evaluate.py b/wtpsplit/train/evaluate.py index 8751245b..5e1d04a6 100644 --- a/wtpsplit/train/evaluate.py +++ b/wtpsplit/train/evaluate.py @@ -1,5 +1,6 @@ import logging import sys +from typing import Literal import numpy as np import pysbd @@ -74,6 +75,7 @@ def evaluate_sentence( stride, block_size, batch_size, + weighting: Literal["uniform", "hat"] = "uniform", use_pysbd=False, positive_index=None, do_lowercase=False, @@ -97,6 +99,7 @@ def evaluate_sentence( stride=stride, max_block_size=block_size, batch_size=batch_size, + weighting=weighting, ) logits = logits[0] if offsets_mapping is not None: