Skip to content

Commit

Permalink
Fix to #18147 - Where bool column needs to convert to equality when v…
Browse files Browse the repository at this point in the history
…alue converter is applied

Fix is to detect bool columns with value converters (upon initial translation) and apply comparison with constant true (with the same mapping). Also, we need to recognize this pattern during SqlExpression optimization, so that it' doesn't get simplified from 'a == true' to 'a'

Resolves #18147
  • Loading branch information
maumar committed Jan 24, 2020
1 parent 7e97f36 commit e33ea57
Show file tree
Hide file tree
Showing 7 changed files with 245 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,22 @@ protected override Expression VisitMember(MemberExpression memberExpression)
{
Check.NotNull(memberExpression, nameof(memberExpression));

return TryBindMember(memberExpression.Expression, MemberIdentity.Create(memberExpression.Member), out var result)
? result
: TranslationFailed(memberExpression.Expression, Visit(memberExpression.Expression), out var sqlInnerExpression)
? null
: _memberTranslatorProvider.Translate(sqlInnerExpression, memberExpression.Member, memberExpression.Type);
return CompensateForValueConverter(
TryBindMember(memberExpression.Expression, MemberIdentity.Create(memberExpression.Member), out var result)
? result
: TranslationFailed(memberExpression.Expression, Visit(memberExpression.Expression), out var sqlInnerExpression)
? null
: _memberTranslatorProvider.Translate(sqlInnerExpression, memberExpression.Member, memberExpression.Type));
}

private Expression CompensateForValueConverter(Expression result)
=> result != null
&& result.Type == typeof(bool)
&& result is KeyAccessExpression keyAccessExpression
&& keyAccessExpression.TypeMapping.Converter != null
? _sqlExpressionFactory.Equal(keyAccessExpression, _sqlExpressionFactory.Constant(true, keyAccessExpression.TypeMapping))
: result;

private bool TryBindMember(Expression source, MemberIdentity member, out Expression expression)
{
source = source.UnwrapTypeConversion(out var convertedType);
Expand Down Expand Up @@ -162,15 +171,19 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp

if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName))
{
return TryBindMember(source, MemberIdentity.Create(propertyName), out var result)
? result
: null;
return CompensateForValueConverter(
TryBindMember(source, MemberIdentity.Create(propertyName), out var result)
? result
: null);
}

// EF Indexer property
if (methodCallExpression.TryGetIndexerArguments(_model, out source, out propertyName))
{
return TryBindMember(source, MemberIdentity.Create(propertyName), out var result) ? result : null;
return CompensateForValueConverter(
TryBindMember(source, MemberIdentity.Create(propertyName), out var result)
? result
: null);
}

if (TranslationFailed(methodCallExpression.Object, Visit(methodCallExpression.Object), out var sqlObject))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -691,11 +691,12 @@ private SqlExpression OptimizeComparison(
}

if (IsTrueOrFalse(right) is bool rightTrueFalseValue
&& !leftNullable)
&& !leftNullable
&& left.TypeMapping.Converter == null)
{
_nullable = leftNullable;

// only correct in 2-value logic
// only correct in 2-value logic and only if 'a' doesn't have value converter applied to it
// a == true -> a
// a == false -> !a
// a != true -> !a
Expand All @@ -706,11 +707,12 @@ private SqlExpression OptimizeComparison(
}

if (IsTrueOrFalse(left) is bool leftTrueFalseValue
&& !rightNullable)
&& !rightNullable
&& right.TypeMapping.Converter == null)
{
_nullable = rightNullable;

// only correct in 2-value logic
// only correct in 2-value logic and only if 'a' doesn't have value converter applied to it
// true == a -> a
// false == a -> !a
// true != a -> !a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,21 @@ protected override Expression VisitMember(MemberExpression memberExpression)
{
Check.NotNull(memberExpression, nameof(memberExpression));

return TryBindMember(memberExpression.Expression, MemberIdentity.Create(memberExpression.Member), out var result)
? result
: TranslationFailed(memberExpression.Expression, base.Visit(memberExpression.Expression), out var sqlInnerExpression)
? null
: Dependencies.MemberTranslatorProvider.Translate(sqlInnerExpression, memberExpression.Member, memberExpression.Type);
}
return CompensateForValueConverter(
TryBindMember(memberExpression.Expression, MemberIdentity.Create(memberExpression.Member), out var result)
? result
: TranslationFailed(memberExpression.Expression, base.Visit(memberExpression.Expression), out var sqlInnerExpression)
? null
: Dependencies.MemberTranslatorProvider.Translate(sqlInnerExpression, memberExpression.Member, memberExpression.Type));
}

private Expression CompensateForValueConverter(Expression result)
=> result != null
&& result.Type == typeof(bool)
&& result is ColumnExpression columnExpression
&& columnExpression.TypeMapping.Converter != null
? SqlExpressionFactory.Equal(columnExpression, SqlExpressionFactory.Constant(true, columnExpression.TypeMapping))
: result;

private bool TryBindMember(Expression source, MemberIdentity member, out Expression expression)
{
Expand Down Expand Up @@ -369,7 +378,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
{
if (TryBindMember(source, MemberIdentity.Create(propertyName), out var result))
{
return result;
return CompensateForValueConverter(result);
}

throw new InvalidOperationException("EF.Property called with wrong property name.");
Expand All @@ -378,7 +387,10 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
// EF Indexer property
if (methodCallExpression.TryGetIndexerArguments(_model, out source, out propertyName))
{
return TryBindMember(source, MemberIdentity.Create(propertyName), out var result) ? result : null;
return CompensateForValueConverter(
TryBindMember(source, MemberIdentity.Create(propertyName), out var result)
? result
: null);
}

// GroupBy Aggregate case
Expand Down
33 changes: 31 additions & 2 deletions test/EFCore.Cosmos.FunctionalTests/CustomConvertersCosmosTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ public class CustomConvertersCosmosTest : CustomConvertersTestBase<CustomConvert
public CustomConvertersCosmosTest(CustomConvertersCosmosFixture fixture)
: base(fixture)
{
Fixture.TestSqlLoggerFactory.Clear();
}

public override void Can_perform_query_with_max_length()
Expand Down Expand Up @@ -137,16 +138,43 @@ public override void Value_conversion_is_appropriately_used_for_left_join_condit
base.Value_conversion_is_appropriately_used_for_left_join_condition();
}

[ConditionalFact(Skip = "Issue #18147")]
[ConditionalFact]
public override void Where_bool_gets_converted_to_equality_when_value_conversion_is_used()
{
base.Where_bool_gets_converted_to_equality_when_value_conversion_is_used();

AssertSql(
@"SELECT c
FROM root c
WHERE (c[""Discriminator""] IN (""Blog"", ""RssBlog"") AND (c[""IsVisible""] = ""Y""))");
}

public override void Where_bool_gets_converted_to_equality_when_value_conversion_is_used_using_EFProperty()
{
base.Where_bool_gets_converted_to_equality_when_value_conversion_is_used_using_EFProperty();

AssertSql(
@"SELECT c
FROM root c
WHERE (c[""Discriminator""] IN (""Blog"", ""RssBlog"") AND (c[""IsVisible""] = ""Y""))");
}

public override void Where_bool_gets_converted_to_equality_when_value_conversion_is_used_using_indexer()
{
base.Where_bool_gets_converted_to_equality_when_value_conversion_is_used_using_indexer();

AssertSql(
@"SELECT c
FROM root c
WHERE (c[""Discriminator""] IN (""Blog"", ""RssBlog"") AND NOT((c[""IndexerVisible""] = ""Aye"")))");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

public class CustomConvertersCosmosFixture : CustomConvertersFixtureBase
{
protected override ITestStoreFactory TestStoreFactory => CosmosTestStoreFactory.Instance;

public override bool StrictEquality => true;

public override int IntegerPrecision => 53;
Expand All @@ -162,6 +190,7 @@ public class CustomConvertersCosmosFixture : CustomConvertersFixtureBase
public override bool SupportsDecimalComparisons => true;

public override DateTime DefaultDateTime => new DateTime();
public TestSqlLoggerFactory TestSqlLoggerFactory => (TestSqlLoggerFactory)ListLoggerFactory;

protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext context)
{
Expand Down
57 changes: 53 additions & 4 deletions test/EFCore.Specification.Tests/CustomConvertersTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,26 @@ public virtual void Where_bool_gets_converted_to_equality_when_value_conversion_
Assert.Equal("http://blog.com", result.Url);
}

[ConditionalFact]
public virtual void Where_bool_gets_converted_to_equality_when_value_conversion_is_used_using_EFProperty()
{
using var context = CreateContext();
var query = context.Set<Blog>().Where(b => EF.Property<bool>(b, "IsVisible")).ToList();

var result = Assert.Single(query);
Assert.Equal("http://blog.com", result.Url);
}

[ConditionalFact]
public virtual void Where_bool_gets_converted_to_equality_when_value_conversion_is_used_using_indexer()
{
using var context = CreateContext();
var query = context.Set<Blog>().Where(b => !(bool)b["IndexerVisible"]).ToList();

var result = Assert.Single(query);
Assert.Equal("http://blog.com", result.Url);
}

[ConditionalFact]
public virtual void Value_conversion_with_property_named_value()
{
Expand All @@ -491,10 +511,35 @@ public virtual void Value_conversion_with_property_named_value()

protected class Blog
{
private bool _indexerVisible;

public int BlogId { get; set; }
public string Url { get; set; }
public bool IsVisible { get; set; }
public List<Post> Posts { get; set; }

public object this[string name]
{
get
{
if (!string.Equals(name, "IndexerVisible", StringComparison.Ordinal))
{
throw new InvalidOperationException($"Indexed property with key {name} is not defined on {nameof(Blog)}.");
}

return _indexerVisible;
}

set
{
if (!string.Equals(name, "IndexerVisible", StringComparison.Ordinal))
{
throw new InvalidOperationException($"Indexed property with key {name} is not defined on {nameof(Blog)}.");
}

_indexerVisible = (bool)value;
}
}
}

protected class RssBlog : Blog
Expand Down Expand Up @@ -952,12 +997,15 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
{
b.Property(e => e.Url).HasConversion(urlConverter);
b.Property(e => e.IsVisible).HasConversion(new BoolToStringConverter("N", "Y"));
b.IndexedProperty(typeof(bool), "IndexerVisible").HasConversion(new BoolToStringConverter("Nay", "Aye"));

b.HasData(
new Blog
new
{
BlogId = 1,
Url = "http://blog.com",
IsVisible = true
IsVisible = true,
IndexerVisible = false,
});
});

Expand All @@ -966,12 +1014,13 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
{
b.Property(e => e.RssUrl).HasConversion(urlConverter);
b.HasData(
new RssBlog
new
{
BlogId = 2,
Url = "http://rssblog.com",
RssUrl = "http://rssblog.com/rss",
IsVisible = false
IsVisible = false,
IndexerVisible = true,
});
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ public class CustomConvertersSqlServerTest : CustomConvertersTestBase<CustomConv
public CustomConvertersSqlServerTest(CustomConvertersSqlServerFixture fixture)
: base(fixture)
{
Fixture.TestSqlLoggerFactory.Clear();
}

[ConditionalFact]
Expand All @@ -37,6 +38,7 @@ public virtual void Columns_have_expected_data_types()
BinaryKeyDataType.Id ---> [varbinary] [MaxLength = 900]
Blog.BlogId ---> [int] [Precision = 10 Scale = 0]
Blog.Discriminator ---> [nvarchar] [MaxLength = -1]
Blog.IndexerVisible ---> [nvarchar] [MaxLength = 3]
Blog.IsVisible ---> [nvarchar] [MaxLength = 1]
Blog.RssUrl ---> [nullable nvarchar] [MaxLength = -1]
Blog.Url ---> [nullable nvarchar] [MaxLength = -1]
Expand Down Expand Up @@ -194,6 +196,68 @@ public virtual void Columns_have_expected_data_types()
Assert.Equal(expected, actual, ignoreLineEndingDifferences: true);
}

[ConditionalFact]
public override void Value_conversion_is_appropriately_used_for_join_condition()
{
base.Value_conversion_is_appropriately_used_for_join_condition();

AssertSql(
@"@__blogId_0='1'
SELECT [b].[Url]
FROM [Blog] AS [b]
INNER JOIN [Post] AS [p] ON (([b].[BlogId] = [p].[BlogId]) AND ([b].[IsVisible] = N'Y')) AND ([b].[BlogId] = @__blogId_0)
WHERE [b].[Discriminator] IN (N'Blog', N'RssBlog') AND ([b].[IsVisible] = N'Y')");
}

[ConditionalFact]
public override void Value_conversion_is_appropriately_used_for_left_join_condition()
{
base.Value_conversion_is_appropriately_used_for_left_join_condition();

AssertSql(
@"@__blogId_0='1'
SELECT [b].[Url]
FROM [Blog] AS [b]
LEFT JOIN [Post] AS [p] ON (([b].[BlogId] = [p].[BlogId]) AND ([b].[IsVisible] = N'Y')) AND ([b].[BlogId] = @__blogId_0)
WHERE [b].[Discriminator] IN (N'Blog', N'RssBlog') AND ([b].[IsVisible] = N'Y')");
}

[ConditionalFact]
public override void Where_bool_gets_converted_to_equality_when_value_conversion_is_used()
{
base.Where_bool_gets_converted_to_equality_when_value_conversion_is_used();

AssertSql(
@"SELECT [b].[BlogId], [b].[Discriminator], [b].[IndexerVisible], [b].[IsVisible], [b].[Url], [b].[RssUrl]
FROM [Blog] AS [b]
WHERE [b].[Discriminator] IN (N'Blog', N'RssBlog') AND ([b].[IsVisible] = N'Y')");
}

public override void Where_bool_gets_converted_to_equality_when_value_conversion_is_used_using_EFProperty()
{
base.Where_bool_gets_converted_to_equality_when_value_conversion_is_used_using_EFProperty();

AssertSql(
@"SELECT [b].[BlogId], [b].[Discriminator], [b].[IndexerVisible], [b].[IsVisible], [b].[Url], [b].[RssUrl]
FROM [Blog] AS [b]
WHERE [b].[Discriminator] IN (N'Blog', N'RssBlog') AND ([b].[IsVisible] = N'Y')");
}

public override void Where_bool_gets_converted_to_equality_when_value_conversion_is_used_using_indexer()
{
base.Where_bool_gets_converted_to_equality_when_value_conversion_is_used_using_indexer();

AssertSql(
@"SELECT [b].[BlogId], [b].[Discriminator], [b].[IndexerVisible], [b].[IsVisible], [b].[Url], [b].[RssUrl]
FROM [Blog] AS [b]
WHERE [b].[Discriminator] IN (N'Blog', N'RssBlog') AND ([b].[IndexerVisible] <> N'Aye')");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

public class CustomConvertersSqlServerFixture : CustomConvertersFixtureBase
{
public override bool StrictEquality => true;
Expand All @@ -205,7 +269,7 @@ public class CustomConvertersSqlServerFixture : CustomConvertersFixtureBase
public override bool SupportsLargeStringComparisons => true;

protected override ITestStoreFactory TestStoreFactory => SqlServerTestStoreFactory.Instance;

public TestSqlLoggerFactory TestSqlLoggerFactory => (TestSqlLoggerFactory)ListLoggerFactory;
public override bool SupportsBinaryKeys => true;

public override bool SupportsDecimalComparisons => true;
Expand Down
Loading

0 comments on commit e33ea57

Please sign in to comment.