Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
alvoron committed Jan 22, 2025
1 parent eb17cb3 commit 20218d2
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
14 changes: 11 additions & 3 deletions src/plugins/intel_cpu/src/nodes/executors/acl/acl_pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ bool AclPoolingExecutor::isSupported(const TensorInfo& srcTensorInfo,
pool_info->pool_type = pool_type;
pool_info->exclude_padding = exclude_padding;
if (dstDescsSize > 1) {
TensorInfo indTensorInfo = TensorInfo(shapeCast(*indDims), 1, arm_compute::DataType::U32, dataLayout);
auto indShape = shapeCast(*indDims);
if (dataLayout == arm_compute::DataLayout::NHWC) {
changeLayoutToNH_C({&indShape});
}
TensorInfo indTensorInfo = TensorInfo(indShape, 1, arm_compute::DataType::U32, dataLayout);
arm_compute::Status s =
arm_compute::NEPoolingLayer::validate(&srcTensorInfo, &dstTensorInfo, *pool_info, &indTensorInfo);
if (!s) {
Expand Down Expand Up @@ -176,9 +180,13 @@ bool AclPoolingExecutor::init(const PoolingAttrs& poolingAttrs,
nullptr))
return false;
auto indDims = dstDescs[1]->getShape().getStaticDims();
TensorInfo indTensorInfo = TensorInfo(shapeCast(indDims),
auto indShape = shapeCast(indDims);
if (dstTensorInfo.data_layout() == arm_compute::DataLayout::NHWC) {
changeLayoutToNH_C({&indShape});
}
TensorInfo indTensorInfo = TensorInfo(indShape,
1,
precisionToAclDataType(dstDescs[1]->getPrecision()),
arm_compute::DataType::U32,
getAclDataLayoutByMemoryDesc(dstDescs[1]));
indTensor.allocator()->init(indTensorInfo);
exec_func = [this, pool_info]() -> std::unique_ptr<IFunction> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ class AclPoolingExecutorBuilder : public PoolingExecutorBuilder {
return false;
}

if (dstDescs.size() == 2u && dstDescs[1]->getPrecision() != ov::element::u32) {
if (dstDescs.size() == 2u &&
(dstDescs[1]->getPrecision() != ov::element::u32 && dstDescs[1]->getPrecision() != ov::element::i32)) {
DEBUG_LOG("AclPoolingExecutor supports U32 as indices precisions only. ",
"Passed indices precision: ",
dstDescs[1]->getPrecision());
Expand Down

0 comments on commit 20218d2

Please sign in to comment.