Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Compiler] Improve default functions existence with conditions #3754

Merged
merged 6 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 147 additions & 104 deletions bbq/compiler/desugar.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ func (d *Desugar) VisitFunctionDeclaration(declaration *ast.FunctionDeclaration)
// Add the remaining statements that are defined in this function.
statements := funcBlock.Block.Statements
modifiedStatements = append(modifiedStatements, statements...)

} else if d.enclosingInterfaceType != nil {
// If this is an interface-method without a body,
// then do not generate a function for it.
return nil
}

// Before the post conditions are appended, we need to move the
Expand Down Expand Up @@ -186,6 +191,7 @@ func (d *Desugar) VisitFunctionDeclaration(declaration *ast.FunctionDeclaration)
nil,
)

// TODO: Is the generated function needed to be desugared again?
return ast.NewFunctionDeclaration(
d.memoryGauge,
declaration.Access,
Expand Down Expand Up @@ -656,6 +662,8 @@ func (d *Desugar) generateConditionsFunction(
"",
)

// TODO: Is the generated function needed to be desugared?

d.modifiedDeclarations = append(d.modifiedDeclarations, conditionFunc)
}

Expand Down Expand Up @@ -836,8 +844,6 @@ func (d *Desugar) VisitAttachmentDeclaration(declaration *ast.AttachmentDeclarat
}

func (d *Desugar) VisitCompositeDeclaration(declaration *ast.CompositeDeclaration) ast.Declaration {
existingMembers := declaration.Members.Declarations()

compositeType := d.elaboration.CompositeDeclarationType(declaration)

// Recursively de-sugar nested declarations (functions, types, etc.)
Expand All @@ -850,36 +856,42 @@ func (d *Desugar) VisitCompositeDeclaration(declaration *ast.CompositeDeclaratio

var desugaredMembers []ast.Declaration
membersDesugared := false
existingMembers := declaration.Members.Declarations()

for _, member := range existingMembers {
desugaredMember := d.desugarDeclaration(member)
if desugaredMember == nil {
continue
}

membersDesugared = membersDesugared || (desugaredMember != member)
desugaredMembers = append(desugaredMembers, desugaredMember)
}

// Copy over inherited default functions.

inheritedDefaultFuncs := d.inheritedDefaultFunctions(compositeType, declaration)
// Add inherited default functions.
existingFunctions := declaration.Members.FunctionsByIdentifier()
inheritedDefaultFuncs := d.inheritedDefaultFunctions(
compositeType,
existingFunctions,
declaration.StartPos,
declaration.Range,
)

// Optimization: If none of the existing members got updated or,
// if there are no inherited members, then return the same declaration as-is.
if !membersDesugared && len(inheritedDefaultFuncs) == 0 {
return declaration
}

modifiedMembers := make([]ast.Declaration, len(desugaredMembers))
copy(modifiedMembers, desugaredMembers)

modifiedMembers = append(modifiedMembers, inheritedDefaultFuncs...)
desugaredMembers = append(desugaredMembers, inheritedDefaultFuncs...)

modifiedDecl := ast.NewCompositeDeclaration(
d.memoryGauge,
declaration.Access,
declaration.CompositeKind,
declaration.Identifier,
declaration.Conformances,
ast.NewMembers(d.memoryGauge, modifiedMembers),
ast.NewMembers(d.memoryGauge, desugaredMembers),
declaration.DocString,
declaration.Range,
)
Expand All @@ -890,11 +902,10 @@ func (d *Desugar) VisitCompositeDeclaration(declaration *ast.CompositeDeclaratio
return modifiedDecl
}

func (d *Desugar) inheritedFunctionsWithConditions(compositeType *sema.CompositeType) map[string][]*inheritedFunction {
func (d *Desugar) inheritedFunctionsWithConditions(compositeType sema.ConformingType) map[string][]*inheritedFunction {
inheritedFunctions := make(map[string][]*inheritedFunction)

for _, conformance := range compositeType.EffectiveInterfaceConformances() {
interfaceType := conformance.InterfaceType
compositeType.EffectiveInterfaceConformanceSet().ForEach(func(interfaceType *sema.InterfaceType) {

elaboration, err := d.config.ElaborationResolver(interfaceType.Location)
if err != nil {
Expand All @@ -915,116 +926,154 @@ func (d *Desugar) inheritedFunctionsWithConditions(compositeType *sema.Composite
})
inheritedFunctions[name] = funcs
}
}
})

return inheritedFunctions
}

func (d *Desugar) inheritedDefaultFunctions(compositeType *sema.CompositeType, decl *ast.CompositeDeclaration) []ast.Declaration {
directMembers := compositeType.Members
allMembers := compositeType.GetMembers()
func (d *Desugar) inheritedDefaultFunctions(
compositeType sema.ConformingType,
existingFunctions map[string]*ast.FunctionDeclaration,
pos ast.Position,
declRange ast.Range,
) []ast.Declaration {

pos := decl.StartPos
inheritedDefaultFunctions := make(map[string]struct{})

inheritedMembers := make([]ast.Declaration, 0)

for memberName, resolver := range allMembers { // nolint:maprange
if directMembers.Contains(memberName) {
continue
}

member := resolver.Resolve(
d.memoryGauge,
memberName,
ast.EmptyRange,
func(err error) {
if err != nil {
panic(err)
}
},
)

// Only interested in functions.
// Also filter out built-in functions.
if member.DeclarationKind != common.DeclarationKindFunction ||
member.Predeclared {
continue
}

// Inherited functions are always from interfaces
interfaceType := member.ContainerType.(*sema.InterfaceType)
for _, conformance := range compositeType.EffectiveInterfaceConformances() {
SupunS marked this conversation as resolved.
Show resolved Hide resolved
interfaceType := conformance.InterfaceType

elaboration, err := d.config.ElaborationResolver(interfaceType.Location)
if err != nil {
panic(err)
}

interfaceDecl := elaboration.InterfaceTypeDeclaration(interfaceType)

functions := interfaceDecl.Members.FunctionsByIdentifier()
inheritedFunc, ok := functions[memberName]
if !ok {
panic(errors.NewUnreachableError())
}

// for each inherited function, generate a delegator function,
// which calls the actual default implementation at the interface.
// i.e:
// FooImpl {
// fun defaultFunc(a1: T1, a2: T2): R {
// return FooInterface.defaultFunc(a1, a2)
// }
// }
for funcName, inheritedFunc := range functions { // nolint:maprange
if !inheritedFunc.FunctionBlock.HasStatements() {
continue
}

// Generate: `FooInterface.defaultFunc(a1, a2)`
// Pick the 'closest' default function.
// This is the same way how it is implemented in the interpreter.
_, ok := inheritedDefaultFunctions[funcName]
if ok {
continue
}
inheritedDefaultFunctions[funcName] = struct{}{}

inheritedFuncType := elaboration.FunctionDeclarationFunctionType(inheritedFunc)
// If the inherited function is overridden by the current type, then skip.
if d.isFunctionOverridden(compositeType, funcName, existingFunctions) {
continue
}

invocation := d.interfaceDelegationMethodCall(
interfaceType,
inheritedFuncType,
pos,
memberName,
member,
)
// For each inherited function, generate a delegator function,
// which calls the actual default implementation at the interface.
// i.e:
// FooImpl {
// fun defaultFunc(a1: T1, a2: T2): R {
// return FooInterface.defaultFunc(a1, a2)
// }
// }

// Generate: `fun defaultFunc(a1: T1, a2: T2) { ... }`
defaultFuncDelegator := ast.NewFunctionDeclaration(
d.memoryGauge,
inheritedFunc.Access,
inheritedFunc.Purity,
inheritedFunc.IsStatic(),
inheritedFunc.IsNative(),
ast.NewIdentifier(
d.memoryGauge,
memberName,
// Generate: `FooInterface.defaultFunc(a1, a2)`

inheritedFuncType := elaboration.FunctionDeclarationFunctionType(inheritedFunc)

member, ok := interfaceType.MemberMap().Get(funcName)
if !ok {
panic(errors.NewUnreachableError())
}

invocation := d.interfaceDelegationMethodCall(
interfaceType,
inheritedFuncType,
pos,
),
inheritedFunc.TypeParameterList,
inheritedFunc.ParameterList,
inheritedFunc.ReturnTypeAnnotation,
ast.NewFunctionBlock(
funcName,
member,
)

funcReturnType := inheritedFuncType.ReturnTypeAnnotation.Type
returnStmt := ast.NewReturnStatement(d.memoryGauge, invocation, declRange)
d.elaboration.SetReturnStatementTypes(
returnStmt,
sema.ReturnStatementTypes{
ValueType: funcReturnType,
ReturnType: funcReturnType,
},
)

// Generate: `fun defaultFunc(a1: T1, a2: T2) { ... }`
defaultFuncDelegator := ast.NewFunctionDeclaration(
d.memoryGauge,
ast.NewBlock(
inheritedFunc.Access,
inheritedFunc.Purity,
inheritedFunc.IsStatic(),
inheritedFunc.IsNative(),
ast.NewIdentifier(
d.memoryGauge,
[]ast.Statement{
ast.NewReturnStatement(d.memoryGauge, invocation, decl.Range),
},
decl.Range,
funcName,
pos,
),
nil,
nil,
),
inheritedFunc.StartPos,
inheritedFunc.DocString,
)
inheritedFunc.TypeParameterList,
inheritedFunc.ParameterList,
inheritedFunc.ReturnTypeAnnotation,
ast.NewFunctionBlock(
d.memoryGauge,
ast.NewBlock(
d.memoryGauge,
[]ast.Statement{
returnStmt,
},
declRange,
),
nil,
nil,
),
inheritedFunc.StartPos,
inheritedFunc.DocString,
)

d.elaboration.SetFunctionDeclarationFunctionType(defaultFuncDelegator, inheritedFuncType)

// Pass the generated default function again through the desugar phase,
// so that it will properly link/chain the function conditions
// that are inherited/available for this default function.
desugaredDelegator := d.desugarDeclaration(defaultFuncDelegator)

inheritedMembers = append(inheritedMembers, defaultFuncDelegator)
inheritedMembers = append(inheritedMembers, desugaredDelegator)

}
}

return inheritedMembers
}

func (d *Desugar) isFunctionOverridden(
enclosingType sema.ConformingType,
funcName string,
existingFunctions map[string]*ast.FunctionDeclaration,
) bool {
implementedFunc, isImplemented := existingFunctions[funcName]
if !isImplemented {
return false
}

_, isInterface := enclosingType.(*sema.InterfaceType)
if isInterface {
// If the currently visiting declaration is an interface type (i.e: This function is an interface method)
// then it is considered as a default implementation only if there are statements.
// This is because interface methods can define conditions, without overriding the function.
return implementedFunc.FunctionBlock.HasStatements()
}

return true
}

func (d *Desugar) interfaceDelegationMethodCall(
interfaceType *sema.InterfaceType,
inheritedFuncType *sema.FunctionType,
Expand Down Expand Up @@ -1162,6 +1211,7 @@ func (d *Desugar) VisitInterfaceDeclaration(declaration *ast.InterfaceDeclaratio

prevModifiedDecls := d.modifiedDeclarations
prevEnclosingInterfaceType := d.enclosingInterfaceType

d.modifiedDeclarations = nil
d.enclosingInterfaceType = interfaceType

Expand All @@ -1170,26 +1220,19 @@ func (d *Desugar) VisitInterfaceDeclaration(declaration *ast.InterfaceDeclaratio
d.enclosingInterfaceType = prevEnclosingInterfaceType
}()

existingMembers := declaration.Members.Declarations()

// Recursively de-sugar nested declarations (functions, types, etc.)

membersDesugared := false

existingMembers := declaration.Members.Declarations()
for _, member := range existingMembers {
desugaredMember := d.desugarDeclaration(member)
membersDesugared = membersDesugared || (desugaredMember != member)
if desugaredMember == nil {
continue
}
d.modifiedDeclarations = append(d.modifiedDeclarations, desugaredMember)
}

// Optimization: If none of the existing members got updated or,
// TODO: Optimize: If none of the existing members got updated or,
// if there are no inherited members, then return the same declaration as-is.
//if !membersDesugared && len(inheritedDefaultFuncs) == 0 {
// return declaration
//}
if !membersDesugared {
return declaration
}

modifiedDecl := ast.NewInterfaceDeclaration(
d.memoryGauge,
Expand Down
Loading