diff --git a/src/ir/daphneir/DaphneDialect.cpp b/src/ir/daphneir/DaphneDialect.cpp index df0aff2d2..9030c9582 100644 --- a/src/ir/daphneir/DaphneDialect.cpp +++ b/src/ir/daphneir/DaphneDialect.cpp @@ -896,6 +896,20 @@ mlir::OpFoldResult mlir::daphne::ConcatOp::fold(FoldAdaptor adaptor) { return {}; } +mlir::OpFoldResult mlir::daphne::StringEqOp::fold(FoldAdaptor adaptor) { + ArrayRef operands = adaptor.getOperands(); + assert(operands.size() == 2 && "binary op takes two operands"); + if (!operands[0] || !operands[1] || !llvm::isa(operands[0]) || + !isa(operands[1])) { + return {}; + } + + auto lhs = operands[0].cast(); + auto rhs = operands[1].cast(); + + return mlir::BoolAttr::get(getContext(), lhs.getValue() == rhs.getValue()); +} + mlir::OpFoldResult mlir::daphne::EwEqOp::fold(FoldAdaptor adaptor) { ArrayRef operands = adaptor.getOperands(); auto floatOp = [](const llvm::APFloat &a, const llvm::APFloat &b) { return a == b; }; @@ -1151,6 +1165,30 @@ struct SimplifyDistributeRead : public mlir::OpRewritePattern(lhs.getType()); + const bool rhsIsStr = llvm::isa(rhs.getType()); + + if (!lhsIsStr && !rhsIsStr) return mlir::failure(); + + mlir::Type strTy = mlir::daphne::StringType::get(rewriter.getContext()); + if (!lhsIsStr) + lhs = rewriter.create(op.getLoc(), strTy, lhs); + if (!rhsIsStr) + rhs = rewriter.create(op.getLoc(), strTy, rhs); + + rewriter.replaceOpWithNewOp( + op, rewriter.getI1Type(), lhs, rhs); + return mlir::success(); +} + /** * @brief Replaces (1) `a + b` by `a concat b`, if `a` or `b` is a string, * and (2) `a + X` by `X + a` (`a` scalar, `X` matrix/frame). diff --git a/src/ir/daphneir/DaphneOps.td b/src/ir/daphneir/DaphneOps.td index 2d846c7ff..35b437916 100644 --- a/src/ir/daphneir/DaphneOps.td +++ b/src/ir/daphneir/DaphneOps.td @@ -326,6 +326,12 @@ def Daphne_ConcatOp : Daphne_Op<"concat", [DataTypeSca, ValueTypeStr]> { let hasFolder = 1; } +def Daphne_StringEqOp : Daphne_Op<"stringEq", [ValueTypeStr]> { + let arguments = (ins StrScalar:$lhs, StrScalar:$rhs); + let results = (outs BoolScalar:$res); + let hasFolder = 1; +} + // ---------------------------------------------------------------------------- // Comparisons // ---------------------------------------------------------------------------- @@ -337,7 +343,9 @@ class Daphne_EwCmpOp traits = []> //let results = (outs AnyTypeOf<[MatrixOf<[BoolScalar]>, BoolScalar, Unknown]>:$res); } -def Daphne_EwEqOp : Daphne_EwCmpOp<"ewEq" , AnyScalar, [Commutative]>; +def Daphne_EwEqOp : Daphne_EwCmpOp<"ewEq" , AnyScalar, [Commutative]> { + let hasCanonicalizeMethod = 1; +} def Daphne_EwNeqOp : Daphne_EwCmpOp<"ewNeq", AnyScalar, [Commutative, CUDASupport]>; def Daphne_EwLtOp : Daphne_EwCmpOp<"ewLt" , AnyScalar>; def Daphne_EwLeOp : Daphne_EwCmpOp<"ewLe" , AnyScalar>; diff --git a/src/runtime/local/kernels/StringEq.h b/src/runtime/local/kernels/StringEq.h new file mode 100644 index 000000000..4d3eed900 --- /dev/null +++ b/src/runtime/local/kernels/StringEq.h @@ -0,0 +1,25 @@ +/* + * Copyright 2024 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include + +inline void stringEq(bool *res, const char *lhs, const char *rhs, DCTX(ctx)) { + *res = std::string_view(lhs) == std::string_view(rhs); +} diff --git a/src/runtime/local/kernels/kernels.json b/src/runtime/local/kernels/kernels.json index 696f4e188..55d1381c3 100644 --- a/src/runtime/local/kernels/kernels.json +++ b/src/runtime/local/kernels/kernels.json @@ -696,6 +696,31 @@ } ] }, + { + "kernelTemplate": { + "header": "StringEq.h", + "opName": "stringEq", + "returnType": "void", + "templateParams": [], + "runtimeParams": [ + { + "type": "bool *", + "name": "res" + }, + { + "type": "const char *", + "name": "lhs" + }, + { + "type": "const char *", + "name": "rhs" + } + ] + }, + "instantiations": [ + [] + ] + }, { "kernelTemplate": { "header": "Concat.h", @@ -4022,4 +4047,4 @@ [] ] } -] \ No newline at end of file +] diff --git a/test/codegen/stringeq.mlir b/test/codegen/stringeq.mlir new file mode 100644 index 000000000..3a5d18fe8 --- /dev/null +++ b/test/codegen/stringeq.mlir @@ -0,0 +1,43 @@ +// RUN: daphne-opt --canonicalize %s | FileCheck %s + +func.func @string_string() { + %0 = "daphne.constant"() {value = "debug"} : () -> !daphne.String + %1 = "daphne.constant"() {value = "debug"} : () -> !daphne.String + // CHECK-NOT: daphne.ewEq + %2 = "daphne.ewEq"(%0, %1) : (!daphne.String, !daphne.String) -> !daphne.String + %3 = "daphne.cast"(%2) : (!daphne.String) -> i1 + "daphne.print"(%0, %3, %3) : (!daphne.String, i1, i1) -> () + "daphne.return"() : () -> () +} + +func.func @string_int() { + // CHECK-NOT: daphne.eqEq + // CHECK: daphne.cast + // CHECK: daphne.stringEq + %0 = "daphne.constant"() {value = "debug"} : () -> !daphne.String + %1 = "daphne.constant"() {value = 5 : si64} : () -> si64 + %2 = "daphne.ewEq"(%0, %1) : (!daphne.String, si64) -> !daphne.String + %3 = "daphne.cast"(%2) : (!daphne.String) -> i1 + "daphne.print"(%0, %3, %3) : (!daphne.String, i1, i1) -> () + "daphne.return"() : () -> () +} + +func.func @int_int_do_not_canonicalize() { + %0 = "daphne.constant"() {value = 2 : si64} : () -> si64 + %1 = "daphne.constant"() {value = 5 : si64} : () -> si64 + %2 = "daphne.ewEq"(%0, %1) : (si64, si64) -> si64 + // CHECK-NOT: daphne.stringEq + %3 = "daphne.cast"(%2) : (si64) -> i1 + scf.if %3 { + %4 = "daphne.constant"() {value = "debug"} : () -> !daphne.String + %5 = "daphne.constant"() {value = true} : () -> i1 + %6 = "daphne.constant"() {value = false} : () -> i1 + "daphne.print"(%4, %5, %6) : (!daphne.String, i1, i1) -> () + } else { + %4 = "daphne.constant"() {value = "release"} : () -> !daphne.String + %5 = "daphne.constant"() {value = true} : () -> i1 + %6 = "daphne.constant"() {value = false} : () -> i1 + "daphne.print"(%4, %5, %6) : (!daphne.String, i1, i1) -> () + } + "daphne.return"() : () -> () +}