Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lambda inferrer: support custom delegates #1752

Merged
merged 2 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,8 @@ protected override Expression VisitLambda(LambdaBindingParserNode node)
for (var paramIndex = 0; paramIndex < typeInferenceData.Parameters!.Length; paramIndex++)
{
var currentParamType = typeInferenceData.Parameters[paramIndex];
if (currentParamType.ContainsGenericParameters)
throw new BindingCompilationException($"Internal bug: lambda parameter still contains generic arguments: parameters[{paramIndex}] = {currentParamType.ToCode()}", node);
node.ParameterExpressions[paramIndex].SetResolvedType(currentParamType);
}
}
Expand Down Expand Up @@ -506,26 +508,90 @@ protected override Expression VisitLambdaParameter(LambdaParameterBindingParserN

private Expression CreateLambdaExpression(Expression body, ParameterExpression[] parameters, Type? delegateType)
{
if (delegateType != null && delegateType.Namespace == "System")
if (delegateType is null || delegateType == typeof(object) || delegateType == typeof(Delegate))
// Assume delegate is a System.Func<...>
return Expression.Lambda(body, parameters);

if (!delegateType.IsDelegate(out var invokeMethod))
throw new DotvvmCompilationException($"Cannot create lambda function, type '{delegateType.ToCode()}' is not a delegate type.");

if (invokeMethod.ReturnType == typeof(void))
{
// We must validate that lambda body contains a valid statement
if ((body.NodeType != ExpressionType.Default) && (body.NodeType != ExpressionType.Block) && (body.NodeType != ExpressionType.Call) && (body.NodeType != ExpressionType.Assign))
throw new DotvvmCompilationException($"Only method invocations and assignments can be used as statements.");

// Make sure the result type will be void by adding an empty expression
body = Expression.Block(body, Expression.Empty());
}

// convert body result to the delegate return type
if (invokeMethod.ReturnType.ContainsGenericParameters)
{
if (delegateType.Name == "Action" || delegateType.Name == $"Action`{parameters.Length}")
if (invokeMethod.ReturnType.IsGenericType)
{
// We must validate that lambda body contains a valid statement
if ((body.NodeType != ExpressionType.Default) && (body.NodeType != ExpressionType.Block) && (body.NodeType != ExpressionType.Call) && (body.NodeType != ExpressionType.Assign))
throw new DotvvmCompilationException($"Only method invocations and assignments can be used as statements.");
// no fancy implicit conversions are supported, only inheritance
if (!ReflectionUtils.IsAssignableToGenericType(body.Type, invokeMethod.ReturnType.GetGenericTypeDefinition(), out var bodyReturnType))
{
throw new DotvvmCompilationException($"Cannot convert lambda function body of type '{body.Type.ToCode()}' to the delegate return type '{invokeMethod.ReturnType.ToCode()}'.");
}
else
{
body = Expression.Convert(body, bodyReturnType);
}
}
else
{
// fine, we will unify it in the next step

// Make sure the result type will be void by adding an empty expression
return Expression.Lambda(Expression.Block(body, Expression.Empty()), parameters);
// Some complex conversions like Tuple<T, List<object>> -> Tuple<T, IEnumerable<T2>>
// will fail, but we don't have to support everything
}
else if (delegateType.Name == "Predicate`1")
}
else
{
body = TypeConversion.EnsureImplicitConversion(body, invokeMethod.ReturnType);
}

if (delegateType.ContainsGenericParameters)
{
var delegateTypeDef = delegateType.GetGenericTypeDefinition();
// The delegate is either purely generic (Func<T, T>) or only some of the generic arguments are known (Func<T, bool>)
// initialize generic args with the already known types
var genericArgs =
delegateTypeDef.GetGenericArguments().Zip(
delegateType.GetGenericArguments(),
(param, argument) => new KeyValuePair<Type, Type>(param, argument)
)
.Where(p => p.Value != p.Key)
.ToDictionary(p => p.Key, p => p.Value);

var delegateParameters = invokeMethod.GetParameters();
for (int i = 0; i < parameters.Length; i++)
{
if (!ReflectionUtils.TryUnifyGenericTypes(delegateParameters[i].ParameterType, parameters[i].Type, genericArgs))
{
throw new DotvvmCompilationException($"Could not match lambda function parameter '{parameters[i].Type.ToCode()} {parameters[i].Name}' to delegate parameter '{delegateParameters[i].ParameterType.ToCode()} {delegateParameters[i].Name}'.");
}
}
if (!ReflectionUtils.TryUnifyGenericTypes(invokeMethod.ReturnType, body.Type, genericArgs))
{
var type = delegateType.GetGenericTypeDefinition().MakeGenericType(parameters.Single().Type);
return Expression.Lambda(type, body, parameters);
throw new DotvvmCompilationException($"Could not match lambda function return type '{body.Type.ToCode()}' to delegate return type '{invokeMethod.ReturnType.ToCode()}'.");
}
ReflectionUtils.ExpandUnifiedTypes(genericArgs);

if (!delegateTypeDef.GetGenericArguments().All(a => genericArgs.TryGetValue(a, out var v) && !v.ContainsGenericParameters))
{
var missingGenericArgs = delegateTypeDef.GetGenericArguments().Where(genericArg => !genericArgs.ContainsKey(genericArg) || genericArgs[genericArg].ContainsGenericParameters);
throw new DotvvmCompilationException($"Could not infer all generic arguments ({string.Join(", ", missingGenericArgs)}) of delegate type '{delegateType.ToCode()}' from lambda expression '({string.Join(", ", parameters.Select(p => $"{p.Type.ToCode()} {p.Name}"))}) => ...'.");
}

delegateType = delegateTypeDef.MakeGenericType(
delegateTypeDef.GetGenericArguments().Select(genericParam => genericArgs[genericParam]).ToArray()
);
}

// Assume delegate is a System.Func<...>
return Expression.Lambda(body, parameters);
return Expression.Lambda(delegateType, body, parameters);
}

protected override Expression VisitBlock(BlockBindingParserNode node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ internal class InfererContext
{
public MethodGroupExpression? Target { get; set; }
public Expression[] Arguments { get; set; }
public Dictionary<string, Type> Generics { get; set; }
public Dictionary<Type, Type> Generics { get; set; }
public int CurrentArgumentIndex { get; set; }
public bool IsExtensionCall { get; set; }

public InfererContext(MethodGroupExpression? target, int argsCount)
{
this.Target = target;
this.Arguments = new Expression[argsCount];
this.Generics = new Dictionary<string, Type>();
this.Generics = new();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -94,39 +94,32 @@ private bool TryMatchDelegate(InfererContext? context, int argsCount, Type deleg
if (delegateParameters.Length != argsCount)
return false;

var generics = (context != null) ? context.Generics : new Dictionary<string, Type>();
if (!TryInstantiateDelegateParameters(delegateType, argsCount, generics, out parameters))
var generics = (context != null) ? context.Generics : new Dictionary<Type, Type>();
if (!TryInstantiateDelegateParameters(delegateParameters.Select(p => p.ParameterType).ToArray(), argsCount, generics, out parameters))
return false;

return true;
}

private bool TryInstantiateDelegateParameters(Type generic, int argsCount, Dictionary<string, Type> generics, [NotNullWhen(true)] out Type[]? instantiation)
private bool TryInstantiateDelegateParameters(Type[] delegateParameters, int argsCount, Dictionary<Type, Type> generics, [NotNullWhen(true)] out Type[]? instantiation)
{
var genericArgs = generic.GetGenericArguments();
var substitutions = new Type[argsCount];

for (var argIndex = 0; argIndex < argsCount; argIndex++)
{
var currentArg = genericArgs[argIndex];
var currentArg = delegateParameters[argIndex];
var assignedArg = ReflectionUtils.AssignGenericParameters(currentArg, generics);

if (!currentArg.IsGenericParameter)
{
// This is a known type
substitutions[argIndex] = currentArg;
}
else if (currentArg.IsGenericParameter && generics.ContainsKey(currentArg.Name))
{
// This is a generic parameter
// But we already inferred its type
substitutions[argIndex] = generics[currentArg.Name];
}
else
if (assignedArg.ContainsGenericParameters)
{
// This is an unknown type
instantiation = null;
return false;
}
else
{
substitutions[argIndex] = assignedArg;
}
}

instantiation = substitutions;
Expand Down
12 changes: 6 additions & 6 deletions src/Framework/Framework/Compilation/Inference/TypeInferer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ private void RefineCandidates(int index)
return;

var newCandidates = new List<MethodInfo>();
var newInstantiations = new Dictionary<string, HashSet<Type>>();
var newInstantiations = new Dictionary<Type, HashSet<Type>>();

// Check if we can remove some candidates
// Also try to infer generics based on provided argument
var tempInstantiations = new Dictionary<string, Type>();
var tempInstantiations = new Dictionary<Type, Type>();
foreach (var candidate in context.Target.Candidates!.Where(c => c.GetParameters().Length > index))
{
tempInstantiations.Clear();
Expand All @@ -87,12 +87,12 @@ private void RefineCandidates(int index)

if (parameterType.IsGenericParameter)
{
tempInstantiations.Add(parameterType.Name, argumentType);
tempInstantiations.Add(parameterType, argumentType);
}
else if (parameterType.ContainsGenericParameters)
{
// Check if we already inferred instantiation for these generics
if (!parameterType.GetGenericArguments().Any(param => !context.Generics.ContainsKey(param.Name)))
if (!parameterType.GetGenericArguments().Any(param => !context.Generics.ContainsKey(param)))
continue;

// Try to infer instantiation based on given argument
Expand All @@ -119,15 +119,15 @@ private void RefineCandidates(int index)
context.Target.Candidates = newCandidates;
}

private bool TryInferInstantiation(Type generic, Type concrete, Dictionary<string, Type> generics)
private bool TryInferInstantiation(Type generic, Type concrete, Dictionary<Type, Type> generics)
{
if (generic == concrete)
return true;

if (generic.IsGenericParameter)
{
// We found the instantiation
generics.Add(generic.Name, concrete);
generics.Add(generic, concrete);
return true;
}
else if (ReflectionUtils.IsEnumerable(generic))
Expand Down
125 changes: 119 additions & 6 deletions src/Framework/Framework/Utils/ReflectionUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,17 @@ public static IEnumerable<MethodInfo> GetAllMethods(this Type type, BindingFlags
/// </summary>
public static bool IsAssignableToGenericType(this Type givenType, Type genericType, [NotNullWhen(returnValue: true)] out Type? commonType)
{
var interfaceTypes = givenType.GetInterfaces();

foreach (var it in interfaceTypes)
if (genericType.IsInterface)
{
if (it.IsGenericType && it.GetGenericTypeDefinition() == genericType)
var interfaceTypes = givenType.GetInterfaces();

foreach (var it in interfaceTypes)
{
commonType = it;
return true;
if (it.IsGenericType && it.GetGenericTypeDefinition() == genericType)
{
commonType = it;
return true;
}
}
}

Expand Down Expand Up @@ -654,5 +657,115 @@ public static IEnumerable<Type> GetBaseTypesAndInterfaces(Type type)
type = baseType;
}
}


internal static bool TryUnifyGenericTypes(Type a, Type b, Dictionary<Type, Type> genericAssignment)
{
if (a == b)
return true;

if (a.IsGenericParameter)
{
if (genericAssignment.ContainsKey(a))
return TryUnifyGenericTypes(genericAssignment[a], b, genericAssignment);

genericAssignment.Add(a, b);
return true;
}
else if (b.IsGenericParameter)
{
if (genericAssignment.ContainsKey(b))
return TryUnifyGenericTypes(a, genericAssignment[b], genericAssignment);

genericAssignment.Add(b, a);
return true;
}
else if (a.IsGenericType && b.IsGenericType)
{
if (a.GetGenericTypeDefinition() != b.GetGenericTypeDefinition())
return false;

var aArgs = a.GetGenericArguments();
var bArgs = b.GetGenericArguments();
if (aArgs.Length != bArgs.Length)
return false;

for (var i = 0; i < aArgs.Length; i++)
{
if (!TryUnifyGenericTypes(aArgs[i], bArgs[i], genericAssignment))
return false;
}

return true;
}
else
{
return false;
}
}

internal static void ExpandUnifiedTypes(Dictionary<Type, Type> genericAssignment)
{
var iteration = 0;
bool dirty;
do
{
dirty = false;
iteration++;
if (iteration > 100)
throw new Exception("Too much recursion in ExpandUnifiedTypes");

foreach (var (key, value) in genericAssignment.ToArray())
{
var expanded = AssignGenericParameters(value, genericAssignment);
if (expanded != value)
{
genericAssignment[key] = expanded;
dirty = true;
}
}
}
while (dirty);
}

internal static Type AssignGenericParameters(Type t, IReadOnlyDictionary<Type, Type> genericAssignment)
{
if (!t.ContainsGenericParameters)
return t;

if (t.IsGenericParameter)
{
if (genericAssignment.TryGetValue(t, out var result))
return result;
else
return t;
}
else if (t.IsGenericType)
{
var args = t.GetGenericArguments();
for (var i = 0; i < args.Length; i++)
{
args[i] = AssignGenericParameters(args[i], genericAssignment);
}
if (args.SequenceEqual(t.GetGenericArguments()))
return t;
else
return t.GetGenericTypeDefinition().MakeGenericType(args);
}
else if (t.HasElementType)
{
var el = AssignGenericParameters(t.GetElementType()!, genericAssignment);
if (el == t.GetElementType())
return t;
else if (t.IsArray)
return el.MakeArrayType(t.GetArrayRank());
else
throw new NotSupportedException();
}
else
{
return t;
}
}
}
}
Loading
Loading