diff --git a/Directory.Packages.props b/Directory.Packages.props index 543e3af..be9eada 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -3,6 +3,7 @@ + diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 0000000..6cc4dbc --- /dev/null +++ b/codecov.yml @@ -0,0 +1,8 @@ +coverage: + status: + project: + default: + threshold: 10% + patch: + default: + threshold: 10% diff --git a/src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs b/src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs index ca8535e..5665fd8 100644 --- a/src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs +++ b/src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs @@ -57,7 +57,7 @@ internal sealed class AsyncToSyncRewriter(SemanticModel semanticModel) : CSharpS private readonly SemanticModel semanticModel = semanticModel; private readonly HashSet removedParameters = []; private readonly Dictionary renamedLocalFunctions = []; - private readonly ImmutableArray.Builder diagnostics = ImmutableArray.CreateBuilder(); + private readonly ImmutableArray.Builder diagnostics = ImmutableArray.CreateBuilder(); private enum SyncOnlyDirectiveType { @@ -77,7 +77,7 @@ private enum SpecialMethod /// /// Gets the diagnostics messages. /// - public ImmutableArray Diagnostics => diagnostics.ToImmutable(); + public ImmutableArray Diagnostics => diagnostics.ToImmutable(); /// public override SyntaxNode? VisitConditionalAccessExpression(ConditionalAccessExpressionSyntax node) @@ -1307,7 +1307,7 @@ BinaryExpressionSyntax be if (syncOnlyDirectiveType == SyncOnlyDirectiveType.Invalid) { - var d = Diagnostic.Create(InvalidCondition, trivia.GetLocation(), trivia); + var d = ReportedDiagnostic.Create(InvalidCondition, trivia.GetLocation(), trivia.ToString()); diagnostics.Add(d); return null; } @@ -1318,7 +1318,7 @@ BinaryExpressionSyntax be { if (isStackSyncOnly ^ syncOnlyDirectiveType == SyncOnlyDirectiveType.SyncOnly) { - var d = Diagnostic.Create(InvalidNesting, trivia.GetLocation(), trivia); + var d = ReportedDiagnostic.Create(InvalidNesting, trivia.GetLocation(), trivia.ToString()); diagnostics.Add(d); return null; } @@ -1378,7 +1378,7 @@ BinaryExpressionSyntax be } else { - var d = Diagnostic.Create(InvalidElif, trivia.GetLocation(), trivia); + var d = ReportedDiagnostic.Create(InvalidElif, trivia.GetLocation(), trivia.ToString()); diagnostics.Add(d); return null; } diff --git a/src/Zomp.SyncMethodGenerator/ClassDeclaration.cs b/src/Zomp.SyncMethodGenerator/ClassDeclaration.cs index 978b68e..e1f8542 100644 --- a/src/Zomp.SyncMethodGenerator/ClassDeclaration.cs +++ b/src/Zomp.SyncMethodGenerator/ClassDeclaration.cs @@ -6,4 +6,4 @@ /// Class name. /// A list of modifiers. /// A list of type parameters. -internal sealed record ClassDeclaration(string ClassName, IEnumerable Modifiers, TypeParameterListSyntax? TypeParameterListSyntax); +internal sealed record ClassDeclaration(string ClassName, EquatableArray Modifiers, EquatableArray TypeParameterListSyntax); diff --git a/src/Zomp.SyncMethodGenerator/Helpers/EquatableArray.cs b/src/Zomp.SyncMethodGenerator/Helpers/EquatableArray.cs new file mode 100644 index 0000000..c7f0a94 --- /dev/null +++ b/src/Zomp.SyncMethodGenerator/Helpers/EquatableArray.cs @@ -0,0 +1,202 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections; +using System.Diagnostics.CodeAnalysis; + +namespace Zomp.SyncMethodGenerator.Helpers; + +/// +/// An immutable, equatable array. This is equivalent to but with value equality support. +/// +/// The type of values in the array. +/// +/// Modified from: https://github.com/dotnet/runtime/issues/77183#issuecomment-1284577055. +/// Remove this struct when the issue above is resolved. +/// +[ExcludeFromCodeCoverage] +internal readonly struct EquatableArray : IEquatable>, IEnumerable + where T : IEquatable +{ + /// + /// The underlying array. + /// + private readonly T[]? array; + + /// + /// Initializes a new instance of the struct. + /// + /// The input to wrap. + public EquatableArray(ImmutableArray array) + { + this.array = Unsafe.As, T[]?>(ref array); + } + + /// + /// Gets a value indicating whether the current array is empty. + /// + public bool IsEmpty + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => AsImmutableArray().IsEmpty; + } + + /// + /// Gets a reference to an item at a specified position within the array. + /// + /// The index of the item to retrieve a reference to. + /// A reference to an item at a specified position within the array. + public ref readonly T this[int index] + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => ref AsImmutableArray().ItemRef(index); + } + + /// + /// Implicitly converts an to . + /// + /// An instance from a given . + public static implicit operator EquatableArray(ImmutableArray array) + { + return FromImmutableArray(array); + } + + /// + /// Implicitly converts an to . + /// + /// An instance from a given . + public static implicit operator ImmutableArray(EquatableArray array) + { + return array.AsImmutableArray(); + } + + /// + /// Checks whether two values are the same. + /// + /// The first value. + /// The second value. + /// Whether and are equal. + public static bool operator ==(EquatableArray left, EquatableArray right) + { + return left.Equals(right); + } + + /// + /// Checks whether two values are not the same. + /// + /// The first value. + /// The second value. + /// Whether and are not equal. + public static bool operator !=(EquatableArray left, EquatableArray right) + { + return !left.Equals(right); + } + + /// + /// Creates an instance from a given . + /// + /// The input instance. + /// An instance from a given . + public static EquatableArray FromImmutableArray(ImmutableArray array) + { + return new(array); + } + + /// + public bool Equals(EquatableArray array) + { + return AsSpan().SequenceEqual(array.AsSpan()); + } + + /// + public override bool Equals([NotNullWhen(true)] object? obj) + { + return obj is EquatableArray array && Equals(this, array); + } + + /// + public override int GetHashCode() + { + if (this.array is not T[] array) + { + return 0; + } + + HashCode hashCode = default; + + foreach (T item in array) + { + hashCode.Add(item); + } + + return hashCode.ToHashCode(); + } + + /// + /// Gets an instance from the current . + /// + /// The from the current . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public ImmutableArray AsImmutableArray() + { + return Unsafe.As>(ref Unsafe.AsRef(in array)); + } + + /// + /// Returns a wrapping the current items. + /// + /// A wrapping the current items. + public ReadOnlySpan AsSpan() + { + return AsImmutableArray().AsSpan(); + } + + /// + /// Copies the contents of this instance to a mutable array. + /// + /// The newly instantiated array. + public T[] ToArray() + { + return AsImmutableArray().ToArray(); + } + + /// + /// Gets an value to traverse items in the current array. + /// + /// An value to traverse items in the current array. + public ImmutableArray.Enumerator GetEnumerator() + { + return AsImmutableArray().GetEnumerator(); + } + + /// + IEnumerator IEnumerable.GetEnumerator() + { + return ((IEnumerable)AsImmutableArray()).GetEnumerator(); + } + + /// + IEnumerator IEnumerable.GetEnumerator() + { + return ((IEnumerable)AsImmutableArray()).GetEnumerator(); + } +} + +/// +/// Extensions for . +/// +[ExcludeFromCodeCoverage] +internal static class EquatableArray +{ + /// + /// Creates an instance from a given . + /// + /// The type of items in the input array. + /// The input instance. + /// An instance from a given . + public static EquatableArray AsEquatableArray(this ImmutableArray array) + where T : IEquatable + { + return new(array); + } +} diff --git a/src/Zomp.SyncMethodGenerator/MethodToGenerate.cs b/src/Zomp.SyncMethodGenerator/MethodToGenerate.cs index 47d5264..884a228 100644 --- a/src/Zomp.SyncMethodGenerator/MethodToGenerate.cs +++ b/src/Zomp.SyncMethodGenerator/MethodToGenerate.cs @@ -3,16 +3,22 @@ /// /// Represents a sync method to generate from its async version. /// +/// Index of the method in the source file. /// List of namespaces this method is under. /// True if namespace is file scoped. /// List of classes this method belongs to starting from the outer-most class. /// Name of the method. /// Implementation. /// Disables nullable for the method. +/// Diagnostics. +/// True if there are errors in . internal sealed record MethodToGenerate( - IEnumerable Namespaces, + int Index, + EquatableArray Namespaces, bool IsNamespaceFileScoped, - IEnumerable Classes, + EquatableArray Classes, string MethodName, string Implementation, - bool DisableNullable); + bool DisableNullable, + EquatableArray Diagnostics, + bool HasErrors); diff --git a/src/Zomp.SyncMethodGenerator/Models/ReportedDiagnostic.cs b/src/Zomp.SyncMethodGenerator/Models/ReportedDiagnostic.cs new file mode 100644 index 0000000..df4e0b0 --- /dev/null +++ b/src/Zomp.SyncMethodGenerator/Models/ReportedDiagnostic.cs @@ -0,0 +1,37 @@ +namespace Zomp.SyncMethodGenerator.Models; + +/// +/// Basic diagnostic description for reporting diagnostic inside the incremental pipeline. +/// +/// Diagnostic descriptor. +/// File path. +/// Text span. +/// Line span. +/// Trivia. +/// +internal sealed record ReportedDiagnostic(DiagnosticDescriptor Descriptor, string FilePath, TextSpan TextSpan, LinePositionSpan LineSpan, string Trivia) +{ + /// + /// Implicitly converts to . + /// + /// Diagnostic to convert. + public static implicit operator Diagnostic(ReportedDiagnostic diagnostic) + { + return Diagnostic.Create( + descriptor: diagnostic.Descriptor, + location: Location.Create(diagnostic.FilePath, diagnostic.TextSpan, diagnostic.LineSpan), + messageArgs: new object[] { diagnostic.Trivia }); + } + + /// + /// Creates a new from and . + /// + /// Descriptor. + /// Location. + /// Trivia. + /// A new . + public static ReportedDiagnostic Create(DiagnosticDescriptor descriptor, Location location, string trivia) + { + return new(descriptor, location.SourceTree?.FilePath ?? string.Empty, location.SourceSpan, location.GetLineSpan().Span, trivia); + } +} diff --git a/src/Zomp.SyncMethodGenerator/SourceGenerationHelper.cs b/src/Zomp.SyncMethodGenerator/SourceGenerationHelper.cs index 2d7489e..543f0a7 100644 --- a/src/Zomp.SyncMethodGenerator/SourceGenerationHelper.cs +++ b/src/Zomp.SyncMethodGenerator/SourceGenerationHelper.cs @@ -61,9 +61,9 @@ internal static string GenerateExtensionClass(MethodToGenerate methodToGenerate) { var indent = new string(' ', 4 * i); - var modifiers = string.Join(string.Empty, @class.Modifiers.Select(z => GetKeyword(z) + " ")); - var classDeclarationLine = $"{modifiers}partial class {@class.ClassName}{(@class.TypeParameterListSyntax is null ? string.Empty - : "<" + string.Join(", ", @class.TypeParameterListSyntax.Parameters.Select(z => z.ToString())) + ">")}"; + var modifiers = string.Join(string.Empty, @class.Modifiers.Select(z => GetKeyword((SyntaxKind)z) + " ")); + var classDeclarationLine = $"{modifiers}partial class {@class.ClassName}{(@class.TypeParameterListSyntax.IsEmpty ? string.Empty + : "<" + string.Join(", ", @class.TypeParameterListSyntax) + ">")}"; sbBegin.Append($$""" {{indent}}{{classDeclarationLine}} diff --git a/src/Zomp.SyncMethodGenerator/SyncMethodSourceGenerator.cs b/src/Zomp.SyncMethodGenerator/SyncMethodSourceGenerator.cs index 666bec5..0b9f8a0 100644 --- a/src/Zomp.SyncMethodGenerator/SyncMethodSourceGenerator.cs +++ b/src/Zomp.SyncMethodGenerator/SyncMethodSourceGenerator.cs @@ -12,6 +12,8 @@ public class SyncMethodSourceGenerator : IIncrementalGenerator public const string CreateSyncVersionAttribute = "CreateSyncVersionAttribute"; internal const string QualifiedCreateSyncVersionAttribute = $"{ThisAssembly.RootNamespace}.{CreateSyncVersionAttribute}"; + private static MethodToGenerate? last; + /// public void Initialize(IncrementalGeneratorInitializationContext context) { @@ -25,162 +27,161 @@ public void Initialize(IncrementalGeneratorInitializationContext context) context.RegisterPostInitializationOutput(ctx => ctx.AddSource( $"{CreateSyncVersionAttribute}.g.cs", SourceText.From(SourceGenerationHelper.CreateSyncVersionAttributeSource, Encoding.UTF8))); - IncrementalValuesProvider methodDeclarations = context.SyntaxProvider + var disableNullable = + context.CompilationProvider.Select((c, _) + => c is CSharpCompilation { LanguageVersion: < LanguageVersion.CSharp8 }); + + var methodDeclarations = context.SyntaxProvider .ForAttributeWithMetadataName( QualifiedCreateSyncVersionAttribute, predicate: static (s, _) => IsSyntaxTargetForGeneration(s), - transform: static (ctx, _) => (MethodDeclarationSyntax)ctx.TargetNode); + transform: static (ctx, ct) => ctx) + .Combine(disableNullable) + .Select((data, ct) => GetMethodToGenerate(data.Left, (MethodDeclarationSyntax)data.Left.TargetNode, data.Right, ct)!) + .WithTrackingName("GetMethodToGenerate") + .Where(static s => s is not null); - IncrementalValueProvider<(Compilation, ImmutableArray)> compilationAndMethods - = context.CompilationProvider.Combine(methodDeclarations.Collect()); + var sourceTexts = methodDeclarations + .Select(static (m, _) => GenerateSource(m)) + .WithTrackingName("GenerateSource"); context.RegisterSourceOutput( - compilationAndMethods, - static (spc, source) => Execute(source.Item1, source.Item2, spc)); + sourceTexts, + static (spc, source) => + { + foreach (var diagnostic in source.MethodToGenerate.Diagnostics) + { + spc.ReportDiagnostic(diagnostic); + } + + if (!source.MethodToGenerate.HasErrors) + { + spc.AddSource(source.Path, SourceText.From(source.Content, Encoding.UTF8)); + } + }); } private static bool IsSyntaxTargetForGeneration(SyntaxNode node) - => node is MethodDeclarationSyntax m && m.AttributeLists.Count > 0; + => node is MethodDeclarationSyntax { AttributeLists.Count: > 0 }; - private static void Execute(Compilation compilation, ImmutableArray methods, SourceProductionContext context) + private static (MethodToGenerate MethodToGenerate, string Path, string Content) GenerateSource(MethodToGenerate m) { - if (methods.IsDefaultOrEmpty) + var sourcePath = $"{string.Join(".", m.Namespaces)}" + + $".{string.Join(".", m.Classes.Select(c => c.ClassName))}" + + $".{m.MethodName + (m.Index == 1 ? string.Empty : "_" + m.Index)}.g.cs"; + + var source = SourceGenerationHelper.GenerateExtensionClass(m); + + return (m, sourcePath, source); + } + + private static MethodToGenerate? GetMethodToGenerate(GeneratorAttributeSyntaxContext context, MethodDeclarationSyntax methodDeclarationSyntax, bool disableNullable, CancellationToken ct) + { + // stop if we're asked to + ct.ThrowIfCancellationRequested(); + + if (context.TargetSymbol is not IMethodSymbol methodSymbol) { - // nothing to do yet - return; + // the attribute isn't on a method + return null; } - // I'm not sure if this is actually necessary, but `[LoggerMessage]` does it, so seems like a good idea! - IEnumerable distinctMethods = methods.Distinct(); + INamedTypeSymbol? attribute = context.SemanticModel.Compilation.GetTypeByMetadataName(QualifiedCreateSyncVersionAttribute); + if (attribute == null) + { + // nothing to do if this type isn't available + return null; + } - // Convert each MethodDeclarationSyntax to an MethodToGenerate - List methodsToGenerate = GetTypesToGenerate(context, compilation, distinctMethods, context.CancellationToken); + // find the index of the method in the containing type + var index = 1; - // If there were errors in the MethodDeclarationSyntax, we won't create an - // MethodToGenerate for it, so make sure we have something to generate - if (methodsToGenerate.Count > 0) + if (methodSymbol.ContainingType is { } containingType) { - // Generate the source code and add it to the output - var sourceDictionary = new Dictionary(); - foreach (var m in methodsToGenerate) + foreach (var member in containingType.GetMembers()) { - // Ensure there are no collisions in generated names - var i = 1; - while (true) + if (member.Equals(methodSymbol, SymbolEqualityComparer.Default)) { - var sourcePath = $"{string.Join(".", m.Namespaces)}" + - $".{string.Join(".", m.Classes.Select(c => c.ClassName))}" + - $".{m.MethodName + (i == 1 ? string.Empty : "_" + i)}.g.cs"; - - if (!sourceDictionary.ContainsKey(sourcePath)) - { - var source = SourceGenerationHelper.GenerateExtensionClass(m); - sourceDictionary.Add(sourcePath, source); - break; - } - - ++i; + break; } - } - foreach (var entry in sourceDictionary) - { - context.AddSource(entry.Key, SourceText.From(entry.Value, Encoding.UTF8)); + if (member.Name.Equals(methodSymbol.Name, StringComparison.Ordinal)) + { + ++index; + } } } - } - private static List GetTypesToGenerate(SourceProductionContext context, Compilation compilation, IEnumerable methodDeclarations, CancellationToken ct) - { - var methodsToGenerate = new List(); - INamedTypeSymbol? attribute = compilation.GetTypeByMetadataName(QualifiedCreateSyncVersionAttribute); - if (attribute == null) + if (!methodSymbol.IsAsync && !AsyncToSyncRewriter.IsTypeOfInterest(methodSymbol.ReturnType)) { - // nothing to do if this type isn't available - return methodsToGenerate; + return null; } - foreach (var methodDeclarationSyntax in methodDeclarations) + foreach (AttributeData attributeData in methodSymbol.GetAttributes()) { - // stop if we're asked to - ct.ThrowIfCancellationRequested(); - - SemanticModel semanticModel = compilation.GetSemanticModel(methodDeclarationSyntax.SyntaxTree); - - if (semanticModel.GetDeclaredSymbol(methodDeclarationSyntax, cancellationToken: ct) is not IMethodSymbol methodSymbol) + if (!attribute.Equals(attributeData.AttributeClass, SymbolEqualityComparer.Default)) { - // something went wrong continue; } - if (!methodSymbol.IsAsync && !AsyncToSyncRewriter.IsTypeOfInterest(methodSymbol.ReturnType)) + break; + } + + var classes = ImmutableArray.CreateBuilder(); + SyntaxNode? node = methodDeclarationSyntax; + while (node.Parent is not null) + { + node = node.Parent; + if (node is not ClassDeclarationSyntax classSyntax) { - continue; + break; } - var methodName = methodSymbol.ToString(); + var modifiers = ImmutableArray.CreateBuilder(); - foreach (AttributeData attributeData in methodSymbol.GetAttributes()) + foreach (var mod in classSyntax.Modifiers) { - if (!attribute.Equals(attributeData.AttributeClass, SymbolEqualityComparer.Default)) + var kind = mod.RawKind; + if (kind == (int)SyntaxKind.PartialKeyword) { continue; } - break; + modifiers.Add((ushort)kind); } - var classes = new List(); - SyntaxNode? node = methodDeclarationSyntax; - while (node.Parent is not null) - { - node = node.Parent; - if (node is not ClassDeclarationSyntax classSyntax) - { - break; - } - - var modifiers = new List(); - - foreach (var mod in classSyntax.Modifiers) - { - var kind = mod.RawKind; - if (kind == (int)SyntaxKind.PartialKeyword) - { - continue; - } - - modifiers.Add((SyntaxKind)kind); - } - - classes.Insert(0, new(classSyntax.Identifier.ValueText, modifiers, classSyntax.TypeParameterList)); - } + var typeParameters = ImmutableArray.CreateBuilder(); - if (classes.Count == 0) + foreach (var typeParameter in classSyntax.TypeParameterList?.Parameters ?? default) { - continue; + typeParameters.Add(typeParameter.Identifier.ValueText); } - var rewriter = new AsyncToSyncRewriter(semanticModel); - var sn = rewriter.Visit(methodDeclarationSyntax); - var content = sn.ToFullString(); + classes.Insert(0, new(classSyntax.Identifier.ValueText, modifiers.ToImmutable(), typeParameters.ToImmutable())); + } - var diagnostics = rewriter.Diagnostics; + if (classes.Count == 0) + { + return null; + } - var hasErrors = false; - foreach (var diagnostic in diagnostics) - { - context.ReportDiagnostic(diagnostic); - hasErrors |= diagnostic.Severity == DiagnosticSeverity.Error; - } + var rewriter = new AsyncToSyncRewriter(context.SemanticModel); + var sn = rewriter.Visit(methodDeclarationSyntax); + var content = sn.ToFullString(); - if (hasErrors) - { - continue; - } + var diagnostics = rewriter.Diagnostics; + + var hasErrors = false; + foreach (var diagnostic in diagnostics) + { + hasErrors |= diagnostic.Descriptor.DefaultSeverity == DiagnosticSeverity.Error; + } - var isNamespaceFileScoped = false; - var namespaces = new List(); + var isNamespaceFileScoped = false; + var namespaces = ImmutableArray.CreateBuilder(); + + if (!hasErrors) + { while (node is not null && node is not CompilationUnitSyntax) { switch (node) @@ -198,11 +199,11 @@ private static List GetTypesToGenerate(SourceProductionContext node = node.Parent; } - - var disableNullable = compilation is CSharpCompilation { LanguageVersion: < LanguageVersion.CSharp8 }; - methodsToGenerate.Add(new(namespaces, isNamespaceFileScoped, classes, methodDeclarationSyntax.Identifier.ValueText, content, disableNullable)); } - return methodsToGenerate; + var result = new MethodToGenerate(index, namespaces.ToImmutable(), isNamespaceFileScoped, classes.ToImmutable(), methodDeclarationSyntax.Identifier.ValueText, content, disableNullable, rewriter.Diagnostics, hasErrors); + + last = result; + return result; } } diff --git a/src/Zomp.SyncMethodGenerator/Zomp.SyncMethodGenerator.csproj b/src/Zomp.SyncMethodGenerator/Zomp.SyncMethodGenerator.csproj index a7c2bbc..e0c6791 100644 --- a/src/Zomp.SyncMethodGenerator/Zomp.SyncMethodGenerator.csproj +++ b/src/Zomp.SyncMethodGenerator/Zomp.SyncMethodGenerator.csproj @@ -23,6 +23,9 @@ + + + @@ -43,6 +46,7 @@ + @@ -58,6 +62,8 @@ + + diff --git a/tests/Generator.Tests/IncrementalGeneratorTests.cs b/tests/Generator.Tests/IncrementalGeneratorTests.cs new file mode 100644 index 0000000..04adb43 --- /dev/null +++ b/tests/Generator.Tests/IncrementalGeneratorTests.cs @@ -0,0 +1,116 @@ +using System.Reflection; +using Zomp.SyncMethodGenerator; + +namespace Generator.Tests; + +public class IncrementalGeneratorTests +{ + [Theory] + [InlineData( + IncrementalStepRunReason.Cached, + IncrementalStepRunReason.Unchanged, + IncrementalStepRunReason.Cached, + """ + using System; + using System.Threading.Tasks; + + class Test + { + public void ProgressMethod() { } + public Task ProgressMethodAsync() => Task.CompletedTask; + + [Zomp.SyncMethodGenerator.CreateSyncVersion] + public async Task CallProgressMethodAsync() + { + await ProgressMethodAsync(); + } + } + """, + """ + using System; + using System.Threading.Tasks; + + class Test + { + public void ProgressMethod() { } + public Task ProgressMethodAsync() => Task.Yield(); + + [Zomp.SyncMethodGenerator.CreateSyncVersion] + public async Task CallProgressMethodAsync() + { + await ProgressMethodAsync(); + } + } + """)] + [InlineData( + IncrementalStepRunReason.Modified, + IncrementalStepRunReason.Modified, + IncrementalStepRunReason.Modified, + """ + using System; + using System.Threading.Tasks; + + class Test + { + public void ProgressMethod() { } + public Task ProgressMethodAsync() => Task.CompletedTask; + + [Zomp.SyncMethodGenerator.CreateSyncVersion] + public async Task CallProgressMethodAsync() + { + } + } + """, + """ + using System; + using System.Threading.Tasks; + + class Test + { + public void ProgressMethod() { } + public Task ProgressMethodAsync() => Task.CompletedTask; + + [Zomp.SyncMethodGenerator.CreateSyncVersion] + public async Task CallProgressMethodAsync() + { + await ProgressMethodAsync(); + } + } + """)] + public void CheckGeneratorIsIncremental( + IncrementalStepRunReason sourceStepReason, + IncrementalStepRunReason executeStepReason, + IncrementalStepRunReason combineStepReason, + string source, + string sourceUpdated) + { + SyntaxTree baseSyntaxTree = CSharpSyntaxTree.ParseText(source); + + Compilation compilation = CSharpCompilation.Create( + "compilation", + new[] { baseSyntaxTree }, + new[] { MetadataReference.CreateFromFile(typeof(Binder).GetTypeInfo().Assembly.Location) }, + new CSharpCompilationOptions(OutputKind.ConsoleApplication)); + + ISourceGenerator sourceGenerator = new SyncMethodSourceGenerator().AsSourceGenerator(); + + GeneratorDriver driver = CSharpGeneratorDriver.Create( + generators: new[] { sourceGenerator }, + driverOptions: new GeneratorDriverOptions(default, trackIncrementalGeneratorSteps: true)); + + // Run the generator + driver = driver.RunGenerators(compilation); + + // Update the compilation and rerun the generator + compilation = compilation.ReplaceSyntaxTree(baseSyntaxTree, CSharpSyntaxTree.ParseText(sourceUpdated)); + driver = driver.RunGenerators(compilation); + + GeneratorRunResult result = driver.GetRunResult().Results.Single(); + IEnumerable<(object Value, IncrementalStepRunReason Reason)> sourceOutputs = + result.TrackedOutputSteps.SelectMany(outputStep => outputStep.Value).SelectMany(output => output.Outputs); + var (value, reason) = Assert.Single(sourceOutputs); + Assert.Equal(sourceStepReason, reason); + Assert.Equal(executeStepReason, result.TrackedSteps["GetMethodToGenerate"].Single().Outputs[0].Reason); + Assert.Equal(combineStepReason, result.TrackedSteps["GenerateSource"].Single().Outputs[0].Reason); + } +}