Skip to content

Commit

Permalink
Generate sync classes
Browse files Browse the repository at this point in the history
  • Loading branch information
GerardSmit committed Mar 30, 2024
1 parent a3d7afb commit bdd854f
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 43 deletions.
7 changes: 6 additions & 1 deletion src/Zomp.SyncMethodGenerator/SourceGenerationHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,18 @@ namespace Zomp.SyncMethodGenerator
/// <summary>
/// An attribute that can be used to automatically generate a synchronous version of an async method. Must be used in a partial class.
/// </summary>
[System.AttributeUsage(System.AttributeTargets.Method)]
[System.AttributeUsage(System.AttributeTargets.Method | System.AttributeTargets.Class, Inherited = false, AllowMultiple = false)]
internal class {{SyncMethodSourceGenerator.CreateSyncVersionAttribute}} : System.Attribute
{
/// <summary>
/// Gets or sets a value indicating whether "#nullable enable" directive will be omitted from generated code. False by default.
/// </summary>
public bool {{SyncMethodSourceGenerator.OmitNullableDirective}} { get; set; }
/// <summary>
/// Gets or sets the name of the generated method or class. If not set, the name will be the same as the original method or class.
/// </summary>
public string {{SyncMethodSourceGenerator.Name}} { get; set; }
}
#endif
}
Expand Down
170 changes: 129 additions & 41 deletions src/Zomp.SyncMethodGenerator/SyncMethodSourceGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ public class SyncMethodSourceGenerator : IIncrementalGenerator
internal const string QualifiedCreateSyncVersionAttribute = $"{ThisAssembly.RootNamespace}.{CreateSyncVersionAttribute}";

internal const string OmitNullableDirective = "OmitNullableDirective";

private static MethodToGenerate? last;
internal const string Name = "Name";

/// <inheritdoc/>
public void Initialize(IncrementalGeneratorInitializationContext context)
Expand All @@ -33,22 +32,38 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
context.CompilationProvider.Select((c, _)
=> c is CSharpCompilation { LanguageVersion: < LanguageVersion.CSharp8 });

var methodDeclarations = context.SyntaxProvider
var methodSourceTexts = context.SyntaxProvider
.ForAttributeWithMetadataName(
QualifiedCreateSyncVersionAttribute,
predicate: static (s, _) => IsSyntaxTargetForGeneration(s),
transform: static (ctx, ct) => ctx)
predicate: static (s, _) => s is MethodDeclarationSyntax { AttributeLists.Count: > 0 },
transform: static (ctx, _) => ctx)
.Combine(disableNullable)
.Select((data, ct) => GetMethodToGenerate(data.Left, (MethodDeclarationSyntax)data.Left.TargetNode, data.Right, ct)!)
.WithTrackingName("GetMethodToGenerate")
.Where(static s => s is not null);
.Where(static s => s is not null)
.Select(static (m, _) => GenerateSource(m))
.WithTrackingName("GenerateMethodSource");

AddSourceTexts(context, methodSourceTexts);

var sourceTexts = methodDeclarations
var classSourceTexts = context.SyntaxProvider
.ForAttributeWithMetadataName(
QualifiedCreateSyncVersionAttribute,
predicate: static (s, _) => s is TypeDeclarationSyntax { AttributeLists.Count: > 0 },
transform: static (ctx, _) => ctx)
.Combine(disableNullable)
.SelectMany((data, ct) => GetClassToGenerate(data.Left, (TypeDeclarationSyntax)data.Left.TargetNode, data.Right, ct)!)
.WithTrackingName("GetClassToGenerate")
.Select(static (m, _) => GenerateSource(m))
.WithTrackingName("GenerateSource");
.WithTrackingName("GenerateClassSource");

AddSourceTexts(context, classSourceTexts);
}

private static void AddSourceTexts(IncrementalGeneratorInitializationContext context, IncrementalValuesProvider<(MethodToGenerate MethodToGenerate, string Path, string Content)> classSourceTexts)
{
context.RegisterSourceOutput(
sourceTexts,
classSourceTexts,
static (spc, source) =>
{
foreach (var diagnostic in source.MethodToGenerate.Diagnostics)
Expand All @@ -63,9 +78,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
});
}

private static bool IsSyntaxTargetForGeneration(SyntaxNode node)
=> node is MethodDeclarationSyntax { AttributeLists.Count: > 0 };

private static (MethodToGenerate MethodToGenerate, string Path, string Content) GenerateSource(MethodToGenerate m)
{
static string BuildClassName(ClassDeclaration c)
Expand All @@ -87,15 +99,81 @@ static string BuildClassName(ClassDeclaration c)
return (m, sourcePath, source);
}

private static MethodToGenerate? GetMethodToGenerate(GeneratorAttributeSyntaxContext context, MethodDeclarationSyntax methodDeclarationSyntax, bool disableNullable, CancellationToken ct)
private static ImmutableArray<MethodToGenerate> GetClassToGenerate(GeneratorAttributeSyntaxContext context, TypeDeclarationSyntax typeDeclartionSyntax, bool disableNullable, CancellationToken ct)
{
if (context.TargetSymbol is not ITypeSymbol typeSymbol)
{
// the attribute isn't on a method
return default;
}

INamedTypeSymbol? attribute = context.SemanticModel.Compilation.GetTypeByMetadataName(QualifiedCreateSyncVersionAttribute);
if (attribute == null)
{
// nothing to do if this type isn't available
return default;
}

AttributeData syncMethodGeneratorAttributeData = null!;

foreach (AttributeData attributeData in typeSymbol.GetAttributes())
{
if (!attribute.Equals(attributeData.AttributeClass, SymbolEqualityComparer.Default))
{
continue;
}

syncMethodGeneratorAttributeData = attributeData;
break;
}

var className = syncMethodGeneratorAttributeData.NamedArguments.FirstOrDefault(c => c.Key == Name).Value.Value as string;
var array = ImmutableArray.CreateBuilder<MethodToGenerate>();

foreach (var member in typeDeclartionSyntax.Members)
{
if (member is not MethodDeclarationSyntax method)
{
continue;
}

var m = GetMethodToGenerate(context, method, disableNullable, ct, attributeOptional: true);

if (m is null)
{
continue;
}

if (className != null)
{
m = m with
{
Classes = ImmutableArray.Create(CreateClassDeclaration(typeDeclartionSyntax, className)),
};
}

array.Add(m);
}

return array.ToImmutable();
}

private static MethodToGenerate? GetMethodToGenerate(GeneratorAttributeSyntaxContext context, MethodDeclarationSyntax methodDeclarationSyntax, bool disableNullable, CancellationToken ct, bool attributeOptional = false)
{
// stop if we're asked to
ct.ThrowIfCancellationRequested();

if (context.TargetSymbol is not IMethodSymbol methodSymbol)
{
// the attribute isn't on a method
return null;
var symbolFromModel = context.SemanticModel.GetDeclaredSymbol(methodDeclarationSyntax);

if (symbolFromModel is null)
{
// the attribute isn't on a method
return null;
}

methodSymbol = symbolFromModel;
}

INamedTypeSymbol? attribute = context.SemanticModel.Compilation.GetTypeByMetadataName(QualifiedCreateSyncVersionAttribute);
Expand Down Expand Up @@ -129,7 +207,10 @@ static string BuildClassName(ClassDeclaration c)
return null;
}

AttributeData syncMethodGeneratorAttributeData = null!;
string? name = null;
var explicitDisableNullable = false;

AttributeData? syncMethodGeneratorAttributeData = null;

foreach (AttributeData attributeData in methodSymbol.GetAttributes())
{
Expand All @@ -142,7 +223,12 @@ static string BuildClassName(ClassDeclaration c)
break;
}

var explicitDisableNullable = syncMethodGeneratorAttributeData.NamedArguments.FirstOrDefault(c => c.Key == OmitNullableDirective) is { Value.Value: true };
if (syncMethodGeneratorAttributeData != null)
{
explicitDisableNullable = syncMethodGeneratorAttributeData.NamedArguments.FirstOrDefault(c => c.Key == OmitNullableDirective) is { Value.Value: true };
name = syncMethodGeneratorAttributeData.NamedArguments.FirstOrDefault(c => c.Key == Name).Value.Value as string;
}

disableNullable |= explicitDisableNullable;

var classes = ImmutableArray.CreateBuilder<ClassDeclaration>();
Expand All @@ -155,27 +241,7 @@ static string BuildClassName(ClassDeclaration c)
break;
}

var modifiers = ImmutableArray.CreateBuilder<ushort>();

foreach (var mod in classSyntax.Modifiers)
{
var kind = mod.RawKind;
if (kind == (int)SyntaxKind.PartialKeyword)
{
continue;
}

modifiers.Add((ushort)kind);
}

var typeParameters = ImmutableArray.CreateBuilder<string>();

foreach (var typeParameter in classSyntax.TypeParameterList?.Parameters ?? default)
{
typeParameters.Add(typeParameter.Identifier.ValueText);
}

classes.Insert(0, new(classSyntax.Identifier.ValueText, modifiers.ToImmutable(), typeParameters.ToImmutable()));
classes.Insert(0, CreateClassDeclaration(classSyntax));
}

if (classes.Count == 0)
Expand Down Expand Up @@ -219,9 +285,31 @@ static string BuildClassName(ClassDeclaration c)
}
}

var result = new MethodToGenerate(index, namespaces.ToImmutable(), isNamespaceFileScoped, classes.ToImmutable(), methodDeclarationSyntax.Identifier.ValueText, content, disableNullable, rewriter.Diagnostics, hasErrors);
return new MethodToGenerate(index, namespaces.ToImmutable(), isNamespaceFileScoped, classes.ToImmutable(), name ?? methodDeclarationSyntax.Identifier.ValueText, content, disableNullable, rewriter.Diagnostics, hasErrors);
}

private static ClassDeclaration CreateClassDeclaration(TypeDeclarationSyntax classSyntax, string? name = null)
{
var modifiers = ImmutableArray.CreateBuilder<ushort>();

foreach (var mod in classSyntax.Modifiers)
{
var kind = mod.RawKind;
if (kind == (int)SyntaxKind.PartialKeyword)
{
continue;
}

modifiers.Add((ushort)kind);
}

var typeParameters = ImmutableArray.CreateBuilder<string>();

foreach (var typeParameter in classSyntax.TypeParameterList?.Parameters ?? default)
{
typeParameters.Add(typeParameter.Identifier.ValueText);
}

last = result;
return result;
return new ClassDeclaration(name ?? classSyntax.Identifier.ValueText, modifiers.ToImmutable(), typeParameters.ToImmutable());
}
}
42 changes: 42 additions & 0 deletions tests/Generator.Tests/ClassTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
namespace Generator.Tests;

public class ClassTests
{
[Fact]
public Task ConvertClass() => $$"""
namespace Test;
interface IAsyncFileSystem
{
Task<string> ReadAllTextAsync(string path);
Task WriteAllTextAsync(string path, string contents);
}
interface IFileSystem : IAsyncFileSystem
{
string ReadAllText(string path);
void WriteAllText(string path, string contents);
}
[Zomp.SyncMethodGenerator.CreateSyncVersion(Name = "FileSystem")]
internal class AsyncFileSystem : IAsyncFileSystem
{
public virtual Task<string> ReadAllTextAsync(string path) => Task.FromResult("");
public virtual Task WriteAllTextAsync(string path, string contents) => Task.CompletedTask;
}
internal partial class FileSystem : AsyncFileSystem, IFileSystem
{
public override Task<string> ReadAllTextAsync(string path) => Task.FromResult(ReadAllText(path));
public override Task WriteAllTextAsync(string path, string contents)
{
WriteAllText(path, contents);
return Task.CompletedTask;
}
}
""".Verify(false, true);
}
2 changes: 1 addition & 1 deletion tests/Generator.Tests/IncrementalGeneratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,6 @@ public void CheckGeneratorIsIncremental(
var (value, reason) = Assert.Single(sourceOutputs);
Assert.Equal(sourceStepReason, reason);
Assert.Equal(executeStepReason, result.TrackedSteps["GetMethodToGenerate"].Single().Outputs[0].Reason);
Assert.Equal(combineStepReason, result.TrackedSteps["GenerateSource"].Single().Outputs[0].Reason);
Assert.Equal(combineStepReason, result.TrackedSteps["GenerateMethodSource"].Single().Outputs[0].Reason);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
//HintName: Test.Test.FileSystem.ReadAllTextAsync.g.cs
// <auto-generated/>
#nullable enable
namespace Test;
internal partial class FileSystem
{
public virtual string ReadAllText(string path) => "";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
//HintName: Test.Test.FileSystem.WriteAllTextAsync.g.cs
// <auto-generated/>
#nullable enable
namespace Test;
internal partial class FileSystem
{
public virtual void WriteAllText(string path, string contents) { }
}

0 comments on commit bdd854f

Please sign in to comment.