From 7cfc5626ee1ca91b94c37998bed77c8501098d9e Mon Sep 17 00:00:00 2001 From: kzrnm Date: Sat, 27 Jan 2024 19:10:49 +0900 Subject: [PATCH] Optimize NumberToBigInteger --- .../src/System/Number.BigInteger.cs | 360 ++++++++++++------ .../Numerics/BigIntegerCalculator.SquMul.cs | 6 +- 2 files changed, 243 insertions(+), 123 deletions(-) diff --git a/src/libraries/System.Runtime.Numerics/src/System/Number.BigInteger.cs b/src/libraries/System.Runtime.Numerics/src/System/Number.BigInteger.cs index 169188a22133e2..65ab3ef40bf808 100644 --- a/src/libraries/System.Runtime.Numerics/src/System/Number.BigInteger.cs +++ b/src/libraries/System.Runtime.Numerics/src/System/Number.BigInteger.cs @@ -671,49 +671,33 @@ internal const int s_naiveThreshold = 20000; private static ParsingStatus NumberToBigInteger(ref NumberBuffer number, out BigInteger result) { - int currentBufferSize = 0; - - int totalDigitCount = 0; - int numberScale = number.Scale; - const int MaxPartialDigits = 9; const uint TenPowMaxPartial = 1000000000; - int[]? arrayFromPoolForResultBuffer = null; - - if (numberScale == int.MaxValue) + if ((uint)number.Scale >= int.MaxValue) { result = default; - return ParsingStatus.Overflow; + return number.Scale == int.MaxValue + ? ParsingStatus.Overflow + : ParsingStatus.Failed; } - if (numberScale < 0) + if (number.Scale <= s_naiveThreshold) { - result = default; - return ParsingStatus.Failed; + return Naive(ref number, out result); } - - try + else { - if (number.DigitsCount <= s_naiveThreshold) - { - return Naive(ref number, out result); - } - else - { - return DivideAndConquer(ref number, out result); - } - } - finally - { - if (arrayFromPoolForResultBuffer != null) - { - ArrayPool.Shared.Return(arrayFromPoolForResultBuffer); - } + return DivideAndConquer(ref number, out result); } - ParsingStatus Naive(ref NumberBuffer number, out BigInteger result) + static ParsingStatus Naive(ref NumberBuffer number, out BigInteger result) { + int numberScale = number.Scale; + + uint[]? arrayFromPoolForResultBuffer = null; + int currentBufferSize = 0; + int totalDigitCount = 0; Span stackBuffer = stackalloc uint[BigIntegerCalculator.StackAllocThreshold]; Span currentBuffer = stackBuffer; uint partialValue = 0; @@ -730,7 +714,18 @@ ParsingStatus Naive(ref NumberBuffer number, out BigInteger result) MultiplyAdd(ref currentBuffer, UInt32PowersOfTen[partialDigitCount], partialValue); } - result = NumberBufferToBigInteger(currentBuffer, number.IsNegative); + int trailingZeroCount = numberScale - totalDigitCount; + while (trailingZeroCount >= MaxPartialDigits) + { + MultiplyAdd(ref currentBuffer, TenPowMaxPartial, 0); + trailingZeroCount -= MaxPartialDigits; + } + if (trailingZeroCount > 0) + { + MultiplyAdd(ref currentBuffer, UInt32PowersOfTen[trailingZeroCount], 0); + } + + result = NumberBufferToBigInteger(currentBuffer.Slice(0, currentBufferSize), number.IsNegative); return ParsingStatus.OK; bool ProcessChunk(ReadOnlySpan chunkDigits, ref Span currentBuffer) @@ -791,22 +786,75 @@ bool ProcessChunk(ReadOnlySpan chunkDigits, ref Span currentBuffer) return true; } + + // This function should only be used for result buffer. + void MultiplyAdd(ref Span currentBuffer, uint multiplier, uint addValue) + { + Span curBits = currentBuffer.Slice(0, currentBufferSize); + uint carry = addValue; + + for (int i = 0; i < curBits.Length; i++) + { + ulong p = (ulong)multiplier * curBits[i] + carry; + curBits[i] = (uint)p; + carry = (uint)(p >> 32); + } + + if (carry == 0) + { + return; + } + + if (currentBufferSize == currentBuffer.Length) + { + uint[]? arrayToReturn = arrayFromPoolForResultBuffer; + + arrayFromPoolForResultBuffer = ArrayPool.Shared.Rent(checked(currentBufferSize * 2)); + Span newBuffer = new Span(arrayFromPoolForResultBuffer); + currentBuffer.CopyTo(newBuffer); + currentBuffer = newBuffer; + + if (arrayToReturn != null) + { + ArrayPool.Shared.Return(arrayToReturn); + } + } + + currentBuffer[currentBufferSize] = carry; + currentBufferSize++; + } } - ParsingStatus DivideAndConquer(ref NumberBuffer number, out BigInteger result) + static ParsingStatus DivideAndConquer(ref NumberBuffer number, out BigInteger result) { + // log_{2^32}(10^9) + const double digitRatio = 0.934292276687070661; + Span currentBuffer; - int[]? arrayFromPoolForMultiplier = null; + uint[]? arrayFromPoolForMultiplier = null; + uint[]? arrayFromPoolForResultBuffer = null; + uint[]? arrayFromPoolForResultBuffer2 = null; + uint[]? arrayFromPoolForTrailingZero = null; try { - totalDigitCount = Math.Min(number.DigitsCount, numberScale); + int totalDigitCount = Math.Min(number.DigitsCount, number.Scale); + int trailingZeroCount = number.Scale - totalDigitCount; int bufferSize = (totalDigitCount + MaxPartialDigits - 1) / MaxPartialDigits; - Span buffer = new uint[bufferSize]; - arrayFromPoolForResultBuffer = ArrayPool.Shared.Rent(bufferSize); - Span newBuffer = MemoryMarshal.Cast(arrayFromPoolForResultBuffer).Slice(0, bufferSize); + Span buffer = new Span(arrayFromPoolForResultBuffer = ArrayPool.Shared.Rent(bufferSize), 0, bufferSize); + Span newBuffer = new Span(arrayFromPoolForResultBuffer2 = ArrayPool.Shared.Rent(bufferSize), 0, bufferSize); newBuffer.Clear(); + int trailingZeroE9 = Math.DivRem(trailingZeroCount, MaxPartialDigits, out int trailingZeroRemainder); + int trailingZeroBufferLength = checked((int)(digitRatio * (trailingZeroE9 + Math.Max(trailingZeroRemainder, 1))) + 1); + Span trailingZeroBuffer = (trailingZeroBufferLength <= BigIntegerCalculator.StackAllocThreshold + ? stackalloc uint[BigIntegerCalculator.StackAllocThreshold] + : arrayFromPoolForTrailingZero = ArrayPool.Shared.Rent(trailingZeroBufferLength)).Slice(0, trailingZeroBufferLength); + + int currentTrailingZeroBufferLength = 1; + trailingZeroBuffer.Slice(1).Clear(); + trailingZeroBuffer[0] = UInt32PowersOfTen[trailingZeroRemainder]; + // Separate every MaxPartialDigits digits and store them in the buffer. // Buffers are treated as little-endian. That means, the array { 234567890, 1 } // represents the number 1234567890. @@ -855,13 +903,37 @@ ParsingStatus DivideAndConquer(ref NumberBuffer number, out BigInteger result) Debug.Assert(bufferIndex == -1); int blockSize = 1; - arrayFromPoolForMultiplier = ArrayPool.Shared.Rent(blockSize); - Span multiplier = MemoryMarshal.Cast(arrayFromPoolForMultiplier).Slice(0, blockSize); - multiplier[0] = TenPowMaxPartial; + int multiplierSize = 1; + Span multiplier = stackalloc uint[1] { TenPowMaxPartial }; // This loop is executed ceil(log_2(bufferSize)) times. while (true) { + if ((trailingZeroE9 & 1) != 0) + { + uint[]? previousTrailingZeroBufferFromPool = null; + Span previousTrailingZeroBuffer = new Span( + previousTrailingZeroBufferFromPool = ArrayPool.Shared.Rent(currentTrailingZeroBufferLength), + 0, currentTrailingZeroBufferLength); + + trailingZeroBuffer.Slice(0, currentTrailingZeroBufferLength).CopyTo(previousTrailingZeroBuffer); + trailingZeroBuffer.Slice(0, currentTrailingZeroBufferLength).Clear(); + if (multiplier.Length < previousTrailingZeroBuffer.Length) + BigIntegerCalculator.Multiply(previousTrailingZeroBuffer, multiplier, trailingZeroBuffer); + else + BigIntegerCalculator.Multiply(multiplier, previousTrailingZeroBuffer, trailingZeroBuffer); + + currentTrailingZeroBufferLength += multiplier.Length; + while (--currentTrailingZeroBufferLength >= 0 && trailingZeroBuffer[currentTrailingZeroBufferLength] == 0) ; + ++currentTrailingZeroBufferLength; + + if (previousTrailingZeroBufferFromPool != null) + ArrayPool.Shared.Return(previousTrailingZeroBufferFromPool); + + Debug.Assert(currentTrailingZeroBufferLength >= 1); + } + trailingZeroE9 >>= 1; + // merge each block pairs. // When buffer represents: // | A | B | C | D | @@ -878,9 +950,10 @@ ParsingStatus DivideAndConquer(ref NumberBuffer number, out BigInteger result) if (upperLen != 0) { Debug.Assert(blockSize == lowerLen); - Debug.Assert(blockSize == multiplier.Length); - Debug.Assert(multiplier.Length == lowerLen); - BigIntegerCalculator.Multiply(multiplier, curBuffer.Slice(blockSize, upperLen), curNewBuffer.Slice(0, len)); + Debug.Assert(blockSize >= multiplier.Length); + Debug.Assert(multiplier.Length >= curBuffer.Slice(blockSize, upperLen).TrimEnd(0u).Length); + + BigIntegerCalculator.Multiply(multiplier, curBuffer.Slice(blockSize, upperLen).TrimEnd(0u), curNewBuffer.Slice(0, len)); } long carry = 0; @@ -908,122 +981,167 @@ ParsingStatus DivideAndConquer(ref NumberBuffer number, out BigInteger result) Span tmp = buffer; buffer = newBuffer; newBuffer = tmp; - blockSize *= 2; + blockSize <<= 1; if (bufferSize <= blockSize) { break; } + multiplierSize <<= 1; newBuffer.Clear(); - int[]? arrayToReturn = arrayFromPoolForMultiplier; + uint[]? arrayToReturn = arrayFromPoolForMultiplier; + + Span newMultiplier = new Span( + arrayFromPoolForMultiplier = ArrayPool.Shared.Rent(multiplierSize), + 0, multiplierSize); + newMultiplier.Clear(); + BigIntegerCalculator.Square(multiplier, newMultiplier); + multiplier = newMultiplier; + + while (--multiplierSize >= 0 && multiplier[multiplierSize] == 0) ; + multiplier = multiplier.Slice(0, ++multiplierSize); + + if (arrayToReturn is not null) + { + ArrayPool.Shared.Return(arrayToReturn); + } + } + + while (trailingZeroE9 != 0) + { + multiplierSize <<= 1; + uint[]? arrayToReturn = arrayFromPoolForMultiplier; - arrayFromPoolForMultiplier = ArrayPool.Shared.Rent(blockSize); - Span newMultiplier = MemoryMarshal.Cast(arrayFromPoolForMultiplier).Slice(0, blockSize); + Span newMultiplier = new Span( + arrayFromPoolForMultiplier = ArrayPool.Shared.Rent(multiplierSize), + 0, multiplierSize); newMultiplier.Clear(); BigIntegerCalculator.Square(multiplier, newMultiplier); multiplier = newMultiplier; + + while (--multiplierSize >= 0 && multiplier[multiplierSize] == 0) ; + multiplier = multiplier.Slice(0, ++multiplierSize); + if (arrayToReturn is not null) { - ArrayPool.Shared.Return(arrayToReturn); + ArrayPool.Shared.Return(arrayToReturn); } + + if ((trailingZeroE9 & 1) != 0) + { + uint[]? previousTrailingZeroBufferFromPool = null; + Span previousTrailingZeroBuffer = new Span( + previousTrailingZeroBufferFromPool = ArrayPool.Shared.Rent(currentTrailingZeroBufferLength), + 0, currentTrailingZeroBufferLength); + + trailingZeroBuffer.Slice(0, currentTrailingZeroBufferLength).CopyTo(previousTrailingZeroBuffer); + trailingZeroBuffer.Slice(0, currentTrailingZeroBufferLength).Clear(); + if (multiplier.Length < previousTrailingZeroBuffer.Length) + BigIntegerCalculator.Multiply(previousTrailingZeroBuffer, multiplier, trailingZeroBuffer); + else + BigIntegerCalculator.Multiply(multiplier, previousTrailingZeroBuffer, trailingZeroBuffer); + + currentTrailingZeroBufferLength += multiplier.Length; + while (--currentTrailingZeroBufferLength >= 0 && trailingZeroBuffer[currentTrailingZeroBufferLength] == 0) ; + ++currentTrailingZeroBufferLength; + + if (previousTrailingZeroBufferFromPool != null) + ArrayPool.Shared.Return(previousTrailingZeroBufferFromPool); + + Debug.Assert(currentTrailingZeroBufferLength >= 1); + } + trailingZeroE9 >>= 1; } + // shrink buffer to the currently used portion. // First, calculate the rough size of the buffer from the ratio that the number // of digits follows. Then, shrink the size until there is no more space left. // The Ratio is calculated as: log_{2^32}(10^9) - const double digitRatio = 0.934292276687070661; - currentBufferSize = Math.Min((int)(bufferSize * digitRatio) + 1, bufferSize); - Debug.Assert(buffer.Length == currentBufferSize || buffer[currentBufferSize] == 0); - while (0 < currentBufferSize && buffer[currentBufferSize - 1] == 0) - { - currentBufferSize--; - } + int currentBufferSize = Math.Min((int)(bufferSize * digitRatio) + 1, bufferSize); + Debug.Assert(buffer.Length == currentBufferSize || buffer.Slice(currentBufferSize).Trim(0u).Length == 0); + while (--currentBufferSize >= 0 && buffer[currentBufferSize] == 0) ; + ++currentBufferSize; currentBuffer = buffer.Slice(0, currentBufferSize); - result = NumberBufferToBigInteger(currentBuffer, number.IsNegative); - } - finally - { - if (arrayFromPoolForMultiplier != null) + + trailingZeroBuffer = trailingZeroBuffer.Slice(0, currentTrailingZeroBufferLength); + if (trailingZeroBuffer.Length <= 1) { - ArrayPool.Shared.Return(arrayFromPoolForMultiplier); - } - } - return ParsingStatus.OK; - } + Debug.Assert(trailingZeroBuffer.Length == 1); + uint trailingZero = trailingZeroBuffer[0]; + if (trailingZero != 1) + { + int i = 0; + ulong carry = 0UL; - BigInteger NumberBufferToBigInteger(Span currentBuffer, bool signa) - { - int trailingZeroCount = numberScale - totalDigitCount; + for (; i < currentBuffer.Length; i++) + { + ulong digits = (ulong)currentBuffer[i] * trailingZero + carry; + currentBuffer[i] = unchecked((uint)digits); + carry = digits >> 32; + } + if (carry != 0) + { + currentBuffer = buffer.Slice(0, ++currentBufferSize); + currentBuffer[i] = (uint)carry; + } + } - while (trailingZeroCount >= MaxPartialDigits) - { - MultiplyAdd(ref currentBuffer, TenPowMaxPartial, 0); - trailingZeroCount -= MaxPartialDigits; - } + result = NumberBufferToBigInteger(currentBuffer, number.IsNegative); + } + else + { + int resultBufferLength = checked(currentBufferSize + trailingZeroBuffer.Length); + Span resultBuffer = (resultBufferLength <= BigIntegerCalculator.StackAllocThreshold + ? stackalloc uint[BigIntegerCalculator.StackAllocThreshold] + : arrayFromPoolForTrailingZero = ArrayPool.Shared.Rent(resultBufferLength)).Slice(0, resultBufferLength); + resultBuffer.Clear(); - if (trailingZeroCount > 0) - { - MultiplyAdd(ref currentBuffer, UInt32PowersOfTen[trailingZeroCount], 0); - } + if (trailingZeroBuffer.Length < currentBuffer.Length) + BigIntegerCalculator.Multiply(currentBuffer, trailingZeroBuffer, resultBuffer); + else + BigIntegerCalculator.Multiply(trailingZeroBuffer, currentBuffer, resultBuffer); - int sign; - uint[]? bits; + while (--resultBufferLength >= 0 && resultBuffer[resultBufferLength] == 0) ; + ++resultBufferLength; - if (currentBufferSize == 0) - { - sign = 0; - bits = null; - } - else if (currentBufferSize == 1 && currentBuffer[0] <= int.MaxValue) - { - sign = (int)(signa ? -currentBuffer[0] : currentBuffer[0]); - bits = null; + result = NumberBufferToBigInteger(resultBuffer.Slice(0, resultBufferLength), number.IsNegative); + } } - else + finally { - sign = signa ? -1 : 1; - bits = currentBuffer.Slice(0, currentBufferSize).ToArray(); - } + if (arrayFromPoolForMultiplier != null) + ArrayPool.Shared.Return(arrayFromPoolForMultiplier); + + if (arrayFromPoolForResultBuffer != null) + ArrayPool.Shared.Return(arrayFromPoolForResultBuffer); + + if (arrayFromPoolForResultBuffer2 != null) + ArrayPool.Shared.Return(arrayFromPoolForResultBuffer2); - return new BigInteger(sign, bits); + if (arrayFromPoolForTrailingZero != null) + ArrayPool.Shared.Return(arrayFromPoolForTrailingZero); + } + return ParsingStatus.OK; } - // This function should only be used for result buffer. - void MultiplyAdd(ref Span currentBuffer, uint multiplier, uint addValue) + static BigInteger NumberBufferToBigInteger(ReadOnlySpan currentBuffer, bool isNegative) { - Span curBits = currentBuffer.Slice(0, currentBufferSize); - uint carry = addValue; - - for (int i = 0; i < curBits.Length; i++) + if (currentBuffer.Length == 0) { - ulong p = (ulong)multiplier * curBits[i] + carry; - curBits[i] = (uint)p; - carry = (uint)(p >> 32); + return new BigInteger(0); } - - if (carry == 0) + else if (currentBuffer.Length == 1 && currentBuffer[0] <= int.MaxValue) { - return; + int v = (int)currentBuffer[0]; + if (isNegative) + v = -v; + return new BigInteger(v, null); } - - if (currentBufferSize == currentBuffer.Length) + else { - int[]? arrayToReturn = arrayFromPoolForResultBuffer; - - arrayFromPoolForResultBuffer = ArrayPool.Shared.Rent(checked(currentBufferSize * 2)); - Span newBuffer = MemoryMarshal.Cast(arrayFromPoolForResultBuffer); - currentBuffer.CopyTo(newBuffer); - currentBuffer = newBuffer; - - if (arrayToReturn != null) - { - ArrayPool.Shared.Return(arrayToReturn); - } + return new BigInteger(isNegative ? -1 : 1, currentBuffer.ToArray()); } - - currentBuffer[currentBufferSize] = carry; - currentBufferSize++; } } diff --git a/src/libraries/System.Runtime.Numerics/src/System/Numerics/BigIntegerCalculator.SquMul.cs b/src/libraries/System.Runtime.Numerics/src/System/Numerics/BigIntegerCalculator.SquMul.cs index a31153ceec8348..0530c84124dec6 100644 --- a/src/libraries/System.Runtime.Numerics/src/System/Numerics/BigIntegerCalculator.SquMul.cs +++ b/src/libraries/System.Runtime.Numerics/src/System/Numerics/BigIntegerCalculator.SquMul.cs @@ -163,6 +163,8 @@ public static void Multiply(ReadOnlySpan left, ReadOnlySpan right, S Debug.Assert(bits.Length >= left.Length + right.Length); Debug.Assert(!bits.ContainsAnyExcept(0u)); + bits = bits.Slice(0, left.Length + right.Length); + if (left.Length - right.Length < 3) { MultiplyNearLength(left, right, bits); @@ -176,7 +178,7 @@ public static void Multiply(ReadOnlySpan left, ReadOnlySpan right, S private static void MultiplyFarLength(ReadOnlySpan left, ReadOnlySpan right, Span bits) { Debug.Assert(left.Length - right.Length >= 3); - Debug.Assert(bits.Length >= left.Length + right.Length); + Debug.Assert(bits.Length == left.Length + right.Length); Debug.Assert(!bits.ContainsAnyExcept(0u)); // Executes different algorithms for computing z = a * b @@ -372,7 +374,7 @@ stackalloc uint[StackAllocThreshold] private static void MultiplyNearLength(ReadOnlySpan left, ReadOnlySpan right, Span bits) { Debug.Assert(left.Length - right.Length < 3); - Debug.Assert(bits.Length >= left.Length + right.Length); + Debug.Assert(bits.Length == left.Length + right.Length); Debug.Assert(!bits.ContainsAnyExcept(0u)); // Executes different algorithms for computing z = a * b