Skip to content

Commit

Permalink
Make constructors private for parts of the ComInterfaceGenerator that…
Browse files Browse the repository at this point in the history
… should only be created from their static methods (#101740)

Moves ComClassInfo to its own file and moves the Func that produces them in the ComClassGenerator pipeline to a static method on ComClassInfo.

Makes the constructors for ComClassInfo, ComInterfaceContext, ComInterfaceInfo, and `ComMethoInfo private to enforce that they are only created from the static creation methods.

Creates more static SpecialTypeInfo types for sbyte, uint, short, and ushort.

Changes pattern matching on records to use property pattern matching rather than the deconstruction notation for records. Since the constructors are not accessible, neither are the deconstruct methods.
  • Loading branch information
jtschuster authored May 29, 2024
1 parent 9c9155d commit e36b937
Show file tree
Hide file tree
Showing 12 changed files with 171 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ namespace Microsoft.Interop
[Generator]
public class ComClassGenerator : IIncrementalGenerator
{
private sealed record ComClassInfo(string ClassName, ContainingSyntaxContext ContainingSyntaxContext, ContainingSyntax ClassSyntax, SequenceEqualImmutableArray<string> ImplementedInterfacesNames);
public void Initialize(IncrementalGeneratorInitializationContext context)
{
var unsafeCodeIsEnabled = context.CompilationProvider.Select((comp, ct) => comp.Options is CSharpCompilationOptions { AllowUnsafe: true }); // Unsafe code enabled
Expand All @@ -27,54 +26,13 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
static (node, ct) => node is ClassDeclarationSyntax,
static (context, ct) => context)
.Combine(unsafeCodeIsEnabled)
.Select((data, ct) =>
.Select(static (data, ct) =>
{
var context = data.Left;
var unsafeCodeIsEnabled = data.Right;
var type = (INamedTypeSymbol)context.TargetSymbol;
var syntax = (ClassDeclarationSyntax)context.TargetNode;
if (!unsafeCodeIsEnabled)
{
return DiagnosticOr<ComClassInfo>.From(DiagnosticInfo.Create(GeneratorDiagnostics.RequiresAllowUnsafeBlocks, syntax.Identifier.GetLocation()));
}

if (!syntax.IsInPartialContext(out _))
{
return DiagnosticOr<ComClassInfo>.From(
DiagnosticInfo.Create(
GeneratorDiagnostics.InvalidAttributedClassMissingPartialModifier,
syntax.Identifier.GetLocation(),
type.ToDisplayString()));
}

ImmutableArray<string>.Builder names = ImmutableArray.CreateBuilder<string>();
foreach (INamedTypeSymbol iface in type.AllInterfaces)
{
AttributeData? generatedComInterfaceAttribute = iface.GetAttributes().FirstOrDefault(attr => attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute);
if (generatedComInterfaceAttribute is not null)
{
var attributeData = GeneratedComInterfaceCompilationData.GetDataFromAttribute(generatedComInterfaceAttribute);
if (attributeData.Options.HasFlag(ComInterfaceOptions.ManagedObjectWrapper))
{
names.Add(iface.ToDisplayString());
}
}
}

if (names.Count == 0)
{
return DiagnosticOr<ComClassInfo>.From(DiagnosticInfo.Create(GeneratorDiagnostics.ClassDoesNotImplementAnyGeneratedComInterface,
syntax.Identifier.GetLocation(),
type.ToDisplayString()));
}


return DiagnosticOr<ComClassInfo>.From(
new ComClassInfo(
type.ToDisplayString(),
new ContainingSyntaxContext(syntax),
new ContainingSyntax(syntax.Modifiers, syntax.Kind(), syntax.Identifier, syntax.TypeParameterList),
new(names.ToImmutable())));
return ComClassInfo.From(type, syntax, unsafeCodeIsEnabled);
});

var attributedClasses = context.FilterAndReportDiagnostics(attributedClassesOrDiagnostics);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;

namespace Microsoft.Interop
{
internal sealed record ComClassInfo
{
public string ClassName { get; init; }
public ContainingSyntaxContext ContainingSyntaxContext { get; init; }
public ContainingSyntax ClassSyntax { get; init; }
public SequenceEqualImmutableArray<string> ImplementedInterfacesNames { get; init; }

private ComClassInfo(string className, ContainingSyntaxContext containingSyntaxContext, ContainingSyntax classSyntax, SequenceEqualImmutableArray<string> implementedInterfacesNames)
{
ClassName = className;
ContainingSyntaxContext = containingSyntaxContext;
ClassSyntax = classSyntax;
ImplementedInterfacesNames = implementedInterfacesNames;
}

public static DiagnosticOr<ComClassInfo> From(INamedTypeSymbol type, ClassDeclarationSyntax syntax, bool unsafeCodeIsEnabled)
{
if (!unsafeCodeIsEnabled)
{
return DiagnosticOr<ComClassInfo>.From(DiagnosticInfo.Create(GeneratorDiagnostics.RequiresAllowUnsafeBlocks, syntax.Identifier.GetLocation()));
}

if (!syntax.IsInPartialContext(out _))
{
return DiagnosticOr<ComClassInfo>.From(
DiagnosticInfo.Create(
GeneratorDiagnostics.InvalidAttributedClassMissingPartialModifier,
syntax.Identifier.GetLocation(),
type.ToDisplayString()));
}

ImmutableArray<string>.Builder names = ImmutableArray.CreateBuilder<string>();
foreach (INamedTypeSymbol iface in type.AllInterfaces)
{
AttributeData? generatedComInterfaceAttribute = iface.GetAttributes().FirstOrDefault(attr => attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute);
if (generatedComInterfaceAttribute is not null)
{
var attributeData = GeneratedComInterfaceCompilationData.GetDataFromAttribute(generatedComInterfaceAttribute);
if (attributeData.Options.HasFlag(ComInterfaceOptions.ManagedObjectWrapper))
{
names.Add(iface.ToDisplayString());
}
}
}

if (names.Count == 0)
{
return DiagnosticOr<ComClassInfo>.From(DiagnosticInfo.Create(GeneratorDiagnostics.ClassDoesNotImplementAnyGeneratedComInterface,
syntax.Identifier.GetLocation(),
type.ToDisplayString()));
}

return DiagnosticOr<ComClassInfo>.From(
new ComClassInfo(
type.ToDisplayString(),
new ContainingSyntaxContext(syntax),
new ContainingSyntax(syntax.Modifiers, syntax.Kind(), syntax.Identifier, syntax.TypeParameterList),
new(names.ToImmutable())));
}

public bool Equals(ComClassInfo? other)
{
return other is not null
&& ClassName == other.ClassName
&& ContainingSyntaxContext.Equals(other.ContainingSyntaxContext)
&& ImplementedInterfacesNames.SequenceEqual(other.ImplementedInterfacesNames);
}

public override int GetHashCode()
{
return HashCode.Combine(ClassName, ContainingSyntaxContext, ImplementedInterfacesNames);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
Expand All @@ -9,8 +10,19 @@

namespace Microsoft.Interop
{
internal sealed record ComInterfaceContext(ComInterfaceInfo Info, ComInterfaceContext? Base, ComInterfaceOptions Options)
internal sealed record ComInterfaceContext
{
internal ComInterfaceInfo Info { get; init; }
internal ComInterfaceContext? Base { get; init; }
internal ComInterfaceOptions Options { get; init; }

private ComInterfaceContext(ComInterfaceInfo info, ComInterfaceContext? @base, ComInterfaceOptions options)
{
Info = info;
Base = @base;
Options = options;
}

/// <summary>
/// Takes a list of ComInterfaceInfo, and creates a list of ComInterfaceContext.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,40 @@ namespace Microsoft.Interop
/// <summary>
/// Information about a Com interface, but not its methods.
/// </summary>
internal sealed record ComInterfaceInfo(
ManagedTypeInfo Type,
string ThisInterfaceKey, // For associating interfaces to its base
string? BaseInterfaceKey, // For associating interfaces to its base
InterfaceDeclarationSyntax Declaration,
ContainingSyntaxContext TypeDefinitionContext,
ContainingSyntax ContainingSyntax,
Guid InterfaceId,
ComInterfaceOptions Options,
Location DiagnosticLocation)
internal sealed record ComInterfaceInfo
{
public ManagedTypeInfo Type { get; init; }
public string ThisInterfaceKey { get; init; }
public string? BaseInterfaceKey { get; init; }
public InterfaceDeclarationSyntax Declaration { get; init; }
public ContainingSyntaxContext TypeDefinitionContext { get; init; }
public ContainingSyntax ContainingSyntax { get; init; }
public Guid InterfaceId { get; init; }
public ComInterfaceOptions Options { get; init; }
public Location DiagnosticLocation { get; init; }

private ComInterfaceInfo(
ManagedTypeInfo type,
string thisInterfaceKey,
string? baseInterfaceKey,
InterfaceDeclarationSyntax declaration,
ContainingSyntaxContext typeDefinitionContext,
ContainingSyntax containingSyntax,
Guid interfaceId,
ComInterfaceOptions options,
Location diagnosticLocation)
{
Type = type;
ThisInterfaceKey = thisInterfaceKey;
BaseInterfaceKey = baseInterfaceKey;
Declaration = declaration;
TypeDefinitionContext = typeDefinitionContext;
ContainingSyntax = containingSyntax;
InterfaceId = interfaceId;
Options = options;
DiagnosticLocation = diagnosticLocation;
}

public static DiagnosticOrInterfaceInfo From(INamedTypeSymbol symbol, InterfaceDeclarationSyntax syntax, StubEnvironment env, CancellationToken _)
{
if (env.Compilation.Options is not CSharpCompilationOptions { AllowUnsafe: true }) // Unsafe code enabled
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
Expand All @@ -14,12 +15,25 @@ namespace Microsoft.Interop
/// <summary>
/// Represents a method that has been determined to be a COM interface method. Only contains info immediately available from an IMethodSymbol and MethodDeclarationSyntax.
/// </summary>
internal sealed record ComMethodInfo(
MethodDeclarationSyntax Syntax,
string MethodName,
SequenceEqualImmutableArray<AttributeInfo> Attributes,
bool IsUserDefinedShadowingMethod)
internal sealed record ComMethodInfo
{
public MethodDeclarationSyntax Syntax { get; init; }
public string MethodName { get; init; }
public SequenceEqualImmutableArray<AttributeInfo> Attributes { get; init; }
public bool IsUserDefinedShadowingMethod { get; init; }

private ComMethodInfo(
MethodDeclarationSyntax syntax,
string methodName,
SequenceEqualImmutableArray<AttributeInfo> attributes,
bool isUserDefinedShadowingMethod)
{
Syntax = syntax;
MethodName = methodName;
Attributes = attributes;
IsUserDefinedShadowingMethod = isUserDefinedShadowingMethod;
}

/// <summary>
/// Returns a list of tuples of ComMethodInfo, IMethodSymbol, and Diagnostic. If ComMethodInfo is null, Diagnostic will not be null, and vice versa.
/// </summary>
Expand Down Expand Up @@ -95,7 +109,6 @@ internal sealed record ComMethodInfo(
return DiagnosticOr<(ComMethodInfo, IMethodSymbol)>.From(DiagnosticInfo.Create(GeneratorDiagnostics.MethodNotDeclaredInAttributedInterface, method.Locations.FirstOrDefault(), method.ToDisplayString()));
}


// Find the matching declaration syntax
MethodDeclarationSyntax? comMethodDeclaringSyntax = null;
foreach (var declaringSyntaxReference in method.DeclaringSyntaxReferences)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@ public static ManagedTypeInfo CreateTypeInfoForTypeSymbol(ITypeSymbol type)
public sealed record SpecialTypeInfo(string FullTypeName, string DiagnosticFormattedName, SpecialType SpecialType) : ManagedTypeInfo(FullTypeName, DiagnosticFormattedName)
{
public static readonly SpecialTypeInfo Byte = new("byte", "byte", SpecialType.System_Byte);
public static readonly SpecialTypeInfo SByte = new("sbyte", "sbyte", SpecialType.System_SByte);
public static readonly SpecialTypeInfo Int16 = new("short", "short", SpecialType.System_Int16);
public static readonly SpecialTypeInfo UInt16 = new("ushort", "ushort", SpecialType.System_UInt16);
public static readonly SpecialTypeInfo Int32 = new("int", "int", SpecialType.System_Int32);
public static readonly SpecialTypeInfo UInt32 = new("uint", "uint", SpecialType.System_UInt32);
public static readonly SpecialTypeInfo Void = new("void", "void", SpecialType.System_Void);
public static readonly SpecialTypeInfo String = new("string", "string", SpecialType.System_String);
public static readonly SpecialTypeInfo Boolean = new("bool", "bool", SpecialType.System_Boolean);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ private ResolvedGenerator CreateNativeCollectionMarshaller(
marshallerType = marshallerType with
{
FullTypeName = marshallerTypeSyntax.ToString(),
DiagnosticFormattedName = marshallerTypeSyntax.ToString(),
DiagnosticFormattedName = marshallerTypeSyntax.ToString()
};
string newNativeTypeName = ReplacePlaceholderSyntaxWithUnmanagedTypeSyntax(marshallerData.NativeType.Syntax, marshalInfo, unmanagedElementType).ToFullString();
ManagedTypeInfo nativeType = marshallerData.NativeType with
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ protected BoolMarshallerBase(ManagedTypeInfo nativeType, int trueValue, int fals

public ManagedTypeInfo AsNativeType(TypePositionInfo info)
{
Debug.Assert(info.ManagedType is SpecialTypeInfo(_, _, SpecialType.System_Boolean));
Debug.Assert(info.ManagedType is SpecialTypeInfo { SpecialType: SpecialType.System_Boolean });
return _nativeType;
}

Expand Down Expand Up @@ -118,7 +118,7 @@ public sealed class ByteBoolMarshaller : BoolMarshallerBase
/// </summary>
/// <param name="signed">True if the byte should be signed, otherwise false</param>
public ByteBoolMarshaller(bool signed)
: base(new SpecialTypeInfo(signed ? "sbyte" : "byte", signed ? "sbyte" : "byte", signed ? SpecialType.System_SByte : SpecialType.System_Byte), trueValue: 1, falseValue: 0, compareToTrue: false)
: base(signed ? SpecialTypeInfo.SByte : SpecialTypeInfo.Byte, trueValue: 1, falseValue: 0, compareToTrue: false)
{
}
}
Expand All @@ -136,7 +136,7 @@ public sealed class WinBoolMarshaller : BoolMarshallerBase
/// </summary>
/// <param name="signed">True if the int should be signed, otherwise false</param>
public WinBoolMarshaller(bool signed)
: base(new SpecialTypeInfo(signed ? "int" : "uint", signed ? "int" : "uint", signed ? SpecialType.System_Int32 : SpecialType.System_UInt32), trueValue: 1, falseValue: 0, compareToTrue: false)
: base(signed ? SpecialTypeInfo.Int32 : SpecialTypeInfo.UInt32, trueValue: 1, falseValue: 0, compareToTrue: false)
{
}
}
Expand All @@ -149,7 +149,7 @@ public sealed class VariantBoolMarshaller : BoolMarshallerBase
private const short VARIANT_TRUE = -1;
private const short VARIANT_FALSE = 0;
public VariantBoolMarshaller()
: base(new SpecialTypeInfo("short", "short", SpecialType.System_Int16), trueValue: VARIANT_TRUE, falseValue: VARIANT_FALSE, compareToTrue: true)
: base(SpecialTypeInfo.Int16, trueValue: VARIANT_TRUE, falseValue: VARIANT_FALSE, compareToTrue: true)
{
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public ResolvedGenerator Create(TypePositionInfo info, StubCodeContext context)
}

// Breaking change: [MarshalAs(UnmanagedType.Struct)] in object in unmanaged-to-managed scenarios will not respect VT_BYREF.
if (info is { RefKind: RefKind.In or RefKind.RefReadOnlyParameter, MarshallingAttributeInfo: NativeMarshallingAttributeInfo(ManagedTypeInfo(_, TypeNames.ComVariantMarshaller), _) }
if (info is { RefKind: RefKind.In or RefKind.RefReadOnlyParameter, MarshallingAttributeInfo: NativeMarshallingAttributeInfo(ManagedTypeInfo { DiagnosticFormattedName: TypeNames.ComVariantMarshaller }, _) }
&& context.Direction == MarshalDirection.UnmanagedToManaged)
{
gen = ResolvedGenerator.ResolvedWithDiagnostics(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace Microsoft.Interop
{
public sealed class Utf16CharMarshaller : IMarshallingGenerator
{
private static readonly ManagedTypeInfo s_nativeType = new SpecialTypeInfo("ushort", "ushort", SpecialType.System_UInt16);
private static readonly ManagedTypeInfo s_nativeType = SpecialTypeInfo.UInt16;

public ValueBoundaryBehavior GetValueBoundaryBehavior(TypePositionInfo info, StubCodeContext context)
{
Expand All @@ -35,7 +35,7 @@ public ValueBoundaryBehavior GetValueBoundaryBehavior(TypePositionInfo info, Stu

public ManagedTypeInfo AsNativeType(TypePositionInfo info)
{
Debug.Assert(info.ManagedType is SpecialTypeInfo(_, _, SpecialType.System_Char));
Debug.Assert(info.ManagedType is SpecialTypeInfo {SpecialType: SpecialType.System_Char });
return s_nativeType;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ public ResolvedGenerator Create(
return ResolvedGenerator.Resolved(s_blittable);

// Pointer with no marshalling info
case { ManagedType: PointerTypeInfo(_, _, IsFunctionPointer: false), MarshallingAttributeInfo: NoMarshallingInfo }:
case { ManagedType: PointerTypeInfo{ IsFunctionPointer: false }, MarshallingAttributeInfo: NoMarshallingInfo }:
return ResolvedGenerator.Resolved(s_blittable);

// Function pointer with no marshalling info
case { ManagedType: PointerTypeInfo(_, _, IsFunctionPointer: true), MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.FunctionPtr, _) }:
case { ManagedType: PointerTypeInfo { IsFunctionPointer: true }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.FunctionPtr, _) }:
return ResolvedGenerator.Resolved(s_blittable);

// Bool with marshalling info
Expand Down
Loading

0 comments on commit e36b937

Please sign in to comment.