Skip to content

Commit

Permalink
[CIR][CIRGen] Support for C++20 three-way comparison (#485)
Browse files Browse the repository at this point in the history
This patch adds CIRGen support for the C++20 three-way comparison
operator `<=>`. The binary operator is directly lowered to existing CIR
operations.

Most of the changes are tests.
  • Loading branch information
Lancern authored Mar 19, 2024
1 parent 8bf4c69 commit 43551d4
Show file tree
Hide file tree
Showing 9 changed files with 731 additions and 4 deletions.
62 changes: 62 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,68 @@ def ConstPtrAttr : CIR_Attr<"ConstPtr", "ptr", [TypedAttrInterface]> {
let hasCustomAssemblyFormat = 1;
}

//===----------------------------------------------------------------------===//
// CmpThreeWayInfoAttr
//===----------------------------------------------------------------------===//

def CmpOrdering_Strong : I32EnumAttrCase<"Strong", 1, "strong">;
def CmpOrdering_Partial : I32EnumAttrCase<"Partial", 2, "partial">;

def CmpOrdering : I32EnumAttr<
"CmpOrdering", "three-way comparison ordering kind",
[CmpOrdering_Strong, CmpOrdering_Partial]
> {
let cppNamespace = "::mlir::cir";
}

def CmpThreeWayInfoAttr : CIR_Attr<"CmpThreeWayInfo", "cmp3way_info"> {
let summary = "Holds information about a three-way comparison operation";
let description = [{
The `#cmp3way_info` attribute contains information about a three-way
comparison operation `cir.cmp3way`.

The `ordering` parameter gives the ordering kind of the three-way comparison
operation. It may be either strong ordering or partial ordering.

Given the two input operands of the three-way comparison operation `lhs` and
`rhs`, the `lt`, `eq`, `gt`, and `unordered` parameters gives the result
value that should be produced by the three-way comparison operation when the
ordering between `lhs` and `rhs` is `lhs < rhs`, `lhs == rhs`, `lhs > rhs`,
or neither, respectively.
}];

let parameters = (ins "CmpOrdering":$ordering, "int64_t":$lt, "int64_t":$eq,
"int64_t":$gt,
OptionalParameter<"std::optional<int64_t>">:$unordered);

let builders = [
AttrBuilder<(ins "int64_t":$lt, "int64_t":$eq, "int64_t":$gt), [{
return $_get($_ctxt, CmpOrdering::Strong, lt, eq, gt, std::nullopt);
}]>,
AttrBuilder<(ins "int64_t":$lt, "int64_t":$eq, "int64_t":$gt,
"int64_t":$unordered), [{
return $_get($_ctxt, CmpOrdering::Partial, lt, eq, gt, unordered);
}]>,
];

let extraClassDeclaration = [{
/// Get attribute alias name for this attribute.
std::string getAlias() const;
}];

let assemblyFormat = [{
`<`
$ordering `,`
`lt` `=` $lt `,`
`eq` `=` $eq `,`
`gt` `=` $gt
(`,` `unordered` `=` $unordered^)?
`>`
}];

let genVerifyDecl = 1;
}

//===----------------------------------------------------------------------===//
// DataMemberAttr
//===----------------------------------------------------------------------===//
Expand Down
49 changes: 49 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1147,6 +1147,55 @@ def BitPopcountOp : CIR_BitOp<"bit.popcount", UIntOfWidths<[16, 32, 64]>> {
}];
}

//===----------------------------------------------------------------------===//
// CmpThreeWayOp
//===----------------------------------------------------------------------===//

def CmpThreeWayOp : CIR_Op<"cmp3way", [Pure, SameTypeOperands]> {
let summary = "Compare two values with C++ three-way comparison semantics";
let description = [{
The `cir.cmp3way` operation models the `<=>` operator in C++20. It takes two
operands with the same type and produces a result indicating the ordering
between the two input operands.

The result of the operation is a signed integer that indicates the ordering
between the two input operands.

There are two kinds of ordering: strong ordering and partial ordering.
Comparing different types of values yields different kinds of orderings.
The `info` parameter gives the ordering kind and other necessary information
about the comparison.

Example:

```mlir
!s32i = !cir.int<s, 32>

#cmp3way_strong = #cmp3way_info<strong, lt = -1, eq = 0, gt = 1>
#cmp3way_partial = #cmp3way_info<strong, lt = -1, eq = 0, gt = 1, unordered = 2>

%0 = cir.const(#cir.int<0> : !s32i) : !s32i
%1 = cir.const(#cir.int<1> : !s32i) : !s32i
%2 = cir.cmp3way(%0 : !s32i, %1, #cmp3way_strong) : !s8i

%3 = cir.const(#cir.fp<0.0> : !cir.float) : !cir.float
%4 = cir.const(#cir.fp<1.0> : !cir.float) : !cir.float
%5 = cir.cmp3way(%3 : !cir.float, %4, #cmp3way_partial) : !s8i
```
}];

let results = (outs CIR_IntType:$result);
let arguments = (ins CIR_AnyType:$lhs, CIR_AnyType:$rhs,
CmpThreeWayInfoAttr:$info);

let assemblyFormat = [{
`(` $lhs `:` type($lhs) `,` $rhs `,` qualified($info) `)`
`:` type($result) attr-dict
}];

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// SwitchOp
//===----------------------------------------------------------------------===//
Expand Down
48 changes: 48 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,20 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
return mlir::cir::TypeInfoAttr::get(anonStruct.getType(), fieldsAttr);
}

mlir::cir::CmpThreeWayInfoAttr getCmpThreeWayInfoStrongOrdering(
const llvm::APSInt &lt, const llvm::APSInt &eq, const llvm::APSInt &gt) {
return mlir::cir::CmpThreeWayInfoAttr::get(
getContext(), lt.getSExtValue(), eq.getSExtValue(), gt.getSExtValue());
}

mlir::cir::CmpThreeWayInfoAttr getCmpThreeWayInfoPartialOrdering(
const llvm::APSInt &lt, const llvm::APSInt &eq, const llvm::APSInt &gt,
const llvm::APSInt &unordered) {
return mlir::cir::CmpThreeWayInfoAttr::get(
getContext(), lt.getSExtValue(), eq.getSExtValue(), gt.getSExtValue(),
unordered.getSExtValue());
}

mlir::cir::DataMemberAttr getDataMemberAttr(mlir::cir::DataMemberType ty,
size_t memberIndex) {
return mlir::cir::DataMemberAttr::get(getContext(), ty, memberIndex);
Expand Down Expand Up @@ -598,6 +612,11 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
return create<mlir::cir::ContinueOp>(loc);
}

mlir::cir::CmpOp createCompare(mlir::Location loc, mlir::cir::CmpOpKind kind,
mlir::Value lhs, mlir::Value rhs) {
return create<mlir::cir::CmpOp>(loc, getBoolTy(), kind, lhs, rhs);
}

mlir::cir::MemCpyOp createMemCpy(mlir::Location loc, mlir::Value dst,
mlir::Value src, mlir::Value len) {
return create<mlir::cir::MemCpyOp>(loc, dst, src, len);
Expand Down Expand Up @@ -824,6 +843,35 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
}
}

mlir::cir::CmpThreeWayOp
createThreeWayCmpStrong(mlir::Location loc, mlir::Value lhs, mlir::Value rhs,
const llvm::APSInt &ltRes, const llvm::APSInt &eqRes,
const llvm::APSInt &gtRes) {
assert(ltRes.getBitWidth() == eqRes.getBitWidth() &&
ltRes.getBitWidth() == gtRes.getBitWidth() &&
"the three comparison results must have the same bit width");
auto cmpResultTy = getSIntNTy(ltRes.getBitWidth());
auto infoAttr = getCmpThreeWayInfoStrongOrdering(ltRes, eqRes, gtRes);
return create<mlir::cir::CmpThreeWayOp>(loc, cmpResultTy, lhs, rhs,
infoAttr);
}

mlir::cir::CmpThreeWayOp
createThreeWayCmpPartial(mlir::Location loc, mlir::Value lhs, mlir::Value rhs,
const llvm::APSInt &ltRes, const llvm::APSInt &eqRes,
const llvm::APSInt &gtRes,
const llvm::APSInt &unorderedRes) {
assert(ltRes.getBitWidth() == eqRes.getBitWidth() &&
ltRes.getBitWidth() == gtRes.getBitWidth() &&
ltRes.getBitWidth() == unorderedRes.getBitWidth() &&
"the four comparison results must have the same bit width");
auto cmpResultTy = getSIntNTy(ltRes.getBitWidth());
auto infoAttr =
getCmpThreeWayInfoPartialOrdering(ltRes, eqRes, gtRes, unorderedRes);
return create<mlir::cir::CmpThreeWayOp>(loc, cmpResultTy, lhs, rhs,
infoAttr);
}

mlir::cir::GetRuntimeMemberOp createGetIndirectMember(mlir::Location loc,
mlir::Value objectPtr,
mlir::Value memberPtr) {
Expand Down
60 changes: 59 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenExprAgg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/IR/Attributes.h"

#include "clang/AST/Decl.h"
#include "clang/AST/Expr.h"
#include "clang/AST/OperationKinds.h"
#include "clang/AST/RecordLayout.h"
#include "clang/AST/StmtVisitor.h"
Expand Down Expand Up @@ -261,7 +262,7 @@ class AggExprEmitter : public StmtVisitor<AggExprEmitter> {
}

void VisitBinComma(const BinaryOperator *E) { llvm_unreachable("NYI"); }
void VisitBinCmp(const BinaryOperator *E) { llvm_unreachable("NYI"); }
void VisitBinCmp(const BinaryOperator *E);
void VisitCXXRewrittenBinaryOperator(CXXRewrittenBinaryOperator *E) {
llvm_unreachable("NYI");
}
Expand Down Expand Up @@ -1024,6 +1025,63 @@ void AggExprEmitter::withReturnValueSlot(
}
}

void AggExprEmitter::VisitBinCmp(const BinaryOperator *E) {
assert(CGF.getContext().hasSameType(E->getLHS()->getType(),
E->getRHS()->getType()));
const ComparisonCategoryInfo &CmpInfo =
CGF.getContext().CompCategories.getInfoForType(E->getType());
assert(CmpInfo.Record->isTriviallyCopyable() &&
"cannot copy non-trivially copyable aggregate");

QualType ArgTy = E->getLHS()->getType();

if (!ArgTy->isIntegralOrEnumerationType() && !ArgTy->isRealFloatingType() &&
!ArgTy->isNullPtrType() && !ArgTy->isPointerType() &&
!ArgTy->isMemberPointerType() && !ArgTy->isAnyComplexType())
llvm_unreachable("aggregate three-way comparison");

auto Loc = CGF.getLoc(E->getSourceRange());

if (E->getType()->isAnyComplexType())
llvm_unreachable("NYI");

auto LHS = CGF.buildAnyExpr(E->getLHS()).getScalarVal();
auto RHS = CGF.buildAnyExpr(E->getRHS()).getScalarVal();

mlir::Value ResultScalar;
if (ArgTy->isNullPtrType()) {
ResultScalar =
CGF.builder.getConstInt(Loc, CmpInfo.getEqualOrEquiv()->getIntValue());
} else {
auto LtRes = CmpInfo.getLess()->getIntValue();
auto EqRes = CmpInfo.getEqualOrEquiv()->getIntValue();
auto GtRes = CmpInfo.getGreater()->getIntValue();
if (!CmpInfo.isPartial()) {
// Strong ordering.
ResultScalar = CGF.builder.createThreeWayCmpStrong(Loc, LHS, RHS, LtRes,
EqRes, GtRes);
} else {
// Partial ordering.
auto UnorderedRes = CmpInfo.getUnordered()->getIntValue();
ResultScalar = CGF.builder.createThreeWayCmpPartial(
Loc, LHS, RHS, LtRes, EqRes, GtRes, UnorderedRes);
}
}

// Create the return value in the destination slot.
EnsureDest(Loc, E->getType());
LValue DestLV = CGF.makeAddrLValue(Dest.getAddress(), E->getType());

// Emit the address of the first (and only) field in the comparison category
// type, and initialize it from the constant integer value produced above.
const FieldDecl *ResultField = *CmpInfo.Record->field_begin();
LValue FieldLV = CGF.buildLValueForFieldInitialization(
DestLV, ResultField, ResultField->getName());
CGF.buildStoreThroughLValue(RValue::get(ResultScalar), FieldLV);

// All done! The result is in the Dest slot.
}

void AggExprEmitter::VisitInitListExpr(InitListExpr *E) {
// TODO(cir): use something like CGF.ErrorUnsupported
if (E->hadArrayRangeDesignator())
Expand Down
52 changes: 52 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,58 @@ LogicalResult cir::FPAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}

//===----------------------------------------------------------------------===//
// CmpThreeWayInfoAttr definitions
//===----------------------------------------------------------------------===//

std::string CmpThreeWayInfoAttr::getAlias() const {
std::string alias = "cmp3way_info";

if (getOrdering() == CmpOrdering::Strong)
alias.append("_strong_");
else
alias.append("_partial_");

auto appendInt = [&](int64_t value) {
if (value < 0) {
alias.push_back('n');
value = -value;
}
alias.append(std::to_string(value));
};

alias.append("lt");
appendInt(getLt());
alias.append("eq");
appendInt(getEq());
alias.append("gt");
appendInt(getGt());

if (auto unordered = getUnordered()) {
alias.append("un");
appendInt(unordered.value());
}

return alias;
}

LogicalResult
CmpThreeWayInfoAttr::verify(function_ref<InFlightDiagnostic()> emitError,
CmpOrdering ordering, int64_t lt, int64_t eq,
int64_t gt, std::optional<int64_t> unordered) {
// The presense of unordered must match the value of ordering.
if (ordering == CmpOrdering::Strong && unordered) {
emitError() << "strong ordering does not include unordered ordering";
return failure();
}
if (ordering == CmpOrdering::Partial && !unordered) {
emitError() << "partial ordering lacks unordered ordering";
return failure();
}

return success();
}

//===----------------------------------------------------------------------===//
// DataMemberAttr definitions
//===----------------------------------------------------------------------===//
Expand Down
19 changes: 19 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ struct CIROpAsmDialectInterface : public OpAsmDialectInterface {
os << "fn_attr";
return AliasResult::FinalAlias;
}
if (auto cmpThreeWayInfoAttr =
attr.dyn_cast<mlir::cir::CmpThreeWayInfoAttr>()) {
os << cmpThreeWayInfoAttr.getAlias();
return AliasResult::FinalAlias;
}

return AliasResult::NoAlias;
}
Expand Down Expand Up @@ -870,6 +875,20 @@ Block *BrCondOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
return nullptr;
}

//===----------------------------------------------------------------------===//
// CmpThreeWayOp
//===----------------------------------------------------------------------===//

LogicalResult CmpThreeWayOp::verify() {
// Type of the result must be a signed integer type.
if (!getType().isSigned()) {
emitOpError() << "result type of cir.cmp3way must be a signed integer type";
return failure();
}

return success();
}

//===----------------------------------------------------------------------===//
// SwitchOp
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit 43551d4

Please sign in to comment.