From 843adbbe758556339a207c1118879af09f7c65ab Mon Sep 17 00:00:00 2001 From: Edward Ma Date: Sun, 21 Nov 2021 08:54:44 -0800 Subject: [PATCH] [#249] Special handle stopwords for ContextEmbsAug --- nlpaug/augmenter/word/context_word_embs.py | 75 ++++++++++++++----- nlpaug/augmenter/word/word_augmenter.py | 50 ++++++++++++- test/augmenter/word/test_context_word_embs.py | 1 - test/augmenter/word/test_word.py | 55 ++++++++++++++ 4 files changed, 160 insertions(+), 21 deletions(-) diff --git a/nlpaug/augmenter/word/context_word_embs.py b/nlpaug/augmenter/word/context_word_embs.py index 22956ed..86af743 100755 --- a/nlpaug/augmenter/word/context_word_embs.py +++ b/nlpaug/augmenter/word/context_word_embs.py @@ -4,6 +4,7 @@ import string import os +import re import logging from nlpaug.augmenter.word import WordAugmenter @@ -45,9 +46,8 @@ class ContextualWordEmbsAug(WordAugmenter): :param str model_path: Model name or model path. It used transformers to load the model. Tested 'bert-base-uncased', 'bert-base-cased', 'distilbert-base-uncased', 'roberta-base', 'distilroberta-base', 'facebook/bart-base', 'squeezebert/squeezebert-uncased'. - :param str model_type: Type of model. For BERT model, use 'bert'. For XLNet model, use 'xlnet'. - For RoBERTa/LongFormer model, use 'roberta'. For BART model, use 'bart'. If no value is provided, will - determine from model name. + :param str model_type: Type of model. For BERT model, use 'bert'. For RoBERTa/LongFormer model, use 'roberta'. + For BART model, use 'bart'. If no value is provided, will determine from model name. :param str action: Either 'insert or 'substitute'. If value is 'insert', a new word will be injected to random position according to contextual word embeddings calculation. If value is 'substitute', word will be replaced according to contextual embeddings calculation @@ -58,7 +58,8 @@ class ContextualWordEmbsAug(WordAugmenter): :param int aug_max: Maximum number of word will be augmented. If None is passed, number of augmentation is calculated via aup_p. If calculated result from aug_p is smaller than aug_max, will use calculated result from aug_p. Otherwise, using aug_max. - :param list stopwords: List of words which will be skipped from augment operation. + :param list stopwords: List of words which will be skipped from augment operation. Do NOT include the UNKNOWN word. + UNKNOWN word of BERT is [UNK]. UNKNOWN word of RoBERTa and BART is . :param str stopwords_regex: Regular expression for matching words which will be skipped from augment operation. :param str device: Default value is CPU. If value is CPU, it uses CPU for processing. If value is CUDA, it uses GPU for processing. Possible values include 'cuda' and 'cpu'. (May able to use other options) @@ -89,8 +90,21 @@ def __init__(self, model_path='bert-base-uncased', model_type='', action="substi model_path=model_path, model_type=self.model_type, device=device, force_reload=force_reload, batch_size=batch_size, top_k=top_k, silence=silence) # Override stopwords - if stopwords is not None and self.model_type in ['xlnet', 'roberta']: - stopwords = [self.stopwords] + # if stopwords and self.model_type in ['xlnet', 'roberta']: + # stopwords = [self.stopwords] + + if stopwords: + prefix_reg = '(?<=\s|\W)' + suffix_reg = '(?=\s|\W)' + stopword_reg = '('+')|('.join([prefix_reg + re.escape(s) + suffix_reg for s in stopwords])+')' + self.stopword_reg = re.compile(stopword_reg) + + prefix_reg = '(?<=\s|\W)' + suffix_reg = '(?=\s|\W)' + reserve_word_reg = '(' + prefix_reg + re.escape(self.model.get_unknown_token()) + suffix_reg + ')' + self.reserve_word_reg = re.compile(reserve_word_reg) + + self.device = self.model.device """ @@ -127,11 +141,12 @@ def check_model_type(self): return '' def is_stop_words(self, token): - if self.model_type in ['bert', 'electra']: - return super().is_stop_words(token) - elif self.model_type in ['xlnet', 'roberta', 'bart']: - return self.stopwords is not None and token.replace(self.model.get_subword_prefix(), '').lower() in self.stopwords - return False + return token == '[UNK]' + # if self.model_type in ['bert', 'electra']: + # return super().is_stop_words(token) + # elif self.model_type in ['xlnet', 'roberta', 'bart']: + # return self.stopwords is not None and token.replace(self.model.get_subword_prefix(), '').lower() in self.stopwords + # return False def skip_aug(self, token_idxes, tokens): results = [] @@ -162,13 +177,20 @@ def skip_aug(self, token_idxes, tokens): def split_text(self, data): # Expect to have waring for "Token indices sequence length is longer than the specified maximum sequence length for this model" + + # Handle stopwords first #https://github.com/makcedward/nlpaug/issues/247 + if self.stopwords: + preprocessed_data, reserved_stopwords = self.replace_stopword_by_reserved_word(data, self.stopword_reg, '[UNK]') + else: + preprocessed_data, reserved_stopwords = data, None + orig_log_level = logging.getLogger('transformers.' + 'tokenization_utils_base').getEffectiveLevel() logging.getLogger('transformers.' + 'tokenization_utils_base').setLevel(logging.ERROR) - tokens = self.model.get_tokenizer().tokenize(data) + tokens = self.model.get_tokenizer().tokenize(preprocessed_data) logging.getLogger('transformers.' + 'tokenization_utils_base').setLevel(orig_log_level) if self.model.get_model().config.max_position_embeddings == -1: # e.g. No max length restriction for XLNet - return data, None, tokens, None # Head text, tail text, head token, tail token + return (preprocessed_data, None, tokens, None), reserved_stopwords # (Head text, tail text, head token, tail token), reserved_stopwords ids = self.model.get_tokenizer().convert_tokens_to_ids(tokens[:self.max_num_token]) head_text = self.model.get_tokenizer().decode(ids).strip() @@ -179,7 +201,7 @@ def split_text(self, data): ids = self.model.get_tokenizer().convert_tokens_to_ids(tokens[self.max_num_token:]) tail_text = self.model.get_tokenizer().decode(ids).strip() - return head_text, tail_text, tokens[:self.max_num_token], tokens[self.max_num_token:] + return (head_text, tail_text, tokens[:self.max_num_token], tokens[self.max_num_token:]), reserved_stopwords def insert(self, data): if not data: @@ -194,7 +216,12 @@ def insert(self, data): all_data = [data] # If length of input is larger than max allowed input, only augment heading part - split_results = [self.split_text(d) for d in all_data] # head_text, tail_text, head_tokens, tail_tokens + split_results = [] # head_text, tail_text, head_tokens, tail_tokens + reserved_stopwords = [] + for d in all_data: + split_result, reserved_stopword = self.split_text(d) + split_results.append(split_result) + reserved_stopwords.append(reserved_stopword) # Pick target word for augmentation for i, split_result in enumerate(split_results): @@ -292,7 +319,7 @@ def insert(self, data): split_results[aug_input_pos][6][j] = -1 augmented_texts = [] - for split_result in split_results: + for split_result, reserved_stopword_tokens in zip(split_results, reserved_stopwords): tail_text, head_doc = split_result[1], split_result[5] head_tokens = head_doc.get_augmented_tokens() @@ -302,8 +329,11 @@ def insert(self, data): ids = self.model.get_tokenizer().convert_tokens_to_ids(head_tokens) augmented_text = self.model.get_tokenizer().decode(ids) - if tail_text is not None: + if tail_text: augmented_text += ' ' + tail_text + if reserved_stopwords: + augmented_text = self.replace_reserve_word_by_stopword(augmented_text, self.reserve_word_reg, reserved_stopword_tokens) + augmented_texts.append(augmented_text) if isinstance(data, list): @@ -324,7 +354,12 @@ def substitute(self, data): all_data = [data] # If length of input is larger than max allowed input, only augment heading part - split_results = [self.split_text(d) for d in all_data] # head_text, tail_text, head_tokens, tail_tokens + split_results = [] # head_text, tail_text, head_tokens, tail_tokens + reserved_stopwords = [] + for d in all_data: + split_result, reserved_stopword = self.split_text(d) + split_results.append(split_result) + reserved_stopwords.append(reserved_stopword) # Pick target word for augmentation for i, split_result in enumerate(split_results): @@ -444,7 +479,7 @@ def substitute(self, data): split_results[aug_input_pos][6][j] = -1 augmented_texts = [] - for split_result in split_results: + for split_result, reserved_stopword_tokens in zip(split_results, reserved_stopwords): tail_text, head_doc = split_result[1], split_result[5] head_tokens = head_doc.get_augmented_tokens() @@ -456,6 +491,8 @@ def substitute(self, data): augmented_text = self.model.get_tokenizer().decode(ids) if tail_text is not None: augmented_text += ' ' + tail_text + if reserved_stopwords: + augmented_text = self.replace_reserve_word_by_stopword(augmented_text, self.reserve_word_reg, reserved_stopword_tokens) augmented_texts.append(augmented_text) if isinstance(data, list): diff --git a/nlpaug/augmenter/word/word_augmenter.py b/nlpaug/augmenter/word/word_augmenter.py index 950706a..510cdbd 100755 --- a/nlpaug/augmenter/word/word_augmenter.py +++ b/nlpaug/augmenter/word/word_augmenter.py @@ -18,7 +18,7 @@ def __init__(self, action, name='Word_Aug', aug_min=1, aug_max=10, aug_p=0.3, st self.tokenizer = tokenizer or Tokenizer.tokenizer self.reverse_tokenizer = reverse_tokenizer or Tokenizer.reverse_tokenizer self.stopwords = stopwords - self.stopwords_regex = re.compile(stopwords_regex) if stopwords_regex is not None else stopwords_regex + self.stopwords_regex = re.compile(stopwords_regex) if stopwords_regex else stopwords_regex @classmethod def clean(cls, data): @@ -133,3 +133,51 @@ def get_word_case(cls, word): if word[0].isupper(): return 'capitalize' return 'unknown' + + def replace_stopword_by_reserved_word(self, text, stopword_reg, reserve_word): + replaced_text = '' + reserved_stopwords = [] + + # pad space for easy handling + replaced_text = ' ' + text + ' ' + for m in reversed(list(stopword_reg.finditer(replaced_text))): + # Get position excluding prefix and suffix + start, end, token = m.start(), m.end(), m.group() + # replace stopword by reserve word + replaced_text = replaced_text[:start] + reserve_word + replaced_text[end:] + reserved_stopwords.append(token) # reversed order but it will consumed in reversed order later too + + # trim + replaced_text = replaced_text[1:-1] + + return replaced_text, reserved_stopwords + + def replace_reserve_word_by_stopword(self, text, reserve_word_aug, original_stopwords): + # pad space for easy handling + replaced_text = ' ' + text + ' ' + matched = list(reserve_word_aug.finditer(replaced_text))[::-1] + + # TODO:? + if len(matched) != len(original_stopwords): + pass + if len(matched) > len(original_stopwords): + pass + if len(matched) < len(original_stopwords): + pass + + for m, orig_stopword in zip(matched, original_stopwords): + # Get position excluding prefix and suffix + start, end = m.start(), m.end() + # replace stopword by reserve word + replaced_text = replaced_text[:start] + orig_stopword + replaced_text[end:] + + # trim + replaced_text = replaced_text[1:-1] + + return replaced_text + + def preprocess(self, data): + ... + + def postprocess(self, data): + ... diff --git a/test/augmenter/word/test_context_word_embs.py b/test/augmenter/word/test_context_word_embs.py index e4c57b0..2f0a3a7 100755 --- a/test/augmenter/word/test_context_word_embs.py +++ b/test/augmenter/word/test_context_word_embs.py @@ -209,7 +209,6 @@ def substitute(self, aug, data): self.assertTrue(aug.model.get_subword_prefix() not in augmented_text) def substitute_stopwords(self, aug, data): - print('->substitute_stopwords') original_stopwords = aug.stopwords if isinstance(data, list): aug.stopwords = [t.lower() for t in data[0].split(' ')[:3]] diff --git a/test/augmenter/word/test_word.py b/test/augmenter/word/test_word.py index c3ef6ca..b83f244 100755 --- a/test/augmenter/word/test_word.py +++ b/test/augmenter/word/test_word.py @@ -214,6 +214,61 @@ def test_stopwords(self): self.assertTrue( 'quick' not in augmented_text or 'over' not in augmented_text or 'lazy' not in augmented_text) + # https://github.com/makcedward/nlpaug/issues/247 + def test_stopword_for_preprocess(self): + stopwords = ["[id]", "[year]"] + texts = [ + "My id is [id], and I born in [year]", # with stopwords as last word + "[id] id is [id], and I born in [year]", # with stopwords as first word + "[id] [id] Id is [year] [id]", # continuous stopwords + "[id] [id] Id is [year] [id]", # continuous stopwords with space + "My id is [id], and I [id] born in [year] a[year] [year]b aa[year]", # with similar stopwords + "My id is [id], and I born [UNK] [year]", # already have reserved word. NOT handling now + ] + expected_replaced_texts = [ + 'My id is [UNK], and I born in [UNK]', + '[UNK] id is [UNK], and I born in [UNK]', + '[UNK] [UNK] Id is [UNK] [UNK]', + '[UNK] [UNK] Id is [UNK] [UNK]', + 'My id is [UNK], and I [UNK] born in [UNK] a[year] [year]b aa[year]', + "My id is [UNK], and I born [UNK] [UNK]", + ] + expected_reserved_tokens = [ + ['[year]', '[id]'], + ['[year]', '[id]', '[id]'], + ['[id]', '[year]', '[id]', '[id]'], + ['[id]', '[year]', '[id]', '[id]'], + ['[year]', '[id]', '[id]'], + ['[year]', '[id]'] + ] + expected_reversed_texts = [ + 'My id is [id], and I born in [year]', + '[id] id is [id], and I born in [year]', + '[id] [id] Id is [year] [id]', + '[id] [id] Id is [year] [id]', + 'My id is [id], and I [id] born in [year] a[year] [year]b aa[year]', + 'My id is [UNK], and I born [id] [year]' + ] + + augs = [ + aug = naw.ContextualWordEmbsAug( + model_path='bert-base-uncased', action="insert", stopwords=stopwords), + aug = naw.ContextualWordEmbsAug( + model_path='bert-base-uncased', action="substitute", stopwords=stopwords) + ] + + for aug in augs: + for expected_text, expected_reserved_token_list, expected_reversed_text, text in zip( + expected_replaced_texts, expected_reserved_tokens, expected_reversed_texts, texts): + replaced_text, reserved_stopwords = aug.replace_stopword_by_reserved_word( + text, aug.stopword_reg, reserve_word) + assert expected_text == replaced_text + assert expected_reserved_token_list == reserved_stopwords + + reversed_text = aug.replace_reserve_word_by_stopword( + replaced_text, aug.reserve_word_reg, reserved_stopwords) + assert expected_reversed_text == reversed_text + # https://github.com/makcedward/nlpaug/issues/81 def test_stopwords_regex(self): text = 'The quick brown fox jumps over the lazy dog.'