Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ShawnXuan committed Jan 23, 2025
1 parent 76b8675 commit db0ff1e
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions oneflow/user/ops/layer_norm_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,17 +274,17 @@ oneflow::DataType InferBnParamDataType(const DataType x_data_type) {
const user_op::TensorDesc& mean = ctx->InputTensorDesc("mean", 0);
const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc("inv_variance", 0);
user_op::TensorDesc* dx = ctx->MutOutputTensorDesc("dx", 0);
CHECK_EQ_OR_RETURN(dy.shape(), x.shape());
CHECK_EQ_OR_RETURN(dy.shape(), x.shape()) << "dy and x shapes should be equal.";
const int64_t begin_norm_axis = ctx->Attr<int64_t>("begin_norm_axis");
CHECK_GT_OR_RETURN(begin_norm_axis, 0);
CHECK_GT_OR_RETURN(begin_norm_axis, 0) << "begin_norm_axis must be greater than 0.";
const Shape& bn_param_shape = InferBnParamShape(x.shape(), begin_norm_axis);
CHECK_EQ_OR_RETURN(mean.shape(), bn_param_shape);
CHECK_EQ_OR_RETURN(inv_variance.shape(), bn_param_shape);
CHECK_EQ_OR_RETURN(mean.shape(), bn_param_shape) << "mean shape must match bn_param_shape.";
CHECK_EQ_OR_RETURN(inv_variance.shape(), bn_param_shape) << "inv_variance shape must match bn_param_shape.";
dx->set_shape(dy.shape());
dx->set_is_dynamic(dy.is_dynamic());
if (ctx->has_input("_add_to_output", 0)) {
const auto& add_to_output = ctx->InputTensorDesc("_add_to_output", 0);
CHECK_EQ_OR_RETURN(add_to_output.shape(), dx->shape());
CHECK_EQ_OR_RETURN(add_to_output.shape(), dx->shape()) << "add_to_output shape must match dx shape.";
}

auto has_tensor = [ctx](const std::string& bn) -> bool {
Expand All @@ -300,8 +300,8 @@ oneflow::DataType InferBnParamDataType(const DataType x_data_type) {
const int64_t begin_params_axis = ctx->Attr<int64_t>("begin_params_axis");
const bool has_beta_diff = has_tensor("beta_diff");
const bool has_gamma_diff = has_tensor("gamma_diff");
CHECK_GE_OR_RETURN(begin_params_axis, 1);
CHECK_LT_OR_RETURN(begin_params_axis, dy.shape().NumAxes());
CHECK_GE_OR_RETURN(begin_params_axis, 1) << "begin_params_axis must be greater than or equal to 1.";
CHECK_LT_OR_RETURN(begin_params_axis, dy.shape().NumAxes()) << "begin_params_axis must be less than the number of axes in dy shape.";
DimVector param_shape_dim_vec;
param_shape_dim_vec.insert(param_shape_dim_vec.end(),
dy.shape().dim_vec().cbegin() + begin_params_axis,
Expand Down

0 comments on commit db0ff1e

Please sign in to comment.