Skip to content

Commit

Permalink
Adds support for string equivalence, close daphne-eu#581
Browse files Browse the repository at this point in the history
This patch adds support for equality checks between strings.
Examples:
* string with program argument
```
a = "foo";
if (a == $b)
```
* compare program arguments
```
if ($a == $b)
```
  • Loading branch information
philipportner committed Feb 12, 2024
1 parent 8f5ca5c commit 4c78578
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 2 deletions.
38 changes: 38 additions & 0 deletions src/ir/daphneir/DaphneDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,20 @@ mlir::OpFoldResult mlir::daphne::ConcatOp::fold(FoldAdaptor adaptor) {
return {};
}

mlir::OpFoldResult mlir::daphne::StringEqOp::fold(FoldAdaptor adaptor) {
ArrayRef<Attribute> operands = adaptor.getOperands();
assert(operands.size() == 2 && "binary op takes two operands");
if (!operands[0] || !operands[1] || !llvm::isa<StringAttr>(operands[0]) ||
!isa<StringAttr>(operands[1])) {
return {};
}

auto lhs = operands[0].cast<StringAttr>();
auto rhs = operands[1].cast<StringAttr>();

return mlir::BoolAttr::get(getContext(), lhs.getValue() == rhs.getValue());
}

mlir::OpFoldResult mlir::daphne::EwEqOp::fold(FoldAdaptor adaptor) {
ArrayRef<Attribute> operands = adaptor.getOperands();
auto floatOp = [](const llvm::APFloat &a, const llvm::APFloat &b) { return a == b; };
Expand Down Expand Up @@ -1151,6 +1165,30 @@ struct SimplifyDistributeRead : public mlir::OpRewritePattern<mlir::daphne::Dist
}
};

// The EwBinarySca kernel does not handle string types in any way. In order to
// support simple string equivalence checks this canonicalizer rewrites the
// EwEqOp to the StringEqOp if one of the operands is of daphne::StringType.
mlir::LogicalResult mlir::daphne::EwEqOp::canonicalize(
mlir::daphne::EwEqOp op, PatternRewriter &rewriter) {
mlir::Value lhs = op.getLhs();
mlir::Value rhs = op.getRhs();

const bool lhsIsStr = llvm::isa<mlir::daphne::StringType>(lhs.getType());
const bool rhsIsStr = llvm::isa<mlir::daphne::StringType>(rhs.getType());

if (!lhsIsStr && !rhsIsStr) return mlir::failure();

mlir::Type strTy = mlir::daphne::StringType::get(rewriter.getContext());
if (!lhsIsStr)
lhs = rewriter.create<mlir::daphne::CastOp>(op.getLoc(), strTy, lhs);
if (!rhsIsStr)
rhs = rewriter.create<mlir::daphne::CastOp>(op.getLoc(), strTy, rhs);

rewriter.replaceOpWithNewOp<mlir::daphne::StringEqOp>(
op, rewriter.getI1Type(), lhs, rhs);
return mlir::success();
}

/**
* @brief Replaces (1) `a + b` by `a concat b`, if `a` or `b` is a string,
* and (2) `a + X` by `X + a` (`a` scalar, `X` matrix/frame).
Expand Down
10 changes: 9 additions & 1 deletion src/ir/daphneir/DaphneOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,12 @@ def Daphne_ConcatOp : Daphne_Op<"concat", [DataTypeSca, ValueTypeStr]> {
let hasFolder = 1;
}

def Daphne_StringEqOp : Daphne_Op<"stringEq", [ValueTypeStr]> {
let arguments = (ins StrScalar:$lhs, StrScalar:$rhs);
let results = (outs BoolScalar:$res);
let hasFolder = 1;
}

// ----------------------------------------------------------------------------
// Comparisons
// ----------------------------------------------------------------------------
Expand All @@ -337,7 +343,9 @@ class Daphne_EwCmpOp<string name, Type inputScalarType, list<Trait> traits = []>
//let results = (outs AnyTypeOf<[MatrixOf<[BoolScalar]>, BoolScalar, Unknown]>:$res);
}

def Daphne_EwEqOp : Daphne_EwCmpOp<"ewEq" , AnyScalar, [Commutative]>;
def Daphne_EwEqOp : Daphne_EwCmpOp<"ewEq" , AnyScalar, [Commutative]> {
let hasCanonicalizeMethod = 1;
}
def Daphne_EwNeqOp : Daphne_EwCmpOp<"ewNeq", AnyScalar, [Commutative, CUDASupport]>;
def Daphne_EwLtOp : Daphne_EwCmpOp<"ewLt" , AnyScalar>;
def Daphne_EwLeOp : Daphne_EwCmpOp<"ewLe" , AnyScalar>;
Expand Down
25 changes: 25 additions & 0 deletions src/runtime/local/kernels/StringEq.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright 2024 The DAPHNE Consortium
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <runtime/local/context/DaphneContext.h>

#include <string_view>

inline void stringEq(bool *res, const char *lhs, const char *rhs, DCTX(ctx)) {
*res = std::string_view(lhs) == std::string_view(rhs);
}
27 changes: 26 additions & 1 deletion src/runtime/local/kernels/kernels.json
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,31 @@
}
]
},
{
"kernelTemplate": {
"header": "StringEq.h",
"opName": "stringEq",
"returnType": "void",
"templateParams": [],
"runtimeParams": [
{
"type": "bool *",
"name": "res"
},
{
"type": "const char *",
"name": "lhs"
},
{
"type": "const char *",
"name": "rhs"
}
]
},
"instantiations": [
[]
]
},
{
"kernelTemplate": {
"header": "Concat.h",
Expand Down Expand Up @@ -4022,4 +4047,4 @@
[]
]
}
]
]
43 changes: 43 additions & 0 deletions test/codegen/stringeq.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// RUN: daphne-opt --canonicalize %s | FileCheck %s

func.func @string_string() {
%0 = "daphne.constant"() {value = "debug"} : () -> !daphne.String
%1 = "daphne.constant"() {value = "debug"} : () -> !daphne.String
// CHECK-NOT: daphne.ewEq
%2 = "daphne.ewEq"(%0, %1) : (!daphne.String, !daphne.String) -> !daphne.String
%3 = "daphne.cast"(%2) : (!daphne.String) -> i1
"daphne.print"(%0, %3, %3) : (!daphne.String, i1, i1) -> ()
"daphne.return"() : () -> ()
}

func.func @string_int() {
// CHECK-NOT: daphne.eqEq
// CHECK: daphne.cast
// CHECK: daphne.stringEq
%0 = "daphne.constant"() {value = "debug"} : () -> !daphne.String
%1 = "daphne.constant"() {value = 5 : si64} : () -> si64
%2 = "daphne.ewEq"(%0, %1) : (!daphne.String, si64) -> !daphne.String
%3 = "daphne.cast"(%2) : (!daphne.String) -> i1
"daphne.print"(%0, %3, %3) : (!daphne.String, i1, i1) -> ()
"daphne.return"() : () -> ()
}

func.func @int_int_do_not_canonicalize() {
%0 = "daphne.constant"() {value = 2 : si64} : () -> si64
%1 = "daphne.constant"() {value = 5 : si64} : () -> si64
%2 = "daphne.ewEq"(%0, %1) : (si64, si64) -> si64
// CHECK-NOT: daphne.stringEq
%3 = "daphne.cast"(%2) : (si64) -> i1
scf.if %3 {
%4 = "daphne.constant"() {value = "debug"} : () -> !daphne.String
%5 = "daphne.constant"() {value = true} : () -> i1
%6 = "daphne.constant"() {value = false} : () -> i1
"daphne.print"(%4, %5, %6) : (!daphne.String, i1, i1) -> ()
} else {
%4 = "daphne.constant"() {value = "release"} : () -> !daphne.String
%5 = "daphne.constant"() {value = true} : () -> i1
%6 = "daphne.constant"() {value = false} : () -> i1
"daphne.print"(%4, %5, %6) : (!daphne.String, i1, i1) -> ()
}
"daphne.return"() : () -> ()
}

0 comments on commit 4c78578

Please sign in to comment.