Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce worst-case alg complexity of MemoryExtensions.IndexOfAny #53652

Merged
merged 3 commits into from
Jun 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,34 +1,60 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Runtime.InteropServices;

namespace System.Buffers
{
public static partial class BoundedMemory
public static unsafe partial class BoundedMemory
{
private static UnixImplementation<T> AllocateWithoutDataPopulationUnix<T>(int elementCount, PoisonPagePlacement placement) where T : unmanaged
{
// On non-Windows platforms, we don't yet have support for changing the permissions of individual pages.
// We'll instead use AllocHGlobal / FreeHGlobal to carve out a r+w section of unmanaged memory.

return new UnixImplementation<T>(elementCount);
}

private sealed class UnixImplementation<T> : BoundedMemory<T> where T : unmanaged
{
private readonly T[] _buffer;
private readonly AllocHGlobalHandle _handle;
private readonly int _elementCount;
private readonly BoundedMemoryManager _memoryManager;

public UnixImplementation(int elementCount)
{
_buffer = new T[elementCount];
_handle = AllocHGlobalHandle.Allocate(checked(elementCount * (nint)sizeof(T)));
_elementCount = elementCount;
_memoryManager = new BoundedMemoryManager(this);
}

public override bool IsReadonly => false;

public override Memory<T> Memory => _buffer;
public override Memory<T> Memory => _memoryManager.Memory;

public override Span<T> Span => _buffer;
public override Span<T> Span
{
get
{
bool refAdded = false;
try
{
_handle.DangerousAddRef(ref refAdded);
return new Span<T>((void*)_handle.DangerousGetHandle(), _elementCount);
}
finally
{
if (refAdded)
{
_handle.DangerousRelease();
}
}
}
}

public override void Dispose()
{
// no-op
_handle.Dispose();
}

public override void MakeReadonly()
Expand All @@ -40,6 +66,82 @@ public override void MakeWriteable()
{
// no-op
}

private sealed class BoundedMemoryManager : MemoryManager<T>
{
private readonly UnixImplementation<T> _impl;

public BoundedMemoryManager(UnixImplementation<T> impl)
{
_impl = impl;
}

public override Memory<T> Memory => CreateMemory(_impl._elementCount);

protected override void Dispose(bool disposing)
{
// no-op; the handle will be disposed separately
}

public override Span<T> GetSpan()
{
throw new NotImplementedException();
}

public override MemoryHandle Pin(int elementIndex)
{
if ((uint)elementIndex > (uint)_impl._elementCount)
{
throw new ArgumentOutOfRangeException(paramName: nameof(elementIndex));
}

bool refAdded = false;
try
{
_impl._handle.DangerousAddRef(ref refAdded);
return new MemoryHandle((T*)_impl._handle.DangerousGetHandle() + elementIndex);
}
finally
{
if (refAdded)
{
_impl._handle.DangerousRelease();
}
}
}

public override void Unpin()
{
// no-op - we don't unpin native memory
}
}
}

private sealed class AllocHGlobalHandle : SafeHandle
{
// Called by P/Invoke when returning SafeHandles
private AllocHGlobalHandle()
: base(IntPtr.Zero, ownsHandle: true)
{
}

internal static AllocHGlobalHandle Allocate(nint byteLength)
{
AllocHGlobalHandle retVal = new AllocHGlobalHandle();
retVal.SetHandle(Marshal.AllocHGlobal(byteLength)); // this is for unit testing; don't bother setting up a CER on Full Framework
return retVal;
}

// Do not provide a finalizer - SafeHandle's critical finalizer will
// call ReleaseHandle for you.

public override bool IsInvalid => (handle == IntPtr.Zero);

protected override bool ReleaseHandle()
{
Marshal.FreeHGlobal(handle);
return true;
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Buffers;
using System.Runtime.ConstrainedExecution;
using System.Runtime.InteropServices;
using System.Security;

namespace System.Buffers
{
Expand Down Expand Up @@ -290,7 +287,6 @@ protected override bool ReleaseHandle() =>
UnsafeNativeMethods.VirtualFree(handle, IntPtr.Zero, VirtualAllocAllocationType.MEM_RELEASE);
}

[SuppressUnmanagedCodeSecurity]
private static class UnsafeNativeMethods
{
private const string KERNEL32_LIB = "kernel32.dll";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Buffers;
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using Xunit;

namespace System.SpanTests
{
public static partial class SpanTests
{
[Fact]
public static void IndexOfAny_LastIndexOfAny_AlgComplexity_Bytes()
=> RunIndexOfAnyLastIndexOfAnyAlgComplexityTest<byte>();

[Fact]
public static void IndexOfAny_LastIndexOfAny_AlgComplexity_Chars()
=> RunIndexOfAnyLastIndexOfAnyAlgComplexityTest<char>();

[Fact]
public static void IndexOfAny_LastIndexOfAny_AlgComplexity_Ints()
=> RunIndexOfAnyLastIndexOfAnyAlgComplexityTest<int>();

[Fact]
public static void IndexOfAny_LastIndexOfAny_AlgComplexity_RefType()
{
// Similar to RunIndexOfAnyAlgComplexityTest (see comments there), but we can't use
// BoundedMemory because we're dealing with ref types. Instead, we'll trap the call to
// Equals and use that to fail the test.

Span<CustomEquatableType<int>> haystack = new CustomEquatableType<int>[8192];
haystack[1024] = new CustomEquatableType<int>(default, isPoison: true); // fail the test if we iterate this far
haystack[^1024] = new CustomEquatableType<int>(default, isPoison: true);

Span<CustomEquatableType<int>> needle = Enumerable.Range(100, 20).Select(val => new CustomEquatableType<int>(val)).ToArray();
for (int i = 0; i < needle.Length; i++)
{
haystack[4096] = needle[i];
Assert.Equal(2048, MemoryExtensions.IndexOfAny(haystack[2048..], needle));
Assert.Equal(2048, MemoryExtensions.IndexOfAny((ReadOnlySpan<CustomEquatableType<int>>)haystack[2048..], needle));
Assert.Equal(4096, MemoryExtensions.LastIndexOfAny(haystack[..^2048], needle));
Assert.Equal(4096, MemoryExtensions.LastIndexOfAny((ReadOnlySpan<CustomEquatableType<int>>)haystack[..^2048], needle));
}
}

private static void RunIndexOfAnyLastIndexOfAnyAlgComplexityTest<T>() where T : unmanaged, IEquatable<T>
{
T[] needles = GetIndexOfAnyNeedlesForAlgComplexityTest<T>().ToArray();
RunIndexOfAnyAlgComplexityTest<T>(needles);
RunLastIndexOfAnyAlgComplexityTest<T>(needles);
}

private static void RunIndexOfAnyAlgComplexityTest<T>(T[] needle) where T : unmanaged, IEquatable<T>
{
// For the following paragraphs, let:
// n := length of haystack
// i := index of first occurrence of any needle within haystack
// l := length of needle array
//
// This test ensures that the complexity of IndexOfAny is O(i * l) rather than O(n * l),
// or just O(n * l) if no needle is found. The reason for this is that it's common for
// callers to invoke IndexOfAny immediately before slicing, and when this is called in
// a loop, we want the entire loop to be bounded by O(n * l) rather than O(n^2 * l).
//
// We test this by utilizing the BoundedMemory infrastructure to allocate a poison page
// after the scratch buffer, then we intentionally use MemoryMarshal to manipulate the
// scratch buffer so that it extends into the poison page. If the runtime skips past the
// first occurrence of the needle and attempts to read all the way to the end of the span,
// this will manifest as an AV within this unit test.

using BoundedMemory<T> boundedMem = BoundedMemory.Allocate<T>(4096, PoisonPagePlacement.After);
Span<T> span = boundedMem.Span;
span.Clear();

span = MemoryMarshal.CreateSpan(ref MemoryMarshal.GetReference(span), span.Length + 4096);

for (int i = 0; i < needle.Length; i++)
{
span[1024] = needle[i];
Assert.Equal(1024, MemoryExtensions.IndexOfAny(span, needle));
Assert.Equal(1024, MemoryExtensions.IndexOfAny((ReadOnlySpan<T>)span, needle));
}
}

private static void RunLastIndexOfAnyAlgComplexityTest<T>(T[] needle) where T : unmanaged, IEquatable<T>
{
// Similar to RunIndexOfAnyAlgComplexityTest (see comments there), but we run backward
// since we're testing LastIndexOfAny.

using BoundedMemory<T> boundedMem = BoundedMemory.Allocate<T>(4096, PoisonPagePlacement.Before);
Span<T> span = boundedMem.Span;
span.Clear();

span = MemoryMarshal.CreateSpan(ref Unsafe.Subtract(ref MemoryMarshal.GetReference(span), 4096), span.Length + 4096);

for (int i = 0; i < needle.Length; i++)
{
span[^1024] = needle[i];
Assert.Equal(span.Length - 1024, MemoryExtensions.LastIndexOfAny(span, needle));
Assert.Equal(span.Length - 1024, MemoryExtensions.LastIndexOfAny((ReadOnlySpan<T>)span, needle));
}
}

// returns [ 'a', 'b', 'c', ... ], or the equivalent in bytes, ints, etc.
private static IEnumerable<T> GetIndexOfAnyNeedlesForAlgComplexityTest<T>() where T : unmanaged
{
for (int i = 0; i < 26; i++)
{
yield return (T)Convert.ChangeType('a' + i, typeof(T), CultureInfo.InvariantCulture);
}
}

#pragma warning disable CS0659 // Type overrides Object.Equals(object o) but does not override Object.GetHashCode()
private sealed class CustomEquatableType<T> : IEquatable<CustomEquatableType<T>> where T : IEquatable<T>
#pragma warning restore CS0659 // Type overrides Object.Equals(object o) but does not override Object.GetHashCode()
{
private readonly T _value;
private readonly bool _isPoison;

public CustomEquatableType(T value, bool isPoison = false)
{
_value = value;
_isPoison = isPoison;
}

public override bool Equals(object obj) => Equals(obj as CustomEquatableType<T>);

public bool Equals(CustomEquatableType<T> other)
{
if (_isPoison)
{
throw new InvalidOperationException("This object is poisoned and its Equals method should not be called.");
}

if (other is null) { return false; }
if (other._isPoison)
{
throw new InvalidOperationException("The 'other' object is poisoned and should not be passed to Equals.");
}

return _value.Equals(other._value);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
<Compile Include="Span\IndexOf.byte.cs" />
<Compile Include="Span\IndexOf.char.cs" />
<Compile Include="Span\IndexOf.T.cs" />
<Compile Include="Span\IndexOfAny.AlgorithmicComplexity.cs" />
<Compile Include="Span\IndexOfAny.byte.cs" />
<Compile Include="Span\IndexOfAny.char.cs" />
<Compile Include="Span\IndexOfAny.T.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -713,14 +713,6 @@ ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
Unsafe.Add(ref valueRef, 2),
span.Length);
}
else
{
return SpanHelpers.IndexOfAny(
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
span.Length,
ref valueRef,
values.Length);
}
}

if (Unsafe.SizeOf<T>() == sizeof(char))
Expand Down Expand Up @@ -888,14 +880,7 @@ ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static int LastIndexOfAny<T>(this ReadOnlySpan<T> span, ReadOnlySpan<T> values) where T : IEquatable<T>
{
if (Unsafe.SizeOf<T>() == sizeof(byte) && RuntimeHelpers.IsBitwiseEquatable<T>())
return SpanHelpers.LastIndexOfAny(
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
span.Length,
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(values)),
values.Length);

return SpanHelpers.LastIndexOfAny<T>(ref MemoryMarshal.GetReference(span), span.Length, ref MemoryMarshal.GetReference(values), values.Length);
return SpanHelpers.LastIndexOfAny(ref MemoryMarshal.GetReference(span), span.Length, ref MemoryMarshal.GetReference(values), values.Length);
}

/// <summary>
Expand Down
Loading