Skip to content

Commit

Permalink
[CPU] Moved jit_uni_eltwise_generic x64 to another files
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Jan 30, 2025
1 parent 9a46c97 commit 149c684
Show file tree
Hide file tree
Showing 10 changed files with 1,299 additions and 1,292 deletions.
1,105 changes: 33 additions & 1,072 deletions src/plugins/intel_cpu/src/nodes/eltwise.cpp

Large diffs are not rendered by default.

64 changes: 1 addition & 63 deletions src/plugins/intel_cpu/src/nodes/eltwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,65 +14,13 @@
#include "dnnl_postops_composer_legacy.h"
#include "executors/eltwise_list.hpp"
#include "nodes/executors/eltwise.hpp"
#include "nodes/kernels/jit_eltwise_call_args_ptrs.hpp"

#if defined(OPENVINO_ARCH_ARM64)
# include "kernels/aarch64/jit_uni_eltwise_generic.hpp"
#endif
#include "nodes/kernels/jit_eltwise_common.hpp"

namespace ov {
namespace intel_cpu {
namespace node {

#ifndef OPENVINO_ARCH_ARM64

struct jit_eltwise_params {
size_t inputs_number;
size_t input_size;

ov::element::Type src_prc[MAX_ELTWISE_INPUTS];
ov::element::Type dst_prc;

VectorDims dims;
VectorDims src_offsets[MAX_ELTWISE_INPUTS];
VectorDims dst_offsets;
VectorDims oc_offsets;

size_t src_size[MAX_ELTWISE_INPUTS];
size_t dst_size;
size_t oc_size;

size_t work_amount;
bool use_runtime_ptrs;
bool do_output_saturation;
};

struct jit_eltwise_call_args_indexes {
size_t indexes[MAX_ELTWISE_DIM_RANK];
};

class Eltwise;

struct jit_uni_eltwise_kernel {
void (*ker_)(const jit_eltwise_call_args_ptrs*, const jit_eltwise_call_args_indexes*);

void operator()(const jit_eltwise_call_args_ptrs* const_args, const jit_eltwise_call_args_indexes* indexes) {
assert(ker_);
ker_(const_args, indexes);
}

explicit jit_uni_eltwise_kernel(jit_eltwise_params jep) : ker_(nullptr), jep_(std::move(jep)) {}
virtual ~jit_uni_eltwise_kernel() {}

virtual void create_ker() = 0;

jit_eltwise_params jep_;
};

#endif

enum class EltwiseImplType { reference = 0, optimized = 1, optimizedShapeAgnostic = 2 };

class Eltwise : public Node {
public:
class IEltwiseExecutor {
Expand Down Expand Up @@ -218,16 +166,6 @@ class Eltwise : public Node {
std::shared_ptr<EltwiseExecutor> eltwiseExecPtr = nullptr;
};

class eltwise_precision_helper {
public:
static ov::element::Type get_precision(const size_t inputs_number,
const ov::element::Type (&src_prc)[MAX_ELTWISE_INPUTS],
const std::vector<EltwiseData>& eltwise_data);

private:
static std::set<std::vector<element::Type>> get_supported_precisions(const Algorithm& algo);
};

} // namespace node
} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,6 @@ using namespace Xbyak_aarch64;
using namespace dnnl::impl::cpu;
using namespace dnnl::impl::cpu::aarch64;

void jit_uni_eltwise_kernel::operator()(const node::jit_eltwise_call_args_ptrs* const_args,
const jit_eltwise_call_args_indexes* indexes) {
assert(ker_);
ker_(const_args, indexes);
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
jit_uni_eltwise_generic<isa>::jit_uni_eltwise_generic(jit_eltwise_params jep,
std::vector<EltwiseData> eltwise_data,
Expand All @@ -35,7 +29,8 @@ template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_uni_eltwise_generic<isa>::generate() {
preamble();

auto const exec_prc = eltwise_precision_helper::get_precision(jep_.inputs_number, jep_.src_prc, eltwise_data_);
static const std::vector<element::Type> exec_precisions_priority = {element::f16, element::f32};
auto const exec_prc = eltwise_precision_helper::get_precision(jep_.inputs_number, jep_.src_prc, eltwise_data_, exec_precisions_priority);

eltwise_emitter = create_eltwise_emitter(eltwise_data_.front(), exec_prc);
for (size_t i = 1; i < eltwise_data_.size(); ++i) {
Expand Down Expand Up @@ -777,81 +772,19 @@ void jit_uni_eltwise_generic<isa>::apply_post_ops() {
}
}

namespace {
template struct jit_uni_eltwise_generic<cpu_isa_t::asimd>;

} // namespace aarch64

namespace {
template <typename T>
struct SupportedPrecisions {
void operator()(std::set<std::vector<element::Type>>& precisions) {
precisions = T::get_supported_precisions();
}
};

static void set_intersection(const std::set<std::vector<element::Type>>& precisions1,
const std::set<std::vector<element::Type>>& precisions2,
std::set<std::vector<element::Type>>& intersection) {
std::map<element::Type, size_t> intersection_types;

for (auto it1 = precisions1.begin(); it1 != precisions1.end(); ++it1) {
for (auto it2 = precisions2.begin(); it2 != precisions2.end(); ++it2) {
const auto& it1_precisions = *it1;
// all element types are equal
if (it1_precisions[0] == (*it2)[0]) {
// first precisions size is used
intersection_types.emplace(it1_precisions[0], it1_precisions.size());
}
}
}

for (auto it = intersection_types.begin(); it != intersection_types.end(); ++it) {
intersection.insert(std::vector<element::Type>(it->second, it->first));
}
}
} // namespace

ov::element::Type eltwise_precision_helper::get_precision(const size_t inputs_number,
const ov::element::Type (&src_prc)[MAX_ELTWISE_INPUTS],
const std::vector<EltwiseData>& eltwise_data) {
ov::element::Type exec_prc = ov::element::undefined;

const auto algorithm = eltwise_data.front().algo;
std::set<std::vector<element::Type>> supported_precision_intersection = get_supported_precisions(algorithm);

for (size_t i = 1; i < eltwise_data.size(); ++i) {
std::set<std::vector<element::Type>> prcs = get_supported_precisions(eltwise_data[i].algo);
std::set<std::vector<element::Type>> prcs_intersect = {};

set_intersection(supported_precision_intersection, prcs, prcs_intersect);

supported_precision_intersection = prcs_intersect;
}

static const element::Type exec_precisions_priority[] = {element::f16, element::f32};

for (const auto prc : exec_precisions_priority) {
if (std::any_of(supported_precision_intersection.begin(),
supported_precision_intersection.end(),
[&prc](const std::vector<element::Type>& precisions) {
return std::find(precisions.begin(), precisions.end(), prc) != precisions.end();
})) {
exec_prc = prc;
break;
}
}

for (size_t i = 0; i < inputs_number; i++) {
if (src_prc[i] != exec_prc) {
exec_prc = ov::element::f32;
break;
}
}

if (exec_prc == ov::element::undefined) {
OPENVINO_THROW("Eltwise jitter failed to specify execution precision for Eltwise node");
}

return exec_prc;
}

std::set<std::vector<element::Type>> eltwise_precision_helper::get_supported_precisions(const Algorithm& algo) {
std::set<std::vector<element::Type>> precisions;

Expand Down Expand Up @@ -909,8 +842,5 @@ std::set<std::vector<element::Type>> eltwise_precision_helper::get_supported_pre
return precisions;
}

template struct jit_uni_eltwise_generic<cpu_isa_t::asimd>;

} // namespace aarch64
} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

#include "emitters/plugin/aarch64/jit_eltwise_emitters.hpp"
#include "emitters/plugin/aarch64/jit_emitter.hpp"
#include "nodes/kernels/jit_eltwise_call_args_ptrs.hpp"
#include "nodes/kernels/jit_eltwise_common.hpp"
#include "utils/cpu_utils.hpp"
#include "utils/general_utils.h"

Expand All @@ -40,45 +40,6 @@ using namespace Xbyak_aarch64;
using namespace dnnl::impl::cpu;
using namespace dnnl::impl::cpu::aarch64;

struct jit_eltwise_params {
size_t inputs_number;
size_t input_size;

ov::element::Type src_prc[MAX_ELTWISE_INPUTS];
ov::element::Type dst_prc;

VectorDims dims;
VectorDims src_offsets[MAX_ELTWISE_INPUTS];
VectorDims dst_offsets;
VectorDims oc_offsets;

size_t src_size[MAX_ELTWISE_INPUTS];
size_t dst_size;
size_t oc_size;

size_t work_amount;
bool use_runtime_ptrs;
bool do_output_saturation;
};

struct jit_eltwise_call_args_indexes {
size_t indexes[MAX_ELTWISE_DIM_RANK];
};

struct jit_uni_eltwise_kernel {
void (*ker_)(const node::jit_eltwise_call_args_ptrs*, const jit_eltwise_call_args_indexes*);

void operator()(const node::jit_eltwise_call_args_ptrs* const_args, const jit_eltwise_call_args_indexes* indexes);

jit_uni_eltwise_kernel() {}
jit_uni_eltwise_kernel(jit_eltwise_params jep) : ker_(nullptr), jep_(std::move(jep)) {}
virtual ~jit_uni_eltwise_kernel() {}

virtual void create_ker() = 0;

jit_eltwise_params jep_;
};

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator {
public:
Expand Down Expand Up @@ -255,16 +216,6 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator {
std::vector<std::shared_ptr<jit_emitter>> post_op_emitters;
};

class eltwise_precision_helper {
public:
static ov::element::Type get_precision(const size_t inputs_number,
const ov::element::Type (&src_prc)[MAX_ELTWISE_INPUTS],
const std::vector<EltwiseData>& eltwise_data);

private:
static std::set<std::vector<element::Type>> get_supported_precisions(const Algorithm& algo);
};

} // namespace aarch64
} // namespace intel_cpu
} // namespace ov

This file was deleted.

90 changes: 90 additions & 0 deletions src/plugins/intel_cpu/src/nodes/kernels/jit_eltwise_common.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "jit_eltwise_common.hpp"

namespace ov {
namespace intel_cpu {

static void set_intersection(const std::set<std::vector<element::Type>>& precisions1,
const std::set<std::vector<element::Type>>& precisions2,
std::set<std::vector<element::Type>>& intersection) {
std::map<element::Type, size_t> intersection_types;

for (auto it1 = precisions1.begin(); it1 != precisions1.end(); ++it1) {
for (auto it2 = precisions2.begin(); it2 != precisions2.end(); ++it2) {
const auto& it1_precisions = *it1;
// all element types are equal
if (it1_precisions[0] == (*it2)[0]) {
// first precisions size is used
intersection_types.emplace(it1_precisions[0], it1_precisions.size());
}
}
}

for (auto it = intersection_types.begin(); it != intersection_types.end(); ++it) {
intersection.insert(std::vector<element::Type>(it->second, it->first));
}
}

ov::element::Type eltwise_precision_helper::get_precision(const size_t inputs_number,
const ov::element::Type (&src_prc)[MAX_ELTWISE_INPUTS],
const std::vector<EltwiseData>& eltwise_data,
const std::vector<element::Type>& exec_precisions_priority) {
ov::element::Type exec_prc = ov::element::undefined;

std::set<std::vector<element::Type>> supported_precision_intersection =
get_supported_precisions(eltwise_data.front().algo);

// for element-wise operations all inputs must to have the same precisions
auto has_same_precision = [](const std::vector<element::Type>& precisions) {
return std::all_of(precisions.begin(), precisions.end(), [&precisions](const element::Type precision) {
return precision == precisions[0];
});
};

assert(std::all_of(supported_precision_intersection.begin(),
supported_precision_intersection.end(),
has_same_precision));

for (size_t i = 1; i < eltwise_data.size(); ++i) {
std::set<std::vector<element::Type>> prcs = get_supported_precisions(eltwise_data[i].algo);
std::set<std::vector<element::Type>> prcs_intersect = {};

OPENVINO_ASSERT(std::all_of(prcs.begin(), prcs.end(), has_same_precision),
"for element-wise nodes all precisions have to be equal");

set_intersection(supported_precision_intersection, prcs, prcs_intersect);

supported_precision_intersection = prcs_intersect;
}

for (const auto prc : exec_precisions_priority) {
if (std::any_of(supported_precision_intersection.begin(),
supported_precision_intersection.end(),
[&prc, &src_prc](const std::vector<element::Type>& precisions) {
return (std::find(precisions.begin(), precisions.end(), prc) != precisions.end()) &&
(src_prc[0] == prc);
})) {
exec_prc = prc;
break;
}
}

for (size_t i = 0; i < inputs_number; i++) {
if (src_prc[i] != exec_prc) {
exec_prc = ov::element::f32;
break;
}
}

if (exec_prc == ov::element::undefined) {
OPENVINO_THROW("Eltwise jitter failed to specify execution precision for Eltwise node");
}

return exec_prc;
}

} // namespace intel_cpu
} // namespace ov
Loading

0 comments on commit 149c684

Please sign in to comment.