Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix to #18147 - Where bool column needs to convert to equality when value converter is applied #19689

Merged
merged 1 commit into from
Feb 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,22 @@ protected override Expression VisitMember(MemberExpression memberExpression)
{
Check.NotNull(memberExpression, nameof(memberExpression));

return TryBindMember(memberExpression.Expression, MemberIdentity.Create(memberExpression.Member), out var result)
? result
: TranslationFailed(memberExpression.Expression, Visit(memberExpression.Expression), out var sqlInnerExpression)
? null
: _memberTranslatorProvider.Translate(sqlInnerExpression, memberExpression.Member, memberExpression.Type);
return CompensateForValueConverter(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should wrap only result of TryBindMember. MemberTranslatorProvider should take care of itself.

TryBindMember(memberExpression.Expression, MemberIdentity.Create(memberExpression.Member), out var result)
? result
: TranslationFailed(memberExpression.Expression, Visit(memberExpression.Expression), out var sqlInnerExpression)
? null
: _memberTranslatorProvider.Translate(sqlInnerExpression, memberExpression.Member, memberExpression.Type));
}

private Expression CompensateForValueConverter(Expression result)
=> result != null
&& result.Type == typeof(bool)
&& result is KeyAccessExpression keyAccessExpression
&& keyAccessExpression.TypeMapping.Converter != null
? _sqlExpressionFactory.Equal(keyAccessExpression, _sqlExpressionFactory.Constant(true, keyAccessExpression.TypeMapping))
: result;

private bool TryBindMember(Expression source, MemberIdentity member, out Expression expression)
{
source = source.UnwrapTypeConversion(out var convertedType);
Expand Down Expand Up @@ -162,15 +171,19 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp

if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName))
{
return TryBindMember(source, MemberIdentity.Create(propertyName), out var result)
? result
: null;
return CompensateForValueConverter(
TryBindMember(source, MemberIdentity.Create(propertyName), out var result)
? result
: null);
}

// EF Indexer property
if (methodCallExpression.TryGetIndexerArguments(_model, out source, out propertyName))
{
return TryBindMember(source, MemberIdentity.Create(propertyName), out var result) ? result : null;
return CompensateForValueConverter(
TryBindMember(source, MemberIdentity.Create(propertyName), out var result)
? result
: null);
}

if (TranslationFailed(methodCallExpression.Object, Visit(methodCallExpression.Object), out var sqlObject))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -691,11 +691,12 @@ private SqlExpression OptimizeComparison(
}

if (IsTrueOrFalse(right) is bool rightTrueFalseValue
&& !leftNullable)
&& !leftNullable
&& left.TypeMapping.Converter == null)
{
_nullable = leftNullable;

// only correct in 2-value logic
// only correct in 2-value logic and only if 'a' doesn't have value converter applied to it
// a == true -> a
// a == false -> !a
// a != true -> !a
Expand All @@ -706,11 +707,12 @@ private SqlExpression OptimizeComparison(
}

if (IsTrueOrFalse(left) is bool leftTrueFalseValue
&& !rightNullable)
&& !rightNullable
&& right.TypeMapping.Converter == null)
{
_nullable = rightNullable;

// only correct in 2-value logic
// only correct in 2-value logic and only if 'a' doesn't have value converter applied to it
// true == a -> a
// false == a -> !a
// true != a -> !a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,21 @@ protected override Expression VisitMember(MemberExpression memberExpression)
{
Check.NotNull(memberExpression, nameof(memberExpression));

return TryBindMember(memberExpression.Expression, MemberIdentity.Create(memberExpression.Member), out var result)
? result
: TranslationFailed(memberExpression.Expression, base.Visit(memberExpression.Expression), out var sqlInnerExpression)
? null
: Dependencies.MemberTranslatorProvider.Translate(sqlInnerExpression, memberExpression.Member, memberExpression.Type);
}
return CompensateForValueConverter(
TryBindMember(memberExpression.Expression, MemberIdentity.Create(memberExpression.Member), out var result)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to above.

? result
: TranslationFailed(memberExpression.Expression, base.Visit(memberExpression.Expression), out var sqlInnerExpression)
? null
: Dependencies.MemberTranslatorProvider.Translate(sqlInnerExpression, memberExpression.Member, memberExpression.Type));
}

private Expression CompensateForValueConverter(Expression result)
=> result != null
&& result.Type == typeof(bool)
&& result is ColumnExpression columnExpression
&& columnExpression.TypeMapping.Converter != null
? SqlExpressionFactory.Equal(columnExpression, SqlExpressionFactory.Constant(true, columnExpression.TypeMapping))
: result;

private bool TryBindMember(Expression source, MemberIdentity member, out Expression expression)
{
Expand Down Expand Up @@ -369,7 +378,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
{
if (TryBindMember(source, MemberIdentity.Create(propertyName), out var result))
{
return result;
return CompensateForValueConverter(result);
}

throw new InvalidOperationException("EF.Property called with wrong property name.");
Expand All @@ -378,7 +387,10 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
// EF Indexer property
if (methodCallExpression.TryGetIndexerArguments(_model, out source, out propertyName))
{
return TryBindMember(source, MemberIdentity.Create(propertyName), out var result) ? result : null;
return CompensateForValueConverter(
TryBindMember(source, MemberIdentity.Create(propertyName), out var result)
? result
: null);
}

// GroupBy Aggregate case
Expand Down
33 changes: 31 additions & 2 deletions test/EFCore.Cosmos.FunctionalTests/CustomConvertersCosmosTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ public class CustomConvertersCosmosTest : CustomConvertersTestBase<CustomConvert
public CustomConvertersCosmosTest(CustomConvertersCosmosFixture fixture)
: base(fixture)
{
Fixture.TestSqlLoggerFactory.Clear();
}

public override void Can_perform_query_with_max_length()
Expand Down Expand Up @@ -137,16 +138,43 @@ public override void Value_conversion_is_appropriately_used_for_left_join_condit
base.Value_conversion_is_appropriately_used_for_left_join_condition();
}

[ConditionalFact(Skip = "Issue #18147")]
[ConditionalFact]
public override void Where_bool_gets_converted_to_equality_when_value_conversion_is_used()
{
base.Where_bool_gets_converted_to_equality_when_value_conversion_is_used();

AssertSql(
@"SELECT c
FROM root c
WHERE (c[""Discriminator""] IN (""Blog"", ""RssBlog"") AND (c[""IsVisible""] = ""Y""))");
}

public override void Where_bool_gets_converted_to_equality_when_value_conversion_is_used_using_EFProperty()
{
base.Where_bool_gets_converted_to_equality_when_value_conversion_is_used_using_EFProperty();

AssertSql(
@"SELECT c
FROM root c
WHERE (c[""Discriminator""] IN (""Blog"", ""RssBlog"") AND (c[""IsVisible""] = ""Y""))");
}

public override void Where_bool_gets_converted_to_equality_when_value_conversion_is_used_using_indexer()
{
base.Where_bool_gets_converted_to_equality_when_value_conversion_is_used_using_indexer();

AssertSql(
@"SELECT c
FROM root c
WHERE (c[""Discriminator""] IN (""Blog"", ""RssBlog"") AND NOT((c[""IndexerVisible""] = ""Aye"")))");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

public class CustomConvertersCosmosFixture : CustomConvertersFixtureBase
{
protected override ITestStoreFactory TestStoreFactory => CosmosTestStoreFactory.Instance;

public override bool StrictEquality => true;

public override int IntegerPrecision => 53;
Expand All @@ -162,6 +190,7 @@ public class CustomConvertersCosmosFixture : CustomConvertersFixtureBase
public override bool SupportsDecimalComparisons => true;

public override DateTime DefaultDateTime => new DateTime();
public TestSqlLoggerFactory TestSqlLoggerFactory => (TestSqlLoggerFactory)ListLoggerFactory;

protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext context)
{
Expand Down
57 changes: 53 additions & 4 deletions test/EFCore.Specification.Tests/CustomConvertersTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,26 @@ public virtual void Where_bool_gets_converted_to_equality_when_value_conversion_
Assert.Equal("http://blog.com", result.Url);
}

[ConditionalFact]
public virtual void Where_bool_gets_converted_to_equality_when_value_conversion_is_used_using_EFProperty()
{
using var context = CreateContext();
var query = context.Set<Blog>().Where(b => EF.Property<bool>(b, "IsVisible")).ToList();

var result = Assert.Single(query);
Assert.Equal("http://blog.com", result.Url);
}

[ConditionalFact]
public virtual void Where_bool_gets_converted_to_equality_when_value_conversion_is_used_using_indexer()
{
using var context = CreateContext();
var query = context.Set<Blog>().Where(b => !(bool)b["IndexerVisible"]).ToList();

var result = Assert.Single(query);
Assert.Equal("http://blog.com", result.Url);
}

[ConditionalFact]
public virtual void Value_conversion_with_property_named_value()
{
Expand All @@ -491,10 +511,35 @@ public virtual void Value_conversion_with_property_named_value()

protected class Blog
{
private bool _indexerVisible;

public int BlogId { get; set; }
public string Url { get; set; }
public bool IsVisible { get; set; }
public List<Post> Posts { get; set; }

public object this[string name]
{
get
{
if (!string.Equals(name, "IndexerVisible", StringComparison.Ordinal))
{
throw new InvalidOperationException($"Indexed property with key {name} is not defined on {nameof(Blog)}.");
}

return _indexerVisible;
}

set
{
if (!string.Equals(name, "IndexerVisible", StringComparison.Ordinal))
{
throw new InvalidOperationException($"Indexed property with key {name} is not defined on {nameof(Blog)}.");
}

_indexerVisible = (bool)value;
}
}
}

protected class RssBlog : Blog
Expand Down Expand Up @@ -952,12 +997,15 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
{
b.Property(e => e.Url).HasConversion(urlConverter);
b.Property(e => e.IsVisible).HasConversion(new BoolToStringConverter("N", "Y"));
b.IndexedProperty(typeof(bool), "IndexerVisible").HasConversion(new BoolToStringConverter("Nay", "Aye"));

b.HasData(
new Blog
new
{
BlogId = 1,
Url = "http://blog.com",
IsVisible = true
IsVisible = true,
IndexerVisible = false,
});
});

Expand All @@ -966,12 +1014,13 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
{
b.Property(e => e.RssUrl).HasConversion(urlConverter);
b.HasData(
new RssBlog
new
{
BlogId = 2,
Url = "http://rssblog.com",
RssUrl = "http://rssblog.com/rss",
IsVisible = false
IsVisible = false,
IndexerVisible = true,
});
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ public class CustomConvertersSqlServerTest : CustomConvertersTestBase<CustomConv
public CustomConvertersSqlServerTest(CustomConvertersSqlServerFixture fixture)
: base(fixture)
{
Fixture.TestSqlLoggerFactory.Clear();
}

[ConditionalFact]
Expand All @@ -37,6 +38,7 @@ public virtual void Columns_have_expected_data_types()
BinaryKeyDataType.Id ---> [varbinary] [MaxLength = 900]
Blog.BlogId ---> [int] [Precision = 10 Scale = 0]
Blog.Discriminator ---> [nvarchar] [MaxLength = -1]
Blog.IndexerVisible ---> [nvarchar] [MaxLength = 3]
Blog.IsVisible ---> [nvarchar] [MaxLength = 1]
Blog.RssUrl ---> [nullable nvarchar] [MaxLength = -1]
Blog.Url ---> [nullable nvarchar] [MaxLength = -1]
Expand Down Expand Up @@ -194,6 +196,68 @@ public virtual void Columns_have_expected_data_types()
Assert.Equal(expected, actual, ignoreLineEndingDifferences: true);
}

[ConditionalFact]
public override void Value_conversion_is_appropriately_used_for_join_condition()
{
base.Value_conversion_is_appropriately_used_for_join_condition();

AssertSql(
@"@__blogId_0='1'

SELECT [b].[Url]
FROM [Blog] AS [b]
INNER JOIN [Post] AS [p] ON (([b].[BlogId] = [p].[BlogId]) AND ([b].[IsVisible] = N'Y')) AND ([b].[BlogId] = @__blogId_0)
WHERE [b].[Discriminator] IN (N'Blog', N'RssBlog') AND ([b].[IsVisible] = N'Y')");
}

[ConditionalFact]
public override void Value_conversion_is_appropriately_used_for_left_join_condition()
{
base.Value_conversion_is_appropriately_used_for_left_join_condition();

AssertSql(
@"@__blogId_0='1'

SELECT [b].[Url]
FROM [Blog] AS [b]
LEFT JOIN [Post] AS [p] ON (([b].[BlogId] = [p].[BlogId]) AND ([b].[IsVisible] = N'Y')) AND ([b].[BlogId] = @__blogId_0)
WHERE [b].[Discriminator] IN (N'Blog', N'RssBlog') AND ([b].[IsVisible] = N'Y')");
}

[ConditionalFact]
public override void Where_bool_gets_converted_to_equality_when_value_conversion_is_used()
{
base.Where_bool_gets_converted_to_equality_when_value_conversion_is_used();

AssertSql(
@"SELECT [b].[BlogId], [b].[Discriminator], [b].[IndexerVisible], [b].[IsVisible], [b].[Url], [b].[RssUrl]
FROM [Blog] AS [b]
WHERE [b].[Discriminator] IN (N'Blog', N'RssBlog') AND ([b].[IsVisible] = N'Y')");
}

public override void Where_bool_gets_converted_to_equality_when_value_conversion_is_used_using_EFProperty()
{
base.Where_bool_gets_converted_to_equality_when_value_conversion_is_used_using_EFProperty();

AssertSql(
@"SELECT [b].[BlogId], [b].[Discriminator], [b].[IndexerVisible], [b].[IsVisible], [b].[Url], [b].[RssUrl]
FROM [Blog] AS [b]
WHERE [b].[Discriminator] IN (N'Blog', N'RssBlog') AND ([b].[IsVisible] = N'Y')");
}

public override void Where_bool_gets_converted_to_equality_when_value_conversion_is_used_using_indexer()
{
base.Where_bool_gets_converted_to_equality_when_value_conversion_is_used_using_indexer();

AssertSql(
@"SELECT [b].[BlogId], [b].[Discriminator], [b].[IndexerVisible], [b].[IsVisible], [b].[Url], [b].[RssUrl]
FROM [Blog] AS [b]
WHERE [b].[Discriminator] IN (N'Blog', N'RssBlog') AND ([b].[IndexerVisible] <> N'Aye')");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

public class CustomConvertersSqlServerFixture : CustomConvertersFixtureBase
{
public override bool StrictEquality => true;
Expand All @@ -205,7 +269,7 @@ public class CustomConvertersSqlServerFixture : CustomConvertersFixtureBase
public override bool SupportsLargeStringComparisons => true;

protected override ITestStoreFactory TestStoreFactory => SqlServerTestStoreFactory.Instance;

public TestSqlLoggerFactory TestSqlLoggerFactory => (TestSqlLoggerFactory)ListLoggerFactory;
public override bool SupportsBinaryKeys => true;

public override bool SupportsDecimalComparisons => true;
Expand Down
Loading