From 717f82727cc318956a3814efb149ea9208993e9d Mon Sep 17 00:00:00 2001 From: AndriySvyryd Date: Mon, 1 Jul 2019 16:01:58 -0700 Subject: [PATCH] Add support for eager loaded navigations --- ...osShapedQueryCompilingExpressionVisitor.cs | 132 +++---- ...ionalProjectionBindingExpressionVisitor.cs | 77 ++-- ...aperExpressionDedupingExpressionVisitor.cs | 9 +- .../NavigationExpansion/IncludeHelpers.cs | 2 +- ...terializeCollectionNavigationExpression.cs | 2 - .../NavigationExpansion/NavigationExpander.cs | 31 -- .../NavigationExpansionHelpers.cs | 22 +- ...odeExpansionMode.cs => NavigationState.cs} | 17 +- .../NavigationExpansion/NavigationTreeNode.cs | 60 +-- .../NavigationTreeNodeIncludeMode.cs | 28 -- .../Visitors/IncludeApplyingVisitor.cs | 110 ------ .../Visitors/NavigationExpandingVisitor.cs | 40 +- .../NavigationExpandingVisitor_MethodCall.cs | 346 +++++++++--------- .../NavigationExpansionReducingVisitor.cs | 178 ++++----- .../NavigationPropertyBindingVisitor.cs | 27 +- .../Visitors/PendingIncludeFindingVisitor.cs | 110 ------ .../Visitors/PendingSelectorIncludeVisitor.cs | 160 ++++++++ ...ntityEqualityRewritingExpressionVisitor.cs | 4 +- .../QueryOptimizingExpressionVisitor.cs | 5 +- ...QueryOptimizingExpressionVisitorFactory.cs | 4 +- .../Query/OwnedQueryCosmosTest.cs | 17 +- .../TestUtilities/CosmosTestStore.cs | 4 +- .../TestUtilities/TestEnvironment.cs | 8 +- .../OptimisticConcurrencyTestBase.cs | 1 - 24 files changed, 646 insertions(+), 748 deletions(-) delete mode 100644 src/EFCore/Query/NavigationExpansion/NavigationExpander.cs rename src/EFCore/Query/NavigationExpansion/{NavigationTreeNodeExpansionMode.cs => NavigationState.cs} (50%) delete mode 100644 src/EFCore/Query/NavigationExpansion/NavigationTreeNodeIncludeMode.cs delete mode 100644 src/EFCore/Query/NavigationExpansion/Visitors/IncludeApplyingVisitor.cs delete mode 100644 src/EFCore/Query/NavigationExpansion/Visitors/PendingIncludeFindingVisitor.cs create mode 100644 src/EFCore/Query/NavigationExpansion/Visitors/PendingSelectorIncludeVisitor.cs diff --git a/src/EFCore.Cosmos/Query/Pipeline/CosmosShapedQueryCompilingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Pipeline/CosmosShapedQueryCompilingExpressionVisitor.cs index b98b0ab41cb..b401e1c04ff 100644 --- a/src/EFCore.Cosmos/Query/Pipeline/CosmosShapedQueryCompilingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Pipeline/CosmosShapedQueryCompilingExpressionVisitor.cs @@ -18,6 +18,7 @@ using Microsoft.EntityFrameworkCore.Metadata.Internal; using Microsoft.EntityFrameworkCore.Query; using Microsoft.EntityFrameworkCore.Query.Internal; +using Microsoft.EntityFrameworkCore.Query.NavigationExpansion; using Microsoft.EntityFrameworkCore.Query.Pipeline; using Microsoft.EntityFrameworkCore.Storage; using Newtonsoft.Json.Linq; @@ -169,77 +170,80 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp protected override Expression VisitExtension(Expression extensionExpression) { - if (extensionExpression is ProjectionBindingExpression projectionBindingExpression) + switch (extensionExpression) { - var projectionIndex = (int)GetProjectionIndex(projectionBindingExpression); - var projection = _selectExpression.Projection[projectionIndex]; - - return CreateGetStoreValueExpression( - _jObjectParameter, - projection.Alias, - ((SqlExpression)projection.Expression).TypeMapping, - projectionBindingExpression.Type); - } - - if (extensionExpression is EntityShaperExpression shaperExpression) - { - _currentEntityIndex++; - - var jObjectVariable = Expression.Variable(typeof(JObject), - "jObject" + _currentEntityIndex); - var variables = new List { jObjectVariable }; + case ProjectionBindingExpression projectionBindingExpression: + { + var projectionIndex = (int)GetProjectionIndex(projectionBindingExpression); + var projection = _selectExpression.Projection[projectionIndex]; + + return CreateGetStoreValueExpression( + _jObjectParameter, + projection.Alias, + ((SqlExpression)projection.Expression).TypeMapping, + projectionBindingExpression.Type); + } - var expressions = new List(); + case EntityShaperExpression shaperExpression: + { + _currentEntityIndex++; - if (shaperExpression.ParentNavigation == null) - { - var projectionIndex = (int)GetProjectionIndex((ProjectionBindingExpression)shaperExpression.ValueBufferExpression); - var projection = _selectExpression.Projection[projectionIndex]; + var jObjectVariable = Expression.Variable(typeof(JObject), + "jObject" + _currentEntityIndex); + var variables = new List { jObjectVariable }; - expressions.Add( - Expression.Assign( - jObjectVariable, - Expression.TypeAs( - CreateReadJTokenExpression(_jObjectParameter, projection.Alias), - typeof(JObject)))); + var expressions = new List(); - shaperExpression = shaperExpression.Update( - shaperExpression.ValueBufferExpression, - GetNestedShapers(shaperExpression.EntityType, shaperExpression.ValueBufferExpression)); - } - else - { - var methodCallExpression = (MethodCallExpression)shaperExpression.ValueBufferExpression; - Debug.Assert(methodCallExpression.Method.IsGenericMethod - && methodCallExpression.Method.GetGenericMethodDefinition() == EntityMaterializerSource.TryReadValueMethod); - - var navigation = (INavigation)((ConstantExpression)methodCallExpression.Arguments[2]).Value; - - expressions.Add( - Expression.Assign( - jObjectVariable, - Expression.TypeAs( - CreateReadJTokenExpression(_jObjectParameter, navigation.GetTargetType().GetCosmosContainingPropertyName()), - typeof(JObject)))); - } + if (shaperExpression.ParentNavigation == null) + { + var projectionIndex = (int)GetProjectionIndex((ProjectionBindingExpression)shaperExpression.ValueBufferExpression); + var projection = _selectExpression.Projection[projectionIndex]; + + expressions.Add( + Expression.Assign( + jObjectVariable, + Expression.TypeAs( + CreateReadJTokenExpression(_jObjectParameter, projection.Alias), + typeof(JObject)))); + + shaperExpression = shaperExpression.Update( + shaperExpression.ValueBufferExpression, + GetNestedShapers(shaperExpression.EntityType, shaperExpression.ValueBufferExpression)); + } + else + { + var methodCallExpression = (MethodCallExpression)shaperExpression.ValueBufferExpression; + Debug.Assert(methodCallExpression.Method.IsGenericMethod + && methodCallExpression.Method.GetGenericMethodDefinition() == EntityMaterializerSource.TryReadValueMethod); + + var navigation = (INavigation)((ConstantExpression)methodCallExpression.Arguments[2]).Value; + + expressions.Add( + Expression.Assign( + jObjectVariable, + Expression.TypeAs( + CreateReadJTokenExpression(_jObjectParameter, navigation.GetTargetType().GetCosmosContainingPropertyName()), + typeof(JObject)))); + } - var parentJObject = _jObjectParameter; - _jObjectParameter = jObjectVariable; - expressions.Add(Expression.Condition( - Expression.Equal(jObjectVariable, Expression.Constant(null, jObjectVariable.Type)), - Expression.Constant(null, shaperExpression.Type), - Visit(_shapedQueryCompilingExpressionVisitor.InjectEntityMaterializer(shaperExpression)))); - _jObjectParameter = parentJObject; - - return Expression.Block( - shaperExpression.Type, - variables, - expressions); - } + var parentJObject = _jObjectParameter; + _jObjectParameter = jObjectVariable; + expressions.Add(Expression.Condition( + Expression.Equal(jObjectVariable, Expression.Constant(null, jObjectVariable.Type)), + Expression.Constant(null, shaperExpression.Type), + Visit(_shapedQueryCompilingExpressionVisitor.InjectEntityMaterializer(shaperExpression)))); + _jObjectParameter = parentJObject; + + return Expression.Block( + shaperExpression.Type, + variables, + expressions); + } - if (extensionExpression is CollectionShaperExpression collectionShaperExpression) - { - throw new NotImplementedException(); + case CollectionShaperExpression collectionShaperExpression: + throw new NotImplementedException(); + case IncludeExpression includeExpression: + return includeExpression; } return base.VisitExtension(extensionExpression); diff --git a/src/EFCore.Relational/Query/Pipeline/RelationalProjectionBindingExpressionVisitor.cs b/src/EFCore.Relational/Query/Pipeline/RelationalProjectionBindingExpressionVisitor.cs index afaa6e83959..2bf38b8d57d 100644 --- a/src/EFCore.Relational/Query/Pipeline/RelationalProjectionBindingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/Pipeline/RelationalProjectionBindingExpressionVisitor.cs @@ -81,60 +81,51 @@ public override Expression Visit(Expression expression) if (_clientEval) { - if (expression is ConstantExpression) + switch (expression) { - return expression; - } - - if (expression is ParameterExpression parameterExpression) - { - return Expression.Call( - _getParameterValueMethodInfo.MakeGenericMethod(parameterExpression.Type), - QueryCompilationContext.QueryContextParameter, - Expression.Constant(parameterExpression.Name)); - } - - if (expression is MaterializeCollectionNavigationExpression materializeCollectionNavigationExpression) - { - return _selectExpression.AddCollectionProjection( - _queryableMethodTranslatingExpressionVisitor.TranslateSubquery( + case ConstantExpression _: + return expression; + case ParameterExpression parameterExpression: + return Expression.Call( + _getParameterValueMethodInfo.MakeGenericMethod(parameterExpression.Type), + QueryCompilationContext.QueryContextParameter, + Expression.Constant(parameterExpression.Name)); + case MaterializeCollectionNavigationExpression materializeCollectionNavigationExpression: + return _selectExpression.AddCollectionProjection( + _queryableMethodTranslatingExpressionVisitor.TranslateSubquery( materializeCollectionNavigationExpression.Subquery), - materializeCollectionNavigationExpression.Navigation, null); - } + materializeCollectionNavigationExpression.Navigation, null); + case MethodCallExpression methodCallExpression: + { + if (methodCallExpression.Method.IsGenericMethod + && methodCallExpression.Method.DeclaringType == typeof(Enumerable) + && methodCallExpression.Method.Name == nameof(Enumerable.ToList)) + { + var elementType = methodCallExpression.Method.GetGenericArguments()[0]; - if (expression is MethodCallExpression methodCallExpression) - { - if (methodCallExpression.Method.IsGenericMethod - && methodCallExpression.Method.DeclaringType == typeof(Enumerable) - && methodCallExpression.Method.Name == nameof(Enumerable.ToList)) - { - var elementType = methodCallExpression.Method.GetGenericArguments()[0]; + var result = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(methodCallExpression.Arguments[0]); - var result = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(methodCallExpression.Arguments[0]); + return _selectExpression.AddCollectionProjection(result, null, elementType); + } - return _selectExpression.AddCollectionProjection(result, null, elementType); - } + var subquery = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(methodCallExpression); - var subquery = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(methodCallExpression); + if (subquery != null) + { + if (subquery.ResultType == ResultType.Enumerable) + { + return _selectExpression.AddCollectionProjection(subquery, null, subquery.ShaperExpression.Type); + } + } - if (subquery != null) - { - if (subquery.ResultType == ResultType.Enumerable) - { - return _selectExpression.AddCollectionProjection(subquery, null, subquery.ShaperExpression.Type); + break; } - } } var translation = _sqlTranslator.Translate(expression); - if (translation == null) - { - return base.Visit(expression); - } - else - { - return new ProjectionBindingExpression(_selectExpression, _selectExpression.AddToProjection(translation), expression.Type); - } + return translation == null + ? base.Visit(expression) + : new ProjectionBindingExpression(_selectExpression, _selectExpression.AddToProjection(translation), expression.Type); } else { diff --git a/src/EFCore.Relational/Query/Pipeline/ShaperExpressionDedupingExpressionVisitor.cs b/src/EFCore.Relational/Query/Pipeline/ShaperExpressionDedupingExpressionVisitor.cs index 147e05bcaf0..a443f9286c2 100644 --- a/src/EFCore.Relational/Query/Pipeline/ShaperExpressionDedupingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/Pipeline/ShaperExpressionDedupingExpressionVisitor.cs @@ -2,7 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Collections.Generic; -using System.Diagnostics; using System.Linq; using System.Linq.Expressions; using Microsoft.EntityFrameworkCore.Query.NavigationExpansion; @@ -71,8 +70,7 @@ public Expression Inject(Expression expression) } private LambdaExpression ConvertToLambda(Expression result, ParameterExpression resultParameter) - { - return _indexMapParameter != null + => _indexMapParameter != null ? Expression.Lambda( result, QueryCompilationContext.QueryContextParameter, @@ -86,7 +84,6 @@ private LambdaExpression ConvertToLambda(Expression result, ParameterExpression _dataReaderParameter, resultParameter, _resultCoordinatorParameter); - } protected override Expression VisitExtension(Expression extensionExpression) { @@ -186,10 +183,8 @@ protected override Expression VisitExtension(Expression extensionExpression) } private Expression GenerateKey(ProjectionBindingExpression projectionBindingExpression) - { - return projectionBindingExpression.ProjectionMember != null + => projectionBindingExpression.ProjectionMember != null ? _selectExpression.GetMappedProjection(projectionBindingExpression.ProjectionMember) : projectionBindingExpression; - } } } diff --git a/src/EFCore/Query/NavigationExpansion/IncludeHelpers.cs b/src/EFCore/Query/NavigationExpansion/IncludeHelpers.cs index 9298b93f15a..5efff149b4f 100644 --- a/src/EFCore/Query/NavigationExpansion/IncludeHelpers.cs +++ b/src/EFCore/Query/NavigationExpansion/IncludeHelpers.cs @@ -11,7 +11,7 @@ public static class IncludeHelpers { public static void CopyIncludeInformation(NavigationTreeNode originalNavigationTree, NavigationTreeNode newNavigationTree, SourceMapping newSourceMapping) { - foreach (var child in originalNavigationTree.Children.Where(n => n.Included == NavigationTreeNodeIncludeMode.ReferencePending || n.Included == NavigationTreeNodeIncludeMode.Collection)) + foreach (var child in originalNavigationTree.Children.Where(n => n.IncludeState == NavigationState.ReferencePending || n.IncludeState == NavigationState.CollectionPending)) { var copy = NavigationTreeNode.Create(newSourceMapping, child.Navigation, newNavigationTree, true); CopyIncludeInformation(child, copy, newSourceMapping); diff --git a/src/EFCore/Query/NavigationExpansion/MaterializeCollectionNavigationExpression.cs b/src/EFCore/Query/NavigationExpansion/MaterializeCollectionNavigationExpression.cs index ac5e6cf77cc..ccb21d94086 100644 --- a/src/EFCore/Query/NavigationExpansion/MaterializeCollectionNavigationExpression.cs +++ b/src/EFCore/Query/NavigationExpansion/MaterializeCollectionNavigationExpression.cs @@ -39,7 +39,5 @@ protected override Expression VisitChildren(ExpressionVisitor visitor) ? new MaterializeCollectionNavigationExpression(subquery, Navigation) : this; } - - } } diff --git a/src/EFCore/Query/NavigationExpansion/NavigationExpander.cs b/src/EFCore/Query/NavigationExpansion/NavigationExpander.cs deleted file mode 100644 index 6b4335deed4..00000000000 --- a/src/EFCore/Query/NavigationExpansion/NavigationExpander.cs +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System.Linq.Expressions; -using JetBrains.Annotations; -using Microsoft.EntityFrameworkCore.Metadata; -using Microsoft.EntityFrameworkCore.Query.NavigationExpansion.Visitors; -using Microsoft.EntityFrameworkCore.Utilities; - -namespace Microsoft.EntityFrameworkCore.Query.NavigationExpansion -{ - public class NavigationExpander - { - private IModel _model; - - public NavigationExpander([NotNull] IModel model) - { - Check.NotNull(model, nameof(model)); - - _model = model; - } - - public virtual Expression ExpandNavigations(Expression expression) - { - var newExpression = new NavigationExpandingVisitor(_model).Visit(expression); - newExpression = new NavigationExpansionReducingVisitor().Visit(newExpression); - - return newExpression; - } - } -} diff --git a/src/EFCore/Query/NavigationExpansion/NavigationExpansionHelpers.cs b/src/EFCore/Query/NavigationExpansion/NavigationExpansionHelpers.cs index 1e830dd8df1..c182cbd131b 100644 --- a/src/EFCore/Query/NavigationExpansion/NavigationExpansionHelpers.cs +++ b/src/EFCore/Query/NavigationExpansion/NavigationExpansionHelpers.cs @@ -56,7 +56,7 @@ public static NavigationExpansionExpression CreateNavigationExpansionRoot( private static readonly MethodInfo _leftJoinMethodInfo = typeof(QueryableExtensions).GetTypeInfo() .GetDeclaredMethods(nameof(QueryableExtensions.LeftJoin)).Single(mi => mi.GetParameters().Length == 5); - public static (Expression source, ParameterExpression parameter) AddNavigationJoin( + public static (Expression Source, ParameterExpression Parameter) AddNavigationJoin( Expression sourceExpression, ParameterExpression parameterExpression, SourceMapping sourceMapping, @@ -66,8 +66,8 @@ public static (Expression source, ParameterExpression parameter) AddNavigationJo bool include) { var joinNeeded = include - ? navigationTree.Included == NavigationTreeNodeIncludeMode.ReferencePending - : navigationTree.ExpansionMode == NavigationTreeNodeExpansionMode.ReferencePending; + ? navigationTree.IncludeState == NavigationState.ReferencePending + : navigationTree.ExpansionState == NavigationState.ReferencePending; if (joinNeeded) { @@ -176,11 +176,11 @@ var transparentIdentifierCtorInfo foreach (var mapping in state.SourceMappings) { var nodes = include - ? mapping.NavigationTree.Flatten().Where(n => (n.Included == NavigationTreeNodeIncludeMode.ReferenceComplete - || n.ExpansionMode == NavigationTreeNodeExpansionMode.ReferenceComplete + ? mapping.NavigationTree.Flatten().Where(n => (n.IncludeState == NavigationState.ReferenceComplete + || n.ExpansionState == NavigationState.ReferenceComplete || n.Navigation.ForeignKey.IsOwnership) && n != navigationTree) - : mapping.NavigationTree.Flatten().Where(n => (n.ExpansionMode == NavigationTreeNodeExpansionMode.ReferenceComplete + : mapping.NavigationTree.Flatten().Where(n => (n.ExpansionState == NavigationState.ReferenceComplete || n.Navigation.ForeignKey.IsOwnership) && n != navigationTree); @@ -197,11 +197,11 @@ var transparentIdentifierCtorInfo if (include) { - navigationTree.Included = NavigationTreeNodeIncludeMode.ReferenceComplete; + navigationTree.IncludeState = NavigationState.ReferenceComplete; } else { - navigationTree.ExpansionMode = NavigationTreeNodeExpansionMode.ReferenceComplete; + navigationTree.ExpansionState = NavigationState.ReferenceComplete; } navigationPath.Add(navigation); @@ -211,12 +211,12 @@ var transparentIdentifierCtorInfo navigationPath.Add(navigationTree.Navigation); } - var result = (source: sourceExpression, parameter: parameterExpression); + var result = (Source: sourceExpression, Parameter: parameterExpression); foreach (var child in navigationTree.Children.Where(n => !n.Navigation.IsCollection())) { result = AddNavigationJoin( - result.source, - result.parameter, + result.Source, + result.Parameter, sourceMapping, child, state, diff --git a/src/EFCore/Query/NavigationExpansion/NavigationTreeNodeExpansionMode.cs b/src/EFCore/Query/NavigationExpansion/NavigationState.cs similarity index 50% rename from src/EFCore/Query/NavigationExpansion/NavigationTreeNodeExpansionMode.cs rename to src/EFCore/Query/NavigationExpansion/NavigationState.cs index cce53329400..2e2f07150eb 100644 --- a/src/EFCore/Query/NavigationExpansion/NavigationTreeNodeExpansionMode.cs +++ b/src/EFCore/Query/NavigationExpansion/NavigationState.cs @@ -3,26 +3,31 @@ namespace Microsoft.EntityFrameworkCore.Query.NavigationExpansion { - public enum NavigationTreeNodeExpansionMode + public enum NavigationState { /// - /// Navigation doesn't need to be expanded + /// Navigation doesn't need to be processed /// NotNeeded, /// - /// Reference navigation needs to be expanded, but hasn't been expanded yet + /// Reference navigation needs to be processed, but hasn't been processed yet /// ReferencePending, /// - /// Reference navigation has already been expanded + /// Reference navigation has already been processed /// ReferenceComplete, /// - /// Collection navigation needs to be expanded + /// Collection navigation needs to be processed, but hasn't been processed yet /// - Collection, + CollectionPending, + + /// + /// Collection navigation has already been processed + /// + CollectionComplete, }; } diff --git a/src/EFCore/Query/NavigationExpansion/NavigationTreeNode.cs b/src/EFCore/Query/NavigationExpansion/NavigationTreeNode.cs index c18e0446e9c..8696f5449bc 100644 --- a/src/EFCore/Query/NavigationExpansion/NavigationTreeNode.cs +++ b/src/EFCore/Query/NavigationExpansion/NavigationTreeNode.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Linq; +using System.Text; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Utilities; @@ -26,17 +27,17 @@ private NavigationTreeNode( ToMapping = new List(); if (include) { - ExpansionMode = NavigationTreeNodeExpansionMode.NotNeeded; - Included = navigation.IsCollection() - ? NavigationTreeNodeIncludeMode.Collection - : NavigationTreeNodeIncludeMode.ReferencePending; + ExpansionState = NavigationState.NotNeeded; + IncludeState = navigation.IsCollection() + ? NavigationState.CollectionPending + : NavigationState.ReferencePending; } else { - ExpansionMode = navigation.IsCollection() - ? NavigationTreeNodeExpansionMode.Collection - : NavigationTreeNodeExpansionMode.ReferencePending; - Included = NavigationTreeNodeIncludeMode.NotNeeded; + ExpansionState = navigation.IsCollection() + ? NavigationState.CollectionPending + : NavigationState.ReferencePending; + IncludeState = NavigationState.NotNeeded; } // for ownership don't mark for include or expansion @@ -44,8 +45,8 @@ private NavigationTreeNode( // they will be expanded/translated later in the pipeline if (navigation.ForeignKey.IsOwnership) { - ExpansionMode = NavigationTreeNodeExpansionMode.NotNeeded; - Included = NavigationTreeNodeIncludeMode.NotNeeded; + ExpansionState = NavigationState.NotNeeded; + IncludeState = NavigationState.NotNeeded; ToMapping = parent.ToMapping.ToList(); ToMapping.Add(navigation.Name); @@ -66,16 +67,16 @@ private NavigationTreeNode( Optional = optional; FromMappings.Add(fromMapping.ToList()); ToMapping = fromMapping.ToList(); - ExpansionMode = NavigationTreeNodeExpansionMode.ReferenceComplete; - Included = NavigationTreeNodeIncludeMode.NotNeeded; + ExpansionState = NavigationState.ReferenceComplete; + IncludeState = NavigationState.NotNeeded; } public INavigation Navigation { get; private set; } public bool Optional { get; private set; } public NavigationTreeNode Parent { get; private set; } public List Children { get; private set; } = new List(); - public NavigationTreeNodeExpansionMode ExpansionMode { get; set; } - public NavigationTreeNodeIncludeMode Included { get; set; } + public NavigationState ExpansionState { get; set; } + public NavigationState IncludeState { get; set; } public List> FromMappings { get; set; } = new List>(); public List ToMapping { get; set; } @@ -106,17 +107,17 @@ public static NavigationTreeNode Create( { if (!navigation.ForeignKey.IsOwnership) { - if (include && existingChild.Included == NavigationTreeNodeIncludeMode.NotNeeded) + if (include && existingChild.IncludeState == NavigationState.NotNeeded) { - existingChild.Included = navigation.IsCollection() - ? NavigationTreeNodeIncludeMode.Collection - : NavigationTreeNodeIncludeMode.ReferencePending; + existingChild.IncludeState = navigation.IsCollection() + ? NavigationState.CollectionPending + : NavigationState.ReferencePending; } - else if (!include && existingChild.ExpansionMode == NavigationTreeNodeExpansionMode.NotNeeded) + else if (!include && existingChild.ExpansionState == NavigationState.NotNeeded) { - existingChild.ExpansionMode = navigation.IsCollection() - ? NavigationTreeNodeExpansionMode.Collection - : NavigationTreeNodeExpansionMode.ReferencePending; + existingChild.ExpansionState = navigation.IsCollection() + ? NavigationState.CollectionPending + : NavigationState.ReferencePending; } } @@ -147,5 +148,20 @@ public void MakeOptional() { Optional = true; } + + public override string ToString() + { + var builder = new StringBuilder(); + builder.Append("'"); + builder.Append(Navigation?.Name ?? ""); + builder.Append("' Expand: '"); + builder.Append(ExpansionState); + builder.Append("' Include: '"); + builder.Append(IncludeState); + builder.Append("' Children: "); + builder.Append(Children.Count); + + return builder.ToString(); + } } } diff --git a/src/EFCore/Query/NavigationExpansion/NavigationTreeNodeIncludeMode.cs b/src/EFCore/Query/NavigationExpansion/NavigationTreeNodeIncludeMode.cs deleted file mode 100644 index 973c3271681..00000000000 --- a/src/EFCore/Query/NavigationExpansion/NavigationTreeNodeIncludeMode.cs +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -namespace Microsoft.EntityFrameworkCore.Query.NavigationExpansion -{ - public enum NavigationTreeNodeIncludeMode - { - /// - /// Navigation doesn't need to be included - /// - NotNeeded, - - /// - /// Navigation needs to be included, but hasn't been included yet - /// - ReferencePending, - - /// - /// Navigation has already been included - /// - ReferenceComplete, - - /// - /// Collection navigation needs to be included - /// - Collection, - }; -} diff --git a/src/EFCore/Query/NavigationExpansion/Visitors/IncludeApplyingVisitor.cs b/src/EFCore/Query/NavigationExpansion/Visitors/IncludeApplyingVisitor.cs deleted file mode 100644 index d53afdf9544..00000000000 --- a/src/EFCore/Query/NavigationExpansion/Visitors/IncludeApplyingVisitor.cs +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System.Linq; -using System.Linq.Expressions; -using Microsoft.EntityFrameworkCore.Internal; - -namespace Microsoft.EntityFrameworkCore.Query.NavigationExpansion.Visitors -{ - public class PendingSelectorIncludeRewriter : ExpressionVisitor - { - protected override Expression VisitMember(MemberExpression memberExpression) - { - if (memberExpression.Expression is NavigationBindingExpression navigationBindingExpression - && navigationBindingExpression.EntityType.FindProperty(memberExpression.Member) != null) - { - return memberExpression; - } - - var newExpression = Visit(memberExpression.Expression); - - return newExpression != memberExpression.Expression - ? Expression.MakeMemberAccess(newExpression, memberExpression.Member) - : memberExpression; - } - - protected override Expression VisitInvocation(InvocationExpression invocationExpression) => invocationExpression; - protected override Expression VisitLambda(Expression lambdaExpression) => lambdaExpression; - protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExpression) => typeBinaryExpression; - - protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) - => methodCallExpression.IsEFProperty() - ? methodCallExpression - : base.VisitMethodCall(methodCallExpression); - - protected override Expression VisitConditional(ConditionalExpression conditionalExpression) - { - var newIfTrue = Visit(conditionalExpression.IfTrue); - var newIfFalse = Visit(conditionalExpression.IfFalse); - - return newIfTrue != conditionalExpression.IfTrue || newIfFalse != conditionalExpression.IfFalse - ? conditionalExpression.Update(conditionalExpression.Test, newIfTrue, newIfFalse) - : conditionalExpression; - } - - protected override Expression VisitBinary(BinaryExpression binaryExpression) - { - return binaryExpression.NodeType == ExpressionType.Coalesce - ? base.VisitBinary(binaryExpression) - : binaryExpression; - } - - protected override Expression VisitExtension(Expression extensionExpression) - { - if (extensionExpression is NavigationBindingExpression navigationBindingExpression) - { - var result = (Expression)navigationBindingExpression; - - foreach (var child in navigationBindingExpression.NavigationTreeNode.Children.Where(n => n.Included == NavigationTreeNodeIncludeMode.ReferencePending || n.Included == NavigationTreeNodeIncludeMode.Collection)) - { - result = CreateIncludeCall(result, child, navigationBindingExpression.RootParameter, navigationBindingExpression.SourceMapping); - } - - return result; - } - - if (extensionExpression is CustomRootExpression customRootExpression) - { - return customRootExpression; - } - - if (extensionExpression is NavigationExpansionRootExpression expansionRootExpression) - { - return expansionRootExpression; - } - - if (extensionExpression is NavigationExpansionExpression navigationExpansionExpression) - { - return navigationExpansionExpression; - } - - return base.VisitExtension(extensionExpression); - } - - private IncludeExpression CreateIncludeCall(Expression caller, NavigationTreeNode node, ParameterExpression rootParameter, SourceMapping sourceMapping) - => node.Navigation.IsCollection() - ? CreateIncludeCollectionCall(caller, node, rootParameter, sourceMapping) - : CreateIncludeReferenceCall(caller, node, rootParameter, sourceMapping); - - private IncludeExpression CreateIncludeReferenceCall(Expression caller, NavigationTreeNode node, ParameterExpression rootParameter, SourceMapping sourceMapping) - { - var entityType = node.Navigation.GetTargetType(); - var included = (Expression)new NavigationBindingExpression(rootParameter, node, entityType, sourceMapping, entityType.ClrType); - - foreach (var child in node.Children.Where(n => n.Included == NavigationTreeNodeIncludeMode.ReferencePending || n.Included == NavigationTreeNodeIncludeMode.Collection)) - { - included = CreateIncludeCall(included, child, rootParameter, sourceMapping); - } - - return new IncludeExpression(caller, included, node.Navigation); - } - - private IncludeExpression CreateIncludeCollectionCall(Expression caller, NavigationTreeNode node, ParameterExpression rootParameter, SourceMapping sourceMapping) - { - var included = CollectionNavigationRewritingVisitor.CreateCollectionNavigationExpression(node, rootParameter, sourceMapping); - - return new IncludeExpression(caller, included, node.Navigation); - } - } -} diff --git a/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpandingVisitor.cs b/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpandingVisitor.cs index d330aa239d6..18f8c3d8152 100644 --- a/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpandingVisitor.cs +++ b/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpandingVisitor.cs @@ -5,6 +5,7 @@ using System.Linq; using System.Linq.Expressions; using System.Reflection; +using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.Internal; using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Query.Internal; @@ -14,38 +15,23 @@ namespace Microsoft.EntityFrameworkCore.Query.NavigationExpansion.Visitors { public partial class NavigationExpandingVisitor : ExpressionVisitor { - private readonly IModel _model; + private readonly QueryCompilationContext _queryCompilationContext; - public NavigationExpandingVisitor(IModel model) + public NavigationExpandingVisitor([NotNull] QueryCompilationContext queryCompilationContext) { - _model = model; + _queryCompilationContext = queryCompilationContext; } protected override Expression VisitExtension(Expression extensionExpression) { - if (extensionExpression is NavigationBindingExpression navigationBindingExpression) + switch (extensionExpression) { - return navigationBindingExpression; - } - - if (extensionExpression is CustomRootExpression customRootExpression) - { - return customRootExpression; - } - - if (extensionExpression is NavigationExpansionRootExpression navigationExpansionRootExpression) - { - return navigationExpansionRootExpression; - } - - if (extensionExpression is IncludeExpression includeExpression) - { - return includeExpression; - } - - if (extensionExpression is NavigationExpansionExpression navigationExpansionExpression) - { - return navigationExpansionExpression; + case NavigationBindingExpression _: + case CustomRootExpression _: + case NavigationExpansionRootExpression _: + case NavigationExpansionExpression _: + case IncludeExpression _: + return extensionExpression; } return base.VisitExtension(extensionExpression); @@ -243,7 +229,7 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) if (binaryExpression.Left is MemberExpression leftMember && leftMember.Type.TryGetSequenceType() is Type leftSequenceType && leftSequenceType != null - && _model.FindEntityType(leftMember.Expression.Type) is IEntityType leftParentEntityType) + && _queryCompilationContext.Model.FindEntityType(leftMember.Expression.Type) is IEntityType leftParentEntityType) { leftNavigation = leftParentEntityType.FindNavigation(leftMember.Member.Name); if (leftNavigation != null) @@ -255,7 +241,7 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) if (binaryExpression.Right is MemberExpression rightMember && rightMember.Type.TryGetSequenceType() is Type rightSequenceType && rightSequenceType != null - && _model.FindEntityType(rightMember.Expression.Type) is IEntityType rightParentEntityType) + && _queryCompilationContext.Model.FindEntityType(rightMember.Expression.Type) is IEntityType rightParentEntityType) { rightNavigation = rightParentEntityType.FindNavigation(rightMember.Member.Name); if (rightNavigation != null) diff --git a/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpandingVisitor_MethodCall.cs b/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpandingVisitor_MethodCall.cs index 202d3153f04..5a5b09d0511 100644 --- a/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpandingVisitor_MethodCall.cs +++ b/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpandingVisitor_MethodCall.cs @@ -105,8 +105,8 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp case nameof(Queryable.Except): return ProcessSetOperation(methodCallExpression); - case "Include": - case "ThenInclude": + case nameof(EntityFrameworkQueryableExtensions.Include): + case nameof(EntityFrameworkQueryableExtensions.ThenInclude): return ProcessInclude(methodCallExpression); default: @@ -136,12 +136,12 @@ private Expression ProcessUnknownMethod(MethodCallExpression methodCallExpressio var source = VisitSourceExpression(methodCallExpression.Arguments[0]); var preProcessResult = PreProcessTerminatingOperation(source); var newArguments = methodCallExpression.Arguments.Skip(1).Select(Visit).ToList(); - newArguments.Insert(0, preProcessResult.source); + newArguments.Insert(0, preProcessResult.Source); - var methodInfo = methodCallExpression.Method.GetGenericMethodDefinition().MakeGenericMethod(preProcessResult.state.CurrentParameter.Type); + var methodInfo = methodCallExpression.Method.GetGenericMethodDefinition().MakeGenericMethod(preProcessResult.State.CurrentParameter.Type); var rewritten = Expression.Call(methodInfo, newArguments); - return new NavigationExpansionExpression(rewritten, preProcessResult.state, methodCallExpression.Type); + return new NavigationExpansionExpression(rewritten, preProcessResult.State, methodCallExpression.Type); } } @@ -187,17 +187,18 @@ private Expression ProcessWhere(MethodCallExpression methodCallExpression) var source = VisitSourceExpression(methodCallExpression.Arguments[0]); var predicate = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); - var appliedNavigationsResult = FindAndApplyNavigations(source.Operand, predicate, source.State); - var newPredicateBody = new NavigationPropertyUnbindingVisitor(appliedNavigationsResult.state.CurrentParameter).Visit(appliedNavigationsResult.lambdaBody); - var newPredicateLambda = Expression.Lambda(newPredicateBody, appliedNavigationsResult.state.CurrentParameter); - var appliedOrderingsResult = ApplyPendingOrderings(appliedNavigationsResult.source, appliedNavigationsResult.state); + var (newSource, newLambdaBody, newState) = FindAndApplyNavigations(source.Operand, predicate, source.State); + var newPredicateBody = new NavigationPropertyUnbindingVisitor(newState.CurrentParameter) + .Visit(newLambdaBody); + var newPredicateLambda = Expression.Lambda(newPredicateBody, newState.CurrentParameter); + (newSource, newState) = ApplyPendingOrderings(newSource, newState); - var newMethodInfo = methodCallExpression.Method.GetGenericMethodDefinition().MakeGenericMethod(appliedOrderingsResult.state.CurrentParameter.Type); - var rewritten = Expression.Call(newMethodInfo, appliedOrderingsResult.source, newPredicateLambda); + var newMethodInfo = methodCallExpression.Method.GetGenericMethodDefinition().MakeGenericMethod(newState.CurrentParameter.Type); + var rewritten = Expression.Call(newMethodInfo, newSource, newPredicateLambda); return new NavigationExpansionExpression( rewritten, - appliedOrderingsResult.state, + newState, methodCallExpression.Type); } @@ -211,31 +212,31 @@ private Expression ProcessSelect(MethodCallExpression methodCallExpression) private Expression ProcessSelectCore(Expression source, NavigationExpansionExpressionState state, LambdaExpression selector, Type resultType) { - var appliedNavigationsResult = FindAndApplyNavigations(source, selector, state); - appliedNavigationsResult.state.PendingSelector = Expression.Lambda(appliedNavigationsResult.lambdaBody, appliedNavigationsResult.state.CurrentParameter); + var (newSource, newLambdaBody, newState) = FindAndApplyNavigations(source, selector, state); + newState.PendingSelector = Expression.Lambda(newLambdaBody, newState.CurrentParameter); // we could force apply pending selector only for non-identity projections // but then we lose information about variable names, e.g. ctx.Customers.Select(x => x) - appliedNavigationsResult.state.ApplyPendingSelector = true; + newState.ApplyPendingSelector = true; - var appliedOrderingsResult = ApplyPendingOrderings(appliedNavigationsResult.source, appliedNavigationsResult.state); + (newSource, newState) = ApplyPendingOrderings(newSource, newState); var resultElementType = resultType.TryGetSequenceType(); if (resultElementType != null) { - if (resultElementType != appliedOrderingsResult.state.PendingSelector.Body.Type) + if (resultElementType != newState.PendingSelector.Body.Type) { - resultType = resultType.GetGenericTypeDefinition().MakeGenericType(appliedOrderingsResult.state.PendingSelector.Body.Type); + resultType = resultType.GetGenericTypeDefinition().MakeGenericType(newState.PendingSelector.Body.Type); } } else { - resultType = appliedOrderingsResult.state.PendingSelector.Body.Type; + resultType = newState.PendingSelector.Body.Type; } return new NavigationExpansionExpression( - appliedOrderingsResult.source, - appliedOrderingsResult.state, + newSource, + newState, resultType); } @@ -244,15 +245,15 @@ private Expression ProcessOrderBy(MethodCallExpression methodCallExpression) var source = VisitSourceExpression(methodCallExpression.Arguments[0]); var keySelector = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); - var appliedNavigationsResult = FindAndApplyNavigations(source.Operand, keySelector, source.State); - var pendingOrdering = (method: methodCallExpression.Method.GetGenericMethodDefinition(), keySelector: Expression.Lambda(appliedNavigationsResult.lambdaBody, appliedNavigationsResult.state.CurrentParameter)); - var appliedOrderingsResult = ApplyPendingOrderings(appliedNavigationsResult.source, appliedNavigationsResult.state); + var (newSource, newLambdaBody, newState) = FindAndApplyNavigations(source.Operand, keySelector, source.State); + var pendingOrdering = (method: methodCallExpression.Method.GetGenericMethodDefinition(), keySelector: Expression.Lambda(newLambdaBody, newState.CurrentParameter)); + (newSource, newState) = ApplyPendingOrderings(newSource, newState); - appliedOrderingsResult.state.PendingOrderings.Add(pendingOrdering); + newState.PendingOrderings.Add(pendingOrdering); return new NavigationExpansionExpression( - appliedOrderingsResult.source, - appliedOrderingsResult.state, + newSource, + newState, methodCallExpression.Type); } @@ -261,14 +262,14 @@ private Expression ProcessThenByBy(MethodCallExpression methodCallExpression) var source = VisitSourceExpression(methodCallExpression.Arguments[0]); var keySelector = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); - var appliedNavigationsResult = FindAndApplyNavigations(source.Operand, keySelector, source.State); + var (newSource, newLambdaBody, newState) = FindAndApplyNavigations(source.Operand, keySelector, source.State); - var pendingOrdering = (method: methodCallExpression.Method.GetGenericMethodDefinition(), keySelector: Expression.Lambda(appliedNavigationsResult.lambdaBody, appliedNavigationsResult.state.CurrentParameter)); - appliedNavigationsResult.state.PendingOrderings.Add(pendingOrdering); + var pendingOrdering = (method: methodCallExpression.Method.GetGenericMethodDefinition(), keySelector: Expression.Lambda(newLambdaBody, newState.CurrentParameter)); + newState.PendingOrderings.Add(pendingOrdering); return new NavigationExpansionExpression( - appliedNavigationsResult.source, - appliedNavigationsResult.state, + newSource, + newState, methodCallExpression.Type); } @@ -277,14 +278,11 @@ private Expression ProcessSelectMany(MethodCallExpression methodCallExpression) var outerSourceNee = VisitSourceExpression(methodCallExpression.Arguments[0]); var collectionSelector = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); - var applyNavigationsResult = FindAndApplyNavigations(outerSourceNee.Operand, collectionSelector, outerSourceNee.State); - var applyOrderingsResult = ApplyPendingOrderings(applyNavigationsResult.source, applyNavigationsResult.state); + var (outerSource, outerLambdaBody, outerState) = FindAndApplyNavigations(outerSourceNee.Operand, collectionSelector, outerSourceNee.State); + (outerSource, outerState) = ApplyPendingOrderings(outerSource, outerState); - var outerSource = applyOrderingsResult.source; - var outerState = applyOrderingsResult.state; - - var collectionSelectorNavigationExpansionExpression = applyNavigationsResult.lambdaBody as NavigationExpansionExpression - ?? (applyNavigationsResult.lambdaBody as NavigationExpansionRootExpression)?.Unwrap() as NavigationExpansionExpression; + var collectionSelectorNavigationExpansionExpression = outerLambdaBody as NavigationExpansionExpression + ?? (outerLambdaBody as NavigationExpansionRootExpression)?.Unwrap() as NavigationExpansionExpression; if (collectionSelectorNavigationExpansionExpression != null) { @@ -343,11 +341,11 @@ private Expression ProcessSelectMany(MethodCallExpression methodCallExpression) var pendingSelector = resultSelectorRemap.state.PendingSelector; resultSelectorRemap.state.PendingSelector = Expression.Lambda(resultSelectorRemap.state.PendingSelector.Parameters[0], resultSelectorRemap.state.PendingSelector.Parameters[0]); var result = FindAndApplyNavigations(rewritten, pendingSelector, resultSelectorRemap.state); - result.state.PendingSelector = Expression.Lambda(result.lambdaBody, result.state.CurrentParameter); + result.State.PendingSelector = Expression.Lambda(result.LambdaBody, result.State.CurrentParameter); return new NavigationExpansionExpression( - result.source, - result.state, + result.Source, + result.State, methodCallExpression.Type); } @@ -420,8 +418,8 @@ private void CopyNavigationTree( foreach (var child in originalNavigationTree.Children) { var copy = NavigationTreeNode.Create(newSourceMapping, child.Navigation, newNavigationTree, include: false); - copy.ExpansionMode = child.ExpansionMode; - copy.Included = child.Included; + copy.ExpansionState = child.ExpansionState; + copy.IncludeState = child.IncludeState; // TODO: simply copying ToMapping might not be correct for very complex cases where the child mapping is not purely Inner/Outer but has some properties from previous anonymous projections // we should recognize and filter those out, however this is theoretical at this point - scenario is not supported and likely won't be in the foreseeable future @@ -513,40 +511,44 @@ private Expression ProcessJoin(MethodCallExpression methodCallExpression) var innerKeySelector = methodCallExpression.Arguments[3].UnwrapLambdaFromQuote(); var resultSelector = methodCallExpression.Arguments[4].UnwrapLambdaFromQuote(); - var outerApplyNavigationsResult = FindAndApplyNavigations(outerSource.Operand, outerKeySelector, outerSource.State); - var innerApplyNavigationsResult = FindAndApplyNavigations(innerSource.Operand, innerKeySelector, innerSource.State); + var (newOuterSource, newOuterLambdaBody, newOuterState) + = FindAndApplyNavigations(outerSource.Operand, outerKeySelector, outerSource.State); + var (newInnerSource, newInnerLambdaBody, newInnerState) + = FindAndApplyNavigations(innerSource.Operand, innerKeySelector, innerSource.State); - var newOuterKeySelectorBody = new NavigationPropertyUnbindingVisitor(outerApplyNavigationsResult.state.CurrentParameter).Visit(outerApplyNavigationsResult.lambdaBody); - var newInnerKeySelectorBody = new NavigationPropertyUnbindingVisitor(innerApplyNavigationsResult.state.CurrentParameter).Visit(innerApplyNavigationsResult.lambdaBody); + var newOuterKeySelectorBody = new NavigationPropertyUnbindingVisitor(newOuterState.CurrentParameter) + .Visit(newOuterLambdaBody); + var newInnerKeySelectorBody = new NavigationPropertyUnbindingVisitor(newInnerState.CurrentParameter) + .Visit(newInnerLambdaBody); - var outerApplyOrderingsResult = ApplyPendingOrderings(outerApplyNavigationsResult.source, outerApplyNavigationsResult.state); - var innerApplyOrderingsResult = ApplyPendingOrderings(innerApplyNavigationsResult.source, innerApplyNavigationsResult.state); + (newOuterSource, newOuterState) = ApplyPendingOrderings(newOuterSource, newOuterState); + (newInnerSource, newInnerState) = ApplyPendingOrderings(newInnerSource, newInnerState); - var resultSelectorRemap = RemapTwoArgumentResultSelector(resultSelector, outerApplyOrderingsResult.state, innerApplyOrderingsResult.state); + var resultSelectorRemap = RemapTwoArgumentResultSelector(resultSelector, newOuterState, newInnerState); var newMethodInfo = methodCallExpression.Method.GetGenericMethodDefinition().MakeGenericMethod( - outerApplyOrderingsResult.state.CurrentParameter.Type, - innerApplyOrderingsResult.state.CurrentParameter.Type, - outerApplyNavigationsResult.lambdaBody.Type, + newOuterState.CurrentParameter.Type, + newInnerState.CurrentParameter.Type, + newOuterLambdaBody.Type, resultSelectorRemap.lambda.Body.Type); var rewritten = Expression.Call( newMethodInfo, - outerApplyOrderingsResult.source, - innerApplyOrderingsResult.source, - Expression.Lambda(newOuterKeySelectorBody, outerApplyOrderingsResult.state.CurrentParameter), - Expression.Lambda(newInnerKeySelectorBody, innerApplyOrderingsResult.state.CurrentParameter), - Expression.Lambda(resultSelectorRemap.lambda.Body, outerApplyOrderingsResult.state.CurrentParameter, innerApplyOrderingsResult.state.CurrentParameter)); + newOuterSource, + newInnerSource, + Expression.Lambda(newOuterKeySelectorBody, newOuterState.CurrentParameter), + Expression.Lambda(newInnerKeySelectorBody, newInnerState.CurrentParameter), + Expression.Lambda(resultSelectorRemap.lambda.Body, newOuterState.CurrentParameter, newInnerState.CurrentParameter)); // temporarily change selector to ti => ti for purpose of finding & expanding navigations in the pending selector lambda itself var pendingSelector = resultSelectorRemap.state.PendingSelector; resultSelectorRemap.state.PendingSelector = Expression.Lambda(resultSelectorRemap.state.PendingSelector.Parameters[0], resultSelectorRemap.state.PendingSelector.Parameters[0]); var result = FindAndApplyNavigations(rewritten, pendingSelector, resultSelectorRemap.state); - result.state.PendingSelector = Expression.Lambda(result.lambdaBody, result.state.CurrentParameter); + result.State.PendingSelector = Expression.Lambda(result.LambdaBody, result.State.CurrentParameter); return new NavigationExpansionExpression( - result.source, - result.state, + result.Source, + result.State, methodCallExpression.Type); } @@ -559,41 +561,45 @@ private Expression ProcessGroupJoin(MethodCallExpression methodCallExpression) var innerKeySelector = methodCallExpression.Arguments[3].UnwrapLambdaFromQuote(); var resultSelector = methodCallExpression.Arguments[4].UnwrapLambdaFromQuote(); - var outerApplyNavigationsResult = FindAndApplyNavigations(outerSource.Operand, outerKeySelector, outerSource.State); - var innerApplyNavigationsResult = FindAndApplyNavigations(innerSource.Operand, innerKeySelector, innerSource.State); + var (newOuterSource, newOuterLambdaBody, newOuterState) + = FindAndApplyNavigations(outerSource.Operand, outerKeySelector, outerSource.State); + var (newInnerSource, newInnerLambdaBody, newInnerState) + = FindAndApplyNavigations(innerSource.Operand, innerKeySelector, innerSource.State); - var newOuterKeySelectorBody = new NavigationPropertyUnbindingVisitor(outerApplyNavigationsResult.state.CurrentParameter).Visit(outerApplyNavigationsResult.lambdaBody); - var newInnerKeySelectorBody = new NavigationPropertyUnbindingVisitor(innerApplyNavigationsResult.state.CurrentParameter).Visit(innerApplyNavigationsResult.lambdaBody); + var newOuterKeySelectorBody = new NavigationPropertyUnbindingVisitor(newOuterState.CurrentParameter) + .Visit(newOuterLambdaBody); + var newInnerKeySelectorBody = new NavigationPropertyUnbindingVisitor(newInnerState.CurrentParameter) + .Visit(newInnerLambdaBody); - var outerApplyOrderingsResult = ApplyPendingOrderings(outerApplyNavigationsResult.source, outerApplyNavigationsResult.state); - var innerApplyOrderingsResult = ApplyPendingOrderings(innerApplyNavigationsResult.source, innerApplyNavigationsResult.state); + (newOuterSource, newOuterState) = ApplyPendingOrderings(newOuterSource, newOuterState); + (newInnerSource, newInnerState) = ApplyPendingOrderings(newInnerSource, newInnerState); var resultSelectorBody = resultSelector.Body; var remappedResultSelectorBody = ReplacingExpressionVisitor.Replace( - resultSelector.Parameters[0], outerApplyOrderingsResult.state.PendingSelector.Body, resultSelector.Body); + resultSelector.Parameters[0], newOuterState.PendingSelector.Body, resultSelector.Body); var groupingParameter = resultSelector.Parameters[1]; - var newGroupingParameter = Expression.Parameter(typeof(IEnumerable<>).MakeGenericType(innerApplyOrderingsResult.state.CurrentParameter.Type), "new_" + groupingParameter.Name); + var newGroupingParameter = Expression.Parameter(typeof(IEnumerable<>).MakeGenericType(newInnerState.CurrentParameter.Type), "new_" + groupingParameter.Name); var groupingMapping = new List { nameof(TransparentIdentifier.Inner) }; // TODO: need to create the new state and copy includes from the old one, rather than simply copying it over to grouping // this shouldn't be a problem currently since we don't support queries that compose on the grouping // but when we do, state can't be shared - otherwise any nav expansion that affects the flattened part of the GroupJoin would be incorrectly propagated to the grouping as well - var newGrouping = new NavigationExpansionExpression(newGroupingParameter, innerApplyOrderingsResult.state, groupingParameter.Type); + var newGrouping = new NavigationExpansionExpression(newGroupingParameter, newInnerState, groupingParameter.Type); remappedResultSelectorBody = new ExpressionReplacingVisitor( groupingParameter, new NavigationExpansionRootExpression(newGrouping, groupingMapping)).Visit(remappedResultSelectorBody); - foreach (var outerCustomRootMapping in outerApplyOrderingsResult.state.CustomRootMappings) + foreach (var outerCustomRootMapping in newOuterState.CustomRootMappings) { outerCustomRootMapping.Insert(0, nameof(TransparentIdentifier.Outer)); } - foreach (var outerSourceMapping in outerApplyOrderingsResult.state.SourceMappings) + foreach (var outerSourceMapping in newOuterState.SourceMappings) { - foreach (var navigationTreeNode in outerSourceMapping.NavigationTree.Flatten().Where(n => n.ExpansionMode == NavigationTreeNodeExpansionMode.ReferenceComplete)) + foreach (var navigationTreeNode in outerSourceMapping.NavigationTree.Flatten().Where(n => n.ExpansionState == NavigationState.ReferenceComplete)) { navigationTreeNode.ToMapping.Insert(0, nameof(TransparentIdentifier.Outer)); foreach (var fromMapping in navigationTreeNode.FromMappings) @@ -603,23 +609,23 @@ private Expression ProcessGroupJoin(MethodCallExpression methodCallExpression) } } - var resultType = typeof(TransparentIdentifier<,>).MakeGenericType(outerApplyOrderingsResult.state.CurrentParameter.Type, newGroupingParameter.Type); + var resultType = typeof(TransparentIdentifier<,>).MakeGenericType(newOuterState.CurrentParameter.Type, newGroupingParameter.Type); var transparentIdentifierCtorInfo = resultType.GetTypeInfo().GetConstructors().Single(); var transparentIdentifierParameter = Expression.Parameter(resultType, "groupjoin"); - var newPendingSelectorBody = new ExpressionReplacingVisitor(outerApplyOrderingsResult.state.CurrentParameter, transparentIdentifierParameter).Visit(remappedResultSelectorBody); + var newPendingSelectorBody = new ExpressionReplacingVisitor(newOuterState.CurrentParameter, transparentIdentifierParameter).Visit(remappedResultSelectorBody); newPendingSelectorBody = new ExpressionReplacingVisitor(newGroupingParameter, transparentIdentifierParameter).Visit(newPendingSelectorBody); // for GroupJoin inner sources are not available, only the outer source mappings and the custom mappings for the grouping var newState = new NavigationExpansionExpressionState( transparentIdentifierParameter, - outerApplyOrderingsResult.state.SourceMappings, + newOuterState.SourceMappings, Expression.Lambda(newPendingSelectorBody, transparentIdentifierParameter), applyPendingSelector: true, - outerApplyOrderingsResult.state.PendingOrderings, - outerApplyOrderingsResult.state.PendingIncludeChain, - outerApplyOrderingsResult.state.PendingCardinalityReducingOperator, - outerApplyOrderingsResult.state.CustomRootMappings.Concat(new[] { groupingMapping }).ToList(), + newOuterState.PendingOrderings, + newOuterState.PendingIncludeChain, + newOuterState.PendingCardinalityReducingOperator, + newOuterState.CustomRootMappings.Concat(new[] { groupingMapping }).ToList(), materializeCollectionNavigation: null); var transparentIdentifierOuterMemberInfo = resultType.GetTypeInfo().GetDeclaredField("Outer"); @@ -628,34 +634,34 @@ private Expression ProcessGroupJoin(MethodCallExpression methodCallExpression) var lambda = Expression.Lambda( Expression.New( transparentIdentifierCtorInfo, - new[] { outerApplyOrderingsResult.state.CurrentParameter, newGroupingParameter }, + new[] { newOuterState.CurrentParameter, newGroupingParameter }, new[] { transparentIdentifierOuterMemberInfo, transparentIdentifierInnerMemberInfo }), - outerApplyOrderingsResult.state.CurrentParameter, + newOuterState.CurrentParameter, newGroupingParameter); var newMethodInfo = methodCallExpression.Method.GetGenericMethodDefinition().MakeGenericMethod( - outerApplyOrderingsResult.state.CurrentParameter.Type, - innerApplyOrderingsResult.state.CurrentParameter.Type, - outerApplyNavigationsResult.lambdaBody.Type, + newOuterState.CurrentParameter.Type, + newInnerState.CurrentParameter.Type, + newOuterLambdaBody.Type, lambda.Body.Type); var rewritten = Expression.Call( newMethodInfo, - outerApplyOrderingsResult.source, - innerApplyOrderingsResult.source, - Expression.Lambda(newOuterKeySelectorBody, outerApplyOrderingsResult.state.CurrentParameter), - Expression.Lambda(newInnerKeySelectorBody, innerApplyOrderingsResult.state.CurrentParameter), + newOuterSource, + newInnerSource, + Expression.Lambda(newOuterKeySelectorBody, newOuterState.CurrentParameter), + Expression.Lambda(newInnerKeySelectorBody, newInnerState.CurrentParameter), lambda); // temporarily change selector to ti => ti for purpose of finding & expanding navigations in the pending selector lambda itself var pendingSelector = newState.PendingSelector; newState.PendingSelector = Expression.Lambda(newState.PendingSelector.Parameters[0], newState.PendingSelector.Parameters[0]); var result = FindAndApplyNavigations(rewritten, pendingSelector, newState); - result.state.PendingSelector = Expression.Lambda(result.lambdaBody, result.state.CurrentParameter); + result.State.PendingSelector = Expression.Lambda(result.LambdaBody, result.State.CurrentParameter); return new NavigationExpansionExpression( - result.source, - result.state, + result.Source, + result.State, methodCallExpression.Type); } @@ -665,17 +671,18 @@ private Expression ProcessAll(MethodCallExpression methodCallExpression) source = RemoveIncludesFromSource(source); var predicate = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); - var applyNavigationsResult = FindAndApplyNavigations(source.Operand, predicate, source.State); - var newPredicateBody = new NavigationPropertyUnbindingVisitor(applyNavigationsResult.state.CurrentParameter).Visit(applyNavigationsResult.lambdaBody); - var applyOrderingsResult = ApplyPendingOrderings(applyNavigationsResult.source, applyNavigationsResult.state); + var (newSource, newLambdaBody, newState) = FindAndApplyNavigations(source.Operand, predicate, source.State); + var newPredicateBody = new NavigationPropertyUnbindingVisitor(newState.CurrentParameter) + .Visit(newLambdaBody); + (newSource, newState) = ApplyPendingOrderings(newSource, newState); - var newMethodInfo = methodCallExpression.Method.GetGenericMethodDefinition().MakeGenericMethod(applyOrderingsResult.state.CurrentParameter.Type); + var newMethodInfo = methodCallExpression.Method.GetGenericMethodDefinition().MakeGenericMethod(newState.CurrentParameter.Type); var rewritten = Expression.Call( newMethodInfo, - applyOrderingsResult.source, + newSource, Expression.Lambda( newPredicateBody, - applyOrderingsResult.state.CurrentParameter)); + newState.CurrentParameter)); return rewritten; } @@ -718,7 +725,7 @@ private NavigationExpansionExpression RemoveIncludesFromSource(NavigationExpansi private void RemoveIncludes(NavigationTreeNode navigationTreeNode) { - navigationTreeNode.Included = NavigationTreeNodeIncludeMode.NotNeeded; + navigationTreeNode.IncludeState = NavigationState.NotNeeded; foreach (var child in navigationTreeNode.Children) { RemoveIncludes(child); @@ -738,25 +745,25 @@ private Expression ProcessAverageSumMinMax(MethodCallExpression methodCallExpres if (methodCallExpression.Arguments.Count == 2) { var selector = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); - var applyNavigationsResult = FindAndApplyNavigations(source.Operand, selector, source.State); - var newSelectorBody = new NavigationPropertyUnbindingVisitor(applyNavigationsResult.state.CurrentParameter).Visit(applyNavigationsResult.lambdaBody); - var newSelector = Expression.Lambda(newSelectorBody, applyNavigationsResult.state.CurrentParameter); + var (newSource, newLambdaBody, newState) = FindAndApplyNavigations(source.Operand, selector, source.State); + var newSelectorBody = new NavigationPropertyUnbindingVisitor(newState.CurrentParameter).Visit(newLambdaBody); + var newSelector = Expression.Lambda(newSelectorBody, newState.CurrentParameter); - var applyOrderingsResult = ApplyPendingOrderings(applyNavigationsResult.source, applyNavigationsResult.state); + (newSource, newState) = ApplyPendingOrderings(newSource, newState); var newMethod = methodCallExpression.Method.GetGenericMethodDefinition(); // Enumerable Min/Max overloads have only one type argument, Queryable have 2 but no overloads explosion if ((methodCallExpression.Method.Name == nameof(Enumerable.Min) || methodCallExpression.Method.Name == nameof(Enumerable.Max)) && newMethod.GetGenericArguments().Count() == 2) { - newMethod = newMethod.MakeGenericMethod(applyNavigationsResult.state.CurrentParameter.Type, methodCallExpression.Type); + newMethod = newMethod.MakeGenericMethod(newState.CurrentParameter.Type, methodCallExpression.Type); } else { - newMethod = newMethod.MakeGenericMethod(applyNavigationsResult.state.CurrentParameter.Type); + newMethod = newMethod.MakeGenericMethod(newState.CurrentParameter.Type); } - return Expression.Call(newMethod, applyOrderingsResult.source, newSelector); + return Expression.Call(newMethod, newSource, newSelector); } return methodCallExpression.Update(methodCallExpression.Object, new[] { source }); @@ -766,9 +773,9 @@ private Expression ProcessDistinct(MethodCallExpression methodCallExpression) { var source = VisitSourceExpression(methodCallExpression.Arguments[0]); var preProcessResult = PreProcessTerminatingOperation(source); - var rewritten = methodCallExpression.Update(methodCallExpression.Object, new[] { preProcessResult.source }); + var rewritten = methodCallExpression.Update(methodCallExpression.Object, new[] { preProcessResult.Source }); - return new NavigationExpansionExpression(rewritten, preProcessResult.state, methodCallExpression.Type); + return new NavigationExpansionExpression(rewritten, preProcessResult.State, methodCallExpression.Type); } private Expression ProcessDefaultIfEmpty(MethodCallExpression methodCallExpression) @@ -783,9 +790,9 @@ private Expression ProcessDefaultIfEmpty(MethodCallExpression methodCallExpressi if (methodCallExpression.Method.MethodIsClosedFormOf(LinqMethodHelpers.QueryableDefaultIfEmptyWithDefaultValue)) { var preProcessResult = PreProcessTerminatingOperation(source); - var rewritten = methodCallExpression.Update(methodCallExpression.Object, new[] { preProcessResult.source }); + var rewritten = methodCallExpression.Update(methodCallExpression.Object, new[] { preProcessResult.Source }); - return new NavigationExpansionExpression(rewritten, preProcessResult.state, methodCallExpression.Type); + return new NavigationExpansionExpression(rewritten, preProcessResult.State, methodCallExpression.Type); } else { @@ -800,7 +807,7 @@ private Expression ProcessOfType(MethodCallExpression methodCallExpression) { var source = VisitSourceExpression(methodCallExpression.Arguments[0]); var preProcessResult = PreProcessTerminatingOperation(source); - var newEntityType = _model.FindEntityType(methodCallExpression.Method.GetGenericArguments()[0]); + var newEntityType = _queryCompilationContext.Model.FindEntityType(methodCallExpression.Method.GetGenericArguments()[0]); // TODO: possible small optimization - only apply this if newEntityType is different than the old one if (newEntityType != null) @@ -809,20 +816,20 @@ private Expression ProcessOfType(MethodCallExpression methodCallExpression) var newNavigationTreeRoot = NavigationTreeNode.CreateRoot(newSourceMapping, fromMapping: new List(), optional: false); newSourceMapping.NavigationTree = newNavigationTreeRoot; - preProcessResult.state.SourceMappings = new List { newSourceMapping }; + preProcessResult.State.SourceMappings = new List { newSourceMapping }; - var newPendingSelectorParameter = Expression.Parameter(newEntityType.ClrType, preProcessResult.state.CurrentParameter.Name); + var newPendingSelectorParameter = Expression.Parameter(newEntityType.ClrType, preProcessResult.State.CurrentParameter.Name); // since we just ran preprocessing and the method is OfType, pending selector is guaranteed to be simple e => e - var newPendingSelectorBody = new NavigationPropertyBindingVisitor(newPendingSelectorParameter, preProcessResult.state.SourceMappings).Visit(newPendingSelectorParameter); + var newPendingSelectorBody = new NavigationPropertyBindingVisitor(newPendingSelectorParameter, preProcessResult.State.SourceMappings).Visit(newPendingSelectorParameter); - preProcessResult.state.CurrentParameter = newPendingSelectorParameter; - preProcessResult.state.PendingSelector = Expression.Lambda(newPendingSelectorBody, newPendingSelectorParameter); + preProcessResult.State.CurrentParameter = newPendingSelectorParameter; + preProcessResult.State.PendingSelector = Expression.Lambda(newPendingSelectorBody, newPendingSelectorParameter); } - var rewritten = methodCallExpression.Update(methodCallExpression.Object, new[] { preProcessResult.source }); + var rewritten = methodCallExpression.Update(methodCallExpression.Object, new[] { preProcessResult.Source }); - return new NavigationExpansionExpression(rewritten, preProcessResult.state, methodCallExpression.Type); + return new NavigationExpansionExpression(rewritten, preProcessResult.State, methodCallExpression.Type); } private Expression ProcessSkipTake(MethodCallExpression methodCallExpression) @@ -830,9 +837,9 @@ private Expression ProcessSkipTake(MethodCallExpression methodCallExpression) var source = VisitSourceExpression(methodCallExpression.Arguments[0]); var preProcessResult = PreProcessTerminatingOperation(source); var newArgument = Visit(methodCallExpression.Arguments[1]); - var rewritten = methodCallExpression.Update(methodCallExpression.Object, new[] { preProcessResult.source, newArgument }); + var rewritten = methodCallExpression.Update(methodCallExpression.Object, new[] { preProcessResult.Source, newArgument }); - return new NavigationExpansionExpression(rewritten, preProcessResult.state, methodCallExpression.Type); + return new NavigationExpansionExpression(rewritten, preProcessResult.State, methodCallExpression.Type); } private Expression ProcessSetOperation(MethodCallExpression methodCallExpression) @@ -847,12 +854,12 @@ private Expression ProcessSetOperation(MethodCallExpression methodCallExpression // Extract the includes from each side and compare to make sure they're identical. // We don't allow set operations over operands with different includes. - var pendingIncludeFindingVisitor = new PendingIncludeFindingVisitor(skipCollectionNavigations: false); - pendingIncludeFindingVisitor.Visit(preProcessResult1.state.PendingSelector.Body); + var pendingIncludeFindingVisitor = new PendingSelectorIncludeVisitor(skipCollectionNavigations: false, rewriteIncludes: false); + pendingIncludeFindingVisitor.Visit(preProcessResult1.State.PendingSelector.Body); var pendingIncludes1 = pendingIncludeFindingVisitor.PendingIncludes; - pendingIncludeFindingVisitor = new PendingIncludeFindingVisitor(skipCollectionNavigations: false); - pendingIncludeFindingVisitor.Visit(preProcessResult2.state.PendingSelector.Body); + pendingIncludeFindingVisitor = new PendingSelectorIncludeVisitor(skipCollectionNavigations: false, rewriteIncludes: false); + pendingIncludeFindingVisitor.Visit(preProcessResult2.State.PendingSelector.Body); var pendingIncludes2 = pendingIncludeFindingVisitor.PendingIncludes; if (pendingIncludes1.Count != pendingIncludes2.Count) @@ -871,45 +878,45 @@ private Expression ProcessSetOperation(MethodCallExpression methodCallExpression // If the siblings are different types, one is derived from the other the set operation returns the less derived type. // Find that. - var clrType1 = preProcessResult1.state.CurrentParameter.Type; - var clrType2 = preProcessResult2.state.CurrentParameter.Type; - var parentState = clrType1.IsAssignableFrom(clrType2) ? preProcessResult1.state : preProcessResult2.state; + var clrType1 = preProcessResult1.State.CurrentParameter.Type; + var clrType2 = preProcessResult2.State.CurrentParameter.Type; + var parentState = clrType1.IsAssignableFrom(clrType2) ? preProcessResult1.State : preProcessResult2.State; - var rewritten = methodCallExpression.Update(null, new[] { preProcessResult1.source, preProcessResult2.source }); + var rewritten = methodCallExpression.Update(null, new[] { preProcessResult1.Source, preProcessResult2.Source }); return new NavigationExpansionExpression(rewritten, parentState, methodCallExpression.Type); } - private (Expression source, NavigationExpansionExpressionState state) PreProcessTerminatingOperation(NavigationExpansionExpression source) + private (Expression Source, NavigationExpansionExpressionState State) PreProcessTerminatingOperation(NavigationExpansionExpression source) { - var applyOrderingsResult = ApplyPendingOrderings(source.Operand, source.State); + var (newSource, newState) = ApplyPendingOrderings(source.Operand, source.State); - if (applyOrderingsResult.state.ApplyPendingSelector) + if (newState.ApplyPendingSelector) { - var unbinder = new NavigationPropertyUnbindingVisitor(applyOrderingsResult.state.CurrentParameter); - var newSelectorBody = unbinder.Visit(applyOrderingsResult.state.PendingSelector.Body); + var unbinder = new NavigationPropertyUnbindingVisitor(newState.CurrentParameter); + var newSelectorBody = unbinder.Visit(newState.PendingSelector.Body); - var pssmg = new PendingSelectorSourceMappingGenerator(applyOrderingsResult.state.PendingSelector.Parameters[0], null); - pssmg.Visit(applyOrderingsResult.state.PendingSelector.Body); + var pssmg = new PendingSelectorSourceMappingGenerator(newState.PendingSelector.Parameters[0], null); + pssmg.Visit(newState.PendingSelector.Body); - var selectorMethodInfo = applyOrderingsResult.source.Type.IsQueryableType() + var selectorMethodInfo = newSource.Type.IsQueryableType() ? LinqMethodHelpers.QueryableSelectMethodInfo : LinqMethodHelpers.EnumerableSelectMethodInfo; selectorMethodInfo = selectorMethodInfo.MakeGenericMethod( - applyOrderingsResult.state.CurrentParameter.Type, + newState.CurrentParameter.Type, newSelectorBody.Type); - var result = Expression.Call( + var resultSource = Expression.Call( selectorMethodInfo, - applyOrderingsResult.source, - Expression.Lambda(newSelectorBody, applyOrderingsResult.state.CurrentParameter)); + newSource, + Expression.Lambda(newSelectorBody, newState.CurrentParameter)); var newPendingSelectorParameter = Expression.Parameter(newSelectorBody.Type); var customRootMapping = new List(); Expression newPendingSelectorBody; - if (applyOrderingsResult.state.PendingSelector.Body is NavigationBindingExpression binding) + if (newState.PendingSelector.Body is NavigationBindingExpression binding) { newPendingSelectorBody = new NavigationBindingExpression( newPendingSelectorParameter, @@ -923,10 +930,10 @@ private Expression ProcessSetOperation(MethodCallExpression methodCallExpression // if there are any includes in the result we need to re-project the previous pending selector and re-create bindings based on new mappings // so that we retain include information in case this was the last operation in the query (i.e. bindings won't be generated by processing further nodes) var customRootExpression = new CustomRootExpression(newPendingSelectorParameter, customRootMapping, newPendingSelectorParameter.Type); - if (pssmg.SourceMappings.Where(sm => sm.NavigationTree.Flatten().Where(n => n.Included == NavigationTreeNodeIncludeMode.ReferencePending || n.Included == NavigationTreeNodeIncludeMode.Collection).Any()).Any()) + if (pssmg.SourceMappings.Where(sm => sm.NavigationTree.Flatten().Where(n => n.IncludeState == NavigationState.ReferencePending || n.IncludeState == NavigationState.CollectionPending).Any()).Any()) { var selectorReprojector = new PendingSelectorReprojector(customRootExpression); - newPendingSelectorBody = selectorReprojector.Visit(applyOrderingsResult.state.PendingSelector.Body); + newPendingSelectorBody = selectorReprojector.Visit(newState.PendingSelector.Body); var binder = new NavigationPropertyBindingVisitor(newPendingSelectorParameter, pssmg.SourceMappings); newPendingSelectorBody = binder.Visit(newPendingSelectorBody); @@ -937,7 +944,7 @@ private Expression ProcessSetOperation(MethodCallExpression methodCallExpression } } - var newState = new NavigationExpansionExpressionState( + var resultState = new NavigationExpansionExpressionState( newPendingSelectorParameter, pssmg.SourceMappings, Expression.Lambda(newPendingSelectorBody, newPendingSelectorParameter), @@ -948,11 +955,11 @@ private Expression ProcessSetOperation(MethodCallExpression methodCallExpression new List> { customRootMapping }, materializeCollectionNavigation: null); - return (source: result, state: newState); + return (resultSource, resultState); } else { - return (applyOrderingsResult.source, applyOrderingsResult.state); + return (newSource, newState); } } @@ -1119,15 +1126,15 @@ private Expression ProcessInclude(MethodCallExpression methodCallExpression) var source = VisitSourceExpression(methodCallExpression.Arguments[0]); var includeLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); - var applyOrderingsResult = ApplyPendingOrderings(source.Operand, source.State); + var (newSource, newState) = ApplyPendingOrderings(source.Operand, source.State); // just bind to mark all the necessary navigation for include in the future // include need to be delayed, in case they are not needed, e.g. when there is a projection on top that only projects scalars Expression remappedIncludeLambdaBody; - if (methodCallExpression.Method.Name == "Include") + if (methodCallExpression.Method.Name == nameof(EntityFrameworkQueryableExtensions.Include)) { remappedIncludeLambdaBody = ReplacingExpressionVisitor.Replace( - includeLambda.Parameters[0], applyOrderingsResult.state.PendingSelector.Body, includeLambda.Body); + includeLambda.Parameters[0], newState.PendingSelector.Body, includeLambda.Body); } else { @@ -1135,38 +1142,38 @@ private Expression ProcessInclude(MethodCallExpression methodCallExpression) // because the type mismatch (trying to compose Navigation access on the ICollection from the first include // we manually construct navigation binding that should be a root of the new include, its EntityType being the element of the previously included collection // pendingIncludeLambda is only used for marking the includes - as long as the NavigationTreeNodes are correct it should be fine - if (applyOrderingsResult.state.PendingIncludeChain.NavigationTreeNode.Navigation.IsCollection()) + if (newState.PendingIncludeChain.NavigationTreeNode.Navigation.IsCollection()) { var newIncludeLambdaRoot = new NavigationBindingExpression( - applyOrderingsResult.state.CurrentParameter, - applyOrderingsResult.state.PendingIncludeChain.NavigationTreeNode, - applyOrderingsResult.state.PendingIncludeChain.EntityType, - applyOrderingsResult.state.PendingIncludeChain.SourceMapping, + newState.CurrentParameter, + newState.PendingIncludeChain.NavigationTreeNode, + newState.PendingIncludeChain.EntityType, + newState.PendingIncludeChain.SourceMapping, includeLambda.Parameters[0].Type); remappedIncludeLambdaBody = new ExpressionReplacingVisitor(includeLambda.Parameters[0], newIncludeLambdaRoot).Visit(includeLambda.Body); } else { - var pendingIncludeChainLambda = Expression.Lambda(applyOrderingsResult.state.PendingIncludeChain, applyOrderingsResult.state.CurrentParameter); + var pendingIncludeChainLambda = Expression.Lambda(newState.PendingIncludeChain, newState.CurrentParameter); remappedIncludeLambdaBody = ReplacingExpressionVisitor.Replace( includeLambda.Parameters[0], pendingIncludeChainLambda.Body, includeLambda.Body); } } - var binder = new NavigationPropertyBindingVisitor(applyOrderingsResult.state.PendingSelector.Parameters[0], applyOrderingsResult.state.SourceMappings, bindInclude: true); + var binder = new NavigationPropertyBindingVisitor(newState.PendingSelector.Parameters[0], newState.SourceMappings, bindInclude: true); var boundIncludeLambdaBody = binder.Visit(remappedIncludeLambdaBody); if (boundIncludeLambdaBody is NavigationBindingExpression navigationBindingExpression) { - applyOrderingsResult.state.PendingIncludeChain = navigationBindingExpression; + newState.PendingIncludeChain = navigationBindingExpression; } else { throw new InvalidOperationException("Incorrect include argument: " + includeLambda); } - return new NavigationExpansionExpression(applyOrderingsResult.source, applyOrderingsResult.state, methodCallExpression.Type); + return new NavigationExpansionExpression(newSource, newState, methodCallExpression.Type); } private MethodCallExpression SimplifyPredicateMethod(MethodCallExpression methodCallExpression, bool queryable) @@ -1262,10 +1269,10 @@ private Expression ProcessCardinalityReducingOperation(MethodCallExpression meth } var source = VisitSourceExpression(methodCallExpression.Arguments[0]); - var applyOrderingsResult = ApplyPendingOrderings(source.Operand, source.State); - applyOrderingsResult.state.PendingCardinalityReducingOperator = methodCallExpression.Method; + var (newSource, newState) = ApplyPendingOrderings(source.Operand, source.State); + newState.PendingCardinalityReducingOperator = methodCallExpression.Method; - return new NavigationExpansionExpression(applyOrderingsResult.source, applyOrderingsResult.state, methodCallExpression.Type); + return new NavigationExpansionExpression(newSource, newState, methodCallExpression.Type); } private Expression ProcessFromSql(MethodCallExpression methodCallExpression) @@ -1282,7 +1289,7 @@ protected override Expression VisitConstant(ConstantExpression constantExpressio && constantExpression.Value.GetType().GetGenericTypeDefinition() == typeof(EntityQueryable<>)) { var elementType = constantExpression.Value.GetType().GetSequenceType(); - var entityType = _model.FindEntityType(elementType); + var entityType = _queryCompilationContext.Model.FindEntityType(elementType); return NavigationExpansionHelpers.CreateNavigationExpansionRoot(constantExpression, entityType, materializeCollectionNavigation: null); } @@ -1290,7 +1297,8 @@ protected override Expression VisitConstant(ConstantExpression constantExpressio return base.VisitConstant(constantExpression); } - private (Expression source, NavigationExpansionExpressionState state) ApplyPendingOrderings(Expression source, NavigationExpansionExpressionState state) + private (Expression Source, NavigationExpansionExpressionState State) ApplyPendingOrderings( + Expression source, NavigationExpansionExpressionState state) { foreach (var pendingOrdering in state.PendingOrderings) { @@ -1306,7 +1314,7 @@ protected override Expression VisitConstant(ConstantExpression constantExpressio return (source, state); } - private (Expression source, Expression lambdaBody, NavigationExpansionExpressionState state) FindAndApplyNavigations( + private (Expression Source, Expression LambdaBody, NavigationExpansionExpressionState State) FindAndApplyNavigations( Expression source, LambdaExpression lambda, NavigationExpansionExpressionState state) @@ -1333,7 +1341,7 @@ protected override Expression VisitConstant(ConstantExpression constantExpressio foreach (var sourceMapping in state.SourceMappings) { - if (sourceMapping.NavigationTree.Flatten().Any(n => n.ExpansionMode == NavigationTreeNodeExpansionMode.ReferencePending)) + if (sourceMapping.NavigationTree.Flatten().Any(n => n.ExpansionState == NavigationState.ReferencePending)) { foreach (var navigationTree in sourceMapping.NavigationTree.Children.Where(n => !n.Navigation.IsCollection())) { @@ -1370,7 +1378,7 @@ protected override Expression VisitConstant(ConstantExpression constantExpressio state.CustomRootMappings, state.MaterializeCollectionNavigation); - return (result.source, lambdaBody: boundLambdaBody, state: newState); + return (result.source, LambdaBody: boundLambdaBody, State: newState); } private (LambdaExpression lambda, NavigationExpansionExpressionState state) RemapTwoArgumentResultSelector( @@ -1402,7 +1410,7 @@ protected override Expression VisitConstant(ConstantExpression constantExpressio foreach (var outerSourceMapping in outerState.SourceMappings) { - foreach (var navigationTreeNode in outerSourceMapping.NavigationTree.Flatten().Where(n => n.ExpansionMode == NavigationTreeNodeExpansionMode.ReferenceComplete)) + foreach (var navigationTreeNode in outerSourceMapping.NavigationTree.Flatten().Where(n => n.ExpansionState == NavigationState.ReferenceComplete)) { navigationTreeNode.ToMapping.Insert(0, nameof(TransparentIdentifier.Outer)); foreach (var fromMapping in navigationTreeNode.FromMappings) @@ -1419,7 +1427,7 @@ protected override Expression VisitConstant(ConstantExpression constantExpressio foreach (var innerSourceMapping in innerState.SourceMappings) { - foreach (var navigationTreeNode in innerSourceMapping.NavigationTree.Flatten().Where(n => n.ExpansionMode == NavigationTreeNodeExpansionMode.ReferenceComplete)) + foreach (var navigationTreeNode in innerSourceMapping.NavigationTree.Flatten().Where(n => n.ExpansionState == NavigationState.ReferenceComplete)) { navigationTreeNode.ToMapping.Insert(0, nameof(TransparentIdentifier.Inner)); foreach (var fromMapping in navigationTreeNode.FromMappings) diff --git a/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpansionReducingVisitor.cs b/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpansionReducingVisitor.cs index 0648f4d476f..d87a4b23665 100644 --- a/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpansionReducingVisitor.cs +++ b/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpansionReducingVisitor.cs @@ -14,97 +14,97 @@ namespace Microsoft.EntityFrameworkCore.Query.NavigationExpansion.Visitors { public class NavigationExpansionReducingVisitor : ExpressionVisitor { - protected override Expression VisitExtension(Expression extensionExpression) + public NavigationExpansionReducingVisitor() { - if (extensionExpression is NavigationBindingExpression navigationBindingExpression) - { - var result = navigationBindingExpression.RootParameter.BuildPropertyAccess(navigationBindingExpression.NavigationTreeNode.ToMapping); - - return result; - } - - if (extensionExpression is NavigationExpansionRootExpression navigationExpansionRootExpression) - { - return Visit(navigationExpansionRootExpression.Unwrap()); - } + } - if (extensionExpression is NavigationExpansionExpression navigationExpansionExpression) + protected override Expression VisitExtension(Expression extensionExpression) + { + switch (extensionExpression) { - var includeResult = ApplyIncludes(navigationExpansionExpression); - var state = includeResult.state; - var result = Visit(includeResult.operand); - - if (!state.ApplyPendingSelector - && state.PendingOrderings.Count == 0 - && state.PendingCardinalityReducingOperator == null - && state.MaterializeCollectionNavigation == null) - { - return result; - } - - var parameter = Expression.Parameter(result.Type.GetSequenceType()); - - foreach (var pendingOrdering in state.PendingOrderings) - { - var remappedKeySelectorBody = new ExpressionReplacingVisitor(pendingOrdering.keySelector.Parameters[0], state.CurrentParameter).Visit(pendingOrdering.keySelector.Body); - var newSelectorBody = new NavigationPropertyUnbindingVisitor(state.CurrentParameter).Visit(remappedKeySelectorBody); - var newSelector = Expression.Lambda(newSelectorBody, state.CurrentParameter); - var orderingMethod = pendingOrdering.method.MakeGenericMethod(state.CurrentParameter.Type, newSelectorBody.Type); - result = Expression.Call(orderingMethod, result, newSelector); - } - - if (state.ApplyPendingSelector) - { - var pendingSelector = (LambdaExpression)new NavigationPropertyUnbindingVisitor(state.CurrentParameter).Visit(state.PendingSelector); - var pendingSelectorBodyType = pendingSelector.Type.GetGenericArguments()[1]; - - var pendingSelectMathod = result.Type.IsGenericType && (result.Type.GetGenericTypeDefinition() == typeof(IEnumerable<>) || result.Type.GetGenericTypeDefinition() == typeof(IOrderedEnumerable<>)) - ? LinqMethodHelpers.EnumerableSelectMethodInfo.MakeGenericMethod(parameter.Type, pendingSelectorBodyType) - : LinqMethodHelpers.QueryableSelectMethodInfo.MakeGenericMethod(parameter.Type, pendingSelectorBodyType); - - result = Expression.Call(pendingSelectMathod, result, pendingSelector); - parameter = Expression.Parameter(result.Type.GetSequenceType()); - } - - if (state.PendingCardinalityReducingOperator != null) - { - result = Expression.Call(state.PendingCardinalityReducingOperator, result); - } - - if (state.MaterializeCollectionNavigation != null) - { - result = new MaterializeCollectionNavigationExpression(result, state.MaterializeCollectionNavigation); - } - - if (navigationExpansionExpression.Type != result.Type && navigationExpansionExpression.Type.IsGenericType) - { - if (navigationExpansionExpression.Type.GetGenericTypeDefinition() == typeof(IOrderedQueryable<>)) + case NavigationBindingExpression navigationBindingExpression: { - var toOrderedQueryableMethodInfo = ToOrderedQueryableMethod.MakeGenericMethod(parameter.Type); - - return Expression.Call(toOrderedQueryableMethodInfo, result); + return navigationBindingExpression.RootParameter.BuildPropertyAccess(navigationBindingExpression.NavigationTreeNode.ToMapping); } - else if (navigationExpansionExpression.Type.GetGenericTypeDefinition() == typeof(IOrderedEnumerable<>)) - { - var toOrderedEnumerableMethodInfo = ToOrderedEnumerableMethod.MakeGenericMethod(parameter.Type); - return Expression.Call(toOrderedEnumerableMethodInfo, result); + case NavigationExpansionRootExpression navigationExpansionRootExpression: + return Visit(navigationExpansionRootExpression.Unwrap()); + case NavigationExpansionExpression navigationExpansionExpression: + { + var includeResult = ApplyIncludes(navigationExpansionExpression); + var state = includeResult.State; + var result = Visit(includeResult.Operand); + + if (!state.ApplyPendingSelector + && state.PendingOrderings.Count == 0 + && state.PendingCardinalityReducingOperator == null + && state.MaterializeCollectionNavigation == null) + { + return result; + } + + var parameter = Expression.Parameter(result.Type.GetSequenceType()); + + foreach (var pendingOrdering in state.PendingOrderings) + { + var remappedKeySelectorBody = new ExpressionReplacingVisitor(pendingOrdering.keySelector.Parameters[0], state.CurrentParameter).Visit(pendingOrdering.keySelector.Body); + var newSelectorBody = new NavigationPropertyUnbindingVisitor(state.CurrentParameter).Visit(remappedKeySelectorBody); + var newSelector = Expression.Lambda(newSelectorBody, state.CurrentParameter); + var orderingMethod = pendingOrdering.method.MakeGenericMethod(state.CurrentParameter.Type, newSelectorBody.Type); + result = Expression.Call(orderingMethod, result, newSelector); + } + + if (state.ApplyPendingSelector) + { + var pendingSelector = (LambdaExpression)new NavigationPropertyUnbindingVisitor(state.CurrentParameter).Visit(state.PendingSelector); + var pendingSelectorBodyType = pendingSelector.Type.GetGenericArguments()[1]; + + var pendingSelectMathod = result.Type.IsGenericType && (result.Type.GetGenericTypeDefinition() == typeof(IEnumerable<>) || result.Type.GetGenericTypeDefinition() == typeof(IOrderedEnumerable<>)) + ? LinqMethodHelpers.EnumerableSelectMethodInfo.MakeGenericMethod(parameter.Type, pendingSelectorBodyType) + : LinqMethodHelpers.QueryableSelectMethodInfo.MakeGenericMethod(parameter.Type, pendingSelectorBodyType); + + result = Expression.Call(pendingSelectMathod, result, pendingSelector); + parameter = Expression.Parameter(result.Type.GetSequenceType()); + } + + if (state.PendingCardinalityReducingOperator != null) + { + result = Expression.Call(state.PendingCardinalityReducingOperator, result); + } + + if (state.MaterializeCollectionNavigation != null) + { + result = new MaterializeCollectionNavigationExpression(result, state.MaterializeCollectionNavigation); + } + + if (navigationExpansionExpression.Type != result.Type && navigationExpansionExpression.Type.IsGenericType) + { + if (navigationExpansionExpression.Type.GetGenericTypeDefinition() == typeof(IOrderedQueryable<>)) + { + var toOrderedQueryableMethodInfo = ToOrderedQueryableMethod.MakeGenericMethod(parameter.Type); + + return Expression.Call(toOrderedQueryableMethodInfo, result); + } + else if (navigationExpansionExpression.Type.GetGenericTypeDefinition() == typeof(IOrderedEnumerable<>)) + { + var toOrderedEnumerableMethodInfo = ToOrderedEnumerableMethod.MakeGenericMethod(parameter.Type); + + return Expression.Call(toOrderedEnumerableMethodInfo, result); + } + } + + return result; } - } - - return result; } return base.VisitExtension(extensionExpression); } - private (Expression operand, NavigationExpansionExpressionState state) ApplyIncludes(NavigationExpansionExpression navigationExpansionExpression) + private (Expression Operand, NavigationExpansionExpressionState State) ApplyIncludes( + NavigationExpansionExpression navigationExpansionExpression) { - var includeFinder = new PendingIncludeFindingVisitor(); - includeFinder.Visit(navigationExpansionExpression.State.PendingSelector.Body); - - var includeRewriter = new PendingSelectorIncludeRewriter(); - var rewrittenBody = includeRewriter.Visit(navigationExpansionExpression.State.PendingSelector.Body); + var includeVisitor = new PendingSelectorIncludeVisitor(); + var rewrittenBody = includeVisitor.Visit(navigationExpansionExpression.State.PendingSelector.Body); if (navigationExpansionExpression.State.PendingSelector.Body != rewrittenBody) { @@ -112,14 +112,14 @@ protected override Expression VisitExtension(Expression extensionExpression) navigationExpansionExpression.State.ApplyPendingSelector = true; } - if (includeFinder.PendingIncludes.Count > 0) + if (includeVisitor.PendingIncludes.Count > 0) { - var result = (source: navigationExpansionExpression.Operand, parameter: navigationExpansionExpression.State.CurrentParameter); - foreach (var pendingIncludeNode in includeFinder.PendingIncludes) + var result = (Source: navigationExpansionExpression.Operand, Parameter: navigationExpansionExpression.State.CurrentParameter); + foreach (var pendingIncludeNode in includeVisitor.PendingIncludes) { result = NavigationExpansionHelpers.AddNavigationJoin( - result.source, - result.parameter, + result.Source, + result.Parameter, pendingIncludeNode.SourceMapping, pendingIncludeNode.NavTreeNode, navigationExpansionExpression.State, @@ -128,14 +128,14 @@ protected override Expression VisitExtension(Expression extensionExpression) } var pendingSelector = navigationExpansionExpression.State.PendingSelector; - if (navigationExpansionExpression.State.CurrentParameter != result.parameter) + if (navigationExpansionExpression.State.CurrentParameter != result.Parameter) { - var pendingSelectorBody = new ExpressionReplacingVisitor(navigationExpansionExpression.State.CurrentParameter, result.parameter).Visit(navigationExpansionExpression.State.PendingSelector.Body); - pendingSelector = Expression.Lambda(pendingSelectorBody, result.parameter); + var pendingSelectorBody = new ExpressionReplacingVisitor(navigationExpansionExpression.State.CurrentParameter, result.Parameter).Visit(navigationExpansionExpression.State.PendingSelector.Body); + pendingSelector = Expression.Lambda(pendingSelectorBody, result.Parameter); } var newState = new NavigationExpansionExpressionState( - result.parameter, + result.Parameter, navigationExpansionExpression.State.SourceMappings, pendingSelector, applyPendingSelector: true, @@ -145,10 +145,10 @@ protected override Expression VisitExtension(Expression extensionExpression) navigationExpansionExpression.State.CustomRootMappings, navigationExpansionExpression.State.MaterializeCollectionNavigation); - return (operand: result.source, state: newState); + return (Operand: result.Source, newState); } - return (operand: navigationExpansionExpression.Operand, state: navigationExpansionExpression.State); + return (navigationExpansionExpression.Operand, navigationExpansionExpression.State); } public static MethodInfo ToOrderedQueryableMethod = typeof(NavigationExpansionReducingVisitor).GetMethod(nameof(ToOrderedQueryable)); diff --git a/src/EFCore/Query/NavigationExpansion/Visitors/NavigationPropertyBindingVisitor.cs b/src/EFCore/Query/NavigationExpansion/Visitors/NavigationPropertyBindingVisitor.cs index fb6f8c76519..cde470327b2 100644 --- a/src/EFCore/Query/NavigationExpansion/Visitors/NavigationPropertyBindingVisitor.cs +++ b/src/EFCore/Query/NavigationExpansion/Visitors/NavigationPropertyBindingVisitor.cs @@ -27,22 +27,15 @@ public NavigationPropertyBindingVisitor( protected override Expression VisitExtension(Expression extensionExpression) { - if (extensionExpression is NavigationBindingExpression navigationBindingExpression) + switch (extensionExpression) { - return navigationBindingExpression; + case NavigationBindingExpression _: + case CustomRootExpression _: + case NavigationExpansionRootExpression _: + return extensionExpression; + default: + return base.VisitExtension(extensionExpression); } - - if (extensionExpression is CustomRootExpression customRootExpression) - { - return customRootExpression; - } - - if (extensionExpression is NavigationExpansionRootExpression navigationExpansionRootExpression) - { - return navigationExpansionRootExpression; - } - - return base.VisitExtension(extensionExpression); } protected override Expression VisitLambda(Expression lambdaExpression) @@ -152,7 +145,8 @@ private Expression TryBindProperty(Expression originalExpression, Expression new var navigation = navigationBindingExpression.EntityType.FindNavigation(navigationMemberName); if (navigation != null) { - var navigationTreeNode = NavigationTreeNode.Create(navigationBindingExpression.SourceMapping, navigation, navigationBindingExpression.NavigationTreeNode, _bindInclude); + var navigationTreeNode = NavigationTreeNode.Create( + navigationBindingExpression.SourceMapping, navigation, navigationBindingExpression.NavigationTreeNode, _bindInclude); return new NavigationBindingExpression( navigationBindingExpression.RootParameter, @@ -167,7 +161,8 @@ private Expression TryBindProperty(Expression originalExpression, Expression new { foreach (var sourceMapping in _sourceMappings) { - var candidates = sourceMapping.NavigationTree.Flatten().SelectMany(n => n.FromMappings, (n, m) => (navigationTreeNode: n, path: m)).ToList(); + var candidates = sourceMapping.NavigationTree.Flatten() + .SelectMany(n => n.FromMappings, (n, m) => (navigationTreeNode: n, path: m)).ToList(); var match = TryFindMatchingNavigationTreeNode(originalExpression, candidates); if (match.navigationTreeNode != null) { diff --git a/src/EFCore/Query/NavigationExpansion/Visitors/PendingIncludeFindingVisitor.cs b/src/EFCore/Query/NavigationExpansion/Visitors/PendingIncludeFindingVisitor.cs deleted file mode 100644 index 963432490d5..00000000000 --- a/src/EFCore/Query/NavigationExpansion/Visitors/PendingIncludeFindingVisitor.cs +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System.Collections.Generic; -using System.Linq.Expressions; -using Microsoft.EntityFrameworkCore.Internal; - -namespace Microsoft.EntityFrameworkCore.Query.NavigationExpansion.Visitors -{ - - public class PendingIncludeFindingVisitor : ExpressionVisitor - { - private bool _skipCollectionNavigations; - - public PendingIncludeFindingVisitor(bool skipCollectionNavigations = true) - { - _skipCollectionNavigations = skipCollectionNavigations; - } - - public virtual List<(NavigationTreeNode NavTreeNode, SourceMapping SourceMapping)> PendingIncludes { get; } = - new List<(NavigationTreeNode, SourceMapping)>(); - - protected override Expression VisitMember(MemberExpression memberExpression) - { - if (memberExpression.Expression is NavigationBindingExpression navigationBindingExpression - && navigationBindingExpression.EntityType.FindProperty(memberExpression.Member) != null) - { - return memberExpression; - } - - Visit(memberExpression.Expression); - - return memberExpression; - } - - protected override Expression VisitInvocation(InvocationExpression invocationExpression) => invocationExpression; - protected override Expression VisitLambda(Expression lambdaExpression) => lambdaExpression; - protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExpression) => typeBinaryExpression; - - protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) - => methodCallExpression.IsEFProperty() - ? methodCallExpression - : base.VisitMethodCall(methodCallExpression); - - protected override Expression VisitConditional(ConditionalExpression conditionalExpression) - { - Visit(conditionalExpression.IfTrue); - Visit(conditionalExpression.IfFalse); - - return conditionalExpression; - } - - protected override Expression VisitBinary(BinaryExpression binaryExpression) - { - return binaryExpression.NodeType == ExpressionType.Coalesce - ? base.VisitBinary(binaryExpression) - : binaryExpression; - } - - protected override Expression VisitExtension(Expression extensionExpression) - { - // TODO: what about nested scenarios i.e. NavigationExpansionExpression inside pending selector? - add tests - if (extensionExpression is NavigationBindingExpression navigationBindingExpression) - { - // find all nodes and children UNTIL you find a collection in that subtree - // collection navigations will be converted to their own NavigationExpansionExpressions and their child includes will be applied when those NavigationExpansionExpressions are processed - FindPendingReferenceIncludes(navigationBindingExpression.NavigationTreeNode, navigationBindingExpression.SourceMapping); - - return navigationBindingExpression; - } - - if (extensionExpression is CustomRootExpression customRootExpression) - { - return customRootExpression; - } - - if (extensionExpression is NavigationExpansionRootExpression expansionRootExpression) - { - return expansionRootExpression; - } - - if (extensionExpression is NavigationExpansionExpression navigationExpansionExpression) - { - return navigationExpansionExpression; - } - - return base.VisitExtension(extensionExpression); - } - - private void FindPendingReferenceIncludes(NavigationTreeNode node, SourceMapping sourceMapping) - { - if (_skipCollectionNavigations && node.Navigation != null && node.Navigation.IsCollection()) - { - return; - } - - if (node.ExpansionMode != NavigationTreeNodeExpansionMode.ReferenceComplete - && (node.Included == NavigationTreeNodeIncludeMode.ReferencePending - || !_skipCollectionNavigations && node.Included == NavigationTreeNodeIncludeMode.Collection)) - { - PendingIncludes.Add((node, sourceMapping)); - } - - foreach (var child in node.Children) - { - FindPendingReferenceIncludes(child, sourceMapping); - } - } - } -} diff --git a/src/EFCore/Query/NavigationExpansion/Visitors/PendingSelectorIncludeVisitor.cs b/src/EFCore/Query/NavigationExpansion/Visitors/PendingSelectorIncludeVisitor.cs new file mode 100644 index 00000000000..bccddc17c77 --- /dev/null +++ b/src/EFCore/Query/NavigationExpansion/Visitors/PendingSelectorIncludeVisitor.cs @@ -0,0 +1,160 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using Microsoft.EntityFrameworkCore.Internal; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Metadata.Internal; + +namespace Microsoft.EntityFrameworkCore.Query.NavigationExpansion.Visitors +{ + public class PendingSelectorIncludeVisitor : ExpressionVisitor + { + private bool _skipCollectionNavigations; + private bool _rewriteIncludes; + + public PendingSelectorIncludeVisitor(bool skipCollectionNavigations = true, bool rewriteIncludes = true) + { + _skipCollectionNavigations = skipCollectionNavigations; + _rewriteIncludes = rewriteIncludes; + } + + public virtual List<(NavigationTreeNode NavTreeNode, SourceMapping SourceMapping)> PendingIncludes { get; } = + new List<(NavigationTreeNode, SourceMapping)>(); + + protected override Expression VisitMember(MemberExpression memberExpression) + { + if (memberExpression.Expression is NavigationBindingExpression navigationBindingExpression + && navigationBindingExpression.EntityType.FindProperty(memberExpression.Member) != null) + { + return memberExpression; + } + + var newExpression = Visit(memberExpression.Expression); + + return newExpression != memberExpression.Expression + ? Expression.MakeMemberAccess(newExpression, memberExpression.Member) + : memberExpression; + } + + protected override Expression VisitInvocation(InvocationExpression invocationExpression) => invocationExpression; + protected override Expression VisitLambda(Expression lambdaExpression) => lambdaExpression; + protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExpression) => typeBinaryExpression; + + protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) + => methodCallExpression.IsEFProperty() + ? methodCallExpression + : base.VisitMethodCall(methodCallExpression); + + protected override Expression VisitConditional(ConditionalExpression conditionalExpression) + { + var newIfTrue = Visit(conditionalExpression.IfTrue); + var newIfFalse = Visit(conditionalExpression.IfFalse); + + return newIfTrue != conditionalExpression.IfTrue || newIfFalse != conditionalExpression.IfFalse + ? conditionalExpression.Update(conditionalExpression.Test, newIfTrue, newIfFalse) + : conditionalExpression; + } + + protected override Expression VisitBinary(BinaryExpression binaryExpression) + => binaryExpression.NodeType == ExpressionType.Coalesce + ? base.VisitBinary(binaryExpression) + : binaryExpression; + + protected override Expression VisitExtension(Expression extensionExpression) + { + // TODO: what about nested scenarios i.e. NavigationExpansionExpression inside pending selector? - add tests + switch (extensionExpression) + { + case NavigationBindingExpression navigationBindingExpression: + ProcessEagerLoadedNavigations( + navigationBindingExpression.NavigationTreeNode, + navigationBindingExpression.EntityType, + navigationBindingExpression.SourceMapping); + + // find all nodes and children UNTIL you find a collection in that subtree + // collection navigations will be converted to their own NavigationExpansionExpressions and their child includes will be applied when those NavigationExpansionExpressions are processed + return ProcessIncludes( + navigationBindingExpression, + navigationBindingExpression.NavigationTreeNode, + navigationBindingExpression.RootParameter, + navigationBindingExpression.SourceMapping); + case CustomRootExpression _: + case NavigationExpansionRootExpression _: + case NavigationExpansionExpression _: + return extensionExpression; + } + + return base.VisitExtension(extensionExpression); + } + + private void ProcessEagerLoadedNavigations(NavigationTreeNode node, IEntityType entityType, SourceMapping sourceMapping) + { + foreach (var child in node.Children) + { + ProcessEagerLoadedNavigations(child, child.Navigation.GetTargetType(), sourceMapping); + } + + var outboundNavigations + = entityType.GetNavigations() + .Concat(entityType.GetDerivedNavigations()) + .Where(n => n.IsEagerLoaded()); + + foreach (var navigation in outboundNavigations) + { + var newNode = NavigationTreeNode.Create(sourceMapping, navigation, node, include: true); + ProcessEagerLoadedNavigations(newNode, navigation.GetTargetType(), sourceMapping); + } + } + + private Expression ProcessIncludes( + Expression caller, NavigationTreeNode node, ParameterExpression rootParameter, SourceMapping sourceMapping) + { + var included = caller; + var skipChildren = false; + if (node.Navigation != null) + { + if (node.Navigation.IsCollection()) + { + skipChildren = _skipCollectionNavigations; + } + + if (_rewriteIncludes) + { + if (node.Navigation.IsCollection()) + { + included = CollectionNavigationRewritingVisitor.CreateCollectionNavigationExpression(node, rootParameter, sourceMapping); + } + else + { + var entityType = node.Navigation.GetTargetType(); + included = new NavigationBindingExpression(rootParameter, node, entityType, sourceMapping, entityType.ClrType); + } + } + } + + if (!skipChildren) + { + foreach (var child in node.Children.Where(n => n.IncludeState == NavigationState.ReferencePending + || n.IncludeState == NavigationState.CollectionPending)) + { + if (node.ExpansionState != NavigationState.ReferenceComplete + && node.ExpansionState != NavigationState.CollectionComplete + && (node.IncludeState == NavigationState.ReferencePending + || (!_skipCollectionNavigations && node.IncludeState == NavigationState.CollectionPending))) + { + PendingIncludes.Add((node, sourceMapping)); + } + + included = ProcessIncludes(included, child, rootParameter, sourceMapping); + } + } + + return _rewriteIncludes && node.Navigation != null + ? new IncludeExpression(caller, included, node.Navigation) + : included; + } + } +} diff --git a/src/EFCore/Query/Pipeline/EntityEqualityRewritingExpressionVisitor.cs b/src/EFCore/Query/Pipeline/EntityEqualityRewritingExpressionVisitor.cs index 5892fcbd30f..b54a0c9bcd4 100644 --- a/src/EFCore/Query/Pipeline/EntityEqualityRewritingExpressionVisitor.cs +++ b/src/EFCore/Query/Pipeline/EntityEqualityRewritingExpressionVisitor.cs @@ -17,10 +17,10 @@ namespace Microsoft.EntityFrameworkCore.Query.Pipeline { /// - /// Rewrites comparisons of entities (as opposed to comparisons of their properties) into comparison of their keys. + /// Rewrites comparisons of entities (as opposed to comparisons of their properties) into comparison of their keys. /// /// - /// For example, an expression such as cs.Where(c => c == something) would be rewritten to cs.Where(c => c.Id == something.Id). + /// For example, an expression such as cs.Where(c => c == something) would be rewritten to cs.Where(c => c.Id == something.Id). /// public class EntityEqualityRewritingExpressionVisitor : ExpressionVisitor { diff --git a/src/EFCore/Query/Pipeline/QueryOptimizingExpressionVisitor.cs b/src/EFCore/Query/Pipeline/QueryOptimizingExpressionVisitor.cs index ee245275c6a..e3e24ca12c1 100644 --- a/src/EFCore/Query/Pipeline/QueryOptimizingExpressionVisitor.cs +++ b/src/EFCore/Query/Pipeline/QueryOptimizingExpressionVisitor.cs @@ -4,7 +4,7 @@ using System.Linq; using System.Linq.Expressions; using System.Reflection; -using Microsoft.EntityFrameworkCore.Query.NavigationExpansion; +using Microsoft.EntityFrameworkCore.Query.NavigationExpansion.Visitors; namespace Microsoft.EntityFrameworkCore.Query.Pipeline { @@ -24,7 +24,8 @@ public Expression Visit(Expression query) query = new GroupJoinFlatteningExpressionVisitor().Visit(query); query = new NullCheckRemovingExpressionVisitor().Visit(query); query = new EntityEqualityRewritingExpressionVisitor(_queryCompilationContext).Rewrite(query); - query = new NavigationExpander(_queryCompilationContext.Model).ExpandNavigations(query); + query = new NavigationExpandingVisitor(_queryCompilationContext).Visit(query); + query = new NavigationExpansionReducingVisitor().Visit(query); query = new EnumerableToQueryableReMappingExpressionVisitor().Visit(query); query = new NullCheckRemovingExpressionVisitor().Visit(query); query = new FunctionPreprocessingVisitor().Visit(query); diff --git a/src/EFCore/Query/Pipeline/QueryOptimizingExpressionVisitorFactory.cs b/src/EFCore/Query/Pipeline/QueryOptimizingExpressionVisitorFactory.cs index 01a3c3465ee..d42aebaf1e9 100644 --- a/src/EFCore/Query/Pipeline/QueryOptimizingExpressionVisitorFactory.cs +++ b/src/EFCore/Query/Pipeline/QueryOptimizingExpressionVisitorFactory.cs @@ -6,8 +6,6 @@ namespace Microsoft.EntityFrameworkCore.Query.Pipeline public class QueryOptimizerFactory : IQueryOptimizerFactory { public QueryOptimizer Create(QueryCompilationContext queryCompilationContext) - { - return new QueryOptimizer(queryCompilationContext); - } + => new QueryOptimizer(queryCompilationContext); } } diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/OwnedQueryCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/Query/OwnedQueryCosmosTest.cs index 1408cfb1c4e..1a05c3ac792 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/OwnedQueryCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/OwnedQueryCosmosTest.cs @@ -43,7 +43,7 @@ public override void Navigation_rewrite_on_owned_collection_with_composition_com base.Navigation_rewrite_on_owned_collection_with_composition_complex(); } - [ConditionalFact(Skip = "Owned collection #12086")] + [ConditionalFact(Skip = "Owned projection #12086")] public override void Navigation_rewrite_on_owned_reference_projecting_entity() { base.Navigation_rewrite_on_owned_reference_projecting_entity(); @@ -58,18 +58,33 @@ FROM root c public override void Query_for_base_type_loads_all_owned_navs() { base.Query_for_base_type_loads_all_owned_navs(); + + AssertSql( + @"SELECT c +FROM root c +WHERE (((c[""Discriminator""] = ""LeafB"") OR ((c[""Discriminator""] = ""LeafA"") OR ((c[""Discriminator""] = ""Branch"") OR (c[""Discriminator""] = ""OwnedPerson"")))) AND (c[""PersonAddress""][""Country""][""Name""] = ""USA""))"); } [ConditionalFact(Skip = "Owned collection #12086")] public override void Query_for_branch_type_loads_all_owned_navs() { base.Query_for_branch_type_loads_all_owned_navs(); + + AssertSql( + @"SELECT c +FROM root c +WHERE (((c[""Discriminator""] = ""LeafB"") OR ((c[""Discriminator""] = ""LeafA"") OR ((c[""Discriminator""] = ""Branch"") OR (c[""Discriminator""] = ""OwnedPerson"")))) AND (c[""PersonAddress""][""Country""][""Name""] = ""USA""))"); } [ConditionalFact(Skip = "Owned collection #12086")] public override void Query_for_leaf_type_loads_all_owned_navs() { base.Query_for_leaf_type_loads_all_owned_navs(); + + AssertSql( + @"SELECT c +FROM root c +WHERE (((c[""Discriminator""] = ""LeafB"") OR ((c[""Discriminator""] = ""LeafA"") OR ((c[""Discriminator""] = ""Branch"") OR (c[""Discriminator""] = ""OwnedPerson"")))) AND (c[""PersonAddress""][""Country""][""Name""] = ""USA""))"); } [ConditionalFact(Skip = "LeftJoin #12086")] diff --git a/test/EFCore.Cosmos.FunctionalTests/TestUtilities/CosmosTestStore.cs b/test/EFCore.Cosmos.FunctionalTests/TestUtilities/CosmosTestStore.cs index 18eeabaec5d..3777dd6ff3b 100644 --- a/test/EFCore.Cosmos.FunctionalTests/TestUtilities/CosmosTestStore.cs +++ b/test/EFCore.Cosmos.FunctionalTests/TestUtilities/CosmosTestStore.cs @@ -34,7 +34,7 @@ public static CosmosTestStore GetOrCreate(string name, string dataFilePath) private CosmosTestStore( string name, bool shared = true, string dataFilePath = null, Action extensionConfiguration = null) - : base(name + _runId.ToString(), shared) + : base(CreateName(name), shared) { ConnectionUri = TestEnvironment.DefaultConnection; AuthToken = TestEnvironment.AuthToken; @@ -50,6 +50,8 @@ private CosmosTestStore( } } + private static string CreateName(string name) => name == "Northwind" ? name : (name + _runId.ToString()); + public string ConnectionUri { get; } public string AuthToken { get; } public Action ConfigureCosmos => _configureCosmos ?? (_ => { }); diff --git a/test/EFCore.Cosmos.FunctionalTests/TestUtilities/TestEnvironment.cs b/test/EFCore.Cosmos.FunctionalTests/TestUtilities/TestEnvironment.cs index a732712f3ee..a63bf469c18 100644 --- a/test/EFCore.Cosmos.FunctionalTests/TestUtilities/TestEnvironment.cs +++ b/test/EFCore.Cosmos.FunctionalTests/TestUtilities/TestEnvironment.cs @@ -16,8 +16,12 @@ public static class TestEnvironment .Build() .GetSection("Test:Cosmos"); - public static string DefaultConnection { get; } = Config["DefaultConnection"] ?? "https://localhost:8081"; + public static string DefaultConnection { get; } = string.IsNullOrEmpty(Config["DefaultConnection"]) + ? "https://localhost:8081" + : Config["DefaultConnection"]; - public static string AuthToken { get; } = Config["AuthToken"] ?? ""; + public static string AuthToken { get; } = string.IsNullOrEmpty(Config["AuthToken"]) + ? "C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==" + : Config["AuthToken"]; } } diff --git a/test/EFCore.Specification.Tests/OptimisticConcurrencyTestBase.cs b/test/EFCore.Specification.Tests/OptimisticConcurrencyTestBase.cs index 053c761b0ea..5bf8d29bc82 100644 --- a/test/EFCore.Specification.Tests/OptimisticConcurrencyTestBase.cs +++ b/test/EFCore.Specification.Tests/OptimisticConcurrencyTestBase.cs @@ -9,7 +9,6 @@ using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Storage; using Microsoft.EntityFrameworkCore.TestModels.ConcurrencyModel; -using Microsoft.EntityFrameworkCore.TestUtilities.Xunit; using Microsoft.Extensions.Logging; using Xunit;