Skip to content

Commit

Permalink
ACL EP GEMM improvements (#2780)
Browse files Browse the repository at this point in the history
When it is posible we use a fully connected layer instead of the gemm implementation.
This will let the library use the best implementation based on the input data.
  • Loading branch information
Andrews548 authored and jywu-msft committed Jan 7, 2020
1 parent f22bffe commit fdc0106
Showing 1 changed file with 20 additions and 38 deletions.
58 changes: 20 additions & 38 deletions onnxruntime/core/providers/acl/math/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
//NEON
#include "arm_compute/runtime/NEON/functions/NEGEMM.h"
#include "arm_compute/runtime/NEON/functions/NETranspose.h"
#include "arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h"

#undef GEMM_ACL
#define CACHE_TRANSPOSED_DATA
Expand All @@ -29,7 +30,7 @@ namespace onnxruntime {
namespace acl {

typedef struct {
std::shared_ptr<arm_compute::NEGEMM> layer;
std::shared_ptr<arm_compute::IFunction> layer;
std::shared_ptr<arm_compute::Tensor> a, b, c, d;
std::shared_ptr<arm_compute::MemoryManagerOnDemand> mm_layer;
} ACLNEGEMM;
Expand All @@ -51,7 +52,6 @@ class Gemm : public onnxruntime::Gemm<T> {
ORT_ENFORCE(info.GetAttr<float>("beta", &beta_).IsOK());
}

#ifdef GEMM_ACL
Status Compute(OpKernelContext* context) const override {
const auto X = context->Input<Tensor>(0);
const auto W = context->Input<Tensor>(1);
Expand All @@ -66,6 +66,8 @@ class Gemm : public onnxruntime::Gemm<T> {
int64_t N = helper.N();
auto Y = context->Output(0, TensorShape({M, N}));

bool FC = ((alpha_ == 1 && beta_ == 1) || (alpha_ == 1 && beta_ == 0));

int64_t K = helper.K();
LOGS_DEFAULT(VERBOSE) << "Gemm ACL:" << std::endl;
if (X) LOGS_DEFAULT(VERBOSE) << "X " << X->Shape().ToString().c_str() << std::endl;
Expand Down Expand Up @@ -93,7 +95,7 @@ class Gemm : public onnxruntime::Gemm<T> {
tGEMM.d->allocator()->init(arm_compute::TensorInfo(arm_compute::TensorShape(N, M), tGEMM.a->info()->format()));

// transpose
if (trans_B_ == CblasTrans) {
if (!FC && trans_B_ == CblasTrans) {
auto trans_layer = std::make_shared<arm_compute::NETranspose>();
tGEMM.b->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(W->Shape()), arm_compute::Format::F32));

Expand All @@ -113,13 +115,23 @@ class Gemm : public onnxruntime::Gemm<T> {
}

tGEMM.mm_layer = ACLCreateMemoryManager();
tGEMM.layer = std::make_shared<arm_compute::NEGEMM>(tGEMM.mm_layer);

// configure GEMM
tGEMM.layer->configure(tGEMM.a.get(), tGEMM.b.get(), tGEMM.c.get(), tGEMM.d.get(), alpha_, beta_, arm_compute::GEMMInfo());
if(FC) {
auto layer = std::make_shared<arm_compute::NEFullyConnectedLayer>(tGEMM.mm_layer);
layer->configure(tGEMM.a.get(), tGEMM.b.get(), (B != nullptr && beta_ != 0) ? tGEMM.c.get() : nullptr, tGEMM.d.get());
tGEMM.layer = std::move(layer);
} else {
#ifdef GEMM_ACL
auto layer = std::make_shared<arm_compute::NEGEMM>(tGEMM.mm_layer);
layer->configure(tGEMM.a.get(), tGEMM.b.get(), (B != nullptr && beta_ != 0) ? tGEMM.c.get() : nullptr, tGEMM.d.get(), alpha_, beta_, arm_compute::GEMMInfo());
tGEMM.layer = std::move(layer);
#else
return onnxruntime::Gemm<T>::Compute(context);
#endif
}

// non-transpose
if (trans_B_ != CblasTrans) {
if (FC || trans_B_ != CblasTrans) {
const T* b_data = W->template Data<T>();
ACLImportMemory(tGEMM.b->allocator(), (void*)b_data, W->Shape().Size() * 4);
}
Expand All @@ -132,7 +144,7 @@ class Gemm : public onnxruntime::Gemm<T> {
pGEMM = &it->second;

// transpose
if (trans_B_ == CblasTrans) {
if (!FC && trans_B_ == CblasTrans) {
#ifndef CACHE_TRANSPOSED_DATA
auto trans_layer = std::make_shared<arm_compute::NETranspose>();
pGEMM->b->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(W->Shape()), arm_compute::Format::F32));
Expand Down Expand Up @@ -193,36 +205,6 @@ class Gemm : public onnxruntime::Gemm<T> {
gemmLayers.erase(this);
}

#else

Status Compute(OpKernelContext* context) const override {
const auto X = context->Input<Tensor>(0);
const auto W = context->Input<Tensor>(1);
const auto B = context->Input<Tensor>(2);

GemmHelper helper(X->Shape(), trans_A_ != CblasNoTrans, W->Shape(), trans_B_ != CblasNoTrans, B->Shape());

if (!helper.State().IsOK())
return helper.State();

int64_t M = helper.M();
int64_t N = helper.N();
auto Y = context->Output(0, TensorShape({M, N}));

int64_t K = helper.K();
LOGS_DEFAULT(VERBOSE) << "Gemm CPU:" << std::endl;
if (X) LOGS_DEFAULT(VERBOSE) << "X " << X->Shape().ToString().c_str() << std::endl;
if (W) LOGS_DEFAULT(VERBOSE) << "W " << W->Shape().ToString().c_str() << std::endl;
if (B) LOGS_DEFAULT(VERBOSE) << "B " << B->Shape().ToString().c_str() << std::endl;
LOGS_DEFAULT(VERBOSE) << "Y " << Y->Shape().ToString().c_str() << std::endl;
LOGS_DEFAULT(VERBOSE) << "M " << (int)M << ", N " << (int)N << ", K " << (int)K << std::endl;
LOGS_DEFAULT(VERBOSE) << "Alfa " << alpha_ << ", Beta " << beta_ << std::endl;
LOGS_DEFAULT(VERBOSE) << std::endl;

return onnxruntime::Gemm<T>::Compute(context);
}
#endif

private:
static thread_local std::map<OpKernel*, ACLNEGEMM> gemmLayers;

Expand Down

0 comments on commit fdc0106

Please sign in to comment.