Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve vectorization of String.Split #64899

Merged
merged 9 commits into from
Mar 24, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,26 @@ public ref T this[int index]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void Append(T item)
{
Span<T> span = _span;
int pos = _pos;
if (pos >= _span.Length)
Grow();
if ((uint)pos < (uint)span.Length)
{
span[pos] = item;
_pos = pos + 1;
}
else
{
AddWithResize(item);
}
}

// Hide uncommon path
[MethodImpl(MethodImplOptions.NoInlining)]
private void AddWithResize(T item)
{
Debug.Assert(_pos == _span.Length);
int pos = _pos;
Grow();
_span[pos] = item;
_pos = pos + 1;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1637,27 +1637,13 @@ private void MakeSeparatorList(ReadOnlySpan<char> separators, ref ValueListBuild
}

// Special-case the common cases of 1, 2, and 3 separators, with manual comparisons against each separator.
else if (separators.Length <= 3)
else if (separators.Length <= 3u)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it affect codegen?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it got rid of redundant range checks for separators, doing (uint)separators.Length <= (uint)3 is one movsxd less, but I personally thought this was cleaner. However, I can see it being too obscure with it's intent.

{
char sep0, sep1, sep2;
sep0 = separators[0];
sep1 = separators.Length > 1 ? separators[1] : sep0;
sep2 = separators.Length > 2 ? separators[2] : sep1;

if (Length >= 16 && Sse41.IsSupported)
{
MakeSeparatorListVectorized(ref sepListBuilder, sep0, sep1, sep2);
return;
}

for (int i = 0; i < Length; i++)
{
char c = this[i];
if (c == sep0 || c == sep1 || c == sep2)
{
sepListBuilder.Append(i);
}
}
MakeSeparatorListVectorized(ref sepListBuilder, sep0, sep1, sep2);
}

// Handle > 3 separators with a probabilistic map, ala IndexOfAny.
Expand Down Expand Up @@ -1686,77 +1672,60 @@ private void MakeSeparatorList(ReadOnlySpan<char> separators, ref ValueListBuild

private void MakeSeparatorListVectorized(ref ValueListBuilder<int> sepListBuilder, char c, char c2, char c3)
{
// Redundant test so we won't prejit remainder of this method
// on platforms without SSE.
if (!Sse41.IsSupported)
{
throw new PlatformNotSupportedException();
}

// Constant that allows for the truncation of 16-bit (FFFF/0000) values within a register to 4-bit (F/0)
Vector128<byte> shuffleConstant = Vector128.Create(0x00, 0x02, 0x04, 0x06, 0x08, 0x0A, 0x0C, 0x0E, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF);

Vector128<ushort> v1 = Vector128.Create((ushort)c);
Vector128<ushort> v2 = Vector128.Create((ushort)c2);
Vector128<ushort> v3 = Vector128.Create((ushort)c3);

ref char c0 = ref MemoryMarshal.GetReference(this.AsSpan());
int cond = Length & -Vector128<ushort>.Count;
int i = 0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int -> nint, it will help to avoid redundant sign extensions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same variable is used as index to the scalar/non vectorized version at the bottom. I'll see if I can find a middle-way

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can always cast it to signed just once before the scalar version


for (; i < cond; i += Vector128<ushort>.Count)
if (Vector256.IsHardwareAccelerated && Length >= Vector256<ushort>.Count * 2)
{
Vector128<ushort> charVector = ReadVector(ref c0, i);
Vector128<ushort> cmp = Sse2.CompareEqual(charVector, v1);

cmp = Sse2.Or(Sse2.CompareEqual(charVector, v2), cmp);
cmp = Sse2.Or(Sse2.CompareEqual(charVector, v3), cmp);

if (Sse41.TestZ(cmp, cmp)) { continue; }

Vector128<byte> mask = Sse2.ShiftRightLogical(cmp.AsUInt64(), 4).AsByte();
mask = Ssse3.Shuffle(mask, shuffleConstant);
ref ushort source = ref Unsafe.As<char, ushort>(ref _firstChar);

uint lowBits = Sse2.ConvertToUInt32(mask.AsUInt32());
mask = Sse2.ShiftRightLogical(mask.AsUInt64(), 32).AsByte();
uint highBits = Sse2.ConvertToUInt32(mask.AsUInt32());
Vector256<ushort> v1 = Vector256.Create((ushort)c);
Vector256<ushort> v2 = Vector256.Create((ushort)c2);
Vector256<ushort> v3 = Vector256.Create((ushort)c3);

for (int idx = i; lowBits != 0; idx++)
int vector256ShortCount = Vector256<ushort>.Count;
for (; (i + vector256ShortCount) <= Length; i += vector256ShortCount)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider processing trailing elements via overlapping instead of scalar fallback

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a risk though that the code will start getting a bit complicated, I wanted to keep the code easy to follow since it's only used for a specific scenario. If you still think it's worth it, I can definitely look into it

Copy link
Member

@EgorBo EgorBo Feb 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

handling trailing elements in the same loop (or via a spilled iteration) shows nice improvements for small-medium sized inputs, in theory it only adds an additional check inside the loop, feel free to keep it as is, we can then follow up

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(i + vector256ShortCount) <= Length might overflow, it should be
i <= Length - vector256ShortCount

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides that the i <= len - count version can keep the len - count in a register, whilst i + count needs a repeated addition.

Also local vector256ShortCount isn't needed, as JIT will treat Vector256<ushort>.Count as constant.

{
if ((lowBits & 0xF) != 0)
Vector256<ushort> vector = Vector256.LoadUnsafe(ref source, (uint)i);
Vector256<ushort> cmp = Vector256.Equals(vector, v1) | Vector256.Equals(vector, v2) | Vector256.Equals(vector, v3);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider splitting this to temps for better pipelining so all compare instructions will be next to each other and so are ORs


uint mask = cmp.AsByte().ExtractMostSignificantBits() & 0b0101010101010101;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be a good idea to also use TestZ for faster out, e.g.

if (cmp == Vector256<ushort>.Zero)
    continue;

it's faster than movmsk

while (mask != 0)
{
sepListBuilder.Append(idx);
sepListBuilder.Append(i + BitOperations.TrailingZeroCount(mask) / 2);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mask = BitOperations.ResetLowestSetBit(mask);
}

lowBits >>= 8;
}
}
else if (Vector128.IsHardwareAccelerated && Length >= Vector128<ushort>.Count * 2)
{
ref ushort source = ref Unsafe.As<char, ushort>(ref _firstChar);

Vector128<ushort> v1 = Vector128.Create((ushort)c);
Vector128<ushort> v2 = Vector128.Create((ushort)c2);
Vector128<ushort> v3 = Vector128.Create((ushort)c3);

for (int idx = i + 4; highBits != 0; idx++)
int vector128ShortCount = Vector128<ushort>.Count;
for (; (i + vector128ShortCount) <= Length; i += vector128ShortCount)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
int vector128ShortCount = Vector128<ushort>.Count;
for (; (i + vector128ShortCount) <= Length; i += vector128ShortCount)
for (; i <= Length - Vector128<ushort>.Count; i += Vector128<ushort>.Count)

When i is of type nint just check if the comparison doesn't introduce any sign extensions -- please double check to be on the safe side.

{
if ((highBits & 0xF) != 0)
Vector128<ushort> vector = Vector128.LoadUnsafe(ref source, (uint)i);
Vector128<ushort> cmp = Vector128.Equals(vector, v1) | Vector128.Equals(vector, v2) | Vector128.Equals(vector, v3);

uint mask = cmp.AsByte().ExtractMostSignificantBits() & 0b0101010101010101;
while (mask != 0)
{
sepListBuilder.Append(idx);
sepListBuilder.Append(i + BitOperations.TrailingZeroCount(mask) / 2);
mask = BitOperations.ResetLowestSetBit(mask);
}

highBits >>= 8;
}
}

for (; i < Length; i++)
{
char curr = Unsafe.Add(ref c0, (IntPtr)(uint)i);
char curr = this[i];
if (curr == c || curr == c2 || curr == c3)
{
sepListBuilder.Append(i);
}
}

static Vector128<ushort> ReadVector(ref char c0, int offset)
{
ref char ci = ref Unsafe.Add(ref c0, (IntPtr)(uint)offset);
ref byte b = ref Unsafe.As<char, byte>(ref ci);
return Unsafe.ReadUnaligned<Vector128<ushort>>(ref b);
}
}

/// <summary>
Expand Down