Skip to content

Commit

Permalink
Fix Average to support nullable types (#704)
Browse files Browse the repository at this point in the history
  • Loading branch information
StefH authored May 22, 2023
1 parent d8754c2 commit 4cc72c4
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 35 deletions.
6 changes: 6 additions & 0 deletions src/System.Linq.Dynamic.Core/Parser/TypeHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,12 @@ public static string GetTypeName(Type? type)
return name;
}

public static Type GetNullableType(Type type)
{
type = Nullable.GetUnderlyingType(type) ?? type;
return type.GetTypeInfo().IsValueType ? typeof(Nullable<>).MakeGenericType(type) : type;
}

public static Type GetNonNullableType(Type type)
{
Check.NotNull(type, nameof(type));
Expand Down
18 changes: 14 additions & 4 deletions src/System.Linq.Dynamic.Core/Util/QueryableMethodFinder.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System.Linq.Expressions;
using System.Collections.Generic;
using System.Linq.Dynamic.Core.Parser;
using System.Linq.Expressions;
using System.Reflection;

namespace System.Linq.Dynamic.Core.Util;
Expand All @@ -13,8 +15,16 @@ public static MethodInfo GetGenericMethod(string name)
public static MethodInfo GetMethod(string name, Type argumentType, Type returnType, int parameterCount = 0, Func<MethodInfo, bool>? predicate = null) =>
GetMethod(name, returnType, parameterCount, mi => mi.ToString().Contains(argumentType.ToString()) && ((predicate == null) || predicate(mi)));

public static MethodInfo GetMethod(string name, Type returnType, int parameterCount = 0, Func<MethodInfo, bool>? predicate = null) =>
GetMethod(name, parameterCount, mi => (mi.ReturnType == returnType) && ((predicate == null) || predicate(mi)));
public static MethodInfo GetMethod(string name, Type returnType, int parameterCount = 0, Func<MethodInfo, bool>? predicate = null)
{
var returnTypes = new List<Type> { returnType };
if (!TypeHelper.IsNullableType(returnType))
{
returnTypes.Add(TypeHelper.GetNullableType(returnType));
}

return GetMethod(name, parameterCount, mi => returnTypes.Contains(mi.ReturnType) && (predicate == null || predicate(mi)));
}

public static MethodInfo GetMethodWithExpressionParameter(string name) =>
GetMethod(name, 1, mi =>
Expand All @@ -35,7 +45,7 @@ public static MethodInfo GetMethod(string name, int parameterCount = 0, Func<Met
{
try
{
return typeof(Queryable).GetTypeInfo().GetDeclaredMethods(name).Single(mi =>
return typeof(Queryable).GetTypeInfo().GetDeclaredMethods(name).First(mi =>
mi.GetParameters().Length == parameterCount + 1 && (predicate == null || predicate(mi)));
}
catch (Exception ex)
Expand Down
83 changes: 52 additions & 31 deletions test/System.Linq.Dynamic.Core.Tests/QueryableTests.Average.cs
Original file line number Diff line number Diff line change
@@ -1,36 +1,57 @@
using System.Linq.Dynamic.Core.Tests.Helpers.Models;
using System.Collections.Generic;
using System.Linq.Dynamic.Core.Tests.Helpers.Models;
using Xunit;

namespace System.Linq.Dynamic.Core.Tests
namespace System.Linq.Dynamic.Core.Tests;

public partial class QueryableTests
{
public partial class QueryableTests
[Fact]
public void Average()
{
// Arrange
var incomes = User.GenerateSampleModels(100).Select(u => u.Income).ToArray();

// Act
var expected = incomes.Average();
var actual = incomes.AsQueryable().Average();

// Assert
Assert.Equal(expected, actual);
}

[Fact]
public void Average_Nullable()
{
// Arrange
var list = new List<Entity> { new Entity { Value = 1 }, new Entity { Value = 2 }, new Entity { Value = null } };
var queryable = list.AsQueryable();


// Act
var expected = queryable.Average(p => p.Value);
var actual = queryable.Average("Value");

// Assert
Assert.Equal(expected, actual);
}

[Fact]
public void Average_Selector()
{
// Arrange
var users = User.GenerateSampleModels(100);

// Act
var expected = users.Average(u => u.Income);
var result = users.AsQueryable().Average("Income");

// Assert
Assert.Equal(expected, result);
}

public class Entity
{
[Fact]
public void Average()
{
// Arrange
var incomes = User.GenerateSampleModels(100).Select(u => u.Income);

// Act
var expected = incomes.Average();
var actual = incomes.AsQueryable().Average();

// Assert
Assert.Equal(expected, actual);
}

[Fact]
public void Average_Selector()
{
// Arrange
var users = User.GenerateSampleModels(100);

// Act
var expected = users.Average(u => u.Income);
var result = users.AsQueryable().Average("Income");

// Assert
Assert.Equal(expected, result);
}
public int? Value { get; set; }
}
}
}

0 comments on commit 4cc72c4

Please sign in to comment.