diff --git a/src/Simplify.cpp b/src/Simplify.cpp index 6bb34fc4db15..284eaa76568d 100644 --- a/src/Simplify.cpp +++ b/src/Simplify.cpp @@ -276,6 +276,24 @@ void Simplify::ScopedFact::learn_true(const Expr &fact) { learn_lower_bound(v, i.bounds.min + 1); } } + const Min *min = lt->b.as(); + if (min) { + // c < min(a, b) -> c < a, c < b + learn_true(lt->a < min->a); + learn_true(lt->a < min->b); + // c < min(a, b) -> !(a <= c), !(b <= c) + learn_false(min->a <= lt->a); + learn_false(min->b <= lt->a); + } + const Max *max = lt->a.as(); + if (max) { + // max(a, b) < c -> a < c, b < c + learn_true(max->a < lt->b); + learn_true(max->b < lt->b); + // max(a, b) < c -> !(c <= a), !(c <= b) + learn_false(lt->b <= max->a); + learn_false(lt->b <= max->b); + } } else if (const LE *le = fact.as()) { const Variable *v = le->a.as(); Simplify::ExprInfo i; @@ -294,6 +312,24 @@ void Simplify::ScopedFact::learn_true(const Expr &fact) { learn_lower_bound(v, i.bounds.min); } } + const Min *min = le->b.as(); + if (min) { + // c <= min(a, b) -> c <= a, c <= b + learn_true(le->a <= min->a); + learn_true(le->a <= min->b); + // c <= min(a, b) -> !(a < c), !(b < c) + learn_false(min->a < le->a); + learn_false(min->b < le->a); + } + const Max *max = le->a.as(); + if (max) { + // max(a, b) <= c -> a <= c, b <= c + learn_true(max->a <= le->b); + learn_true(max->b <= le->b); + // max(a, b) <= c -> !(c < a), !(c < b) + learn_false(le->b < max->a); + learn_false(le->b < max->b); + } } else if (const Call *c = Call::as_tag(fact)) { learn_true(c->args[0]); return; diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 8ca5cfb05045..d3ca7ead1586 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -296,6 +296,7 @@ tests(GROUPS correctness sort_exprs.cpp specialize.cpp specialize_to_gpu.cpp + specialize_trim_condition.cpp split_by_non_factor.cpp split_fuse_rvar.cpp split_reuse_inner_name_bug.cpp diff --git a/test/correctness/specialize_trim_condition.cpp b/test/correctness/specialize_trim_condition.cpp new file mode 100644 index 000000000000..c2753e17c09f --- /dev/null +++ b/test/correctness/specialize_trim_condition.cpp @@ -0,0 +1,78 @@ +#include "Halide.h" +#include "HalideRuntime.h" +#include +#include +#include + +using namespace Halide; + +int load_count = 0; + +// A trace that records the number of loads +int my_trace(JITUserContext *user_context, const halide_trace_event_t *ev) { + + if (ev->event == halide_trace_load) { + load_count++; + } + return 0; +} + +int main(int argc, char **argv) { + Param scale_factor_x, scale_factor_y; + ImageParam input(UInt(8), 2); + + Var x, y; + + Func f; + Expr upsample_x = scale_factor_x > cast(1.0f); + Expr upsample_y = scale_factor_y > cast(1.0f); + Expr upsample = upsample_x && upsample_y; + Expr downsample = !upsample_x && !upsample_y; + + f(x, y) = select(upsample, input(cast(x / 2), cast(y / 2)), + select(downsample, input(x * 2, y * 2), 0)); + + input.trace_loads(); + f.jit_handlers().custom_trace = &my_trace; + + // Impossible condition + // f.specialize(upsample && downsample); + f.specialize(upsample && !downsample); + f.specialize(!upsample && downsample); + f.specialize(!upsample && !downsample); + f.specialize_fail("Unreachable condition"); + + Buffer img(16, 16); + input.set(img); + + { + // In this specialization, one of the select branches should be trimmed, + // resulting in one load per output pixel + load_count = 0; + scale_factor_x.set(2.0f); + scale_factor_y.set(2.0f); + Buffer out = f.realize({8, 8}); + assert(load_count == 64); + } + { + // In this specialization, no select can be trimmed, + // resulting in two loads per output pixel + load_count = 0; + scale_factor_x.set(0.5f); + scale_factor_y.set(2.0f); + Buffer out = f.realize({8, 8}); + assert(load_count == 128); + } + { + // In this specialization, one of the select branches should be trimmed, + // resulting in one load per output pixel + load_count = 0; + scale_factor_x.set(0.5f); + scale_factor_y.set(0.5f); + Buffer out = f.realize({8, 8}); + assert(load_count == 64); + } + + printf("Success!\n"); + return 0; +}