Skip to content

Commit

Permalink
Add "if" statement for loop unrolling in rms kernel. (#27215)
Browse files Browse the repository at this point in the history
### Details:
- *Add "if" statement for loop unrolling in rms kernel to fix
Segmentation Fault in tiny-random-sd3 model*

### Tickets:
 - *CVS-152057*
  • Loading branch information
mangguo321 authored Oct 28, 2024
1 parent 65dd174 commit 78864ca
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 15 deletions.
33 changes: 18 additions & 15 deletions src/plugins/intel_cpu/src/nodes/kernels/x64/rms_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,22 +123,25 @@ void jit_rms_kernel<isa>::generate() {
// x * 1/Sqrt(ReduceMean(x^2,axes)+eps) * gamma
// sum(x^2)
align(16);
Xbyak::Label loop_4reg;
L(loop_4reg);
{
load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false);
vfmadd231ps(vmm_sum0, vmm_src, vmm_src);
load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false, vec_size * m_jcp.src_prc.size() * 1);
vfmadd231ps(vmm_sum1, vmm_src, vmm_src);
load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false, vec_size * m_jcp.src_prc.size() * 2);
vfmadd231ps(vmm_sum2, vmm_src, vmm_src);
load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false, vec_size * m_jcp.src_prc.size() * 3);
vfmadd231ps(vmm_sum3, vmm_src, vmm_src);

add(reg_src, vec_size * m_jcp.src_prc.size() * 4);
dec(reg_size);
jnz(loop_4reg);
if ((m_jcp.data_size / (vec_size * 4)) != 0) {
Xbyak::Label loop_4reg;
L(loop_4reg);
{
load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false);
vfmadd231ps(vmm_sum0, vmm_src, vmm_src);
load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false, vec_size * m_jcp.src_prc.size() * 1);
vfmadd231ps(vmm_sum1, vmm_src, vmm_src);
load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false, vec_size * m_jcp.src_prc.size() * 2);
vfmadd231ps(vmm_sum2, vmm_src, vmm_src);
load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false, vec_size * m_jcp.src_prc.size() * 3);
vfmadd231ps(vmm_sum3, vmm_src, vmm_src);

add(reg_src, vec_size * m_jcp.src_prc.size() * 4);
dec(reg_size);
jnz(loop_4reg);
}
}

// 1 ~ 3 vmm
for (size_t i = m_jcp.data_size / (vec_size * 4) * 4; i < m_jcp.data_size / vec_size; i++) {
load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ const std::vector<std::vector<InputShape>> shapes{
{ov::Shape{1024 + 16 + 1}, ov::Shape{1024 + 16 + 1}}}
},
},
// small data size
{
// data shape
{ov::test::InputShape{ov::PartialShape{-1, -1, 31},
{ov::Shape{1, 8, 31}, ov::Shape{2, 3, 31}}}
},
// scale shape
{ov::test::InputShape{ov::PartialShape{31},
{ov::Shape{31}, ov::Shape{31}}}
},
},
// scale is scalar
{
// data shape
Expand Down

0 comments on commit 78864ca

Please sign in to comment.