Skip to content

Commit

Permalink
Support registering generic types (#126)
Browse files Browse the repository at this point in the history
E.g

```csharp
[Register(typeof(A<>))]
[Register(typeof(B<,>), typeof(IB<,>))]
```
  • Loading branch information
YairHalberstadt authored Jun 8, 2021
1 parent aa48f8e commit f87231c
Show file tree
Hide file tree
Showing 10 changed files with 1,078 additions and 43 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,18 @@ If you do so, you will have to explicitly also register it as itself if that is

If there is a single public non-parameterless constructor, StrongInject will use that to construct the type. If there is no public non-parameterless constructor StrongInject will use the parameterless constructor if it exists and is public. Else it will report an error.

You can register generic types as well:

```csharp
public class A<T> {}
public interface IB<T1, T2> {}
public class B<T1, T2> : IB<T1, T2> {}

[Register(typeof(A<>))]
[Register(typeof(B<,>), typeof(IB<,>))]
public partial class Container : IContainer<A<int>>, IContainer<IB<string, object>> {}
```

#### Scope

The scope of a registration determines how often a new instance is created, how long it lives, and who uses it.
Expand Down
1 change: 0 additions & 1 deletion StrongInject.Generator/ContainerGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,6 @@ void EmitStatement(Operation operation)
case FactoryMethod { Method: var method }:
{
GenerateMethodCall(variableName, method, dependencies);

break;
}
case ArraySource { ArrayType: var arrayType }:
Expand Down
148 changes: 135 additions & 13 deletions StrongInject.Generator/GenericRegistrationsResolver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ private GenericRegistrationsResolver(Dictionary<INamedTypeSymbol, Bucket> namedT
_typeParameterBuckets = typeParameterBuckets;
}

public bool TryResolve(ITypeSymbol type, out FactoryMethod instanceSource, out bool isAmbiguous, out IEnumerable<FactoryMethod> sourcesNotMatchingConstraints)
public bool TryResolve(ITypeSymbol type, out InstanceSource instanceSource, out bool isAmbiguous, out IEnumerable<InstanceSource> sourcesNotMatchingConstraints)
{
instanceSource = null!;
sourcesNotMatchingConstraints = Array.Empty<FactoryMethod>();
Expand Down Expand Up @@ -100,10 +100,14 @@ public class Builder
private readonly List<Builder> _children = new();
private readonly List<FactoryMethod> _factoryMethods = new();
private readonly List<FactoryOfMethod> _factoryOfMethods = new();
private readonly List<Registration> _registrations = new();
private readonly List<ForwardedInstanceSource> _forwardedInstanceSources = new();

public void Add(Builder child) => _children.Add(child);
public void Add(FactoryMethod factoryMethod) => _factoryMethods.Add(factoryMethod);
public void Add(FactoryOfMethod factoryOfMethod) => _factoryOfMethods.Add(factoryOfMethod);
public void Add(Registration registration) => _registrations.Add(registration);
public void Add(ForwardedInstanceSource forwardedInstanceSource) => _forwardedInstanceSources.Add(forwardedInstanceSource);

public GenericRegistrationsResolver Build(Compilation compilation)
{
Expand Down Expand Up @@ -168,6 +172,20 @@ public GenericRegistrationsResolver Build(Compilation compilation)
type.OriginalDefinition,
static _ => new BucketBuilder()).Add(factoryOfMethod);
}

foreach (var registration in builder._registrations)
{
namedTypeBucketBuilders.GetOrCreate(
registration.Type.OriginalDefinition,
static _ => new BucketBuilder()).Add(registration);
}

foreach (var forwardedInstanceSource in builder._forwardedInstanceSources)
{
namedTypeBucketBuilders.GetOrCreate(
forwardedInstanceSource.AsType.OriginalDefinition,
static _ => new BucketBuilder()).Add(forwardedInstanceSource);
}

return
(
Expand All @@ -185,6 +203,8 @@ private class BucketBuilder
private List<Bucket>? _buckets;
private ImmutableArray<FactoryMethod>.Builder? _factoryMethods;
private ImmutableArray<FactoryOfMethod>.Builder? _factoryOfMethods;
private ImmutableArray<Registration>.Builder? _registrations;
private ImmutableArray<ForwardedInstanceSource>.Builder? _forwardedInstanceSources;

public void Add(Bucket bucket)
{
Expand All @@ -203,26 +223,46 @@ public void Add(FactoryOfMethod factoryOfMethod)
_factoryOfMethods ??= ImmutableArray.CreateBuilder<FactoryOfMethod>();
_factoryOfMethods.Add(factoryOfMethod);
}

public void Add(Registration registration)
{
_registrations ??= ImmutableArray.CreateBuilder<Registration>();
_registrations.Add(registration);
}

public void Add(ForwardedInstanceSource forwardedInstanceSource)
{
_forwardedInstanceSources ??= ImmutableArray.CreateBuilder<ForwardedInstanceSource>();
_forwardedInstanceSources.Add(forwardedInstanceSource);
}

public Bucket Build(Compilation compilation)
{
return new Bucket(
_buckets ?? Enumerable.Empty<Bucket>(),
_factoryMethods?.ToImmutable() ?? ImmutableArray<FactoryMethod>.Empty,
_factoryOfMethods?.ToImmutable() ?? ImmutableArray<FactoryOfMethod>.Empty,
_registrations?.ToImmutable() ?? ImmutableArray<Registration>.Empty,
_forwardedInstanceSources?.ToImmutable() ?? ImmutableArray<ForwardedInstanceSource>.Empty,
compilation);
}
}
}

private class Bucket
{
public Bucket(IEnumerable<Bucket> childResolvers, ImmutableArray<FactoryMethod> factoryMethods, ImmutableArray<FactoryOfMethod> factoryOfMethods, Compilation compilation)
public Bucket(
IEnumerable<Bucket> childResolvers,
ImmutableArray<FactoryMethod> factoryMethods,
ImmutableArray<FactoryOfMethod> factoryOfMethods,
ImmutableArray<Registration> registrations,
ImmutableArray<ForwardedInstanceSource> forwardedInstanceSources,
Compilation compilation)
{
var builder = ImmutableArray.CreateBuilder<Bucket>();
foreach (var childResolver in childResolvers)
{
if (childResolver._factoryMethods.IsDefaultOrEmpty)
if (childResolver.IsEmpty)
{
builder.AddRange(childResolver._childResolvers);
}
Expand All @@ -234,18 +274,27 @@ public Bucket(IEnumerable<Bucket> childResolvers, ImmutableArray<FactoryMethod>
_childResolvers = builder.ToImmutable();
_factoryMethods = factoryMethods;
_factoryOfMethods = factoryOfMethods;
_registrations = registrations;
_forwardedInstanceSources = forwardedInstanceSources;
_compilation = compilation;
}

private readonly ImmutableArray<Bucket> _childResolvers;
private readonly ImmutableArray<FactoryMethod> _factoryMethods;
private readonly ImmutableArray<FactoryOfMethod> _factoryOfMethods;
private readonly ImmutableArray<Registration> _registrations;
private readonly ImmutableArray<ForwardedInstanceSource> _forwardedInstanceSources;
private readonly Compilation _compilation;

public bool TryResolve(ITypeSymbol type, out FactoryMethod instanceSource, out bool isAmbiguous, out IEnumerable<FactoryMethod> sourcesNotMatchingConstraints)
private bool IsEmpty => _factoryMethods.IsDefaultOrEmpty
&& _factoryOfMethods.IsDefaultOrEmpty
&& _registrations.IsDefaultOrEmpty
&& _forwardedInstanceSources.IsDefaultOrEmpty;

public bool TryResolve(ITypeSymbol type, out InstanceSource instanceSource, out bool isAmbiguous, out IEnumerable<InstanceSource> sourcesNotMatchingConstraints)
{
instanceSource = null!;
List<FactoryMethod>? factoriesWhereConstraintsDoNotMatch = null;
List<InstanceSource>? sourcesNotMatchingConstraintsTemp = null;

foreach (var factoryMethod in GetAllRelevantFactoryMethods(type))
{
Expand All @@ -259,20 +308,93 @@ public bool TryResolve(ITypeSymbol type, out FactoryMethod instanceSource, out b
{
instanceSource = null!;
isAmbiguous = true;
sourcesNotMatchingConstraints = Enumerable.Empty<FactoryMethod>();
sourcesNotMatchingConstraints = Array.Empty<InstanceSource>();
return false;
}
}
else if (constraintsDoNotMatch)
{
(factoriesWhereConstraintsDoNotMatch ??= new()).Add(factoryMethod);
(sourcesNotMatchingConstraintsTemp ??= new()).Add(factoryMethod);
}
}

foreach (var registration in _registrations)
{
if (registration.Type.OriginalDefinition.Equals(type.OriginalDefinition, SymbolEqualityComparer.Default))
{
var originalConstructor = registration.Constructor.OriginalDefinition;
var constructor = ((INamedTypeSymbol)type).InstanceConstructors.First(
x => SymbolEqualityComparer.Default.Equals(x.OriginalDefinition, originalConstructor));

var updatedRegistration = registration with
{
Constructor = constructor, Type = ((INamedTypeSymbol)type)
};
if (instanceSource is null)
{
instanceSource = updatedRegistration;
}
else if (instanceSource != updatedRegistration)
{
instanceSource = null!;
isAmbiguous = true;
sourcesNotMatchingConstraints = Array.Empty<InstanceSource>();
return false;
}
}
}

foreach (var forwardedInstanceSource in _forwardedInstanceSources)
{
if (forwardedInstanceSource.AsType.OriginalDefinition.Equals(type.OriginalDefinition, SymbolEqualityComparer.Default))
{
if (forwardedInstanceSource.Underlying is Registration registration)
{
var typeArguments = ((INamedTypeSymbol)type).TypeArguments;
if (SatisfiesConstraints(registration.Type, typeArguments, _compilation))
{
var originalConstructor = registration.Constructor.OriginalDefinition;
var constructedRegistrationType =
registration.Type.OriginalDefinition.Construct(typeArguments.ToArray());
var constructor = constructedRegistrationType.InstanceConstructors.First(
x => SymbolEqualityComparer.Default.Equals(x.OriginalDefinition, originalConstructor));

var updatedRegistration = registration with
{
Constructor = constructor, Type = constructedRegistrationType
};

var updatedForwardedInstanceSource =
ForwardedInstanceSource.Create((INamedTypeSymbol)type, updatedRegistration);

if (instanceSource is null)
{
instanceSource = updatedForwardedInstanceSource;
}
else if (instanceSource != updatedForwardedInstanceSource)
{
instanceSource = null!;
isAmbiguous = true;
sourcesNotMatchingConstraints = Array.Empty<InstanceSource>();
return false;
}
}
else
{
(sourcesNotMatchingConstraintsTemp ??= new()).Add(forwardedInstanceSource);
}
}
else
{
throw new NotImplementedException(forwardedInstanceSource.Underlying.ToString());
}
}
}

if (instanceSource is not null)
{
isAmbiguous = false;
sourcesNotMatchingConstraints = Array.Empty<FactoryMethod>();
sourcesNotMatchingConstraints = Array.Empty<InstanceSource>();
return true;
}

Expand All @@ -288,30 +410,30 @@ public bool TryResolve(ITypeSymbol type, out FactoryMethod instanceSource, out b
{
instanceSource = null!;
isAmbiguous = true;
sourcesNotMatchingConstraints = Enumerable.Empty<FactoryMethod>();
sourcesNotMatchingConstraints = Array.Empty<InstanceSource>();
return false;
}
}
else if (isChildAmbiguous)
{
instanceSource = null!;
isAmbiguous = true;
sourcesNotMatchingConstraints = Enumerable.Empty<FactoryMethod>();
sourcesNotMatchingConstraints = Array.Empty<InstanceSource>();
return false;
}
(factoriesWhereConstraintsDoNotMatch ??= new()).AddRange(childSourcesNotMatchingConstraints);
(sourcesNotMatchingConstraintsTemp ??= new()).AddRange(childSourcesNotMatchingConstraints);
}

if (instanceSource is not null)
{
isAmbiguous = false;
sourcesNotMatchingConstraints = Array.Empty<FactoryMethod>();
sourcesNotMatchingConstraints = Array.Empty<InstanceSource>();
return true;
}

instanceSource = null!;
isAmbiguous = false;
sourcesNotMatchingConstraints = factoriesWhereConstraintsDoNotMatch ?? Enumerable.Empty<FactoryMethod>();
sourcesNotMatchingConstraints = sourcesNotMatchingConstraintsTemp ?? Enumerable.Empty<InstanceSource>();
return false;
}

Expand Down
11 changes: 6 additions & 5 deletions StrongInject.Generator/GenericResolutionHelpers.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using System;
using System.Collections.Generic;
using System.Linq;

namespace StrongInject.Generator
Expand Down Expand Up @@ -134,9 +135,9 @@ static bool CanConstructFrom(ITypeSymbol toConstruct, ITypeSymbol toConstructFro
}
}

private static bool SatisfiesConstraints(IMethodSymbol method, ITypeSymbol[] typeArguments, Compilation compilation)
public static bool SatisfiesConstraints(ISymbol symbol, IReadOnlyList<ITypeSymbol> typeArguments, Compilation compilation)
{
var typeParameters = method.TypeParameters;
var typeParameters = symbol.TypeParameters();
for (int i = 0; i < typeParameters.Length; i++)
{
var typeParameter = typeParameters[i];
Expand All @@ -157,7 +158,7 @@ private static bool SatisfiesConstraints(IMethodSymbol method, ITypeSymbol[] typ

foreach (var typeConstraint in typeParameter.ConstraintTypes)
{
var substitutedConstraintType = SubstituteType(compilation, typeConstraint, method, typeArguments);
var substitutedConstraintType = SubstituteType(compilation, typeConstraint, symbol, typeArguments);
var conversion = compilation.ClassifyConversion(typeArgument, substitutedConstraintType);
if (typeArgument.IsNullableType() || conversion is not ({ IsIdentity: true } or { IsImplicit: true, IsReference: true } or { IsBoxing: true }))
{
Expand Down Expand Up @@ -203,7 +204,7 @@ private static bool HasPublicParameterlessConstructor(INamedTypeSymbol type)
return false;
}

private static ITypeSymbol SubstituteType(Compilation compilation, ITypeSymbol type, IMethodSymbol method, ITypeSymbol[] typeArguments)
private static ITypeSymbol SubstituteType(Compilation compilation, ITypeSymbol type, ISymbol symbol, IReadOnlyList<ITypeSymbol> typeArguments)
{
return Visit(type);

Expand All @@ -212,7 +213,7 @@ ITypeSymbol Visit(ITypeSymbol type)
switch (type)
{
case ITypeParameterSymbol typeParameterSymbol:
return SymbolEqualityComparer.Default.Equals(typeParameterSymbol.DeclaringMethod, method)
return SymbolEqualityComparer.Default.Equals(typeParameterSymbol.DeclaringSymbol(), symbol)
? typeArguments[typeParameterSymbol.Ordinal]
: type;
case IArrayTypeSymbol { ElementType: var elementType, Rank: var rank } arrayTypeSymbol:
Expand Down
8 changes: 4 additions & 4 deletions StrongInject.Generator/InstanceSource.cs
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,17 @@ public override void Visit<TState>(IVisitor<TState> visitor, TState state)
}
internal record ForwardedInstanceSource : InstanceSource
{
private ForwardedInstanceSource(ITypeSymbol asType, InstanceSource underlying) : base(underlying.Scope, IsAsync: false, underlying.CanDecorate)
private ForwardedInstanceSource(INamedTypeSymbol asType, InstanceSource underlying) : base(underlying.Scope, IsAsync: false, underlying.CanDecorate)
=> (AsType, Underlying) = (asType, underlying);

public void Deconstruct(out ITypeSymbol AsType, out InstanceSource Underlying) => (AsType, Underlying) = (this.AsType, this.Underlying);
public void Deconstruct(out INamedTypeSymbol AsType, out InstanceSource Underlying) => (AsType, Underlying) = (this.AsType, this.Underlying);

public ITypeSymbol AsType { get; init; }
public INamedTypeSymbol AsType { get; init; }
public InstanceSource Underlying { get; init; }

public override ITypeSymbol OfType => AsType;

public static InstanceSource Create(ITypeSymbol asType, InstanceSource underlying)
public static InstanceSource Create(INamedTypeSymbol asType, InstanceSource underlying)
=> SymbolEqualityComparer.Default.Equals(underlying.OfType, asType)
? underlying
: new ForwardedInstanceSource(asType, underlying is ForwardedInstanceSource forwardedUnderlying ? forwardedUnderlying.Underlying : underlying);
Expand Down
Loading

0 comments on commit f87231c

Please sign in to comment.