|
| 1 | +// Copyright (C) 2018-2025 Intel Corporation |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | +// |
| 4 | +#pragma once |
| 5 | + |
| 6 | +#include "convolution_shape_inference_util.hpp" |
| 7 | +#include "openvino/op/convolution.hpp" |
| 8 | +#include "ov_ops/convolution.hpp" |
| 9 | +#include "utils.hpp" |
| 10 | + |
| 11 | +namespace ov { |
| 12 | +namespace op { |
| 13 | + |
| 14 | +template <class TOp, |
| 15 | + class TShape, |
| 16 | + class TRShape = result_shape_t<TShape>, |
| 17 | + typename std::enable_if<std::is_same<TOp, internal::Convolution>::value>::type* = nullptr> |
| 18 | +std::vector<TRShape> shape_infer(const TOp* op, |
| 19 | + const std::vector<TShape>& input_shapes, |
| 20 | + CoordinateDiff& pads_begin, |
| 21 | + CoordinateDiff& pads_end) { |
| 22 | + NODE_VALIDATION_CHECK(op, input_shapes.size() >= 2); |
| 23 | + using namespace ov::util; |
| 24 | + |
| 25 | + const auto num_spatial = convolution::calculate_num_spatial(op, input_shapes); |
| 26 | + |
| 27 | + auto output_shapes = std::vector<TRShape>(1); |
| 28 | + auto& output_shape = output_shapes[0]; |
| 29 | + if (num_spatial != util::num_spatial_undefined) { |
| 30 | + const auto& data_shape = input_shapes[0]; |
| 31 | + const auto& filters_shape = input_shapes[1]; |
| 32 | + const auto data_rank = data_shape.rank(); |
| 33 | + const auto filters_rank = filters_shape.rank(); |
| 34 | + |
| 35 | + if (op->get_groups() > 1) { |
| 36 | + convolution::resize_empty_padding(num_spatial, pads_begin, pads_end); |
| 37 | + if (is_attr_validation_required(op)) { |
| 38 | + convolution::validate::data_shape(op, data_shape); |
| 39 | + |
| 40 | + NODE_VALIDATION_CHECK(op, |
| 41 | + data_rank.compatible(filters_rank - 1), |
| 42 | + "Data batch and filters rank do not match (data batch shape: ", |
| 43 | + data_shape, |
| 44 | + ", filters shape: ", |
| 45 | + filters_shape, |
| 46 | + ")."); |
| 47 | + |
| 48 | + convolution::validate::common_attributes(op, num_spatial, pads_begin, pads_end); |
| 49 | + } |
| 50 | + convolution::apply_padding(op, data_shape, filters_shape, pads_begin, pads_end); |
| 51 | + |
| 52 | + output_shape.reserve(util::spatial_dim_offset + num_spatial); |
| 53 | + output_shape.emplace_back(data_rank.is_static() ? data_shape[0] : dim::inf_bound); |
| 54 | + |
| 55 | + if (filters_rank.is_static()) { |
| 56 | + auto groups = filters_shape[0]; |
| 57 | + |
| 58 | + if (data_rank.is_static() && filters_shape[2].is_static()) { |
| 59 | + NODE_VALIDATION_CHECK( |
| 60 | + op, |
| 61 | + groups.merge(groups, groups, (data_shape[1] / filters_shape[2].get_length())), |
| 62 | + "Input channels dimension of data batch is incompatible with filter groups or input channels."); |
| 63 | + } |
| 64 | + |
| 65 | + groups *= filters_shape[1]; |
| 66 | + output_shape.push_back(std::move(groups)); |
| 67 | + } else { |
| 68 | + output_shape.emplace_back(dim::inf_bound); |
| 69 | + } |
| 70 | + } else { |
| 71 | + convolution::resize_empty_padding(num_spatial, pads_begin, pads_end); |
| 72 | + convolution::validate::filter_shape(op, filters_shape, data_shape); |
| 73 | + if (is_attr_validation_required(op)) { |
| 74 | + convolution::validate::data_shape(op, data_shape); |
| 75 | + convolution::validate::common_attributes(op, num_spatial, pads_begin, pads_end); |
| 76 | + } |
| 77 | + convolution::apply_padding(op, data_shape, filters_shape, pads_begin, pads_end); |
| 78 | + |
| 79 | + output_shape.reserve(util::spatial_dim_offset + num_spatial); |
| 80 | + output_shape.emplace_back(data_rank.is_static() ? data_shape[0] : dim::inf_bound); |
| 81 | + output_shape.emplace_back(filters_rank.is_static() ? filters_shape[0] : dim::inf_bound); |
| 82 | + } |
| 83 | + |
| 84 | + convolution::append_spatial_shape(op, data_shape, filters_shape, pads_begin, pads_end, output_shape); |
| 85 | + } else { |
| 86 | + output_shape = PartialShape::dynamic(); |
| 87 | + } |
| 88 | + |
| 89 | + return output_shapes; |
| 90 | +} |
| 91 | + |
| 92 | +} // namespace op |
| 93 | +} // namespace ov |
0 commit comments