From 7c8abb17cb65830a3e34229b21150c28d5a00e7b Mon Sep 17 00:00:00 2001 From: khaotik Date: Sat, 20 Mar 2021 16:06:40 +0800 Subject: [PATCH] fix formula --- python/mxnet/optimizer/adabelief.py | 2 +- src/operator/contrib/adabelief-inl.h | 31 ++++++++++++++-------------- src/operator/contrib/adabelief.cc | 12 +++++------ src/operator/contrib/adabelief.cu | 5 ++--- src/operator/contrib/adamw-inl.h | 6 +++--- src/operator/contrib/adamw.cc | 4 ++-- src/operator/contrib/adamw.cu | 4 ++-- 7 files changed, 32 insertions(+), 32 deletions(-) diff --git a/python/mxnet/optimizer/adabelief.py b/python/mxnet/optimizer/adabelief.py index 25ad26159f50..e9f31eebf29f 100644 --- a/python/mxnet/optimizer/adabelief.py +++ b/python/mxnet/optimizer/adabelief.py @@ -118,9 +118,9 @@ def step(self, indices, weights, grads, states): # preprocess grad grad *= self.rescale_grad + grad += wd * weight if self.clip_gradient is not None: grad = clip(grad, -self.clip_gradient, self.clip_gradient) - grad += wd * weight if self.correct_bias: coef1 = 1. - self.beta1**t coef2 = 1. - self.beta2**t diff --git a/src/operator/contrib/adabelief-inl.h b/src/operator/contrib/adabelief-inl.h index 48c942b61584..7ca3de5dc93f 100644 --- a/src/operator/contrib/adabelief-inl.h +++ b/src/operator/contrib/adabelief-inl.h @@ -32,6 +32,7 @@ namespace mxnet { namespace op { +namespace adabelief { struct AdaBeliefParam : public dmlc::Parameter { float lr; @@ -107,17 +108,17 @@ struct MPAdaBeliefKernel { const float param_rescale_grad, const float param_epsilon) { float w = weight32[i]; float scaled_grad = param_rescale_grad*static_cast(grad_data[i]); + scaled_grad += param_wd * w; if (param_clip_gradient >= 0.f) scaled_grad = mshadow_op::clip::Map(scaled_grad, param_clip_gradient); - scaled_grad += param_wd * weight_data[i]; const float mean = param_beta1 * (mean_data[i] - scaled_grad) + scaled_grad; - const float adj = mshadow_op::square::Map(mean - scaled_grad); - const float var = param_beta2 * var_data[i] + (1.f - param_beta2) * adj + param_epsilon; + const float adj = mshadow_op::square::Map(scaled_grad - mean); + const float var = param_beta2*(var_data[i] - adj) + adj + param_epsilon; + w -= param_eta * (param_lr * mean / (mshadow_op::square_root::Map(var) + param_epsilon)); mean_data[i] = mean; var_data[i] = var; - w = w - param_eta * (param_lr * mean / (mshadow_op::square_root::Map(var) + param_epsilon)); weight32[i] = w; KERNEL_ASSIGN(out_data[i], req, w); } @@ -174,18 +175,19 @@ struct AdaBeliefUpdate { Tensor var = inputs[3].FlatTo2D(s); Tensor out = outputs[0].FlatTo2D(s); - grad = scalar(rescale_grad) * grad; + grad = scalar(rescale_grad) * grad + scalar(param.wd) * weight; if (param.clip_gradient >= 0.0f) grad = F(grad, DType(param.clip_gradient)); mean = scalar(param.beta1) * mean + scalar(1.f-param.beta1) * grad; - var = scalar(param.beta2) * var + scalar(1.f-param.beta2) * F(grad); + var = scalar(param.beta2) * var + + scalar(1.f-param.beta2) * F(grad - mean) + + scalar(param.epsilon); Assign(out, req[0], weight - scalar(param.eta) * (scalar(param.lr) * - mean / (F(var) + scalar(param.epsilon)) + - (scalar(param.wd) * weight))); + mean / (F(var) + scalar(param.epsilon)))); }); } }; @@ -352,7 +354,7 @@ template struct MultiMPAdaBeliefKernel { template MSHADOW_XINLINE static void Map(int i, const MultiKernelParam& param, - const OpReqType req, const float rescale_grad){ + const OpReqType req, const float rescale_grad) { for (int index = 0; index < param.count; ++index) { if ((size_t)i < param.sizes[index]) { MPDType w = has_mixed_precision ? param.weights32[index][i]: @@ -360,13 +362,13 @@ struct MultiMPAdaBeliefKernel { MPDType scaled_grad = static_cast(rescale_grad)* static_cast(param.grad_data[index][i]); - if (param.clip_gradient >= 0.f) - scaled_grad = mshadow_op::clip::Map(scaled_grad, param.clip_gradient) ; scaled_grad += param.wds[index] * w; + if (param.clip_gradient >= 0.f) + scaled_grad = mshadow_op::clip::Map(scaled_grad, param.clip_gradient); - const auto mean = param.beta1 * (param.mean_data[index][i]- scaled_grad) + scaled_grad; + const auto mean = param.beta1 * (param.mean_data[index][i] - scaled_grad) + scaled_grad; const auto adj = mshadow_op::square::Map(mean - scaled_grad); - const auto var = param.var_data[index][i] + (1.f - param.beta2) * adj + param.epsilon; + const auto var = param.beta2 * (param.var_data[index][i] - adj) + adj + param.epsilon; param.mean_data[index][i] = mean; param.var_data[index][i] = var; @@ -444,7 +446,6 @@ static inline void MultiAdaBeliefUpdate(const nnvm::NodeAttrs& attrs, }); } -namespace adabelief { template void GetScaleFloat(mshadow::Stream *s, const TBlob &scale_blob, float *pScalef); @@ -497,8 +498,8 @@ inline void multiMPUpdate(const nnvm::NodeAttrs& attrs, MultiAdaBeliefUpdate (attrs, ctx, inputs_wo_scale, req, outputs, scalef); } -} // namespace adabelief +} // namespace adabelief } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/adabelief.cc b/src/operator/contrib/adabelief.cc index 49a3ef6ca38c..f881c5ab7b16 100644 --- a/src/operator/contrib/adabelief.cc +++ b/src/operator/contrib/adabelief.cc @@ -27,6 +27,7 @@ namespace mxnet { namespace op { +namespace adabelief { DMLC_REGISTER_PARAMETER(AdaBeliefParam); DMLC_REGISTER_PARAMETER(MultiAdaBeliefParam); @@ -65,7 +66,7 @@ the update is skipped. [](const nnvm::NodeAttrs& attrs) { return std::vector{2, 3, 4}; }) -.set_attr("FCompute", adabelief::MPUpdate>) +.set_attr("FCompute", MPUpdate>) .add_argument("weight", "NDArray-or-Symbol", "Weight") .add_argument("grad", "NDArray-or-Symbol", "Gradient") .add_argument("mean", "NDArray-or-Symbol", "Moving mean") @@ -110,7 +111,7 @@ the update is skipped. [](const nnvm::NodeAttrs& attrs) { return std::vector{2, 3}; }) -.set_attr("FCompute", adabelief::MPUpdate>) +.set_attr("FCompute", MPUpdate>) .add_argument("weight", "NDArray-or-Symbol", "Weight") .add_argument("grad", "NDArray-or-Symbol", "Gradient") .add_argument("mean", "NDArray-or-Symbol", "Moving mean") @@ -120,14 +121,12 @@ the update is skipped. "the update is skipped.") .add_arguments(AdaBeliefParam::__FIELDS__()); -namespace adabelief { template<> void GetScaleFloat(mshadow::Stream *s, const TBlob &scale_blob, float *pScalef) { MSHADOW_REAL_TYPE_SWITCH(scale_blob.type_flag_, DType, *pScalef = static_cast(*scale_blob.dptr()); ) } -} static std::vector ParamToVector(uint32_t num_args, const char *pName[], size_t nParams) { std::vector ret; @@ -195,7 +194,7 @@ the update is skipped. return ret; }) -.set_attr("FCompute", adabelief::multiMPUpdate) +.set_attr("FCompute", multiMPUpdate) .add_argument("data", "NDArray-or-Symbol[]", "data") .add_arguments(MultiAdaBeliefParam::__FIELDS__()); @@ -252,9 +251,10 @@ the update is skipped. return ret; }) -.set_attr("FCompute", adabelief::multiMPUpdate) +.set_attr("FCompute", multiMPUpdate) .add_argument("data", "NDArray-or-Symbol[]", "data") .add_arguments(MultiAdaBeliefParam::__FIELDS__()); +} } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/adabelief.cu b/src/operator/contrib/adabelief.cu index 2ee3ca4e6643..f82322e38f79 100644 --- a/src/operator/contrib/adabelief.cu +++ b/src/operator/contrib/adabelief.cu @@ -27,7 +27,6 @@ namespace mxnet { namespace op { - namespace adabelief { template<> void GetScaleFloat(mshadow::Stream *s, const TBlob &scale_blob, float *pScalef) { @@ -43,10 +42,10 @@ void GetScaleFloat(mshadow::Stream *s, const TBlob &scale_blob, float } NNVM_REGISTER_OP(_adabelief_update) -.set_attr("FCompute", adabelief::MPUpdate>); +.set_attr("FCompute", adabelief::MPUpdate>); NNVM_REGISTER_OP(_mp_adabelief_update) -.set_attr("FCompute", adabelief::MPUpdate>); +.set_attr("FCompute", adabelief::MPUpdate>); NNVM_REGISTER_OP(_multi_adabelief_update) .set_attr("FCompute", adabelief::multiMPUpdate); diff --git a/src/operator/contrib/adamw-inl.h b/src/operator/contrib/adamw-inl.h index 8a1f3f249145..56c5ea227862 100644 --- a/src/operator/contrib/adamw-inl.h +++ b/src/operator/contrib/adamw-inl.h @@ -32,6 +32,7 @@ namespace mxnet { namespace op { +namespace adamw { struct AdamWParam : public dmlc::Parameter { float lr; @@ -114,7 +115,7 @@ struct MPAdamWKernel { float var = var_data[i] = param_beta2 * var_data[i] + (1.0f - param_beta2) * mshadow_op::square::Map(scaled_grad); - w = w - param_eta * (param_lr * mean / (mshadow_op::square_root::Map(var) + param_epsilon) + w -= param_eta * (param_lr * mean / (mshadow_op::square_root::Map(var) + param_epsilon) + param_wd * w); weight32[i] = w; KERNEL_ASSIGN(out_data[i], req, w); @@ -441,7 +442,6 @@ static inline void MultiAdamWUpdate(const nnvm::NodeAttrs& attrs, }); } -namespace adamw { template static void GetScaleFloat(mshadow::Stream *s, const TBlob &scale_blob, float *pScalef); @@ -494,8 +494,8 @@ inline void multiMPUpdate(const nnvm::NodeAttrs& attrs, MultiAdamWUpdate (attrs, ctx, inputs_wo_scale, req, outputs, scalef); } -} // namespace adamw +} // namespace adamw } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/adamw.cc b/src/operator/contrib/adamw.cc index 46920aa5cb9b..0766de78f15a 100644 --- a/src/operator/contrib/adamw.cc +++ b/src/operator/contrib/adamw.cc @@ -27,6 +27,7 @@ namespace mxnet { namespace op { +namespace adamw { DMLC_REGISTER_PARAMETER(AdamWParam); DMLC_REGISTER_PARAMETER(MultiAdamWParam); @@ -118,14 +119,12 @@ the update is skipped. "the update is skipped.") .add_arguments(AdamWParam::__FIELDS__()); -namespace adamw { template<> void GetScaleFloat(mshadow::Stream *s, const TBlob &scale_blob, float *pScalef) { MSHADOW_REAL_TYPE_SWITCH(scale_blob.type_flag_, DType, *pScalef = static_cast(*scale_blob.dptr()); ) } -} static std::vector ParamToVector(uint32_t num_args, const char *pName[], size_t nParams) { std::vector ret; @@ -255,5 +254,6 @@ the update is skipped. .add_arguments(MultiAdamWParam::__FIELDS__()); +} // namespace adamw } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/adamw.cu b/src/operator/contrib/adamw.cu index c07767a1e4d6..a58656d18ac8 100644 --- a/src/operator/contrib/adamw.cu +++ b/src/operator/contrib/adamw.cu @@ -27,8 +27,8 @@ namespace mxnet { namespace op { - namespace adamw{ + template<> void GetScaleFloat(mshadow::Stream *s, const TBlob &scale_blob, float *pScalef) { MSHADOW_REAL_TYPE_SWITCH(scale_blob.type_flag_, DType, { @@ -40,7 +40,6 @@ void GetScaleFloat(mshadow::Stream *s, const TBlob &scale_blob, float *pScalef = static_cast(scale); }) } -} NNVM_REGISTER_OP(_adamw_update) .set_attr("FCompute", adamw::MPUpdate>); @@ -54,5 +53,6 @@ NNVM_REGISTER_OP(_multi_adamw_update) NNVM_REGISTER_OP(_multi_mp_adamw_update) .set_attr("FCompute", adamw::multiMPUpdate); +} // namespace adamw } // namespace op } // namespace mxnet