Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of Normalizations #14

Merged
merged 30 commits into from
Mar 9, 2019

Conversation

Smokrow
Copy link
Contributor

@Smokrow Smokrow commented Jan 15, 2019

First draft of a file for Normalizations implementing #6 .
I have Implemented GroupNorm, InstanceNorm and LayerNorm and a first testcase for GroupNorm( I will add a few more and will also implement some for Layer/Instance norm).

Could you give me a quick feedback on the implementation or if something is missing?

Thx in advance

@seanpmorgan seanpmorgan self-requested a review January 16, 2019 01:26
Copy link
Member

@seanpmorgan seanpmorgan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Smokrow Tyvm for the PR. In general looks good, few changes requested.

Regarding test cases, as you mentioned there are some more things we'll want to check, including some of the tests from: https://github.com/tensorflow/tensorflow/blob/v1.12.0/tensorflow/contrib/layers/python/layers/normalization_test.py

tensorflow_addons/layers/python/normalizations.py Outdated Show resolved Hide resolved
tensorflow_addons/layers/python/normalizations.py Outdated Show resolved Hide resolved
tensorflow_addons/layers/python/normalizations.py Outdated Show resolved Hide resolved
@guillaumekln
Copy link
Contributor

Looks like a layer normalization layer has been added to tf.keras.layers.experimental.LayerNormalization, so it should probably be removed from this PR.

@seanpmorgan
Copy link
Member

@karmel as discussed in our monthly meeting... LayerNorm was added to keras experimental which kind of stomps on the implementation we were going to add. Could you please update us on the keras experimental roadmap?

Additionally, addons was going to add GroupNorm as a generalized normalization case (See below image). Any thoughts on how we should proceed?

groupnorm
https://arxiv.org/pdf/1803.08494.pdf

@karmel
Copy link
Contributor

karmel commented Feb 5, 2019

Notes from discussions with the Keras team--

In general, this makes more sense in Addons than in experimental. It was added to experimental as a stopgap for some migration work, but we all agree this is a better fit for Addons than experimental, as the scope of use-cases is fairly narrow.

In the future, we will make sure to de-dupe with Addons before adding to experimental, and to prefer Addons unless we are actually adding a Layer/etc. that will end up in core, but has some API kinks to work out.

So, for this PR, if you could go ahead and dedupe with the experimental implementation and push here, we will remove the experimental implementation and use this one instead.

Thanks, all, for working on this.

@seanpmorgan
Copy link
Member

@Smokrow Would you be able to modify the test coverage to be as extensive as the reference in tf.keras experimental. Then we can merge and request the removal of LayerNormalization from core

@Smokrow
Copy link
Contributor Author

Smokrow commented Feb 8, 2019

On it. Thx for clearifying 👍

@facaiy
Copy link
Member

facaiy commented Feb 10, 2019

Hi, could you resolve conflicts?

@Smokrow
Copy link
Contributor Author

Smokrow commented Feb 10, 2019

This is still not finished.

@Smokrow
Copy link
Contributor Author

Smokrow commented Feb 10, 2019

@seanpmorgan I am having some trouble with running the tests.

I am getting the following while running the test_groupnorm_flat and test_groupnorm_conv

Traceback (most recent call last):
  File "/usr/local/lib/python2.7/dist-packages/absl/third_party/unittest3_backport/case.py", line 37, in testPartExecutor
    yield
  File "/usr/local/lib/python2.7/dist-packages/absl/third_party/unittest3_backport/case.py", line 162, in run
    testMethod()
  File "normalizations_test.py", line 82, in test_groupnorm_conv
    model.fit(np.random.random((10,20, 20, 3)))
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/keras/engine/training.py", line 989, in fit
    steps_name='steps_per_epoch')
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/keras/engine/training_arrays.py", line 330, in model_iteration
    batch_outs = f(ins_batch)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/keras/backend.py", line 3123, in __call__
    outputs = self._graph_fn(*converted_inputs)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/eager/function.py", line 452, in __call__
    return self._call_flat(args)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/eager/function.py", line 521, in _call_flat
    outputs = self._inference_function.call(ctx, args)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/eager/function.py", line 291, in call
    (len(args), len(list(self.signature.input_arg))))
ValueError: Arguments and signature arguments do not match: 15 17

Looks like it is deep down in the Eager execution but I am quite confused whats going on down there 😄
Greets

@facaiy
Copy link
Member

facaiy commented Feb 12, 2019

I check out the feature branch, and fail to run bazel tests.

screen shot 2019-02-12 at 20 26 16

@Smokrow Hi, Moritz. Could you fix bazel configuration and make sure we can reproduce your problem? It would be greatly helpful for debug :-)

@Smokrow
Copy link
Contributor Author

Smokrow commented Feb 17, 2019

@seanpmorgan ready for review

@seanpmorgan seanpmorgan self-assigned this Feb 19, 2019
Copy link
Member

@facaiy facaiy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I leave some comments, and most of them are trivial code style problem. Thanks for the PR :-)

tensorflow_addons/layers/python/normalizations.py Outdated Show resolved Hide resolved
tensorflow_addons/layers/python/normalizations.py Outdated Show resolved Hide resolved
group_axes.insert(1, self.groups)

# reshape inputs to new group shape
group_shape = [group_axes[0], self.groups] + group_axes[2:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it handle both channel-first and channel-last format?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. at this point the ordering would be [batch, group, channels, steps]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@facaiy since you are setting up an axis to work on I am not quite sure if you can handle "channel first" and "channel last". When somebody sets his axis exactly on the channel axis I think he should be allowed to do that.

tensorflow_addons/layers/python/normalizations.py Outdated Show resolved Hide resolved
tensorflow_addons/layers/python/normalizations.py Outdated Show resolved Hide resolved
tensorflow_addons/layers/python/normalizations.py Outdated Show resolved Hide resolved
tensorflow_addons/layers/python/normalizations_test.py Outdated Show resolved Hide resolved
tensorflow_addons/layers/python/normalizations_test.py Outdated Show resolved Hide resolved
tensorflow_addons/layers/BUILD Outdated Show resolved Hide resolved
tensorflow_addons/layers/BUILD Show resolved Hide resolved
@facaiy facaiy requested a review from seanpmorgan February 23, 2019 22:06
tensorflow_addons/layers/python/normalizations.py Outdated Show resolved Hide resolved
tensorflow_addons/layers/python/normalizations.py Outdated Show resolved Hide resolved
tensorflow_addons/layers/python/normalizations.py Outdated Show resolved Hide resolved
groups: Integer, the number of groups for Group Normalization.
Can be in the range [1, N] where N is the input dimension.
The input dimension must be divisible by the number of groups.
axis: Integer, the axis that should be normalized
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case of 4D input tensor, the axis that's normalized is either C/G, H, W, or Batch, G, depend on your definition of "be normalized". But it is in no way "C".

BatchNorm layer takes axis=C, therefore if you want to make an analogy here, it would be axis=[Batch, G]. This analogy is a bit ugly, so I think a better way is to still define axis to be the channel dimension, and write a clearer documentation about what this layer actually does for 2D and 4D tensors, respectively.

Same comment applies to other norms.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to remove this explanation completly since it does not really belong into the code docs. Currently there are colab notebooks planned for better documentation and explanation. I am currently writing one for layer/group/instance normalization layers. If you want I can reference you in the PR when it is finished.

@Smokrow
Copy link
Contributor Author

Smokrow commented Mar 7, 2019

@seanpmorgan @facaiy ready for review 👍

@facaiy
Copy link
Member

facaiy commented Mar 9, 2019

@Smokrow Thanks for your work, Moritz!

@ppwwyyxx Yuxin, would you mind taking a look again? Thank you very much for your review :-)

@ppwwyyxx
Copy link

ppwwyyxx commented Mar 9, 2019

I don't have other comments. The description for the "axis" argument is still not very accurate but it seems there are other plans to address it.

* Remove tf.logging as part of TF2
* Add normaliztion layers to init
* Update READMEs
@googlebot
Copy link

So there's good news and bad news.

👍 The good news is that everyone that needs to sign a CLA (the pull request submitter and all commit authors) have done so. Everything is all good there.

😕 The bad news is that it appears that one or more commits were authored or co-authored by someone other than the pull request submitter. We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that here in the pull request.

Note to project maintainer: This is a terminal state, meaning the cla/google commit status will not change from this state. It's up to you to confirm consent of all the commit author(s), set the cla label to yes (if enabled on your project), and then merge this pull request when appropriate.

ℹ️ Googlers: Go here for more info.

@googlebot googlebot added cla: no and removed cla: yes labels Mar 9, 2019
@seanpmorgan
Copy link
Member

Thanks so much for the contribution Moritz. Made some minor formatting changes & integrations with the project if you want to check them out. Looking forward to the example/demo notebook. Thanks for the review @ppwwyyxx we'll be sure to address the axis ambiguity with the demos and if that is not sufficient we can iterate on the parameters as we're still early in release cycles.

@seanpmorgan seanpmorgan merged commit 423fdb6 into tensorflow:master Mar 9, 2019
Squadrick pushed a commit to Squadrick/addons that referenced this pull request Mar 26, 2019
* Implemented GroupNorm,InstanceNorm and LayerNorm
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants