Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for noisy dense layers. #2099

Merged
merged 45 commits into from
Sep 15, 2020
Merged

Added support for noisy dense layers. #2099

merged 45 commits into from
Sep 15, 2020

Conversation

LeonShams
Copy link
Contributor

@LeonShams LeonShams commented Aug 18, 2020

Description

This PR adds support for noisy dense layers. Noisy dense layers are like dense layers but random noise is injected to help agents explore in reinforcement learning environments. This is very commonly used in Deep Q Learning and was first introduced by DeepMind in their Noisy Networks for Exploration paper in 2017.
Fixes: #2127

Type of change

Checklist:

  • I've properly formatted my code according to the guidelines
    • By running Black + Flake8
    • By running pre-commit hooks
  • This PR addresses an already submitted issue for TensorFlow Addons
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • This PR contains modifications to C++ custom-ops

How Has This Been Tested?

I used PyTest to test the code and trained it on some reinforcement learning environments from the gym toolkit and compared my results to the results on the Noisy Networks for Exploration paper.

@LeonShams
Copy link
Contributor Author

@facaiy @seanpmorgan please let me know if there are any problems or if you have any questions. Thanks!

@boring-cyborg boring-cyborg bot added the github label Aug 27, 2020
@LeonShams LeonShams requested a review from WindQAQ September 11, 2020 22:22
Copy link
Member

@WindQAQ WindQAQ left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just figure out that lots of codes are copied from tf.keras.layers.Dense. Therefore, a better design pattern is that

class NoisyDense(tf.keras.layers.Dense):
    @typechecked
    def __init__(
        self,
        units: int,
        sigma: float = 0.5,
        activation: types.Activation = None,
        use_bias: bool = True,
        kernel_regularizer: types.Regularizer = None,
        bias_regularizer: types.Regularizer = None,
        activity_regularizer: types.Regularizer = None,
        kernel_constraint: types.Constraint = None,
        bias_constraint: types.Constraint = None,
        **kwargs
    ):
        super().__init__(xxxx)
        # DO OTHER THINGS

    def build(self, input_shape):
        # OVERRIDE

    def reset_noise(self):
        # DO SOMETHING

    def remove_noise(self):
        # DO SOMETHING

    def call(self, inputs, reset_noise=True):
        if reset_noise:
            self.reset_noise()
        # TODO(WindQAQ): Replace this with `dense()` once public.
        # prepare your self.kernel, self.bias here.
        return super().call(inputs)

    def get_config(self):
        # OVERRIDE

tensorflow_addons/layers/noisy_dense.py Outdated Show resolved Hide resolved
tensorflow_addons/layers/noisy_dense.py Outdated Show resolved Hide resolved
tensorflow_addons/layers/noisy_dense.py Outdated Show resolved Hide resolved
tensorflow_addons/layers/noisy_dense.py Outdated Show resolved Hide resolved
@LeonShams
Copy link
Contributor Author

Regarding the super().__init__(xxxx), what would you like me to do about self.kernel_initializer and self.bias_initializer since they are not used in NoisyDense?

@WindQAQ
Copy link
Member

WindQAQ commented Sep 11, 2020

Regarding the super().__init__(xxxx), what would you like me to do about self.kernel_initializer and self.bias_initializer since they are not used in NoisyDense?

Just pass them as None. and remove them with delattr after it.

@LeonShams
Copy link
Contributor Author

LeonShams commented Sep 12, 2020

Also, what would you like me to do about super(NoisyDense, self).get_config() since it will now get the config of Dense?

@WindQAQ
Copy link
Member

WindQAQ commented Sep 12, 2020

Also, what would you like me to do about super(NoisyDense, self).get_config() since it will now get the config of Dense?

config = super().get_config()
config.pop("foo") # pop things you do not have
config["bar"] = bar # update things you have additionally
return config

@LeonShams
Copy link
Contributor Author

LeonShams commented Sep 12, 2020

That will raise an error since kernel_initializer and bias_initializer were deleted.

Dense get_config():

config.update({
    ...
    'kernel_initializer':
        initializers.serialize(self.kernel_initializer),
    'bias_initializer':
        initializers.serialize(self.bias_initializer),
    ...
}

@WindQAQ
Copy link
Member

WindQAQ commented Sep 12, 2020

That will raise an error since kernel_initializer and bias_initializer were deleted.

Dense get_config():

config.update({
    ...
    'kernel_initializer':
        initializers.serialize(self.kernel_initializer),
    'bias_initializer':
        initializers.serialize(self.bias_initializer),
    ...
}

Try this

config = super(tf.keras.layers.Dense, self).get_config()
config["foo"] = foo # update your stuff

or even hacking

self.bar = None
config = super().get_config()
config.pop("bar")
config["foo"] = foo

Both approaches are not so good to me, so please put TODO(WindQAQ): Get rid of this hacking way. to me 😃

https://colab.research.google.com/drive/14VzUnIZIDadej8s_fHlSKx-ieTs9rttV?usp=sharing

@LeonShams
Copy link
Contributor Author

The first method works well!

config = super(tf.keras.layers.Dense, self).get_config()
config["foo"] = foo # update your stuff

But is there a reason why you used config["foo"] instead of config.update({})?

@WindQAQ
Copy link
Member

WindQAQ commented Sep 12, 2020

The first method works well!

config = super(tf.keras.layers.Dense, self).get_config()
config["foo"] = foo # update your stuff

But is there a reason why you used config["foo"] instead of config.update({})?

Basically no. Just type in rush. config.update(xxx) LGTM.

@LeonShams LeonShams requested review from WindQAQ and bhack September 12, 2020 08:19
Copy link
Member

@WindQAQ WindQAQ left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some documentation improvement needed. Thanks.

tensorflow_addons/layers/noisy_dense.py Outdated Show resolved Hide resolved
tensorflow_addons/layers/noisy_dense.py Show resolved Hide resolved
tensorflow_addons/layers/tests/noisy_dense_test.py Outdated Show resolved Hide resolved
@WindQAQ WindQAQ self-requested a review September 13, 2020 02:12
Copy link
Member

@WindQAQ WindQAQ left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LeonShams Sorry that this should be the last comment. Others LGTM! Thank you.

tensorflow_addons/layers/noisy_dense.py Outdated Show resolved Hide resolved
tensorflow_addons/layers/noisy_dense.py Show resolved Hide resolved
@LeonShams
Copy link
Contributor Author

No worries, I made the changes, let me know if there is anything else.

@LeonShams LeonShams requested a review from WindQAQ September 13, 2020 02:40
@WindQAQ WindQAQ merged commit 1c3c072 into tensorflow:master Sep 15, 2020
jrruijli pushed a commit to jrruijli/addons that referenced this pull request Dec 23, 2020
* Create noisy_dense.py

* Create noisy_dense_test.py

* Update __init__.py

* Fix minor typo

* Update noisy_dense_test.py

* Update comments

* Update comments

* Update noisy_dense.py

* fix typo

* Update noisy_dense.py

* Update noisy_dense_test.py

* Fix compliance issues

* Fix compliance issues

* Update comments

* Fix typo

* Update CODEOWNERS

* Update CODEOWNERS

* add use bias to config

* Update noisy_dense.py

* Update CODEOWNERS

* Revert "Update CODEOWNERS"

This reverts commit 82e979f.

* Update noisy_dense.py

* Update noisy_dense.py

* Update noisy_dense.py

* Update noisy_dense.py

* Revert "Update CODEOWNERS"

This reverts commit 840ab1c.

* Revert "Revert "Update CODEOWNERS""

This reverts commit 7852e62.

* Update noisy_dense.py

* Code reformatted with updated black

* Update noisy_dense.py

* Update noisy_dense.py

* Update noisy_dense.py

* Added support for manual noise reset

* support for noise removal

* tests for noise removal

* use typecheck and remove unicode,

* fix typo and code cleanup

* control noise removal through call

* Inherit from Dense instead of Layer

* Added missing comment

* Documentation and test improvement

* fix typo

* minor formatting changes

* minor formatting fix

Co-authored-by: schaall <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Noisy Networks for Exploration
4 participants