Skip to content

Commit

Permalink
[CIR][CIRGen] handle __builtin_elementwise_exp (#1376)
Browse files Browse the repository at this point in the history
  • Loading branch information
FantasqueX authored Feb 20, 2025
1 parent d48d459 commit 6492b9b
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 19 deletions.
48 changes: 29 additions & 19 deletions clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ static mlir::Value tryUseTestFPKind(CIRGenFunction &CGF, unsigned BuiltinID,
}

template <class Operation>
static RValue emitUnaryFPBuiltin(CIRGenFunction &CGF, const CallExpr &E) {
static RValue emitUnaryMaybeConstrainedFPBuiltin(CIRGenFunction &CGF,
const CallExpr &E) {
auto Arg = CGF.emitScalarExpr(E.getArg(0));

CIRGenFunction::CIRGenFPOptionsRAII FPOptsRAII(CGF, &E);
Expand All @@ -75,6 +76,14 @@ static RValue emitUnaryFPBuiltin(CIRGenFunction &CGF, const CallExpr &E) {
return RValue::get(Call->getResult(0));
}

template <class Operation>
static RValue emitUnaryFPBuiltin(CIRGenFunction &CGF, const CallExpr &E) {
auto Arg = CGF.emitScalarExpr(E.getArg(0));
auto Call =
CGF.getBuilder().create<Operation>(Arg.getLoc(), Arg.getType(), Arg);
return RValue::get(Call->getResult(0));
}

template <typename Op>
static RValue emitUnaryMaybeConstrainedFPToIntBuiltin(CIRGenFunction &CGF,
const CallExpr &E) {
Expand Down Expand Up @@ -600,7 +609,7 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_ceilf16:
case Builtin::BI__builtin_ceill:
case Builtin::BI__builtin_ceilf128:
return emitUnaryFPBuiltin<cir::CeilOp>(*this, *E);
return emitUnaryMaybeConstrainedFPBuiltin<cir::CeilOp>(*this, *E);

case Builtin::BIcopysign:
case Builtin::BIcopysignf:
Expand All @@ -623,7 +632,7 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_cosl:
case Builtin::BI__builtin_cosf128:
assert(!cir::MissingFeatures::fastMathFlags());
return emitUnaryFPBuiltin<cir::CosOp>(*this, *E);
return emitUnaryMaybeConstrainedFPBuiltin<cir::CosOp>(*this, *E);

case Builtin::BIcosh:
case Builtin::BIcoshf:
Expand All @@ -644,7 +653,7 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_expl:
case Builtin::BI__builtin_expf128:
assert(!cir::MissingFeatures::fastMathFlags());
return emitUnaryFPBuiltin<cir::ExpOp>(*this, *E);
return emitUnaryMaybeConstrainedFPBuiltin<cir::ExpOp>(*this, *E);

case Builtin::BIexp2:
case Builtin::BIexp2f:
Expand All @@ -655,7 +664,7 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_exp2l:
case Builtin::BI__builtin_exp2f128:
assert(!cir::MissingFeatures::fastMathFlags());
return emitUnaryFPBuiltin<cir::Exp2Op>(*this, *E);
return emitUnaryMaybeConstrainedFPBuiltin<cir::Exp2Op>(*this, *E);

case Builtin::BI__builtin_exp10:
case Builtin::BI__builtin_exp10f:
Expand All @@ -672,7 +681,7 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_fabsf16:
case Builtin::BI__builtin_fabsl:
case Builtin::BI__builtin_fabsf128:
return emitUnaryFPBuiltin<cir::FAbsOp>(*this, *E);
return emitUnaryMaybeConstrainedFPBuiltin<cir::FAbsOp>(*this, *E);

case Builtin::BIfloor:
case Builtin::BIfloorf:
Expand All @@ -682,7 +691,7 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_floorf16:
case Builtin::BI__builtin_floorl:
case Builtin::BI__builtin_floorf128:
return emitUnaryFPBuiltin<cir::FloorOp>(*this, *E);
return emitUnaryMaybeConstrainedFPBuiltin<cir::FloorOp>(*this, *E);

case Builtin::BIfma:
case Builtin::BIfmaf:
Expand Down Expand Up @@ -745,7 +754,7 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_logl:
case Builtin::BI__builtin_logf128:
assert(!cir::MissingFeatures::fastMathFlags());
return emitUnaryFPBuiltin<cir::LogOp>(*this, *E);
return emitUnaryMaybeConstrainedFPBuiltin<cir::LogOp>(*this, *E);

case Builtin::BIlog10:
case Builtin::BIlog10f:
Expand All @@ -756,7 +765,7 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_log10l:
case Builtin::BI__builtin_log10f128:
assert(!cir::MissingFeatures::fastMathFlags());
return emitUnaryFPBuiltin<cir::Log10Op>(*this, *E);
return emitUnaryMaybeConstrainedFPBuiltin<cir::Log10Op>(*this, *E);

case Builtin::BIlog2:
case Builtin::BIlog2f:
Expand All @@ -767,7 +776,7 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_log2l:
case Builtin::BI__builtin_log2f128:
assert(!cir::MissingFeatures::fastMathFlags());
return emitUnaryFPBuiltin<cir::Log2Op>(*this, *E);
return emitUnaryMaybeConstrainedFPBuiltin<cir::Log2Op>(*this, *E);

case Builtin::BInearbyint:
case Builtin::BInearbyintf:
Expand All @@ -776,7 +785,7 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_nearbyintf:
case Builtin::BI__builtin_nearbyintl:
case Builtin::BI__builtin_nearbyintf128:
return emitUnaryFPBuiltin<cir::NearbyintOp>(*this, *E);
return emitUnaryMaybeConstrainedFPBuiltin<cir::NearbyintOp>(*this, *E);

case Builtin::BIpow:
case Builtin::BIpowf:
Expand All @@ -800,7 +809,7 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_rintf16:
case Builtin::BI__builtin_rintl:
case Builtin::BI__builtin_rintf128:
return emitUnaryFPBuiltin<cir::RintOp>(*this, *E);
return emitUnaryMaybeConstrainedFPBuiltin<cir::RintOp>(*this, *E);

case Builtin::BIround:
case Builtin::BIroundf:
Expand All @@ -810,7 +819,7 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_roundf16:
case Builtin::BI__builtin_roundl:
case Builtin::BI__builtin_roundf128:
return emitUnaryFPBuiltin<cir::RoundOp>(*this, *E);
return emitUnaryMaybeConstrainedFPBuiltin<cir::RoundOp>(*this, *E);

case Builtin::BIroundeven:
case Builtin::BIroundevenf:
Expand All @@ -831,7 +840,7 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_sinl:
case Builtin::BI__builtin_sinf128:
assert(!cir::MissingFeatures::fastMathFlags());
return emitUnaryFPBuiltin<cir::SinOp>(*this, *E);
return emitUnaryMaybeConstrainedFPBuiltin<cir::SinOp>(*this, *E);

case Builtin::BIsqrt:
case Builtin::BIsqrtf:
Expand All @@ -842,7 +851,7 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_sqrtl:
case Builtin::BI__builtin_sqrtf128:
assert(!cir::MissingFeatures::fastMathFlags());
return emitUnaryFPBuiltin<cir::SqrtOp>(*this, *E);
return emitUnaryMaybeConstrainedFPBuiltin<cir::SqrtOp>(*this, *E);

case Builtin::BI__builtin_elementwise_sqrt:
llvm_unreachable("BI__builtin_elementwise_sqrt NYI");
Expand Down Expand Up @@ -875,7 +884,7 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_truncf16:
case Builtin::BI__builtin_truncl:
case Builtin::BI__builtin_truncf128:
return emitUnaryFPBuiltin<cir::TruncOp>(*this, *E);
return emitUnaryMaybeConstrainedFPBuiltin<cir::TruncOp>(*this, *E);

case Builtin::BIlround:
case Builtin::BIlroundf:
Expand Down Expand Up @@ -1344,7 +1353,7 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
if (mlir::isa<cir::VectorType>(cirTy))
eltTy = mlir::cast<cir::VectorType>(cirTy).getEltType();
if (mlir::isa<cir::SingleType, cir::DoubleType>(eltTy)) {
return emitUnaryFPBuiltin<cir::FAbsOp>(*this, *E);
return emitUnaryMaybeConstrainedFPBuiltin<cir::FAbsOp>(*this, *E);
}
llvm_unreachable("unsupported type for BI__builtin_elementwise_abs");
}
Expand All @@ -1365,8 +1374,9 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
llvm_unreachable("BI__builtin_elementwise_atan2 NYI");
case Builtin::BI__builtin_elementwise_ceil:
llvm_unreachable("BI__builtin_elementwise_ceil NYI");
case Builtin::BI__builtin_elementwise_exp:
llvm_unreachable("BI__builtin_elementwise_exp NYI");
case Builtin::BI__builtin_elementwise_exp: {
return emitUnaryFPBuiltin<cir::ExpOp>(*this, *E);
}
case Builtin::BI__builtin_elementwise_exp2:
llvm_unreachable("BI__builtin_elementwise_exp2 NYI");
case Builtin::BI__builtin_elementwise_log:
Expand Down
21 changes: 21 additions & 0 deletions clang/test/CIR/CodeGen/builtins-elementwise.c
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,24 @@ void test_builtin_elementwise_acos(float f, double d, vfloat4 vf4,
// LLVM: {{%.*}} = call <4 x double> @llvm.acos.v4f64(<4 x double> {{%.*}})
vd4 = __builtin_elementwise_acos(vd4);
}

void test_builtin_elementwise_exp(float f, double d, vfloat4 vf4,
vdouble4 vd4) {
// CIR-LABEL: test_builtin_elementwise_exp
// LLVM-LABEL: test_builtin_elementwise_exp
// CIR: {{%.*}} = cir.exp {{%.*}} : !cir.float
// LLVM: {{%.*}} = call float @llvm.exp.f32(float {{%.*}})
f = __builtin_elementwise_exp(f);

// CIR: {{%.*}} = cir.exp {{%.*}} : !cir.double
// LLVM: {{%.*}} = call double @llvm.exp.f64(double {{%.*}})
d = __builtin_elementwise_exp(d);

// CIR: {{%.*}} = cir.exp {{%.*}} : !cir.vector<!cir.float x 4>
// LLVM: {{%.*}} = call <4 x float> @llvm.exp.v4f32(<4 x float> {{%.*}})
vf4 = __builtin_elementwise_exp(vf4);

// CIR: {{%.*}} = cir.exp {{%.*}} : !cir.vector<!cir.double x 4>
// LLVM: {{%.*}} = call <4 x double> @llvm.exp.v4f64(<4 x double> {{%.*}})
vd4 = __builtin_elementwise_exp(vd4);
}

0 comments on commit 6492b9b

Please sign in to comment.