Skip to content

Commit

Permalink
Fix to #18147 - Where bool column needs to convert to equality when v…
Browse files Browse the repository at this point in the history
…alue converter is applied

Fix is to detect bool columns with value converters (upon initial translation) and apply comparison with constant true (with the same mapping). Also, we need to recognize this pattern during SqlExpression optimization, so that it' doesn't get simplified from 'a == true' to 'a'

Resolves #18147
  • Loading branch information
maumar committed Jan 24, 2020
1 parent 7e97f36 commit 6b3dbdd
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,18 @@ protected override Expression VisitMember(MemberExpression memberExpression)
{
Check.NotNull(memberExpression, nameof(memberExpression));

return TryBindMember(memberExpression.Expression, MemberIdentity.Create(memberExpression.Member), out var result)
var boundResult = TryBindMember(memberExpression.Expression, MemberIdentity.Create(memberExpression.Member), out var result)
? result
: TranslationFailed(memberExpression.Expression, Visit(memberExpression.Expression), out var sqlInnerExpression)
? null
: _memberTranslatorProvider.Translate(sqlInnerExpression, memberExpression.Member, memberExpression.Type);

return boundResult != null
&& boundResult.Type == typeof(bool)
&& boundResult is KeyAccessExpression keyAccessExpression
&& keyAccessExpression.TypeMapping.Converter != null
? _sqlExpressionFactory.Equal(keyAccessExpression, _sqlExpressionFactory.Constant(true, keyAccessExpression.TypeMapping))
: boundResult;
}

private bool TryBindMember(Expression source, MemberIdentity member, out Expression expression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -691,11 +691,12 @@ private SqlExpression OptimizeComparison(
}

if (IsTrueOrFalse(right) is bool rightTrueFalseValue
&& !leftNullable)
&& !leftNullable
&& left.TypeMapping.Converter == null)
{
_nullable = leftNullable;

// only correct in 2-value logic
// only correct in 2-value logic and only if 'a' doesn't have value converter applied to it
// a == true -> a
// a == false -> !a
// a != true -> !a
Expand All @@ -706,11 +707,12 @@ private SqlExpression OptimizeComparison(
}

if (IsTrueOrFalse(left) is bool leftTrueFalseValue
&& !rightNullable)
&& !rightNullable
&& right.TypeMapping.Converter == null)
{
_nullable = rightNullable;

// only correct in 2-value logic
// only correct in 2-value logic and only if 'a' doesn't have value converter applied to it
// true == a -> a
// false == a -> !a
// true != a -> !a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,18 @@ protected override Expression VisitMember(MemberExpression memberExpression)
{
Check.NotNull(memberExpression, nameof(memberExpression));

return TryBindMember(memberExpression.Expression, MemberIdentity.Create(memberExpression.Member), out var result)
var boundResult = TryBindMember(memberExpression.Expression, MemberIdentity.Create(memberExpression.Member), out var result)
? result
: TranslationFailed(memberExpression.Expression, base.Visit(memberExpression.Expression), out var sqlInnerExpression)
? null
: Dependencies.MemberTranslatorProvider.Translate(sqlInnerExpression, memberExpression.Member, memberExpression.Type);

return boundResult != null
&& boundResult.Type == typeof(bool)
&& boundResult is ColumnExpression columnExpression
&& columnExpression.TypeMapping.Converter != null
? SqlExpressionFactory.Equal(columnExpression, SqlExpressionFactory.Constant(true, columnExpression.TypeMapping))
: boundResult;
}

private bool TryBindMember(Expression source, MemberIdentity member, out Expression expression)
Expand Down
13 changes: 11 additions & 2 deletions test/EFCore.Cosmos.FunctionalTests/CustomConvertersCosmosTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ public class CustomConvertersCosmosTest : CustomConvertersTestBase<CustomConvert
public CustomConvertersCosmosTest(CustomConvertersCosmosFixture fixture)
: base(fixture)
{
Fixture.TestSqlLoggerFactory.Clear();
}

public override void Can_perform_query_with_max_length()
Expand Down Expand Up @@ -137,16 +138,23 @@ public override void Value_conversion_is_appropriately_used_for_left_join_condit
base.Value_conversion_is_appropriately_used_for_left_join_condition();
}

[ConditionalFact(Skip = "Issue #18147")]
[ConditionalFact]
public override void Where_bool_gets_converted_to_equality_when_value_conversion_is_used()
{
base.Where_bool_gets_converted_to_equality_when_value_conversion_is_used();

AssertSql(
@"SELECT c
FROM root c
WHERE (c[""Discriminator""] IN (""Blog"", ""RssBlog"") AND (c[""IsVisible""] = ""Y""))");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

public class CustomConvertersCosmosFixture : CustomConvertersFixtureBase
{
protected override ITestStoreFactory TestStoreFactory => CosmosTestStoreFactory.Instance;

public override bool StrictEquality => true;

public override int IntegerPrecision => 53;
Expand All @@ -162,6 +170,7 @@ public class CustomConvertersCosmosFixture : CustomConvertersFixtureBase
public override bool SupportsDecimalComparisons => true;

public override DateTime DefaultDateTime => new DateTime();
public TestSqlLoggerFactory TestSqlLoggerFactory => (TestSqlLoggerFactory)ListLoggerFactory;

protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ public class CustomConvertersSqlServerTest : CustomConvertersTestBase<CustomConv
public CustomConvertersSqlServerTest(CustomConvertersSqlServerFixture fixture)
: base(fixture)
{
Fixture.TestSqlLoggerFactory.Clear();
}

[ConditionalFact]
Expand Down Expand Up @@ -194,6 +195,48 @@ public virtual void Columns_have_expected_data_types()
Assert.Equal(expected, actual, ignoreLineEndingDifferences: true);
}

[ConditionalFact]
public override void Value_conversion_is_appropriately_used_for_join_condition()
{
base.Value_conversion_is_appropriately_used_for_join_condition();

AssertSql(
@"@__blogId_0='1'
SELECT [b].[Url]
FROM [Blog] AS [b]
INNER JOIN [Post] AS [p] ON (([b].[BlogId] = [p].[BlogId]) AND ([b].[IsVisible] = N'Y')) AND ([b].[BlogId] = @__blogId_0)
WHERE [b].[Discriminator] IN (N'Blog', N'RssBlog') AND ([b].[IsVisible] = N'Y')");
}

[ConditionalFact]
public override void Value_conversion_is_appropriately_used_for_left_join_condition()
{
base.Value_conversion_is_appropriately_used_for_left_join_condition();

AssertSql(
@"@__blogId_0='1'
SELECT [b].[Url]
FROM [Blog] AS [b]
LEFT JOIN [Post] AS [p] ON (([b].[BlogId] = [p].[BlogId]) AND ([b].[IsVisible] = N'Y')) AND ([b].[BlogId] = @__blogId_0)
WHERE [b].[Discriminator] IN (N'Blog', N'RssBlog') AND ([b].[IsVisible] = N'Y')");
}

[ConditionalFact]
public override void Where_bool_gets_converted_to_equality_when_value_conversion_is_used()
{
base.Where_bool_gets_converted_to_equality_when_value_conversion_is_used();

AssertSql(
@"SELECT [b].[BlogId], [b].[Discriminator], [b].[IsVisible], [b].[Url], [b].[RssUrl]
FROM [Blog] AS [b]
WHERE [b].[Discriminator] IN (N'Blog', N'RssBlog') AND ([b].[IsVisible] = N'Y')");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

public class CustomConvertersSqlServerFixture : CustomConvertersFixtureBase
{
public override bool StrictEquality => true;
Expand All @@ -205,7 +248,7 @@ public class CustomConvertersSqlServerFixture : CustomConvertersFixtureBase
public override bool SupportsLargeStringComparisons => true;

protected override ITestStoreFactory TestStoreFactory => SqlServerTestStoreFactory.Instance;

public TestSqlLoggerFactory TestSqlLoggerFactory => (TestSqlLoggerFactory)ListLoggerFactory;
public override bool SupportsBinaryKeys => true;

public override bool SupportsDecimalComparisons => true;
Expand Down
31 changes: 28 additions & 3 deletions test/EFCore.Sqlite.FunctionalTests/CustomConvertersSqliteTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,56 @@ public class CustomConvertersSqliteTest : CustomConvertersTestBase<CustomConvert
public CustomConvertersSqliteTest(CustomConvertersSqliteFixture fixture)
: base(fixture)
{
Fixture.TestSqlLoggerFactory.Clear();
}

// Disabled: SQLite database is case-sensitive
public override void Can_insert_and_read_back_with_case_insensitive_string_key()
{
}

[ConditionalFact(Skip = "Issue#18147")]
[ConditionalFact]
public override void Value_conversion_is_appropriately_used_for_join_condition()
{
base.Value_conversion_is_appropriately_used_for_join_condition();

AssertSql(
@"@__blogId_0='1' (DbType = String)
SELECT ""b"".""Url""
FROM ""Blog"" AS ""b""
INNER JOIN ""Post"" AS ""p"" ON ((""b"".""BlogId"" = ""p"".""BlogId"") AND (""b"".""IsVisible"" = 'Y')) AND (""b"".""BlogId"" = @__blogId_0)
WHERE ""b"".""Discriminator"" IN ('Blog', 'RssBlog') AND (""b"".""IsVisible"" = 'Y')");
}

[ConditionalFact(Skip = "Issue#18147")]
[ConditionalFact]
public override void Value_conversion_is_appropriately_used_for_left_join_condition()
{
base.Value_conversion_is_appropriately_used_for_left_join_condition();

AssertSql(
@"@__blogId_0='1' (DbType = String)
SELECT ""b"".""Url""
FROM ""Blog"" AS ""b""
LEFT JOIN ""Post"" AS ""p"" ON ((""b"".""BlogId"" = ""p"".""BlogId"") AND (""b"".""IsVisible"" = 'Y')) AND (""b"".""BlogId"" = @__blogId_0)
WHERE ""b"".""Discriminator"" IN ('Blog', 'RssBlog') AND (""b"".""IsVisible"" = 'Y')");
}

[ConditionalFact(Skip = "Issue#18147")]
[ConditionalFact]
public override void Where_bool_gets_converted_to_equality_when_value_conversion_is_used()
{
base.Where_bool_gets_converted_to_equality_when_value_conversion_is_used();

AssertSql(
@"SELECT ""b"".""BlogId"", ""b"".""Discriminator"", ""b"".""IsVisible"", ""b"".""Url"", ""b"".""RssUrl""
FROM ""Blog"" AS ""b""
WHERE ""b"".""Discriminator"" IN ('Blog', 'RssBlog') AND (""b"".""IsVisible"" = 'Y')");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

public class CustomConvertersSqliteFixture : CustomConvertersFixtureBase
{
public override bool StrictEquality => false;
Expand Down

0 comments on commit 6b3dbdd

Please sign in to comment.