Skip to content

Commit

Permalink
[Bug Fix] add broadcast mechanism before calculating PReLu when the i…
Browse files Browse the repository at this point in the history
…nput layout is NCHW

[BUG FIX] avoid braodcast in certain cases

[BUG FIX] reformat based on comments
  • Loading branch information
BHbean committed Aug 6, 2024
1 parent 3a7d230 commit aed3348
Showing 1 changed file with 45 additions and 13 deletions.
58 changes: 45 additions & 13 deletions source/thead_rvv/fp32/prelu.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/

#include "rvv/rvv.h"
#include "reference/ref.h"

int shl_rvv_prelu_fp32(struct csinn_tensor *input, struct csinn_tensor *alpha,
struct csinn_tensor *output, struct csinn_prelu_params *params)
Expand Down Expand Up @@ -53,22 +54,53 @@ int shl_rvv_prelu_fp32(struct csinn_tensor *input, struct csinn_tensor *alpha,
output->layout = CSINN_LAYOUT_NC1HWC0;
}
} else if (input->layout == CSINN_LAYOUT_NCHW) {
for (int n = 0; n < input->dim[0]; ++n) {
for (int c = 0; c < input->dim[1]; ++c) {
float a = alpha_data[c];
int inner_size = input->dim[2] * input->dim[3];
while (inner_size > 0) {
int vl = vsetvl_e32m2(inner_size);
vfloat32m2_t _input = vle32_v_f32m2(input_data, vl);
vbool16_t _mask = vmflt_vf_f32m2_b16(_input, 0.0f, vl);
vfloat32m2_t _res = vfmul_vf_f32m2_m(_mask, _input, _input, a, vl);
vse32_v_f32m2(output_data, _res, vl);
input_data += vl;
output_data += vl;
inner_size -= vl;
if (alpha->dim[1] == csinn_tensor_size(alpha)) {
// simplify the calculation by avoiding broadcast
for (int n = 0; n < input->dim[0]; ++n) {
for (int c = 0; c < input->dim[1]; ++c) {
float a = alpha_data[c];
int inner_size = input->dim[2] * input->dim[3];
while (inner_size > 0) {
int vl = vsetvl_e32m2(inner_size);
vfloat32m2_t _input = vle32_v_f32m2(input_data, vl);
vbool16_t _mask = vmflt_vf_f32m2_b16(_input, 0.0f, vl);
vfloat32m2_t _res = vfmul_vf_f32m2_m(_mask, _input, _input, a, vl);
vse32_v_f32m2(output_data, _res, vl);
input_data += vl;
output_data += vl;
inner_size -= vl;
}
}
}
} else {
// broadcast alpha
int input_size = csinn_tensor_size(input);
float *alpha_data_b = shl_mem_alloc(input_size * sizeof(float));
struct csinn_tensor *alpha_ = csinn_alloc_tensor(NULL);
csinn_tensor_copy(alpha_, input);
alpha_->data = alpha_data_b;
shl_ref_broadcast_to_shape_f32(alpha, alpha_, alpha_->dim, alpha_->dim_count);
alpha_data = (float *)alpha_->data;

// calculation
while (input_size > 0) {
int vl = vsetvl_e32m2(input_size);
vfloat32m2_t _input = vle32_v_f32m2(input_data, vl);
vfloat32m2_t _a = vle32_v_f32m2(alpha_data, vl);
vbool16_t _mask = vmflt_vf_f32m2_b16(_input, 0.0f, vl);
vfloat32m2_t _res = vfmul_vv_f32m2_m(_mask, _input, _input, _a, vl);
vse32_v_f32m2(output_data, _res, vl);
input_data += vl;
alpha_data += vl;
output_data += vl;
input_size -= vl;
}

// free memory and tensor
shl_mem_free(alpha_data_b);
csinn_free_tensor(alpha_);
}

if (output->layout == CSINN_LAYOUT_NC1HWC0) {
const int packn = csrr_vlenb() / sizeof(float);
output->dim[1] *= packn;
Expand Down

0 comments on commit aed3348

Please sign in to comment.