Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Number.round as a builtin #7460

Merged
merged 27 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
cf4d188
decimal
GregoryTravis Aug 1, 2023
cc78415
integer
GregoryTravis Aug 1, 2023
ba63c03
bench size
GregoryTravis Aug 1, 2023
cf8d9ef
Merge branch 'develop' into wip/gmt/7401-round-builtin
GregoryTravis Aug 2, 2023
e80e6b6
move stuff into tb methods
GregoryTravis Aug 2, 2023
ee2cc72
bench builder api
GregoryTravis Aug 2, 2023
6aee4a4
Merge branch 'develop' into wip/gmt/7401-round-builtin
GregoryTravis Aug 4, 2023
6849e36
additional optimizations, handling one error
GregoryTravis Aug 4, 2023
d9499c6
cleanup
GregoryTravis Aug 4, 2023
f4d2014
Merge branch 'develop' into wip/gmt/7401-round-builtin
mergify[bot] Aug 8, 2023
665786c
Allow invocation of Numeric benchmarks via JMH
JaroslavTulach Aug 8, 2023
dc9ebe0
Have to use PrimitiveValueProfile when profiling primitive values
JaroslavTulach Aug 8, 2023
efeb2d7
Make sure compiler knows a branch yields an exception
JaroslavTulach Aug 8, 2023
b57862b
Double.NaN => ROUND_MIN_DOUBLE is false. Also Double.Infinity <= ROUN…
JaroslavTulach Aug 8, 2023
94dac30
Profile to avoid outOfRange exception code
JaroslavTulach Aug 8, 2023
239c25b
Some optimizations for smallInteger.RoundNode
JaroslavTulach Aug 10, 2023
843f8ef
Check decimalPlaces first to finish as soon as possible
JaroslavTulach Aug 10, 2023
249a36b
Merge branch 'develop' into wip/gmt/7401-round-builtin
GregoryTravis Aug 10, 2023
3034fad
fix up error handling, tests
GregoryTravis Aug 10, 2023
87a1b40
cleanup
GregoryTravis Aug 10, 2023
62063cc
cleanup
GregoryTravis Aug 10, 2023
e7e67e8
rename builtin stubs
GregoryTravis Aug 10, 2023
d004233
handle bigints, do range check in enso
GregoryTravis Aug 10, 2023
64f0da6
bigint tests
GregoryTravis Aug 10, 2023
5fbeb1a
remove old, restore bench
GregoryTravis Aug 10, 2023
70d23ae
added test for type errors
GregoryTravis Aug 11, 2023
92cab6e
fix meta_spec
GregoryTravis Aug 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 17 additions & 23 deletions distribution/lib/Standard/Base/0.0.0-dev/src/Data/Numbers.enso
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ import project.Function.Function
import project.Internal.Rounding_Helpers
import project.Nothing.Nothing
import project.Panic.Panic

from project.Data.Boolean import Boolean, False, True
from project.Errors.Common import Unsupported_Argument_Types

polyglot java import java.lang.Double
polyglot java import java.lang.Long
Expand Down Expand Up @@ -613,10 +615,12 @@ type Decimal

2.5 . round use_bankers=True == 2
round : Integer -> Boolean -> Integer | Decimal ! Illegal_Argument
round self decimal_places=0 use_bankers=False =
Illegal_Argument.handle_java_exception <| Arithmetic_Error.handle_java_exception <|
decimal_result = Core_Math_Utils.roundDouble self decimal_places use_bankers
if decimal_places > 0 then decimal_result else decimal_result.truncate
round self (decimal_places:Integer=0) (use_bankers:Boolean=False) =
report_unsupported cp = Error.throw (Illegal_Argument.Error cp.payload.message)
Panic.catch Unsupported_Argument_Types handler=report_unsupported
Decimal.round_decimal_builtin self decimal_places use_bankers

round_decimal_builtin n decimal_places use_bankers = @Builtin_Method "Decimal.round"

## Compute the negation of this.

Expand Down Expand Up @@ -919,25 +923,15 @@ type Integer

12250 . round -2 use_bankers=True == 12200
round : Integer -> Boolean -> Integer ! Illegal_Argument
round self decimal_places=0 use_bankers=False =
## It's already an integer so unless decimal_places is
negative, the value is unchanged.
if decimal_places >= 0 then self else
Rounding_Helpers.check_decimal_places decimal_places <| Rounding_Helpers.check_round_input self <|
scale = 10 ^ -decimal_places
halfway = scale.div 2
remainder = self % scale
scaled_down = self.div scale
result_unnudged = scaled_down * scale
case self >= 0 of
True ->
half_goes_up = if use_bankers then (scaled_down % 2) != 0 else True
round_up = if half_goes_up then remainder >= halfway else remainder > halfway
if round_up then result_unnudged + scale else result_unnudged
False ->
half_goes_up = if use_bankers then (scaled_down % 2) == 0 else False
round_up = if half_goes_up then remainder < -halfway else remainder <= -halfway
if round_up then result_unnudged - scale else result_unnudged
round self (decimal_places:Integer=0) (use_bankers:Boolean=False) =
## We reject values outside the range of `long` here, so we don't also
do this check in the Java.
Rounding_Helpers.check_round_input self <|
report_unsupported cp = Error.throw (Illegal_Argument.Error cp.payload.message)
Panic.catch Unsupported_Argument_Types handler=report_unsupported
Integer.round_integer_builtin self decimal_places use_bankers

round_integer_builtin n decimal_places use_bankers = @Builtin_Method "Integer.round"

## Compute the negation of this.

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package org.enso.interpreter.node.expression.builtin.number.decimal;

import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.profiles.BranchProfile;
import com.oracle.truffle.api.profiles.CountingConditionProfile;
import com.oracle.truffle.api.profiles.PrimitiveValueProfile;
import java.math.BigDecimal;
import java.math.BigInteger;
import org.enso.interpreter.dsl.BuiltinMethod;
import org.enso.interpreter.node.expression.builtin.number.utils.BigIntegerOps;
import org.enso.interpreter.node.expression.builtin.number.utils.RoundHelpers;
import org.enso.interpreter.runtime.number.EnsoBigInteger;

@BuiltinMethod(
type = "Decimal",
name = "round",
description = "Decimal ceiling, converting to a small or big integer depending on size.")
public class RoundNode extends Node {
private final CountingConditionProfile fitsProfile = CountingConditionProfile.create();

private final PrimitiveValueProfile constantPlacesDecimalPlaces = PrimitiveValueProfile.create();

private final PrimitiveValueProfile constantPlacesUseBankers = PrimitiveValueProfile.create();

private final BranchProfile decimalPlacesOutOfRangeProfile = BranchProfile.create();

private final BranchProfile outOfRangeProfile = BranchProfile.create();

Object execute(double n, long dp, boolean ub) {
long decimalPlaces = constantPlacesDecimalPlaces.profile(dp);
boolean useBankers = constantPlacesUseBankers.profile(ub);

if (decimalPlaces < RoundHelpers.ROUND_MIN_DECIMAL_PLACES
|| decimalPlaces > RoundHelpers.ROUND_MAX_DECIMAL_PLACES) {
decimalPlacesOutOfRangeProfile.enter();
RoundHelpers.decimalPlacesOutOfRangePanic(this, decimalPlaces);
}

boolean inRange = n >= RoundHelpers.ROUND_MIN_DOUBLE && n <= RoundHelpers.ROUND_MAX_DOUBLE;
if (!inRange) {
outOfRangeProfile.enter();
if (Double.isNaN(n) || Double.isInfinite(n)) {
RoundHelpers.specialValuePanic(this, n);
} else {
RoundHelpers.argumentOutOfRangePanic(this, n);
}
}

// Algorithm taken from https://stackoverflow.com/a/7211688.
double scale = Math.pow(10.0, decimalPlaces);
double scaled = n * scale;
double roundBase = Math.floor(scaled);
double roundMidpoint = (roundBase + 0.5) / scale;
boolean evenIsUp = n >= 0 ? (((long) scaled) % 2) != 0 : (((long) scaled) % 2) == 0;
boolean halfGoesUp = useBankers ? evenIsUp : n >= 0;
boolean doRoundUp = halfGoesUp ? n >= roundMidpoint : n > roundMidpoint;
double resultUncast = doRoundUp ? ((roundBase + 1.0) / scale) : (roundBase / scale);
if (decimalPlaces > 0) {
return resultUncast;
} else {
if (fitsProfile.profile(BigIntegerOps.fitsInLong(resultUncast))) {
return (long) resultUncast;
} else {
return new EnsoBigInteger(toBigInteger(resultUncast));
}
}
}

@TruffleBoundary
private BigInteger toBigInteger(double n) {
return BigDecimal.valueOf(n).toBigIntegerExact();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package org.enso.interpreter.node.expression.builtin.number.smallInteger;

import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.profiles.BranchProfile;
import com.oracle.truffle.api.profiles.CountingConditionProfile;
import com.oracle.truffle.api.profiles.PrimitiveValueProfile;
import org.enso.interpreter.dsl.BuiltinMethod;
import org.enso.interpreter.node.expression.builtin.number.utils.RoundHelpers;

@BuiltinMethod(
type = "Integer",
name = "round",
description = "Decimal ceiling, converting to a small or big integer depending on size.")
public class RoundNode extends Node {
private final CountingConditionProfile fitsProfile = CountingConditionProfile.create();

private final PrimitiveValueProfile constantPlacesDecimalPlaces = PrimitiveValueProfile.create();

private final PrimitiveValueProfile constantPlacesUseBankers = PrimitiveValueProfile.create();

private final BranchProfile decimalPlacesOutOfRangeProfile = BranchProfile.create();

Object execute(long n, long dp, boolean ub) {
var decimalPlaces = constantPlacesDecimalPlaces.profile(dp);

// We don't check if `n` is out of range here, since the Enso wrapper does.
if (decimalPlaces < RoundHelpers.ROUND_MIN_DECIMAL_PLACES
|| decimalPlaces > RoundHelpers.ROUND_MAX_DECIMAL_PLACES) {
decimalPlacesOutOfRangeProfile.enter();
RoundHelpers.decimalPlacesOutOfRangePanic(this, decimalPlaces);
}

if (decimalPlaces >= 0) {
return n;
}

var useBankers = constantPlacesUseBankers.profile(ub);
long scale = (long) Math.pow(10, -decimalPlaces);
long halfway = scale / 2;
long remainder = n % scale;
long scaledDown = n / scale;
long resultUnnudged = scaledDown * scale;

if (n >= 0) {
boolean halfGoesUp = useBankers ? (scaledDown % 2) != 0 : true;
boolean roundUp = halfGoesUp ? remainder >= halfway : remainder > halfway;
return roundUp ? resultUnnudged + scale : resultUnnudged;
} else {
boolean halfGoesUp = useBankers ? (scaledDown % 2) == 0 : false;
boolean roundUp = halfGoesUp ? remainder < -halfway : remainder <= -halfway;
return roundUp ? resultUnnudged - scale : resultUnnudged;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package org.enso.interpreter.node.expression.builtin.number.utils;

import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.nodes.Node;
import org.enso.interpreter.runtime.EnsoContext;
import org.enso.interpreter.runtime.builtin.Builtins;
import org.enso.interpreter.runtime.error.PanicException;

public class RoundHelpers {
/** Minimum value for the `decimal_places` parameter to `roundDouble`. */
public static final double ROUND_MIN_DECIMAL_PLACES = -15;

/** Maximum value for the `decimal_places` parameter to `roundDouble`. */
public static final double ROUND_MAX_DECIMAL_PLACES = 15;

/** Minimum value for the `n` parameter to `roundDouble`. */
public static final double ROUND_MIN_DOUBLE = -99999999999999.0;

/** Minimum value for the `n` parameter to `roundDouble`. */
public static final double ROUND_MAX_DOUBLE = 99999999999999.0;

public static final double ROUND_MIN_LONG = -99999999999999L;

/** Minimum value for the `n` parameter to `roundDouble`. */
public static final double ROUND_MAX_LONG = 99999999999999L;

@TruffleBoundary
public static void decimalPlacesOutOfRangePanic(Node node, long decimalPlaces)
throws PanicException {
String msg =
"round: decimalPlaces must be between "
+ ROUND_MIN_DECIMAL_PLACES
+ " and "
+ ROUND_MAX_DECIMAL_PLACES
+ " (inclusive), but was "
+ decimalPlaces;
Builtins builtins = EnsoContext.get(node).getBuiltins();
throw new PanicException(
builtins.error().makeUnsupportedArgumentsError(new Object[] {decimalPlaces}, msg), node);
}

@TruffleBoundary
public static void argumentOutOfRangePanic(Node node, double n) throws PanicException {
String msg =
"Error: `round` can only accept values between "
+ ROUND_MIN_DOUBLE
+ " and "
+ ROUND_MAX_DOUBLE
+ " (inclusive), but was "
+ n;
Builtins builtins = EnsoContext.get(node).getBuiltins();
throw new PanicException(
builtins.error().makeUnsupportedArgumentsError(new Object[] {n}, msg), node);
}

@TruffleBoundary
public static void specialValuePanic(Node node, double n) throws PanicException {
String msg = "Error: `round` cannot accept " + (Double.isNaN(n) ? "NaN" : "Inf") + " values ";
Builtins builtins = EnsoContext.get(node).getBuiltins();
throw new PanicException(
builtins.error().makeUnsupportedArgumentsError(new Object[] {n}, msg), node);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public class Core_Math_Utils {
* @throws IllegalArgumentException if `decimalPlaces` is outside the allowed range.
*/
public static double roundDouble(double n, int decimalPlaces, boolean useBankers) {

if (decimalPlaces < ROUND_MIN_DECIMAL_PLACES || decimalPlaces > ROUND_MAX_DECIMAL_PLACES) {
String msg =
"round: decimalPlaces must be between "
Expand Down
4 changes: 3 additions & 1 deletion test/Benchmarks/src/Main.enso
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import project.Vector.Distinct
import project.Vector.Operations
import project.Numeric

all_benchmarks =
vec_ops = Operations.all
vec_distinct = Distinct.collect_benches
[vec_ops, vec_distinct]
numeric = Numeric.bench
[vec_ops, vec_distinct, numeric]

main =
all_benchmarks.each suite->
Expand Down
27 changes: 17 additions & 10 deletions test/Benchmarks/src/Numeric.enso
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ from Standard.Test import Bench, Faker

## Bench Utilities ============================================================

vector_size = 1000000
vector_size = 100000
iter_size = 100
num_iterations = 10

# The Benchmarks ==============================================================

bench =
collect_benches group_builder =
bench_measure ~act name = group_builder.specify name act

## No specific significance to this constant, just fixed to make generated set deterministic
fixed_random_seed = 1644575867
faker = Faker.new fixed_random_seed
Expand All @@ -25,21 +27,26 @@ bench =
name = pair.at 0
fun = pair.at 1
IO.println <| "Benchmarking decimal " + name
Bench.measure (decimals.map fun) name iter_size num_iterations
bench_measure (decimals.map fun) name

funs.map pair->
name = pair.at 0
fun = pair.at 1
IO.println <| "Benchmarking integer" + name
Bench.measure (integers.map fun) name iter_size num_iterations
bench_measure (integers.map fun) name

[True, False].map use_bankers->
[0, -2, 2].map decimal_places->
name = "round decimal_places=" + decimal_places.to_text + " use_bankers=" + use_bankers.to_text
name = "round decimal_places=" + (decimal_places.to_text.replace '-' '_') + " use_bankers=" + use_bankers.to_text
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to_text.replace could be removed when https://github.com/enso-org/enso/pull/7519/files#r1288361395 is in.

fun = _.round decimal_places use_bankers
IO.println <| "Benchmarking decimal " + name
Bench.measure (decimals.map fun) name iter_size num_iterations
IO.println <| "Benchmarking integer " + name
Bench.measure (integers.map fun) name iter_size num_iterations
bench_measure (decimals.map fun) "decimal "+name
bench_measure (integers.map fun) "integer "+name

bench =
options = Bench.options . size iter_size . iter num_iterations

Bench.build builder->
builder.group "Numbers" options group_builder->
collect_benches group_builder

main = bench
main = bench . run_main
30 changes: 22 additions & 8 deletions test/Tests/src/Data/Numbers_Spec.enso
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from Standard.Base import all
import Standard.Base.Errors.Common.Arithmetic_Error
import Standard.Base.Errors.Common.Incomparable_Values
import Standard.Base.Errors.Common.Type_Error
import Standard.Base.Errors.Illegal_Argument.Illegal_Argument

from Standard.Base.Data.Numbers import Number_Parse_Error
Expand Down Expand Up @@ -659,9 +660,9 @@ spec =
3.1 . round -16 . should_fail_with Illegal_Argument

Test.specify "NaN/Inf" <|
Number.nan . round . should_fail_with Arithmetic_Error
Number.positive_infinity . round . should_fail_with Arithmetic_Error
Number.negative_infinity . round . should_fail_with Arithmetic_Error
Number.nan . round . should_fail_with Illegal_Argument
Number.positive_infinity . round . should_fail_with Illegal_Argument
Number.negative_infinity . round . should_fail_with Illegal_Argument

Test.specify "Floating point imperfect representation counter-examples" <|
1.225 . round 2 use_bankers=True . should_equal 1.22 # Actual result 1.23
Expand Down Expand Up @@ -762,6 +763,11 @@ spec =
-12250 . round -2 use_bankers=True . should_equal -12200
-12251 . round -2 use_bankers=True . should_equal -12300

Test.specify "Handles incorrect argument types" <|
Test.expect_panic_with (123 . round "two") Type_Error
Test.expect_panic_with (123 . round use_bankers="no") Type_Error
Test.expect_panic_with (123 . round use_bankers=0) Type_Error

Test.specify "Returns the correct type" <|
231 . round 1 . should_be_a Integer
231 . round 0 . should_be_a Integer
Expand All @@ -771,14 +777,22 @@ spec =
Test.specify "Input out of range" <|
100000000000000 . round -2 . should_fail_with Illegal_Argument
-100000000000000 . round -2 . should_fail_with Illegal_Argument
100000000000000 . round . should_fail_with Illegal_Argument
-100000000000000 . round . should_fail_with Illegal_Argument
100000000000000 . round 1 . should_fail_with Illegal_Argument
-100000000000000 . round 1 . should_fail_with Illegal_Argument
99999999999999 . round . should_equal 99999999999999
-99999999999999 . round . should_equal -99999999999999
99999999999999 . round -2 . should_equal 100000000000000
-99999999999999 . round -2 . should_equal -100000000000000

Test.specify "Input out of range is ignored when the implementation returns its argument immediately" <|
100000000000000 . round . should_equal 100000000000000
-100000000000000 . round . should_equal -100000000000000
100000000000000 . round 1 . should_equal 100000000000000
-100000000000000 . round 1 . should_equal -100000000000000
Test.specify "Reject bigints before reaching the Java" <|
922337203685477580700000 . round . should_fail_with Illegal_Argument
-922337203685477580700000 . round . should_fail_with Illegal_Argument

Test.specify "Can handle small numbers computed from bigints" <|
(922337203685477580712345 - 922337203685477580700000) . round . should_equal 12345
((99999999999998 * 1000).div 1000) . round . should_equal 99999999999998

Test.group "Decimal.truncate"

Expand Down
Loading