Skip to content

Commit

Permalink
Implement filter transformation to restrict the set of data to be agg…
Browse files Browse the repository at this point in the history
…regated
  • Loading branch information
John Gathogo committed Nov 25, 2020
1 parent f75c071 commit 868a0b2
Show file tree
Hide file tree
Showing 6 changed files with 250 additions and 24 deletions.
49 changes: 48 additions & 1 deletion src/Microsoft.OData.Client/ALinq/ApplyQueryOptionExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,19 @@ namespace Microsoft.OData.Client
{
using System;
using System.Collections.Generic;
using System.Linq.Expressions;
using System.Collections.ObjectModel;
using System.Linq.Expressions;
using Microsoft.OData.UriParser.Aggregation;

/// <summary>
/// A resource specific expression representing an apply query option.
/// </summary>
internal class ApplyQueryOptionExpression : QueryOptionExpression
{
/// <summary>
/// The filter expressions that make the filter predicate
/// </summary>
private readonly List<Expression> filterExpressions;

/// <summary>
/// Creates an <see cref="ApplyQueryOptionExpression"/> expression.
Expand All @@ -25,6 +30,7 @@ internal ApplyQueryOptionExpression(Type type)
: base(type)
{
this.Aggregations = new List<Aggregation>();
this.filterExpressions = new List<Expression>();
}

/// <summary>
Expand All @@ -39,6 +45,47 @@ public override ExpressionType NodeType
/// </summary>
internal List<Aggregation> Aggregations { get; private set; }

/// <summary>
/// Adds the conjuncts to the filter expressions
/// </summary>
internal void AddPredicateConjuncts(IEnumerable<Expression> predicates)
{
this.filterExpressions.AddRange(predicates);
}

internal ReadOnlyCollection<Expression> PredicateConjuncts
{
get
{
return new ReadOnlyCollection<Expression>(this.filterExpressions);
}
}

/// <summary>
/// Gets filter transformation predicate.
/// </summary>
/// <returns>A predicate with all conjuncts AND'd</returns>
internal Expression GetPredicate()
{
Expression combinedPredicate = null;
bool isFirst = true;

foreach (Expression expr in this.filterExpressions)
{
if (isFirst)
{
combinedPredicate = expr;
isFirst = false;
}
else
{
combinedPredicate = Expression.And(combinedPredicate, expr);
}
}

return combinedPredicate;
}

/// <summary>
/// Structure for an aggregation. Holds lambda expression plus enum indicating aggregation method
/// </summary>
Expand Down
50 changes: 50 additions & 0 deletions src/Microsoft.OData.Client/ALinq/QueryableResourceExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,56 @@ internal void AddFilter(IEnumerable<Expression> predicateConjuncts)
this.keyPredicateConjuncts.Clear();
}

internal void AddApply(Expression aggregateExpr, OData.UriParser.Aggregation.AggregationMethod aggregationMethod)
{
if (this.OrderBy != null)
{
throw new NotSupportedException(Strings.ALinq_QueryOptionOutOfOrder("apply", "orderby"));
}
else if (this.Skip != null)
{
// $skip and $top may be used together with rollup.
// However, support for rollup is currently not implemented in OData WebApi
// If $skip and/or $top appears before $apply, its currently ignored.
// Makes sense to throw an exception to avoid giving a false impression.
throw new NotSupportedException(Strings.ALinq_QueryOptionOutOfOrder("apply", "skip"));
}
else if (this.Take != null)
{
throw new NotSupportedException(Strings.ALinq_QueryOptionOutOfOrder("apply", "top"));
}

if (this.Apply == null)
{
AddSequenceQueryOption(new ApplyQueryOptionExpression(this.Type));
}

if (this.Filter != null && this.Filter.PredicateConjuncts.Count > 0)
{
// The $apply query option is evaluated first, then other query options ($filter, $orderby, $select) are evaluated,
// if applicable, on the result of $apply in their normal order.
// http://docs.oasis-open.org/odata/odata-data-aggregation-ext/v4.0/cs02/odata-data-aggregation-ext-v4.0-cs02.html#_Toc435016590

// If a Where appears before an aggregation method (e.g. Average, Sum, etc) or GroupBy,
// the conjuncts of the filter expression will be used to restrict the set of data to be aggregated.
// They will not appear on the $filter query option. Instead, we use them to construct a filter transformation.
// E.g. /Sales?$apply=filter(Amount gt 1)/aggregate(Amount with average as AverageAmount)

// If a Where appears after an aggregation method or GroupBy, the conjuncts should appear
// on a $filter query option after the $apply.
// E.g. /Sales?$apply=groupby((Product/Color),aggregate(Amount with average as AverageAmount))&$filter=Product/Color eq 'Brown'

// To separate the two sets of possible conjuncts, we store those that appear in the Where
// before aggregate method in the Apply query option object.
// We also don't concern ourselves with whether there's a key predicate or not since a ByKey query is not applicable
this.Apply.AddPredicateConjuncts(this.Filter.PredicateConjuncts);
this.keyPredicateConjuncts.Clear();
this.RemoveFilterExpression();
}

this.Apply.Aggregations.Add(new ApplyQueryOptionExpression.Aggregation(aggregateExpr, aggregationMethod));
}

/// <summary>
/// Add query option to resource expression
/// </summary>
Expand Down
21 changes: 2 additions & 19 deletions src/Microsoft.OData.Client/ALinq/ResourceBinder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -825,10 +825,7 @@ private static Expression AnalyzeAggregation(MethodCallExpression methodCallExpr
return methodCallExpr;
}

EnsureApplyInitialized(input);
Debug.Assert(input.Apply != null, "input.Apply != null");

input.Apply.Aggregations.Add(new ApplyQueryOptionExpression.Aggregation(selector, aggregationMethod));
input.AddApply(selector, aggregationMethod);

return input;
}
Expand Down Expand Up @@ -1619,20 +1616,6 @@ private static Expression StripCastMethodCalls(Expression expression)
return expression;
}

/// <summary>
/// Ensure apply query option for the resource set is initialized
/// </summary>
/// <param name="input">The resource expression</param>
private static void EnsureApplyInitialized(QueryableResourceExpression input)
{
Debug.Assert(input != null, "input != null");

if (input.Apply == null)
{
AddSequenceQueryOption(input, new ApplyQueryOptionExpression(input.Type));
}
}

/// <summary>Use this class to perform pattern-matching over expression trees.</summary>
/// <remarks>
/// Following these guidelines simplifies usage:
Expand Down Expand Up @@ -2991,7 +2974,7 @@ internal static void ValidateExpandPath(Expression input, DataServiceContext con

/// <summary>
/// Checks whether the specified <paramref name="expr"/> is a valid aggregate expression.
/// A valid aggregate expression must be translatable into a path to an aggregatable property.
/// An aggregate expression must evaluate to a single-valued property path to an aggregatable property.
/// </summary>
/// <param name="expr">The aggregate expression</param>
internal static void ValidateAggregateExpression(Expression expr)
Expand Down
42 changes: 38 additions & 4 deletions src/Microsoft.OData.Client/ALinq/UriWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,41 @@ internal void VisitQueryOptionExpression(ApplyQueryOptionExpression applyQueryOp
return;
}

// E.g. filter(Amount gt 1)
string filterTransformation = ConstructFilterTransformation(applyQueryOptionExpr);
// E.g. aggregate(Prop with sum as SumProp, Prop with average as AverageProp)
string aggregateTransformation = ConstructAggregateTransformation(applyQueryOptionExpr.Aggregations);

string applyExpression = string.IsNullOrWhiteSpace(filterTransformation) ? string.Empty : filterTransformation + "/";
applyExpression += aggregateTransformation;

this.AddAsCachedQueryOption(UriHelper.DOLLARSIGN + UriHelper.OPTIONAPPLY, applyExpression);
}

/// <summary>
/// Constructs a $apply filter transformation.
/// E.g. $apply=filter(Amount gt 1)
/// </summary>
/// <param name="applyQueryOptionExpr">ApplyQueryOptionExpression expression</param>
/// <returns>A filter transformation</returns>
private string ConstructFilterTransformation(ApplyQueryOptionExpression applyQueryOptionExpr)
{
if (applyQueryOptionExpr.PredicateConjuncts.Count == 0)
{
return string.Empty;
}

return "filter(" + this.ExpressionToString(applyQueryOptionExpr.GetPredicate(), /*inPath*/ false) + ")"; ;
}

/// <summary>
/// Constructs a $apply aggregate transformation.
/// E.g. $apply=aggregate(Prop with sum as SumProp, Prop with average as AverageProp)
/// </summary>
/// <param name="aggregations">List of aggregations.</param>
/// <returns>The aggregate tranformation.</returns>
private string ConstructAggregateTransformation(IList<ApplyQueryOptionExpression.Aggregation> aggregations)
{
StringBuilder aggregateBuilder = new StringBuilder();

aggregateBuilder.Append(UriHelper.AGGREGATE);
Expand All @@ -604,7 +639,7 @@ internal void VisitQueryOptionExpression(ApplyQueryOptionExpression applyQueryOp

while (true)
{
ApplyQueryOptionExpression.Aggregation aggregation = applyQueryOptionExpr.Aggregations[i];
ApplyQueryOptionExpression.Aggregation aggregation = aggregations[i];
AggregationMethod aggregationMethod = aggregation.AggregationMethod;
string aggregationAlias = aggregation.AggregationAlias;

Expand Down Expand Up @@ -644,7 +679,7 @@ internal void VisitQueryOptionExpression(ApplyQueryOptionExpression applyQueryOp

aggregateBuilder.Append(aggregationAlias);

if (++i == applyQueryOptionExpr.Aggregations.Count)
if (++i == aggregations.Count)
{
break;
}
Expand All @@ -654,8 +689,7 @@ internal void VisitQueryOptionExpression(ApplyQueryOptionExpression applyQueryOp

aggregateBuilder.Append(UriHelper.RIGHTPAREN);

// e.g. $apply=aggregate(Prop with sum as SumProp, Prop with average as AverageProp)
this.AddAsCachedQueryOption(UriHelper.DOLLARSIGN + UriHelper.OPTIONAPPLY, aggregateBuilder.ToString());
return aggregateBuilder.ToString();
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using Xunit;
Expand Down Expand Up @@ -340,6 +341,97 @@ public void CountDistinct_NotSupportedException_ThrownForCollectionProperty()
Assert.Throws<NotSupportedException>(() => queryable.CountDistinct(d => d.Sales));
}

[Theory]
[InlineData("Average")]
[InlineData("Sum")]
[InlineData("Min")]
[InlineData("Max")]
public void Aggregation_OnFilteredInputSet(string aggregationMethodName)
{
// Arrange
var queryable = this.dsContext.CreateQuery<Sale>(salesEntitySetName);

PropertyInfo propertyInfo = queryable.ElementType.GetProperty("Amount");
var parameter1Expr = Expression.Parameter(queryable.ElementType, "d1");
// d1.Amount
var memberExpr = Expression.MakeMemberAccess(parameter1Expr, propertyInfo);
// d1.Amount > 1
var greaterThanExpr = Expression.GreaterThan(memberExpr, Expression.Constant((decimal)1));

// Get Where method
var whereMethod = GetWhereMethod();
// .Where(d1 => d1.Amount > 1)
var whereExpr = Expression.Call(
null,
whereMethod.MakeGenericMethod(new Type[] { queryable.ElementType }),
new[] {
queryable.Expression,
Expression.Lambda<Func<Sale, bool>>(greaterThanExpr, parameter1Expr)
});

var parameter2Expr = Expression.Parameter(queryable.ElementType, "d2");
// d2 => d2.Amount
var selectorExpr = Expression.Lambda(
Expression.MakeMemberAccess(parameter2Expr, propertyInfo),
parameter2Expr);

var propertyType = ((MemberExpression)selectorExpr.Body).Type;
// Get aggregation method
var aggregationMethod = GetAggregationMethod(aggregationMethodName, propertyType);

List<Type> genericArguments = new List<Type>();
genericArguments.Add(queryable.ElementType);
if (aggregationMethod.GetGenericArguments().Length > 1)
{
genericArguments.Add(propertyType);
}

// E.g .Where(d1 => d1.Amount > 1).Average(d2 => d2.Amount)
var aggregationMethodExpr = Expression.Call(
null,
aggregationMethod.MakeGenericMethod(genericArguments.ToArray()),
new Expression[] { whereExpr, Expression.Quote(selectorExpr) });

// Act
// Call factory method for creating DataServiceOrderedQuery based on expression
var query = new DataServiceQueryProvider(dsContext).CreateQuery(aggregationMethodExpr);

// Assert
var expectedAggregateUri = $"{serviceUri}/{salesEntitySetName}?$apply=filter(Amount gt 1)" +
$"/aggregate(Amount with {aggregationMethodName.ToLower()} as {aggregationMethodName}Amount)";
Assert.Equal(expectedAggregateUri, query.ToString());
}

[Fact]
public void Aggregation_PrecededByOrderBy_Throws_NotSupportedException()
{
// Arrange
var queryable = this.dsContext.CreateQuery<Sale>(salesEntitySetName);

// Act & Assert
Assert.Throws<NotSupportedException>(() => queryable.OrderBy(d => d.Id).Average(d => d.Amount));
}

[Fact]
public void Aggregation_PrecededBySkip_Throws_NotSupportedException()
{
// Arrange
var queryable = this.dsContext.CreateQuery<Sale>(salesEntitySetName);

// Act & Assert
Assert.Throws<NotSupportedException>(() => queryable.Skip(1).Sum(d => d.Amount));
}

[Fact]
public void Aggregation_PrecededByTake_Throws_NotSupportedException()
{
// Arrange
var queryable = this.dsContext.CreateQuery<Sale>(salesEntitySetName);

// Act & Assert
Assert.Throws<NotSupportedException>(() => queryable.Take(1).Min(d => d.Amount));
}

#region Mock Aggregation Responses

private void MockCountDistinct()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,26 @@ protected static MethodInfo GetCountDistinctMethod()
.Select(d6 => d6.Method).Single();
}

/// <summary>
/// Uses reflection to find the Where method
/// </summary>
protected static MethodInfo GetWhereMethod()
{
return typeof(Queryable).GetMethods(BindingFlags.Static | BindingFlags.Public)
.Where(d1 => d1.Name.Equals("Where", StringComparison.Ordinal))
.Select(d2 => new { Method = d2, Parameters = d2.GetParameters() })
.Where(d3 => d3.Parameters.Length.Equals(2)
&& d3.Parameters[0].ParameterType.IsGenericType
&& d3.Parameters[0].ParameterType.GetGenericTypeDefinition().Equals(typeof(IQueryable<>))
&& d3.Parameters[1].ParameterType.IsGenericType
&& d3.Parameters[1].ParameterType.GetGenericTypeDefinition().Equals(typeof(Expression<>)))
.Select(d4 => new { d4.Method, SelectorArguments = d4.Parameters[1].ParameterType.GetGenericArguments() })
.Where(d5 => d5.SelectorArguments.Length.Equals(1)
&& d5.SelectorArguments[0].IsGenericType
&& d5.SelectorArguments[0].GetGenericTypeDefinition().Equals(typeof(Func<,>))) // Func<TSource, Boolean>
.Select(d6 => d6.Method).Single();
}

/// <summary>
/// Builds a method call expression dynamically.
/// </summary>
Expand Down

0 comments on commit 868a0b2

Please sign in to comment.