Skip to content

Commit

Permalink
Support expand (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
BruceDai authored May 15, 2023
1 parent 9c39278 commit 43aa5b3
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 0 deletions.
18 changes: 18 additions & 0 deletions onnxruntime/core/providers/webnn/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,5 +107,23 @@ bool IsSupportedDataType(int32_t data_type) {
return std::find(supported_data_types.begin(), supported_data_types.end(), data_type) != supported_data_types.end();
}

bool IsValidMultidirectionalBroadcast(std::vector<int64_t>& shape_a,
std::vector<int64_t>& shape_b,
const logging::Logger& logger) {
int64_t size_a = shape_a.size();
int64_t size_b = shape_b.size();
int64_t smaller_size = std::min(size_a, size_b);
for (int64_t i = 0; i < smaller_size; i++) {
// right alignment
int64_t axis_a = size_a - i - 1;
int64_t axis_b = size_b - i - 1;
// Broadcastable tensors must either have each dimension the same size or equal to one.
if (shape_a[axis_a] != shape_b[axis_b] && shape_a[axis_a] != 1 && shape_b[axis_b] != 1) {
return false;
}
}
return true;
}

} // namespace webnn
} // namespace onnxruntime
18 changes: 18 additions & 0 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,20 @@ namespace webnn {

bool GetShape(const NodeArg& node_arg, std::vector<int64_t>& shape, const logging::Logger& logger);

template <typename T>
std::string GetShapeString(std::vector<T>& shape) {
std::stringstream shape_info;
shape_info << "[";
for (size_t i = 0; i < shape.size(); i++) {
if (i != 0) {
shape_info << ", ";
}
shape_info << shape[i];
}
shape_info << "]";
return shape_info.str();
}

bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger);

// Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP.
Expand Down Expand Up @@ -53,6 +67,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
{"Conv", "conv2d"},
{"ConvTranspose", "convTranspose2d"},
{"Concat", "concat"},
{"Expand", "expand"},
{"Gemm", "gemm"},
{"MatMul", "matmul"},
{"GlobalAveragePool", "averagePool2d"},
Expand All @@ -77,5 +92,8 @@ constexpr std::array<ONNX_NAMESPACE::TensorProto_DataType, 3> supported_data_typ

bool IsSupportedDataType(int32_t data_type);

bool IsValidMultidirectionalBroadcast(std::vector<int64_t>& shape_a,
std::vector<int64_t>& shape_b,
const logging::Logger& logger);
} // namespace webnn
} // namespace onnxruntime
139 changes: 139 additions & 0 deletions onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Intel Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/common/safeint.h"
#include "core/framework/tensorprotoutils.h"
#include "core/optimizer/initializer.h"
#include "core/providers/common.h"
#include "core/providers/cpu/tensor/reshape_helper.h"
#include "core/providers/shared/utils/utils.h"
#include "core/providers/webnn/builders/helper.h"
#include "core/providers/webnn/builders/model_builder.h"
#include "core/providers/webnn/builders/op_builder_factory.h"

#include "base_op_builder.h"

namespace onnxruntime {
namespace webnn {

class ExpandOpBuilder : public BaseOpBuilder {
// Add operator related.
public:
void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;

private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;

// Operator support related.
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const logging::Logger& logger) const override;
};

bool GetExpandShape(const onnx::TensorProto& tensor, std::vector<int64_t>& shape, const logging::Logger& logger) {
std::vector<uint8_t> unpacked_tensor;
auto status = onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor);
if (!status.IsOK()) {
LOGS(logger, ERROR) << "Error while unpacking shape: " << status.ErrorMessage();
return false;
}
const auto& dims = tensor.dims();
if (dims.empty() || dims[0] == 0) {
LOGS(logger, VERBOSE) << "The shape of expand cannot be empty.";
return false;
}
if (dims.size() != 1) {
LOGS(logger, VERBOSE) << "The shape of expand must be 1D.";
return false;
}
if (tensor.data_type() != ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
LOGS(logger, VERBOSE) << "The shape element data type must be INT64.";
return false;
}
const int64_t* shape_data = reinterpret_cast<const int64_t*>(unpacked_tensor.data());
shape = std::vector<int64_t>{shape_data, shape_data + dims[0]};
return true;
}

// Add operator related.

void ExpandOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name());
}

Status ExpandOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const Node& node,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& initializers(model_builder.GetInitializerTensors());
const auto& shape_tensor = *initializers.at(input_defs[1]->Name());
std::vector<int64_t> raw_shape;
ORT_RETURN_IF_NOT(GetExpandShape(shape_tensor, raw_shape, logger), "Cannot get shape.");
std::vector<int32_t> new_shape;
std::transform(raw_shape.cbegin(), raw_shape.cend(),
std::back_inserter(new_shape),
[](int64_t dim) -> int32_t { return SafeInt<int32_t>(dim); });
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
std::vector<int64_t> input_shape;
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input's shape.");
if (new_shape.size() < input_shape.size()) {
// Enlarge new shape to input.rank, right aligned with leading ones
new_shape.insert(new_shape.begin(), input_shape.size() - new_shape.size(), 1);
}
emscripten::val output =
model_builder.GetBuilder().call<emscripten::val>("expand",
input, emscripten::val::array(new_shape));
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
return Status::OK();
}

// Operator support related.

bool ExpandOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& shape_name = input_defs[1]->Name();
if (!Contains(initializers, shape_name)) {
LOGS(logger, VERBOSE) << "The shape must be a constant initializer.";
return false;
}

std::vector<int64_t> new_shape;
const auto& shape_tensor = *initializers.at(shape_name);
if (!GetExpandShape(shape_tensor, new_shape, logger)) {
LOGS(logger, VERBOSE) << "Cannot get shape.";
return false;
}

std::vector<int64_t> input_shape;
if (!GetShape(*input_defs[0], input_shape, logger)) {
LOGS(logger, VERBOSE) << "Cannot get input's shape.";
return false;
}

if (input_shape.empty()) {
LOGS(logger, VERBOSE) << "Expand does not support empty input's shape.";
return false;
}

if (new_shape.size() > input_shape.size()) {
LOGS(logger, VERBOSE) << "The size of shape must be less than or equal to the rank of input.";
}

if (!IsValidMultidirectionalBroadcast(input_shape, new_shape, logger)) {
LOGS(logger, VERBOSE) << "The input cannot expand to shape " << GetShapeString(new_shape);
return false;
}

return true;
}

void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<ExpandOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
}

} // namespace webnn
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
CreateConcatOpBuilder("Concat", op_registrations);
}

{ // Expand
CreateExpandOpBuilder("Expand", op_registrations);
}

{ // Gemm/MatMul
CreateGemmOpBuilder("Gemm", op_registrations);
CreateGemmOpBuilder("MatMul", op_registrations);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_
void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreatePoolOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateReductionOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
Expand Down

0 comments on commit 43aa5b3

Please sign in to comment.