Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add rrelu kernel #573

Merged
merged 14 commits into from
Oct 25, 2019
Merged

add rrelu kernel #573

merged 14 commits into from
Oct 25, 2019

Conversation

fsx950223
Copy link
Member

@fsx950223 fsx950223 commented Oct 6, 2019

kernel implemention of rrelu. @WindQAQ

Copy link
Member

@WindQAQ WindQAQ left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fsx950223 Thanks for the contribution! Generally LGTM. I know this is WIP but want to address some points first. Thanks again for bringing rrelu up!

tensorflow_addons/activations/rrelu.py Outdated Show resolved Hide resolved
tensorflow_addons/activations/rrelu.py Outdated Show resolved Hide resolved
tensorflow_addons/activations/rrelu_test.py Outdated Show resolved Hide resolved
tensorflow_addons/activations/rrelu_test.py Outdated Show resolved Hide resolved
tensorflow_addons/activations/rrelu_test.py Show resolved Hide resolved
@fsx950223
Copy link
Member Author

fsx950223 commented Oct 8, 2019

It seems the formal(5) shows in https://arxiv.org/pdf/1505.00853.pdf is wrong.It should be x*(lower+upper)/2

@WindQAQ
Copy link
Member

WindQAQ commented Oct 9, 2019

It seems the formal(5) shows in https://arxiv.org/pdf/1505.00853.pdf is wrong.It should be x*(lower+upper)/2

I could confirm that pytorch computes like x * (lower + upper) / 2.
https://github.com/pytorch/pytorch/blob/master/aten/src/THNN/generic/RReLU.c#L66-L72

x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)
lower = 0.2
upper = 0.2
result, alpha = rrelu(x, lower, upper, training=True)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have any idea about random test? @WindQAQ

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we do not have to return alpha for rrelu OP. In your experiment, setting seed in TF is not enough, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argument with_alpha can control the return behavior and I don't need seed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fsx950223 I do not think it's a good practice to have an argument to return such things for public API. Is alpha deterministic while setting the same TF seed?

Copy link
Member Author

@fsx950223 fsx950223 Oct 11, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it isn't and I don't know how to let alpha deterministic.Is there a better way to set seed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Umm, I think we could just test values with training=False. As for gradient testing, we should check both of them. How do feel about this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

@seanpmorgan
Copy link
Member

Hi @fsx950223 when time allows, mind refactoring to use the custom_op_library from #581

Copy link
Member

@WindQAQ WindQAQ left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally LGTM. Some API designs are needed to discuss. Thanks!

if with_alpha:
return result, alpha
else:
return result
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no idea if we should expose an argument to return alpha or not. My survey is that chainer does but pytorch does not. What do you think @fsx950223 ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I return alpha just for test. It could be removed if the test case is fixed without alpha.

Copy link
Member Author

@fsx950223 fsx950223 Oct 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem seems has been solved by using philox random.

grad = t.gradient(result, x)
expect_grad = _ref_rrelu_grad(x, alpha, dtype)
self.assertAllCloseAccordingToType(
grad, expect_grad, atol=1e-4)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use tf.test.compute_gradient to check gradients? Just like what you do in hardtanh.

Copy link
Member Author

@fsx950223 fsx950223 Oct 16, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test failed when the training argument is True and test failed when training argument is False and data is 0.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fsx950223 Sorry for the late reply. Could you push the codes so that I could test it locally? Thanks.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

@fsx950223 fsx950223 Oct 20, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fsx950223 Sorry for the late reply. Could you push the codes so that I could test it locally? Thanks.

How do you debug the code. I can't step into source code anymore since tf2 released. @WindQAQ

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit: I successfully re-configured the environment with ./configure.sh on the docker early today.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that I can't go to code definition.
Screenshot from 2019-10-20 17-49-01
I tried it in Vscode and Pycharm, but neither of them worked.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe _compute_numeric_jacobian is not suitable for this test case.Because the activcation has different gradients around 0. Please reference Relu test case

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe _compute_numeric_jacobian is not suitable for this test case.Because the activcation has different gradients around 0. Please reference Relu test case

I think it's ok to modify the input value to avoid non-smooth part. Please see relu's gradient check.
https://github.com/tensorflow/tensorflow/blob/66ea3ed9b8cbbbf01b0eabb14e436883895e4bde/tensorflow/python/kernel_tests/relu_op_test.py#L123

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that I can't go to code definition.
Screenshot from 2019-10-20 17-49-01
I tried it in Vscode and Pycharm, but neither of them worked.

Ohoh. Got it. But could the code execute successfully if ide can not find the definition of functions? Sorry that I usually use vim so that I cannot understand what you talk about for a while... Thanks!

tensorflow_addons/custom_ops/activations/BUILD Outdated Show resolved Hide resolved
@fsx950223
Copy link
Member Author

fsx950223 commented Oct 20, 2019 via email

@fsx950223
Copy link
Member Author

fsx950223 commented Oct 20, 2019 via email

@fsx950223
Copy link
Member Author

The CPU kernel and the GPU kernel have different behaviors?

Copy link
Member

@WindQAQ WindQAQ left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the nice and long work!

tensorflow_addons/activations/rrelu.py Outdated Show resolved Hide resolved
tensorflow_addons/activations/rrelu.py Show resolved Hide resolved
Copy link
Member

@WindQAQ WindQAQ left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fsx950223 Nice PR. Thanks for the contribution!

@WindQAQ WindQAQ merged commit 8a49f91 into tensorflow:master Oct 25, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants