Skip to content

Commit

Permalink
[ILLink] Mark recursive interface implementations in MarkStep (#99922)
Browse files Browse the repository at this point in the history
Cache recursive interfaces for interface implementation dependency analysis. Use Cecil's TypeReferenceComparer for comparing type references.

Co-authored-by: Sven Boemer <[email protected]>
  • Loading branch information
jtschuster and sbomer authored Mar 27, 2024
1 parent 78c2c98 commit 1fa699e
Show file tree
Hide file tree
Showing 27 changed files with 1,412 additions and 14 deletions.
28 changes: 17 additions & 11 deletions src/tools/illink/src/linker/Linker.Steps/MarkStep.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Reflection.Metadata.Ecma335;
using System.Reflection.Runtime.TypeParsing;
using System.Runtime.CompilerServices;
using System.Text.RegularExpressions;
using ILLink.Shared;
using ILLink.Shared.TrimAnalysis;
Expand Down Expand Up @@ -2450,26 +2450,27 @@ void MarkNamedProperty (TypeDefinition type, string property_name, in Dependency

void MarkInterfaceImplementations (TypeDefinition type)
{
if (!type.HasInterfaces)
var ifaces = Annotations.GetRecursiveInterfaces (type);
if (ifaces is null)
return;

foreach (var iface in type.Interfaces) {
foreach (var (ifaceType, impls) in ifaces) {
// Only mark interface implementations of interface types that have been marked.
// This enables stripping of interfaces that are never used
if (ShouldMarkInterfaceImplementation (type, iface))
MarkInterfaceImplementation (iface, new MessageOrigin (type));
if (ShouldMarkInterfaceImplementationList (type, impls, ifaceType))
MarkInterfaceImplementationList (impls, new MessageOrigin (type));
}
}

protected virtual bool ShouldMarkInterfaceImplementation (TypeDefinition type, InterfaceImplementation iface)

protected virtual bool ShouldMarkInterfaceImplementationList (TypeDefinition type, List<InterfaceImplementation> ifaces, TypeReference ifaceType)
{
if (Annotations.IsMarked (iface))
if (ifaces.All (Annotations.IsMarked))
return false;

if (!Context.IsOptimizationEnabled (CodeOptimizations.UnusedInterfaces, type))
return true;

if (Context.Resolve (iface.InterfaceType) is not TypeDefinition resolvedInterfaceType)
if (Context.Resolve (ifaceType) is not TypeDefinition resolvedInterfaceType)
return false;

if (Annotations.IsMarked (resolvedInterfaceType))
Expand Down Expand Up @@ -3764,8 +3765,7 @@ protected virtual void MarkInstruction (Instruction instruction, MethodDefinitio
ScopeStack.UpdateCurrentScopeInstructionOffset (instruction.Offset);
if (markForReflectionAccess) {
MarkMethodVisibleToReflection (methodReference, new DependencyInfo (dependencyKind, method), ScopeStack.CurrentScope.Origin);
}
else {
} else {
MarkMethod (methodReference, new DependencyInfo (dependencyKind, method), ScopeStack.CurrentScope.Origin);
}
break;
Expand Down Expand Up @@ -3825,6 +3825,12 @@ protected virtual void MarkInstruction (Instruction instruction, MethodDefinitio
}
}

void MarkInterfaceImplementationList (List<InterfaceImplementation> ifaces, MessageOrigin? origin = null, DependencyInfo? reason = null)
{
foreach (var iface in ifaces) {
MarkInterfaceImplementation (iface, origin, reason);
}
}

protected internal virtual void MarkInterfaceImplementation (InterfaceImplementation iface, MessageOrigin? origin = null, DependencyInfo? reason = null)
{
Expand Down
6 changes: 6 additions & 0 deletions src/tools/illink/src/linker/Linker/Annotations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Reflection.Metadata.Ecma335;
using ILLink.Shared.TrimAnalysis;
using Mono.Cecil;
using Mono.Cecil.Cil;
Expand Down Expand Up @@ -717,5 +718,10 @@ public void EnqueueVirtualMethod (MethodDefinition method)
if (FlowAnnotations.RequiresVirtualMethodDataFlowAnalysis (method) || HasLinkerAttribute<RequiresUnreferencedCodeAttribute> (method))
VirtualMethodsWithAnnotationsToValidate.Add (method);
}

internal List<(TypeReference, List<InterfaceImplementation>)>? GetRecursiveInterfaces (TypeDefinition type)
{
return TypeMapInfo.GetRecursiveInterfaces (type);
}
}
}
171 changes: 171 additions & 0 deletions src/tools/illink/src/linker/Linker/MethodReferenceComparer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
// Copyright (c) .NET Foundation and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using Mono.Cecil;

namespace Mono.Linker
{
// Copied from https://github.com/jbevain/cecil/blob/master/Mono.Cecil/MethodReferenceComparer.cs
internal sealed class MethodReferenceComparer : EqualityComparer<MethodReference>
{
// Initialized lazily for each thread
[ThreadStatic]
static List<MethodReference>? xComparisonStack;

[ThreadStatic]
static List<MethodReference>? yComparisonStack;

public readonly ITryResolveMetadata _resolver;

public MethodReferenceComparer(ITryResolveMetadata resolver)
{
_resolver = resolver;
}

public override bool Equals (MethodReference? x, MethodReference? y)
{
return AreEqual (x, y, _resolver);
}

public override int GetHashCode (MethodReference obj)
{
return GetHashCodeFor (obj);
}

public static bool AreEqual (MethodReference? x, MethodReference? y, ITryResolveMetadata resolver)
{
if (ReferenceEquals (x, y))
return true;

if (x is null ^ y is null)
return false;

Debug.Assert (x is not null);
Debug.Assert (y is not null);

if (x.HasThis != y.HasThis)
return false;

#pragma warning disable RS0030 // MethodReference.HasParameters is banned - this code is copied from Cecil
if (x.HasParameters != y.HasParameters)
return false;
#pragma warning restore RS0030

if (x.HasGenericParameters != y.HasGenericParameters)
return false;

#pragma warning disable RS0030 // MethodReference.HasParameters is banned - this code is copied from Cecil
if (x.Parameters.Count != y.Parameters.Count)
return false;
#pragma warning restore RS0030

if (x.Name != y.Name)
return false;

if (!TypeReferenceEqualityComparer.AreEqual (x.DeclaringType, y.DeclaringType, resolver))
return false;

var xGeneric = x as GenericInstanceMethod;
var yGeneric = y as GenericInstanceMethod;
if (xGeneric != null || yGeneric != null) {
if (xGeneric == null || yGeneric == null)
return false;

if (xGeneric.GenericArguments.Count != yGeneric.GenericArguments.Count)
return false;

for (int i = 0; i < xGeneric.GenericArguments.Count; i++)
if (!TypeReferenceEqualityComparer.AreEqual (xGeneric.GenericArguments[i], yGeneric.GenericArguments[i], resolver))
return false;
}

var xResolved = resolver.TryResolve (x);
var yResolved = resolver.TryResolve (y);

if (xResolved != yResolved)
return false;

if (xResolved == null) {
// We couldn't resolve either method. In order for them to be equal, their parameter types _must_ match. But wait, there's a twist!
// There exists a situation where we might get into a recursive state: parameter type comparison might lead to comparing the same
// methods again if the parameter types are generic parameters whose owners are these methods. We guard against these by using a
// thread static list of all our comparisons carried out in the stack so far, and if we're in progress of comparing them already,
// we'll just say that they match.

xComparisonStack ??= new List<MethodReference> ();

yComparisonStack ??= new List<MethodReference> ();

for (int i = 0; i < xComparisonStack.Count; i++) {
if (xComparisonStack[i] == x && yComparisonStack[i] == y)
return true;
}

xComparisonStack.Add (x);

try {
yComparisonStack.Add (y);

try {
#pragma warning disable RS0030 // MethodReference.HasParameters is banned - this code is copied from Cecil
for (int i = 0; i < x.Parameters.Count; i++) {
if (!TypeReferenceEqualityComparer.AreEqual (x.Parameters[i].ParameterType, y.Parameters[i].ParameterType, resolver))
return false;
}
#pragma warning restore RS0030
} finally {
yComparisonStack.RemoveAt (yComparisonStack.Count - 1);
}
} finally {
xComparisonStack.RemoveAt (xComparisonStack.Count - 1);
}
}

return true;
}

public static bool AreSignaturesEqual (MethodReference x, MethodReference y, ITryResolveMetadata resolver, TypeComparisonMode comparisonMode = TypeComparisonMode.Exact)
{
if (x.HasThis != y.HasThis)
return false;

#pragma warning disable RS0030 // MethodReference.HasParameters is banned - this code is copied from Cecil
if (x.Parameters.Count != y.Parameters.Count)
return false;
#pragma warning restore RS0030

if (x.GenericParameters.Count != y.GenericParameters.Count)
return false;

#pragma warning disable RS0030 // MethodReference.HasParameters is banned - this code is copied from Cecil
for (var i = 0; i < x.Parameters.Count; i++)
if (!TypeReferenceEqualityComparer.AreEqual (x.Parameters[i].ParameterType, y.Parameters[i].ParameterType, resolver, comparisonMode))
return false;
#pragma warning restore RS0030

if (!TypeReferenceEqualityComparer.AreEqual (x.ReturnType, y.ReturnType, resolver, comparisonMode))
return false;

return true;
}

public static int GetHashCodeFor (MethodReference obj)
{
// a very good prime number
const int hashCodeMultiplier = 486187739;

var genericInstanceMethod = obj as GenericInstanceMethod;
if (genericInstanceMethod != null) {
var hashCode = GetHashCodeFor (genericInstanceMethod.ElementMethod);
for (var i = 0; i < genericInstanceMethod.GenericArguments.Count; i++)
hashCode = hashCode * hashCodeMultiplier + TypeReferenceEqualityComparer.GetHashCodeFor (genericInstanceMethod.GenericArguments[i]);
return hashCode;
}

return TypeReferenceEqualityComparer.GetHashCodeFor (obj.DeclaringType) * hashCodeMultiplier + obj.Name.GetHashCode ();
}
}
}
17 changes: 17 additions & 0 deletions src/tools/illink/src/linker/Linker/TypeComparisonMode.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) .NET Foundation and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

namespace Mono.Linker
{
// Copied from https://github.com/jbevain/cecil/blob/master/Mono.Cecil/TypeComparisonMode.cs
internal enum TypeComparisonMode
{
Exact,
SignatureOnly,

/// <summary>
/// Types can be in different assemblies, as long as the module, assembly, and type names match they will be considered equal
/// </summary>
SignatureOnlyLoose
}
}
46 changes: 46 additions & 0 deletions src/tools/illink/src/linker/Linker/TypeMapInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,12 @@ public void AddDefaultInterfaceImplementation (MethodDefinition @base, Interface
default_interface_implementations.AddToList (@base, new OverrideInformation (@base, defaultImplementationMethod, interfaceImplementor));
}

Dictionary<TypeDefinition, List<(TypeReference, List<InterfaceImplementation>)>> interfaces = new ();
protected virtual void MapType (TypeDefinition type)
{
MapVirtualMethods (type);
MapInterfaceMethodsInTypeHierarchy (type);
interfaces[type] = GetRecursiveInterfaceImplementations (type);

if (!type.HasNestedTypes)
return;
Expand All @@ -128,6 +130,50 @@ protected virtual void MapType (TypeDefinition type)
MapType (nested);
}

internal List<(TypeReference, List<InterfaceImplementation>)>? GetRecursiveInterfaces (TypeDefinition type)
{
if (interfaces.TryGetValue (type, out var value))
return value;
return null;
}

List<(TypeReference, List<InterfaceImplementation>)> GetRecursiveInterfaceImplementations (TypeDefinition type)
{
List<(TypeReference, List<InterfaceImplementation>)> firstImplementationChain = new ();

AddRecursiveInterfaces (type, [], firstImplementationChain, context);
Debug.Assert (firstImplementationChain.All (kvp => context.Resolve (kvp.Item1) == context.Resolve (kvp.Item2.Last ().InterfaceType)));

return firstImplementationChain;

static void AddRecursiveInterfaces (TypeReference typeRef, IEnumerable<InterfaceImplementation> pathToType, List<(TypeReference, List<InterfaceImplementation>)> firstImplementationChain, LinkContext Context)
{
var type = Context.TryResolve (typeRef);
if (type is null)
return;
// Get all explicit interfaces of this type
foreach (var iface in type.Interfaces) {
var interfaceType = iface.InterfaceType.TryInflateFrom (typeRef, Context);
if (interfaceType is null) {
continue;
}
if (!firstImplementationChain.Any (i => TypeReferenceEqualityComparer.AreEqual (i.Item1, interfaceType, Context))) {
firstImplementationChain.Add ((interfaceType, pathToType.Append (iface).ToList ()));
}
}

// Recursive interfaces after all direct interfaces to preserve Inherit/Implement tree order
foreach (var iface in type.Interfaces) {
// If we can't resolve the interface type we can't find recursive interfaces
var ifaceDirectlyOnType = iface.InterfaceType.TryInflateFrom (typeRef, Context);
if (ifaceDirectlyOnType is null) {
continue;
}
AddRecursiveInterfaces (ifaceDirectlyOnType, pathToType.Append (iface), firstImplementationChain, Context);
}
}
}

void MapInterfaceMethodsInTypeHierarchy (TypeDefinition type)
{
if (!type.HasInterfaces)
Expand Down
Loading

0 comments on commit 1fa699e

Please sign in to comment.