Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tie-break for method overloads in derived classes #1452

Merged
merged 3 commits into from
Jul 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions Bonsai.Core.Tests/OverloadedCombinatorBuilderTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,40 @@ public IObservable<TSource> Process<TSource>(IObservable<Timestamped<TSource>> s
}
}

class HidingOverloadedCombinatorMock : OverloadedCombinatorMock
{
public new IObservable<double> Process(IObservable<double> source)
{
return source.Select(x => double.NaN);
}
}

class HidingSpecializedGenericOverloadedCombinatorMock : SpecializedGenericOverloadedCombinatorMock
{
public new IObservable<TSource> Process<TSource>(IObservable<Timestamped<TSource>> source)
{
return source.Select(x => default(TSource));
}
}

[Combinator]
class BaseVirtualCombinatorMock
{
public virtual IObservable<string> Process(IObservable<string> source) => source;
}

class DerivedOverrideCombinatorMock : BaseVirtualCombinatorMock
{
public override IObservable<string> Process(IObservable<string> source) => Observable.Return(string.Empty);
}

class DerivedOverrideOverloadedCombinatorMock : BaseVirtualCombinatorMock
{
public override IObservable<string> Process(IObservable<string> source) => source;

public IObservable<object> Process(IObservable<object> _) => Observable.Return(default(object));
}

[TestMethod]
public void Build_DoubleOverloadedMethodCalledWithDouble_ReturnsDoubleValue()
{
Expand Down Expand Up @@ -187,5 +221,49 @@ public void Build_SpecializedGenericOverloadedMethod_ReturnsValue()
var result = Last(resultProvider).Result;
Assert.AreEqual(value, result);
}

[TestMethod]
public void Build_HidingDoubleOverloadedMethodCalledWithDouble_ReturnsDoubleValue()
{
var value = 5.0;
var combinator = new HidingOverloadedCombinatorMock();
var source = CreateObservableExpression(Observable.Return(value));
var resultProvider = TestCombinatorBuilder<double>(combinator, source);
var result = Last(resultProvider).Result;
Assert.AreNotEqual(value, result);
}

[TestMethod]
public void Build_HidingSpecializedGenericOverloadedMethod_ReturnsValue()
{
var value = 5;
var combinator = new HidingSpecializedGenericOverloadedCombinatorMock();
var source = CreateObservableExpression(Observable.Return(value).Timestamp());
var resultProvider = TestCombinatorBuilder<int>(combinator, source);
var result = Last(resultProvider).Result;
Assert.AreNotEqual(value, result);
}

[TestMethod]
public void Build_DerivedOverrideMethodCalledWithString_ReturnsOverrideValue()
{
var value = "5";
var combinator = new DerivedOverrideCombinatorMock();
var source = CreateObservableExpression(Observable.Return(value));
var resultProvider = TestCombinatorBuilder<object>(combinator, source);
var result = Last(resultProvider).Result;
Assert.AreNotEqual(value, result);
}

[TestMethod]
public void Build_DerivedOverrideOverloadedMethodCalledWithString_ReturnsObjectValue()
{
var value = "5";
var combinator = new DerivedOverrideOverloadedCombinatorMock();
var source = CreateObservableExpression(Observable.Return(value));
var resultProvider = TestCombinatorBuilder<object>(combinator, source);
var result = Last(resultProvider).Result;
Assert.AreNotEqual(value, result);
}
}
}
13 changes: 13 additions & 0 deletions Bonsai.Core/Expressions/ExpressionBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,7 @@ class CallCandidate
internal static readonly CallCandidate Ambiguous = new CallCandidate();
internal static readonly CallCandidate None = new CallCandidate();
internal MethodBase method;
internal Type declaringType;
internal Expression[] arguments;
internal bool generic;
internal bool expansion;
Expand Down Expand Up @@ -818,6 +819,9 @@ static CallCandidate OverloadResolution(IEnumerable<MethodBase> methods, params
return new CallCandidate
{
method = method,
declaringType = method.IsVirtual
? ((MethodInfo)method).GetBaseDefinition().DeclaringType
: method.DeclaringType,
arguments = callArguments,
generic = method.IsGenericMethod,
expansion = ParamExpansionRequired(parameters, argumentTypes),
Expand Down Expand Up @@ -853,6 +857,15 @@ static CallCandidate OverloadResolution(IEnumerable<MethodBase> methods, params
// skip self-test
if (i == j) continue;

// exclude self if declaring type is base type of other; and vice-versa
if (candidates[i].declaringType != candidates[j].declaringType)
{
if (candidates[i].declaringType.IsAssignableFrom(candidates[j].declaringType))
candidates[i].excluded = true;
else candidates[j].excluded = true;
continue;
}

// compare implicit type conversion
var comparison = CompareFunctionMember(
candidateParameters[i],
Expand Down