Skip to content

Commit

Permalink
[#249] Special handle stopwords for ContextEmbsAug
Browse files Browse the repository at this point in the history
  • Loading branch information
makcedward committed Nov 21, 2021
1 parent e058a13 commit 843adbb
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 21 deletions.
75 changes: 56 additions & 19 deletions nlpaug/augmenter/word/context_word_embs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import string
import os
import re
import logging

from nlpaug.augmenter.word import WordAugmenter
Expand Down Expand Up @@ -45,9 +46,8 @@ class ContextualWordEmbsAug(WordAugmenter):
:param str model_path: Model name or model path. It used transformers to load the model. Tested
'bert-base-uncased', 'bert-base-cased', 'distilbert-base-uncased', 'roberta-base', 'distilroberta-base',
'facebook/bart-base', 'squeezebert/squeezebert-uncased'.
:param str model_type: Type of model. For BERT model, use 'bert'. For XLNet model, use 'xlnet'.
For RoBERTa/LongFormer model, use 'roberta'. For BART model, use 'bart'. If no value is provided, will
determine from model name.
:param str model_type: Type of model. For BERT model, use 'bert'. For RoBERTa/LongFormer model, use 'roberta'.
For BART model, use 'bart'. If no value is provided, will determine from model name.
:param str action: Either 'insert or 'substitute'. If value is 'insert', a new word will be injected to random
position according to contextual word embeddings calculation. If value is 'substitute', word will be replaced
according to contextual embeddings calculation
Expand All @@ -58,7 +58,8 @@ class ContextualWordEmbsAug(WordAugmenter):
:param int aug_max: Maximum number of word will be augmented. If None is passed, number of augmentation is
calculated via aup_p. If calculated result from aug_p is smaller than aug_max, will use calculated result from
aug_p. Otherwise, using aug_max.
:param list stopwords: List of words which will be skipped from augment operation.
:param list stopwords: List of words which will be skipped from augment operation. Do NOT include the UNKNOWN word.
UNKNOWN word of BERT is [UNK]. UNKNOWN word of RoBERTa and BART is <unk>.
:param str stopwords_regex: Regular expression for matching words which will be skipped from augment operation.
:param str device: Default value is CPU. If value is CPU, it uses CPU for processing. If value is CUDA, it uses GPU
for processing. Possible values include 'cuda' and 'cpu'. (May able to use other options)
Expand Down Expand Up @@ -89,8 +90,21 @@ def __init__(self, model_path='bert-base-uncased', model_type='', action="substi
model_path=model_path, model_type=self.model_type, device=device, force_reload=force_reload,
batch_size=batch_size, top_k=top_k, silence=silence)
# Override stopwords
if stopwords is not None and self.model_type in ['xlnet', 'roberta']:
stopwords = [self.stopwords]
# if stopwords and self.model_type in ['xlnet', 'roberta']:
# stopwords = [self.stopwords]

if stopwords:
prefix_reg = '(?<=\s|\W)'
suffix_reg = '(?=\s|\W)'
stopword_reg = '('+')|('.join([prefix_reg + re.escape(s) + suffix_reg for s in stopwords])+')'
self.stopword_reg = re.compile(stopword_reg)

prefix_reg = '(?<=\s|\W)'
suffix_reg = '(?=\s|\W)'
reserve_word_reg = '(' + prefix_reg + re.escape(self.model.get_unknown_token()) + suffix_reg + ')'
self.reserve_word_reg = re.compile(reserve_word_reg)


self.device = self.model.device

"""
Expand Down Expand Up @@ -127,11 +141,12 @@ def check_model_type(self):
return ''

def is_stop_words(self, token):
if self.model_type in ['bert', 'electra']:
return super().is_stop_words(token)
elif self.model_type in ['xlnet', 'roberta', 'bart']:
return self.stopwords is not None and token.replace(self.model.get_subword_prefix(), '').lower() in self.stopwords
return False
return token == '[UNK]'
# if self.model_type in ['bert', 'electra']:
# return super().is_stop_words(token)
# elif self.model_type in ['xlnet', 'roberta', 'bart']:
# return self.stopwords is not None and token.replace(self.model.get_subword_prefix(), '').lower() in self.stopwords
# return False

def skip_aug(self, token_idxes, tokens):
results = []
Expand Down Expand Up @@ -162,13 +177,20 @@ def skip_aug(self, token_idxes, tokens):

def split_text(self, data):
# Expect to have waring for "Token indices sequence length is longer than the specified maximum sequence length for this model"

# Handle stopwords first #https://github.com/makcedward/nlpaug/issues/247
if self.stopwords:
preprocessed_data, reserved_stopwords = self.replace_stopword_by_reserved_word(data, self.stopword_reg, '[UNK]')
else:
preprocessed_data, reserved_stopwords = data, None

orig_log_level = logging.getLogger('transformers.' + 'tokenization_utils_base').getEffectiveLevel()
logging.getLogger('transformers.' + 'tokenization_utils_base').setLevel(logging.ERROR)
tokens = self.model.get_tokenizer().tokenize(data)
tokens = self.model.get_tokenizer().tokenize(preprocessed_data)
logging.getLogger('transformers.' + 'tokenization_utils_base').setLevel(orig_log_level)

if self.model.get_model().config.max_position_embeddings == -1: # e.g. No max length restriction for XLNet
return data, None, tokens, None # Head text, tail text, head token, tail token
return (preprocessed_data, None, tokens, None), reserved_stopwords # (Head text, tail text, head token, tail token), reserved_stopwords

ids = self.model.get_tokenizer().convert_tokens_to_ids(tokens[:self.max_num_token])
head_text = self.model.get_tokenizer().decode(ids).strip()
Expand All @@ -179,7 +201,7 @@ def split_text(self, data):
ids = self.model.get_tokenizer().convert_tokens_to_ids(tokens[self.max_num_token:])
tail_text = self.model.get_tokenizer().decode(ids).strip()

return head_text, tail_text, tokens[:self.max_num_token], tokens[self.max_num_token:]
return (head_text, tail_text, tokens[:self.max_num_token], tokens[self.max_num_token:]), reserved_stopwords

def insert(self, data):
if not data:
Expand All @@ -194,7 +216,12 @@ def insert(self, data):
all_data = [data]

# If length of input is larger than max allowed input, only augment heading part
split_results = [self.split_text(d) for d in all_data] # head_text, tail_text, head_tokens, tail_tokens
split_results = [] # head_text, tail_text, head_tokens, tail_tokens
reserved_stopwords = []
for d in all_data:
split_result, reserved_stopword = self.split_text(d)
split_results.append(split_result)
reserved_stopwords.append(reserved_stopword)

# Pick target word for augmentation
for i, split_result in enumerate(split_results):
Expand Down Expand Up @@ -292,7 +319,7 @@ def insert(self, data):
split_results[aug_input_pos][6][j] = -1

augmented_texts = []
for split_result in split_results:
for split_result, reserved_stopword_tokens in zip(split_results, reserved_stopwords):
tail_text, head_doc = split_result[1], split_result[5]

head_tokens = head_doc.get_augmented_tokens()
Expand All @@ -302,8 +329,11 @@ def insert(self, data):

ids = self.model.get_tokenizer().convert_tokens_to_ids(head_tokens)
augmented_text = self.model.get_tokenizer().decode(ids)
if tail_text is not None:
if tail_text:
augmented_text += ' ' + tail_text
if reserved_stopwords:
augmented_text = self.replace_reserve_word_by_stopword(augmented_text, self.reserve_word_reg, reserved_stopword_tokens)

augmented_texts.append(augmented_text)

if isinstance(data, list):
Expand All @@ -324,7 +354,12 @@ def substitute(self, data):
all_data = [data]

# If length of input is larger than max allowed input, only augment heading part
split_results = [self.split_text(d) for d in all_data] # head_text, tail_text, head_tokens, tail_tokens
split_results = [] # head_text, tail_text, head_tokens, tail_tokens
reserved_stopwords = []
for d in all_data:
split_result, reserved_stopword = self.split_text(d)
split_results.append(split_result)
reserved_stopwords.append(reserved_stopword)

# Pick target word for augmentation
for i, split_result in enumerate(split_results):
Expand Down Expand Up @@ -444,7 +479,7 @@ def substitute(self, data):
split_results[aug_input_pos][6][j] = -1

augmented_texts = []
for split_result in split_results:
for split_result, reserved_stopword_tokens in zip(split_results, reserved_stopwords):
tail_text, head_doc = split_result[1], split_result[5]

head_tokens = head_doc.get_augmented_tokens()
Expand All @@ -456,6 +491,8 @@ def substitute(self, data):
augmented_text = self.model.get_tokenizer().decode(ids)
if tail_text is not None:
augmented_text += ' ' + tail_text
if reserved_stopwords:
augmented_text = self.replace_reserve_word_by_stopword(augmented_text, self.reserve_word_reg, reserved_stopword_tokens)
augmented_texts.append(augmented_text)

if isinstance(data, list):
Expand Down
50 changes: 49 additions & 1 deletion nlpaug/augmenter/word/word_augmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, action, name='Word_Aug', aug_min=1, aug_max=10, aug_p=0.3, st
self.tokenizer = tokenizer or Tokenizer.tokenizer
self.reverse_tokenizer = reverse_tokenizer or Tokenizer.reverse_tokenizer
self.stopwords = stopwords
self.stopwords_regex = re.compile(stopwords_regex) if stopwords_regex is not None else stopwords_regex
self.stopwords_regex = re.compile(stopwords_regex) if stopwords_regex else stopwords_regex

@classmethod
def clean(cls, data):
Expand Down Expand Up @@ -133,3 +133,51 @@ def get_word_case(cls, word):
if word[0].isupper():
return 'capitalize'
return 'unknown'

def replace_stopword_by_reserved_word(self, text, stopword_reg, reserve_word):
replaced_text = ''
reserved_stopwords = []

# pad space for easy handling
replaced_text = ' ' + text + ' '
for m in reversed(list(stopword_reg.finditer(replaced_text))):
# Get position excluding prefix and suffix
start, end, token = m.start(), m.end(), m.group()
# replace stopword by reserve word
replaced_text = replaced_text[:start] + reserve_word + replaced_text[end:]
reserved_stopwords.append(token) # reversed order but it will consumed in reversed order later too

# trim
replaced_text = replaced_text[1:-1]

return replaced_text, reserved_stopwords

def replace_reserve_word_by_stopword(self, text, reserve_word_aug, original_stopwords):
# pad space for easy handling
replaced_text = ' ' + text + ' '
matched = list(reserve_word_aug.finditer(replaced_text))[::-1]

# TODO:?
if len(matched) != len(original_stopwords):
pass
if len(matched) > len(original_stopwords):
pass
if len(matched) < len(original_stopwords):
pass

for m, orig_stopword in zip(matched, original_stopwords):
# Get position excluding prefix and suffix
start, end = m.start(), m.end()
# replace stopword by reserve word
replaced_text = replaced_text[:start] + orig_stopword + replaced_text[end:]

# trim
replaced_text = replaced_text[1:-1]

return replaced_text

def preprocess(self, data):
...

def postprocess(self, data):
...
1 change: 0 additions & 1 deletion test/augmenter/word/test_context_word_embs.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ def substitute(self, aug, data):
self.assertTrue(aug.model.get_subword_prefix() not in augmented_text)

def substitute_stopwords(self, aug, data):
print('->substitute_stopwords')
original_stopwords = aug.stopwords
if isinstance(data, list):
aug.stopwords = [t.lower() for t in data[0].split(' ')[:3]]
Expand Down
55 changes: 55 additions & 0 deletions test/augmenter/word/test_word.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,61 @@ def test_stopwords(self):
self.assertTrue(
'quick' not in augmented_text or 'over' not in augmented_text or 'lazy' not in augmented_text)

# https://github.com/makcedward/nlpaug/issues/247
def test_stopword_for_preprocess(self):
stopwords = ["[id]", "[year]"]
texts = [
"My id is [id], and I born in [year]", # with stopwords as last word
"[id] id is [id], and I born in [year]", # with stopwords as first word
"[id] [id] Id is [year] [id]", # continuous stopwords
"[id] [id] Id is [year] [id]", # continuous stopwords with space
"My id is [id], and I [id] born in [year] a[year] [year]b aa[year]", # with similar stopwords
"My id is [id], and I born [UNK] [year]", # already have reserved word. NOT handling now
]
expected_replaced_texts = [
'My id is [UNK], and I born in [UNK]',
'[UNK] id is [UNK], and I born in [UNK]',
'[UNK] [UNK] Id is [UNK] [UNK]',
'[UNK] [UNK] Id is [UNK] [UNK]',
'My id is [UNK], and I [UNK] born in [UNK] a[year] [year]b aa[year]',
"My id is [UNK], and I born [UNK] [UNK]",
]
expected_reserved_tokens = [
['[year]', '[id]'],
['[year]', '[id]', '[id]'],
['[id]', '[year]', '[id]', '[id]'],
['[id]', '[year]', '[id]', '[id]'],
['[year]', '[id]', '[id]'],
['[year]', '[id]']
]
expected_reversed_texts = [
'My id is [id], and I born in [year]',
'[id] id is [id], and I born in [year]',
'[id] [id] Id is [year] [id]',
'[id] [id] Id is [year] [id]',
'My id is [id], and I [id] born in [year] a[year] [year]b aa[year]',
'My id is [UNK], and I born [id] [year]'
]

augs = [
aug = naw.ContextualWordEmbsAug(
model_path='bert-base-uncased', action="insert", stopwords=stopwords),
aug = naw.ContextualWordEmbsAug(
model_path='bert-base-uncased', action="substitute", stopwords=stopwords)
]

for aug in augs:
for expected_text, expected_reserved_token_list, expected_reversed_text, text in zip(
expected_replaced_texts, expected_reserved_tokens, expected_reversed_texts, texts):
replaced_text, reserved_stopwords = aug.replace_stopword_by_reserved_word(
text, aug.stopword_reg, reserve_word)
assert expected_text == replaced_text
assert expected_reserved_token_list == reserved_stopwords

reversed_text = aug.replace_reserve_word_by_stopword(
replaced_text, aug.reserve_word_reg, reserved_stopwords)
assert expected_reversed_text == reversed_text

# https://github.com/makcedward/nlpaug/issues/81
def test_stopwords_regex(self):
text = 'The quick brown fox jumps over the lazy dog.'
Expand Down

0 comments on commit 843adbb

Please sign in to comment.