Skip to content

Commit

Permalink
[CIR] introduce cir.float for floating-point types
Browse files Browse the repository at this point in the history
  • Loading branch information
Lancern committed Jan 30, 2024
1 parent 9bc8b47 commit 2bb9383
Show file tree
Hide file tree
Showing 49 changed files with 727 additions and 430 deletions.
24 changes: 24 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,30 @@ def IntAttr : CIR_Attr<"Int", "int", [TypedAttrInterface]> {
let hasCustomAssemblyFormat = 1;
}

//===----------------------------------------------------------------------===//
// FloatAttr
//===----------------------------------------------------------------------===//

def FloatAttr : CIR_Attr<"Float", "float", [TypedAttrInterface]> {
let summary = "An Attribute containing a floating-point value";
let description = [{
A float attribute is a literal attribute that represents a floating-point
value of the specified floating-point type.
}];
let parameters = (ins AttributeSelfTypeParameter<"">:$type, "APFloat":$value);
let builders = [
AttrBuilderWithInferredContext<(ins "Type":$type,
"const APFloat &":$value), [{
return $_get(type.getContext(), type, value);
}]>,
];
let extraClassDeclaration = [{
static FloatAttr getZero(mlir::cir::FloatType type);
}];
let genVerifyDecl = 1;
let hasCustomAssemblyFormat = 1;
}

//===----------------------------------------------------------------------===//
// ConstPointerAttr
//===----------------------------------------------------------------------===//
Expand Down
4 changes: 2 additions & 2 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2539,8 +2539,8 @@ def IterEndOp : CIR_Op<"iterator_end"> {

class UnaryFPToFPBuiltinOp<string mnemonic>
: CIR_Op<mnemonic, [Pure, SameOperandsAndResultType]> {
let arguments = (ins AnyFloat:$src);
let results = (outs AnyFloat:$result);
let arguments = (ins CIR_AnyFloat:$src);
let results = (outs CIR_AnyFloat:$result);
let summary = "libc builtin equivalent ignoring "
"floating point exceptions and errno";
let assemblyFormat = "$src `:` type($src) attr-dict";
Expand Down
37 changes: 37 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,43 @@

#include "clang/CIR/Interfaces/ASTAttrInterfaces.h"

//===----------------------------------------------------------------------===//
// CIR FloatType
//
// The base type for all floating-point types.
//===----------------------------------------------------------------------===//

namespace mlir {
namespace cir {

class SingleType;
class DoubleType;

class FloatType : public Type {
public:
using Type::Type;

// Convenience factories.
static SingleType getSingle(MLIRContext *ctx);
static DoubleType getDouble(MLIRContext *ctx);

/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type);

/// Return the bitwidth of this float type.
unsigned getWidth() const;

/// Return the width of the mantissa of this type.
/// The width includes the integer bit.
unsigned getFPMantissaWidth() const;

/// Return the float semantics of this floating-point type.
const llvm::fltSemantics &getFloatSemantics() const;
};

} // namespace cir
} // namespace mlir

//===----------------------------------------------------------------------===//
// CIR Dialect Tablegen'd Types
//===----------------------------------------------------------------------===//
Expand Down
38 changes: 35 additions & 3 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ include "clang/CIR/Dialect/IR/CIRDialect.td"
include "clang/CIR/Interfaces/ASTAttrInterfaces.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/EnumAttr.td"

//===----------------------------------------------------------------------===//
// CIR Types
//===----------------------------------------------------------------------===//

class CIR_Type<string name, string typeMnemonic, list<Trait> traits = []> :
TypeDef<CIR_Dialect, name, traits> {
class CIR_Type<string name, string typeMnemonic, list<Trait> traits = [],
string baseCppClass = "::mlir::Type">
: TypeDef<CIR_Dialect, name, traits, baseCppClass> {
let mnemonic = typeMnemonic;
}

Expand Down Expand Up @@ -94,6 +96,36 @@ def SInt16 : SInt<16>;
def SInt32 : SInt<32>;
def SInt64 : SInt<64>;

//===----------------------------------------------------------------------===//
// FloatType
//===----------------------------------------------------------------------===//

class CIR_FloatType<string name, string mnemonic>
: CIR_Type<name, mnemonic,
[DeclareTypeInterfaceMethods<DataLayoutTypeInterface>],
"::mlir::cir::FloatType"> {}

def CIR_Single : CIR_FloatType<"Single", "float"> {
let summary = "CIR single-precision float type";
let description = [{
Floating-point type that represents the `float` type in C/C++. Its
underlying floating-point format is the IEEE-754 binary32 format.
}];
}

def CIR_Double : CIR_FloatType<"Double", "double"> {
let summary = "CIR double-precision float type";
let description = [{
Floating-point type that represents the `double` type in C/C++. Its
underlying floating-point format is the IEEE-754 binar64 format.
}];
}

// Constraints

def CIR_AnyFloat: Type<
CPred<"$_self.isa<::mlir::FloatType, ::mlir::cir::FloatType>()">>;

//===----------------------------------------------------------------------===//
// PointerType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -300,7 +332,7 @@ def CIR_StructType : Type<CPred<"$_self.isa<::mlir::cir::StructType>()">,

def CIR_AnyType : AnyTypeOf<[
CIR_IntType, CIR_PointerType, CIR_BoolType, CIR_ArrayType, CIR_VectorType,
CIR_FuncType, CIR_VoidType, CIR_StructType, CIR_ExceptionInfo, AnyFloat,
CIR_FuncType, CIR_VoidType, CIR_StructType, CIR_ExceptionInfo, CIR_AnyFloat,
]>;

#endif // MLIR_CIR_DIALECT_CIR_TYPES
19 changes: 14 additions & 5 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
return mlir::cir::IntAttr::get(ty, 0);
if (ty.isa<mlir::FloatType>())
return mlir::FloatAttr::get(ty, 0.0);
if (auto fltType = ty.dyn_cast<mlir::cir::FloatType>())
return mlir::cir::FloatAttr::getZero(fltType);
if (auto arrTy = ty.dyn_cast<mlir::cir::ArrayType>())
return getZeroAttr(arrTy);
if (auto ptrTy = ty.dyn_cast<mlir::cir::PointerType>())
Expand All @@ -250,12 +252,18 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
if (const auto intVal = attr.dyn_cast<mlir::cir::IntAttr>())
return intVal.isNullValue();

if (const auto fpVal = attr.dyn_cast<mlir::FloatAttr>()) {
if (attr.isa<mlir::FloatAttr, mlir::cir::FloatAttr>()) {
auto fpVal = [&attr] {
if (auto fpAttr = attr.dyn_cast<mlir::cir::FloatAttr>())
return fpAttr.getValue();
return attr.cast<mlir::FloatAttr>().getValue();
}();

bool ignored;
llvm::APFloat FV(+0.0);
FV.convert(fpVal.getValue().getSemantics(),
llvm::APFloat::rmNearestTiesToEven, &ignored);
return FV.bitwiseIsEqual(fpVal.getValue());
FV.convert(fpVal.getSemantics(), llvm::APFloat::rmNearestTiesToEven,
&ignored);
return FV.bitwiseIsEqual(fpVal);
}

if (const auto structVal = attr.dyn_cast<mlir::cir::ConstStructAttr>()) {
Expand Down Expand Up @@ -471,7 +479,8 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
bool isSized(mlir::Type ty) {
if (ty.isIntOrFloat() ||
ty.isa<mlir::cir::PointerType, mlir::cir::StructType,
mlir::cir::ArrayType, mlir::cir::BoolType, mlir::cir::IntType>())
mlir::cir::ArrayType, mlir::cir::BoolType, mlir::cir::IntType,
mlir::cir::FloatType>())
return true;
assert(0 && "Unimplemented size for type");
return false;
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenExprConst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1479,6 +1479,8 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const APValue &Value,
assert(0 && "not implemented");
else {
mlir::Type ty = CGM.getCIRType(DestType);
if (ty.isa<mlir::cir::FloatType>())
return CGM.getBuilder().getAttr<mlir::cir::FloatAttr>(ty, Init);
return builder.getFloatAttr(ty, Init);
}
}
Expand Down
14 changes: 9 additions & 5 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,10 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
}
mlir::Value VisitFloatingLiteral(const FloatingLiteral *E) {
mlir::Type Ty = CGF.getCIRType(E->getType());
if (Ty.isa<mlir::cir::FloatType>())
return Builder.create<mlir::cir::ConstantOp>(
CGF.getLoc(E->getExprLoc()), Ty,
Builder.getAttr<mlir::cir::FloatAttr>(Ty, E->getValue()));
return Builder.create<mlir::cir::ConstantOp>(
CGF.getLoc(E->getExprLoc()), Ty,
Builder.getFloatAttr(Ty, E->getValue()));
Expand Down Expand Up @@ -1201,7 +1205,7 @@ mlir::Value ScalarExprEmitter::buildSub(const BinOpInfo &Ops) {
llvm_unreachable("NYI");

assert(!UnimplementedFeature::cirVectorType());
if (Ops.LHS.getType().isa<mlir::FloatType>()) {
if (Ops.LHS.getType().isa<mlir::FloatType, mlir::cir::FloatType>()) {
CIRGenFunction::CIRGenFPOptionsRAII FPOptsRAII(CGF, Ops.FPFeatures);
return Builder.createFSub(Ops.LHS, Ops.RHS);
}
Expand Down Expand Up @@ -1669,20 +1673,20 @@ mlir::Value ScalarExprEmitter::buildScalarCast(
llvm_unreachable("NYI: signed bool");
if (CGF.getBuilder().isInt(DstTy)) {
CastKind = mlir::cir::CastKind::bool_to_int;
} else if (DstTy.isa<mlir::FloatType>()) {
} else if (DstTy.isa<mlir::FloatType, mlir::cir::FloatType>()) {
CastKind = mlir::cir::CastKind::bool_to_float;
} else {
llvm_unreachable("Internal error: Cast to unexpected type");
}
} else if (CGF.getBuilder().isInt(SrcTy)) {
if (CGF.getBuilder().isInt(DstTy)) {
CastKind = mlir::cir::CastKind::integral;
} else if (DstTy.isa<mlir::FloatType>()) {
} else if (DstTy.isa<mlir::FloatType, mlir::cir::FloatType>()) {
CastKind = mlir::cir::CastKind::int_to_float;
} else {
llvm_unreachable("Internal error: Cast to unexpected type");
}
} else if (SrcTy.isa<mlir::FloatType>()) {
} else if (SrcTy.isa<mlir::FloatType, mlir::cir::FloatType>()) {
if (CGF.getBuilder().isInt(DstTy)) {
// If we can't recognize overflow as undefined behavior, assume that
// overflow saturates. This protects against normal optimizations if we
Expand All @@ -1692,7 +1696,7 @@ mlir::Value ScalarExprEmitter::buildScalarCast(
if (Builder.getIsFPConstrained())
llvm_unreachable("NYI");
CastKind = mlir::cir::CastKind::float_to_int;
} else if (DstTy.isa<mlir::FloatType>()) {
} else if (DstTy.isa<mlir::FloatType, mlir::cir::FloatType>()) {
// TODO: split this to createFPExt/createFPTrunc
return Builder.createFloatingCast(Src, DstTy);
} else {
Expand Down
6 changes: 3 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,11 @@ CIRGenModule::CIRGenModule(mlir::MLIRContext &context,

// TODO: HalfTy
// TODO: BFloatTy
FloatTy = builder.getF32Type();
DoubleTy = builder.getF64Type();
FloatTy = ::mlir::cir::FloatType::getSingle(builder.getContext());
DoubleTy = ::mlir::cir::FloatType::getDouble(builder.getContext());
// TODO(cir): perhaps we should abstract long double variations into a custom
// cir.long_double type. Said type would also hold the semantics for lowering.
LongDouble80BitsTy = builder.getF80Type();
LongDouble80BitsTy = ::mlir::FloatType::getF80(builder.getContext());

// TODO: PointerWidthInBits
PointerAlignInBytes =
Expand Down
4 changes: 3 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenTypeCache.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ struct CIRGenTypeCache {
// mlir::Type HalfTy, BFloatTy;
// TODO(cir): perhaps we should abstract long double variations into a custom
// cir.long_double type. Said type would also hold the semantics for lowering.
mlir::FloatType FloatTy, DoubleTy, LongDouble80BitsTy;
mlir::cir::SingleType FloatTy;
mlir::cir::DoubleType DoubleTy;
mlir::FloatType LongDouble80BitsTy;

/// int
mlir::Type UIntTy;
Expand Down
54 changes: 54 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,60 @@ LogicalResult IntAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}

//===----------------------------------------------------------------------===//
// FloatAttr definitions
//===----------------------------------------------------------------------===//

Attribute cir::FloatAttr::parse(AsmParser &parser, Type odsType) {
double value;

if (!odsType.isa<cir::FloatType>())
return {};
auto ty = odsType.cast<cir::FloatType>();

if (parser.parseLess())
return {};

if (parser.parseFloat(value))
parser.emitError(parser.getCurrentLocation(),
"expected floating-point value");

if (parser.parseGreater())
return {};

auto losesInfo = false;
APFloat convertedValue{value};
convertedValue.convert(ty.getFloatSemantics(), llvm::RoundingMode::TowardZero,
&losesInfo);

return cir::FloatAttr::get(ty, convertedValue);
}

void cir::FloatAttr::print(AsmPrinter &printer) const {
printer << '<' << getValue() << '>';
}

cir::FloatAttr cir::FloatAttr::getZero(mlir::cir::FloatType type) {
return get(type, APFloat::getZero(type.getFloatSemantics()));
}

LogicalResult
cir::FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError, Type type,
APFloat value) {
auto fltType = type.dyn_cast<cir::FloatType>();
if (!fltType) {
emitError() << "expected floating-point type";
return failure();
}
if (APFloat::SemanticsToEnum(fltType.getFloatSemantics()) !=
APFloat::SemanticsToEnum(value.getSemantics())) {
emitError() << "floating-point semantics mismatch";
return failure();
}

return success();
}

//===----------------------------------------------------------------------===//
// CIR Dialect
//===----------------------------------------------------------------------===//
Expand Down
15 changes: 8 additions & 7 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,8 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
return success();
}

if (attrType.isa<mlir::cir::IntAttr, FloatAttr>()) {
if (attrType
.isa<mlir::cir::IntAttr, mlir::FloatAttr, mlir::cir::FloatAttr>()) {
auto at = attrType.cast<TypedAttr>();
if (at.getType() != opType) {
return op->emitOpError("result type (")
Expand Down Expand Up @@ -423,13 +424,13 @@ LogicalResult CastOp::verify() {
return success();
}
case cir::CastKind::floating: {
if (!srcType.dyn_cast<mlir::FloatType>() ||
!resType.dyn_cast<mlir::FloatType>())
if (!srcType.isa<mlir::FloatType, mlir::cir::FloatType>() ||
!resType.isa<mlir::FloatType, mlir::cir::FloatType>())
return emitOpError() << "requries floating for source and result";
return success();
}
case cir::CastKind::float_to_int: {
if (!srcType.dyn_cast<mlir::FloatType>())
if (!srcType.isa<mlir::FloatType, mlir::cir::FloatType>())
return emitOpError() << "requires floating for source";
if (!resType.dyn_cast<mlir::cir::IntType>())
return emitOpError() << "requires !IntegerType for result";
Expand All @@ -450,7 +451,7 @@ LogicalResult CastOp::verify() {
return success();
}
case cir::CastKind::float_to_bool: {
if (!srcType.isa<mlir::FloatType>())
if (!srcType.isa<mlir::FloatType, mlir::cir::FloatType>())
return emitOpError() << "requires float for source";
if (!resType.isa<mlir::cir::BoolType>())
return emitOpError() << "requires !cir.bool for result";
Expand All @@ -466,14 +467,14 @@ LogicalResult CastOp::verify() {
case cir::CastKind::int_to_float: {
if (!srcType.isa<mlir::cir::IntType>())
return emitOpError() << "requires !cir.int for source";
if (!resType.isa<mlir::FloatType>())
if (!resType.isa<mlir::FloatType, mlir::cir::FloatType>())
return emitOpError() << "requires !cir.float for result";
return success();
}
case cir::CastKind::bool_to_float: {
if (!srcType.isa<mlir::cir::BoolType>())
return emitOpError() << "requires !cir.bool for source";
if (!resType.isa<mlir::FloatType>())
if (!resType.isa<mlir::FloatType, mlir::cir::FloatType>())
return emitOpError() << "requires !cir.float for result";
return success();
}
Expand Down
Loading

0 comments on commit 2bb9383

Please sign in to comment.