diff --git a/src/libraries/Common/tests/TestUtilities/System/Buffers/BoundedMemory.Unix.cs b/src/libraries/Common/tests/TestUtilities/System/Buffers/BoundedMemory.Unix.cs index 8718608152a3fe..b127a2a72491e9 100644 --- a/src/libraries/Common/tests/TestUtilities/System/Buffers/BoundedMemory.Unix.cs +++ b/src/libraries/Common/tests/TestUtilities/System/Buffers/BoundedMemory.Unix.cs @@ -1,34 +1,60 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Runtime.InteropServices; + namespace System.Buffers { - public static partial class BoundedMemory + public static unsafe partial class BoundedMemory { private static UnixImplementation AllocateWithoutDataPopulationUnix(int elementCount, PoisonPagePlacement placement) where T : unmanaged { // On non-Windows platforms, we don't yet have support for changing the permissions of individual pages. + // We'll instead use AllocHGlobal / FreeHGlobal to carve out a r+w section of unmanaged memory. + return new UnixImplementation(elementCount); } private sealed class UnixImplementation : BoundedMemory where T : unmanaged { - private readonly T[] _buffer; + private readonly AllocHGlobalHandle _handle; + private readonly int _elementCount; + private readonly BoundedMemoryManager _memoryManager; public UnixImplementation(int elementCount) { - _buffer = new T[elementCount]; + _handle = AllocHGlobalHandle.Allocate(checked(elementCount * (nint)sizeof(T))); + _elementCount = elementCount; + _memoryManager = new BoundedMemoryManager(this); } public override bool IsReadonly => false; - public override Memory Memory => _buffer; + public override Memory Memory => _memoryManager.Memory; - public override Span Span => _buffer; + public override Span Span + { + get + { + bool refAdded = false; + try + { + _handle.DangerousAddRef(ref refAdded); + return new Span((void*)_handle.DangerousGetHandle(), _elementCount); + } + finally + { + if (refAdded) + { + _handle.DangerousRelease(); + } + } + } + } public override void Dispose() { - // no-op + _handle.Dispose(); } public override void MakeReadonly() @@ -40,6 +66,82 @@ public override void MakeWriteable() { // no-op } + + private sealed class BoundedMemoryManager : MemoryManager + { + private readonly UnixImplementation _impl; + + public BoundedMemoryManager(UnixImplementation impl) + { + _impl = impl; + } + + public override Memory Memory => CreateMemory(_impl._elementCount); + + protected override void Dispose(bool disposing) + { + // no-op; the handle will be disposed separately + } + + public override Span GetSpan() + { + throw new NotImplementedException(); + } + + public override MemoryHandle Pin(int elementIndex) + { + if ((uint)elementIndex > (uint)_impl._elementCount) + { + throw new ArgumentOutOfRangeException(paramName: nameof(elementIndex)); + } + + bool refAdded = false; + try + { + _impl._handle.DangerousAddRef(ref refAdded); + return new MemoryHandle((T*)_impl._handle.DangerousGetHandle() + elementIndex); + } + finally + { + if (refAdded) + { + _impl._handle.DangerousRelease(); + } + } + } + + public override void Unpin() + { + // no-op - we don't unpin native memory + } + } + } + + private sealed class AllocHGlobalHandle : SafeHandle + { + // Called by P/Invoke when returning SafeHandles + private AllocHGlobalHandle() + : base(IntPtr.Zero, ownsHandle: true) + { + } + + internal static AllocHGlobalHandle Allocate(nint byteLength) + { + AllocHGlobalHandle retVal = new AllocHGlobalHandle(); + retVal.SetHandle(Marshal.AllocHGlobal(byteLength)); // this is for unit testing; don't bother setting up a CER on Full Framework + return retVal; + } + + // Do not provide a finalizer - SafeHandle's critical finalizer will + // call ReleaseHandle for you. + + public override bool IsInvalid => (handle == IntPtr.Zero); + + protected override bool ReleaseHandle() + { + Marshal.FreeHGlobal(handle); + return true; + } } } } diff --git a/src/libraries/Common/tests/TestUtilities/System/Buffers/BoundedMemory.Windows.cs b/src/libraries/Common/tests/TestUtilities/System/Buffers/BoundedMemory.Windows.cs index 8c672e9d21e9e2..3f256b970fe762 100644 --- a/src/libraries/Common/tests/TestUtilities/System/Buffers/BoundedMemory.Windows.cs +++ b/src/libraries/Common/tests/TestUtilities/System/Buffers/BoundedMemory.Windows.cs @@ -1,10 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Buffers; -using System.Runtime.ConstrainedExecution; using System.Runtime.InteropServices; -using System.Security; namespace System.Buffers { @@ -290,7 +287,6 @@ protected override bool ReleaseHandle() => UnsafeNativeMethods.VirtualFree(handle, IntPtr.Zero, VirtualAllocAllocationType.MEM_RELEASE); } - [SuppressUnmanagedCodeSecurity] private static class UnsafeNativeMethods { private const string KERNEL32_LIB = "kernel32.dll"; diff --git a/src/libraries/System.Memory/tests/Span/IndexOfAny.AlgorithmicComplexity.cs b/src/libraries/System.Memory/tests/Span/IndexOfAny.AlgorithmicComplexity.cs new file mode 100644 index 00000000000000..9b395998577cbb --- /dev/null +++ b/src/libraries/System.Memory/tests/Span/IndexOfAny.AlgorithmicComplexity.cs @@ -0,0 +1,149 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using Xunit; + +namespace System.SpanTests +{ + public static partial class SpanTests + { + [Fact] + public static void IndexOfAny_LastIndexOfAny_AlgComplexity_Bytes() + => RunIndexOfAnyLastIndexOfAnyAlgComplexityTest(); + + [Fact] + public static void IndexOfAny_LastIndexOfAny_AlgComplexity_Chars() + => RunIndexOfAnyLastIndexOfAnyAlgComplexityTest(); + + [Fact] + public static void IndexOfAny_LastIndexOfAny_AlgComplexity_Ints() + => RunIndexOfAnyLastIndexOfAnyAlgComplexityTest(); + + [Fact] + public static void IndexOfAny_LastIndexOfAny_AlgComplexity_RefType() + { + // Similar to RunIndexOfAnyAlgComplexityTest (see comments there), but we can't use + // BoundedMemory because we're dealing with ref types. Instead, we'll trap the call to + // Equals and use that to fail the test. + + Span> haystack = new CustomEquatableType[8192]; + haystack[1024] = new CustomEquatableType(default, isPoison: true); // fail the test if we iterate this far + haystack[^1024] = new CustomEquatableType(default, isPoison: true); + + Span> needle = Enumerable.Range(100, 20).Select(val => new CustomEquatableType(val)).ToArray(); + for (int i = 0; i < needle.Length; i++) + { + haystack[4096] = needle[i]; + Assert.Equal(2048, MemoryExtensions.IndexOfAny(haystack[2048..], needle)); + Assert.Equal(2048, MemoryExtensions.IndexOfAny((ReadOnlySpan>)haystack[2048..], needle)); + Assert.Equal(4096, MemoryExtensions.LastIndexOfAny(haystack[..^2048], needle)); + Assert.Equal(4096, MemoryExtensions.LastIndexOfAny((ReadOnlySpan>)haystack[..^2048], needle)); + } + } + + private static void RunIndexOfAnyLastIndexOfAnyAlgComplexityTest() where T : unmanaged, IEquatable + { + T[] needles = GetIndexOfAnyNeedlesForAlgComplexityTest().ToArray(); + RunIndexOfAnyAlgComplexityTest(needles); + RunLastIndexOfAnyAlgComplexityTest(needles); + } + + private static void RunIndexOfAnyAlgComplexityTest(T[] needle) where T : unmanaged, IEquatable + { + // For the following paragraphs, let: + // n := length of haystack + // i := index of first occurrence of any needle within haystack + // l := length of needle array + // + // This test ensures that the complexity of IndexOfAny is O(i * l) rather than O(n * l), + // or just O(n * l) if no needle is found. The reason for this is that it's common for + // callers to invoke IndexOfAny immediately before slicing, and when this is called in + // a loop, we want the entire loop to be bounded by O(n * l) rather than O(n^2 * l). + // + // We test this by utilizing the BoundedMemory infrastructure to allocate a poison page + // after the scratch buffer, then we intentionally use MemoryMarshal to manipulate the + // scratch buffer so that it extends into the poison page. If the runtime skips past the + // first occurrence of the needle and attempts to read all the way to the end of the span, + // this will manifest as an AV within this unit test. + + using BoundedMemory boundedMem = BoundedMemory.Allocate(4096, PoisonPagePlacement.After); + Span span = boundedMem.Span; + span.Clear(); + + span = MemoryMarshal.CreateSpan(ref MemoryMarshal.GetReference(span), span.Length + 4096); + + for (int i = 0; i < needle.Length; i++) + { + span[1024] = needle[i]; + Assert.Equal(1024, MemoryExtensions.IndexOfAny(span, needle)); + Assert.Equal(1024, MemoryExtensions.IndexOfAny((ReadOnlySpan)span, needle)); + } + } + + private static void RunLastIndexOfAnyAlgComplexityTest(T[] needle) where T : unmanaged, IEquatable + { + // Similar to RunIndexOfAnyAlgComplexityTest (see comments there), but we run backward + // since we're testing LastIndexOfAny. + + using BoundedMemory boundedMem = BoundedMemory.Allocate(4096, PoisonPagePlacement.Before); + Span span = boundedMem.Span; + span.Clear(); + + span = MemoryMarshal.CreateSpan(ref Unsafe.Subtract(ref MemoryMarshal.GetReference(span), 4096), span.Length + 4096); + + for (int i = 0; i < needle.Length; i++) + { + span[^1024] = needle[i]; + Assert.Equal(span.Length - 1024, MemoryExtensions.LastIndexOfAny(span, needle)); + Assert.Equal(span.Length - 1024, MemoryExtensions.LastIndexOfAny((ReadOnlySpan)span, needle)); + } + } + + // returns [ 'a', 'b', 'c', ... ], or the equivalent in bytes, ints, etc. + private static IEnumerable GetIndexOfAnyNeedlesForAlgComplexityTest() where T : unmanaged + { + for (int i = 0; i < 26; i++) + { + yield return (T)Convert.ChangeType('a' + i, typeof(T), CultureInfo.InvariantCulture); + } + } + +#pragma warning disable CS0659 // Type overrides Object.Equals(object o) but does not override Object.GetHashCode() + private sealed class CustomEquatableType : IEquatable> where T : IEquatable +#pragma warning restore CS0659 // Type overrides Object.Equals(object o) but does not override Object.GetHashCode() + { + private readonly T _value; + private readonly bool _isPoison; + + public CustomEquatableType(T value, bool isPoison = false) + { + _value = value; + _isPoison = isPoison; + } + + public override bool Equals(object obj) => Equals(obj as CustomEquatableType); + + public bool Equals(CustomEquatableType other) + { + if (_isPoison) + { + throw new InvalidOperationException("This object is poisoned and its Equals method should not be called."); + } + + if (other is null) { return false; } + if (other._isPoison) + { + throw new InvalidOperationException("The 'other' object is poisoned and should not be passed to Equals."); + } + + return _value.Equals(other._value); + } + } + } +} diff --git a/src/libraries/System.Memory/tests/System.Memory.Tests.csproj b/src/libraries/System.Memory/tests/System.Memory.Tests.csproj index b9e9e02a838414..6176b956641aef 100644 --- a/src/libraries/System.Memory/tests/System.Memory.Tests.csproj +++ b/src/libraries/System.Memory/tests/System.Memory.Tests.csproj @@ -75,6 +75,7 @@ + diff --git a/src/libraries/System.Private.CoreLib/src/System/MemoryExtensions.cs b/src/libraries/System.Private.CoreLib/src/System/MemoryExtensions.cs index 2d0fd54f7f2355..31608d643a2924 100644 --- a/src/libraries/System.Private.CoreLib/src/System/MemoryExtensions.cs +++ b/src/libraries/System.Private.CoreLib/src/System/MemoryExtensions.cs @@ -713,14 +713,6 @@ ref Unsafe.As(ref MemoryMarshal.GetReference(span)), Unsafe.Add(ref valueRef, 2), span.Length); } - else - { - return SpanHelpers.IndexOfAny( - ref Unsafe.As(ref MemoryMarshal.GetReference(span)), - span.Length, - ref valueRef, - values.Length); - } } if (Unsafe.SizeOf() == sizeof(char)) @@ -888,14 +880,7 @@ ref Unsafe.As(ref MemoryMarshal.GetReference(span)), [MethodImpl(MethodImplOptions.AggressiveInlining)] public static int LastIndexOfAny(this ReadOnlySpan span, ReadOnlySpan values) where T : IEquatable { - if (Unsafe.SizeOf() == sizeof(byte) && RuntimeHelpers.IsBitwiseEquatable()) - return SpanHelpers.LastIndexOfAny( - ref Unsafe.As(ref MemoryMarshal.GetReference(span)), - span.Length, - ref Unsafe.As(ref MemoryMarshal.GetReference(values)), - values.Length); - - return SpanHelpers.LastIndexOfAny(ref MemoryMarshal.GetReference(span), span.Length, ref MemoryMarshal.GetReference(values), values.Length); + return SpanHelpers.LastIndexOfAny(ref MemoryMarshal.GetReference(span), span.Length, ref MemoryMarshal.GetReference(values), values.Length); } /// diff --git a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs index 7ba50af0e642f3..b9efaf24ee1a2a 100644 --- a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs +++ b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs @@ -51,49 +51,6 @@ public static int IndexOf(ref byte searchSpace, int searchSpaceLength, ref byte return -1; } - public static int IndexOfAny(ref byte searchSpace, int searchSpaceLength, ref byte value, int valueLength) - { - Debug.Assert(searchSpaceLength >= 0); - Debug.Assert(valueLength >= 0); - - if (valueLength == 0) - return -1; // A zero-length set of values is always treated as "not found". - - int offset = -1; - for (int i = 0; i < valueLength; i++) - { - int tempIndex = IndexOf(ref searchSpace, Unsafe.Add(ref value, i), searchSpaceLength); - if ((uint)tempIndex < (uint)offset) - { - offset = tempIndex; - // Reduce space for search, cause we don't care if we find the search value after the index of a previously found value - searchSpaceLength = tempIndex; - - if (offset == 0) - break; - } - } - return offset; - } - - public static int LastIndexOfAny(ref byte searchSpace, int searchSpaceLength, ref byte value, int valueLength) - { - Debug.Assert(searchSpaceLength >= 0); - Debug.Assert(valueLength >= 0); - - if (valueLength == 0) - return -1; // A zero-length set of values is always treated as "not found". - - int offset = -1; - for (int i = 0; i < valueLength; i++) - { - int tempIndex = LastIndexOf(ref searchSpace, Unsafe.Add(ref value, i), searchSpaceLength); - if (tempIndex > offset) - offset = tempIndex; - } - return offset; - } - // Adapted from IndexOf(...) [MethodImpl(MethodImplOptions.AggressiveOptimization)] public static bool Contains(ref byte searchSpace, byte value, int length) diff --git a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs index 339aa157f8b7e0..f5de74e1166e4a 100644 --- a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs +++ b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs @@ -593,21 +593,67 @@ public static int IndexOfAny(ref T searchSpace, int searchSpaceLength, ref T if (valueLength == 0) return -1; // A zero-length set of values is always treated as "not found". - int index = -1; - for (int i = 0; i < valueLength; i++) + // For the following paragraph, let: + // n := length of haystack + // i := index of first occurrence of any needle within haystack + // l := length of needle array + // + // We use a naive non-vectorized search because we want to bound the complexity of IndexOfAny + // to O(i * l) rather than O(n * l), or just O(n * l) if no needle is found. The reason for + // this is that it's common for callers to invoke IndexOfAny immediately before slicing, + // and when this is called in a loop, we want the entire loop to be bounded by O(n * l) + // rather than O(n^2 * l). + + if (typeof(T).IsValueType) { - int tempIndex = IndexOf(ref searchSpace, Unsafe.Add(ref value, i), searchSpaceLength); - if ((uint)tempIndex < (uint)index) + // Calling ValueType.Equals (devirtualized), which takes 'this' byref. We'll make + // a byval copy of the candidate from the search space in the outer loop, then in + // the inner loop we'll pass a ref (as 'this') to each element in the needle. + + for (int i = 0; i < searchSpaceLength; i++) { - index = tempIndex; - // Reduce space for search, cause we don't care if we find the search value after the index of a previously found value - searchSpaceLength = tempIndex; + T candidate = Unsafe.Add(ref searchSpace, i); + for (int j = 0; j < valueLength; j++) + { + if (Unsafe.Add(ref value, j).Equals(candidate)) + { + return i; + } + } + } + } + else + { + // Calling IEquatable.Equals (virtual dispatch). We'll perform the null check + // in the outer loop instead of in the inner loop to save some branching. - if (index == 0) - break; + for (int i = 0; i < searchSpaceLength; i++) + { + T candidate = Unsafe.Add(ref searchSpace, i); + if (candidate is not null) + { + for (int j = 0; j < valueLength; j++) + { + if (candidate.Equals(Unsafe.Add(ref value, j))) + { + return i; + } + } + } + else + { + for (int j = 0; j < valueLength; j++) + { + if (Unsafe.Add(ref value, j) is null) + { + return i; + } + } + } } } - return index; + + return -1; // not found } public static int LastIndexOf(ref T searchSpace, int searchSpaceLength, ref T value, int valueLength) where T : IEquatable @@ -939,14 +985,52 @@ public static int LastIndexOfAny(ref T searchSpace, int searchSpaceLength, re if (valueLength == 0) return -1; // A zero-length set of values is always treated as "not found". - int index = -1; - for (int i = 0; i < valueLength; i++) + // See comments in IndexOfAny(ref T, int, ref T, int) above regarding algorithmic complexity concerns. + // This logic is similar, but it runs backward. + + if (typeof(T).IsValueType) { - int tempIndex = LastIndexOf(ref searchSpace, Unsafe.Add(ref value, i), searchSpaceLength); - if (tempIndex > index) - index = tempIndex; + for (int i = searchSpaceLength - 1; i >= 0; i--) + { + T candidate = Unsafe.Add(ref searchSpace, i); + for (int j = 0; j < valueLength; j++) + { + if (Unsafe.Add(ref value, j).Equals(candidate)) + { + return i; + } + } + } } - return index; + else + { + for (int i = searchSpaceLength - 1; i >= 0; i--) + { + T candidate = Unsafe.Add(ref searchSpace, i); + if (candidate is not null) + { + for (int j = 0; j < valueLength; j++) + { + if (candidate.Equals(Unsafe.Add(ref value, j))) + { + return i; + } + } + } + else + { + for (int j = 0; j < valueLength; j++) + { + if (Unsafe.Add(ref value, j) is null) + { + return i; + } + } + } + } + } + + return -1; // not found } public static bool SequenceEqual(ref T first, ref T second, int length) where T : IEquatable