Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
use float32 to store the reduction result of float16
Browse files Browse the repository at this point in the history
enable safe accumulation

fix bug

fix
  • Loading branch information
sxjscience committed May 21, 2019
1 parent 5bc08ce commit 53f31e1
Show file tree
Hide file tree
Showing 2 changed files with 212 additions and 129 deletions.
88 changes: 56 additions & 32 deletions src/operator/nn/layer_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs,
// Calculate mean
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
broadcast::Reduce<red::sum, NDim, DType, op::mshadow_op::identity>(
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, true>(
s, mean_data, req[0], workspace, in_data);
Tensor<xpu, 1, DType> mean_data_tensor = mean_data.FlatTo1D<xpu, DType>(s);
mean_data_tensor /= scalar<DType>(channel_size);
Expand All @@ -130,25 +130,25 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs,
const TBlob centered_out = outputs[0].reshape(red_src_shape);
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
broadcast::Reduce<red::sum, NDim, DType, op::mshadow_op::square>(
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::square, true>(
s, std_data, req[0], workspace, centered_out);
Tensor<xpu, 1, DType> std_data_tensor = std_data.FlatTo1D<xpu, DType>(s);
std_data_tensor = F<mshadow_op::square_root>(std_data_tensor / scalar<DType>(channel_size)
+ scalar<DType>(param.eps));
});
});
// Calculate data = data / std
BinaryBroadcastCompute<xpu, op::mshadow_op::div>(attrs, ctx,
{outputs[0], outputs[layernorm::kStd]},
{kWriteTo}, {outputs[0]});
BinaryBroadcastCompute<xpu, mshadow_op::div>(attrs, ctx,
{outputs[0], outputs[layernorm::kStd]},
{kWriteTo}, {outputs[0]});
// Calculate data = data * gamma
BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx,
{outputs[0], gamma},
{kWriteTo}, {outputs[0]});
BinaryBroadcastCompute<xpu, mshadow_op::mul>(attrs, ctx,
{outputs[0], gamma},
{kWriteTo}, {outputs[0]});
// Calculate data = data + beta
BinaryBroadcastCompute<xpu, op::mshadow_op::plus>(attrs, ctx,
{outputs[0], beta},
{kWriteTo}, {outputs[0]});
BinaryBroadcastCompute<xpu, mshadow_op::plus>(attrs, ctx,
{outputs[0], beta},
{kWriteTo}, {outputs[0]});
}

template<typename xpu>
Expand Down Expand Up @@ -233,19 +233,25 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
const TBlob red_out = TBlob(workspace.dptr_ + reduce_workspace_size + data_size * 2,
mean.shape_, mean.dev_mask(), mean.type_flag_, mean.dev_id());
// Compute normalized_data = (data - mean) / std
BinaryBroadcastCompute<xpu, op::mshadow_op::minus>(attrs, ctx,
{data, mean},
{kWriteTo}, {normalized_data});
BinaryBroadcastCompute<xpu, op::mshadow_op::div>(attrs, ctx,
{normalized_data, std},
{kWriteTo}, {normalized_data});
BinaryBroadcastCompute<xpu, mshadow_op::minus>(attrs, ctx,
{data, mean},
{kWriteTo}, {normalized_data});
BinaryBroadcastCompute<xpu, mshadow_op::div>(attrs, ctx,
{normalized_data, std},
{kWriteTo}, {normalized_data});
// Calculate grad_beta
if (req[2] != kNullOp) {
MSHADOW_REAL_TYPE_SWITCH(outputs[2].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
broadcast::Reduce<red::sum, NDim, DType, op::mshadow_op::identity>(
s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace,
ograd.reshape(red_exclude_src_shape));
if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace,
ograd.reshape(red_exclude_src_shape));
} else {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, true>(
s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace,
ograd.reshape(red_exclude_src_shape));
}
});
});
}
Expand All @@ -255,9 +261,15 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
if (req[1] != kNullOp) {
MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
broadcast::Reduce<red::sum, NDim, DType, op::mshadow_op::identity>(
s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace,
ograd_mult.reshape(red_exclude_src_shape));
if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace,
ograd_mult.reshape(red_exclude_src_shape));
} else {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, true>(
s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace,
ograd_mult.reshape(red_exclude_src_shape));
}
});
});
}
Expand All @@ -274,9 +286,15 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
{kWriteTo}, {ograd_mult});
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
broadcast::Reduce<red::sum, NDim, DType, op::mshadow_op::identity>(
s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
ograd_mult.reshape(red_src_shape));
if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
ograd_mult.reshape(red_src_shape));
} else {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, true>(
s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
ograd_mult.reshape(red_src_shape));
}
});
Tensor<xpu, 1, DType> red_out_tensor = red_out.FlatTo1D<xpu, DType>(s);
red_out_tensor /= scalar<DType>(channel_size);
Expand All @@ -288,16 +306,22 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
{kWriteTo}, {ograd_mult});
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
broadcast::Reduce<red::sum, NDim, DType, op::mshadow_op::identity>(
s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
ograd_mult.reshape(red_src_shape));
if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
ograd_mult.reshape(red_src_shape));
} else {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, true>(
s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
ograd_mult.reshape(red_src_shape));
}
});
Tensor<xpu, 1, DType> red_out_tensor = red_out.FlatTo1D<xpu, DType>(s);
red_out_tensor /= scalar<DType>(- channel_size);
});
BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx,
{normalized_data, red_out},
{kAddTo}, {outputs[0]});
BinaryBroadcastCompute<xpu, mshadow_op::mul>(attrs, ctx,
{normalized_data, red_out},
{kAddTo}, {outputs[0]});
}
}

Expand Down
Loading

0 comments on commit 53f31e1

Please sign in to comment.