Skip to content

Commit

Permalink
UnresolvedSymbol is now accepted by Vector.sort (#6334)
Browse files Browse the repository at this point in the history
`Vector.sort` does some custom method dispatch logic which always expected a function as `by` and `on` arguments. At the same time, `UnresolvedSymbol` is treated like a (to be resolved) `Function` and under normal circumstances there would be no difference between `_.foo` and `.foo` provided as arguments.

Rather than adding an additional phase that does some form of eta-expansion, to accomodate for this custom dispatch, this change only fixes the problem locally. We accept `Function` and `UnresolvedSymbol` and perform the resolution on the fly. Ideally, we would have a specialization on the latter but again, it would be dependent on the contents of the `Vector` so unclear if that is better.

Closes #6276,

# Important Notes
There was a suggestion to somehow modify our codegen to accomodate for this scenario but I went against it. In fact a lot of name literals have `isMethod` flag and that information is used in the passes but it should not control how (late) codegen is done. If we were to make this more generic, I would suggest maybe to add separate eta-expansion pass. But it could affect other things and could be potentially a significant change with limited potential initially, so potential future work item.
  • Loading branch information
hubertp authored Apr 20, 2023
1 parent dd4dce2 commit 6d3151f
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
import org.enso.interpreter.dsl.AcceptsError;
import org.enso.interpreter.dsl.BuiltinMethod;
import org.enso.interpreter.node.callable.dispatch.CallOptimiserNode;
import org.enso.interpreter.node.callable.resolver.MethodResolverNode;
import org.enso.interpreter.node.expression.builtin.interop.syntax.HostValueToEnsoNode;
import org.enso.interpreter.node.expression.builtin.meta.EqualsNode;
import org.enso.interpreter.node.expression.builtin.meta.TypeOfNode;
import org.enso.interpreter.node.expression.builtin.text.AnyToTextNode;
import org.enso.interpreter.runtime.EnsoContext;
import org.enso.interpreter.runtime.callable.UnresolvedSymbol;
import org.enso.interpreter.runtime.callable.atom.Atom;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.data.Array;
Expand All @@ -39,6 +41,7 @@
import org.enso.interpreter.runtime.error.Warning;
import org.enso.interpreter.runtime.error.WarningsLibrary;
import org.enso.interpreter.runtime.error.WithWarnings;
import org.enso.interpreter.runtime.library.dispatch.TypesLibrary;
import org.enso.interpreter.runtime.state.State;

/**
Expand Down Expand Up @@ -179,10 +182,12 @@ Object sortGeneric(
long problemBehaviorNum,
@CachedLibrary(limit = "10") InteropLibrary interop,
@CachedLibrary(limit = "5") WarningsLibrary warningsLib,
@CachedLibrary(limit = "5") TypesLibrary typesLib,
@Cached LessThanNode lessThanNode,
@Cached EqualsNode equalsNode,
@Cached TypeOfNode typeOfNode,
@Cached AnyToTextNode toTextNode,
@Cached MethodResolverNode methodResolverNode,
@Cached(value = "build()", uncached = "build()") HostValueToEnsoNode hostValueToEnsoNode,
@Cached(value = "build()", uncached = "build()") CallOptimiserNode callNode) {
var problemBehavior = ProblemBehavior.fromInt((int) problemBehaviorNum);
Expand Down Expand Up @@ -230,7 +235,9 @@ Object sortGeneric(
less,
equal,
greater,
interop);
interop,
typesLib,
methodResolverNode);
}
group.elems.sort(javaComparator);
if (javaComparator.hasWarnings()) {
Expand Down Expand Up @@ -647,6 +654,94 @@ private int getPrimitiveValueCost(Object object) {
}
}

/**
* Helper class that returns the comparator function.
*
* The class is introduced to handle the presence of {@code UnresolvedSymbol},
* as the comparator function, which has to be first resolved before it
* can be used to compare values.
*/
private abstract class Compare {

/**
* Test if the comparator function has self argument.
*
* @param definedOn the value on which the function is defined on.
* @return true if self argument is present, false otherwise.
*/
abstract boolean hasFunctionSelfArgument(Object definedOn);

/**
* Return a comparator function.
*
* @param arg the value on which the function is defined on.
* @return a non-null comparator function.
*/
abstract Function get(Object arg);

}

private final class CompareFromFunction extends Compare {

private final Function function;

private CompareFromFunction(Function function) {
this.function = function;
}

@Override
boolean hasFunctionSelfArgument(Object definedOn) {
if (function.getSchema().getArgumentsCount() > 0) {
return function.getSchema().getArgumentInfos()[0].getName().equals("self");
} else {
return false;
}
}

@Override
Function get(Object arg) {
return function;
}
}

private class CompareFromUnresolvedSymbol extends Compare {

private final UnresolvedSymbol unresolvedSymbol;
private final MethodResolverNode methodResolverNode;
private final TypesLibrary typesLibrary;

private @CompilerDirectives.CompilationFinal Function resolvedFunction;

private CompareFromUnresolvedSymbol(UnresolvedSymbol unresolvedSymbol,
MethodResolverNode methodResolvedNode,
TypesLibrary typesLibrary) {
this.unresolvedSymbol = unresolvedSymbol;
this.methodResolverNode = methodResolvedNode;
this.typesLibrary = typesLibrary;

}

@Override
boolean hasFunctionSelfArgument(Object definedOn) {
ensureSymbolIsResolved(definedOn);
return resolvedFunction.getSchema().getArgumentsCount() > 0 &&
resolvedFunction.getSchema().getArgumentInfos()[0].getName().equals("self");

}

@Override
Function get(Object arg) {
ensureSymbolIsResolved(arg);
return resolvedFunction;
}

private void ensureSymbolIsResolved(Object definedOn) {
if (resolvedFunction == null) {
resolvedFunction = methodResolverNode.expectNonNull(definedOn, typesLibrary.getType(definedOn), unresolvedSymbol);
}
}
}

/**
* Comparator for any values. This comparator compares the values by calling back to Enso (by
* {@link #compareFunc}), rather than using compare nodes (i.e. {@link LessThanNode}). directly,
Expand All @@ -659,9 +754,9 @@ private final class GenericSortComparator extends SortComparator {
* Either function from `by` parameter to the `Vector.sort` method, or the `compare` function
* extracted from the comparator for the appropriate group.
*/
private final Function compareFunc;
private final Compare compareFunc;

private final Function onFunc;
private final Compare onFunc;
private final boolean hasCustomOnFunc;
private final Type comparator;
private final CallOptimiserNode callNode;
Expand All @@ -682,20 +777,22 @@ private GenericSortComparator(
Atom less,
Atom equal,
Atom greater,
InteropLibrary interop) {
InteropLibrary interop,
TypesLibrary typesLibrary,
MethodResolverNode methodResolverNode) {
super(toTextNode, problemBehavior, interop);
assert compareFunc != null;
assert comparator != null;
this.comparator = comparator;
this.state = state;
this.ascending = ascending;
this.compareFunc = checkAndConvertByFunc(compareFunc);
this.compareFunc = checkAndConvertByFunc(compareFunc, typesLibrary, methodResolverNode);
if (interop.isNull(onFunc)) {
this.hasCustomOnFunc = false;
this.onFunc = null;
} else {
this.hasCustomOnFunc = true;
this.onFunc = checkAndConvertOnFunc(onFunc);
this.onFunc = checkAndConvertOnFunc(onFunc, typesLibrary, methodResolverNode);
}
this.callNode = callNode;
this.less = less;
Expand All @@ -709,19 +806,19 @@ public int compare(Object x, Object y) {
Object yConverted;
if (hasCustomOnFunc) {
// onFunc cannot have `self` argument, we assume it has just one argument.
xConverted = callNode.executeDispatch(onFunc, null, state, new Object[] {x});
yConverted = callNode.executeDispatch(onFunc, null, state, new Object[] {y});
xConverted = callNode.executeDispatch(onFunc.get(x), null, state, new Object[]{x});
yConverted = callNode.executeDispatch(onFunc.get(y), null, state, new Object[]{y});
} else {
xConverted = x;
yConverted = y;
}
Object[] args;
if (hasFunctionSelfArgument(compareFunc)) {
if (compareFunc.hasFunctionSelfArgument(xConverted)) {
args = new Object[] {comparator, xConverted, yConverted};
} else {
args = new Object[] {xConverted, yConverted};
}
Object res = callNode.executeDispatch(compareFunc, null, state, args);
Object res = callNode.executeDispatch(compareFunc.get(xConverted), null, state, args);
if (res == less) {
return ascending ? -1 : 1;
} else if (res == equal) {
Expand All @@ -738,43 +835,43 @@ public int compare(Object x, Object y) {
}
}

private boolean hasFunctionSelfArgument(Function function) {
if (function.getSchema().getArgumentsCount() > 0) {
return function.getSchema().getArgumentInfos()[0].getName().equals("self");
} else {
return false;
}
}

/**
* Checks value given for {@code by} parameter and converts it to {@link Function}. Throw a
* dataflow error otherwise.
*/
private Function checkAndConvertByFunc(Object byFuncObj) {
private Compare checkAndConvertByFunc(Object byFuncObj, TypesLibrary typesLibrary, MethodResolverNode methodResolverNode) {
return checkAndConvertFunction(
byFuncObj, "Unsupported argument for `by`, expected a method with two arguments", 2, 3);
byFuncObj, "Unsupported argument for `by`, expected a method with two arguments", 2, 3,
typesLibrary, methodResolverNode);
}

/**
* Checks the value given for {@code on} parameter and converts it to {@link Function}. Throws a
* dataflow error otherwise.
*/
private Function checkAndConvertOnFunc(Object onFuncObj) {
private Compare checkAndConvertOnFunc(Object onFuncObj, TypesLibrary typesLibrary, MethodResolverNode methodResolverNode) {
return checkAndConvertFunction(
onFuncObj, "Unsupported argument for `on`, expected a method with one argument", 1, 1);
onFuncObj, "Unsupported argument for `on`, expected a method with one argument", 1, 1,
typesLibrary, methodResolverNode);
}

/**
* @param minArgCount Minimal count of arguments without a default value.
* @param maxArgCount Maximal count of argument without a default value.
* @param methodResolverNode node for resolving unresolved symbols.
* @param typesLibrary types library for resolving the dispatch type for unresolved symbols.
*/
private Function checkAndConvertFunction(
Object funcObj, String errMsg, int minArgCount, int maxArgCount) {
private Compare checkAndConvertFunction(
Object funcObj, String errMsg, int minArgCount, int maxArgCount,
TypesLibrary typesLibrary, MethodResolverNode methodResolverNode) {
if (funcObj instanceof UnresolvedSymbol unresolved) {
return new CompareFromUnresolvedSymbol(unresolved, methodResolverNode, typesLibrary);
}
var err = new IllegalArgumentException(errMsg + ", got " + funcObj);
if (funcObj instanceof Function func) {
var argCount = getNumberOfNonDefaultArguments(func);
if (minArgCount <= argCount && argCount <= maxArgCount) {
return func;
return new CompareFromFunction(func);
} else {
throw err;
}
Expand Down
8 changes: 8 additions & 0 deletions test/Benchmarks/src/Vector/Sort.enso
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,19 @@ make_partially_sorted_vec n =
run_length.put (run_length.get - 1)
num

type Int
Value v

identity self = self


# The Benchmarks ==============================================================

bench =
sorted_vec = make_sorted_ascending_vec vector_size
partially_sorted_vec = make_partially_sorted_vec vector_size
random_vec = Utils.make_random_vec vector_size
random_vec_wrapped = random_vec.map (v -> Int.Value v)
projection = x -> x % 10
comparator = l -> r -> Ordering.compare l r

Expand All @@ -66,6 +72,8 @@ bench =
Bench.measure (random_vec.sort) "Random Elements Ascending" iter_size num_iterations
Bench.measure (random_vec.sort Sort_Direction.Descending) "Random Elements Descending" iter_size num_iterations
Bench.measure (random_vec.sort on=projection) "Sorting with a Custom Projection" iter_size num_iterations
Bench.measure (random_vec_wrapped.sort on=(_.identity)) "Sorting with an identity function" iter_size num_iterations
Bench.measure (random_vec_wrapped.sort on=(.identity)) "Sorting with an (unresolved) identity function" iter_size num_iterations
Bench.measure (random_vec.sort by=comparator) "Sorting with the Default_Ordered_Comparator" iter_size num_iterations

main = bench
1 change: 1 addition & 0 deletions test/Tests/src/Data/Vector_Spec.enso
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,7 @@ type_spec name alter = Test.group name <|
small_vec = alter [T.Value 1 8, T.Value 1 3, T.Value -20 0, T.Value -1 1, T.Value -1 10, T.Value 4 0]
small_expected = [T.Value -20 0, T.Value 4 0, T.Value -1 1, T.Value 1 3, T.Value 1 8, T.Value -1 10]
small_vec.sort (on = _.b) . should_equal small_expected
small_vec.sort (on = .b) . should_equal small_expected

Test.specify "should be able to use a custom compare function" <|
small_vec = alter [2, 7, -3, 383, -392, 28, -90]
Expand Down

0 comments on commit 6d3151f

Please sign in to comment.