diff --git a/Makefile b/Makefile index a49615a91a91..484f25ed1cd7 100644 --- a/Makefile +++ b/Makefile @@ -594,6 +594,7 @@ SOURCE_FILES = \ StripAsserts.cpp \ Substitute.cpp \ Target.cpp \ + TargetQueryOps.cpp \ Tracing.cpp \ TrimNoOps.cpp \ Tuple.cpp \ @@ -778,6 +779,7 @@ HEADER_FILES = \ StripAsserts.h \ Substitute.h \ Target.h \ + TargetQueryOps.h \ Tracing.h \ TrimNoOps.h \ Tuple.h \ diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 3a89e132b9c3..b7b465aee1eb 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -160,6 +160,7 @@ set(HEADER_FILES StripAsserts.h Substitute.h Target.h + TargetQueryOps.h Tracing.h TrimNoOps.h Tuple.h @@ -346,6 +347,7 @@ set(SOURCE_FILES StripAsserts.cpp Substitute.cpp Target.cpp + TargetQueryOps.cpp Tracing.cpp TrimNoOps.cpp Tuple.cpp diff --git a/src/IR.cpp b/src/IR.cpp index 804e41234f71..3454b48f7936 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -680,6 +680,11 @@ const char *const intrinsic_op_names[] = { "sorted_avg", "strict_float", "stringify", + "target_arch_is", + "target_bits", + "target_has_feature", + "target_natural_vector_size", + "target_os_is", "undef", "unreachable", "unsafe_promise_clamped", diff --git a/src/IR.h b/src/IR.h index f21f3a9a52ba..c04c5068cb24 100644 --- a/src/IR.h +++ b/src/IR.h @@ -612,6 +612,13 @@ struct Call : public ExprNode { sorted_avg, strict_float, stringify, + + target_arch_is, + target_bits, + target_has_feature, + target_natural_vector_size, + target_os_is, + undef, unreachable, unsafe_promise_clamped, diff --git a/src/IROperator.cpp b/src/IROperator.cpp index 6ee62d66015c..f2e91fd5307e 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -2735,4 +2735,24 @@ Expr concat_bits(const std::vector &e) { return Call::make(t.with_bits(t.bits() * (int)e.size()), Call::concat_bits, e, Call::Intrinsic); } +Expr target_arch_is(Target::Arch arch) { + return Call::make(Bool(), Call::target_arch_is, {Expr((int)arch)}, Call::PureIntrinsic); +} + +Expr target_os_is(Target::OS os) { + return Call::make(Bool(), Call::target_os_is, {Expr((int)os)}, Call::PureIntrinsic); +} + +Expr target_bits() { + return Call::make(Int(32), Call::target_bits, {}, Call::PureIntrinsic); +} + +Expr target_has_feature(Target::Feature feat) { + return Call::make(Bool(), Call::target_has_feature, {Expr((int)feat)}, Call::PureIntrinsic); +} + +Expr target_natural_vector_size(Type t) { + return Call::make(Int(32), Call::target_natural_vector_size, {make_zero(t.element_of())}, Call::PureIntrinsic); +} + } // namespace Halide diff --git a/src/IROperator.h b/src/IROperator.h index ef2ef3526bb5..e5fe733569d5 100644 --- a/src/IROperator.h +++ b/src/IROperator.h @@ -11,6 +11,7 @@ #include #include "Expr.h" +#include "Target.h" #include "Tuple.h" namespace Halide { @@ -1689,6 +1690,46 @@ Expr rounding_mul_shift_right(Expr a, Expr b, Expr q); Expr rounding_mul_shift_right(Expr a, Expr b, int q); //@} +/** Return a boolean Expr for the corresponding field of the Target + * being used during lowering; they can be useful in writing library + * code without having to plumb a Target through call sites, so that you + * can do things like + \code + Expr e = select(target_arch_is(Target::ARM), something, something_else); + \endcode + */ +//@{ +Expr target_arch_is(Target::Arch arch); +Expr target_os_is(Target::OS os); +Expr target_has_feature(Target::Feature feat); +//@} + +/** Return the bit width of the Target used during lowering; this can be useful + * in writing library code without having to plumb a Target through call sites, so that you + * can do things like + \code + Expr e = select(target_bits() == 32, something, something_else); + \endcode + */ +Expr target_bits(); + +/** Return the natural vector width for the given Type for the Target + * being used during lowering; this can be useful in writing library + * code without having to plumb a Target through call sites, so that you + * can do things like + \code + f.vectorize(x, target_natural_vector_size(Float(32))); + \endcode + */ +//@{ +Expr target_natural_vector_size(Type t); +template +Expr target_natural_vector_size() { + return target_natural_vector_size(type_of()); +} +//@} + + } // namespace Halide #endif diff --git a/src/Lower.cpp b/src/Lower.cpp index f092e2e711ef..19be543975f1 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -70,6 +70,7 @@ #include "StrictifyFloat.h" #include "StripAsserts.h" #include "Substitute.h" +#include "TargetQueryOps.h" #include "Tracing.h" #include "TrimNoOps.h" #include "UnifyDuplicateLets.h" @@ -144,6 +145,8 @@ void lower_impl(const vector &output_funcs, // Create a deep-copy of the entire graph of Funcs. auto [outputs, env] = deep_copy(output_funcs, build_environment(output_funcs)); + lower_target_query_ops(env, t); + bool any_strict_float = strictify_float(env, t); result_module.set_any_strict_float(any_strict_float); diff --git a/src/TargetQueryOps.cpp b/src/TargetQueryOps.cpp new file mode 100644 index 000000000000..337d90c29b70 --- /dev/null +++ b/src/TargetQueryOps.cpp @@ -0,0 +1,54 @@ +#include "TargetQueryOps.h" + +#include "Function.h" +#include "IRMutator.h" +#include "IROperator.h" + +namespace Halide { +namespace Internal { + +namespace { + +class LowerTargetQueryOps : public IRMutator { + const Target &t; + + using IRMutator::visit; + + Expr visit(const Call *call) override { + if (call->is_intrinsic(Call::target_arch_is)) { + Target::Arch arch = (Target::Arch)*as_const_int(call->args[0]); + return make_bool(t.arch == arch); + } else if (call->is_intrinsic(Call::target_has_feature)) { + Target::Feature feat = (Target::Feature)*as_const_int(call->args[0]); + return make_bool(t.has_feature(feat)); + } else if (call->is_intrinsic(Call::target_natural_vector_size)) { + Expr zero = call->args[0]; + return Expr(t.natural_vector_size(zero.type())); + } else if (call->is_intrinsic(Call::target_os_is)) { + Target::OS os = (Target::OS)*as_const_int(call->args[0]); + return make_bool(t.os == os); + } else if (call->is_intrinsic(Call::target_bits)) { + return Expr(t.bits); + } + + return IRMutator::visit(call); + } + +public: + LowerTargetQueryOps(const Target &t) + : t(t) { + } +}; + +} // namespace + +void lower_target_query_ops(std::map &env, const Target &t) { + for (auto &iter : env) { + Function &func = iter.second; + LowerTargetQueryOps ltqo(t); + func.mutate(<qo); + } +} + +} // namespace Internal +} // namespace Halide diff --git a/src/TargetQueryOps.h b/src/TargetQueryOps.h new file mode 100644 index 000000000000..0cc8023b48a5 --- /dev/null +++ b/src/TargetQueryOps.h @@ -0,0 +1,24 @@ +#ifndef HALIDE_TARGET_QUERY_OPS_H +#define HALIDE_TARGET_QUERY_OPS_H + +/** \file + * Defines a lowering pass to lower all target_is() and target_has() helpers. + */ + +#include +#include + +namespace Halide { + +struct Target; + +namespace Internal { + +class Function; + +void lower_target_query_ops(std::map &env, const Target &t); + +} // namespace Internal +} // namespace Halide + +#endif diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 6d16d8612594..8ca5cfb05045 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -311,6 +311,7 @@ tests(GROUPS correctness strict_float_bounds.cpp strided_load.cpp target.cpp + target_query.cpp tiled_matmul.cpp tracing.cpp tracing_bounds.cpp diff --git a/test/correctness/target_query.cpp b/test/correctness/target_query.cpp new file mode 100644 index 000000000000..c3bfb9e8f123 --- /dev/null +++ b/test/correctness/target_query.cpp @@ -0,0 +1,48 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + // For simplicity, only run this test on hosts that we can predict. + Target t = get_host_target(); + if (t.arch != Target::X86 || t.bits != 64 || t.os != Target::OSX) { + printf("[SKIP] This test only runs on x86-64-osx.\n"); + return 0; + } + + t = t.with_feature(Target::Debug); + + // Full specification round-trip, crazy features + Target t1 = Target(Target::OSX, Target::X86, 64, + {Target::CUDA, Target::Debug}); + + Expr is_arm = target_arch_is(Target::ARM); + Expr is_x86 = target_arch_is(Target::X86); + Expr bits = target_bits(); + Expr is_android = target_os_is(Target::Android); + Expr is_osx = target_os_is(Target::OSX); + Expr vec = target_natural_vector_size(); + Expr has_cuda = target_has_feature(Target::CUDA); + Expr has_vulkan = target_has_feature(Target::Vulkan); + + Func f; + Var x; + + f(x) = select(is_arm, 1, 0) + + select(is_x86, 2, 0) + + select(vec == 4, 4, 0) + + select(is_android, 8, 0) + + select(is_osx, 16, 0) + + select(bits == 32, 32, 0) + + select(bits == 64, 64, 0) + + select(has_cuda, 128, 0) + + select(has_vulkan, 256, 0); + + Buffer result = f.realize({1}, t1); + + assert(result(0) == 2 + 4 + 16 + 64 + 128); + + printf("Success!\n"); + return 0; +}