Skip to content

Commit

Permalink
Respect derived type preference in overload resolution
Browse files Browse the repository at this point in the history
according to the spec, this is the rule with priority second
only to "instance > extension".
The change is needed to fix problems with multiple times
overriden operators

Easy to try demo in C# that this is how we should behave:

public class C {
    public void M() {
        B.Method((int)1);
    }
}

public class A { public static void Method(int a) { } }
public class B : A { public static void Method(object a) { } }
  • Loading branch information
exyi committed Jul 30, 2023
1 parent a98d13a commit 652e795
Showing 1 changed file with 46 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ public Expression CallMethod(Type target, BindingFlags flags, string name, Type[

if (matchingMethods.Count == 0)
return null;
var overload = BestOverload(matchingMethods, operatorName);
var overload = BestOverload(matchingMethods, searchTypes, operatorName);
var parameters = overload.Method.GetParameters();

return Expression.MakeBinary(
Expand All @@ -285,6 +285,7 @@ public Expression CallMethod(Type target, BindingFlags flags, string name, Type[

private MethodRecognitionResult FindValidMethodOverloads(Expression? target, Type type, string name, BindingFlags flags, Type[]? typeArguments, Expression[] arguments, IDictionary<string, Expression>? namedArgs)
{
bool extensionMethods = false;
var methods = FindValidMethodOverloads(type.GetAllMethods(flags), name, false, typeArguments, arguments, namedArgs);

if (methods.Count == 1) return methods[0];
Expand All @@ -293,6 +294,7 @@ private MethodRecognitionResult FindValidMethodOverloads(Expression? target, Typ
// We did not find any match in regular methods => try extension methods
if (target != null && flags.HasFlag(BindingFlags.Instance))
{
extensionMethods = true;
// Change to a static call
var newArguments = new[] { target }.Concat(arguments).ToArray();
var extensions = FindValidMethodOverloads(GetAllExtensionMethods(), name, true, typeArguments, newArguments, namedArgs);
Expand All @@ -311,21 +313,25 @@ private MethodRecognitionResult FindValidMethodOverloads(Expression? target, Typ
}

// There are multiple method candidates
return BestOverload(methods, name);
return BestOverload(methods, extensionMethods ? Type.EmptyTypes : new[] { type }, name);
}

private MethodRecognitionResult BestOverload(List<MethodRecognitionResult> methods, string name)
private MethodRecognitionResult BestOverload(List<MethodRecognitionResult> methods, Type[] callingOnType, string name)
{
if (methods.Count == 1)
return methods[0];

methods = methods.OrderBy(s => s.CastCount).ThenBy(s => s.AutomaticTypeArgCount).ThenBy(s => s.HasParamsAttribute).ToList();
methods = methods
.OrderBy(s => GetNearestInheritanceDistance(s.Method.DeclaringType, callingOnType))
.ThenBy(s => s.CastCount)
.ThenBy(s => s.AutomaticTypeArgCount)
.ThenBy(s => s.HasParamsAttribute).ToList();
var method = methods.First();
var method2 = methods.Skip(1).First();
if (method.AutomaticTypeArgCount == method2.AutomaticTypeArgCount && method.CastCount == method2.CastCount && method.HasParamsAttribute == method2.HasParamsAttribute)
if (method.AutomaticTypeArgCount == method2.AutomaticTypeArgCount && method.CastCount == method2.CastCount && method.HasParamsAttribute == method2.HasParamsAttribute && GetNearestInheritanceDistance(method.Method.DeclaringType, callingOnType) == GetNearestInheritanceDistance(method2.Method.DeclaringType, callingOnType))
{
// TODO: this behavior is not completed. Implement the same behavior as in roslyn.
var foundOverloads = $"{method.Method}, {method2.Method}";
var foundOverloads = $"{ReflectionUtils.FormatMethodInfo(method.Method, stripNamespace: true)}, {ReflectionUtils.FormatMethodInfo(method2.Method, stripNamespace: true)}";
throw new InvalidOperationException($"Found ambiguous overloads of method '{name}'. The following overloads were found: {foundOverloads}.");
}
return method;
Expand Down Expand Up @@ -354,6 +360,40 @@ private List<MethodRecognitionResult> FindValidMethodOverloads(IEnumerable<Metho
return result;
}

private int? TryGetInheritanceDistance(Type baseType, Type? derivedType)
{
int distance = 0;
while (derivedType != baseType)
{
if (derivedType is null)
return null;

distance++;
derivedType = derivedType.BaseType;
}
return distance;
}
private int GetNearestInheritanceDistance(Type? baseType, Type[] derivedTypes)
{
if (baseType is null || derivedTypes.Length == 0)
// in extension method invocation this is irrelevant
return 0;
foreach (var derivedType in derivedTypes)
{
if (derivedType == baseType)
return 0;
}
int distance = int.MaxValue;
foreach (var derivedType in derivedTypes)
{
if (TryGetInheritanceDistance(baseType, derivedType) is {} d)
distance = Math.Min(distance, d);
}
if (distance == int.MaxValue)
throw new InvalidOperationException($"'{baseType.ToCode()}' is not a base type of any of '{string.Join(", ", derivedTypes.Select(t => t.ToCode()))}'.");
return distance;
}

sealed class MethodRecognitionResult
{
public MethodRecognitionResult(int automaticTypeArgCount, int castCount, Expression[] arguments, MethodInfo method, int paramsArrayCount, bool isExtension, bool hasParamsAttribute)
Expand Down

0 comments on commit 652e795

Please sign in to comment.