From aed33488286d557ac38e82e25a3f76eed9acef2d Mon Sep 17 00:00:00 2001 From: BHbean <1216808064@qq.com> Date: Wed, 10 Jul 2024 07:24:50 +0000 Subject: [PATCH] [Bug Fix] add broadcast mechanism before calculating PReLu when the input layout is NCHW [BUG FIX] avoid braodcast in certain cases [BUG FIX] reformat based on comments --- source/thead_rvv/fp32/prelu.c | 58 +++++++++++++++++++++++++++-------- 1 file changed, 45 insertions(+), 13 deletions(-) diff --git a/source/thead_rvv/fp32/prelu.c b/source/thead_rvv/fp32/prelu.c index d53b0f6c..cd281b7a 100644 --- a/source/thead_rvv/fp32/prelu.c +++ b/source/thead_rvv/fp32/prelu.c @@ -17,6 +17,7 @@ */ #include "rvv/rvv.h" +#include "reference/ref.h" int shl_rvv_prelu_fp32(struct csinn_tensor *input, struct csinn_tensor *alpha, struct csinn_tensor *output, struct csinn_prelu_params *params) @@ -53,22 +54,53 @@ int shl_rvv_prelu_fp32(struct csinn_tensor *input, struct csinn_tensor *alpha, output->layout = CSINN_LAYOUT_NC1HWC0; } } else if (input->layout == CSINN_LAYOUT_NCHW) { - for (int n = 0; n < input->dim[0]; ++n) { - for (int c = 0; c < input->dim[1]; ++c) { - float a = alpha_data[c]; - int inner_size = input->dim[2] * input->dim[3]; - while (inner_size > 0) { - int vl = vsetvl_e32m2(inner_size); - vfloat32m2_t _input = vle32_v_f32m2(input_data, vl); - vbool16_t _mask = vmflt_vf_f32m2_b16(_input, 0.0f, vl); - vfloat32m2_t _res = vfmul_vf_f32m2_m(_mask, _input, _input, a, vl); - vse32_v_f32m2(output_data, _res, vl); - input_data += vl; - output_data += vl; - inner_size -= vl; + if (alpha->dim[1] == csinn_tensor_size(alpha)) { + // simplify the calculation by avoiding broadcast + for (int n = 0; n < input->dim[0]; ++n) { + for (int c = 0; c < input->dim[1]; ++c) { + float a = alpha_data[c]; + int inner_size = input->dim[2] * input->dim[3]; + while (inner_size > 0) { + int vl = vsetvl_e32m2(inner_size); + vfloat32m2_t _input = vle32_v_f32m2(input_data, vl); + vbool16_t _mask = vmflt_vf_f32m2_b16(_input, 0.0f, vl); + vfloat32m2_t _res = vfmul_vf_f32m2_m(_mask, _input, _input, a, vl); + vse32_v_f32m2(output_data, _res, vl); + input_data += vl; + output_data += vl; + inner_size -= vl; + } } } + } else { + // broadcast alpha + int input_size = csinn_tensor_size(input); + float *alpha_data_b = shl_mem_alloc(input_size * sizeof(float)); + struct csinn_tensor *alpha_ = csinn_alloc_tensor(NULL); + csinn_tensor_copy(alpha_, input); + alpha_->data = alpha_data_b; + shl_ref_broadcast_to_shape_f32(alpha, alpha_, alpha_->dim, alpha_->dim_count); + alpha_data = (float *)alpha_->data; + + // calculation + while (input_size > 0) { + int vl = vsetvl_e32m2(input_size); + vfloat32m2_t _input = vle32_v_f32m2(input_data, vl); + vfloat32m2_t _a = vle32_v_f32m2(alpha_data, vl); + vbool16_t _mask = vmflt_vf_f32m2_b16(_input, 0.0f, vl); + vfloat32m2_t _res = vfmul_vv_f32m2_m(_mask, _input, _input, _a, vl); + vse32_v_f32m2(output_data, _res, vl); + input_data += vl; + alpha_data += vl; + output_data += vl; + input_size -= vl; + } + + // free memory and tensor + shl_mem_free(alpha_data_b); + csinn_free_tensor(alpha_); } + if (output->layout == CSINN_LAYOUT_NC1HWC0) { const int packn = csrr_vlenb() / sizeof(float); output->dim[1] *= packn;