Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve hashing and equality allocations/performance #5304

Merged
merged 12 commits into from
Jul 17, 2023
36 changes: 23 additions & 13 deletions build/Shared/EqualityUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;

namespace NuGet.Shared
Expand All @@ -22,7 +23,7 @@ internal static class EqualityUtility
/// <param name="keySelector">The function to extract the key from each item in the list</param>
/// <param name="orderComparer">An optional comparer for comparing keys</param>
/// <param name="sequenceComparer">An optional comparer for sequences</param>
internal static bool OrderedEquals<TSource, TKey>(this IEnumerable<TSource> self, IEnumerable<TSource> other, Func<TSource, TKey> keySelector, IComparer<TKey>? orderComparer = null, IEqualityComparer<TSource>? sequenceComparer = null)
internal static bool OrderedEquals<TSource, TKey>(this IEnumerable<TSource>? self, IEnumerable<TSource>? other, Func<TSource, TKey> keySelector, IComparer<TKey>? orderComparer = null, IEqualityComparer<TSource>? sequenceComparer = null)
{
Debug.Assert(orderComparer != null || typeof(TKey) != typeof(string), "Argument " + "orderComparer" + " must be provided if " + "TKey" + " is a string.");
Debug.Assert(sequenceComparer != null || typeof(TSource) != typeof(string), "Argument " + "sequenceComparer" + " must be provided if " + "TSource" + " is a string.");
Expand All @@ -48,7 +49,7 @@ internal static bool OrderedEquals<TSource, TKey>(this IEnumerable<TSource> self
/// <param name="keySelector">The function to extract the key from each item in the list</param>
/// <param name="orderComparer">An optional comparer for comparing keys</param>
/// <param name="sequenceComparer">An optional comparer for sequences</param>
internal static bool OrderedEquals<TSource, TKey>(this ICollection<TSource> self, ICollection<TSource> other, Func<TSource, TKey> keySelector, IComparer<TKey>? orderComparer = null, IEqualityComparer<TSource>? sequenceComparer = null)
internal static bool OrderedEquals<TSource, TKey>(this ICollection<TSource>? self, ICollection<TSource>? other, Func<TSource, TKey> keySelector, IComparer<TKey>? orderComparer = null, IEqualityComparer<TSource>? sequenceComparer = null)
{
Debug.Assert(orderComparer != null || typeof(TKey) != typeof(string), "Argument " + "orderComparer" + " must be provided if " + "TKey" + " is a string.");
Debug.Assert(sequenceComparer != null || typeof(TSource) != typeof(string), "Argument " + "sequenceComparer" + " must be provided if " + "TSource" + " is a string.");
Expand All @@ -69,6 +70,12 @@ internal static bool OrderedEquals<TSource, TKey>(this ICollection<TSource> self
return true;
}

if (self.Count == 1)
{
sequenceComparer ??= EqualityComparer<TSource>.Default;
return sequenceComparer.Equals(self.First(), other.First());
}

return self
.OrderBy(keySelector, orderComparer)
.SequenceEqual(other.OrderBy(keySelector, orderComparer), sequenceComparer);
Expand All @@ -84,7 +91,7 @@ internal static bool OrderedEquals<TSource, TKey>(this ICollection<TSource> self
/// <param name="keySelector">The function to extract the key from each item in the list</param>
/// <param name="orderComparer">An optional comparer for comparing keys</param>
/// <param name="sequenceComparer">An optional comparer for sequences</param>
internal static bool OrderedEquals<TSource, TKey>(this IList<TSource> self, IList<TSource> other, Func<TSource, TKey> keySelector, IComparer<TKey>? orderComparer = null, IEqualityComparer<TSource>? sequenceComparer = null)
internal static bool OrderedEquals<TSource, TKey>(this IList<TSource>? self, IList<TSource>? other, Func<TSource, TKey> keySelector, IComparer<TKey>? orderComparer = null, IEqualityComparer<TSource>? sequenceComparer = null)
{
Debug.Assert(orderComparer != null || typeof(TKey) != typeof(string), "Argument " + "orderComparer" + " must be provided if " + "TKey" + " is a string.");
Debug.Assert(sequenceComparer != null || typeof(TSource) != typeof(string), "Argument " + "sequenceComparer" + " must be provided if " + "TSource" + " is a string.");
Expand Down Expand Up @@ -120,8 +127,8 @@ internal static bool OrderedEquals<TSource, TKey>(this IList<TSource> self, ILis
/// null for equality.
/// </summary>
internal static bool SequenceEqualWithNullCheck<T>(
this IEnumerable<T> self,
IEnumerable<T> other,
this IEnumerable<T>? self,
IEnumerable<T>? other,
IEqualityComparer<T>? comparer = null)
{
bool identityEquals;
Expand All @@ -143,8 +150,8 @@ internal static bool SequenceEqualWithNullCheck<T>(
/// null for equality.
/// </summary>
internal static bool SequenceEqualWithNullCheck<T>(
this ICollection<T> self,
ICollection<T> other,
this ICollection<T>? self,
ICollection<T>? other,
IEqualityComparer<T>? comparer = null)
{
bool identityEquals;
Expand Down Expand Up @@ -176,8 +183,8 @@ internal static bool SequenceEqualWithNullCheck<T>(
/// null for equality.
/// </summary>
internal static bool SequenceEqualWithNullCheck<T>(
this IList<T> self,
IList<T> other,
this IList<T>? self,
IList<T>? other,
IEqualityComparer<T>? comparer = null)
{
bool identityEquals;
Expand Down Expand Up @@ -214,8 +221,8 @@ internal static bool SequenceEqualWithNullCheck<T>(
/// If one is null, both have to be null for equality.
/// </summary>
internal static bool SetEqualsWithNullCheck<T>(
this ISet<T> self,
ISet<T> other,
this ISet<T>? self,
ISet<T>? other,
IEqualityComparer<T>? comparer = null)
{
bool identityEquals;
Expand Down Expand Up @@ -315,7 +322,7 @@ internal static bool EqualsWithNullCheck<T>(T self, T other)
}

/// <summary>
/// Determines if the current string contains a value equal "false". Leading and trailing whitespace are trimmed and the comparision is case-insensitive
/// Determines if the current string contains a value equal "false". Leading and trailing whitespace are trimmed and the comparison is case-insensitive
/// </summary>
/// <param name="value">The string to compare.</param>
/// <returns><c>true</c> if the current string is equal to a value of "false", otherwise <c>false></c>.</returns>
Expand All @@ -324,7 +331,10 @@ internal static bool EqualsFalse(this string value)
return !string.IsNullOrWhiteSpace(value) && bool.FalseString.Equals(value.Trim(), StringComparison.OrdinalIgnoreCase);
}

private static bool TryIdentityEquals<T>(T? self, T? other, out bool equals)
private static bool TryIdentityEquals<T>(
[NotNullWhen(returnValue: false)] T? self,
[NotNullWhen(returnValue: false)] T? other,
out bool equals)
{
// Are they the same instance? This handles the case where both are null.
if (ReferenceEquals(self, other))
Expand Down
109 changes: 63 additions & 46 deletions build/Shared/HashCodeCombiner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,47 +5,43 @@

using System;
using System.Collections.Generic;
using System.Linq;

namespace NuGet.Shared
{
/// <summary>
/// Hash code creator, based on the original NuGet hash code combiner/ASP hash code combiner implementations
/// </summary>
internal struct HashCodeCombiner
internal ref struct HashCodeCombiner
{
// seed from String.GetHashCode()
private const long Seed = 0x1505L;

private bool _initialized;
private long _combinedHash;
private long _combinedHash = Seed;

internal int CombinedHash
public HashCodeCombiner()
{
get { return _combinedHash.GetHashCode(); }
}

internal int CombinedHash => _combinedHash.GetHashCode();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in general when making refactorings, making minor stylistic changes only makes things harder to review sa every time it is reviewed, people would review something that's a no-op.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From memory the IDE showed a diagnostic suggesting this change. I try to reduce the number of diagnostics in the gutter/margin as part of refactoring/tidy ups.

If this style is not preferred, perhaps remove the diagnostic from the solution in the .editorconfig file. As it stands, it seems like a request to change this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In #9, we touch on that, https://github.com/NuGet/NuGet.Client/blob/dev/docs/coding-guidelines.md.

The general idea is that it's preferred to use that in new code, but part of the internal feedback for the turnaround of some of these PRs is the sheer number of line changes.


private void AddHashCode(int i)
{
_combinedHash = ((_combinedHash << 5) + _combinedHash) ^ i;
}

internal void AddObject(int i)
{
CheckInitialized();
AddHashCode(i);
}

internal void AddObject(bool b)
{
CheckInitialized();
AddHashCode(b ? 1 : 0);
}

internal void AddObject<TValue>(TValue? o, IEqualityComparer<TValue> comparer)
where TValue : class
internal void AddObject<T>(T? o, IEqualityComparer<T> comparer)
where T : class
{
CheckInitialized();
if (o != null)
{
AddHashCode(comparer.GetHashCode(o));
Expand All @@ -55,7 +51,6 @@ internal void AddObject<TValue>(TValue? o, IEqualityComparer<TValue> comparer)
internal void AddObject<T>(T? o)
where T : class
{
CheckInitialized();
if (o != null)
{
AddHashCode(o.GetHashCode());
Expand All @@ -66,8 +61,6 @@ internal void AddObject<T>(T? o)
internal void AddStruct<T>(T? o)
where T : struct
{
CheckInitialized();

if (o.HasValue)
{
AddHashCode(o.GetHashCode());
Expand All @@ -78,62 +71,54 @@ internal void AddStruct<T>(T? o)
internal void AddStruct<T>(T o)
where T : struct
{
CheckInitialized();

AddHashCode(o.GetHashCode());
}

internal void AddStringIgnoreCase(string s)
internal void AddStringIgnoreCase(string? s)
{
CheckInitialized();
if (s != null)
{
AddHashCode(StringComparer.OrdinalIgnoreCase.GetHashCode(s));
}
}

internal void AddSequence<T>(IEnumerable<T> sequence) where T : notnull
internal void AddSequence<T>(IEnumerable<T>? sequence) where T : notnull
{
if (sequence != null)
{
CheckInitialized();
foreach (var item in sequence)
foreach (var item in sequence.NoAllocEnumerate())
{
AddHashCode(item.GetHashCode());
}
}
}

internal void AddSequence<T>(T[] array) where T : notnull
internal void AddSequence<T>(T[]? array) where T : notnull
{
if (array != null)
{
CheckInitialized();
foreach (var item in array)
{
AddHashCode(item.GetHashCode());
}
}
}

internal void AddSequence<T>(IList<T> list) where T : notnull
internal void AddSequence<T>(IList<T>? list) where T : notnull
{
if (list != null)
{
CheckInitialized();
var count = list.Count;
for (var i = 0; i < count; i++)
foreach (var item in list.NoAllocEnumerate())
{
AddHashCode(list[i].GetHashCode());
AddHashCode(item.GetHashCode());
}
}
}

internal void AddSequence<T>(IReadOnlyList<T> list) where T : notnull
internal void AddSequence<T>(IReadOnlyList<T>? list) where T : notnull
{
if (list != null)
{
CheckInitialized();
var count = list.Count;
for (var i = 0; i < count; i++)
{
Expand All @@ -142,18 +127,61 @@ internal void AddSequence<T>(IReadOnlyList<T> list) where T : notnull
}
}

internal void AddDictionary<TKey, TValue>(IEnumerable<KeyValuePair<TKey, TValue>> dictionary)
internal void AddUnorderedSequence<T>(IEnumerable<T>? list) where T : notnull
{
if (list != null)
{
int count = 0;
int hashCode = 0;
foreach (var item in list)
{
// XOR is commutative -- the order of operations doesn't matter
hashCode ^= item.GetHashCode();
count++;
}
AddHashCode(hashCode);
AddHashCode(count);
}
}

internal void AddUnorderedSequence<T>(IEnumerable<T>? list, IEqualityComparer<T> comparer) where T : notnull
{
if (list != null)
{
int count = 0;
int hashCode = 0;
foreach (var item in list)
{
// XOR is commutative -- the order of operations doesn't matter
hashCode ^= comparer.GetHashCode(item);
count++;
}
AddHashCode(hashCode);
AddHashCode(count);
}
}

internal void AddDictionary<TKey, TValue>(IEnumerable<KeyValuePair<TKey, TValue>>? dictionary)
where TKey : notnull
where TValue : notnull
{
if (dictionary != null)
{
CheckInitialized();
foreach (var pair in dictionary.OrderBy(x => x.Key))
int count = 0;
int dictionaryHash = 0;

foreach (var pair in dictionary.NoAllocEnumerate())
{
AddHashCode(pair.Key.GetHashCode());
AddHashCode(pair.Value.GetHashCode());
int keyHash = pair.Key.GetHashCode();
int valHash = pair.Value.GetHashCode();
int pairHash = ((keyHash << 5) + keyHash) ^ valHash;

// XOR is commutative -- the order of operations doesn't matter
dictionaryHash ^= pairHash;
count++;
}

AddHashCode(dictionaryHash + count);
}
}

Expand All @@ -165,7 +193,6 @@ internal static int GetHashCode<T1, T2>(T1 o1, T2 o2)
where T2 : notnull
{
var combiner = new HashCodeCombiner();
combiner.CheckInitialized();

combiner.AddHashCode(o1.GetHashCode());
combiner.AddHashCode(o2.GetHashCode());
Expand All @@ -182,22 +209,12 @@ internal static int GetHashCode<T1, T2, T3>(T1 o1, T2 o2, T3 o3)
where T3 : notnull
{
var combiner = new HashCodeCombiner();
combiner.CheckInitialized();

combiner.AddHashCode(o1.GetHashCode());
combiner.AddHashCode(o2.GetHashCode());
combiner.AddHashCode(o3.GetHashCode());

return combiner.CombinedHash;
}

private void CheckInitialized()
{
if (!_initialized)
{
_combinedHash = Seed;
_initialized = true;
}
}
}
}
Loading