Skip to content

Commit

Permalink
Add first draft of real incremental generator
Browse files Browse the repository at this point in the history
  • Loading branch information
k94ll13nn3 committed Oct 5, 2023
1 parent b6d4d3d commit bca9320
Show file tree
Hide file tree
Showing 11 changed files with 847 additions and 86 deletions.
2 changes: 1 addition & 1 deletion global.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
"rollForward": "latestMajor",
"allowPrerelease": false
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
<RepositoryType>git</RepositoryType>
<RepositoryUrl>https://github.com/k94ll13nn3/AutoConstructor</RepositoryUrl>
<PackageReleaseNotes>https://github.com/k94ll13nn3/AutoConstructor/blob/main/CHANGELOG.md</PackageReleaseNotes>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>

<ItemGroup>
Expand All @@ -37,9 +38,9 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="[4.0.1]" PrivateAssets="all" />
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="[4.6.0]" PrivateAssets="all" />
<PackageReference Include="Microsoft.CodeAnalysis.Analyzers" Version="3.3.4" PrivateAssets="all" />
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Workspaces" Version="[4.0.1]" />
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Workspaces" Version="[4.6.0]" />
</ItemGroup>

<ItemGroup>
Expand Down
131 changes: 69 additions & 62 deletions src/AutoConstructor.Generator/AutoConstructorGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ namespace AutoConstructor.Generator;
[Generator]
public class AutoConstructorGenerator : IIncrementalGenerator
{
private static int _counter;

public void Initialize(IncrementalGeneratorInitializationContext context)
{
// Register the attribute source
Expand All @@ -23,38 +25,38 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
i.AddSource(Source.InjectAttributeFullName, SourceText.From(Source.InjectAttributeText, Encoding.UTF8));
});

IncrementalValuesProvider<ClassDeclarationSyntax> classDeclarations = context.SyntaxProvider
.CreateSyntaxProvider(static (s, _) => IsSyntaxTargetForGeneration(s), static (ctx, _) => GetSemanticTargetForGeneration(ctx))
.Where(static m => m is not null)!;

IncrementalValueProvider<(Compilation compilation, ImmutableArray<ClassDeclarationSyntax> classes, AnalyzerConfigOptions options)> valueProvider =
context.CompilationProvider
.Combine(classDeclarations.Collect())
IncrementalValuesProvider<(MainNamedTypeSymbolInfo symbol, EquatableArray<FieldInfo> fields, Options options)> valueProvider = context.SyntaxProvider
.ForAttributeWithMetadataName(
Source.AttributeFullName,
static (node, _) => IsSyntaxTargetForGeneration(node),
static (context, _) => (ClassDeclarationSyntax)context.TargetNode)
.Where(static m => m is not null)
.Collect()
.Combine(context.AnalyzerConfigOptionsProvider.Select((c, _) => c.GlobalOptions))
.Select((c, _) => (compilation: c.Left.Left, classes: c.Left.Right, options: c.Right));
.Combine(context.CompilationProvider)
.SelectMany((c, _) =>
{
(ImmutableArray<ClassDeclarationSyntax> classes, AnalyzerConfigOptions options, Compilation compilation) = (c.Left.Left, c.Left.Right, c.Right);
return Execute(compilation, classes, new(), options);
});

context.RegisterSourceOutput(valueProvider, static (spc, source) => Execute(source.compilation, source.classes, spc, source.options));
context.RegisterSourceOutput(valueProvider, static (context, item) =>
{
CompilationUnitSyntax compilationUnit = GenerateAutoConstructor(item.symbol, item.fields, item.options);
context.AddSource($"{item.symbol.Filename}.g.cs", compilationUnit.GetText(Encoding.UTF8));
});
}

private static bool IsSyntaxTargetForGeneration(SyntaxNode node)
{
return node is ClassDeclarationSyntax classDeclarationSyntax && classDeclarationSyntax.Modifiers.Any(SyntaxKind.PartialKeyword);
}

private static ClassDeclarationSyntax? GetSemanticTargetForGeneration(GeneratorSyntaxContext context)
{
var classDeclarationSyntax = (ClassDeclarationSyntax)context.Node;

INamedTypeSymbol? symbol = context.SemanticModel.GetDeclaredSymbol(classDeclarationSyntax);

return symbol?.HasAttribute(Source.AttributeFullName, context.SemanticModel.Compilation) is true ? classDeclarationSyntax : null;
}

private static void Execute(Compilation compilation, ImmutableArray<ClassDeclarationSyntax> classes, SourceProductionContext context, AnalyzerConfigOptions options)
private static IEnumerable<(MainNamedTypeSymbolInfo symbol, EquatableArray<FieldInfo> fields, Options options)> Execute(Compilation compilation, ImmutableArray<ClassDeclarationSyntax> classes, SourceProductionContext context, AnalyzerConfigOptions analyzerOptions)
{
if (classes.IsDefaultOrEmpty)
{
return;
yield break;
}

IEnumerable<IGrouping<ISymbol?, ClassDeclarationSyntax>> classesBySymbol = Enumerable.Empty<IGrouping<ISymbol?, ClassDeclarationSyntax>>();
Expand All @@ -64,9 +66,25 @@ private static void Execute(Compilation compilation, ImmutableArray<ClassDeclara
}
catch (ArgumentException)
{
return;
yield break;
}

bool generateConstructorDocumentation = false;
if (analyzerOptions.TryGetValue("build_property.AutoConstructor_GenerateConstructorDocumentation", out string? generateConstructorDocumentationSwitch))
{
generateConstructorDocumentation = generateConstructorDocumentationSwitch.Equals("true", StringComparison.OrdinalIgnoreCase);
}

analyzerOptions.TryGetValue("build_property.AutoConstructor_ConstructorDocumentationComment", out string? constructorDocumentationComment);

bool emitNullChecks = false;
if (analyzerOptions.TryGetValue("build_property.AutoConstructor_DisableNullChecking", out string? disableNullCheckingSwitch))
{
emitNullChecks = disableNullCheckingSwitch.Equals("false", StringComparison.OrdinalIgnoreCase);
}

Options options = new(generateConstructorDocumentation, constructorDocumentationComment, emitNullChecks);

foreach (IGrouping<ISymbol?, ClassDeclarationSyntax> groupedClasses in classesBySymbol)
{
if (context.CancellationToken.IsCancellationRequested)
Expand Down Expand Up @@ -98,27 +116,21 @@ private static void Execute(Compilation compilation, ImmutableArray<ClassDeclara

filename += ".g.cs";

bool emitNullChecks = false;
if (options.TryGetValue("build_property.AutoConstructor_DisableNullChecking", out string? disableNullCheckingSwitch))
{
emitNullChecks = disableNullCheckingSwitch.Equals("false", StringComparison.OrdinalIgnoreCase);
}

List<FieldInfo> concatenatedFields = GetFieldsFromSymbol(compilation, symbol, emitNullChecks);

ExtractFieldsFromParent(compilation, symbol, emitNullChecks, concatenatedFields);

FieldInfo[] fields = concatenatedFields.ToArray();
EquatableArray<FieldInfo> fields = concatenatedFields.ToImmutableArray();

if (fields.Length == 0)
if (fields.IsEmpty)
{
// No need to report diagnostic, taken care by the analyzers.
continue;
}

if (fields.GroupBy(x => x.ParameterName).Any(g =>
g.Where(c => c.Type is not null).Select(c => c.Type).Distinct(SymbolEqualityComparer.Default).Count() > 1
|| (g.All(c => c.Type is null) && g.Select(c => c.FallbackType).Distinct(SymbolEqualityComparer.Default).Count() > 1)
g.Where(c => c.Type is not null).Select(c => c.Type).Count() > 1
|| (g.All(c => c.Type is null) && g.Select(c => c.FallbackType).Count() > 1)
))
{
foreach (ClassDeclarationSyntax classDeclaration in groupedClasses)
Expand All @@ -137,28 +149,23 @@ private static void Execute(Compilation compilation, ImmutableArray<ClassDeclara
.Count() == 1
&& symbol.Constructors.Any(d => !d.IsStatic && d.Parameters.Length == 0);

context.AddSource(filename, SourceText.From(GenerateAutoConstructor(symbol, fields, options, hasParameterlessConstructor), Encoding.UTF8));
yield return (new MainNamedTypeSymbolInfo(symbol, hasParameterlessConstructor, filename), fields, options);
}
}
}

private static string GenerateAutoConstructor(INamedTypeSymbol symbol, FieldInfo[] fields, AnalyzerConfigOptions options, bool hasParameterlessConstructor)
private static CompilationUnitSyntax GenerateAutoConstructor(MainNamedTypeSymbolInfo symbol, EquatableArray<FieldInfo> fields, Options options)
{
bool generateConstructorDocumentation = false;
if (options.TryGetValue("build_property.AutoConstructor_GenerateConstructorDocumentation", out string? generateConstructorDocumentationSwitch))
{
generateConstructorDocumentation = generateConstructorDocumentationSwitch.Equals("true", StringComparison.OrdinalIgnoreCase);
}

options.TryGetValue("build_property.AutoConstructor_ConstructorDocumentationComment", out string? constructorDocumentationComment);
bool generateConstructorDocumentation = options.GenerateConstructorDocumentation;
string? constructorDocumentationComment = options.ConstructorDocumentationComment;
if (string.IsNullOrWhiteSpace(constructorDocumentationComment))
{
constructorDocumentationComment = "Initializes a new instance of the {0} class.";
constructorDocumentationComment = $"Initializes a new instance of the {{0}} class. // Counter: {Interlocked.Increment(ref _counter)}";
}

var codeGenerator = new CodeGenerator();

if (Array.Exists(fields, f => f.Nullable))
if (fields.Any(f => f.Nullable))
{
codeGenerator.AddNullableAnnotation();
}
Expand All @@ -168,21 +175,21 @@ private static string GenerateAutoConstructor(INamedTypeSymbol symbol, FieldInfo
codeGenerator.AddDocumentation(string.Format(CultureInfo.InvariantCulture, constructorDocumentationComment, symbol.Name));
}

if (!symbol.ContainingNamespace.IsGlobalNamespace)
if (!symbol.IsGlobalNamespace)
{
codeGenerator.AddNamespace(symbol.ContainingNamespace);
}

foreach (INamedTypeSymbol containingType in symbol.GetContainingTypes())
foreach (NamedTypeSymbolInfo containingType in symbol.ContainingTypes)
{
codeGenerator.AddClass(containingType);
}

codeGenerator
.AddClass(symbol)
.AddConstructor(fields, hasParameterlessConstructor);
.AddConstructor(fields, symbol.HasParameterlessConstructor);

return codeGenerator.ToString();
return codeGenerator.GetCompilationUnit();
}

private static List<FieldInfo> GetFieldsFromSymbol(Compilation compilation, INamedTypeSymbol symbol, bool emitNullChecks)
Expand Down Expand Up @@ -243,28 +250,17 @@ private static FieldInfo GetFieldInfo(IFieldSymbol fieldSymbol, Compilation comp
}

return new FieldInfo(
injectedType,
injectedType?.ToDisplayString(),
parameterName,
fieldSymbol.AssociatedSymbol?.Name ?? fieldSymbol.Name,
initializer,
type,
type.ToDisplayString(),
IsNullable(type),
summaryText,
type.IsReferenceType && type.NullableAnnotation != NullableAnnotation.Annotated && emitNullChecks,
FieldType.Initialized);
}

private static bool IsNullable(ITypeSymbol typeSymbol)
{
bool isNullable = typeSymbol.IsReferenceType && typeSymbol.NullableAnnotation == NullableAnnotation.Annotated;
if (typeSymbol is INamedTypeSymbol namedSymbol)
{
isNullable |= namedSymbol.TypeArguments.Any(IsNullable);
}

return isNullable;
}

private static T? GetParameterValue<T>(string parameterName, ImmutableArray<IParameterSymbol> parameters, ImmutableArray<TypedConstant> arguments)
where T : class
{
Expand Down Expand Up @@ -304,11 +300,11 @@ private static void ExtractFieldsFromConstructedParent(List<FieldInfo> concatena
else
{
concatenatedFields.Add(new FieldInfo(
parameter.Type,
parameter.Type.ToDisplayString(),
parameter.Name,
string.Empty,
string.Empty,
parameter.Type,
parameter.Type.ToDisplayString(),
IsNullable(parameter.Type),
null,
false,
Expand All @@ -334,7 +330,7 @@ private static void ExtractFieldsFromGeneratedParent(Compilation compilation, bo
string.Empty,
string.Empty,
parameter.FallbackType,
IsNullable(parameter.FallbackType),
false,//IsNullable(parameter.FallbackType),
null,
false,
FieldType.PassedToBase));
Expand All @@ -343,4 +339,15 @@ private static void ExtractFieldsFromGeneratedParent(Compilation compilation, bo

ExtractFieldsFromParent(compilation, symbol, emitNullChecks, concatenatedFields);
}

private static bool IsNullable(ITypeSymbol typeSymbol)
{
bool isNullable = typeSymbol.IsReferenceType && typeSymbol.NullableAnnotation == NullableAnnotation.Annotated;
if (typeSymbol is INamedTypeSymbol namedSymbol)
{
isNullable |= namedSymbol.TypeArguments.Any(IsNullable);
}

return isNullable;
}
}
39 changes: 26 additions & 13 deletions src/AutoConstructor.Generator/CodeGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,22 @@ public CodeGenerator AddDocumentation(string? constructorDocumentationComment)
return this;
}

public CodeGenerator AddNamespace(INamespaceSymbol namespaceSymbol)
public CodeGenerator AddNamespace(string namespaceSymbolDisplayString)
{
if (_current is not null)
{
throw new InvalidOperationException($"Method {nameof(AddNamespace)} must be called first.");
}

_current = GetNamespace(namespaceSymbol.ToDisplayString(), _addNullableAnnotation);
_current = GetNamespace(namespaceSymbolDisplayString, _addNullableAnnotation);
return this;
}

public CodeGenerator AddClass(INamedTypeSymbol classSymbol)
public CodeGenerator AddClass(NamedTypeSymbolInfo classSymbol)
{
string identifier = classSymbol.Name;
bool isStatic = classSymbol.IsStatic;
ITypeParameterSymbol[] typeParameterList = classSymbol.TypeParameters.ToArray();
EquatableArray<string> typeParameterList = classSymbol.TypeParameters;

ClassDeclarationSyntax classSyntax = GetClass(
identifier,
Expand Down Expand Up @@ -83,7 +83,7 @@ public CodeGenerator AddClass(INamedTypeSymbol classSymbol)
return this;
}

public CodeGenerator AddConstructor(FieldInfo[] parameters, bool symbolHasParameterlessConstructor)
public CodeGenerator AddConstructor(EquatableArray<FieldInfo> parameters, bool symbolHasParameterlessConstructor)
{
if (_current is ClassDeclarationSyntax classDeclarationSyntax)
{
Expand Down Expand Up @@ -112,6 +112,19 @@ public CodeGenerator AddConstructor(FieldInfo[] parameters, bool symbolHasParame
return this;
}

public CompilationUnitSyntax GetCompilationUnit()
{
if (_current is null)
{
throw new InvalidOperationException("No class was added to the generator.");
}

return CompilationUnit()
.AddMembers(_current)
.NormalizeWhitespace()
.WithTrailingTrivia(CarriageReturnLineFeed);
}

public override string ToString()
{
if (_current is null)
Expand Down Expand Up @@ -152,7 +165,7 @@ private static BaseNamespaceDeclarationSyntax GetNamespace(string identifier, bo
.WithNamespaceKeyword(Token(GetHeaderTrivia(addNullableAnnotation), SyntaxKind.NamespaceKeyword, TriviaList()));
}

private static ClassDeclarationSyntax GetClass(string identifier, bool addHeaderTrivia, bool addNullableAnnotation, bool isStatic, ITypeParameterSymbol[] typeParameterList)
private static ClassDeclarationSyntax GetClass(string identifier, bool addHeaderTrivia, bool addNullableAnnotation, bool isStatic, EquatableArray<string> typeParameterList)
{
SyntaxToken firstModifier = Token(isStatic ? SyntaxKind.StaticKeyword : SyntaxKind.PartialKeyword);
if (addHeaderTrivia)
Expand All @@ -166,15 +179,15 @@ private static ClassDeclarationSyntax GetClass(string identifier, bool addHeader
declaration = declaration.AddModifiers(Token(SyntaxKind.PartialKeyword));
}

if (typeParameterList.Length > 0)
if (!typeParameterList.IsEmpty)
{
declaration = declaration.AddTypeParameterListParameters(Array.ConvertAll(typeParameterList, GetTypeParameter));
declaration = declaration.AddTypeParameterListParameters(typeParameterList.Select(GetTypeParameter).ToArray());
}

return declaration;
}

private static ConstructorDeclarationSyntax GetConstructor(SyntaxToken identifier, FieldInfo[] parameters, string? constructorDocumentationComment, bool generateThisInitializer)
private static ConstructorDeclarationSyntax GetConstructor(SyntaxToken identifier, EquatableArray<FieldInfo> parameters, string? constructorDocumentationComment, bool generateThisInitializer)
{
FieldInfo[] constructorParameters = parameters
.GroupBy(x => x.ParameterName)
Expand Down Expand Up @@ -228,15 +241,15 @@ private static DocumentationCommentTriviaSyntax GetDocumentation(string construc

private static ParameterSyntax GetParameter(FieldInfo parameter)
{
ITypeSymbol parameterType = parameter.Type ?? parameter.FallbackType;
string parameterType = parameter.Type ?? parameter.FallbackType;

return Parameter(Identifier(parameter.ParameterName))
.WithType(ParseTypeName(parameterType.ToDisplayString()));
.WithType(ParseTypeName(parameterType));
}

private static TypeParameterSyntax GetTypeParameter(ITypeParameterSymbol identifier)
private static TypeParameterSyntax GetTypeParameter(string identifierName)
{
return TypeParameter(Identifier(identifier.Name));
return TypeParameter(Identifier(identifierName));
}

private static ArgumentSyntax GetArgument(FieldInfo parameter)
Expand Down
Loading

0 comments on commit bca9320

Please sign in to comment.