-
Notifications
You must be signed in to change notification settings - Fork 251
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Random Deletion Working * Added to init * WOrking * Working * Current Status * Working Layer More Tests to be Added * Fixed Scalar Case * Added Comments * Minor Fixes * Major Refactors and Fixes, ToDo - Docs, Tests * Fixed Shape Issues for Scalar Lists * Finalized Tests and DocString * Ran Stylers Added More Descriptive DocString * Fixed Failing Docstring Tests * Removed Map Call and Unsupported Test * Shape Fixes * Working * Working * Changing Parent Class * Changes * Formatter Ran * Finalized * Addresed Review Comments * Fornatter * Added new Tests * Fan Formatter * Skip Works * New Randomness * All Testing Done * Review Changes * Addressed all Review Comments * Copy edits for docstrings * Finishes * Changed Tokenizer Import * Addressed Reviews * Fix typo Co-authored-by: Matt Watson <[email protected]>
- Loading branch information
1 parent
7c338db
commit 58f4bbb
Showing
3 changed files
with
493 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,266 @@ | ||
# Copyright 2022 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import random | ||
|
||
import tensorflow as tf | ||
from tensorflow import keras | ||
from tensorflow.python.ops.ragged import ragged_array_ops | ||
|
||
|
||
class RandomDeletion(keras.layers.Layer): | ||
"""Augments input by randomly deleting tokens. | ||
This layer comes in handy when you need to generate new data using deletion | ||
augmentation as described in the paper [EDA: Easy Data Augmentation | ||
Techniques for Boosting Performance on Text Classification Tasks] | ||
(https://arxiv.org/pdf/1901.11196.pdf). The layer expects the inputs to be | ||
pretokenized so that each token can be individually treated as a possible | ||
candidate for deletion. | ||
Input should be either a `tf.RaggedTensor` or a dense `tf.Tensor`, and | ||
either rank-1 or rank-2. | ||
Args: | ||
rate: The probability of a token being chosen for deletion. | ||
max_deletions: The maximum number of tokens to delete. | ||
skip_list: A list of token values that should not be considered | ||
candidates for deletion. | ||
skip_fn: A function that takes as input a scalar tensor token and | ||
returns as output a scalar tensor True/False value. A value of | ||
True indicates that the token should not be considered a | ||
candidate for deletion. This function must be tracable--it | ||
should consist of tensorflow operations. | ||
skip_py_fn: A function that takes as input a python token value and | ||
returns as output `True` or `False`. A value of True | ||
indicates that should not be considered a candidate for deletion. | ||
Unlike the `skip_fn` argument, this argument need not be | ||
tracable--it can be any python function. | ||
seed: A seed for the rng. | ||
Examples: | ||
Word level usage. | ||
>>> keras.utils.set_random_seed(1337) | ||
>>> inputs=tf.strings.split(["Hey I like", "Keras and Tensorflow"]) | ||
>>> augmenter=keras_nlp.layers.RandomDeletion(rate=0.4, seed=42) | ||
>>> augmented=augmenter(inputs) | ||
>>> tf.strings.reduce_join(augmented, separator=" ", axis=-1) | ||
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'I like', b'and'], | ||
dtype=object)> | ||
Character level usage. | ||
>>> keras.utils.set_random_seed(1337) | ||
>>> inputs=tf.strings.unicode_split(["Hey Dude", "Speed Up"], "UTF-8") | ||
>>> augmenter=keras_nlp.layers.RandomDeletion(rate=0.4, seed=42) | ||
>>> augmented=augmenter(inputs) | ||
>>> tf.strings.reduce_join(augmented, axis=-1) | ||
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'H Dude', b'pedUp'], | ||
dtype=object)> | ||
Usage with skip_list. | ||
>>> keras.utils.set_random_seed(1337) | ||
>>> inputs=tf.strings.split(["Hey I like", "Keras and Tensorflow"]) | ||
>>> augmenter=keras_nlp.layers.RandomDeletion(rate=0.4, | ||
... skip_list=["Keras", "Tensorflow"], seed=42) | ||
>>> augmented=augmenter(inputs) | ||
>>> tf.strings.reduce_join(augmented, separator=" ", axis=-1) | ||
<tf.Tensor: shape=(2,), dtype=string, | ||
numpy=array([b'I like', b'Keras Tensorflow'], dtype=object)> | ||
Usage with skip_fn. | ||
>>> def skip_fn(word): | ||
... return tf.strings.regex_full_match(word, r"\\pP") | ||
>>> keras.utils.set_random_seed(1337) | ||
>>> inputs=tf.strings.split(["Hey I like", "Keras and Tensorflow"]) | ||
>>> augmenter=keras_nlp.layers.RandomDeletion(rate=0.4, | ||
... skip_fn=skip_fn, seed=42) | ||
>>> augmented=augmenter(inputs) | ||
>>> tf.strings.reduce_join(augmented, separator=" ", axis=-1) | ||
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'I like', b'and'], | ||
dtype=object)> | ||
Usage with skip_py_fn. | ||
>>> def skip_py_fn(word): | ||
... return len(word) < 4 | ||
>>> keras.utils.set_random_seed(1337) | ||
>>> inputs=tf.strings.split(["Hey I like", "Keras and Tensorflow"]) | ||
>>> augmenter=RandomDeletion(rate=0.4, | ||
... skip_py_fn=skip_py_fn, seed=42) | ||
>>> augmented=augmenter(inputs) | ||
>>> tf.strings.reduce_join(augmented, separator=" ", axis=-1) | ||
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'Hey I', b'and'], | ||
dtype=object)> | ||
""" | ||
|
||
def __init__( | ||
self, | ||
rate, | ||
max_deletions=None, | ||
skip_list=None, | ||
skip_fn=None, | ||
skip_py_fn=None, | ||
seed=None, | ||
name=None, | ||
**kwargs, | ||
): | ||
# Check dtype and provide a default. | ||
if "dtype" not in kwargs or kwargs["dtype"] is None: | ||
kwargs["dtype"] = tf.int32 | ||
else: | ||
dtype = tf.dtypes.as_dtype(kwargs["dtype"]) | ||
if not dtype.is_integer and dtype != tf.string: | ||
raise ValueError( | ||
"Output dtype must be one of `'string'`, `'int32'`, and " | ||
f"`'int64'`. Received: dtype={dtype}" | ||
) | ||
|
||
super().__init__(name=name, **kwargs) | ||
self.rate = rate | ||
self.max_deletions = max_deletions | ||
self.seed = random.randint(1, 1e9) if seed is None else seed | ||
self._generator = tf.random.Generator.from_seed(self.seed) | ||
self.skip_list = skip_list | ||
self.skip_fn = skip_fn | ||
self.skip_py_fn = skip_py_fn | ||
|
||
if self.rate > 1 or self.rate < 0: | ||
raise ValueError( | ||
"Rate must be between 0 and 1 (both inclusive)." | ||
f"Received: rate={rate}" | ||
) | ||
|
||
if [self.skip_list, self.skip_fn, self.skip_py_fn].count(None) < 2: | ||
raise ValueError( | ||
"Exactly one of `skip_list`, `skip_fn`, `skip_py_fn` must be " | ||
"provided." | ||
) | ||
|
||
if self.skip_list: | ||
self.StaticHashTable = tf.lookup.StaticHashTable( | ||
tf.lookup.KeyValueTensorInitializer( | ||
tf.convert_to_tensor(self.skip_list), | ||
tf.convert_to_tensor([True] * len(self.skip_list)), | ||
), | ||
default_value=False, | ||
) | ||
|
||
def call(self, inputs): | ||
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): | ||
inputs = tf.convert_to_tensor(inputs) | ||
|
||
input_is_1d = False | ||
if inputs.shape.rank < 1 or inputs.shape.rank > 2: | ||
raise ValueError( | ||
"Input must either be rank 1 or rank 2. Received input with " | ||
f"rank={inputs.shape.rank}" | ||
) | ||
elif inputs.shape.rank == 1: | ||
input_is_1d = True | ||
# Add a new axis at the beginning. | ||
inputs = tf.expand_dims(inputs, axis=0) | ||
if isinstance(inputs, tf.Tensor): | ||
# Convert to ragged tensor. | ||
inputs = tf.RaggedTensor.from_tensor(inputs) | ||
|
||
skip_masks = None | ||
if self.skip_list: | ||
skip_masks = self.StaticHashTable.lookup(inputs.flat_values) | ||
elif self.skip_fn: | ||
skip_masks = tf.map_fn( | ||
self.skip_fn, inputs.flat_values, fn_output_signature=tf.bool | ||
) | ||
elif self.skip_py_fn: | ||
|
||
def string_fn(token): | ||
return self.skip_py_fn(token.numpy().decode("utf-8")) | ||
|
||
def int_fn(token): | ||
return self.skip_py_fn(token.numpy()) | ||
|
||
py_fn = string_fn if inputs.dtype == tf.string else int_fn | ||
|
||
skip_masks = tf.map_fn( | ||
lambda x: tf.py_function(py_fn, [x], tf.bool), | ||
inputs.flat_values, | ||
fn_output_signature=tf.bool, | ||
) | ||
|
||
positions_flat = tf.range(tf.size(inputs.flat_values)) | ||
positions = inputs.with_flat_values(positions_flat) | ||
if skip_masks is not None: | ||
skip_masks = tf.logical_not(skip_masks) | ||
skip_masks.set_shape([None]) | ||
positions = ragged_array_ops.boolean_mask( | ||
positions, inputs.with_flat_values(skip_masks) | ||
) | ||
|
||
# Figure out how many we are going to select. | ||
token_counts = tf.cast(inputs.row_lengths(), "float32") | ||
num_to_select = tf.random.stateless_binomial( | ||
shape=tf.shape(token_counts), | ||
seed=self._generator.make_seeds()[:, 0], | ||
counts=token_counts, | ||
probs=self.rate, | ||
) | ||
if self.max_deletions is not None: | ||
num_to_select = tf.math.minimum(num_to_select, self.max_deletions) | ||
num_to_select = tf.cast(num_to_select, "int64") | ||
|
||
# Shuffle and trim to items that are going to be selected. | ||
def _shuffle_and_trim(x): | ||
positions, top_n = x | ||
shuffled = tf.random.shuffle(positions, seed=self.seed) | ||
return shuffled[:top_n] | ||
|
||
selected_for_mask = tf.map_fn( | ||
_shuffle_and_trim, | ||
(positions, num_to_select), | ||
fn_output_signature=tf.RaggedTensorSpec( | ||
ragged_rank=positions.ragged_rank - 1, dtype=positions.dtype | ||
), | ||
) | ||
selected_for_mask.flat_values.set_shape([None]) | ||
|
||
# Construct the mask which is a boolean RT | ||
# Scatter 0's to positions that have been selector for deletion. | ||
update_values = tf.zeros_like(selected_for_mask.flat_values, "int32") | ||
update_indices = selected_for_mask.flat_values | ||
update_indices = tf.expand_dims(update_indices, -1) | ||
update_indices = tf.cast(update_indices, "int32") | ||
mask_flat = tf.ones_like(inputs.flat_values, dtype="int32") | ||
mask_flat = tf.tensor_scatter_nd_update( | ||
mask_flat, update_indices, update_values | ||
) | ||
mask = tf.cast(inputs.with_flat_values(mask_flat), "bool") | ||
|
||
inputs = tf.ragged.boolean_mask(inputs, mask) | ||
|
||
if input_is_1d: | ||
inputs = tf.squeeze(inputs, axis=0) | ||
|
||
return inputs | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"rate": self.rate, | ||
"max_deletions": self.max_deletions, | ||
"seed": self.seed, | ||
"skip_list": self.skip_list, | ||
"skip_fn": self.skip_fn, | ||
"skip_py_fn": self.skip_py_fn, | ||
} | ||
) | ||
return config |
Oops, something went wrong.