From 334989fd7886ce7eb04076b44b2fa27846fd2a77 Mon Sep 17 00:00:00 2001 From: Nikolas Klauser Date: Sat, 3 Feb 2024 00:48:49 +0100 Subject: [PATCH] [CIR][LibOpt] Extend std::find optimization to all calls with raw pointers (#400) This also adds a missing check whether the pointer returned from `memchr` is null and changes the result to `last` in that case. --- clang/lib/CIR/Dialect/Transforms/LibOpt.cpp | 105 ++++++++++++-------- clang/test/CIR/Transforms/lib-opt-find.cpp | 50 ++++++++-- 2 files changed, 110 insertions(+), 45 deletions(-) diff --git a/clang/lib/CIR/Dialect/Transforms/LibOpt.cpp b/clang/lib/CIR/Dialect/Transforms/LibOpt.cpp index 2422613a5315..762ee961bcba 100644 --- a/clang/lib/CIR/Dialect/Transforms/LibOpt.cpp +++ b/clang/lib/CIR/Dialect/Transforms/LibOpt.cpp @@ -120,29 +120,18 @@ static bool containerHasStaticSize(StructType t, unsigned &size) { } void LibOptPass::xformStdFindIntoMemchr(StdFindOp findOp) { - // First and second operands need to be iterators begin() and end(). - // TODO: look over cir.loads until we have a mem2reg + other passes - // to help out here. - auto iterBegin = dyn_cast(findOp.getOperand(0).getDefiningOp()); - if (!iterBegin) - return; - if (!isa(findOp.getOperand(1).getDefiningOp())) - return; - - // Both operands have the same type, use iterBegin. - - // Look at this pointer to retrieve container information. - auto thisPtr = - iterBegin.getOperand().getType().cast().getPointee(); - auto containerTy = dyn_cast(thisPtr); - if (!containerTy) - return; - - if (!isSequentialContainer(containerTy)) - return; - - unsigned staticSize = 0; - if (!containerHasStaticSize(containerTy, staticSize)) + // template + // requires (sizeof(T) == 1 && is_integral_v) + // T* find(T* first, T* last, T value) { + // if (auto result = __builtin_memchr(first, value, last - first)) + // return result; + // return last; + // } + + auto first = findOp.getOperand(0); + auto last = findOp.getOperand(1); + auto value = findOp->getOperand(2); + if (!first.getType().isa() || !last.getType().isa()) return; // Transformation: @@ -150,9 +139,9 @@ void LibOptPass::xformStdFindIntoMemchr(StdFindOp findOp) { // - Assert the Iterator is a pointer to primitive type. // - Check IterBeginOp is char sized. TODO: add other types that map to // char size. - auto iterResTy = iterBegin.getResult().getType().dyn_cast(); + auto iterResTy = findOp.getType().dyn_cast(); assert(iterResTy && "expected pointer type for iterator"); - auto underlyingDataTy = iterResTy.getPointee().dyn_cast(); + auto underlyingDataTy = iterResTy.getPointee().dyn_cast(); if (!underlyingDataTy || underlyingDataTy.getWidth() != 8) return; @@ -160,7 +149,7 @@ void LibOptPass::xformStdFindIntoMemchr(StdFindOp findOp) { // - Check it's a pointer type. // - Load the pattern from memory // - cast it to `int`. - auto patternAddrTy = findOp.getOperand(2).getType().dyn_cast(); + auto patternAddrTy = value.getType().dyn_cast(); if (!patternAddrTy || patternAddrTy.getPointee() != underlyingDataTy) return; @@ -169,27 +158,65 @@ void LibOptPass::xformStdFindIntoMemchr(StdFindOp findOp) { CIRBaseBuilderTy builder(getContext()); builder.setInsertionPointAfter(findOp.getOperation()); - auto memchrOp0 = builder.createBitcast( - iterBegin.getLoc(), iterBegin.getResult(), builder.getVoidPtrTy()); + auto memchrOp0 = + builder.createBitcast(first.getLoc(), first, builder.getVoidPtrTy()); // FIXME: get datalayout based "int" instead of fixed size 4. - auto loadPattern = builder.create( - findOp.getOperand(2).getLoc(), underlyingDataTy, findOp.getOperand(2)); + auto loadPattern = + builder.create(value.getLoc(), underlyingDataTy, value); auto memchrOp1 = builder.createIntCast( loadPattern, IntType::get(builder.getContext(), 32, true)); - // FIXME: get datalayout based "size_t" instead of fixed size 64. - auto uInt64Ty = IntType::get(builder.getContext(), 64, false); - auto memchrOp2 = builder.create( - findOp.getLoc(), uInt64Ty, mlir::cir::IntAttr::get(uInt64Ty, staticSize)); + const auto uInt64Ty = IntType::get(builder.getContext(), 64, false); // Build memchr op: // void *memchr(const void *s, int c, size_t n); - auto memChr = builder.create(findOp.getLoc(), memchrOp0, memchrOp1, - memchrOp2); - mlir::Operation *result = - builder.createBitcast(findOp.getLoc(), memChr.getResult(), iterResTy) - .getDefiningOp(); + auto memChr = [&] { + if (auto iterBegin = dyn_cast(first.getDefiningOp()); + iterBegin && isa(last.getDefiningOp())) { + // Both operands have the same type, use iterBegin. + + // Look at this pointer to retrieve container information. + auto thisPtr = + iterBegin.getOperand().getType().cast().getPointee(); + auto containerTy = dyn_cast(thisPtr); + + unsigned staticSize = 0; + if (containerTy && isSequentialContainer(containerTy) && + containerHasStaticSize(containerTy, staticSize)) { + return builder.create( + findOp.getLoc(), memchrOp0, memchrOp1, + builder.create( + findOp.getLoc(), uInt64Ty, + mlir::cir::IntAttr::get(uInt64Ty, staticSize))); + } + } + return builder.create( + findOp.getLoc(), memchrOp0, memchrOp1, + builder.create(findOp.getLoc(), uInt64Ty, last, first)); + }(); + + auto MemChrResult = + builder.createBitcast(findOp.getLoc(), memChr.getResult(), iterResTy); + + // if (result) + // return result; + // else + // return last; + auto NullPtr = builder.create( + findOp.getLoc(), first.getType(), ConstPtrAttr::get(first.getType(), 0)); + auto CmpResult = builder.create( + findOp.getLoc(), BoolType::get(builder.getContext()), CmpOpKind::eq, + NullPtr.getRes(), MemChrResult); + + auto result = builder.create( + findOp.getLoc(), CmpResult.getResult(), + [&](mlir::OpBuilder &ob, mlir::Location Loc) { + ob.create(Loc, last); + }, + [&](mlir::OpBuilder &ob, mlir::Location Loc) { + ob.create(Loc, MemChrResult); + }); findOp.replaceAllUsesWith(result); findOp.erase(); diff --git a/clang/test/CIR/Transforms/lib-opt-find.cpp b/clang/test/CIR/Transforms/lib-opt-find.cpp index a1a3f81d065d..4812e72d8037 100644 --- a/clang/test/CIR/Transforms/lib-opt-find.cpp +++ b/clang/test/CIR/Transforms/lib-opt-find.cpp @@ -3,16 +3,18 @@ #include "std-cxx.h" -int test_find(unsigned char n = 3) +int test1(unsigned char n = 3) { + // CHECK: test1 unsigned num_found = 0; // CHECK: %[[pattern_addr:.*]] = cir.alloca !u8i, cir.ptr , ["n" std::array v = {1, 2, 3, 4, 5, 6, 7, 8, 9}; auto f = std::find(v.begin(), v.end(), n); - // CHECK: %[[begin:.*]] = cir.call @_ZNSt5arrayIhLj9EE5beginEv - // CHECK: cir.call @_ZNSt5arrayIhLj9EE3endEv - // CHECK: %[[cast_to_void:.*]] = cir.cast(bitcast, %[[begin]] : !cir.ptr), !cir.ptr + + // CHECK: %[[first:.*]] = cir.call @_ZNSt5arrayIhLj9EE5beginEv + // CHECK: %[[last:.*]] = cir.call @_ZNSt5arrayIhLj9EE3endEv + // CHECK: %[[cast_to_void:.*]] = cir.cast(bitcast, %[[first]] : !cir.ptr), !cir.ptr // CHECK: %[[load_pattern:.*]] = cir.load %[[pattern_addr]] : cir.ptr , !u8i // CHECK: %[[pattern:.*]] = cir.cast(integral, %[[load_pattern:.*]] : !u8i), !s32i @@ -20,9 +22,45 @@ int test_find(unsigned char n = 3) // CHECK: %[[array_size:.*]] = cir.const(#cir.int<9> : !u64i) : !u64i // CHECK: %[[result_cast:.*]] = cir.libc.memchr(%[[cast_to_void]], %[[pattern]], %[[array_size]]) - // CHECK: cir.cast(bitcast, %[[result_cast]] : !cir.ptr), !cir.ptr + // CHECK: %[[memchr_res:.*]] = cir.cast(bitcast, %[[result_cast]] : !cir.ptr), !cir.ptr + // CHECK: %[[nullptr:.*]] = cir.const(#cir.ptr : !cir.ptr) : !cir.ptr + // CHECK: %[[cmp_res:.*]] = cir.cmp(eq, %[[nullptr]], %[[memchr_res]]) : !cir.ptr, !cir.bool + // CHECK: cir.ternary(%[[cmp_res]], true { + // CHECK: cir.yield %[[last]] : !cir.ptr + // CHECK: }, false { + // CHECK: cir.yield %[[memchr_res]] : !cir.ptr + // CHECK: }) : (!cir.bool) -> !cir.ptr + if (f != v.end()) num_found++; return num_found; -} \ No newline at end of file +} + +unsigned char* test2(unsigned char* first, unsigned char* last, unsigned char v) +{ + return std::find(first, last, v); + // CHECK: test2 + + // CHECK: %[[first_storage:.*]] = cir.alloca !cir.ptr, cir.ptr >, ["first", init] + // CHECK: %[[last_storage:.*]] = cir.alloca !cir.ptr, cir.ptr >, ["last", init] + // CHECK: %[[pattern_storage:.*]] = cir.alloca !u8i, cir.ptr , ["v", init] + // CHECK: %[[first:.*]] = cir.load %[[first_storage]] + // CHECK: %[[last:.*]] = cir.load %[[last_storage]] + // CHECK: %[[cast_to_void:.*]] = cir.cast(bitcast, %[[first]] : !cir.ptr), !cir.ptr + // CHECK: %[[load_pattern:.*]] = cir.load %[[pattern_storage]] : cir.ptr , !u8i + // CHECK: %[[pattern:.*]] = cir.cast(integral, %[[load_pattern:.*]] : !u8i), !s32i + + // CHECK-NOT: {{.*}} cir.call @_ZSt4findIPhhET_S1_S1_RKT0_( + // CHECK: %[[array_size:.*]] = cir.ptr_diff(%[[last]], %[[first]]) : !cir.ptr -> !u64i + + // CHECK: %[[result_cast:.*]] = cir.libc.memchr(%[[cast_to_void]], %[[pattern]], %[[array_size]]) + // CHECK: %[[memchr_res:.*]] = cir.cast(bitcast, %[[result_cast]] : !cir.ptr), !cir.ptr + // CHECK: %[[nullptr:.*]] = cir.const(#cir.ptr : !cir.ptr) : !cir.ptr + // CHECK: %[[cmp_res:.*]] = cir.cmp(eq, %[[nullptr]], %[[memchr_res]]) : !cir.ptr, !cir.bool + // CHECK: cir.ternary(%[[cmp_res]], true { + // CHECK: cir.yield %[[last]] : !cir.ptr + // CHECK: }, false { + // CHECK: cir.yield %[[memchr_res]] : !cir.ptr + // CHECK: }) : (!cir.bool) -> !cir.ptr +}