Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiled type provider was getting types properly #1326

Merged
merged 3 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
23 changes: 10 additions & 13 deletions src/DependencyInjection.Analyzers/AssemblyCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ IncrementalValueProvider<bool> hasAssemblyLoadContext
return valueProvider
.CreateSyntaxProvider((node, _) => IsValidMethod(node), (syntaxContext, _) => GetMethod(syntaxContext))
.Combine(hasAssemblyLoadContext)
.Where(z => z is { Right: true, Left: { method: { }, selector: { }, }, })
.Where(z => z is { Right: true, Left: { method: { }, selector: { } } })
.Select((tuple, _) => tuple.Left)
.Collect();
}
Expand Down Expand Up @@ -90,11 +90,6 @@ CollectRequest request
);
}

private static bool IsValidMethod(SyntaxNode node)
{
return GetMethod(node) is { method: { }, selector: { } };
}

public static MethodDeclarationSyntax Execute(
Request request
)
Expand All @@ -104,7 +99,7 @@ Request request

var assemblySymbols = compilation
.References.Select(compilation.GetAssemblyOrModuleSymbol)
.Concat([compilation.Assembly,])
.Concat([compilation.Assembly])
.Select(
symbol => symbol switch
{
Expand Down Expand Up @@ -140,26 +135,26 @@ GeneratorSyntaxContext context
|| baseData.selector is null
|| context.SemanticModel.GetTypeInfo(baseData.selector).ConvertedType is not INamedTypeSymbol
{
TypeArguments: [{ Name: "IAssemblyProviderAssemblySelector", }, ..,],
TypeArguments: [{ Name: IReflectionAssemblySelector }, ..],
})
return default;

return ( baseData.method, baseData.selector, semanticModel: context.SemanticModel );
}

public static (InvocationExpressionSyntax method, ExpressionSyntax selector ) GetMethod(SyntaxNode node)
{
return node is InvocationExpressionSyntax
public static (InvocationExpressionSyntax method, ExpressionSyntax selector ) GetMethod(SyntaxNode node) =>
node is InvocationExpressionSyntax
{
Expression: MemberAccessExpressionSyntax
{
Name.Identifier.Text: "GetAssemblies",
},
ArgumentList.Arguments: [{ Expression: { } expression, },],
ArgumentList.Arguments: [{ Expression: { } expression }],
} invocationExpressionSyntax
? ( invocationExpressionSyntax, expression )
: default;
}

private static bool IsValidMethod(SyntaxNode node) => GetMethod(node) is { method: { }, selector: { } };

private static BlockSyntax GenerateDescriptors(Compilation compilation, IEnumerable<IAssemblySymbol> assemblies, HashSet<IAssemblySymbol> privateAssemblies)
{
Expand Down Expand Up @@ -322,6 +317,8 @@ HashSet<IAssemblySymbol> privateAssemblies
.AddMembers(privateMembers.ToArray());
}

private const string IReflectionAssemblySelector = nameof(IReflectionAssemblySelector);

public record CollectRequest
(
Compilation Compilation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ private static TypeFilterData LoadTypeFilterData(CompiledAssemblyFilter assembly
.ToImmutableArray(),
typeFilter
.TypeFilterDescriptors.OfType<NameFilterDescriptor>()
.Select(z => new NameFilterData(z.Filter, z.Names.OrderBy(z => z).ToImmutableArray()))
.Select(z => new NameFilterData(z.Include, z.Filter, z.Names.OrderBy(z => z).ToImmutableArray()))
.OrderBy(z => string.Join(",", z.Names.OrderBy(static z => z)))
.ThenBy(z => z.Filter)
.ToImmutableArray(),
Expand Down Expand Up @@ -405,7 +405,7 @@ ImmutableDictionary<string, IAssemblySymbol> assemblySymbols

foreach (var item in data.NameFilters)
{
descriptors.Add(new NameFilterDescriptor(item.Filter, item.Names.ToImmutableHashSet()));
descriptors.Add(new NameFilterDescriptor(item.Include, item.Filter, item.Names.ToImmutableHashSet()));
}

foreach (var item in data.TypeKindFilters)
Expand Down Expand Up @@ -747,6 +747,8 @@ internal record NamespaceFilterData

internal record NameFilterData
(
[property: JsonPropertyName("i")]
bool Include,
[property: JsonPropertyName("f")]
TextDirectionFilter Filter,
[property: JsonPropertyName("n")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ bool GetFilterDescriptor(ITypeFilterDescriptor filterDescriptor)
targetType.GetAttributes().All(z => Helpers.GetFullMetadataName(z.AttributeClass) != attribute),
NamespaceFilterDescriptor { Filter: var filterName, Namespaces: var filterNamespaces } =>
handleNamespaceFilter(filterName, filterNamespaces, targetType),
NameFilterDescriptor { Filter: var filterName, Names: var filterNames } =>
handleNameFilter(filterName, filterNames, targetType),
NameFilterDescriptor { Include: var include, Filter: var filterName, Names: var filterNames } =>
handleNameFilter(include, filterName, filterNames, targetType),
TypeKindFilterDescriptor { Include: var include, TypeKinds: var typeKinds } =>
handleKindFilter(include, typeKinds, targetType),
TypeInfoFilterDescriptor { Include: var include, TypeInfos: var typeInfos } =>
Expand Down Expand Up @@ -84,14 +84,17 @@ static bool handleNamespaceFilter(NamespaceFilter filterName, ImmutableHashSet<s
};
}

static bool handleNameFilter(TextDirectionFilter filterName, ImmutableHashSet<string> filterNames, INamedTypeSymbol type)
static bool handleNameFilter(bool include, TextDirectionFilter filterName, ImmutableHashSet<string> filterNames, INamedTypeSymbol type)
{
return filterName switch
return ( include, filterName ) switch
{
TextDirectionFilter.Contains => filterNames.Any(name => type.Name.Contains(name)),
TextDirectionFilter.StartsWith => filterNames.Any(name => type.Name.StartsWith(name)),
TextDirectionFilter.EndsWith => filterNames.Any(name => type.Name.EndsWith(name)),
_ => throw new NotImplementedException(),
(true, TextDirectionFilter.Contains) => filterNames.Any(name => type.Name.Contains(name)),
(false, TextDirectionFilter.Contains) => !filterNames.Any(name => type.Name.Contains(name)),
(true, TextDirectionFilter.EndsWith) => filterNames.Any(name => type.Name.EndsWith(name)),
(false, TextDirectionFilter.EndsWith) => !filterNames.Any(name => type.Name.EndsWith(name)),
(true, TextDirectionFilter.StartsWith) => filterNames.Any(name => type.Name.StartsWith(name)),
(false, TextDirectionFilter.StartsWith) => !filterNames.Any(name => type.Name.StartsWith(name)),
_ => throw new NotImplementedException(),
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ SimpleNameSyntax name
createNamespaceTypeFilterDescriptor(context, name, expression, semanticModel),
({ Identifier.Text: "InExactNamespaces" or "InNamespaces" or "NotInNamespaces" }, _) =>
createNamespaceStringFilterDescriptor(context, name, expression, semanticModel),
({ Identifier.Text: "EndsWith" or "StartsWith" or "Contains" }, _) =>
({ Identifier.Text: "EndsWith" or "StartsWith" or "Contains" or "NotEndsWith" or "NotStartsWith" or "NotContains" }, _) =>
createNameFilterDescriptor(context, name, expression),
({ Identifier.Text: "KindOf" or "NotKindOf" }, _) =>
createTypeKindFilterDescriptor(context, name, expression),
Expand Down Expand Up @@ -276,17 +276,17 @@ SemanticModel semanticModel
: new AssignableToAnyTypeFilterDescriptor(arguments);
}

static NameFilterDescriptor createNameFilterDescriptor(
static ITypeFilterDescriptor createNameFilterDescriptor(
SourceProductionContext context,
SimpleNameSyntax name,
InvocationExpressionSyntax expression
)
{
var filter = name.Identifier.Text switch
{
"EndsWith" => TextDirectionFilter.EndsWith,
"StartsWith" => TextDirectionFilter.StartsWith,
"Contains" => TextDirectionFilter.Contains,
"EndsWith" or "NotEndsWith" => TextDirectionFilter.EndsWith,
"StartsWith" or "NotStartsWith" => TextDirectionFilter.StartsWith,
"Contains" or "NotContains" => TextDirectionFilter.Contains,
_ => throw new NotSupportedException(
$"Not supported name filter. Method: {name.ToFullString()} {expression.ToFullString()} method."
),
Expand All @@ -303,7 +303,7 @@ InvocationExpressionSyntax expression
stringValues.Add(item);
}

return new(filter, stringValues.ToImmutable());
return new NameFilterDescriptor(!name.Identifier.Text.StartsWith("Not"), filter, stringValues.ToImmutable());
}

static NamespaceFilterDescriptor createNamespaceTypeFilterDescriptor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
namespace Rocket.Surgery.DependencyInjection.Analyzers.Descriptors;

[DebuggerDisplay("{ToString()}")]
internal readonly record struct NameFilterDescriptor(TextDirectionFilter Filter, ImmutableHashSet<string> Names) : ITypeFilterDescriptor;
internal readonly record struct NameFilterDescriptor(bool Include, TextDirectionFilter Filter, ImmutableHashSet<string> Names) : ITypeFilterDescriptor;
9 changes: 4 additions & 5 deletions src/DependencyInjection.Analyzers/ReflectionCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ IncrementalValueProvider<bool> hasAssemblyLoadContext
return valueProvider
.CreateSyntaxProvider((node, _) => IsValidMethod(node), (syntaxContext, _) => GetTypesMethod(syntaxContext))
.Combine(hasAssemblyLoadContext)
.Where(z => z is { Right: true, Left: { method: { }, selector: { }, }, })
.Where(z => z is { Right: true, Left: { method: { }, selector: { } } })
.Select((tuple, _) => tuple.Left)
.Collect();
}
Expand Down Expand Up @@ -49,7 +49,7 @@ public static (InvocationExpressionSyntax method, ExpressionSyntax selector, Sem
|| baseData.selector is null
|| context.SemanticModel.GetTypeInfo(baseData.selector).ConvertedType is not INamedTypeSymbol
{
TypeArguments: [{ Name: IReflectionAssemblySelector, }, ..,],
TypeArguments: [{ Name: IReflectionTypeSelector }, ..],
})
return default;

Expand All @@ -63,7 +63,7 @@ node is InvocationExpressionSyntax
{
Name.Identifier.Text: "GetTypes",
},
ArgumentList.Arguments: [.., { Expression: { } expression, },],
ArgumentList.Arguments: [.., { Expression: { } expression }],
} invocationExpressionSyntax
? ( invocationExpressionSyntax, expression )
: default;
Expand Down Expand Up @@ -116,7 +116,7 @@ internal static ImmutableArray<Item> GetTypeDetails(
return items.ToImmutable();
}

private static bool IsValidMethod(SyntaxNode node) => GetTypesMethod(node) is { method: { }, selector: { }, };
private static bool IsValidMethod(SyntaxNode node) => GetTypesMethod(node) is { method: { }, selector: { } };

private static BlockSyntax GenerateDescriptors(Compilation compilation, IEnumerable<INamedTypeSymbol> types, HashSet<IAssemblySymbol> privateAssemblies)
{
Expand Down Expand Up @@ -169,7 +169,6 @@ private static BlockSyntax GenerateDescriptors(Compilation compilation, IEnumera
Block(SingletonList<StatementSyntax>(YieldStatement(SyntaxKind.YieldBreakStatement)))
);

private const string IReflectionAssemblySelector = nameof(IReflectionAssemblySelector);
private const string IReflectionTypeSelector = nameof(IReflectionTypeSelector);

public record Request(SourceProductionContext Context, Compilation Compilation, ImmutableArray<Item> Items, HashSet<IAssemblySymbol> PrivateAssemblies);
Expand Down
24 changes: 24 additions & 0 deletions src/DependencyInjection.Extensions/Compiled/ITypeFilter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,24 @@ public interface ITypeFilter
/// <param name="value"></param>
/// <param name="values"></param>
/// <returns></returns>
ITypeFilter NotEndsWith(string value, params string[] values);

/// <summary>
/// Will match all types that start with
/// </summary>
/// <param name="value"></param>
/// <param name="values"></param>
/// <returns></returns>
ITypeFilter StartsWith(string value, params string[] values);

/// <summary>
/// Will match all types that start with
/// </summary>
/// <param name="value"></param>
/// <param name="values"></param>
/// <returns></returns>
ITypeFilter NotStartsWith(string value, params string[] values);

/// <summary>
/// Will match all types that contain the given values
/// </summary>
Expand All @@ -68,6 +84,14 @@ public interface ITypeFilter
/// <returns></returns>
ITypeFilter Contains(string value, params string[] values);

/// <summary>
/// Will match all types that contain the given values
/// </summary>
/// <param name="value"></param>
/// <param name="values"></param>
/// <returns></returns>
ITypeFilter NotContains(string value, params string[] values);

/// <summary>
/// Will match all types in the exact same namespace as the type <typeparamref name="T" />
/// </summary>
Expand Down
Loading
Loading