Skip to content

Commit

Permalink
Handle recursive type graphs in config binding generator (#83644)
Browse files Browse the repository at this point in the history
  • Loading branch information
layomia authored Mar 20, 2023
1 parent 1448de8 commit 4bd7819
Show file tree
Hide file tree
Showing 8 changed files with 211 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ private static class GlobalName
public const string Enum = "global::System.Enum";
public const string FromBase64String = "global::System.Convert.FromBase64String";
public const string IConfiguration = "global::Microsoft.Extensions.Configuration.IConfiguration";
public const string IConfigurationSection = "global::Microsoft.Extensions.Configuration.IConfigurationSection";
public const string Int32 = "int";
public const string IServiceCollection = "global::Microsoft.Extensions.DependencyInjection.IServiceCollection";
public const string Object = "object";
Expand Down Expand Up @@ -75,6 +76,8 @@ public void Emit()

EmitBindMethods();

EmitIConfigurationHasChildrenHelperMethod();

_writer.WriteBlockEnd();

SourceText source = SourceText.From(_writer.GetSource(), Encoding.UTF8);
Expand Down Expand Up @@ -363,7 +366,7 @@ private void EmitBindCoreImplForProperty(PropertySpec property, TypeSpec propert
string propertyParentReference = property.IsStatic ? parentType.DisplayString : Literal.obj;
string expressionForPropertyAccess = $"{propertyParentReference}.{property.Name}";

string expressionForConfigGetSection = $@"{Literal.configuration}.{Literal.GetSection}(""{configurationKeyName}"")";
string expressionForConfigSectionAccess = $@"{Literal.configuration}.{Literal.GetSection}(""{configurationKeyName}"")";
string expressionForConfigValueIndexer = $@"{Literal.configuration}[""{configurationKeyName}""]";

bool canGet = property.CanGet;
Expand Down Expand Up @@ -394,12 +397,12 @@ private void EmitBindCoreImplForProperty(PropertySpec property, TypeSpec propert
property,
propertyType,
expressionForPropertyAccess,
expressionForConfigArg: expressionForConfigGetSection);
expressionForConfigSectionAccess);
}
break;
case TypeSpecKind.IConfigurationSection:
{
EmitAssignment(expressionForPropertyAccess, expressionForConfigGetSection);
EmitAssignment(expressionForPropertyAccess, expressionForConfigSectionAccess);
}
break;
case TypeSpecKind.Nullable:
Expand All @@ -414,7 +417,7 @@ private void EmitBindCoreImplForProperty(PropertySpec property, TypeSpec propert
property,
propertyType,
expressionForPropertyAccess,
expressionForConfigArg: expressionForConfigGetSection);
expressionForConfigSectionAccess);
}
break;
}
Expand Down Expand Up @@ -491,8 +494,12 @@ private void EmitBindCoreCallForProperty(
PropertySpec property,
TypeSpec effectivePropertyType,
string expressionForPropertyAccess,
string expressionForConfigArg)
string expressionForConfigSectionAccess)
{
string bindCoreConfigArg = GetIncrementalVarName(Literal.section);
EmitAssignment($"{GlobalName.IConfigurationSection} {bindCoreConfigArg}", expressionForConfigSectionAccess);
_writer.WriteBlockStart($"if ({Literal.HasChildren}({bindCoreConfigArg}))");

bool canGet = property.CanGet;
bool canSet = property.CanSet;

Expand Down Expand Up @@ -523,7 +530,7 @@ private void EmitBindCoreCallForProperty(
EmitObjectInit(effectivePropertyType, tempVarName, InitializationKind.Declaration);
}

_writer.WriteLine($@"{Literal.BindCore}({expressionForConfigArg}, ref {tempVarName});");
_writer.WriteLine($@"{Literal.BindCore}({bindCoreConfigArg}, ref {tempVarName});");
EmitAssignment(expressionForPropertyAccess, tempVarName);
_privateBindCoreMethodGen_QueuedTypes.Enqueue(effectivePropertyType);
}
Expand All @@ -532,7 +539,7 @@ private void EmitBindCoreCallForProperty(
{
EmitAssignment($"{effectivePropertyType.DisplayString} {tempVarName}", $"{expressionForPropertyAccess}");
EmitObjectInit(effectivePropertyType, tempVarName, InitializationKind.AssignmentWithNullCheck);
_writer.WriteLine($@"{Literal.BindCore}({expressionForConfigArg}, ref {tempVarName});");
_writer.WriteLine($@"{Literal.BindCore}({bindCoreConfigArg}, ref {tempVarName});");

if (canSet)
{
Expand All @@ -543,10 +550,12 @@ private void EmitBindCoreCallForProperty(
{
Debug.Assert(canSet);
EmitObjectInit(effectivePropertyType, tempVarName, InitializationKind.Declaration);
_writer.WriteLine($@"{Literal.BindCore}({expressionForConfigArg}, ref {tempVarName});");
_writer.WriteLine($@"{Literal.BindCore}({bindCoreConfigArg}, ref {tempVarName});");
EmitAssignment(expressionForPropertyAccess, tempVarName);
}

_writer.WriteBlockEnd();

_privateBindCoreMethodGen_QueuedTypes.Enqueue(effectivePropertyType);
}

Expand Down Expand Up @@ -609,7 +618,7 @@ private void EmitObjectInit(TypeSpec type, string expressionForMemberAccess, Ini
{
return;
}
else if (type is CollectionSpec { ConcreteType: { } concreteType})
else if (type is CollectionSpec { ConcreteType: { } concreteType })
{
displayString = concreteType.DisplayString;
}
Expand Down Expand Up @@ -639,6 +648,16 @@ private void EmitCastToIConfigurationSection()
_writer.WriteBlockEnd();
}

private void EmitIConfigurationHasChildrenHelperMethod()
{
_writer.WriteBlockStart($"public static bool {Literal.HasChildren}({GlobalName.IConfiguration} {Literal.configuration})");
_writer.WriteBlockStart($"foreach ({GlobalName.IConfigurationSection} {Literal.section} in {Literal.configuration}.{Literal.GetChildren}())");
_writer.WriteLine($"return true;");
_writer.WriteBlockEnd();
_writer.WriteLine($"return false;");
_writer.WriteBlockEnd();
}

private void EmitVarDeclaration(TypeSpec type, string varName) => _writer.WriteLine($"{type.DisplayString} {varName};");

private void EmitAssignment(string lhsSource, string rhsSource) => _writer.WriteLine($"{lhsSource} = {rhsSource};");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System;
using System.Diagnostics;
using System.Linq;
using System.Text;
using Microsoft.CodeAnalysis;

Expand Down Expand Up @@ -61,7 +60,10 @@ private static class Literal
public const string Get = nameof(Get);
public const string GetChildren = nameof(GetChildren);
public const string GetSection = nameof(GetSection);
public const string HasChildren = nameof(HasChildren);
public const string HasValue = nameof(HasValue);
public const string IConfiguration = nameof(IConfiguration);
public const string IConfigurationSection = nameof(IConfigurationSection);
public const string Length = nameof(Length);
public const string Parse = nameof(Parse);
public const string Resize = nameof(Resize);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,12 +355,18 @@ private DictionarySpec CreateDictionarySpec(INamedTypeSymbol type, Location? loc

private ObjectSpec? CreateObjectSpec(INamedTypeSymbol type, Location? location)
{
Debug.Assert(!_createdSpecs.ContainsKey(type));

// Add spec to cache before traversing properties to avoid stack overflow.

if (!CanConstructObject(type, location))
{
_createdSpecs.Add(type, null);
return null;
}
ObjectSpec objectSpec = new(type) { Location = location, ConstructionStrategy = ConstructionStrategy.ParameterlessConstructor };
_createdSpecs.Add(type, objectSpec);

List<PropertySpec> properties = new();
INamedTypeSymbol current = type;
while (current != null)
{
Expand All @@ -385,7 +391,7 @@ private DictionarySpec CreateDictionarySpec(INamedTypeSymbol type, Location? loc
PropertySpec spec = new PropertySpec(property) { Type = propertyTypeSpec, ConfigurationKeyName = configKeyName };
if (spec.CanGet || spec.CanSet)
{
properties.Add(spec);
objectSpec.Properties.Add(spec);
}
}
}
Expand All @@ -394,7 +400,7 @@ private DictionarySpec CreateDictionarySpec(INamedTypeSymbol type, Location? loc
current = current.BaseType;
}

return new ObjectSpec(type) { Location = location, Properties = properties, ConstructionStrategy = ConstructionStrategy.ParameterlessConstructor };
return objectSpec;
}

private bool IsCandidateEnumerable(INamedTypeSymbol type, out ITypeSymbol? elementType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ internal sealed record ObjectSpec : TypeSpec
{
public ObjectSpec(INamedTypeSymbol type) : base(type) { }
public override TypeSpecKind SpecKind => TypeSpecKind.Object;
public required List<PropertySpec> Properties { get; init; }
public List<PropertySpec> Properties { get; } = new();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1434,5 +1434,75 @@ public void EnsureCallingThePropertySetter()
Assert.Equal(401, options.HttpStatusCode); // exists in configuration and properly sets the property
Assert.Equal(2, options.OtherCode); // doesn't exist in configuration. the setter sets default value '2'
}

[Fact]
public void RecursiveTypeGraphs_DirectRef()
{
var data = @"{
""MyString"":""Hello"",
""MyClass"": {
""MyString"": ""World"",
""MyClass"": {
""MyString"": ""World"",
""MyClass"": null
}
}
}";

var configuration = new ConfigurationBuilder()
.AddJsonStream(TestStreamHelpers.StringToStream(data))
.Build();

var obj = configuration.Get<ClassWithDirectSelfReference>();
Assert.Equal("Hello", obj.MyString);

var nested = obj.MyClass;
Assert.Equal("World", nested.MyString);

var deeplyNested = nested.MyClass;
Assert.Equal("World", deeplyNested.MyString);
Assert.Null(deeplyNested.MyClass);
}

public class ClassWithDirectSelfReference
{
public string MyString { get; set; }
public ClassWithDirectSelfReference MyClass { get; set; }
}

[Fact]
public void RecursiveTypeGraphs_IndirectRef()
{
var data = @"{
""MyString"":""Hello"",
""MyList"": [{
""MyString"": ""World"",
""MyList"": [{
""MyString"": ""World"",
""MyClass"": null
}]
}]
}";

var configuration = new ConfigurationBuilder()
.AddJsonStream(TestStreamHelpers.StringToStream(data))
.Build();

var obj = configuration.Get<ClassWithIndirectSelfReference>();
Assert.Equal("Hello", obj.MyString);

var nested = obj.MyList[0];
Assert.Equal("World", nested.MyString);

var deeplyNested = nested.MyList[0];
Assert.Equal("World", deeplyNested.MyString);
Assert.Null(deeplyNested.MyList);
}

public class ClassWithIndirectSelfReference
{
public string MyString { get; set; }
public List<ClassWithIndirectSelfReference> MyList { get; set; }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,32 @@ internal static class GeneratedConfigurationBinder
obj.MyInt = int.Parse(stringValue1);
}

System.Collections.Generic.List<int> temp2 = obj.MyList;
temp2 ??= new System.Collections.Generic.List<int>();
BindCore(configuration.GetSection("MyList"), ref temp2);
obj.MyList = temp2;
global::Microsoft.Extensions.Configuration.IConfigurationSection section2 = configuration.GetSection("MyList");
if (HasChildren(section2))
{
System.Collections.Generic.List<int> temp3 = obj.MyList;
temp3 ??= new System.Collections.Generic.List<int>();
BindCore(section2, ref temp3);
obj.MyList = temp3;
}

System.Collections.Generic.Dictionary<string, string> temp3 = obj.MyDictionary;
temp3 ??= new System.Collections.Generic.Dictionary<string, string>();
BindCore(configuration.GetSection("MyDictionary"), ref temp3);
obj.MyDictionary = temp3;
global::Microsoft.Extensions.Configuration.IConfigurationSection section4 = configuration.GetSection("MyDictionary");
if (HasChildren(section4))
{
System.Collections.Generic.Dictionary<string, string> temp5 = obj.MyDictionary;
temp5 ??= new System.Collections.Generic.Dictionary<string, string>();
BindCore(section4, ref temp5);
obj.MyDictionary = temp5;
}

System.Collections.Generic.Dictionary<string, Program.MyClass2> temp4 = obj.MyComplexDictionary;
temp4 ??= new System.Collections.Generic.Dictionary<string, Program.MyClass2>();
BindCore(configuration.GetSection("MyComplexDictionary"), ref temp4);
obj.MyComplexDictionary = temp4;
global::Microsoft.Extensions.Configuration.IConfigurationSection section6 = configuration.GetSection("MyComplexDictionary");
if (HasChildren(section6))
{
System.Collections.Generic.Dictionary<string, Program.MyClass2> temp7 = obj.MyComplexDictionary;
temp7 ??= new System.Collections.Generic.Dictionary<string, Program.MyClass2>();
BindCore(section6, ref temp7);
obj.MyComplexDictionary = temp7;
}

}

Expand All @@ -51,9 +63,9 @@ internal static class GeneratedConfigurationBinder
int element;
foreach (Microsoft.Extensions.Configuration.IConfigurationSection section in configuration.GetChildren())
{
if (section.Value is string stringValue5)
if (section.Value is string stringValue8)
{
element = int.Parse(stringValue5);
element = int.Parse(stringValue8);
obj.Add(element);
}
}
Expand All @@ -69,13 +81,13 @@ internal static class GeneratedConfigurationBinder
string key;
foreach (Microsoft.Extensions.Configuration.IConfigurationSection section in configuration.GetChildren())
{
if (section.Key is string stringValue6)
if (section.Key is string stringValue9)
{
key = stringValue6;
key = stringValue9;
string element;
if (section.Value is string stringValue7)
if (section.Value is string stringValue10)
{
element = stringValue7;
element = stringValue10;
obj[key] = element;
}
}
Expand All @@ -92,9 +104,9 @@ internal static class GeneratedConfigurationBinder
string key;
foreach (Microsoft.Extensions.Configuration.IConfigurationSection section in configuration.GetChildren())
{
if (section.Key is string stringValue8)
if (section.Key is string stringValue11)
{
key = stringValue8;
key = stringValue11;
if (obj.TryGetValue(key, out Program.MyClass2? element) && element is not null)
{
BindCore(section, ref element);
Expand All @@ -119,4 +131,12 @@ internal static class GeneratedConfigurationBinder

}

public static bool HasChildren(global::Microsoft.Extensions.Configuration.IConfiguration configuration)
{
foreach (global::Microsoft.Extensions.Configuration.IConfigurationSection section in configuration.GetChildren())
{
return true;
}
return false;
}
}
Loading

0 comments on commit 4bd7819

Please sign in to comment.