Skip to content

Commit

Permalink
xpu elementwise_add int32 (#5539)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhupengyang authored Mar 9, 2021
1 parent c20e240 commit 40c3744
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 11 deletions.
2 changes: 2 additions & 0 deletions lite/kernels/host/range_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ REGISTER_LITE_KERNEL(range, kHost, kInt32, kAny, range_int32, def)
DATALAYOUT(kAny))})
.Finalize();

#ifdef LITE_BUILD_EXTRA
// float kernel has higher score when picking kernel.
using range_int32_f =
paddle::lite::kernels::host::RangeCompute<int, PRECISION(kFloat)>;
Expand All @@ -122,3 +123,4 @@ REGISTER_LITE_KERNEL(range, kHost, kFloat, kAny, range_int32_f, int32)
PRECISION(kInt32),
DATALAYOUT(kAny))})
.Finalize();
#endif // LITE_BUILD_EXTRA
34 changes: 23 additions & 11 deletions lite/kernels/xpu/elementwise_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ namespace lite {
namespace kernels {
namespace xpu {

void ElementwiseAddCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
template <class T>
void ElementwiseAddCompute<T>::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<XPUContext>();

auto& x_dim = param.X->dims();
auto& y_dim = param.Y->dims();
Expand All @@ -47,12 +48,12 @@ void ElementwiseAddCompute::Run() {
}

int ret =
xdnn::broadcast_add<float>(ctx.GetRawContext(),
param.X->data<float>(),
param.Y->data<float>(),
param.Out->mutable_data<float>(TARGET(kXPU)),
x_shape,
y_shape);
xdnn::broadcast_add<T>(ctx.GetRawContext(),
param.X->template data<T>(),
param.Y->template data<T>(),
param.Out->template mutable_data<T>(TARGET(kXPU)),
x_shape,
y_shape);

CHECK_EQ(ret, 0);
return;
Expand Down Expand Up @@ -163,13 +164,24 @@ REGISTER_LITE_KERNEL(elementwise_add,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::ElementwiseAddCompute,
def)
paddle::lite::kernels::xpu::ElementwiseAddCompute<float>,
float32)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();

REGISTER_LITE_KERNEL(elementwise_add,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::ElementwiseAddCompute<int>,
int32)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))})
.Finalize();

REGISTER_LITE_KERNEL(elementwise_mul,
kXPU,
kFloat,
Expand Down
1 change: 1 addition & 0 deletions lite/kernels/xpu/elementwise_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace lite {
namespace kernels {
namespace xpu {

template <class T>
class ElementwiseAddCompute
: public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
Expand Down

0 comments on commit 40c3744

Please sign in to comment.