Skip to content

Commit

Permalink
Merge pull request #1483 from riganti/custom-binary-op
Browse files Browse the repository at this point in the history
Custom handling of binary operators
  • Loading branch information
exyi authored Jul 30, 2023
2 parents 182d01d + 652e795 commit f059d47
Show file tree
Hide file tree
Showing 15 changed files with 490 additions and 256 deletions.
2 changes: 1 addition & 1 deletion src/Framework/Framework/Binding/BindingFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public static IBinding CreateBinding(this BindingCompilationService service, Typ
if (ctor == null) throw new NotSupportedException($"Could not find .ctor(BindingCompilationService service, object[] properties) on binding '{type.FullName}'.");
var bindingServiceParam = Expression.Parameter(typeof(BindingCompilationService));
var propertiesParam = Expression.Parameter(typeof(object?[]));
var expression = Expression.New(ctor, bindingServiceParam, TypeConversion.ImplicitConversion(propertiesParam, ctor.GetParameters()[1].ParameterType, throwException: true)!);
var expression = Expression.New(ctor, bindingServiceParam, TypeConversion.EnsureImplicitConversion(propertiesParam, ctor.GetParameters()[1].ParameterType));
return Expression.Lambda<Func<BindingCompilationService, object?[], IBinding>>(expression, bindingServiceParam, propertiesParam).CompileFast(flags: CompilerFlags.ThrowOnNotSupportedExpression);
})(service, properties);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ protected override Expression VisitInterpolatedStringExpression(InterpolatedStri
{
// Translate to a String.Format(...) call
var arguments = node.Arguments.Select((arg, index) => HandleErrors(node.Arguments[index], Visit)!).ToArray();
ThrowOnErrors();
return memberExpressionFactory.Call(target, new[] { Expression.Constant(node.Format) }.Concat(arguments).ToArray());
}
else
Expand Down Expand Up @@ -294,7 +295,7 @@ protected override Expression VisitAssemblyQualifiedName(AssemblyQualifiedNameBi

protected override Expression VisitConditionalExpression(ConditionalExpressionBindingParserNode node)
{
var condition = HandleErrors(node.ConditionExpression, n => TypeConversion.ImplicitConversion(Visit(n), typeof(bool), true));
var condition = HandleErrors(node.ConditionExpression, n => TypeConversion.EnsureImplicitConversion(Visit(n), typeof(bool)));
var trueExpr = HandleErrors(node.TrueExpression, Visit)!;
var falseExpr = HandleErrors(node.FalseExpression, Visit)!;
ThrowOnErrors();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ protected override Expression VisitConditional(ConditionalExpression node)
if (ifTrue.Type != ifFalse.Type)
{
var nullable = ifTrue.Type.IsNullable() ? ifTrue.Type : ifFalse.Type;
ifTrue = TypeConversion.ImplicitConversion(ifTrue, nullable, throwException: true)!;
ifFalse = TypeConversion.ImplicitConversion(ifFalse, nullable, throwException: true)!;
ifTrue = TypeConversion.EnsureImplicitConversion(ifTrue, nullable);
ifFalse = TypeConversion.EnsureImplicitConversion(ifFalse, nullable);
}
return Expression.Condition(test, ifTrue, ifFalse);
});
Expand Down Expand Up @@ -181,7 +181,7 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
{
return CheckForNull(Visit(node.Arguments.First()), index =>
{
var convertedIndex = TypeConversion.ImplicitConversion(index, node.Method.GetParameters().First().ParameterType, throwException: true)!;
var convertedIndex = TypeConversion.EnsureImplicitConversion(index, node.Method.GetParameters().First().ParameterType);
return Expression.Call(target, node.Method, new[] { convertedIndex }.Concat(node.Arguments.Skip(1)));
});
}, suppress: node.Object?.Type?.IsNullable() ?? true);
Expand Down Expand Up @@ -244,7 +244,7 @@ protected Expression CheckForNull(Expression? parameter, Func<Expression, Expres
parameter as ParameterExpression ??
Expression.Parameter(parameter.Type, "tmp" + tmpCounter++);
var eresult = callback(p2.Type.IsNullable() ? (Expression)Expression.Property(p2, "Value") : p2);
eresult = TypeConversion.ImplicitConversion(eresult, eresult.Type.MakeNullableType())!;
eresult = TypeConversion.EnsureImplicitConversion(eresult, eresult.Type.MakeNullableType());
var condition = parameter.Type.IsNullable() ? (Expression)Expression.Property(p2, "HasValue") : Expression.NotEqual(p2, Expression.Constant(null, p2.Type));
var handledResult =
Expression.Condition(condition,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public CastedExpressionBindingProperty ConvertExpressionToType(ParsedExpressionB
// if the expression is of type object (i.e. null literal) try the lambda conversion.
convertedExpr != null && expr.Expression.Type != typeof(object) ? convertedExpr :
TypeConversion.MagicLambdaConversion(expr.Expression, destType) ?? convertedExpr ??
TypeConversion.ImplicitConversion(expr.Expression, destType, throwException: true, allowToString: true)!
TypeConversion.EnsureImplicitConversion(expr.Expression, destType, allowToString: true)!
);
}

Expand Down
220 changes: 95 additions & 125 deletions src/Framework/Framework/Compilation/Binding/MemberExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -241,17 +241,60 @@ public Expression CallMethod(Type target, BindingFlags flags, string name, Type[
return Expression.Call(method.Method, method.Arguments);
}

public Expression? TryCallCustomBinaryOperator(Expression a, Expression b, string operatorName, ExpressionType operatorType)
{
if (a is null) throw new ArgumentNullException(nameof(a));
if (b is null) throw new ArgumentNullException(nameof(b));
if (operatorName is null) throw new ArgumentNullException(nameof(b));

var searchTypes = new [] { a.Type, b.Type, a.Type.UnwrapNullableType(), b.Type.UnwrapNullableType() }.OfType<Type>().Distinct().ToArray();


// https://github.com/dotnet/csharpstandard/blob/standard-v6/standard/expressions.md#1145-binary-operator-overload-resolution
// The set of candidate user-defined operators provided by X and Y for the operation operator «op»(x, y) is determined. The set consists of the union of the candidate operators provided by X and the candidate operators provided by Y, each determined using the rules of §11.4.6.

var candidateMethods =
searchTypes
.SelectMany(t => t.GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy))
.Where(m => m.Name == operatorName && !m.IsGenericMethod && m.GetParameters().Length == 2)
.Distinct()
.ToArray();

// The overload resolution rules of §11.6.4 are applied to the set of candidate operators to select the best operator with respect to the argument list (x, y), and this operator becomes the result of the overload resolution process. If overload resolution fails to select a single best operator, a binding-time error occurs.

var matchingMethods = FindValidMethodOverloads(candidateMethods, operatorName, false, null, new[] { a, b }, null);
var liftToNull = matchingMethods.Count == 0 && (a.Type.IsNullable() || b.Type.IsNullable());
if (liftToNull)
{
matchingMethods = FindValidMethodOverloads(candidateMethods, operatorName, false, null, new[] { a.UnwrapNullable(), b.UnwrapNullable() }, null);
}

if (matchingMethods.Count == 0)
return null;
var overload = BestOverload(matchingMethods, searchTypes, operatorName);
var parameters = overload.Method.GetParameters();

return Expression.MakeBinary(
operatorType,
TypeConversion.EnsureImplicitConversion(a, parameters[0].ParameterType),
TypeConversion.EnsureImplicitConversion(b, parameters[1].ParameterType),
liftToNull: liftToNull,
method: overload.Method
);
}

private MethodRecognitionResult FindValidMethodOverloads(Expression? target, Type type, string name, BindingFlags flags, Type[]? typeArguments, Expression[] arguments, IDictionary<string, Expression>? namedArgs)
{
bool extensionMethods = false;
var methods = FindValidMethodOverloads(type.GetAllMethods(flags), name, false, typeArguments, arguments, namedArgs);

if (methods.Count == 1) return methods[0];
if (methods.Count == 0)
{
// We did not find any match in regular methods => try extension methods
if (target != null)
if (target != null && flags.HasFlag(BindingFlags.Instance))
{
extensionMethods = true;
// Change to a static call
var newArguments = new[] { target }.Concat(arguments).ToArray();
var extensions = FindValidMethodOverloads(GetAllExtensionMethods(), name, true, typeArguments, newArguments, namedArgs);
Expand All @@ -270,13 +313,25 @@ private MethodRecognitionResult FindValidMethodOverloads(Expression? target, Typ
}

// There are multiple method candidates
methods = methods.OrderBy(s => s.CastCount).ThenBy(s => s.AutomaticTypeArgCount).ThenBy(s => s.HasParamsAttribute).ToList();
return BestOverload(methods, extensionMethods ? Type.EmptyTypes : new[] { type }, name);
}

private MethodRecognitionResult BestOverload(List<MethodRecognitionResult> methods, Type[] callingOnType, string name)
{
if (methods.Count == 1)
return methods[0];

methods = methods
.OrderBy(s => GetNearestInheritanceDistance(s.Method.DeclaringType, callingOnType))
.ThenBy(s => s.CastCount)
.ThenBy(s => s.AutomaticTypeArgCount)
.ThenBy(s => s.HasParamsAttribute).ToList();
var method = methods.First();
var method2 = methods.Skip(1).First();
if (method.AutomaticTypeArgCount == method2.AutomaticTypeArgCount && method.CastCount == method2.CastCount && method.HasParamsAttribute == method2.HasParamsAttribute)
if (method.AutomaticTypeArgCount == method2.AutomaticTypeArgCount && method.CastCount == method2.CastCount && method.HasParamsAttribute == method2.HasParamsAttribute && GetNearestInheritanceDistance(method.Method.DeclaringType, callingOnType) == GetNearestInheritanceDistance(method2.Method.DeclaringType, callingOnType))
{
// TODO: this behavior is not completed. Implement the same behavior as in roslyn.
var foundOverloads = $"{method.Method}, {method2.Method}";
var foundOverloads = $"{ReflectionUtils.FormatMethodInfo(method.Method, stripNamespace: true)}, {ReflectionUtils.FormatMethodInfo(method2.Method, stripNamespace: true)}";
throw new InvalidOperationException($"Found ambiguous overloads of method '{name}'. The following overloads were found: {foundOverloads}.");
}
return method;
Expand Down Expand Up @@ -305,6 +360,40 @@ private List<MethodRecognitionResult> FindValidMethodOverloads(IEnumerable<Metho
return result;
}

private int? TryGetInheritanceDistance(Type baseType, Type? derivedType)
{
int distance = 0;
while (derivedType != baseType)
{
if (derivedType is null)
return null;

distance++;
derivedType = derivedType.BaseType;
}
return distance;
}
private int GetNearestInheritanceDistance(Type? baseType, Type[] derivedTypes)
{
if (baseType is null || derivedTypes.Length == 0)
// in extension method invocation this is irrelevant
return 0;
foreach (var derivedType in derivedTypes)
{
if (derivedType == baseType)
return 0;
}
int distance = int.MaxValue;
foreach (var derivedType in derivedTypes)
{
if (TryGetInheritanceDistance(baseType, derivedType) is {} d)
distance = Math.Min(distance, d);
}
if (distance == int.MaxValue)
throw new InvalidOperationException($"'{baseType.ToCode()}' is not a base type of any of '{string.Join(", ", derivedTypes.Select(t => t.ToCode()))}'.");
return distance;
}

sealed class MethodRecognitionResult
{
public MethodRecognitionResult(int automaticTypeArgCount, int castCount, Expression[] arguments, MethodInfo method, int paramsArrayCount, bool isExtension, bool hasParamsAttribute)
Expand All @@ -329,6 +418,7 @@ public MethodRecognitionResult(int automaticTypeArgCount, int castCount, Express

private MethodRecognitionResult? TryCallMethod(MethodInfo method, Type[]? typeArguments, Expression[] positionalArguments, IDictionary<string, Expression>? namedArguments)
{
if (positionalArguments.Contains(null)) throw new ArgumentNullException("positionalArguments[]");
var parameters = method.GetParameters();

if (!TryPrepareArguments(parameters, positionalArguments, namedArguments, out var args, out var castCount))
Expand Down Expand Up @@ -406,7 +496,7 @@ public MethodRecognitionResult(int automaticTypeArgCount, int castCount, Express
if (args.Length == i + 1 && hasParamsArrayAttributes && !args[i].Type.IsArray)
{
var converted = positionalArguments.Skip(i)
.Select(a => TypeConversion.ImplicitConversion(a, elm, throwException: true)!)
.Select(a => TypeConversion.EnsureImplicitConversion(a, elm))
.ToArray();
args[i] = NewArrayExpression.NewArrayInit(elm, converted);
}
Expand Down Expand Up @@ -561,131 +651,11 @@ private static bool TryPrepareArguments(ParameterInfo[] parameters, Expression[]
return null;
}

public Expression EqualsMethod(Expression left, Expression right)
{
Expression? equatable = null;
Expression? theOther = null;
if (typeof(IEquatable<>).IsAssignableFrom(left.Type))
{
equatable = left;
theOther = right;
}
else if (typeof(IEquatable<>).IsAssignableFrom(right.Type))
{
equatable = right;
theOther = left;
}

if (equatable != null)
{
var m = CallMethod(equatable, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.FlattenHierarchy, "Equals", null, new[] { theOther! });
if (m != null) return m;
}

if (left.Type.IsValueType)
{
equatable = left;
theOther = right;
}
else if (left.Type.IsValueType)
{
equatable = right;
theOther = left;
}

if (equatable != null)
{
theOther = TypeConversion.ImplicitConversion(theOther!, equatable.Type);
if (theOther != null) return Expression.Equal(equatable, theOther);
}

return CallMethod(left, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.FlattenHierarchy, "Equals", null, new[] { right });
}

public Expression CompareMethod(Expression left, Expression right)
{
Type compareType = typeof(object);
Expression? equatable = null;
Expression? theOther = null;
if (typeof(IComparable<>).IsAssignableFrom(left.Type))
{
equatable = left;
theOther = right;
}
else if (typeof(IComparable<>).IsAssignableFrom(right.Type))
{
equatable = right;
theOther = left;
}
else if (typeof(IComparable).IsAssignableFrom(left.Type))
{
equatable = left;
theOther = right;
}
else if (typeof(IComparable).IsAssignableFrom(right.Type))
{
equatable = right;
theOther = left;
}

if (equatable != null)
{
return CallMethod(equatable, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.FlattenHierarchy, "Compare", null, new[] { theOther! });
}
throw new NotSupportedException("IComparable is not implemented on any of specified types");
}

public Expression GetUnaryOperator(Expression expr, ExpressionType operation)
{
var binder = (DynamicMetaObjectBinder)Microsoft.CSharp.RuntimeBinder.Binder.UnaryOperation(
CSharpBinderFlags.None, operation, typeof(object), ExpressionHelper.GetBinderArguments(1));
return ExpressionHelper.ApplyBinder(binder, true, expr)!;
}

public Expression GetBinaryOperator(Expression left, Expression right, ExpressionType operation)
{
if (operation == ExpressionType.Coalesce)
{
// in bindings, most expressions will be nullable due to automatic null-propagation
// the null propagation visitor however runs after this, so we need to convert left to nullable
// to make the validation in Expression.Coalesce happy
var leftNullable =
left.Type.IsValueType && !left.Type.IsNullable()
? Expression.Convert(left, typeof(Nullable<>).MakeGenericType(left.Type))
: left;
return Expression.Coalesce(leftNullable, right);
}
if (operation == ExpressionType.Assign)
{
return UpdateMember(left, TypeConversion.ImplicitConversion(right, left.Type, true, true)!)
.NotNull($"Expression '{right}' cannot be assigned into '{left}'.");
}

// TODO: type conversions
if (operation == ExpressionType.AndAlso) return Expression.AndAlso(left, right);
else if (operation == ExpressionType.OrElse) return Expression.OrElse(left, right);

var binder = (DynamicMetaObjectBinder)Microsoft.CSharp.RuntimeBinder.Binder.BinaryOperation(
CSharpBinderFlags.None, operation, typeof(object), ExpressionHelper.GetBinderArguments(2));
var result = ExpressionHelper.ApplyBinder(binder, false, left, right);
if (result != null) return result;
if (operation == ExpressionType.Equal) return EqualsMethod(left, right);
if (operation == ExpressionType.NotEqual) return Expression.Not(EqualsMethod(left, right));


// try converting left to right.Type and vice versa
// needed to enum with pseudo-string literal operations
// if (TypeConversion.ImplicitConversion(left, right.Type) is {} leftConverted)
// return GetBinaryOperator(leftConverted, right, operation);
// if (TypeConversion.ImplicitConversion(right, left.Type) is {} rightConverted)
// return GetBinaryOperator(left, rightConverted, operation);

// lift the operator
if (left.Type.IsNullable() || right.Type.IsNullable())
return GetBinaryOperator(left.UnwrapNullable(), right.UnwrapNullable(), operation);

throw new Exception($"could not apply { operation } binary operator to { left } and { right }");
// TODO: comparison operators
}
}
}
Loading

0 comments on commit f059d47

Please sign in to comment.