Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix formula
Browse files Browse the repository at this point in the history
  • Loading branch information
khaotik committed Mar 20, 2021
1 parent c4141d7 commit 7c8abb1
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 32 deletions.
2 changes: 1 addition & 1 deletion python/mxnet/optimizer/adabelief.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ def step(self, indices, weights, grads, states):

# preprocess grad
grad *= self.rescale_grad
grad += wd * weight
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
grad += wd * weight
if self.correct_bias:
coef1 = 1. - self.beta1**t
coef2 = 1. - self.beta2**t
Expand Down
31 changes: 16 additions & 15 deletions src/operator/contrib/adabelief-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

namespace mxnet {
namespace op {
namespace adabelief {

struct AdaBeliefParam : public dmlc::Parameter<AdaBeliefParam> {
float lr;
Expand Down Expand Up @@ -107,17 +108,17 @@ struct MPAdaBeliefKernel {
const float param_rescale_grad, const float param_epsilon) {
float w = weight32[i];
float scaled_grad = param_rescale_grad*static_cast<float>(grad_data[i]);
scaled_grad += param_wd * w;
if (param_clip_gradient >= 0.f)
scaled_grad = mshadow_op::clip::Map(scaled_grad, param_clip_gradient);
scaled_grad += param_wd * weight_data[i];

const float mean = param_beta1 * (mean_data[i] - scaled_grad) + scaled_grad;
const float adj = mshadow_op::square::Map(mean - scaled_grad);
const float var = param_beta2 * var_data[i] + (1.f - param_beta2) * adj + param_epsilon;
const float adj = mshadow_op::square::Map(scaled_grad - mean);
const float var = param_beta2*(var_data[i] - adj) + adj + param_epsilon;

w -= param_eta * (param_lr * mean / (mshadow_op::square_root::Map(var) + param_epsilon));
mean_data[i] = mean;
var_data[i] = var;
w = w - param_eta * (param_lr * mean / (mshadow_op::square_root::Map(var) + param_epsilon));
weight32[i] = w;
KERNEL_ASSIGN(out_data[i], req, w);
}
Expand Down Expand Up @@ -174,18 +175,19 @@ struct AdaBeliefUpdate {
Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);

grad = scalar<DType>(rescale_grad) * grad;
grad = scalar<DType>(rescale_grad) * grad + scalar<DType>(param.wd) * weight;
if (param.clip_gradient >= 0.0f)
grad = F<clip>(grad, DType(param.clip_gradient));

mean = scalar<DType>(param.beta1) * mean + scalar<DType>(1.f-param.beta1) * grad;
var = scalar<DType>(param.beta2) * var + scalar<DType>(1.f-param.beta2) * F<square>(grad);
var = scalar<DType>(param.beta2) * var +
scalar<DType>(1.f-param.beta2) * F<square>(grad - mean) +
scalar<DType>(param.epsilon);

Assign(out, req[0],
weight -
scalar<DType>(param.eta) * (scalar<DType>(param.lr) *
mean / (F<square_root>(var) + scalar<DType>(param.epsilon)) +
(scalar<DType>(param.wd) * weight)));
mean / (F<square_root>(var) + scalar<DType>(param.epsilon))));
});
}
};
Expand Down Expand Up @@ -352,21 +354,21 @@ template<typename MPDType, bool has_mixed_precision>
struct MultiMPAdaBeliefKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, const MultiKernelParam<DType, MPDType>& param,
const OpReqType req, const float rescale_grad){
const OpReqType req, const float rescale_grad) {
for (int index = 0; index < param.count; ++index) {
if ((size_t)i < param.sizes[index]) {
MPDType w = has_mixed_precision ? param.weights32[index][i]:
MPDType(param.weights[index][i]);
MPDType scaled_grad = static_cast<MPDType>(rescale_grad)*
static_cast<MPDType>(param.grad_data[index][i]);

if (param.clip_gradient >= 0.f)
scaled_grad = mshadow_op::clip::Map(scaled_grad, param.clip_gradient) ;
scaled_grad += param.wds[index] * w;
if (param.clip_gradient >= 0.f)
scaled_grad = mshadow_op::clip::Map(scaled_grad, param.clip_gradient);

const auto mean = param.beta1 * (param.mean_data[index][i]- scaled_grad) + scaled_grad;
const auto mean = param.beta1 * (param.mean_data[index][i] - scaled_grad) + scaled_grad;
const auto adj = mshadow_op::square::Map(mean - scaled_grad);
const auto var = param.var_data[index][i] + (1.f - param.beta2) * adj + param.epsilon;
const auto var = param.beta2 * (param.var_data[index][i] - adj) + adj + param.epsilon;

param.mean_data[index][i] = mean;
param.var_data[index][i] = var;
Expand Down Expand Up @@ -444,7 +446,6 @@ static inline void MultiAdaBeliefUpdate(const nnvm::NodeAttrs& attrs,
});
}

namespace adabelief {
template<typename xpu>
void GetScaleFloat(mshadow::Stream<xpu> *s, const TBlob &scale_blob, float *pScalef);

Expand Down Expand Up @@ -497,8 +498,8 @@ inline void multiMPUpdate(const nnvm::NodeAttrs& attrs,
MultiAdaBeliefUpdate<xpu, _single_precision, 5>
(attrs, ctx, inputs_wo_scale, req, outputs, scalef);
}
} // namespace adabelief

} // namespace adabelief
} // namespace op
} // namespace mxnet

Expand Down
12 changes: 6 additions & 6 deletions src/operator/contrib/adabelief.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

namespace mxnet {
namespace op {
namespace adabelief {

DMLC_REGISTER_PARAMETER(AdaBeliefParam);
DMLC_REGISTER_PARAMETER(MultiAdaBeliefParam);
Expand Down Expand Up @@ -65,7 +66,7 @@ the update is skipped.
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2, 3, 4};
})
.set_attr<FCompute>("FCompute<cpu>", adabelief::MPUpdate<cpu, MPAdaBeliefUpdate<cpu>>)
.set_attr<FCompute>("FCompute<cpu>", MPUpdate<cpu, MPAdaBeliefUpdate<cpu>>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
Expand Down Expand Up @@ -110,7 +111,7 @@ the update is skipped.
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2, 3};
})
.set_attr<FCompute>("FCompute<cpu>", adabelief::MPUpdate<cpu, AdaBeliefUpdate<cpu>>)
.set_attr<FCompute>("FCompute<cpu>", MPUpdate<cpu, AdaBeliefUpdate<cpu>>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
Expand All @@ -120,14 +121,12 @@ the update is skipped.
"the update is skipped.")
.add_arguments(AdaBeliefParam::__FIELDS__());

namespace adabelief {
template<>
void GetScaleFloat<cpu>(mshadow::Stream<cpu> *s, const TBlob &scale_blob, float *pScalef) {
MSHADOW_REAL_TYPE_SWITCH(scale_blob.type_flag_, DType,
*pScalef = static_cast<float>(*scale_blob.dptr<DType>());
)
}
}

static std::vector<std::string> ParamToVector(uint32_t num_args, const char *pName[], size_t nParams) {
std::vector<std::string> ret;
Expand Down Expand Up @@ -195,7 +194,7 @@ the update is skipped.
return ret;
})

.set_attr<FCompute>("FCompute<cpu>", adabelief::multiMPUpdate<cpu, false>)
.set_attr<FCompute>("FCompute<cpu>", multiMPUpdate<cpu, false>)
.add_argument("data", "NDArray-or-Symbol[]", "data")
.add_arguments(MultiAdaBeliefParam::__FIELDS__());

Expand Down Expand Up @@ -252,9 +251,10 @@ the update is skipped.
return ret;
})

.set_attr<FCompute>("FCompute<cpu>", adabelief::multiMPUpdate<cpu, true>)
.set_attr<FCompute>("FCompute<cpu>", multiMPUpdate<cpu, true>)
.add_argument("data", "NDArray-or-Symbol[]", "data")
.add_arguments(MultiAdaBeliefParam::__FIELDS__());

}
} // namespace op
} // namespace mxnet
5 changes: 2 additions & 3 deletions src/operator/contrib/adabelief.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

namespace mxnet {
namespace op {

namespace adabelief {
template<>
void GetScaleFloat<gpu>(mshadow::Stream<gpu> *s, const TBlob &scale_blob, float *pScalef) {
Expand All @@ -43,10 +42,10 @@ void GetScaleFloat<gpu>(mshadow::Stream<gpu> *s, const TBlob &scale_blob, float
}

NNVM_REGISTER_OP(_adabelief_update)
.set_attr<FCompute>("FCompute<gpu>", adabelief::MPUpdate<gpu, AdaBeliefUpdate<gpu>>);
.set_attr<FCompute>("FCompute<gpu>", adabelief::MPUpdate<gpu, adabelief::AdaBeliefUpdate<gpu>>);

NNVM_REGISTER_OP(_mp_adabelief_update)
.set_attr<FCompute>("FCompute<gpu>", adabelief::MPUpdate<gpu, MPAdaBeliefUpdate<gpu>>);
.set_attr<FCompute>("FCompute<gpu>", adabelief::MPUpdate<gpu, adabelief::MPAdaBeliefUpdate<gpu>>);

NNVM_REGISTER_OP(_multi_adabelief_update)
.set_attr<FCompute>("FCompute<gpu>", adabelief::multiMPUpdate<gpu, false>);
Expand Down
6 changes: 3 additions & 3 deletions src/operator/contrib/adamw-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

namespace mxnet {
namespace op {
namespace adamw {

struct AdamWParam : public dmlc::Parameter<AdamWParam> {
float lr;
Expand Down Expand Up @@ -114,7 +115,7 @@ struct MPAdamWKernel {
float var = var_data[i] = param_beta2 * var_data[i] +
(1.0f - param_beta2) * mshadow_op::square::Map(scaled_grad);

w = w - param_eta * (param_lr * mean / (mshadow_op::square_root::Map(var) + param_epsilon)
w -= param_eta * (param_lr * mean / (mshadow_op::square_root::Map(var) + param_epsilon)
+ param_wd * w);
weight32[i] = w;
KERNEL_ASSIGN(out_data[i], req, w);
Expand Down Expand Up @@ -441,7 +442,6 @@ static inline void MultiAdamWUpdate(const nnvm::NodeAttrs& attrs,
});
}

namespace adamw {
template<typename xpu>
static void GetScaleFloat(mshadow::Stream<xpu> *s, const TBlob &scale_blob, float *pScalef);

Expand Down Expand Up @@ -494,8 +494,8 @@ inline void multiMPUpdate(const nnvm::NodeAttrs& attrs,
MultiAdamWUpdate<xpu, Adam_single_precision, 5>
(attrs, ctx, inputs_wo_scale, req, outputs, scalef);
}
} // namespace adamw

} // namespace adamw
} // namespace op
} // namespace mxnet

Expand Down
4 changes: 2 additions & 2 deletions src/operator/contrib/adamw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

namespace mxnet {
namespace op {
namespace adamw {

DMLC_REGISTER_PARAMETER(AdamWParam);
DMLC_REGISTER_PARAMETER(MultiAdamWParam);
Expand Down Expand Up @@ -118,14 +119,12 @@ the update is skipped.
"the update is skipped.")
.add_arguments(AdamWParam::__FIELDS__());

namespace adamw {
template<>
void GetScaleFloat<cpu>(mshadow::Stream<cpu> *s, const TBlob &scale_blob, float *pScalef) {
MSHADOW_REAL_TYPE_SWITCH(scale_blob.type_flag_, DType,
*pScalef = static_cast<float>(*scale_blob.dptr<DType>());
)
}
}

static std::vector<std::string> ParamToVector(uint32_t num_args, const char *pName[], size_t nParams) {
std::vector<std::string> ret;
Expand Down Expand Up @@ -255,5 +254,6 @@ the update is skipped.
.add_arguments(MultiAdamWParam::__FIELDS__());


} // namespace adamw
} // namespace op
} // namespace mxnet
4 changes: 2 additions & 2 deletions src/operator/contrib/adamw.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@

namespace mxnet {
namespace op {

namespace adamw{

template<>
void GetScaleFloat<gpu>(mshadow::Stream<gpu> *s, const TBlob &scale_blob, float *pScalef) {
MSHADOW_REAL_TYPE_SWITCH(scale_blob.type_flag_, DType, {
Expand All @@ -40,7 +40,6 @@ void GetScaleFloat<gpu>(mshadow::Stream<gpu> *s, const TBlob &scale_blob, float
*pScalef = static_cast<float>(scale);
})
}
}

NNVM_REGISTER_OP(_adamw_update)
.set_attr<FCompute>("FCompute<gpu>", adamw::MPUpdate<gpu, AdamWUpdate<gpu>>);
Expand All @@ -54,5 +53,6 @@ NNVM_REGISTER_OP(_multi_adamw_update)
NNVM_REGISTER_OP(_multi_mp_adamw_update)
.set_attr<FCompute>("FCompute<gpu>", adamw::multiMPUpdate<gpu, true>);

} // namespace adamw
} // namespace op
} // namespace mxnet

0 comments on commit 7c8abb1

Please sign in to comment.