diff --git a/distribution/lib/Standard/Base/0.0.0-dev/src/Data/Numbers.enso b/distribution/lib/Standard/Base/0.0.0-dev/src/Data/Numbers.enso index 5bce5d9ed060..d1d279ffb4e9 100644 --- a/distribution/lib/Standard/Base/0.0.0-dev/src/Data/Numbers.enso +++ b/distribution/lib/Standard/Base/0.0.0-dev/src/Data/Numbers.enso @@ -9,7 +9,9 @@ import project.Function.Function import project.Internal.Rounding_Helpers import project.Nothing.Nothing import project.Panic.Panic + from project.Data.Boolean import Boolean, False, True +from project.Errors.Common import Unsupported_Argument_Types polyglot java import java.lang.Double polyglot java import java.lang.Long @@ -613,10 +615,12 @@ type Decimal 2.5 . round use_bankers=True == 2 round : Integer -> Boolean -> Integer | Decimal ! Illegal_Argument - round self decimal_places=0 use_bankers=False = - Illegal_Argument.handle_java_exception <| Arithmetic_Error.handle_java_exception <| - decimal_result = Core_Math_Utils.roundDouble self decimal_places use_bankers - if decimal_places > 0 then decimal_result else decimal_result.truncate + round self (decimal_places:Integer=0) (use_bankers:Boolean=False) = + report_unsupported cp = Error.throw (Illegal_Argument.Error cp.payload.message) + Panic.catch Unsupported_Argument_Types handler=report_unsupported + Decimal.round_decimal_builtin self decimal_places use_bankers + + round_decimal_builtin n decimal_places use_bankers = @Builtin_Method "Decimal.round" ## Compute the negation of this. @@ -919,25 +923,15 @@ type Integer 12250 . round -2 use_bankers=True == 12200 round : Integer -> Boolean -> Integer ! Illegal_Argument - round self decimal_places=0 use_bankers=False = - ## It's already an integer so unless decimal_places is - negative, the value is unchanged. - if decimal_places >= 0 then self else - Rounding_Helpers.check_decimal_places decimal_places <| Rounding_Helpers.check_round_input self <| - scale = 10 ^ -decimal_places - halfway = scale.div 2 - remainder = self % scale - scaled_down = self.div scale - result_unnudged = scaled_down * scale - case self >= 0 of - True -> - half_goes_up = if use_bankers then (scaled_down % 2) != 0 else True - round_up = if half_goes_up then remainder >= halfway else remainder > halfway - if round_up then result_unnudged + scale else result_unnudged - False -> - half_goes_up = if use_bankers then (scaled_down % 2) == 0 else False - round_up = if half_goes_up then remainder < -halfway else remainder <= -halfway - if round_up then result_unnudged - scale else result_unnudged + round self (decimal_places:Integer=0) (use_bankers:Boolean=False) = + ## We reject values outside the range of `long` here, so we don't also + do this check in the Java. + Rounding_Helpers.check_round_input self <| + report_unsupported cp = Error.throw (Illegal_Argument.Error cp.payload.message) + Panic.catch Unsupported_Argument_Types handler=report_unsupported + Integer.round_integer_builtin self decimal_places use_bankers + + round_integer_builtin n decimal_places use_bankers = @Builtin_Method "Integer.round" ## Compute the negation of this. diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/number/decimal/RoundNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/number/decimal/RoundNode.java new file mode 100644 index 000000000000..bdaef39105dc --- /dev/null +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/number/decimal/RoundNode.java @@ -0,0 +1,74 @@ +package org.enso.interpreter.node.expression.builtin.number.decimal; + +import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; +import com.oracle.truffle.api.nodes.Node; +import com.oracle.truffle.api.profiles.BranchProfile; +import com.oracle.truffle.api.profiles.CountingConditionProfile; +import com.oracle.truffle.api.profiles.PrimitiveValueProfile; +import java.math.BigDecimal; +import java.math.BigInteger; +import org.enso.interpreter.dsl.BuiltinMethod; +import org.enso.interpreter.node.expression.builtin.number.utils.BigIntegerOps; +import org.enso.interpreter.node.expression.builtin.number.utils.RoundHelpers; +import org.enso.interpreter.runtime.number.EnsoBigInteger; + +@BuiltinMethod( + type = "Decimal", + name = "round", + description = "Decimal ceiling, converting to a small or big integer depending on size.") +public class RoundNode extends Node { + private final CountingConditionProfile fitsProfile = CountingConditionProfile.create(); + + private final PrimitiveValueProfile constantPlacesDecimalPlaces = PrimitiveValueProfile.create(); + + private final PrimitiveValueProfile constantPlacesUseBankers = PrimitiveValueProfile.create(); + + private final BranchProfile decimalPlacesOutOfRangeProfile = BranchProfile.create(); + + private final BranchProfile outOfRangeProfile = BranchProfile.create(); + + Object execute(double n, long dp, boolean ub) { + long decimalPlaces = constantPlacesDecimalPlaces.profile(dp); + boolean useBankers = constantPlacesUseBankers.profile(ub); + + if (decimalPlaces < RoundHelpers.ROUND_MIN_DECIMAL_PLACES + || decimalPlaces > RoundHelpers.ROUND_MAX_DECIMAL_PLACES) { + decimalPlacesOutOfRangeProfile.enter(); + RoundHelpers.decimalPlacesOutOfRangePanic(this, decimalPlaces); + } + + boolean inRange = n >= RoundHelpers.ROUND_MIN_DOUBLE && n <= RoundHelpers.ROUND_MAX_DOUBLE; + if (!inRange) { + outOfRangeProfile.enter(); + if (Double.isNaN(n) || Double.isInfinite(n)) { + RoundHelpers.specialValuePanic(this, n); + } else { + RoundHelpers.argumentOutOfRangePanic(this, n); + } + } + + // Algorithm taken from https://stackoverflow.com/a/7211688. + double scale = Math.pow(10.0, decimalPlaces); + double scaled = n * scale; + double roundBase = Math.floor(scaled); + double roundMidpoint = (roundBase + 0.5) / scale; + boolean evenIsUp = n >= 0 ? (((long) scaled) % 2) != 0 : (((long) scaled) % 2) == 0; + boolean halfGoesUp = useBankers ? evenIsUp : n >= 0; + boolean doRoundUp = halfGoesUp ? n >= roundMidpoint : n > roundMidpoint; + double resultUncast = doRoundUp ? ((roundBase + 1.0) / scale) : (roundBase / scale); + if (decimalPlaces > 0) { + return resultUncast; + } else { + if (fitsProfile.profile(BigIntegerOps.fitsInLong(resultUncast))) { + return (long) resultUncast; + } else { + return new EnsoBigInteger(toBigInteger(resultUncast)); + } + } + } + + @TruffleBoundary + private BigInteger toBigInteger(double n) { + return BigDecimal.valueOf(n).toBigIntegerExact(); + } +} diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/number/smallInteger/RoundNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/number/smallInteger/RoundNode.java new file mode 100644 index 000000000000..83aeee35e020 --- /dev/null +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/number/smallInteger/RoundNode.java @@ -0,0 +1,54 @@ +package org.enso.interpreter.node.expression.builtin.number.smallInteger; + +import com.oracle.truffle.api.nodes.Node; +import com.oracle.truffle.api.profiles.BranchProfile; +import com.oracle.truffle.api.profiles.CountingConditionProfile; +import com.oracle.truffle.api.profiles.PrimitiveValueProfile; +import org.enso.interpreter.dsl.BuiltinMethod; +import org.enso.interpreter.node.expression.builtin.number.utils.RoundHelpers; + +@BuiltinMethod( + type = "Integer", + name = "round", + description = "Decimal ceiling, converting to a small or big integer depending on size.") +public class RoundNode extends Node { + private final CountingConditionProfile fitsProfile = CountingConditionProfile.create(); + + private final PrimitiveValueProfile constantPlacesDecimalPlaces = PrimitiveValueProfile.create(); + + private final PrimitiveValueProfile constantPlacesUseBankers = PrimitiveValueProfile.create(); + + private final BranchProfile decimalPlacesOutOfRangeProfile = BranchProfile.create(); + + Object execute(long n, long dp, boolean ub) { + var decimalPlaces = constantPlacesDecimalPlaces.profile(dp); + + // We don't check if `n` is out of range here, since the Enso wrapper does. + if (decimalPlaces < RoundHelpers.ROUND_MIN_DECIMAL_PLACES + || decimalPlaces > RoundHelpers.ROUND_MAX_DECIMAL_PLACES) { + decimalPlacesOutOfRangeProfile.enter(); + RoundHelpers.decimalPlacesOutOfRangePanic(this, decimalPlaces); + } + + if (decimalPlaces >= 0) { + return n; + } + + var useBankers = constantPlacesUseBankers.profile(ub); + long scale = (long) Math.pow(10, -decimalPlaces); + long halfway = scale / 2; + long remainder = n % scale; + long scaledDown = n / scale; + long resultUnnudged = scaledDown * scale; + + if (n >= 0) { + boolean halfGoesUp = useBankers ? (scaledDown % 2) != 0 : true; + boolean roundUp = halfGoesUp ? remainder >= halfway : remainder > halfway; + return roundUp ? resultUnnudged + scale : resultUnnudged; + } else { + boolean halfGoesUp = useBankers ? (scaledDown % 2) == 0 : false; + boolean roundUp = halfGoesUp ? remainder < -halfway : remainder <= -halfway; + return roundUp ? resultUnnudged - scale : resultUnnudged; + } + } +} diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/number/utils/RoundHelpers.java b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/number/utils/RoundHelpers.java new file mode 100644 index 000000000000..a81a4f0ca592 --- /dev/null +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/number/utils/RoundHelpers.java @@ -0,0 +1,63 @@ +package org.enso.interpreter.node.expression.builtin.number.utils; + +import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; +import com.oracle.truffle.api.nodes.Node; +import org.enso.interpreter.runtime.EnsoContext; +import org.enso.interpreter.runtime.builtin.Builtins; +import org.enso.interpreter.runtime.error.PanicException; + +public class RoundHelpers { + /** Minimum value for the `decimal_places` parameter to `roundDouble`. */ + public static final double ROUND_MIN_DECIMAL_PLACES = -15; + + /** Maximum value for the `decimal_places` parameter to `roundDouble`. */ + public static final double ROUND_MAX_DECIMAL_PLACES = 15; + + /** Minimum value for the `n` parameter to `roundDouble`. */ + public static final double ROUND_MIN_DOUBLE = -99999999999999.0; + + /** Minimum value for the `n` parameter to `roundDouble`. */ + public static final double ROUND_MAX_DOUBLE = 99999999999999.0; + + public static final double ROUND_MIN_LONG = -99999999999999L; + + /** Minimum value for the `n` parameter to `roundDouble`. */ + public static final double ROUND_MAX_LONG = 99999999999999L; + + @TruffleBoundary + public static void decimalPlacesOutOfRangePanic(Node node, long decimalPlaces) + throws PanicException { + String msg = + "round: decimalPlaces must be between " + + ROUND_MIN_DECIMAL_PLACES + + " and " + + ROUND_MAX_DECIMAL_PLACES + + " (inclusive), but was " + + decimalPlaces; + Builtins builtins = EnsoContext.get(node).getBuiltins(); + throw new PanicException( + builtins.error().makeUnsupportedArgumentsError(new Object[] {decimalPlaces}, msg), node); + } + + @TruffleBoundary + public static void argumentOutOfRangePanic(Node node, double n) throws PanicException { + String msg = + "Error: `round` can only accept values between " + + ROUND_MIN_DOUBLE + + " and " + + ROUND_MAX_DOUBLE + + " (inclusive), but was " + + n; + Builtins builtins = EnsoContext.get(node).getBuiltins(); + throw new PanicException( + builtins.error().makeUnsupportedArgumentsError(new Object[] {n}, msg), node); + } + + @TruffleBoundary + public static void specialValuePanic(Node node, double n) throws PanicException { + String msg = "Error: `round` cannot accept " + (Double.isNaN(n) ? "NaN" : "Inf") + " values "; + Builtins builtins = EnsoContext.get(node).getBuiltins(); + throw new PanicException( + builtins.error().makeUnsupportedArgumentsError(new Object[] {n}, msg), node); + } +} diff --git a/lib/scala/common-polyglot-core-utils/src/main/java/org/enso/polyglot/common_utils/Core_Math_Utils.java b/lib/scala/common-polyglot-core-utils/src/main/java/org/enso/polyglot/common_utils/Core_Math_Utils.java index a34e24eccadb..d21081795fd9 100644 --- a/lib/scala/common-polyglot-core-utils/src/main/java/org/enso/polyglot/common_utils/Core_Math_Utils.java +++ b/lib/scala/common-polyglot-core-utils/src/main/java/org/enso/polyglot/common_utils/Core_Math_Utils.java @@ -32,6 +32,7 @@ public class Core_Math_Utils { * @throws IllegalArgumentException if `decimalPlaces` is outside the allowed range. */ public static double roundDouble(double n, int decimalPlaces, boolean useBankers) { + if (decimalPlaces < ROUND_MIN_DECIMAL_PLACES || decimalPlaces > ROUND_MAX_DECIMAL_PLACES) { String msg = "round: decimalPlaces must be between " diff --git a/test/Benchmarks/src/Main.enso b/test/Benchmarks/src/Main.enso index d3cc7418f6ba..cafadf13d224 100644 --- a/test/Benchmarks/src/Main.enso +++ b/test/Benchmarks/src/Main.enso @@ -1,10 +1,12 @@ import project.Vector.Distinct import project.Vector.Operations +import project.Numeric all_benchmarks = vec_ops = Operations.all vec_distinct = Distinct.collect_benches - [vec_ops, vec_distinct] + numeric = Numeric.bench + [vec_ops, vec_distinct, numeric] main = all_benchmarks.each suite-> diff --git a/test/Benchmarks/src/Numeric.enso b/test/Benchmarks/src/Numeric.enso index ab9016a7b052..414741877a11 100644 --- a/test/Benchmarks/src/Numeric.enso +++ b/test/Benchmarks/src/Numeric.enso @@ -4,13 +4,15 @@ from Standard.Test import Bench, Faker ## Bench Utilities ============================================================ -vector_size = 1000000 +vector_size = 100000 iter_size = 100 num_iterations = 10 # The Benchmarks ============================================================== -bench = +collect_benches group_builder = + bench_measure ~act name = group_builder.specify name act + ## No specific significance to this constant, just fixed to make generated set deterministic fixed_random_seed = 1644575867 faker = Faker.new fixed_random_seed @@ -25,21 +27,26 @@ bench = name = pair.at 0 fun = pair.at 1 IO.println <| "Benchmarking decimal " + name - Bench.measure (decimals.map fun) name iter_size num_iterations + bench_measure (decimals.map fun) name funs.map pair-> name = pair.at 0 fun = pair.at 1 IO.println <| "Benchmarking integer" + name - Bench.measure (integers.map fun) name iter_size num_iterations + bench_measure (integers.map fun) name [True, False].map use_bankers-> [0, -2, 2].map decimal_places-> - name = "round decimal_places=" + decimal_places.to_text + " use_bankers=" + use_bankers.to_text + name = "round decimal_places=" + (decimal_places.to_text.replace '-' '_') + " use_bankers=" + use_bankers.to_text fun = _.round decimal_places use_bankers - IO.println <| "Benchmarking decimal " + name - Bench.measure (decimals.map fun) name iter_size num_iterations - IO.println <| "Benchmarking integer " + name - Bench.measure (integers.map fun) name iter_size num_iterations + bench_measure (decimals.map fun) "decimal "+name + bench_measure (integers.map fun) "integer "+name + +bench = + options = Bench.options . size iter_size . iter num_iterations + + Bench.build builder-> + builder.group "Numbers" options group_builder-> + collect_benches group_builder -main = bench +main = bench . run_main diff --git a/test/Tests/src/Data/Numbers_Spec.enso b/test/Tests/src/Data/Numbers_Spec.enso index 75d5d2ce12b7..7adf2f43c7dd 100644 --- a/test/Tests/src/Data/Numbers_Spec.enso +++ b/test/Tests/src/Data/Numbers_Spec.enso @@ -1,6 +1,7 @@ from Standard.Base import all import Standard.Base.Errors.Common.Arithmetic_Error import Standard.Base.Errors.Common.Incomparable_Values +import Standard.Base.Errors.Common.Type_Error import Standard.Base.Errors.Illegal_Argument.Illegal_Argument from Standard.Base.Data.Numbers import Number_Parse_Error @@ -659,9 +660,9 @@ spec = 3.1 . round -16 . should_fail_with Illegal_Argument Test.specify "NaN/Inf" <| - Number.nan . round . should_fail_with Arithmetic_Error - Number.positive_infinity . round . should_fail_with Arithmetic_Error - Number.negative_infinity . round . should_fail_with Arithmetic_Error + Number.nan . round . should_fail_with Illegal_Argument + Number.positive_infinity . round . should_fail_with Illegal_Argument + Number.negative_infinity . round . should_fail_with Illegal_Argument Test.specify "Floating point imperfect representation counter-examples" <| 1.225 . round 2 use_bankers=True . should_equal 1.22 # Actual result 1.23 @@ -762,6 +763,11 @@ spec = -12250 . round -2 use_bankers=True . should_equal -12200 -12251 . round -2 use_bankers=True . should_equal -12300 + Test.specify "Handles incorrect argument types" <| + Test.expect_panic_with (123 . round "two") Type_Error + Test.expect_panic_with (123 . round use_bankers="no") Type_Error + Test.expect_panic_with (123 . round use_bankers=0) Type_Error + Test.specify "Returns the correct type" <| 231 . round 1 . should_be_a Integer 231 . round 0 . should_be_a Integer @@ -771,14 +777,22 @@ spec = Test.specify "Input out of range" <| 100000000000000 . round -2 . should_fail_with Illegal_Argument -100000000000000 . round -2 . should_fail_with Illegal_Argument + 100000000000000 . round . should_fail_with Illegal_Argument + -100000000000000 . round . should_fail_with Illegal_Argument + 100000000000000 . round 1 . should_fail_with Illegal_Argument + -100000000000000 . round 1 . should_fail_with Illegal_Argument + 99999999999999 . round . should_equal 99999999999999 + -99999999999999 . round . should_equal -99999999999999 99999999999999 . round -2 . should_equal 100000000000000 -99999999999999 . round -2 . should_equal -100000000000000 - Test.specify "Input out of range is ignored when the implementation returns its argument immediately" <| - 100000000000000 . round . should_equal 100000000000000 - -100000000000000 . round . should_equal -100000000000000 - 100000000000000 . round 1 . should_equal 100000000000000 - -100000000000000 . round 1 . should_equal -100000000000000 + Test.specify "Reject bigints before reaching the Java" <| + 922337203685477580700000 . round . should_fail_with Illegal_Argument + -922337203685477580700000 . round . should_fail_with Illegal_Argument + + Test.specify "Can handle small numbers computed from bigints" <| + (922337203685477580712345 - 922337203685477580700000) . round . should_equal 12345 + ((99999999999998 * 1000).div 1000) . round . should_equal 99999999999998 Test.group "Decimal.truncate" diff --git a/test/Tests/src/Semantic/Meta_Spec.enso b/test/Tests/src/Semantic/Meta_Spec.enso index 10ba52e77762..1634e97df3fe 100644 --- a/test/Tests/src/Semantic/Meta_Spec.enso +++ b/test/Tests/src/Semantic/Meta_Spec.enso @@ -261,7 +261,7 @@ spec = Meta.meta Integer . methods . sort . should_equal ['bit_shift_l', 'round', 'truncate'] Test.specify "static methods of Integer" <| - Meta.meta (Meta.type_of Integer) . methods . sort . should_equal ['bit_shift_l', 'parse', 'parse_builtin', 'round', 'truncate'] + Meta.meta (Meta.type_of Integer) . methods . sort . should_equal ['bit_shift_l', 'parse', 'parse_builtin', 'round', 'round_integer_builtin', 'truncate'] Test.specify "methods of Any" <| Meta.meta Any . methods . should_contain "to_text"