Skip to content

Commit

Permalink
Added ability to exclude assemblies from resoltion via attribute as w…
Browse files Browse the repository at this point in the history
…ell as via build property
  • Loading branch information
david-driscoll committed Dec 29, 2024
1 parent a9e3122 commit b22e73f
Show file tree
Hide file tree
Showing 440 changed files with 1,092 additions and 13,239 deletions.
24 changes: 7 additions & 17 deletions src/DependencyInjection.Analyzers/AssemblyCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,23 @@ IncrementalValueProvider<bool> hasAssemblyLoadContext
.Select((tuple, _) => tuple.Left)
.Collect();


public static ImmutableList<ResolvedSourceLocation> ResolveSources(
AssemblyProviderConfiguration configuration,
Compilation compilation,
HashSet<Diagnostic> diagnostics,
IReadOnlyList<Item> items
ImmutableList<Item> items,
ImmutableDictionary<string, IAssemblySymbol> assemblySymbols
)
{
var assemblySymbols = compilation
.References.Select(compilation.GetAssemblyOrModuleSymbol)
.Concat([compilation.Assembly])
.Select(
symbol => symbol switch { IAssemblySymbol assemblySymbol => assemblySymbol, IModuleSymbol moduleSymbol => moduleSymbol.ContainingAssembly, _ => null! }
)
.Where(z => z is { })
.ToImmutableHashSet<IAssemblySymbol>(SymbolEqualityComparer.Default);

var results = new List<ResolvedSourceLocation>();
foreach (var item in items)
{
var pa = new HashSet<IAssemblySymbol>(SymbolEqualityComparer.Default);
try
{
var filterAssemblies = assemblySymbols
.Values
.Where(z => item.AssemblyFilter.IsMatch(compilation, z))
.ToArray();

Expand All @@ -50,13 +45,8 @@ IReadOnlyList<Item> items
continue;
}

results.Add(
new(
item.Location,
GenerateDescriptors(compilation, filterAssemblies, pa).NormalizeWhitespace().ToFullString().Replace("\r", ""),
pa.Select(z => z.MetadataName).ToImmutableHashSet()
)
);
var descriptors = GenerateDescriptors(compilation, filterAssemblies, pa).NormalizeWhitespace().ToFullString().Replace("\r", "");
results.Add(new(item.Location, descriptors, pa.Select(z => z.MetadataName).ToImmutableHashSet()));
}
catch (Exception e)
{
Expand Down
100 changes: 60 additions & 40 deletions src/DependencyInjection.Analyzers/CompiledServiceScanningGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
var collectionProvider = assembliesSyntaxProvider
.Combine(reflectionSyntaxProvider)
.Combine(serviceDescriptorSyntaxProvider)
.Select((z, _) => (assemblies: z.Left.Left, reflection: z.Left.Right, serviceDescriptors: z.Right));
.Select((z, _) => ( assemblies: z.Left.Left, reflection: z.Left.Right, serviceDescriptors: z.Right ));
var generatedJsonProvider = context
.AdditionalTextsProvider.Where(z => z.Path.EndsWith(Constants.AssemblyJsonExtension, StringComparison.OrdinalIgnoreCase))
.Select(
Expand All @@ -44,10 +44,10 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
var source = text.GetText()?.ToString();
if (source is not { Length: > 100 })
{
return (path: Path.GetFileName(text.Path), source: new([], [], []));
return ( path: Path.GetFileName(text.Path), source: new([], [], [], false) );
}

return (path: Path.GetFileName(text.Path),
return ( path: Path.GetFileName(text.Path),
source: JsonSerializer.Deserialize(
source,
JsonSourceGenerationContext.Default.CompiledAssemblyProviderData
Expand All @@ -65,16 +65,20 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
var partialProvider = context
.AdditionalTextsProvider.Where(z => z.Path.EndsWith(Constants.PartialExtension, StringComparison.OrdinalIgnoreCase))
.Collect()
.Select((z, _) => z.ToFrozenDictionary(static z => Path.GetFileName(z.Path), static z =>
JsonSerializer.Deserialize(
z.GetText()?.ToString() ?? "",
JsonSourceGenerationContext.Default.SavedSourceLocation
)!
));
.Select(
(z, _) => z.ToFrozenDictionary(
static z => Path.GetFileName(z.Path),
static z =>
JsonSerializer.Deserialize(
z.GetText()?.ToString() ?? "",
JsonSourceGenerationContext.Default.SavedSourceLocation
)!
)
);
var additionalFilesProvider = generatedJsonProvider
.Combine(skipProvider)
.Combine(partialProvider)
.Select((z, _) => (generatedJson: z.Left.Left, skip: z.Left.Right, partial: z.Right));
.Select((z, _) => ( generatedJson: z.Left.Left, skip: z.Left.Right, partial: z.Right ));
context.RegisterImplementationSourceOutput(
context
.CompilationProvider
Expand All @@ -93,6 +97,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
),
static (context, request) =>
{
HashSet<string> excludedAssemblies = request.options.GlobalOptions.TryGetValue("build_property.ExcludeAssemblyFromCTP", out var assemblies) ? [..assemblies.Split([';', ','], StringSplitOptions.RemoveEmptyEntries)] : [];
var privateAssemblies = new HashSet<IAssemblySymbol>(SymbolEqualityComparer.Default);
var diagnostics = new HashSet<Diagnostic>();
var assemblyRequests = AssemblyCollection.GetAssemblyItems(request.compilation, diagnostics, request.assemblies, context.CancellationToken);
Expand All @@ -110,28 +115,38 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
);
var attributes = AssemblyProviderConfiguration.ToAssemblyAttributes(assemblyRequests, reflectionRequests, serviceDescriptorRequests).ToArray();

var config = new AssemblyProviderConfiguration(context, request.compilation, request.options, request.additionalFiles.generatedJson, request.additionalFiles.skip, request.additionalFiles.partial);

var assemblySymbols = request.compilation
.References.Select(request.compilation.GetAssemblyOrModuleSymbol)
.Concat([request.compilation.Assembly])
.Select(
symbol =>
{
if (symbol is IAssemblySymbol assemblySymbol) return assemblySymbol;

if (symbol is IModuleSymbol moduleSymbol) return moduleSymbol.ContainingAssembly;

// ReSharper disable once NullableWarningSuppressionIsUsed
return null!;
}
)
.Where(z => z is { })
.GroupBy(z => z.MetadataName, z => z, (s, symbols) => (Key: s, Symbol: symbols.First()))
.ToImmutableDictionary(z => z.Key, z => z.Symbol);
var assemblySymbols = request
.compilation
.References
.Select(request.compilation.GetAssemblyOrModuleSymbol)
.Concat([request.compilation.Assembly])
.Select(
symbol =>
{
if (symbol is IAssemblySymbol assemblySymbol) return assemblySymbol;

if (symbol is IModuleSymbol moduleSymbol) return moduleSymbol.ContainingAssembly;

// ReSharper disable once NullableWarningSuppressionIsUsed
return null!;
}
)
.Where(z => z is { })
.Where(z => excludedAssemblies.All(a => !z.MetadataName.StartsWith(a, StringComparison.OrdinalIgnoreCase)))
.GroupBy(z => z.MetadataName, z => z, (s, symbols) => ( Key: s, Symbol: symbols.First() ))
.ToImmutableDictionary(z => z.Key, z => z.Symbol);

var config = new AssemblyProviderConfiguration(
context,
request.compilation,
request.options,
request.additionalFiles.generatedJson,
request.additionalFiles.skip,
request.additionalFiles.partial
);

var resolvedData = config.FromAssemblyAttributes(
assemblySymbols,
ref assemblySymbols,
reflectionRequests,
serviceDescriptorRequests,
diagnostics
Expand All @@ -142,17 +157,21 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
serviceDescriptorRequests = serviceDescriptorRequests.AddRange(resolvedData.InternalServiceDescriptorRequests);

var assemblySources = AssemblyCollection.ResolveSources(
config,
request.compilation,
diagnostics,
assemblyRequests
assemblyRequests,
assemblySymbols
);
var reflectionSources = ReflectionCollection.ResolveSources(
config,
request.compilation,
diagnostics,
reflectionRequests,
request.compilation.Assembly
);
var serviceDescriptorSources = ServiceDescriptorCollection.ResolveSources(
config,
request.compilation,
diagnostics,
serviceDescriptorRequests,
Expand All @@ -162,14 +181,9 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
reflectionSources = reflectionSources.AddRange(resolvedData.ReflectionSources);
serviceDescriptorSources = serviceDescriptorSources.AddRange(resolvedData.ServiceDescriptorSources);

privateAssemblies.UnionWith(JoinAssemblies(assemblySymbols, assemblySources));
privateAssemblies.UnionWith(JoinAssemblies(assemblySymbols, reflectionSources));
privateAssemblies.UnionWith(JoinAssemblies(assemblySymbols, serviceDescriptorSources));

static IEnumerable<IAssemblySymbol> JoinAssemblies(System.Collections.Generic.IEnumerable<System.Collections.Generic.KeyValuePair<string, IAssemblySymbol>> assemblies, System.Collections.Generic.IEnumerable<ResolvedSourceLocation> sources)
{
return sources.SelectMany(z => z.PrivateAssemblies).Join(assemblies, z => z, z => z.Key, (_, a) => a.Value);
}
privateAssemblies.UnionWith(joinAssemblies(assemblySymbols, assemblySources));
privateAssemblies.UnionWith(joinAssemblies(assemblySymbols, reflectionSources));
privateAssemblies.UnionWith(joinAssemblies(assemblySymbols, serviceDescriptorSources));

var cu = CompilationUnit()
.WithUsings(
Expand Down Expand Up @@ -226,9 +240,15 @@ static IEnumerable<IAssemblySymbol> JoinAssemblies(System.Collections.Generic.IE
}

context.AddSource(
"Compiled_AssemblyProvider.g.cs",
"Compiled_AssemblyProvider.g.cs", // "CompiledTypeProvider.g.cs",
cu.NormalizeWhitespace().SyntaxTree.GetRoot().GetText(Encoding.UTF8)
);
return;

static IEnumerable<IAssemblySymbol> joinAssemblies(IEnumerable<KeyValuePair<string, IAssemblySymbol>> assemblies, IEnumerable<ResolvedSourceLocation> sources)
{
return sources.SelectMany(z => z.PrivateAssemblies).Join(assemblies, z => z, z => z.Key, (_, a) => a.Value);
}
}
);
}
Expand Down
Loading

0 comments on commit b22e73f

Please sign in to comment.