Skip to content

Commit

Permalink
Merge pull request oneapi-src#44 from Takumi-Honda/develop
Browse files Browse the repository at this point in the history
Fixup: convolution result (BWD_D) may be incorrect when stride_w>1, ic>16, and iw<17
  • Loading branch information
Takumi-Honda authored Nov 27, 2020
2 parents 676d012 + b08b7f2 commit a02206a
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions src/cpu/aarch64/jit_aarch64_sve_512_conv_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1786,20 +1786,19 @@ void _jit_aarch64_sve_512_conv_bwd_data_kernel_f32<Vmm>::compute_loop_fma_core(
return prev_ofs;
};

auto bcast_load_30 = [&](int jj, int nb_oc_block, int aux_output_offset,
auto bcast_load_sw = [&](int jj, int nb_oc_block, int aux_output_offset,
int prev_ofs, int jj_end) {
if (((aux_output_offset & 0x3) == 0) && (aux_output_offset < LDRWMAX)
&& (aux_output_offset >= 0)) {
CGA64::ld1rw(zreg_inp_s(jj / stride_w, nb_oc_block), reg_p_all_ones,
CGA64::ld1rw(zreg_inp_s(jj, nb_oc_block), reg_p_all_ones,
xa::ptr(aux_reg_dst,
static_cast<int32_t>(aux_output_offset)));
} else {
if ((prev_ofs > -1) && ((aux_output_offset - prev_ofs) > 0)
&& ((aux_output_offset - prev_ofs) < LDRWMAX)
&& (((aux_output_offset - prev_ofs) & 0x3) == 0)) {

CGA64::ld1rw(zreg_inp_s(jj / stride_w, nb_oc_block),
reg_p_all_ones,
CGA64::ld1rw(zreg_inp_s(jj, nb_oc_block), reg_p_all_ones,
xa::ptr(reg_prev_bcast_addr,
static_cast<int32_t>(
aux_output_offset - prev_ofs)));
Expand All @@ -1817,8 +1816,8 @@ void _jit_aarch64_sve_512_conv_bwd_data_kernel_f32<Vmm>::compute_loop_fma_core(
reg_prev_bcast_addr, aux_reg_dst, ofs, reg_tmp_imm);
}

CGA64::ld1rw(zreg_inp_s(jj / stride_w, nb_oc_block),
reg_p_all_ones, xa::ptr(reg_prev_bcast_addr));
CGA64::ld1rw(zreg_inp_s(jj, nb_oc_block), reg_p_all_ones,
xa::ptr(reg_prev_bcast_addr));
prev_ofs = aux_output_offset;
}
}
Expand Down Expand Up @@ -1915,7 +1914,7 @@ void _jit_aarch64_sve_512_conv_bwd_data_kernel_f32<Vmm>::compute_loop_fma_core(
zreg_wei_s(wei_count));
} else {
int aux_output_offset = get_dst_offset(jj, oc, ki);
prev_ofs = bcast_load_30(jj, nb_ic_block,
prev_ofs = bcast_load_sw(jj % stride_w, nb_ic_block,
aux_output_offset, prev_ofs, jj_end);

CGA64::fmla(zreg_out_s(jj, ii), reg_p_all_ones,
Expand Down

0 comments on commit a02206a

Please sign in to comment.