Skip to content

Commit

Permalink
[CPU] Moved jit_uni_eltwise_generic x64 to another files
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Jan 30, 2025
1 parent 9a46c97 commit 7c4ae63
Show file tree
Hide file tree
Showing 10 changed files with 1,307 additions and 1,299 deletions.
1,105 changes: 33 additions & 1,072 deletions src/plugins/intel_cpu/src/nodes/eltwise.cpp

Large diffs are not rendered by default.

64 changes: 1 addition & 63 deletions src/plugins/intel_cpu/src/nodes/eltwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,65 +14,13 @@
#include "dnnl_postops_composer_legacy.h"
#include "executors/eltwise_list.hpp"
#include "nodes/executors/eltwise.hpp"
#include "nodes/kernels/jit_eltwise_call_args_ptrs.hpp"

#if defined(OPENVINO_ARCH_ARM64)
# include "kernels/aarch64/jit_uni_eltwise_generic.hpp"
#endif
#include "nodes/kernels/jit_eltwise_common.hpp"

namespace ov {
namespace intel_cpu {
namespace node {

#ifndef OPENVINO_ARCH_ARM64

struct jit_eltwise_params {
size_t inputs_number;
size_t input_size;

ov::element::Type src_prc[MAX_ELTWISE_INPUTS];
ov::element::Type dst_prc;

VectorDims dims;
VectorDims src_offsets[MAX_ELTWISE_INPUTS];
VectorDims dst_offsets;
VectorDims oc_offsets;

size_t src_size[MAX_ELTWISE_INPUTS];
size_t dst_size;
size_t oc_size;

size_t work_amount;
bool use_runtime_ptrs;
bool do_output_saturation;
};

struct jit_eltwise_call_args_indexes {
size_t indexes[MAX_ELTWISE_DIM_RANK];
};

class Eltwise;

struct jit_uni_eltwise_kernel {
void (*ker_)(const jit_eltwise_call_args_ptrs*, const jit_eltwise_call_args_indexes*);

void operator()(const jit_eltwise_call_args_ptrs* const_args, const jit_eltwise_call_args_indexes* indexes) {
assert(ker_);
ker_(const_args, indexes);
}

explicit jit_uni_eltwise_kernel(jit_eltwise_params jep) : ker_(nullptr), jep_(std::move(jep)) {}
virtual ~jit_uni_eltwise_kernel() {}

virtual void create_ker() = 0;

jit_eltwise_params jep_;
};

#endif

enum class EltwiseImplType { reference = 0, optimized = 1, optimizedShapeAgnostic = 2 };

class Eltwise : public Node {
public:
class IEltwiseExecutor {
Expand Down Expand Up @@ -218,16 +166,6 @@ class Eltwise : public Node {
std::shared_ptr<EltwiseExecutor> eltwiseExecPtr = nullptr;
};

class eltwise_precision_helper {
public:
static ov::element::Type get_precision(const size_t inputs_number,
const ov::element::Type (&src_prc)[MAX_ELTWISE_INPUTS],
const std::vector<EltwiseData>& eltwise_data);

private:
static std::set<std::vector<element::Type>> get_supported_precisions(const Algorithm& algo);
};

} // namespace node
} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,6 @@ using namespace Xbyak_aarch64;
using namespace dnnl::impl::cpu;
using namespace dnnl::impl::cpu::aarch64;

void jit_uni_eltwise_kernel::operator()(const node::jit_eltwise_call_args_ptrs* const_args,
const jit_eltwise_call_args_indexes* indexes) {
assert(ker_);
ker_(const_args, indexes);
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
jit_uni_eltwise_generic<isa>::jit_uni_eltwise_generic(jit_eltwise_params jep,
std::vector<EltwiseData> eltwise_data,
Expand All @@ -35,7 +29,8 @@ template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_uni_eltwise_generic<isa>::generate() {
preamble();

auto const exec_prc = eltwise_precision_helper::get_precision(jep_.inputs_number, jep_.src_prc, eltwise_data_);
static const std::vector<element::Type> exec_precisions_priority = {element::f16, element::f32};
auto const exec_prc = eltwise_precision_helper::get_precision(jep_.inputs_number, jep_.src_prc, eltwise_data_, exec_precisions_priority);

eltwise_emitter = create_eltwise_emitter(eltwise_data_.front(), exec_prc);
for (size_t i = 1; i < eltwise_data_.size(); ++i) {
Expand All @@ -52,11 +47,11 @@ void jit_uni_eltwise_generic<isa>::generate() {
for (size_t i = 0; i < jep.inputs_number; i++) {
ldr(start_to_offsets,
ptr(reg_const_params,
static_cast<int32_t>(offsetof(node::jit_eltwise_call_args_ptrs, src_offsets) +
static_cast<int32_t>(offsetof(jit_eltwise_call_args_ptrs, src_offsets) +
i * sizeof(size_t))));
ldr(get_src_reg(i),
ptr(reg_const_params,
static_cast<int32_t>(offsetof(node::jit_eltwise_call_args_ptrs, src_ptr[0]) + i * sizeof(size_t))));
static_cast<int32_t>(offsetof(jit_eltwise_call_args_ptrs, src_ptr[0]) + i * sizeof(size_t))));
XReg offset_reg = get_aux_gpr(0); // X_TMP_0;
XReg index_reg = get_aux_gpr(1); // X_TMP_1;
for (int j = 0; j < offset_count; j++) {
Expand All @@ -67,8 +62,8 @@ void jit_uni_eltwise_generic<isa>::generate() {
}

ldr(start_to_offsets,
ptr(reg_const_params, static_cast<int32_t>(offsetof(node::jit_eltwise_call_args_ptrs, dst_offsets))));
ldr(reg_dst, ptr(reg_const_params, static_cast<int32_t>(offsetof(node::jit_eltwise_call_args_ptrs, dst_ptr))));
ptr(reg_const_params, static_cast<int32_t>(offsetof(jit_eltwise_call_args_ptrs, dst_offsets))));
ldr(reg_dst, ptr(reg_const_params, static_cast<int32_t>(offsetof(jit_eltwise_call_args_ptrs, dst_ptr))));
XReg offset_reg = get_aux_gpr(0); // X_TMP_0;
XReg index_reg = get_aux_gpr(1); // X_TMP_1;
for (int j = 0; j < offset_count; j++) {
Expand All @@ -80,7 +75,7 @@ void jit_uni_eltwise_generic<isa>::generate() {
mov(reg_oc_off, 0);

ldr(reg_work_amount,
ptr(reg_const_params, static_cast<int32_t>(offsetof(node::jit_eltwise_call_args_ptrs, work_amount))));
ptr(reg_const_params, static_cast<int32_t>(offsetof(jit_eltwise_call_args_ptrs, work_amount))));
} else {
auto init_ptrs_with_offsets = [this, offset_count, param2](XReg pointer, const std::vector<size_t>& offsets) {
for (int j = 0; j < offset_count; j++) {
Expand All @@ -98,11 +93,11 @@ void jit_uni_eltwise_generic<isa>::generate() {
for (size_t i = 0; i < jep.inputs_number; i++) {
ldr(get_src_reg(i),
ptr(param1,
static_cast<int32_t>(offsetof(node::jit_eltwise_call_args_ptrs, src_ptr) + i * sizeof(size_t))));
static_cast<int32_t>(offsetof(jit_eltwise_call_args_ptrs, src_ptr) + i * sizeof(size_t))));
init_ptrs_with_offsets(get_src_reg(i), jep.src_offsets[i]);
}

ldr(reg_dst, ptr(reg_const_params, static_cast<int32_t>(offsetof(node::jit_eltwise_call_args_ptrs, dst_ptr))));
ldr(reg_dst, ptr(reg_const_params, static_cast<int32_t>(offsetof(jit_eltwise_call_args_ptrs, dst_ptr))));
init_ptrs_with_offsets(reg_dst, jep.dst_offsets);

mov(reg_oc_off, 0);
Expand Down Expand Up @@ -777,80 +772,21 @@ void jit_uni_eltwise_generic<isa>::apply_post_ops() {
}
}

namespace {
template struct jit_uni_eltwise_generic<cpu_isa_t::asimd>;

} // namespace aarch64

namespace {
template <typename T>
struct SupportedPrecisions {
void operator()(std::set<std::vector<element::Type>>& precisions) {
precisions = T::get_supported_precisions();
}
};

static void set_intersection(const std::set<std::vector<element::Type>>& precisions1,
const std::set<std::vector<element::Type>>& precisions2,
std::set<std::vector<element::Type>>& intersection) {
std::map<element::Type, size_t> intersection_types;

for (auto it1 = precisions1.begin(); it1 != precisions1.end(); ++it1) {
for (auto it2 = precisions2.begin(); it2 != precisions2.end(); ++it2) {
const auto& it1_precisions = *it1;
// all element types are equal
if (it1_precisions[0] == (*it2)[0]) {
// first precisions size is used
intersection_types.emplace(it1_precisions[0], it1_precisions.size());
}
}
}

for (auto it = intersection_types.begin(); it != intersection_types.end(); ++it) {
intersection.insert(std::vector<element::Type>(it->second, it->first));
}
}
} // namespace

ov::element::Type eltwise_precision_helper::get_precision(const size_t inputs_number,
const ov::element::Type (&src_prc)[MAX_ELTWISE_INPUTS],
const std::vector<EltwiseData>& eltwise_data) {
ov::element::Type exec_prc = ov::element::undefined;

const auto algorithm = eltwise_data.front().algo;
std::set<std::vector<element::Type>> supported_precision_intersection = get_supported_precisions(algorithm);

for (size_t i = 1; i < eltwise_data.size(); ++i) {
std::set<std::vector<element::Type>> prcs = get_supported_precisions(eltwise_data[i].algo);
std::set<std::vector<element::Type>> prcs_intersect = {};

set_intersection(supported_precision_intersection, prcs, prcs_intersect);

supported_precision_intersection = prcs_intersect;
}

static const element::Type exec_precisions_priority[] = {element::f16, element::f32};

for (const auto prc : exec_precisions_priority) {
if (std::any_of(supported_precision_intersection.begin(),
supported_precision_intersection.end(),
[&prc](const std::vector<element::Type>& precisions) {
return std::find(precisions.begin(), precisions.end(), prc) != precisions.end();
})) {
exec_prc = prc;
break;
}
}

for (size_t i = 0; i < inputs_number; i++) {
if (src_prc[i] != exec_prc) {
exec_prc = ov::element::f32;
break;
}
}

if (exec_prc == ov::element::undefined) {
OPENVINO_THROW("Eltwise jitter failed to specify execution precision for Eltwise node");
}

return exec_prc;
}
using namespace aarch64;

std::set<std::vector<element::Type>> eltwise_precision_helper::get_supported_precisions(const Algorithm& algo) {
std::set<std::vector<element::Type>> precisions;
Expand Down Expand Up @@ -909,8 +845,5 @@ std::set<std::vector<element::Type>> eltwise_precision_helper::get_supported_pre
return precisions;
}

template struct jit_uni_eltwise_generic<cpu_isa_t::asimd>;

} // namespace aarch64
} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

#include "emitters/plugin/aarch64/jit_eltwise_emitters.hpp"
#include "emitters/plugin/aarch64/jit_emitter.hpp"
#include "nodes/kernels/jit_eltwise_call_args_ptrs.hpp"
#include "nodes/kernels/jit_eltwise_common.hpp"
#include "utils/cpu_utils.hpp"
#include "utils/general_utils.h"

Expand All @@ -40,45 +40,6 @@ using namespace Xbyak_aarch64;
using namespace dnnl::impl::cpu;
using namespace dnnl::impl::cpu::aarch64;

struct jit_eltwise_params {
size_t inputs_number;
size_t input_size;

ov::element::Type src_prc[MAX_ELTWISE_INPUTS];
ov::element::Type dst_prc;

VectorDims dims;
VectorDims src_offsets[MAX_ELTWISE_INPUTS];
VectorDims dst_offsets;
VectorDims oc_offsets;

size_t src_size[MAX_ELTWISE_INPUTS];
size_t dst_size;
size_t oc_size;

size_t work_amount;
bool use_runtime_ptrs;
bool do_output_saturation;
};

struct jit_eltwise_call_args_indexes {
size_t indexes[MAX_ELTWISE_DIM_RANK];
};

struct jit_uni_eltwise_kernel {
void (*ker_)(const node::jit_eltwise_call_args_ptrs*, const jit_eltwise_call_args_indexes*);

void operator()(const node::jit_eltwise_call_args_ptrs* const_args, const jit_eltwise_call_args_indexes* indexes);

jit_uni_eltwise_kernel() {}
jit_uni_eltwise_kernel(jit_eltwise_params jep) : ker_(nullptr), jep_(std::move(jep)) {}
virtual ~jit_uni_eltwise_kernel() {}

virtual void create_ker() = 0;

jit_eltwise_params jep_;
};

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator {
public:
Expand All @@ -89,8 +50,6 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator {
std::vector<ov::intel_cpu::Type> ops_list,
dnnl::post_ops post_ops);

jit_uni_eltwise_generic() {}

void create_ker() override {
jit_generator::create_kernel();
ker_ = (decltype(ker_))jit_ker();
Expand Down Expand Up @@ -255,16 +214,6 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator {
std::vector<std::shared_ptr<jit_emitter>> post_op_emitters;
};

class eltwise_precision_helper {
public:
static ov::element::Type get_precision(const size_t inputs_number,
const ov::element::Type (&src_prc)[MAX_ELTWISE_INPUTS],
const std::vector<EltwiseData>& eltwise_data);

private:
static std::set<std::vector<element::Type>> get_supported_precisions(const Algorithm& algo);
};

} // namespace aarch64
} // namespace intel_cpu
} // namespace ov

This file was deleted.

Loading

0 comments on commit 7c4ae63

Please sign in to comment.