Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate some facts about inequalities with min/max #8475

Merged
merged 4 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions src/Simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,24 @@ void Simplify::ScopedFact::learn_true(const Expr &fact) {
learn_lower_bound(v, i.bounds.min + 1);
}
}
const Min *min = lt->b.as<Min>();
if (min) {
// c < min(a, b) -> c < a, c < b
learn_true(lt->a < min->a);
learn_true(lt->a < min->b);
// c < min(a, b) -> !(a <= c), !(b <= c)
learn_false(min->a <= lt->a);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the inverse statements required?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question. Is this just to get a syntactic match on the <= form? The other cases here don't do it, but I think this is not exactly the same as the other cases.

Copy link
Contributor Author

@shoaibkamil shoaibkamil Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found it was necessary i.e. if I comment out the learn_false() lines, the test fails, since we can't eliminate the condition in the third specialization. I added all the learn_false() clauses even though it may be true that only one or two are needed for this testcase.

We have a rewrite that changes downsample = !(scale_factor_x > 1.0) && !(scale_factor_y > 1.0) into downsample = max(scale_factor_x, scale_factor_y) <= 1.0. Without these learn_false() clauses, the simplifier is unable to prove max(scale_factor_x, scale_factor_y) <= 1.0 --> !(1.0 < scale_factor_x) and similarly max(scale_factor_x, scale_factor_y) <= 1.0 --> !(1.0 < scale_factor_y), so it can't trim the conditions in that branch.

learn_false(min->b <= lt->a);
}
const Max *max = lt->a.as<Max>();
if (max) {
// max(a, b) < c -> a < c, b < c
learn_true(max->a < lt->b);
learn_true(max->b < lt->b);
// max(a, b) < c -> !(c <= a), !(c <= b)
learn_false(lt->b <= max->a);
learn_false(lt->b <= max->b);
}
} else if (const LE *le = fact.as<LE>()) {
const Variable *v = le->a.as<Variable>();
Simplify::ExprInfo i;
Expand All @@ -294,6 +312,24 @@ void Simplify::ScopedFact::learn_true(const Expr &fact) {
learn_lower_bound(v, i.bounds.min);
}
}
const Min *min = le->b.as<Min>();
if (min) {
// c <= min(a, b) -> c <= a, c <= b
learn_true(le->a <= min->a);
learn_true(le->a <= min->b);
// c <= min(a, b) -> !(a < c), !(b < c)
learn_false(min->a < le->a);
learn_false(min->b < le->a);
}
const Max *max = le->a.as<Max>();
if (max) {
// max(a, b) <= c -> a <= c, b <= c
learn_true(max->a <= le->b);
learn_true(max->b <= le->b);
// max(a, b) <= c -> !(c < a), !(c < b)
learn_false(le->b < max->a);
learn_false(le->b < max->b);
}
} else if (const Call *c = Call::as_tag(fact)) {
learn_true(c->args[0]);
return;
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ tests(GROUPS correctness
sort_exprs.cpp
specialize.cpp
specialize_to_gpu.cpp
specialize_trim_condition.cpp
split_by_non_factor.cpp
split_fuse_rvar.cpp
split_reuse_inner_name_bug.cpp
Expand Down
78 changes: 78 additions & 0 deletions test/correctness/specialize_trim_condition.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#include "Halide.h"
#include "HalideRuntime.h"
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>

using namespace Halide;

int load_count = 0;

// A trace that records the number of loads
int my_trace(JITUserContext *user_context, const halide_trace_event_t *ev) {

if (ev->event == halide_trace_load) {
load_count++;
}
return 0;
}

int main(int argc, char **argv) {
Param<float> scale_factor_x, scale_factor_y;
ImageParam input(UInt(8), 2);

Var x, y;

Func f;
Expr upsample_x = scale_factor_x > cast<float>(1.0f);
Expr upsample_y = scale_factor_y > cast<float>(1.0f);
Expr upsample = upsample_x && upsample_y;
Expr downsample = !upsample_x && !upsample_y;

f(x, y) = select(upsample, input(cast<int>(x / 2), cast<int>(y / 2)),
select(downsample, input(x * 2, y * 2), 0));

input.trace_loads();
f.jit_handlers().custom_trace = &my_trace;

// Impossible condition
// f.specialize(upsample && downsample);
f.specialize(upsample && !downsample);
f.specialize(!upsample && downsample);
f.specialize(!upsample && !downsample);
f.specialize_fail("Unreachable condition");

Buffer<uint8_t> img(16, 16);
input.set(img);

{
// In this specialization, one of the select branches should be trimmed,
// resulting in one load per output pixel
load_count = 0;
scale_factor_x.set(2.0f);
scale_factor_y.set(2.0f);
Buffer<uint8_t> out = f.realize({8, 8});
assert(load_count == 64);
}
{
// In this specialization, no select can be trimmed,
// resulting in two loads per output pixel
load_count = 0;
scale_factor_x.set(0.5f);
scale_factor_y.set(2.0f);
Buffer<uint8_t> out = f.realize({8, 8});
assert(load_count == 128);
}
{
// In this specialization, one of the select branches should be trimmed,
// resulting in one load per output pixel
load_count = 0;
scale_factor_x.set(0.5f);
scale_factor_y.set(0.5f);
Buffer<uint8_t> out = f.realize({8, 8});
assert(load_count == 64);
}

printf("Success!\n");
return 0;
}
Loading