diff --git a/src/libraries/System.Collections/src/System/Collections/BitArray.cs b/src/libraries/System.Collections/src/System/Collections/BitArray.cs index 35a747764f7b16..0bc7e497c615c3 100644 --- a/src/libraries/System.Collections/src/System/Collections/BitArray.cs +++ b/src/libraries/System.Collections/src/System/Collections/BitArray.cs @@ -7,6 +7,7 @@ using System.Runtime.CompilerServices; using System.Runtime.Intrinsics; using System.Runtime.Intrinsics.X86; +using System.Runtime.Intrinsics.Arm; namespace System.Collections { @@ -113,6 +114,14 @@ public BitArray(byte[] bytes) _version = 0; } + private static readonly Vector128 s_bitMask128 = BitConverter.IsLittleEndian ? + Vector128.Create(0x80402010_08040201).AsByte() : + Vector128.Create(0x01020408_10204080).AsByte(); + + private const uint Vector128ByteCount = 16; + private const uint Vector128IntCount = 4; + private const uint Vector256ByteCount = 32; + private const uint Vector256IntCount = 8; public unsafe BitArray(bool[] values) { if (values == null) @@ -123,7 +132,7 @@ public unsafe BitArray(bool[] values) m_array = new int[GetInt32ArrayLengthFromBitLength(values.Length)]; m_length = values.Length; - int i = 0; + uint i = 0; if (values.Length < Vector256.Count) { @@ -136,42 +145,84 @@ public unsafe BitArray(bool[] values) if (Avx2.IsSupported) { + // JIT does not support code hoisting for SIMD yet + Vector256 zero = Vector256.Zero; fixed (bool* ptr = values) { - for (; (i + Vector256.Count) <= values.Length; i += Vector256.Count) + for (; (i + Vector256ByteCount) <= (uint)values.Length; i += Vector256ByteCount) { Vector256 vector = Avx.LoadVector256((byte*)ptr + i); - Vector256 isFalse = Avx2.CompareEqual(vector, Vector256.Zero); + Vector256 isFalse = Avx2.CompareEqual(vector, zero); int result = Avx2.MoveMask(isFalse); - m_array[i / 32] = ~result; + m_array[i / 32u] = ~result; } } } else if (Sse2.IsSupported) { + // JIT does not support code hoisting for SIMD yet + Vector128 zero = Vector128.Zero; fixed (bool* ptr = values) { - for (; (i + Vector128.Count * 2) <= values.Length; i += Vector128.Count * 2) + for (; (i + Vector128ByteCount * 2u) <= (uint)values.Length; i += Vector128ByteCount * 2u) { Vector128 lowerVector = Sse2.LoadVector128((byte*)ptr + i); - Vector128 lowerIsFalse = Sse2.CompareEqual(lowerVector, Vector128.Zero); + Vector128 lowerIsFalse = Sse2.CompareEqual(lowerVector, zero); int lowerPackedIsFalse = Sse2.MoveMask(lowerIsFalse); Vector128 upperVector = Sse2.LoadVector128((byte*)ptr + i + Vector128.Count); - Vector128 upperIsFalse = Sse2.CompareEqual(upperVector, Vector128.Zero); + Vector128 upperIsFalse = Sse2.CompareEqual(upperVector, zero); int upperPackedIsFalse = Sse2.MoveMask(upperIsFalse); - m_array[i / 32] = ~((upperPackedIsFalse << 16) | lowerPackedIsFalse); + m_array[i / 32u] = ~((upperPackedIsFalse << 16) | lowerPackedIsFalse); + } + } + } + else if (AdvSimd.Arm64.IsSupported) + { + // JIT does not support code hoisting for SIMD yet + // However comparison against zero can be replaced to cmeq against zero (vceqzq_s8) + // See dotnet/runtime#33972 for details + Vector128 zero = Vector128.Zero; + fixed (bool* ptr = values) + { + for (; (i + Vector128ByteCount * 2u) <= (uint)values.Length; i += Vector128ByteCount * 2u) + { + // Same logic as SSE2 path, however we lack MoveMask (equivalent) instruction + // As a workaround, mask out the relevant bit after comparison + // and combine by ORing all of them together (In this case, adding all of them does the same thing) + Vector128 lowerVector = AdvSimd.LoadVector128((byte*)ptr + i); + Vector128 lowerIsFalse = AdvSimd.CompareEqual(lowerVector, zero); + Vector128 bitsExtracted1 = AdvSimd.And(lowerIsFalse, s_bitMask128); + bitsExtracted1 = AdvSimd.Arm64.AddPairwise(bitsExtracted1, bitsExtracted1); + bitsExtracted1 = AdvSimd.Arm64.AddPairwise(bitsExtracted1, bitsExtracted1); + bitsExtracted1 = AdvSimd.Arm64.AddPairwise(bitsExtracted1, bitsExtracted1); + Vector128 lowerPackedIsFalse = bitsExtracted1.AsInt16(); + + Vector128 upperVector = AdvSimd.LoadVector128((byte*)ptr + i + Vector128.Count); + Vector128 upperIsFalse = AdvSimd.CompareEqual(upperVector, zero); + Vector128 bitsExtracted2 = AdvSimd.And(upperIsFalse, s_bitMask128); + bitsExtracted2 = AdvSimd.Arm64.AddPairwise(bitsExtracted2, bitsExtracted2); + bitsExtracted2 = AdvSimd.Arm64.AddPairwise(bitsExtracted2, bitsExtracted2); + bitsExtracted2 = AdvSimd.Arm64.AddPairwise(bitsExtracted2, bitsExtracted2); + Vector128 upperPackedIsFalse = bitsExtracted2.AsInt16(); + + int result = AdvSimd.Arm64.ZipLow(lowerPackedIsFalse, upperPackedIsFalse).AsInt32().ToScalar(); + if (!BitConverter.IsLittleEndian) + { + result = BinaryPrimitives.ReverseEndianness(result); + } + m_array[i / 32u] = ~result; } } } LessThan32: - for (; i < values.Length; i++) + for (; i < (uint)values.Length; i++) { if (values[i]) { - int elementIndex = Div32Rem(i, out int extraBits); + int elementIndex = Div32Rem((int)i, out int extraBits); m_array[elementIndex] |= 1 << extraBits; } } @@ -284,9 +335,13 @@ public void Set(int index, bool value) =========================================================================*/ public void SetAll(bool value) { - int fillValue = value ? -1 : 0; int arrayLength = GetInt32ArrayLengthFromBitLength(Length); - m_array.AsSpan(0, arrayLength).Fill(fillValue); + Span span = m_array.AsSpan(0, arrayLength); + if (value) + span.Fill(-1); + else + span.Clear(); + _version++; } @@ -327,13 +382,13 @@ public unsafe BitArray And(BitArray value) case 0: goto Done; } - int i = 0; + uint i = 0; if (Avx2.IsSupported) { fixed (int* leftPtr = thisArray) fixed (int* rightPtr = valueArray) { - for (; i < count - (Vector256.Count - 1); i += Vector256.Count) + for (; i < (uint)count - (Vector256IntCount - 1u); i += Vector256IntCount) { Vector256 leftVec = Avx.LoadVector256(leftPtr + i); Vector256 rightVec = Avx.LoadVector256(rightPtr + i); @@ -346,7 +401,7 @@ public unsafe BitArray And(BitArray value) fixed (int* leftPtr = thisArray) fixed (int* rightPtr = valueArray) { - for (; i < count - (Vector128.Count - 1); i += Vector128.Count) + for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount) { Vector128 leftVec = Sse2.LoadVector128(leftPtr + i); Vector128 rightVec = Sse2.LoadVector128(rightPtr + i); @@ -354,8 +409,21 @@ public unsafe BitArray And(BitArray value) } } } + else if (AdvSimd.IsSupported) + { + fixed (int* leftPtr = thisArray) + fixed (int* rightPtr = valueArray) + { + for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount) + { + Vector128 leftVec = AdvSimd.LoadVector128(leftPtr + i); + Vector128 rightVec = AdvSimd.LoadVector128(rightPtr + i); + AdvSimd.Store(leftPtr + i, AdvSimd.And(leftVec, rightVec)); + } + } + } - for (; i < count; i++) + for (; i < (uint)count; i++) thisArray[i] &= valueArray[i]; Done: @@ -400,13 +468,13 @@ public unsafe BitArray Or(BitArray value) case 0: goto Done; } - int i = 0; + uint i = 0; if (Avx2.IsSupported) { fixed (int* leftPtr = thisArray) fixed (int* rightPtr = valueArray) { - for (; i < count - (Vector256.Count - 1); i += Vector256.Count) + for (; i < (uint)count - (Vector256IntCount - 1u); i += Vector256IntCount) { Vector256 leftVec = Avx.LoadVector256(leftPtr + i); Vector256 rightVec = Avx.LoadVector256(rightPtr + i); @@ -419,7 +487,7 @@ public unsafe BitArray Or(BitArray value) fixed (int* leftPtr = thisArray) fixed (int* rightPtr = valueArray) { - for (; i < count - (Vector128.Count - 1); i += Vector128.Count) + for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount) { Vector128 leftVec = Sse2.LoadVector128(leftPtr + i); Vector128 rightVec = Sse2.LoadVector128(rightPtr + i); @@ -427,8 +495,21 @@ public unsafe BitArray Or(BitArray value) } } } + else if (AdvSimd.IsSupported) + { + fixed (int* leftPtr = thisArray) + fixed (int* rightPtr = valueArray) + { + for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount) + { + Vector128 leftVec = AdvSimd.LoadVector128(leftPtr + i); + Vector128 rightVec = AdvSimd.LoadVector128(rightPtr + i); + AdvSimd.Store(leftPtr + i, AdvSimd.Or(leftVec, rightVec)); + } + } + } - for (; i < count; i++) + for (; i < (uint)count; i++) thisArray[i] |= valueArray[i]; Done: @@ -473,13 +554,13 @@ public unsafe BitArray Xor(BitArray value) case 0: goto Done; } - int i = 0; + uint i = 0; if (Avx2.IsSupported) { fixed (int* leftPtr = m_array) fixed (int* rightPtr = value.m_array) { - for (; i < count - (Vector256.Count - 1); i += Vector256.Count) + for (; i < (uint)count - (Vector256IntCount - 1u); i += Vector256IntCount) { Vector256 leftVec = Avx.LoadVector256(leftPtr + i); Vector256 rightVec = Avx.LoadVector256(rightPtr + i); @@ -492,7 +573,7 @@ public unsafe BitArray Xor(BitArray value) fixed (int* leftPtr = thisArray) fixed (int* rightPtr = valueArray) { - for (; i < count - (Vector128.Count - 1); i += Vector128.Count) + for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount) { Vector128 leftVec = Sse2.LoadVector128(leftPtr + i); Vector128 rightVec = Sse2.LoadVector128(rightPtr + i); @@ -500,8 +581,21 @@ public unsafe BitArray Xor(BitArray value) } } } + else if (AdvSimd.IsSupported) + { + fixed (int* leftPtr = thisArray) + fixed (int* rightPtr = valueArray) + { + for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount) + { + Vector128 leftVec = AdvSimd.LoadVector128(leftPtr + i); + Vector128 rightVec = AdvSimd.LoadVector128(rightPtr + i); + AdvSimd.Store(leftPtr + i, AdvSimd.Xor(leftVec, rightVec)); + } + } + } - for (; i < count; i++) + for (; i < (uint)count; i++) thisArray[i] ^= valueArray[i]; Done: @@ -538,13 +632,13 @@ public unsafe BitArray Not() case 0: goto Done; } - int i = 0; + uint i = 0; if (Avx2.IsSupported) { Vector256 ones = Vector256.Create(-1); fixed (int* ptr = thisArray) { - for (; i < count - (Vector256.Count - 1); i += Vector256.Count) + for (; i < (uint)count - (Vector256IntCount - 1u); i += Vector256IntCount) { Vector256 vec = Avx.LoadVector256(ptr + i); Avx.Store(ptr + i, Avx2.Xor(vec, ones)); @@ -556,15 +650,26 @@ public unsafe BitArray Not() Vector128 ones = Vector128.Create(-1); fixed (int* ptr = thisArray) { - for (; i < count - (Vector128.Count - 1); i += Vector128.Count) + for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount) { Vector128 vec = Sse2.LoadVector128(ptr + i); Sse2.Store(ptr + i, Sse2.Xor(vec, ones)); } } } + else if (AdvSimd.IsSupported) + { + fixed (int* leftPtr = thisArray) + { + for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount) + { + Vector128 leftVec = AdvSimd.LoadVector128(leftPtr + i); + AdvSimd.Store(leftPtr + i, AdvSimd.Not(leftVec)); + } + } + } - for (; i < count; i++) + for (; i < (uint)count; i++) thisArray[i] = ~thisArray[i]; Done: @@ -739,7 +844,7 @@ public int Length // On little endian machines, the lower 8 bits of int belong in the first byte, next lower 8 in the second and so on. // We place the bytes that contain the bits to its respective byte so that we can mask out only the relevant bits later. private static readonly Vector128 s_lowerShuffleMask_CopyToBoolArray = Vector128.Create(0, 0x01010101_01010101).AsByte(); - private static readonly Vector128 s_upperShuffleMask_CopyToBoolArray = Vector128.Create(0x_02020202_02020202, 0x03030303_03030303).AsByte(); + private static readonly Vector128 s_upperShuffleMask_CopyToBoolArray = Vector128.Create(0x02020202_02020202, 0x03030303_03030303).AsByte(); public unsafe void CopyTo(Array array, int index) { @@ -826,7 +931,7 @@ public unsafe void CopyTo(Array array, int index) throw new ArgumentException(SR.Argument_InvalidOffLen); } - int i = 0; + uint i = 0; if (m_length < BitsPerInt32) goto LessThan32; @@ -839,9 +944,9 @@ public unsafe void CopyTo(Array array, int index) fixed (bool* destination = &boolArray[index]) { - for (; (i + Vector256.Count) <= m_length; i += Vector256.Count) + for (; (i + Vector256ByteCount) <= (uint)m_length; i += Vector256ByteCount) { - int bits = m_array[i / BitsPerInt32]; + int bits = m_array[i / (uint)BitsPerInt32]; Vector256 scalar = Vector256.Create(bits); Vector256 shuffled = Avx2.Shuffle(scalar.AsByte(), shuffleMask); Vector256 extracted = Avx2.And(shuffled, bitMask); @@ -857,34 +962,72 @@ public unsafe void CopyTo(Array array, int index) { Vector128 lowerShuffleMask = s_lowerShuffleMask_CopyToBoolArray; Vector128 upperShuffleMask = s_upperShuffleMask_CopyToBoolArray; - Vector128 bitMask = Vector128.Create(0x80402010_08040201).AsByte(); ; Vector128 ones = Vector128.Create((byte)1); fixed (bool* destination = &boolArray[index]) { - for (; (i + Vector128.Count * 2) <= m_length; i += Vector128.Count * 2) + for (; (i + Vector128ByteCount * 2u) <= (uint)m_length; i += Vector128ByteCount * 2u) { - int bits = m_array[i / BitsPerInt32]; + int bits = m_array[i / (uint)BitsPerInt32]; Vector128 scalar = Vector128.CreateScalarUnsafe(bits); Vector128 shuffledLower = Ssse3.Shuffle(scalar.AsByte(), lowerShuffleMask); - Vector128 extractedLower = Sse2.And(shuffledLower, bitMask); + Vector128 extractedLower = Sse2.And(shuffledLower, s_bitMask128); Vector128 normalizedLower = Sse2.Min(extractedLower, ones); Sse2.Store((byte*)destination + i, normalizedLower); Vector128 shuffledHigher = Ssse3.Shuffle(scalar.AsByte(), upperShuffleMask); - Vector128 extractedHigher = Sse2.And(shuffledHigher, bitMask); + Vector128 extractedHigher = Sse2.And(shuffledHigher, s_bitMask128); Vector128 normalizedHigher = Sse2.Min(extractedHigher, ones); Sse2.Store((byte*)destination + i + Vector128.Count, normalizedHigher); } } } + else if (AdvSimd.IsSupported) + { + Vector128 ones = Vector128.Create((byte)1); + fixed (bool* destination = &boolArray[index]) + { + for (; (i + Vector128ByteCount * 2u) <= (uint)m_length; i += Vector128ByteCount * 2u) + { + int bits = m_array[i / (uint)BitsPerInt32]; + // Same logic as SSSE3 path, except we do not have Shuffle instruction. + // (TableVectorLookup could be an alternative - dotnet/runtime#1277) + // Instead we use chained ZIP1/2 instructions: + // (A0 is the byte containing LSB, A3 is the byte containing MSB) + // bits (on Big endian) - A3 A2 A1 A0 + // bits (Little endian) / Byte reversal - A0 A1 A2 A3 + // v1 = Vector128.Create - A0 A1 A2 A3 A0 A1 A2 A3 A0 A1 A2 A3 A0 A1 A2 A3 + // v2 = ZipLow(v1, v1) - A0 A0 A1 A1 A2 A2 A3 A3 A0 A0 A1 A1 A2 A2 A3 A3 + // v3 = ZipLow(v2, v2) - A0 A0 A0 A0 A1 A1 A1 A1 A2 A2 A2 A2 A3 A3 A3 A3 + // shuffledLower = ZipLow(v3, v3) - A0 A0 A0 A0 A0 A0 A0 A0 A1 A1 A1 A1 A1 A1 A1 A1 + // shuffledHigher = ZipHigh(v3, v3) - A2 A2 A2 A2 A2 A2 A2 A2 A3 A3 A3 A3 A3 A3 A3 A3 + if (!BitConverter.IsLittleEndian) + { + bits = BinaryPrimitives.ReverseEndianness(bits); + } + Vector128 vector = Vector128.Create(bits).AsByte(); + vector = AdvSimd.Arm64.ZipLow(vector, vector); + vector = AdvSimd.Arm64.ZipLow(vector, vector); + + Vector128 shuffledLower = AdvSimd.Arm64.ZipLow(vector, vector); + Vector128 extractedLower = AdvSimd.And(shuffledLower, s_bitMask128); + Vector128 normalizedLower = AdvSimd.Min(extractedLower, ones); + AdvSimd.Store((byte*)destination + i, normalizedLower); + + Vector128 shuffledHigher = AdvSimd.Arm64.ZipHigh(vector, vector); + Vector128 extractedHigher = AdvSimd.And(shuffledHigher, s_bitMask128); + Vector128 normalizedHigher = AdvSimd.Min(extractedHigher, ones); + AdvSimd.Store((byte*)destination + i + Vector128.Count, normalizedHigher); + } + } + } LessThan32: - for (; i < m_length; i++) + for (; i < (uint)m_length; i++) { - int elementIndex = Div32Rem(i, out int extraBits); - boolArray[index + i] = ((m_array[elementIndex] >> extraBits) & 0x00000001) != 0; + int elementIndex = Div32Rem((int)i, out int extraBits); + boolArray[(uint)index + i] = ((m_array[elementIndex] >> extraBits) & 0x00000001) != 0; } } else