Skip to content

Commit

Permalink
copyfromhost/copytohost are not needed for mkldnn ep (#1532)
Browse files Browse the repository at this point in the history
* memcpy is not necessary for mkldnn ep to copy from/to host.

* update
  • Loading branch information
linkerzhang authored Aug 1, 2019
1 parent 624411b commit 1cf5ebc
Showing 1 changed file with 3 additions and 28 deletions.
31 changes: 3 additions & 28 deletions onnxruntime/core/providers/mkldnn/mkldnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,18 @@
#pragma warning(disable : 4996)
#endif

#include "mkldnn_execution_provider.h"
#include "core/framework/allocator.h"
#include "core/framework/memcpy.h"
#include "core/framework/kernel_registry.h"
#include "mkldnn_fwd.h"
#include "core/framework/compute_capability.h"
#include "core/framework/kernel_registry.h"
#include "core/providers/mkldnn/subgraph/mkldnn_func_kernel.h"
#include "mkldnn_execution_provider.h"
#include "mkldnn_fwd.h"

namespace onnxruntime {

constexpr const char* MKLDNN = "MklDnn";
constexpr const char* MKLDNN_CPU = "MklDnnCpu";

namespace mkl_dnn {

ONNX_OPERATOR_KERNEL_EX(
MemcpyFromHost,
kOnnxDomain,
1,
kMklDnnExecutionProvider,
KernelDefBuilder().InputMemoryType<OrtMemTypeCPUInput>(0).TypeConstraint("T", DataTypeImpl::AllTensorTypes()),
Memcpy);

ONNX_OPERATOR_KERNEL_EX(
MemcpyToHost,
kOnnxDomain,
1,
kMklDnnExecutionProvider,
KernelDefBuilder().OutputMemoryType<OrtMemTypeCPUOutput>(0).TypeConstraint("T", DataTypeImpl::AllTensorTypes()),
Memcpy);

} // namespace mkl_dnn

MKLDNNExecutionProvider::MKLDNNExecutionProvider(const MKLDNNExecutionProviderInfo& info)
: IExecutionProvider{onnxruntime::kMklDnnExecutionProvider} {
DeviceAllocatorRegistrationInfo default_allocator_info({OrtMemTypeDefault,
Expand Down Expand Up @@ -65,8 +44,6 @@ MKLDNNExecutionProvider::~MKLDNNExecutionProvider() {
namespace mkl_dnn {
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kMklDnnExecutionProvider, kOnnxDomain, 1, Conv);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kMklDnnExecutionProvider, kOnnxDomain, 7, Gemm);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kMklDnnExecutionProvider, kOnnxDomain, 1, MemcpyFromHost);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kMklDnnExecutionProvider, kOnnxDomain, 1, MemcpyToHost);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kMklDnnExecutionProvider, kOnnxDomain, 6, Relu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kMklDnnExecutionProvider, kOnnxDomain, 6, Sum);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kMklDnnExecutionProvider, kOnnxDomain, 7, BatchNormalization);
Expand All @@ -81,8 +58,6 @@ void RegisterMKLDNNKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kMklDnnExecutionProvider, kOnnxDomain, 1, Conv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kMklDnnExecutionProvider, kOnnxDomain, 7, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kMklDnnExecutionProvider, kOnnxDomain, 1, MemcpyFromHost)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kMklDnnExecutionProvider, kOnnxDomain, 1, MemcpyToHost)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kMklDnnExecutionProvider, kOnnxDomain, 6, Relu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kMklDnnExecutionProvider, kOnnxDomain, 6, Sum)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kMklDnnExecutionProvider, kOnnxDomain, 7, BatchNormalization)>,
Expand Down

0 comments on commit 1cf5ebc

Please sign in to comment.