Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1. Enabling inline for Polygeist Ops 2. fix for unary pre decrement op to have lvalue by reference #382

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
36 changes: 36 additions & 0 deletions lib/polygeist/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,46 @@

#include "polygeist/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Transforms/InliningUtils.h"
#include "polygeist/Ops.h"

using namespace mlir;
using namespace mlir::polygeist;

//===----------------------------------------------------------------------===//
// PolygeistDialect Interfaces
//===----------------------------------------------------------------------===//

namespace {
/// This class defines the interface for handling inlining with polygeist
/// operations.
struct PolygeistInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;

//===--------------------------------------------------------------------===//
// Analysis Hooks
//===--------------------------------------------------------------------===//

/// Returns true if the given region 'src' can be inlined into the region
/// 'dest' that is attached to an operation registered to the current dialect.
/// 'wouldBeCloned' is set if the region is cloned into its new location
/// rather than moved, indicating there may be other users.
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
IRMapping &valueMapping) const final {
return true;
}

/// Returns true if the given operation 'op', that is registered to this
/// dialect, can be inlined into the given region, false otherwise.
bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
IRMapping &valueMapping) const final {
return true;
}

/// Polygeist regions should be analyzed recursively.
bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
};
} // namespace
//===----------------------------------------------------------------------===//
// Polygeist dialect.
//===----------------------------------------------------------------------===//
Expand All @@ -22,6 +57,7 @@ void PolygeistDialect::initialize() {
#define GET_OP_LIST
#include "polygeist/PolygeistOps.cpp.inc"
>();
addInterfaces<PolygeistInlinerInterface>();
}

#include "polygeist/PolygeistOpsDialect.cpp.inc"
71 changes: 71 additions & 0 deletions test/polygeist-opt/inline.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// RUN: polygeist-opt --inline --allow-unregistered-dialect %s | FileCheck %s

module {
func.func @eval_cost(%arg0: !llvm.ptr, %arg1: i1) -> memref<1x!llvm.struct<(f32, array<3 x f32>)>> {
%1 = "polygeist.pointer2memref"(%arg0) : (!llvm.ptr) -> memref<1x!llvm.struct<(f32, array<3 x f32>)>>
return %1 : memref<1x!llvm.struct<(f32, array<3 x f32>)>>
}
func.func @eval_res(%arg0: memref<1x!llvm.struct<(f32, array<3 x f32>)>>) -> !llvm.struct<(f32, array<3 x f32>)>
{
%c0_i1 = arith.constant 0 : i1
%alloca = memref.alloca() : memref<1x!llvm.struct<(f32, array<3 x f32>)>>
%0 = "polygeist.memref2pointer"(%alloca) : (memref<1x!llvm.struct<(f32, array<3 x f32>)>>) -> !llvm.ptr
%1 = func.call @eval_cost(%0, %c0_i1) : (!llvm.ptr, i1) -> memref<1x!llvm.struct<(f32, array<3 x f32>)>>
%2 = affine.load %1[0] : memref<1x!llvm.struct<(f32, array<3 x f32>)>>
return %2 : !llvm.struct<(f32, array<3 x f32>)>
}

// CHECK-LABEL: func.func @eval_cost(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr,
// CHECK-SAME: %[[VAL_1:.*]]: i1) -> memref<1x!llvm.struct<(f32, array<3 x f32>)>> {
// CHECK: %[[VAL_2:.*]] = "polygeist.pointer2memref"(%[[VAL_0]]) : (!llvm.ptr) -> memref<1x!llvm.struct<(f32, array<3 x f32>)>>
// CHECK: return %[[VAL_2]] : memref<1x!llvm.struct<(f32, array<3 x f32>)>>
// CHECK: }

// CHECK-LABEL: func.func @eval_res(
// CHECK-SAME: %[[VAL_0:.*]]: memref<1x!llvm.struct<(f32, array<3 x f32>)>>) -> !llvm.struct<(f32, array<3 x f32>)> {
// CHECK: %[[VAL_1:.*]] = memref.alloca() : memref<!llvm.struct<(f32, array<3 x f32>)>>
// CHECK: %[[VAL_2:.*]] = affine.load %[[VAL_1]][] : memref<!llvm.struct<(f32, array<3 x f32>)>>
// CHECK: return %[[VAL_2]] : !llvm.struct<(f32, array<3 x f32>)>
// CHECK: }

func.func private @use(%arg0: index, %arg1: index) -> index{
%0 = arith.addi %arg0, %arg1 : index
return %0 : index
}

func.func @f1(%gd : index, %bd : index) {
%mc0 = arith.constant 0 : index
%mc4 = arith.constant 4 : index
%mc1024 = arith.constant 1024 : index
%err = "polygeist.gpu_wrapper"() ({
affine.parallel (%a1, %a2, %a3) = (0, 0, 0) to (%gd, %mc4, %bd) {
"polygeist.noop"(%a3, %mc0, %mc0) {polygeist.noop_type="gpu_kernel.thread_only"} : (index, index, index) -> ()
%a1r = func.call @use(%a1,%mc4) : (index, index) -> (index)
%a2r = func.call @use(%a2,%a1r) : (index, index) -> (index)
%a3r = func.call @use(%a3,%a2r) : (index, index) -> (index)
"test.something"(%a3r) : (index) -> ()
}
"polygeist.polygeist_yield"() : () -> ()
}) : () -> index
return
}
// CHECK-LABEL: func.func @f1(
// CHECK-SAME: %[[VAL_0:.*]]: index,
// CHECK-SAME: %[[VAL_1:.*]]: index) {
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_3:.*]] = arith.constant 4 : index
// CHECK: %[[VAL_4:.*]] = "polygeist.gpu_wrapper"() ({
// CHECK: affine.parallel (%[[VAL_5:.*]], %[[VAL_6:.*]], %[[VAL_7:.*]]) = (0, 0, 0) to (symbol(%[[VAL_0]]), 4, symbol(%[[VAL_1]])) {
// CHECK: "polygeist.noop"(%[[VAL_7]], %[[VAL_2]], %[[VAL_2]]) {polygeist.noop_type = "gpu_kernel.thread_only"} : (index, index, index) -> ()
// CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_5]], %[[VAL_3]] : index
// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_6]], %[[VAL_8]] : index
// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_7]], %[[VAL_9]] : index
// CHECK: "test.something"(%[[VAL_10]]) : (index) -> ()
// CHECK: }
// CHECK: "polygeist.polygeist_yield"() : () -> ()
// CHECK: }) : () -> index
// CHECK: return
// CHECK: }

}
8 changes: 4 additions & 4 deletions tools/cgeist/Lib/clang-mlir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2323,10 +2323,10 @@ ValueCategory MLIRScanner::VisitUnaryOperator(clang::UnaryOperator *U) {
builder.create<ConstantIntOp>(loc, 1, ty.cast<mlir::IntegerType>()));
}
sub.store(loc, builder, next);
return ValueCategory(
(U->getOpcode() == clang::UnaryOperator::Opcode::UO_PostDec) ? prev
: next,
/*isReference*/ false);
if (U->getOpcode() == clang::UnaryOperator::Opcode::UO_PreDec)
return sub;
else
return ValueCategory(prev, /*isReference*/ false);
}
case clang::UnaryOperator::Opcode::UO_Real:
case clang::UnaryOperator::Opcode::UO_Imag: {
Expand Down
2 changes: 1 addition & 1 deletion tools/cgeist/Test/Verification/arrayconsllvm.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: cgeist %s --function=* -S | FileCheck %s
// RUN: cgeist %s --no-inline --function=* -S | FileCheck %s

struct AIntDivider {
AIntDivider() : divisor(3) {}
Expand Down
2 changes: 1 addition & 1 deletion tools/cgeist/Test/Verification/arrayconsmemrefinner.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: cgeist %s --function=* -S | FileCheck %s
// RUN: cgeist %s --no-inline --function=* -S | FileCheck %s

struct AIntDivider {
AIntDivider() : divisor(3) {}
Expand Down
4 changes: 2 additions & 2 deletions tools/cgeist/Test/Verification/base_cast.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: cgeist %s --function=* -S | FileCheck %s
// RUN: cgeist %s --function=* --struct-abi=0 -memref-abi=0 -S | FileCheck %s --check-prefix CHECK-STR
// RUN: cgeist %s --no-inline --function=* -S | FileCheck %s
// RUN: cgeist %s --no-inline --function=* --struct-abi=0 -memref-abi=0 -S | FileCheck %s --check-prefix CHECK-STR


struct A {
Expand Down
2 changes: 1 addition & 1 deletion tools/cgeist/Test/Verification/base_nostructabi.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: cgeist %s --function=* --struct-abi=0 -memref-abi=0 -S | FileCheck %s
// RUN: cgeist %s --no-inline --function=* --struct-abi=0 -memref-abi=0 -S | FileCheck %s

void run0(void*);
void run1(void*);
Expand Down
2 changes: 1 addition & 1 deletion tools/cgeist/Test/Verification/base_with_virt.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: cgeist %s --function=* -S | FileCheck %s
// RUN: cgeist %s --no-inline --function=* -S | FileCheck %s

class M {
};
Expand Down
2 changes: 1 addition & 1 deletion tools/cgeist/Test/Verification/caff.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: cgeist %s --function=* -S | FileCheck %s
// RUN: cgeist %s --no-inline --function=* -S | FileCheck %s

struct AOperandInfo {
void* data;
Expand Down
2 changes: 1 addition & 1 deletion tools/cgeist/Test/Verification/capture.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: cgeist %s --function=* -S | FileCheck %s
// RUN: cgeist %s --no-inline --function=* -S | FileCheck %s

extern "C" {

Expand Down
60 changes: 34 additions & 26 deletions tools/cgeist/Test/Verification/consabi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,42 +15,50 @@ QStream ilaunch_kernel(QStream x) {
}

// CHECK-LABEL: func.func @_Z14ilaunch_kernel7QStream(
// CHECK-SAME: %[[VAL_0:[A-Za-z0-9_]*]]: !llvm.struct<(struct<(f64, f64)>, i32)>) -> !llvm.struct<(struct<(f64, f64)>, i32)>
// CHECK: %[[VAL_1:[A-Za-z0-9_]*]] = memref.alloca() : memref<1x!llvm.struct<(struct<(f64, f64)>, i32)>>
// CHECK: %[[VAL_2:[A-Za-z0-9_]*]] = memref.cast %[[VAL_1]] : memref<1x!llvm.struct<(struct<(f64, f64)>, i32)>> to memref<?x!llvm.struct<(struct<(f64, f64)>, i32)>>
// CHECK: %[[VAL_3:[A-Za-z0-9_]*]] = memref.alloca() : memref<1x!llvm.struct<(struct<(f64, f64)>, i32)>>
// CHECK: %[[VAL_4:[A-Za-z0-9_]*]] = memref.cast %[[VAL_3]] : memref<1x!llvm.struct<(struct<(f64, f64)>, i32)>> to memref<?x!llvm.struct<(struct<(f64, f64)>, i32)>>
// CHECK: affine.store %[[VAL_0]], %[[VAL_3]][0] : memref<1x!llvm.struct<(struct<(f64, f64)>, i32)>>
// CHECK: call @_ZN7QStreamC1EOS_(%[[VAL_2]], %[[VAL_4]]) : (memref<?x!llvm.struct<(struct<(f64, f64)>, i32)>>, memref<?x!llvm.struct<(struct<(f64, f64)>, i32)>>) -> ()
// CHECK: %[[VAL_5:[A-Za-z0-9_]*]] = affine.load %[[VAL_1]][0] : memref<1x!llvm.struct<(struct<(f64, f64)>, i32)>>
// CHECK: return %[[VAL_5]] : !llvm.struct<(struct<(f64, f64)>, i32)>
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(struct<(f64, f64)>, i32)>) -> !llvm.struct<(struct<(f64, f64)>, i32)> attributes {llvm.linkage = #llvm.linkage<external>} {
// CHECK: %[[VAL_1:.*]] = memref.alloca() : memref<1x!llvm.struct<(struct<(f64, f64)>, i32)>>
// CHECK: %[[VAL_2:.*]] = memref.alloca() : memref<1x!llvm.struct<(struct<(f64, f64)>, i32)>>
// CHECK: affine.store %[[VAL_0]], %[[VAL_2]][0] : memref<1x!llvm.struct<(struct<(f64, f64)>, i32)>>
// CHECK: %[[VAL_3:.*]] = "polygeist.memref2pointer"(%[[VAL_1]]) : (memref<1x!llvm.struct<(struct<(f64, f64)>, i32)>>) -> !llvm.ptr
// CHECK: %[[VAL_4:.*]] = "polygeist.memref2pointer"(%[[VAL_2]]) : (memref<1x!llvm.struct<(struct<(f64, f64)>, i32)>>) -> !llvm.ptr
// CHECK: %[[VAL_5:.*]] = llvm.load %[[VAL_4]] : !llvm.ptr -> f64
// CHECK: llvm.store %[[VAL_5]], %[[VAL_3]] : f64, !llvm.ptr
// CHECK: %[[VAL_6:.*]] = llvm.getelementptr %[[VAL_4]][1] : (!llvm.ptr) -> !llvm.ptr, f64
// CHECK: %[[VAL_7:.*]] = llvm.load %[[VAL_6]] : !llvm.ptr -> f64
// CHECK: %[[VAL_8:.*]] = llvm.getelementptr %[[VAL_3]][1] : (!llvm.ptr) -> !llvm.ptr, f64
// CHECK: llvm.store %[[VAL_7]], %[[VAL_8]] : f64, !llvm.ptr
// CHECK: %[[VAL_9:.*]] = llvm.getelementptr %[[VAL_4]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(struct<(f64, f64)>, i32)>
// CHECK: %[[VAL_10:.*]] = llvm.load %[[VAL_9]] : !llvm.ptr -> i32
// CHECK: %[[VAL_11:.*]] = llvm.getelementptr %[[VAL_3]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(struct<(f64, f64)>, i32)>
// CHECK: llvm.store %[[VAL_10]], %[[VAL_11]] : i32, !llvm.ptr
// CHECK: %[[VAL_12:.*]] = affine.load %[[VAL_1]][0] : memref<1x!llvm.struct<(struct<(f64, f64)>, i32)>>
// CHECK: return %[[VAL_12]] : !llvm.struct<(struct<(f64, f64)>, i32)>
// CHECK: }

// CHECK-LABEL: func.func @_ZN7QStreamC1EOS_(
// CHECK-SAME: %[[VAL_0:[A-Za-z0-9_]*]]: memref<?x!llvm.struct<(struct<(f64, f64)>, i32)>>,
// CHECK-SAME: %[[VAL_1:[A-Za-z0-9_]*]]: memref<?x!llvm.struct<(struct<(f64, f64)>, i32)>>)
// CHECK: %[[VAL_2:[A-Za-z0-9_]*]] = "polygeist.memref2pointer"(%[[VAL_0]]) : (memref<?x!llvm.struct<(struct<(f64, f64)>, i32)>>) -> !llvm.ptr
// CHECK: %[[VAL_3:[A-Za-z0-9_]*]] = "polygeist.memref2pointer"(%[[VAL_1]]) : (memref<?x!llvm.struct<(struct<(f64, f64)>, i32)>>) -> !llvm.ptr
// CHECK: %[[VAL_4:[A-Za-z0-9_]*]] = llvm.load %[[VAL_3]] : !llvm.ptr -> f64
// CHECK-SAME: %[[VAL_0:.*]]: memref<?x!llvm.struct<(struct<(f64, f64)>, i32)>>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<?x!llvm.struct<(struct<(f64, f64)>, i32)>>) attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
// CHECK: %[[VAL_2:.*]] = "polygeist.memref2pointer"(%[[VAL_0]]) : (memref<?x!llvm.struct<(struct<(f64, f64)>, i32)>>) -> !llvm.ptr
// CHECK: %[[VAL_3:.*]] = "polygeist.memref2pointer"(%[[VAL_1]]) : (memref<?x!llvm.struct<(struct<(f64, f64)>, i32)>>) -> !llvm.ptr
// CHECK: %[[VAL_4:.*]] = llvm.load %[[VAL_3]] : !llvm.ptr -> f64
// CHECK: llvm.store %[[VAL_4]], %[[VAL_2]] : f64, !llvm.ptr
// CHECK: %[[VAL_5:[A-Za-z0-9_]*]] = llvm.getelementptr %[[VAL_3]][1] : (!llvm.ptr) -> !llvm.ptr, f64
// CHECK: %[[VAL_6:[A-Za-z0-9_]*]] = llvm.load %[[VAL_5]] : !llvm.ptr -> f64
// CHECK: %[[VAL_7:[A-Za-z0-9_]*]] = llvm.getelementptr %[[VAL_2]][1] : (!llvm.ptr) -> !llvm.ptr, f64
// CHECK: %[[VAL_5:.*]] = llvm.getelementptr %[[VAL_3]][1] : (!llvm.ptr) -> !llvm.ptr, f64
// CHECK: %[[VAL_6:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr -> f64
// CHECK: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_2]][1] : (!llvm.ptr) -> !llvm.ptr, f64
// CHECK: llvm.store %[[VAL_6]], %[[VAL_7]] : f64, !llvm.ptr
// CHECK: %[[VAL_8:[A-Za-z0-9_]*]] = llvm.getelementptr %[[VAL_3]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(struct<(f64, f64)>, i32)>
// CHECK: %[[VAL_9:[A-Za-z0-9_]*]] = llvm.load %[[VAL_8]] : !llvm.ptr -> i32
// CHECK: %[[VAL_10:[A-Za-z0-9_]*]] = llvm.getelementptr %[[VAL_2]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(struct<(f64, f64)>, i32)>
// CHECK: %[[VAL_8:.*]] = llvm.getelementptr %[[VAL_3]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(struct<(f64, f64)>, i32)>
// CHECK: %[[VAL_9:.*]] = llvm.load %[[VAL_8]] : !llvm.ptr -> i32
// CHECK: %[[VAL_10:.*]] = llvm.getelementptr %[[VAL_2]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(struct<(f64, f64)>, i32)>
// CHECK: llvm.store %[[VAL_9]], %[[VAL_10]] : i32, !llvm.ptr
// CHECK: return
// CHECK: }

// CHECK-LABEL: func.func @_ZN1DC1EOS_(
// CHECK-SAME: %[[VAL_0:[A-Za-z0-9_]*]]: memref<?x2xf64>,
// CHECK-SAME: %[[VAL_1:[A-Za-z0-9_]*]]: memref<?x2xf64>)
// CHECK: %[[VAL_2:[A-Za-z0-9_]*]] = affine.load %[[VAL_1]][0, 0] : memref<?x2xf64>
// CHECK-SAME: %[[VAL_0:.*]]: memref<?x2xf64>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<?x2xf64>) attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
// CHECK: %[[VAL_2:.*]] = affine.load %[[VAL_1]][0, 0] : memref<?x2xf64>
// CHECK: affine.store %[[VAL_2]], %[[VAL_0]][0, 0] : memref<?x2xf64>
// CHECK: %[[VAL_3:[A-Za-z0-9_]*]] = affine.load %[[VAL_1]][0, 1] : memref<?x2xf64>
// CHECK: %[[VAL_3:.*]] = affine.load %[[VAL_1]][0, 1] : memref<?x2xf64>
// CHECK: affine.store %[[VAL_3]], %[[VAL_0]][0, 1] : memref<?x2xf64>
// CHECK: return
// CHECK: }

// CHECK: }
8 changes: 4 additions & 4 deletions tools/cgeist/Test/Verification/cugen.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ void start(double* w) {
}

// CHECK: func.func @_Z5startPd(%arg0: memref<?xf64>)
// CHECK-NEXT: %cst = arith.constant 2.000000e+00 : f64
// CHECK-NEXT: %c0 = arith.constant 0 : index
// CHECK-NEXT: %c20 = arith.constant 20 : index
// CHECK-NEXT: %c1 = arith.constant 1 : index
// CHECK-DAG: %cst = arith.constant 2.000000e+00 : f64
// CHECK-DAG: %c0 = arith.constant 0 : index
// CHECK-DAG: %c20 = arith.constant 20 : index
// CHECK-DAG: %c1 = arith.constant 1 : index
// CHECK-NEXT: scf.parallel (%arg1) = (%c0) to (%c20) step (%c1) {
// CHECK-NEXT: memref.store %cst, %arg0[%arg1] : memref<?xf64>
// CHECK-NEXT: scf.yield
Expand Down
Loading
Loading