From 61dfb1018eaf3bcdd9a8c6ea3c3a5e19897c0be3 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Sat, 25 Jan 2025 02:18:54 +0000 Subject: [PATCH] stash --- .../builders/impl/activation_op_builder.cc | 5 +-- .../coreml/builders/impl/argmax_op_builder.cc | 7 +--- .../builders/impl/batch_norm_op_builder.cc | 5 +-- .../coreml/builders/impl/binary_op_builder.cc | 7 +--- .../coreml/builders/impl/builder_utils.cc | 2 - .../coreml/builders/impl/builder_utils.h | 2 - .../coreml/builders/impl/cast_op_builder.cc | 8 +--- .../coreml/builders/impl/clip_op_builder.cc | 5 +-- .../coreml/builders/impl/concat_op_builder.cc | 2 - .../coreml/builders/impl/conv_op_builder.cc | 12 +----- .../builders/impl/convtranspose_op_builder.cc | 2 - .../builders/impl/depthtospace_op_builder.cc | 2 - .../coreml/builders/impl/gemm_op_builder.cc | 10 +---- .../builders/impl/gridsample_op_builder.cc | 2 - .../builders/impl/normalization_op_builder.cc | 4 -- .../coreml/builders/impl/pool_op_builder.cc | 5 +-- .../builders/impl/reduction_op_builder.cc | 5 +-- .../builders/impl/reshape_op_builder.cc | 5 +-- .../coreml/builders/impl/resize_op_builder.cc | 5 +-- .../coreml/builders/impl/shape_op_builder.cc | 5 +-- .../coreml/builders/impl/slice_op_builder.cc | 18 +++------ .../builders/impl/softmax_op_builder.cc | 5 +-- .../coreml/builders/impl/split_op_builder.cc | 5 +-- .../builders/impl/squeeze_op_builder.cc | 8 +--- .../builders/impl/transpose_op_builder.cc | 5 +-- .../coreml/builders/impl/unary_op_builder.cc | 5 +-- .../coreml/builders/model_builder.cc | 37 ++----------------- .../providers/coreml/builders/model_builder.h | 10 ++--- .../core/providers/coreml/coreml_options.cc | 12 ------ .../core/providers/coreml/model/host_utils.h | 3 +- .../core/providers/xnnpack/nn/max_pool.cc | 2 +- .../test/contrib_ops/layer_norm_op_test.cc | 2 +- .../providers/coreml/coreml_basic_test.cc | 2 +- .../cpu/activation/activation_op_test.cc | 6 +-- .../cpu/activation/activation_op_test.h | 2 +- .../cpu/math/element_wise_ops_test.cc | 8 ++-- .../test/providers/cpu/math/matmul_test.cc | 4 +- .../providers/cpu/nn/batch_norm_op_test.cc | 2 +- .../test/providers/cpu/nn/conv_fp16_test.cc | 4 +- .../providers/cpu/nn/group_norm_op_test.cc | 2 +- .../providers/cpu/nn/instance_norm_op_test.cc | 2 +- .../providers/cpu/nn/pool_fp16_op_test.cc | 2 +- .../test/providers/cpu/nn/pool_op_test.cc | 2 +- .../cpu/reduction/reduction_ops_test.cc | 6 +-- 44 files changed, 60 insertions(+), 194 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc index 4481a5172966b..3fffc6d0a68c4 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc @@ -97,7 +97,6 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const logging::Logger& logger) const { const auto& op_type(node.OpType()); -#if defined(COREML_ENABLE_MLPROGRAM) if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#module-coremltools.converters.mil.mil.ops.defs.iOS15.activation @@ -166,9 +165,7 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddOperation(std::move(op)); - } else -#endif // (COREML_ENABLE_MLPROGRAM) - { + } else { std::unique_ptr layer = model_builder.CreateNNLayer(node); if (op_type == "Sigmoid") { diff --git a/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc index 6169090a36014..dfa01c8187741 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc @@ -32,7 +32,6 @@ Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const int64_t keepdims = helper.Get("keepdims", 1); const bool removedim = keepdims != 1; -#if defined(COREML_ENABLE_MLPROGRAM) if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#module-coremltools.converters.mil.mil.ops.defs.iOS15.reduction @@ -46,9 +45,7 @@ Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // the output of ArgMax must be int32 AddOperationOutput(*op, *node.OutputDefs()[0], output_datatype); model_builder.AddOperation(std::move(op)); - } else -#endif // (COREML_ENABLE_MLPROGRAM) - { + } else { auto* coreml_argmax = layer->mutable_argmax(); coreml_argmax->set_axis(axis); coreml_argmax->set_removedim(removedim); @@ -91,11 +88,9 @@ bool ArgMaxOpBuilder::IsOpSupportedImpl(const Node& node, return false; } -#if defined(COREML_ENABLE_MLPROGRAM) if (input_params.create_mlprogram) { return true; } -#endif // If there are multiple downstream nodes and cast (toint32) is one of them // not supported, exit here diff --git a/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc index 442194cb31cbc..e547f2e42e527 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc @@ -57,7 +57,6 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu const auto eps = helper.Get("epsilon", 1e-5f); const auto channels = scale_tensor.dims()[0]; -#if defined(COREML_ENABLE_MLPROGRAM) if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.normalization.batch_norm @@ -78,9 +77,7 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu AddOperationOutput(*op, *node.OutputDefs()[0]); model_builder.AddOperation(std::move(op)); - } else -#endif // (COREML_ENABLE_MLPROGRAM) - { + } else { auto* coreml_batch_norm = layer->mutable_batchnorm(); coreml_batch_norm->set_channels(channels); coreml_batch_norm->set_epsilon(eps); diff --git a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc index 0482620b269a4..d7c78e05362ed 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc @@ -56,7 +56,6 @@ bool CheckIfBothInputShapesMatch(const Node& node, const logging::Logger& logger } } // namespace -#if defined(COREML_ENABLE_MLPROGRAM) static std::vector InferOutputShape(const std::vector& a, const std::vector& b) { std::vector output_shape; int64_t i_a = 0, j_b = 0; @@ -112,14 +111,12 @@ static void AddVariadicInputs(std::unique_ptr layer = model_builder.CreateNNLayer(node); if (op_type == "Add") { diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc index 6f9bb35c27d80..684653aa21273 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc @@ -150,7 +150,6 @@ void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span data); -#if defined(COREML_ENABLE_MLPROGRAM) // // MLProgram utils // @@ -174,6 +173,5 @@ void AddOperationOutput(COREML_SPEC::MILSpec::Operation& op, const NodeArg& outp /// Number of spatial dims in input. Generally rank - 2 (ignore N and C dims). void AddPadTypeAndPads(COREML_SPEC::MILSpec::Operation& op, ModelBuilder& model_builder, std::string_view op_type, const NodeAttrHelper& helper, int num_spatial_dims); -#endif // defined(COREML_ENABLE_MLPROGRAM) } // namespace coreml } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc index 7c7363d4c81ad..8abee92451338 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc @@ -27,9 +27,8 @@ class CastOpBuilder : public BaseOpBuilder { Status CastOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuilder& model_builder, [[maybe_unused]] const Node& node, [[maybe_unused]] const logging::Logger& logger) const { -// This is a special handling case for ArgMax Op, where argmax is followed by a cast to int32 type. -// The ArgMax is fused with the Cast node and produces an int32 output. -#if defined(COREML_ENABLE_MLPROGRAM) + // This is a special handling case for ArgMax Op, where argmax is followed by a cast to int32 type. + // The ArgMax is fused with the Cast node and produces an int32 output. if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.elementwise_unary.cast @@ -73,7 +72,6 @@ Status CastOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuilder& model AddOperationOutput(*op, *node.OutputDefs()[0], cast_to_type); model_builder.AddOperation(std::move(op)); } -#endif return Status::OK(); } @@ -134,7 +132,6 @@ bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, [[maybe_unused]] co return false; } -#if defined(COREML_ENABLE_MLPROGRAM) if (input_params.create_mlprogram) { if ((input_type == ONNX_NAMESPACE::TensorProto_DataType_INT32 || input_type == ONNX_NAMESPACE::TensorProto_DataType_INT64 || @@ -152,7 +149,6 @@ bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, [[maybe_unused]] co return false; } } -#endif // only support int64 coming from ArgMax (check for ArgMax is done in IsOpSupportedImpl()) if (input_type != ONNX_NAMESPACE::TensorProto_DataType_INT64) { diff --git a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc index f7046c213a8cb..9e68070a0e693 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc @@ -64,7 +64,6 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, bool has_min = min != std::numeric_limits::lowest(); bool has_max = max != std::numeric_limits::max(); -#if defined(COREML_ENABLE_MLPROGRAM) if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; @@ -121,9 +120,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, AddOperationOutput(*op, output); model_builder.AddOperation(std::move(op)); - } else -#endif // defined(COREML_ENABLE_MLPROGRAM) - { + } else { // TODO: CoreML has a Clip layer for NeuralNetwork. Added in CoreML 4. We could potentially use that if available // to simplify. // https://apple.github.io/coremltools/mlmodel/Format/NeuralNetwork.html#cliplayerparams diff --git a/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc index 9ea0030290abd..34ce2438095ad 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc @@ -26,7 +26,6 @@ class ConcatOpBuilder : public BaseOpBuilder { Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { -#if defined(COREML_ENABLE_MLPROGRAM) if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; // NOLINT @@ -45,7 +44,6 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, AddOperationOutput(*op, *node.OutputDefs()[0]); model_builder.AddOperation(std::move(op)); } else // NOLINT -#endif // defined(COREML_ENABLE_MLPROGRAM) { std::unique_ptr layer = model_builder.CreateNNLayer(node); diff --git a/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc index 38125957bf481..18823bcc78d19 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc @@ -52,7 +52,6 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N NodeAttrHelper helper(node); -#if defined(COREML_ENABLE_MLPROGRAM) if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; @@ -89,9 +88,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N AddOperationOutput(*conv_op, *node.OutputDefs()[0]); model_builder.AddOperation(std::move(conv_op)); - } else -#endif // defined(COREML_ENABLE_MLPROGRAM) - { + } else { std::unique_ptr layer = model_builder.CreateNNLayer(node); auto strides = helper.Get("strides", std::vector{1, 1}); @@ -225,14 +222,11 @@ bool ConvOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara const auto& weight_name = input_defs[1]->Name(); const auto* weight = input_params.graph_viewer.GetConstantInitializer(weight_name); -#if defined(COREML_ENABLE_MLPROGRAM) if (input_params.create_mlprogram) { // ML Program supports non-const weight, 1D, 2D and 3D. // keep to 1D and 2D for consistency with the NeuralNetwork implementation for now. // add 3D support as/when needed. - } else -#endif // defined (COREML_ENABLE_MLPROGRAM) - { + } else { if (!weight) { LOGS(logger, VERBOSE) << "The weight of Conv [" << name << "] must be a constant initializer"; return false; @@ -257,7 +251,6 @@ bool ConvOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara NodeAttrHelper helper(node); -#if defined(COREML_ENABLE_MLPROGRAM) // spec says same_lower is supported in CoreML 5. it lies. CoreML 6 is required otherwise you get // `Unexpected value for parameter pad_type[0] "same_lower" not in ("custom", "same", "valid").` // We _could_ manually calculate the pads, but not implementing that until we have a real use case to justify @@ -269,7 +262,6 @@ bool ConvOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara return false; } } -#endif // there's no equivalent to allow a manual kernel shape in CoreML. // it's OK if a specified kernel_shape matches kH and kW dims of the weight input. diff --git a/onnxruntime/core/providers/coreml/builders/impl/convtranspose_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/convtranspose_op_builder.cc index 5b6d9d72ab3c9..2e2c898b0e10a 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/convtranspose_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/convtranspose_op_builder.cc @@ -28,7 +28,6 @@ class ConvTransposeOpBuilder : public BaseOpBuilder { Status ConvTransposeOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuilder& model_builder, [[maybe_unused]] const Node& node, const logging::Logger& /*logger*/) const { -#if defined(COREML_ENABLE_MLPROGRAM) using namespace CoreML::Specification::MILSpec; // NOLINT const auto input_defs = node.InputDefs(); const auto output_defs = node.OutputDefs(); @@ -80,7 +79,6 @@ Status ConvTransposeOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuild AddOperationOutput(*op, *output_defs[0]); model_builder.AddOperation(std::move(op)); -#endif // defined(COREML_ENABLE_MLPROGRAM) return Status::OK(); } diff --git a/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc index fec14dfd093a0..1a74b1eea97fe 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc @@ -33,7 +33,6 @@ Status DepthToSpaceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, NodeAttrHelper helper(node); int64_t blocksize = *helper.GetInt64("blocksize"); // required attribute -#if defined(COREML_ENABLE_MLPROGRAM) if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; // NOLINT @@ -105,7 +104,6 @@ Status DepthToSpaceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddOperation(std::move(reshape2)); } } else // NOLINT -#endif // if defined(COREML_ENABLE_MLPROGRAM) { const auto& output_name = output_defs[0]->Name(); std::unique_ptr layer = model_builder.CreateNNLayer(node); diff --git a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc index e685c09ef43ca..4f84f7c36259c 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc @@ -33,7 +33,6 @@ void GemmOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Nod const auto& input_defs(node.InputDefs()); const bool is_gemm = op == "Gemm"; -#if defined(COREML_ENABLE_MLPROGRAM) if (model_builder.CreateMLProgram()) { // we have to transpose the weight input of Gemm if transB is false, and potentially override the bias shape if (is_gemm) { @@ -58,9 +57,7 @@ void GemmOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Nod } } } - } else -#endif // defined(COREML_ENABLE_MLPROGRAM) - { + } else { // We have already embedded the weights (matrix B and C(if any)) into the coreml layer // No need to copy them later to reduce memory consumption model_builder.AddInitializerToSkip(input_defs[1]->Name()); @@ -123,7 +120,6 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N const auto K = transB ? b1 : b0; const auto N = transB ? b0 : b1; // we already checked it and dtype must be existed. -#if defined(COREML_ENABLE_MLPROGRAM) auto input_dtype = a.TypeAsProto()->tensor_type().elem_type(); if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; @@ -207,9 +203,7 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N AddOperationOutput(*matmul_op, *node.OutputDefs()[0]); model_builder.AddOperation(std::move(matmul_op)); } - } else -#endif // defined(COREML_ENABLE_MLPROGRAM) - { + } else { auto* coreml_inner_product = layer->mutable_innerproduct(); *layer->mutable_input()->Add() = a.Name(); diff --git a/onnxruntime/core/providers/coreml/builders/impl/gridsample_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gridsample_op_builder.cc index 6dcf14c16f111..f558f423752e8 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/gridsample_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/gridsample_op_builder.cc @@ -42,7 +42,6 @@ class GridSampleOpBuilder : public BaseOpBuilder { Status GridSampleOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuilder& model_builder, [[maybe_unused]] const Node& node, [[maybe_unused]] const logging::Logger& logger) const { -#if defined(COREML_ENABLE_MLPROGRAM) using namespace CoreML::Specification::MILSpec; // NOLINT // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.image_resizing.resample @@ -80,7 +79,6 @@ Status GridSampleOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuilder& AddOperationOutput(*op, *output_defs[0]); model_builder.AddOperation(std::move(op)); -#endif return Status::OK(); } diff --git a/onnxruntime/core/providers/coreml/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/normalization_op_builder.cc index b4dc8d1647ad0..c0db144602ee2 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/normalization_op_builder.cc @@ -49,7 +49,6 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl( if (node.OpType() == "GroupNormalization") { return AddGroupNormToModelBuilderImpl(model_builder, node, logger); } -#if defined(COREML_ENABLE_MLPROGRAM) const auto& input_defs = node.InputDefs(); NodeAttrHelper helper(node); const auto& scale_tensor = *model_builder.GetConstantInitializer(input_defs[1]->Name()); @@ -94,7 +93,6 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl( AddOperationOutput(*op, *node.OutputDefs()[0]); model_builder.AddOperation(std::move(op)); } -#endif // (COREML_ENABLE_MLPROGRAM) return Status::OK(); } @@ -103,7 +101,6 @@ Status NormalizationOpBuilder::AddGroupNormToModelBuilderImpl( [[maybe_unused]] ModelBuilder& model_builder, [[maybe_unused]] const Node& node, [[maybe_unused]] const logging::Logger& logger) const { -#if defined(COREML_ENABLE_MLPROGRAM) const auto& input_defs = node.InputDefs(); NodeAttrHelper helper(node); // Coreml hasn't supported GroupNorm yet. @@ -184,7 +181,6 @@ Status NormalizationOpBuilder::AddGroupNormToModelBuilderImpl( model_builder.AddOperation(std::move(mul)); model_builder.AddOperation(std::move(add)); } -#endif // (COREML_ENABLE_MLPROGRAM) return Status::OK(); } diff --git a/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc index 17910ba6fd486..e43eef75007cc 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc @@ -29,7 +29,6 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); -#if defined(COREML_ENABLE_MLPROGRAM) if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; @@ -91,9 +90,7 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, AddOperationOutput(*op, *node.OutputDefs()[0]); model_builder.AddOperation(std::move(op)); - } else -#endif // defined(COREML_ENABLE_MLPROGRAM) - { + } else { std::unique_ptr layer = model_builder.CreateNNLayer(node); auto* coreml_pool = layer->mutable_pooling(); diff --git a/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc index d533b867bd454..a4609eb2a0584 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc @@ -71,7 +71,6 @@ Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co const bool keepdims = helper.Get("keepdims", 1) != 0; const bool noop_with_empty_axes = helper.Get("noop_with_empty_axes", 0) != 0; -#if defined(COREML_ENABLE_MLPROGRAM) if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; @@ -103,9 +102,7 @@ Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co AddOperationOutput(*op, *node.OutputDefs()[0]); model_builder.AddOperation(std::move(op)); - } else -#endif // (COREML_ENABLE_MLPROGRAM) - { + } else { std::unique_ptr layer = model_builder.CreateNNLayer(node); if (op_type == "ReduceSum") { diff --git a/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc index 27d24d9c21893..b35d6971623ed 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc @@ -50,7 +50,6 @@ Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // ReshapeHelper applies the ONNX rules to create the concrete output shape ReshapeHelper helper(TensorShape(input_shape), new_shape); -#if defined(COREML_ENABLE_MLPROGRAM) if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; @@ -64,9 +63,7 @@ Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, AddOperationOutput(*reshape_op, *node.OutputDefs()[0]); model_builder.AddOperation(std::move(reshape_op)); - } else -#endif // defined(COREML_ENABLE_MLPROGRAM) - { + } else { std::unique_ptr layer = model_builder.CreateNNLayer(node); *layer->mutable_reshapestatic()->mutable_targetshape() = {new_shape.cbegin(), new_shape.cend()}; diff --git a/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc index 7ff66e4a79e37..837573003e515 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc @@ -212,7 +212,6 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const num_sizes = output_sizes.size(); } -#if defined(COREML_ENABLE_MLPROGRAM) if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; // NOLINT @@ -279,9 +278,7 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const AddOperationOutput(*op, *output_defs[0]); model_builder.AddOperation(std::move(op)); - } else // NOLINT -#endif - { + } else { std::unique_ptr layer = model_builder.CreateNNLayer(node); auto* coreml_upsample = layer->mutable_upsample(); diff --git a/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc index 243f949bdd48e..d1c87b033d323 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc @@ -25,7 +25,6 @@ Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const const logging::Logger& /*logger*/) const { const auto& input_defs = node.InputDefs(); -#if defined(COREML_ENABLE_MLPROGRAM) if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; NodeAttrHelper node_attr_helper{node}; @@ -63,9 +62,7 @@ Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const AddOperationOutput(*op, *node.OutputDefs()[0], output_datatype); model_builder.AddOperation(std::move(op)); } - } else // NOLINT -#endif - { + } else { auto layer = model_builder.CreateNNLayer(node); layer->mutable_getshape(); *layer->mutable_input()->Add() = input_defs[0]->Name(); diff --git a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc index 6b3fe75fa592d..368e47e40f831 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc @@ -127,7 +127,6 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const SliceOp::PrepareForComputeMetadata compute_metadata{data_shape}; ORT_RETURN_IF_ERROR(PrepareSliceComputeMetadata(node, model_builder.GetGraphViewer(), compute_metadata)); -#if defined(COREML_ENABLE_MLPROGRAM) if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; // NOLINT // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.tensor_transformation.slice_by_index @@ -178,9 +177,7 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const model_builder.AddOperation(std::move(op)); - } else // NOLINT -#endif // defined(COREML_ENABLE_MLPROGRAM) - { + } else { auto layer = model_builder.CreateNNLayer(node); *layer->mutable_input()->Add() = input_defs[0]->Name(); *layer->mutable_output()->Add() = output_defs[0]->Name(); @@ -222,7 +219,6 @@ bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, return false; } -#ifdef COREML_ENABLE_MLPROGRAM // The [Doc](https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.tensor_transformation.slice_by_index) // says ML Program slice_by_index supports fp16 in CoreML 5 (iOS 15). // It's incorrect and CoreML 6+ (iOS16, CoreML spec version >= 7) is required otherwise only float is supported. @@ -230,13 +226,11 @@ bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, // CoreML 6:https://github.com/apple/coremltools/blob/c3ea4cf56fef1176417246c1b85363417f3e713d/coremltools/converters/mil/mil/ops/defs/iOS15/tensor_transformation.py#L495 if (input_params.create_mlprogram && input_params.coreml_version >= 6 && input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { - } else -#endif // nolint - if (input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && - input_type != ONNX_NAMESPACE::TensorProto_DataType_INT64) { - LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Input type: [" << input_type << "] is not supported"; - return false; - } + } else if (input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && + input_type != ONNX_NAMESPACE::TensorProto_DataType_INT64) { + LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Input type: [" << input_type << "] is not supported"; + return false; + } return true; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc index c6e331feed326..2411cd459fecd 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc @@ -37,7 +37,6 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto axis = helper.Get("axis", axis_default_value); auto axis_nonnegative = HandleNegativeAxis(axis, data_shape.size()); -#if defined(COREML_ENABLE_MLPROGRAM) // CoreML's softmax match onnx's softmax behavior since opset 13. // For opset < 13, we need to reshape to 2D and set axis to -1 to simulate onnx softmax behavior. // [B,D,...](onnx softmax opset 12, axis=1)->[B,D*...](CoreML softmax, axis=-1)->[B,D,...](reshape back) @@ -78,9 +77,7 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, AddOperationOutput(*reshape2, *node.OutputDefs()[0]); model_builder.AddOperation(std::move(reshape2)); } - } else // NOLINT -#endif - { + } else { if (node.SinceVersion() >= 13 || (data_shape.size() == 2)) { auto* coreml_softmaxnd = layer->mutable_softmaxnd(); coreml_softmaxnd->set_axis(axis); diff --git a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc index 6372f3136123b..717d344982473 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc @@ -56,7 +56,6 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return std::make_tuple(remainder, chunk_size); }; -#if defined(COREML_ENABLE_MLPROGRAM) if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; std::unique_ptr split_op = model_builder.CreateOperation(node, "split"); @@ -95,9 +94,7 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } model_builder.AddOperation(std::move(split_op)); - } else -#endif - { + } else { std::unique_ptr layer = model_builder.CreateNNLayer(node); auto* coreml_splitnd = layer->mutable_splitnd(); coreml_splitnd->set_axis(axis); diff --git a/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc index a1b3a18265c70..81bef11906b74 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc @@ -58,7 +58,6 @@ void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const } } -#if defined(COREML_ENABLE_MLPROGRAM) void HandleX86ArchUnsqueezeScalarInput(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) { const auto& input_defs(node.InputDefs()); @@ -74,7 +73,6 @@ void HandleX86ArchUnsqueezeScalarInput(ModelBuilder& model_builder, AddOperationOutput(*op, *node.OutputDefs()[0]); model_builder.AddOperation(std::move(op)); } -#endif Status SqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, @@ -83,7 +81,7 @@ Status SqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, auto* coreml_squeeze = layer->mutable_squeeze(); TensorShapeVector axes; GetAxes(model_builder, node, axes); -#if defined(COREML_ENABLE_MLPROGRAM) + const auto& input_defs(node.InputDefs()); if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; @@ -105,9 +103,7 @@ Status SqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } AddOperationOutput(*op, *node.OutputDefs()[0]); model_builder.AddOperation(std::move(op)); - } else // NOLINT -#endif - { + } else { if (axes.empty()) { coreml_squeeze->set_squeezeall(true); } else { diff --git a/onnxruntime/core/providers/coreml/builders/impl/transpose_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/transpose_op_builder.cc index 831c4cf4d08ba..5bb7e4c11967a 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/transpose_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/transpose_op_builder.cc @@ -34,7 +34,6 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, ORT_RETURN_IF_NOT(perm.size() == input_dims, "Perm and input should have same dimension"); } -#if defined(COREML_ENABLE_MLPROGRAM) if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; @@ -44,9 +43,7 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, AddOperationOutput(*op, *node.OutputDefs()[0]); model_builder.AddOperation(std::move(op)); - } else -#endif // defined(COREML_ENABLE_MLPROGRAM) - { + } else { std::unique_ptr layer = model_builder.CreateNNLayer(node); *layer->mutable_transpose()->mutable_axes() = {perm.cbegin(), perm.cend()}; diff --git a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc index bc3cad004aec1..dd495894ab8bb 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc @@ -25,7 +25,6 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const const auto& op_type(node.OpType()); const auto& input_defs(node.InputDefs()); -#if defined(COREML_ENABLE_MLPROGRAM) if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; @@ -58,9 +57,7 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const AddOperationOutput(*op, *node.OutputDefs()[0]); model_builder.AddOperation(std::move(op)); - } else // NOLINT -#endif // defined (COREML_ENABLE_MLPROGRAM) - { + } else { std::unique_ptr layer = model_builder.CreateNNLayer(node); if (op_type == "Sqrt") { diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.cc b/onnxruntime/core/providers/coreml/builders/model_builder.cc index f8952301d59a9..3551f5759201e 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/model_builder.cc @@ -17,20 +17,17 @@ #include "core/providers/coreml/shape_utils.h" #include "core/optimizer/initializer.h" -#if defined(COREML_ENABLE_MLPROGRAM) // includes from coremltools-src in _deps #include "modelpackage/src/ModelPackage.hpp" #include "mlmodel/src/MILBlob/Blob/StorageWriter.hpp" using MILBlob::Blob::StorageWriter; -#endif - using namespace CoreML::Specification; namespace onnxruntime { namespace coreml { namespace { -#if defined(COREML_ENABLE_MLPROGRAM) + // Should the initializer be written to file or kept as an immediate value bool ShouldWriteInitializerToWeightsFile(const ONNX_NAMESPACE::TensorProto& tensor_proto) { // https://github.com/apple/coremltools/blob/dbb0094fd0cb936469e35320bf37e866ef7a1da4/coremltools/converters/mil/backend/mil/load.py#L51-L57 @@ -388,8 +385,6 @@ void CreateEmptyFile(const std::string& filename) { ORT_ENFORCE(file.is_open(), "Failed to open file ", filename); } -#endif // defined(COREML_ENABLE_MLPROGRAM) - std::string GetModelOutputPath(const CoreMLOptions& coreml_options, const GraphViewer& graph_viewer, const logging::Logger& logger) { @@ -479,7 +474,6 @@ ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logge } if (create_ml_program_) { -#if defined(COREML_ENABLE_MLPROGRAM) coreml_model_->set_specificationversion(CoreMLSpecVersion()); MILSpec::Program& mlprogram = *coreml_model_->mutable_mlprogram(); mlprogram.set_version(1); @@ -503,12 +497,6 @@ ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logge "CoreML Model Weights"); auto weights_info = mlpackage_->findItem(weights_id); weights_file_writer_ = std::make_unique(weights_info->path() + "/weight.bin"); -#else - // should never happen due to handling in coreml_execution_provider.cc - // throw here so all other code in this class can assume create_ml_program_ is only ever true in a build - // where ML Program support is enabled. - ORT_THROW("ML Program is not enabled in this build"); -#endif } else { // We support CorelML Specification Version 4 (Core ML 3) coreml_model_->set_specificationversion(4); @@ -561,7 +549,6 @@ void ModelBuilder::AddLayer(std::unique_ptr layer) { /* * ML Program related helpers */ -#if defined(COREML_ENABLE_MLPROGRAM) const std::string& ModelBuilder::GetSafeName(const std::string& name) { // Check the name is valid according to the MILSpec rules // `Identifiers, generally used for names and keys, must match the regular expression [A-Za-z\_][A-Za-z0-9\_@]*.` @@ -737,8 +724,6 @@ std::string_view ModelBuilder::AddConstantImpl(std::string_view op_type, std::st return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); } -#endif // defined(COREML_ENABLE_MLPROGRAM) - /* * General implementation */ @@ -775,13 +760,10 @@ Status ModelBuilder::RegisterInitializers() { continue; } -#if defined(COREML_ENABLE_MLPROGRAM) if (create_ml_program_) { MILSpec::Value coreml_tensor = OnnxTensorToCoreMLTensor(tensor, *weights_file_writer_); ORT_IGNORE_RETURN_VALUE(AddConstantOperation(name, std::move(coreml_tensor))); - } else -#endif - { + } else { std::unique_ptr layer = std::make_unique(); layer->set_name(GetUniqueName("initializer_" + name)); @@ -915,7 +897,6 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i return Status::OK(); } -#if defined(COREML_ENABLE_MLPROGRAM) if (create_ml_program_) { if (is_input) { // the model inputs need to be wired up as args to the 'main' function. @@ -935,7 +916,6 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i *mlprogram_main_block_->mutable_outputs()->Add() = name; } } -#endif // defined(COREML_ENABLE_MLPROGRAM) return Status::OK(); } @@ -980,11 +960,9 @@ Status ModelBuilder::CreateModel() { ORT_RETURN_IF_ERROR(ProcessNodes()); ORT_RETURN_IF_ERROR(RegisterModelOutputs()); -#if defined(COREML_ENABLE_MLPROGRAM) if (create_ml_program_) { SanitizeNames(); } -#endif return Status::OK(); } @@ -992,7 +970,6 @@ Status ModelBuilder::CreateModel() { Status ModelBuilder::SaveModel() { std::string output_path = model_output_path_; -#if defined(COREML_ENABLE_MLPROGRAM) if (create_ml_program_) { // we need to jump through some hoops to get the model path the ML Program load wants. std::string tmp_model_path = model_output_path_ + "/tmp/model.mlmodel"; @@ -1003,7 +980,6 @@ Status ModelBuilder::SaveModel() { auto model_info = mlpackage_->findItem(model_id); output_path = model_info->path(); } -#endif // scope this so the stream is closed and flushed by the ofstream dtor { @@ -1012,19 +988,16 @@ Status ModelBuilder::SaveModel() { ORT_RETURN_IF_NOT(coreml_model_->SerializeToOstream(&stream), "Saving the CoreML model failed. Path=", output_path); } -#if defined(COREML_ENABLE_MLPROGRAM) // need to delete the ModelPackage instance for it to write out the manifest. clear out the other ML Program // related types as well. mlprogram_main_block_ = nullptr; mlpackage_.reset(); weights_file_writer_.reset(); -#endif return Status::OK(); } Status ModelBuilder::LoadModel(std::unique_ptr& model) { -#if defined(COREML_ENABLE_MLPROGRAM) if (create_ml_program_) { // we need to provide the sanitized names for model inputs/outputs so that info is captured. // the input/output matching when we execute the model from the CoreML EP is based on order, so the change @@ -1058,9 +1031,7 @@ Status ModelBuilder::LoadModel(std::unique_ptr& model) { std::move(scalar_outputs_), std::move(int64_outputs_), logger_, coreml_options_); - } else -#endif - { + } else { model = std::make_unique(model_output_path_, std::move(onnx_input_names_), std::move(onnx_output_names_), @@ -1073,7 +1044,6 @@ Status ModelBuilder::LoadModel(std::unique_ptr& model) { return model->LoadModel(); // load using CoreML API, including compilation } -#if defined(COREML_ENABLE_MLPROGRAM) std::string_view ModelBuilder::AddConstant(std::string_view op_type, std::string_view value_type, const ONNX_NAMESPACE::TensorProto& tensor, std::optional> shape) { @@ -1114,7 +1084,6 @@ std::string_view ModelBuilder::AddConstant(std::string_view op_type, std::string return ret; } -#endif // static Status ModelBuilder::Build(const GraphViewer& graph_viewer, const logging::Logger& logger, int32_t coreml_version, const CoreMLOptions& coreml_options, diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.h b/onnxruntime/core/providers/coreml/builders/model_builder.h index 28c7dc42da581..59963e07d1266 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.h +++ b/onnxruntime/core/providers/coreml/builders/model_builder.h @@ -9,7 +9,7 @@ #include "core/providers/coreml/model/model.h" #include "core/providers/coreml/coreml_options.h" -#if defined(COREML_ENABLE_MLPROGRAM) +#if defined(USE_COREML) // coremltools classes namespace MPL { class ModelPackage; @@ -58,7 +58,7 @@ class ModelBuilder { // Returns true if we are creating an ML Program bool CreateMLProgram() const { -#if defined(COREML_ENABLE_MLPROGRAM) +#if defined(USE_COREML) return create_ml_program_; #else return false; @@ -76,7 +76,7 @@ class ModelBuilder { // Add layer to the Core ML NeuralNetwork model void AddLayer(std::unique_ptr layer); -#if defined(COREML_ENABLE_MLPROGRAM) +#if defined(USE_COREML) /* * MLProgram helpers */ @@ -176,7 +176,7 @@ class ModelBuilder { const logging::Logger& Logger() const { return logger_; } private: -#if defined(COREML_ENABLE_MLPROGRAM) +#if defined(USE_COREML) template std::string_view AddConstantImpl(std::string_view op_type, std::string_view value_type, gsl::span value, std::optional> shape = std::nullopt); @@ -237,7 +237,7 @@ class ModelBuilder { uint32_t name_token_{0}; std::unordered_set unique_names_; -#if defined(COREML_ENABLE_MLPROGRAM) +#if defined(USE_COREML) // mlprogram_main_ is the main block of the CoreML ML Program. // It is set in CreateModel to the CoreML Model.mlprogram.functions['main'].block_specializations['CoreML'] // entry we create. diff --git a/onnxruntime/core/providers/coreml/coreml_options.cc b/onnxruntime/core/providers/coreml/coreml_options.cc index 14ae55de9266b..c441a2eff56e0 100644 --- a/onnxruntime/core/providers/coreml/coreml_options.cc +++ b/onnxruntime/core/providers/coreml/coreml_options.cc @@ -15,18 +15,6 @@ CoreMLOptions::CoreMLOptions(uint32_t coreml_flags) { create_mlprogram_ = (coreml_flags & COREML_FLAG_CREATE_MLPROGRAM) != 0; enable_on_subgraph_ = (coreml_flags & COREML_FLAG_ENABLE_ON_SUBGRAPH) != 0; -#if defined(COREML_ENABLE_MLPROGRAM) - if (coreml::util::CoreMLVersion() < MINIMUM_COREML_MLPROGRAM_VERSION && create_mlprogram_ != 0) { - LOGS_DEFAULT(WARNING) << "ML Program is not supported on this OS version. Falling back to NeuralNetwork."; - create_mlprogram_ = false; - } -#else - if (create_mlprogram_ != 0) { - LOGS_DEFAULT(WARNING) << "ML Program is not supported in this build. Falling back to NeuralNetwork."; - create_mlprogram_ = false; - } -#endif - compute_units_ = 0; // 0 for all if (coreml_flags & COREML_FLAG_USE_CPU_ONLY) { diff --git a/onnxruntime/core/providers/coreml/model/host_utils.h b/onnxruntime/core/providers/coreml/model/host_utils.h index 145c64e5320d3..df5f8ad4b0790 100644 --- a/onnxruntime/core/providers/coreml/model/host_utils.h +++ b/onnxruntime/core/providers/coreml/model/host_utils.h @@ -54,8 +54,7 @@ #endif -#define MINIMUM_COREML_VERSION 3 // first version we support -#define MINIMUM_COREML_MLPROGRAM_VERSION 5 // first version where ML Program was available +#define MINIMUM_COREML_VERSION 5 // first version we support namespace onnxruntime { namespace coreml { diff --git a/onnxruntime/core/providers/xnnpack/nn/max_pool.cc b/onnxruntime/core/providers/xnnpack/nn/max_pool.cc index c828ae9400174..8d972f7d63bc1 100644 --- a/onnxruntime/core/providers/xnnpack/nn/max_pool.cc +++ b/onnxruntime/core/providers/xnnpack/nn/max_pool.cc @@ -57,7 +57,7 @@ bool MaxPool::IsOnnxNodeSupported(const NodeUnit& node_unit, // input of maxpool could be fp16/fp32/fp64,i8/u8 according to ONNX if (x_type == nullptr || (x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && -// because pool_fp16_op_test can be enabled by other preprocessor, for example, COREML_ENABLE_MLPROGRAM +// because pool_fp16_op_test can be enabled by other preprocessor, for example, USE_COREML #ifdef XNNPACK_FP16_SUPPORTED x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 && #endif diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index 4611dc9082734..e22445edc0f5b 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -404,7 +404,7 @@ TYPED_TEST(LayerNormTest, LayerNorm17_opset) { // Execution provider entry invalid. // when other EPs support layer-norm fp16, this test should be updated to include them. if (std::is_same::value) { -#if !defined(COREML_ENABLE_MLPROGRAM) +#if !defined(USE_COREML) return; #endif } diff --git a/onnxruntime/test/providers/coreml/coreml_basic_test.cc b/onnxruntime/test/providers/coreml/coreml_basic_test.cc index a9aa78b7a3229..3505193b77683 100644 --- a/onnxruntime/test/providers/coreml/coreml_basic_test.cc +++ b/onnxruntime/test/providers/coreml/coreml_basic_test.cc @@ -246,7 +246,7 @@ TEST(CoreMLExecutionProviderTest, TestOrtFormatModel) { #endif } -#if defined(COREML_ENABLE_MLPROGRAM) +#if defined(USE_COREML) // Names in CoreML cannot start with [0-9] or contain anything but "[a-z][A-Z][0-9]_" // Test that we fix invalid names in model inputs, initializers and outputs. // This is only enforced for ML Program, so we only do name sanitization when creating an ML Program format model. diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc index 724118d7419d2..9201da348e75c 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc @@ -125,7 +125,7 @@ TEST_F(ActivationOpTest, Relu) { {}, {}, /*is_tensorrt_supported=*/false, /*opset_version= */ 14); -#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(USE_COREML) TestActivationOp( "Relu", input_values_fp16, @@ -139,7 +139,7 @@ TEST_F(ActivationOpTest, Relu) { #endif // MLAS_F16VEC_INTRINSICS_SUPPORTED } -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_COREML) TEST_F(ActivationOpTest, Sigmoid_fp16) { #ifdef USE_CUDA int min_cuda_architecture = 530; @@ -413,7 +413,7 @@ TEST_F(ActivationOpTest, LeakyRelu) { {{"alpha", alpha}}, {}); } -#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(USE_COREML) TEST_F(ActivationOpTest, LeakyRelu_fp16) { OpTester test("LeakyRelu", 11); float alpha = 0.01f; // oneDNN set alpha equal to 0.01 diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.h b/onnxruntime/test/providers/cpu/activation/activation_op_test.h index 59813f433dc41..04d116e29d3b0 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.h +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.h @@ -105,7 +105,7 @@ class ActivationOpTest : public ::testing::Test { std::random_device rd; std::mt19937 gen(rd()); std::uniform_real_distribution dist(low, high); -#ifdef COREML_ENABLE_MLPROGRAM +#ifdef USE_COREML // please check onnxruntime/onnxruntime/core/providers/coreml/builders/helper.cc:81 std::vector batch_size_list = {1, 2, 4, 9, 100}; #else diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index 5fd83ac1ad61b..61b31c8c5c656 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -32,7 +32,7 @@ void TestBinaryFloat16(const char* op_name, bool enable_bf16 = true) { { std::vector> execution_providers; -#ifdef COREML_ENABLE_MLPROGRAM +#ifdef USE_COREML execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); #elif USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); @@ -76,7 +76,7 @@ void TestUnaryFloat16(const char* op_name, bool run_bf16 = true) { { std::vector> execution_providers; -#ifdef COREML_ENABLE_MLPROGRAM +#ifdef USE_COREML execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); #elif USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); @@ -1409,7 +1409,7 @@ TEST(MathOpTest, Pow_float16_float16) { dims, {1.0f, 256.0f, 2.0f, 1.0f}, false); } -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_COREML) TEST(MathOpTest, Pow_float_float16) { OpTester test("Pow", 12); std::vector dims{4}; @@ -1423,7 +1423,7 @@ TEST(MathOpTest, Pow_float_float16) { execution_providers.push_back(DefaultCudaExecutionProvider()); #elif USE_ROCM execution_providers.push_back(DefaultRocmExecutionProvider()); -#elif COREML_ENABLE_MLPROGRAM +#elif USE_COREML execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); #endif test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); diff --git a/onnxruntime/test/providers/cpu/math/matmul_test.cc b/onnxruntime/test/providers/cpu/math/matmul_test.cc index 298e870f348fc..dd8cbed15e5ef 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_test.cc @@ -210,7 +210,7 @@ TEST(MathOpTest, MatMulFloatType) { RunMatMulTest(7, false, true); } -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) || defined(USE_XNNPACK) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_COREML) || defined(USE_XNNPACK) TEST(MathOpTest, MatMulFloat16) { #ifdef USE_CUDA int min_cuda_architecture = 530; @@ -276,7 +276,7 @@ TEST(MathOpTest, MatMulZeroKInt32Type) { RunMatMulZeroKTest(); } -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) || defined(USE_XNNPACK) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_COREML) || defined(USE_XNNPACK) TEST(MathOpTest, MatMul_Float16) { #ifdef USE_CUDA int min_cuda_architecture = 530; diff --git a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc index 08c4e608aada3..a25e71898a9e8 100644 --- a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc @@ -704,7 +704,7 @@ TEST(BatchNormTest, NonSpatial_Complicated) { } // Only CUDA and ROCm kernels have float 16 support -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_COREML) TEST(BatchNormTest, BatchNorm2d_fp16) { vector X{-0.91221f, -0.283559f, 0.937637f, 2.09818f, -0.100199f, -0.608113f, 0.444562f, -1.07505f, 0.940591f, -0.922262f, 0.0931303f, 0.69611f, 1.55187f, 0.159808f, 0.914874f, -1.24856f, -1.98928f, -0.331621f, diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index 4253e36e02548..d1350db8ec12e 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -3,7 +3,7 @@ #include "core/mlas/inc/mlas.h" -#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) || defined(USE_XNNPACK) +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(USE_COREML) || defined(USE_XNNPACK) #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" @@ -30,7 +30,7 @@ struct ConvOpAndTestAttributes { /* Please notice that, we have predefined macros in the head of the file -#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(USE_COREML) When we have these two macro defines, this UT will turn into green light and work. If attributes.activation is set the NhwcFusedConv contrib op is used. diff --git a/onnxruntime/test/providers/cpu/nn/group_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/group_norm_op_test.cc index ac517193a2c77..3d8d188867023 100644 --- a/onnxruntime/test/providers/cpu/nn/group_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/group_norm_op_test.cc @@ -6,7 +6,7 @@ #include "test/common/tensor_op_test_utils.h" #include "test/util/include/default_providers.h" -#ifdef COREML_ENABLE_MLPROGRAM +#ifdef USE_COREML using namespace std; namespace onnxruntime { namespace test { diff --git a/onnxruntime/test/providers/cpu/nn/instance_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/instance_norm_op_test.cc index 341bb8a4fc957..46b74f2c2eb9d 100644 --- a/onnxruntime/test/providers/cpu/nn/instance_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/instance_norm_op_test.cc @@ -121,7 +121,7 @@ TEST(InstanceNormalizationOpTest, InstanceNormBatch2) { } // Only CUDA and ROCm kernels have float 16 support -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_COREML) TEST(InstanceNormalizationOpTest, InstanceNormBatch1_fp16) { OpTester test("InstanceNormalization"); diff --git a/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc index d4e0af5011525..c14fc1fb62ae5 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc @@ -3,7 +3,7 @@ #include "core/mlas/inc/mlas.h" -#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) || defined(USE_XNNPACK) +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(USE_COREML) || defined(USE_XNNPACK) #include "core/providers/cpu/nn/pool.h" #include "gtest/gtest.h" diff --git a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc index 24a8c8491b632..f1d612276174f 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc @@ -70,7 +70,7 @@ TEST(PoolTest, MaxPool) { // Only CUDA kernel has float 16 support // Disable for now, still investigating the issue with cudnn lib -#if defined(USE_CUDA) || defined(COREML_ENABLE_MLPROGRAM) +#if defined(USE_CUDA) || defined(USE_COREML) TEST(PoolTest, MaxPool_F16) { #if defined(USE_CUDA) int min_cuda_architecture = 530; diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc index 61a16d41e3e59..19bfc0f90709f 100644 --- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc +++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc @@ -1375,7 +1375,7 @@ TEST(ReductionOpTest, ReduceMax_double) { test.Run(); } -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_COREML) TEST(ReductionOpTest, ReduceMax_half) { OpTester test("ReduceMax"); test.AddAttribute("axes", std::vector{1, 2}); @@ -2158,7 +2158,7 @@ TEST(ReductionOpTest, ReduceMin_double) { test.Run(); } -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_COREML) TEST(ReductionOpTest, ReduceMin_half) { OpTester test("ReduceMin"); test.AddAttribute("axes", std::vector{0, 2}); @@ -2356,7 +2356,7 @@ TEST(ReductionOpTest, ReduceSum_int32) { test.Run(); } -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_COREML) TEST(ReductionOpTest, ReduceSumHalfHalf) { OpTester test("ReduceSum"); test.AddAttribute("keepdims", (int64_t)0);