Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Random Deletion Layer - Data Augmentation #152

Closed
aflah02 opened this issue Apr 29, 2022 · 11 comments · Fixed by #214
Closed

Random Deletion Layer - Data Augmentation #152

aflah02 opened this issue Apr 29, 2022 · 11 comments · Fixed by #214
Assignees

Comments

@aflah02
Copy link
Collaborator

aflah02 commented Apr 29, 2022

I've created this issue to specifically discuss the Random Deletion Layer while we figure out how to incorporate WordNet for Synonym Replacement

I've adapted the design mentioned by @mattdangerw here for the same

class RandomDeletion(keras.layers.Layer):
    """Augments input by randomly deleting words

    Args:
        probability: probability of a word being chosen for deletion
        max_replacements: The maximum number of words to replace
        stop_word_only: Only deletes stopwords

    Examples:

    Basic usage.
    >>> augmenter = keras_nlp.layers.RandomDeletion(
    ...     probability = 0.3,
    ... )
    >>> augmenter(["dog dog dog dog dog"])
    <tf.Tensor: shape=(), dtype=string, numpy=b'dog dog dog dog'>
    """
    pass
@mattdangerw
Copy link
Member

Thanks, this looks good to me!

A few things we should consider...

  • We may want to offer a character or a word level version of this. Random deleted characters could be good for simulating typos.
  • For the word level splitting, we may want to make split points configurable (we could support regex splitting like WordPieceTokenizer for example). split_pattern: A regex pattern to match delimiters to split.

@aflah02
Copy link
Collaborator Author

aflah02 commented Apr 29, 2022

@mattdangerw Thanks for the review!

For the first point yeah we could do that but I feel it might be better to have it as a separate layer which focuses on character level deletions but we could have it here as well as a parameter if that seems to match the design better

For the second point yup I was also thinking about the same I'll add it in as an option

Just to confirm this should work similarly to how tokenizers work right? The input could be anything ranging from a scalar to a batch of tensors right?

@mattdangerw
Copy link
Member

I think so yeah, input could be a scalar dense tensor or a batched dense tensor.

We could also consider supporting pre-split ragged inputs (in WordPiece we do this for example), if even a configurable split regex is not enough your splitting needs. Probably not something for a first version though.

@aflah02
Copy link
Collaborator Author

aflah02 commented Apr 29, 2022

Thanks, I'll keep this at the back of my mind and maybe take this up as a next step once an initial layer is ready!

@aflah02
Copy link
Collaborator Author

aflah02 commented May 20, 2022

Hey @mattdangerw
I've been thinking about this issue for some days now and I'm not quite sure what's the best way to approach this. Should I go for a recursive approach? So when I reach to the level which has strings I'll treat that as my base case? or Should I create a function and map it to the tensor in the call? or am I missing something and there's a better way out of this?

@aflah02
Copy link
Collaborator Author

aflah02 commented May 26, 2022

Hey @mattdangerw
I've been working on this and I have an implementation here
I am stuck currently at the part where I need to handle non scalar inputs as after the tf.strings.split(input_) it transforms into a ragged tensor which then throws off everything as then the list itself is treated as the sole element at 0th index, would be great if you can check this out and share any suggestions to fix this. Can't think of a fix for this as say if I handle this in the 2D case then for inputs like this (augmenter([[["dog dog dog dog dog"], ["the cat hates to play in the rain"]]])) the issue keeps growing

@mattdangerw
Copy link
Member

@aflah02 yeah, definitely we need to support the bached 2D case at a minimum. Potentially this could be done with RaggedTensor and no map function? Something like...

inputs = tf.constant(["this is a test", "this is another test"])
ragged_words = tf.strings.split(inputs)
mask = tf.random.uniform(ragged_words.flat_values.shape) > 0.25
mask = ragged_words.with_flat_values(mask)
deleted = tf.ragged.boolean_mask(ragged_words, mask)
deleted = tf.strings.reduce_join(deleted, axis=-1, separator=" ")

Would that work?

@aflah02
Copy link
Collaborator Author

aflah02 commented Jun 1, 2022

@mattdangerw
Thanks for the reviews
In regards to this comment for random character deletions I feel separate layers makes more sense as at the end of the day the layers can be stacked by the user. Having them together would essentially mean having 2 separate logics operating in the same file having them separate would also make it more maintainable but I guess that's just my preference. What do you think?

@mattdangerw
Copy link
Member

@aflah02 that makes sense to me. In that case, let's call this RandomWordDeletion, it's more specific to what we are doing here, and leave us room to grow our offering.

@mattdangerw
Copy link
Member

@aflah02 here's a tracable set of ragged tensor ops based on the TF Text RandomItemSelector.

https://colab.sandbox.google.com/gist/mattdangerw/efdf506bd1719192e93be371d6eb68c6/tracable-delete-fn.ipynb

@aflah02
Copy link
Collaborator Author

aflah02 commented Jun 7, 2022

@mattdangerw Thanks a ton! This looks great and will help me a lot. I'll make sure to go through these ragged tensor ops for future use as well!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants