From 05e3c0896478453e6e79926822102276066b93a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Thu, 5 Oct 2023 12:38:34 -0700 Subject: [PATCH] allow native functions to have type parameters --- runtime/sema/check_composite_declaration.go | 7 +- runtime/sema/check_function.go | 15 +- runtime/sema/checker.go | 90 +++++++- runtime/sema/errors.go | 17 ++ runtime/tests/checker/genericfunction_test.go | 213 +++++++++++++++++- 5 files changed, 336 insertions(+), 6 deletions(-) diff --git a/runtime/sema/check_composite_declaration.go b/runtime/sema/check_composite_declaration.go index 3ebb93dcbf..41914a9744 100644 --- a/runtime/sema/check_composite_declaration.go +++ b/runtime/sema/check_composite_declaration.go @@ -1783,7 +1783,12 @@ func (checker *Checker) defaultMembersAndOrigins( identifier := function.Identifier.Identifier - functionType := checker.functionType(function.ParameterList, function.ReturnTypeAnnotation) + functionType := checker.functionType( + function.IsNative(), + function.TypeParameterList, + function.ParameterList, + function.ReturnTypeAnnotation, + ) argumentLabels := function.ParameterList.EffectiveArgumentLabels() diff --git a/runtime/sema/check_function.go b/runtime/sema/check_function.go index 46d2d0392a..4d7452323c 100644 --- a/runtime/sema/check_function.go +++ b/runtime/sema/check_function.go @@ -95,7 +95,13 @@ func (checker *Checker) visitFunctionDeclaration( functionType := checker.Elaboration.FunctionDeclarationFunctionType(declaration) if functionType == nil { - functionType = checker.functionType(declaration.ParameterList, declaration.ReturnTypeAnnotation) + + functionType = checker.functionType( + declaration.IsNative(), + declaration.TypeParameterList, + declaration.ParameterList, + declaration.ReturnTypeAnnotation, + ) if options.declareFunction { checker.declareFunctionDeclaration(declaration, functionType) @@ -430,7 +436,12 @@ func (checker *Checker) declareBefore() { func (checker *Checker) VisitFunctionExpression(expression *ast.FunctionExpression) Type { // TODO: infer - functionType := checker.functionType(expression.ParameterList, expression.ReturnTypeAnnotation) + functionType := checker.functionType( + false, + nil, + expression.ParameterList, + expression.ReturnTypeAnnotation, + ) checker.Elaboration.SetFunctionExpressionFunctionType(expression, functionType) diff --git a/runtime/sema/checker.go b/runtime/sema/checker.go index 8706de2f69..ae7626a8e2 100644 --- a/runtime/sema/checker.go +++ b/runtime/sema/checker.go @@ -424,7 +424,12 @@ func (checker *Checker) checkTopLevelDeclarationValidity( } func (checker *Checker) declareGlobalFunctionDeclaration(declaration *ast.FunctionDeclaration) { - functionType := checker.functionType(declaration.ParameterList, declaration.ReturnTypeAnnotation) + functionType := checker.functionType( + declaration.IsNative(), + declaration.TypeParameterList, + declaration.ParameterList, + declaration.ReturnTypeAnnotation, + ) checker.Elaboration.SetFunctionDeclarationFunctionType(declaration, functionType) checker.declareFunctionDeclaration(declaration, functionType) } @@ -1240,11 +1245,65 @@ func (checker *Checker) ConvertTypeAnnotation(typeAnnotation *ast.TypeAnnotation } func (checker *Checker) functionType( + isNative bool, + typeParameterList *ast.TypeParameterList, parameterList *ast.ParameterList, returnTypeAnnotation *ast.TypeAnnotation, ) *FunctionType { + + // Convert type parameters (if any) + + var convertedTypeParameters []*TypeParameter + if typeParameterList != nil { + + // Only native functions may have type parameters at the moment + if !isNative && !typeParameterList.IsEmpty() { + checker.report(&InvalidTypeParameterizedNonNativeFunctionError{ + Range: ast.NewRangeFromPositioned( + checker.memoryGauge, + typeParameterList, + ), + }) + } + + checker.typeActivations.Enter() + defer checker.typeActivations.Leave(func(gauge common.MemoryGauge) ast.Position { + if returnTypeAnnotation != nil { + return returnTypeAnnotation.EndPosition(gauge) + } else { + return parameterList.EndPos + } + }) + + // All type parameters are converted at once, + // so type bounds may currently not refer to previous type parameters + + convertedTypeParameters = checker.typeParameters(typeParameterList) + + for typeParameterIndex, typeParameter := range typeParameterList.TypeParameters { + convertedTypeParameter := convertedTypeParameters[typeParameterIndex] + + genericType := &GenericType{ + TypeParameter: convertedTypeParameter, + } + + _, err := checker.typeActivations.declareType(typeDeclaration{ + identifier: typeParameter.Identifier, + ty: genericType, + declarationKind: common.DeclarationKindTypeParameter, + allowOuterScopeShadowing: false, + }) + checker.report(err) + + } + } + + // Convert parameters + convertedParameters := checker.parameters(parameterList) + // Convert return type + convertedReturnTypeAnnotation := VoidTypeAnnotation if returnTypeAnnotation != nil { convertedReturnTypeAnnotation = @@ -1252,11 +1311,40 @@ func (checker *Checker) functionType( } return &FunctionType{ + TypeParameters: convertedTypeParameters, Parameters: convertedParameters, ReturnTypeAnnotation: convertedReturnTypeAnnotation, } } +func (checker *Checker) typeParameters(typeParameterList *ast.TypeParameterList) []*TypeParameter { + + var typeParameters []*TypeParameter + + typeParameterCount := len(typeParameterList.TypeParameters) + if typeParameterCount > 0 { + typeParameters = make([]*TypeParameter, typeParameterCount) + + for i, typeParameter := range typeParameterList.TypeParameters { + + typeBoundAnnotation := typeParameter.TypeBound + var convertedTypeBound Type + if typeBoundAnnotation != nil { + convertedTypeBoundAnnotation := checker.ConvertTypeAnnotation(typeBoundAnnotation) + checker.checkTypeAnnotation(convertedTypeBoundAnnotation, typeBoundAnnotation) + convertedTypeBound = convertedTypeBoundAnnotation.Type + } + + typeParameters[i] = &TypeParameter{ + Name: typeParameter.Identifier.Identifier, + TypeBound: convertedTypeBound, + } + } + } + + return typeParameters +} + func (checker *Checker) parameters(parameterList *ast.ParameterList) []Parameter { var parameters []Parameter diff --git a/runtime/sema/errors.go b/runtime/sema/errors.go index 5fc1424792..ac53910145 100644 --- a/runtime/sema/errors.go +++ b/runtime/sema/errors.go @@ -4194,3 +4194,20 @@ func (*AttachmentsNotEnabledError) IsUserError() {} func (e *AttachmentsNotEnabledError) Error() string { return "attachments are not enabled and cannot be used in this environment" } + +// InvalidTypeParameterizedNonNativeFunctionError + +type InvalidTypeParameterizedNonNativeFunctionError struct { + ast.Range +} + +var _ SemanticError = &InvalidTypeParameterizedNonNativeFunctionError{} +var _ errors.UserError = &InvalidTypeParameterizedNonNativeFunctionError{} + +func (*InvalidTypeParameterizedNonNativeFunctionError) isSemanticError() {} + +func (*InvalidTypeParameterizedNonNativeFunctionError) IsUserError() {} + +func (e *InvalidTypeParameterizedNonNativeFunctionError) Error() string { + return "invalid type parameters in non-native function" +} diff --git a/runtime/tests/checker/genericfunction_test.go b/runtime/tests/checker/genericfunction_test.go index 326221b719..d45e61a775 100644 --- a/runtime/tests/checker/genericfunction_test.go +++ b/runtime/tests/checker/genericfunction_test.go @@ -27,6 +27,7 @@ import ( "github.com/onflow/cadence/runtime/ast" "github.com/onflow/cadence/runtime/common" + "github.com/onflow/cadence/runtime/parser" "github.com/onflow/cadence/runtime/sema" "github.com/onflow/cadence/runtime/stdlib" ) @@ -50,7 +51,7 @@ func parseAndCheckWithTestValue(t *testing.T, code string, ty sema.Type) (*sema. ) } -func TestCheckGenericFunction(t *testing.T) { +func TestCheckGenericFunctionInvocation(t *testing.T) { t.Parallel() @@ -896,7 +897,7 @@ func TestCheckBorrowOfCapabilityWithoutTypeArgument(t *testing.T) { require.NoError(t, err) } -func TestCheckUnparameterizedTypeInstantiationE(t *testing.T) { +func TestCheckInvalidUnparameterizedTypeInstantiation(t *testing.T) { t.Parallel() @@ -910,3 +911,211 @@ func TestCheckUnparameterizedTypeInstantiationE(t *testing.T) { assert.IsType(t, &sema.UnparameterizedTypeInstantiationError{}, errs[0]) } + +func TestCheckGenericFunctionDeclaration(t *testing.T) { + + t.Parallel() + + t.Run("global, non-native", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheckWithOptions(t, ` + fun head(_ items: [T]): T? { return nil } + + let x: Int? = head([1, 2, 3]) + `, + ParseAndCheckOptions{ + Config: &sema.Config{ + AllowNativeDeclarations: false, + }, + ParseOptions: parser.Config{ + NativeModifierEnabled: false, + TypeParametersEnabled: true, + }, + }, + ) + + errs := RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.InvalidTypeParameterizedNonNativeFunctionError{}, errs[0]) + }) + + t.Run("global, native", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheckWithOptions(t, ` + native fun head(_ items: [T]): T? {} + + let x: Int? = head([1, 2, 3]) + `, + ParseAndCheckOptions{ + Config: &sema.Config{ + AllowNativeDeclarations: true, + }, + ParseOptions: parser.Config{ + NativeModifierEnabled: true, + TypeParametersEnabled: true, + }, + }, + ) + + require.NoError(t, err) + }) + + t.Run("composite function, non-native", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheckWithOptions(t, ` + struct S { + fun head(_ items: [T]): T? { return nil } + } + + let x: Int? = S().head([1, 2, 3]) + `, + ParseAndCheckOptions{ + Config: &sema.Config{ + AllowNativeDeclarations: false, + }, + ParseOptions: parser.Config{ + NativeModifierEnabled: false, + TypeParametersEnabled: true, + }, + }, + ) + + errs := RequireCheckerErrors(t, err, 2) + + assert.IsType(t, &sema.InvalidTypeParameterizedNonNativeFunctionError{}, errs[0]) + assert.IsType(t, &sema.InvalidTypeParameterizedNonNativeFunctionError{}, errs[1]) + }) + + t.Run("composite function, non-native", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheckWithOptions(t, ` + struct S { + native fun head(_ items: [T]): T? {} + } + + let x: Int? = S().head([1, 2, 3]) + `, + ParseAndCheckOptions{ + Config: &sema.Config{ + AllowNativeDeclarations: true, + }, + ParseOptions: parser.Config{ + NativeModifierEnabled: true, + TypeParametersEnabled: true, + }, + }, + ) + + require.NoError(t, err) + }) + + t.Run("too many type arguments", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheckWithOptions(t, + ` + native fun test() {} + + let x = test() + `, + ParseAndCheckOptions{ + Config: &sema.Config{ + AllowNativeDeclarations: true, + }, + ParseOptions: parser.Config{ + NativeModifierEnabled: true, + TypeParametersEnabled: true, + }, + }, + ) + + errs := RequireCheckerErrors(t, err, 1) + + require.IsType(t, &sema.InvalidTypeArgumentCountError{}, errs[0]) + }) + + t.Run("too few type arguments", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheckWithOptions(t, + ` + native fun test() {} + + let x = test() + `, + ParseAndCheckOptions{ + Config: &sema.Config{ + AllowNativeDeclarations: true, + }, + ParseOptions: parser.Config{ + NativeModifierEnabled: true, + TypeParametersEnabled: true, + }, + }, + ) + + errs := RequireCheckerErrors(t, err, 1) + + require.IsType(t, &sema.TypeParameterTypeInferenceError{}, errs[0]) + }) + + t.Run("type parameter usage in following type parameter", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheckWithOptions(t, + ` + native fun test(_ u: U): U {} + `, + ParseAndCheckOptions{ + Config: &sema.Config{ + AllowNativeDeclarations: true, + }, + ParseOptions: parser.Config{ + NativeModifierEnabled: true, + TypeParametersEnabled: true, + }, + }, + ) + + errs := RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.NotDeclaredError{}, errs[0]) + }) + + t.Run("type bound is checked", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheckWithOptions(t, + ` + native fun test() {} + + let x = test() + `, + ParseAndCheckOptions{ + Config: &sema.Config{ + AllowNativeDeclarations: true, + }, + ParseOptions: parser.Config{ + NativeModifierEnabled: true, + TypeParametersEnabled: true, + }, + }, + ) + + errs := RequireCheckerErrors(t, err, 1) + + require.IsType(t, &sema.TypeMismatchError{}, errs[0]) + }) +}