Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

net.Cast("float16") doesn't work: Check failed: (*in_type)[i] == dtype_param (2 vs. 0) : This layer requires uniform type. Expected 'float32' v.s. given 'float16' at 'gamma' #17164

Closed
Rainweic opened this issue Dec 24, 2019 · 12 comments · Fixed by #17212

Comments

@Rainweic
Copy link

Description

When I use my dataset to trian Gluoncv's SSD with Float16, I meet this error:

Check failed: (*in_type)[i] == dtype_param (2 vs. 0) : This layer requires uniform type. Expected 'float32' v.s. given 'float16' at 'gamma'

I find in SSD, some BatchNormal Layers is use symbol to achieve. Because this, net.cast('float16) loses efficacy.

Error Message

mxnet1.5.1:

Check failed: (*in_type)[i] == dtype_param (2 vs. 0) : This layer requires uniform type. Expected 'float32' v.s. given 'float16' at 'gamma'

To Reproduce


**Example 1**
import mxnet
import gluoncv

net = gluoncv.model_zoo.get_model('ssd_512_resnet50_v1_voc', \
    pretrained=False, 
    pretrained_base=False,
    norm_layer=None,
    use_bn=False,
    norm_kwargs= None)
net.initialize()
net.cast("float16")                                                               # loses efficacy


one = mxnet.nd.zeros((1,3,512,512), dtype="float16")       # meet error

net(one)

Example 2

import mxnet as mx

data = mx.sym.var(name="data")
data = mx.sym.Convolution(data, num_filter=512, kernel=(3, 3), pad=(1, 1))
data = mx.sym.Activation(data, act_type="relu")
data = mx.sym.BatchNorm(data)

net = mx.gluon.SymbolBlock(data, mx.sym.var(name='data'))
net.cast("float16")                                                               # loses efficacy
net.initialize()

net(mx.nd.ones((1, 3, 512, 512), dtype='float16'))             # meet error

Steps to reproduce

What have you tried to solve it?

Example1
I don't know how to solve it

Example2
Change it like this:

import mxnet as mx

data = mx.sym.var(name="data", dtype='float16')
data = mx.sym.Convolution(data, num_filter=512, kernel=(3, 3), pad=(1, 1), )
data = mx.sym.Activation(data, act_type="relu")
data = mx.sym.BatchNorm(data)

net = mx.gluon.SymbolBlock(data, mx.sym.var(name='data', dtype='float16'))
# net.cast("float16")
net.initialize()

net(mx.nd.ones((1, 3, 512, 512), dtype='float16'))

Environment

ubuntu 18 mxnet1.5.1
mac mxnet1.5.1

@Rainweic Rainweic added the Bug label Dec 24, 2019
@kexinyu
Copy link

kexinyu commented Dec 27, 2019

I'm having a similar issue, but I guess this explains why the cast loses efficacy for BatchNorm layer?
https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/gluon/nn/basic_layers.py#L359-L362

class BatchNorm(HybridBlock):
    ....
    def cast(self, dtype):
        if np.dtype(dtype).name == 'float16':
            dtype = 'float32'
        super(BatchNorm, self).cast(dtype)

so 'gamma' is still in float32, while the input is in float16, which causes the check failure.

@Rainweic
Copy link
Author

I'm having a similar issue, but I guess this explains why the cast loses efficacy for BatchNorm layer?
https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/gluon/nn/basic_layers.py#L359-L362

class BatchNorm(HybridBlock):
    ....
    def cast(self, dtype):
        if np.dtype(dtype).name == 'float16':
            dtype = 'float32'
        super(BatchNorm, self).cast(dtype)

so 'gamma' is still in float32, while the input is in float16, which causes the check failure.

姐 他为啥这么设置啊 我都懵了 这不让BN转fp16? 我该咋训练和转化。。。

@samskalicky
Copy link
Contributor

This is a known issue with BatchNorm operator not supporting float16. It needs some work to support, not currently a bug but not-implemented feature
@mxnet-label-bot update [Feature ]

@samskalicky
Copy link
Contributor

@mxnet-label-bot update [Feature request]

@kexinyu
Copy link

kexinyu commented Dec 30, 2019

I'm having a similar issue, but I guess this explains why the cast loses efficacy for BatchNorm layer?
https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/gluon/nn/basic_layers.py#L359-L362

class BatchNorm(HybridBlock):
    ....
    def cast(self, dtype):
        if np.dtype(dtype).name == 'float16':
            dtype = 'float32'
        super(BatchNorm, self).cast(dtype)

so 'gamma' is still in float32, while the input is in float16, which causes the check failure.

姐 他为啥这么设置啊 我都懵了 这不让BN转fp16? 我该咋训练和转化。。。

BatchNorm is a “blacklist” function for which 16 bits of precision may not be sufficient. So you want to ensure that inputs into BatchNorm layer use float32, or you may have convergence issues.

@Rainweic
Copy link
Author

I'm having a similar issue, but I guess this explains why the cast loses efficacy for BatchNorm layer?
https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/gluon/nn/basic_layers.py#L359-L362

class BatchNorm(HybridBlock):
    ....
    def cast(self, dtype):
        if np.dtype(dtype).name == 'float16':
            dtype = 'float32'
        super(BatchNorm, self).cast(dtype)

so 'gamma' is still in float32, while the input is in float16, which causes the check failure.

姐 他为啥这么设置啊 我都懵了 这不让BN转fp16? 我该咋训练和转化。。。

BatchNorm is a “blacklist” function for which 16 bits of precision may not be sufficient. So you want to ensure that inputs into BatchNorm layer use float32, or you may have convergence issues.

噢 这样子啊!谢谢小姐姐~( ̄▽ ̄~)~ 但是我心中有个疑问,batchnorm基本都是大量使用,该如何确保bn层的输入是fp32而其它算子是fp16呢。。。能否给个简单的代码样例

@kexinyu
Copy link

kexinyu commented Dec 30, 2019

This is a known issue with BatchNorm operator not supporting float16. It needs some work to support, not currently a bug but not-implemented feature
@mxnet-label-bot update [Feature ]

Then how does the current BatchNorm operator handle float16 inputs in mixed precision training?
https://github.com/apache/incubator-mxnet/blob/b6972bb055fc44481b072db3abb90e26ee27c787/src/operator/nn/batch_norm-inl.h#L292
Here, if the inputs use float16, DType becomes float16 but AccReal is float32.
https://github.com/apache/incubator-mxnet/blob/5fb29167a1a66480864486bf59c6b4e980ce7daa/src/operator/nn/batch_norm.cu#L241
Is gamma * (inp - mean) * invstd + beta okay with mixed precision operands?
https://github.com/apache/incubator-mxnet/blob/692f49f2b1b9df1bb226c586405226291c6095cf/src/operator/nn/convolution.cc#L297
Or are they the reason why UNIFORM_TYPE_CHECK fails with error
Check failed: (*in_type)[i] == dtype_param (2 vs. 0) : This layer requires uniform type. Expected 'float32' v.s. given 'float16' at 'gamma'
?

@kexinyu
Copy link

kexinyu commented Dec 30, 2019

I'm having a similar issue, but I guess this explains why the cast loses efficacy for BatchNorm layer?
https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/gluon/nn/basic_layers.py#L359-L362

class BatchNorm(HybridBlock):
    ....
    def cast(self, dtype):
        if np.dtype(dtype).name == 'float16':
            dtype = 'float32'
        super(BatchNorm, self).cast(dtype)

so 'gamma' is still in float32, while the input is in float16, which causes the check failure.

姐 他为啥这么设置啊 我都懵了 这不让BN转fp16? 我该咋训练和转化。。。

BatchNorm is a “blacklist” function for which 16 bits of precision may not be sufficient. So you want to ensure that inputs into BatchNorm layer use float32, or you may have convergence issues.

噢 这样子啊!谢谢小姐姐~( ̄▽ ̄~)~ 但是我心中有个疑问,batchnorm基本都是大量使用,该如何确保bn层的输入是fp32而其它算子是fp16呢。。。能否给个简单的代码样例

我也正有这个疑惑,正在研究中哈哈,弄明白了告诉你=w=

@Rainweic
Copy link
Author

This is a known issue with BatchNorm operator not supporting float16. It needs some work to support, not currently a bug but not-implemented feature
@mxnet-label-bot update [Feature ]

Thank you! But I still doubt about it. BatchNorm doesn't support fp16, how can I train my model with fp16?

@Rainweic
Copy link
Author

I'm having a similar issue, but I guess this explains why the cast loses efficacy for BatchNorm layer?
https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/gluon/nn/basic_layers.py#L359-L362

class BatchNorm(HybridBlock):
    ....
    def cast(self, dtype):
        if np.dtype(dtype).name == 'float16':
            dtype = 'float32'
        super(BatchNorm, self).cast(dtype)

so 'gamma' is still in float32, while the input is in float16, which causes the check failure.

姐 他为啥这么设置啊 我都懵了 这不让BN转fp16? 我该咋训练和转化。。。

BatchNorm is a “blacklist” function for which 16 bits of precision may not be sufficient. So you want to ensure that inputs into BatchNorm layer use float32, or you may have convergence issues.

噢 这样子啊!谢谢小姐姐~( ̄▽ ̄~)~ 但是我心中有个疑问,batchnorm基本都是大量使用,该如何确保bn层的输入是fp32而其它算子是fp16呢。。。能否给个简单的代码样例

我也正有这个疑惑,正在研究中哈哈,弄明白了告诉你=w=

嗯呢。话说方便的话能否留下个联系方式么

@ptrendx
Copy link
Member

ptrendx commented Jan 2, 2020

This error actually seems to come from the fact that the SymbolBlock does not know really what symbols are inside it, and so it casts all of the parameters in the net.cast('float16') call. However, BatchNorm layer is slightly special in that it actually requires its parameters gamma and beta to be float32 even if input is float16 in order to not lose precision. So the BatchNorm layer expects gamma to be float32, but the parameter given to it is in float16.

@Rainweic I would recommend looking into AMP for training: https://mxnet.apache.org/api/python/docs/tutorials/performance/backend/amp.html

@zhreshold FYI

@Rainweic
Copy link
Author

Rainweic commented Jan 3, 2020

This error actually seems to come from the fact that the SymbolBlock does not know really what symbols are inside it, and so it casts all of the parameters in the net.cast('float16') call. However, BatchNorm layer is slightly special in that it actually requires its parameters gamma and beta to be float32 even if input is float16 in order to not lose precision. So the BatchNorm layer expects gamma to be float32, but the parameter given to it is in float16.

@Rainweic I would recommend looking into AMP for training: https://mxnet.apache.org/api/python/docs/tutorials/performance/backend/amp.html

@zhreshold FYI

Thank you! Let me try

ptrendx pushed a commit that referenced this issue Jan 6, 2020
)

* fix symbolblock with bn+fp16

* add unittest

* fix

* remove unused

* fix lint
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants