Skip to content

Commit

Permalink
Merge pull request #1077 from hylo-lang/members-in-trait-ext
Browse files Browse the repository at this point in the history
Fix name lookup via trait extensions
  • Loading branch information
kyouko-taiga authored Oct 11, 2023
2 parents 7b73654 + 274a3d3 commit 92f2b69
Show file tree
Hide file tree
Showing 7 changed files with 246 additions and 51 deletions.
58 changes: 40 additions & 18 deletions Sources/Core/Program.swift
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,19 @@ extension Program {
isContained(l, in: r) || isContained(r, in: l)
}

/// Returns `true` iff `l` is lexically enclosed in more scopes than `r`.
public func hasMoreAncestors(_ l: AnyDeclID, than r: AnyDeclID) -> Bool {
guard let s = nodeToScope[l] else { return false }
guard let t = nodeToScope[r] else { return true }

var a = scopes(from: s)
var b = scopes(from: t)
while a.next() != nil {
if b.next() == nil { return true }
}
return false
}

/// Returns the scope of `d`'s body, if any.
public func scopeContainingBody(of d: FunctionDecl.ID) -> AnyScopeID? {
switch ast[d].body {
Expand Down Expand Up @@ -256,9 +269,34 @@ extension Program {
}
}

/// Returns whether `d` is a requirement.
/// Returns `true` iff `d` is defined in an extension.
public func isDefinedInExtension<T: DeclID>(_ d: T) -> Bool {
switch d.kind {
case ModuleDecl.self:
return false
case MethodImpl.self:
return isDefinedInExtension(MethodDecl.ID(nodeToScope[d]!)!)
case SubscriptImpl.self:
return isDefinedInExtension(SubscriptDecl.ID(nodeToScope[d]!)!)
default:
return nodeToScope[d]!.kind == ExtensionDecl.self
}
}

/// Returns `true` iff `d` is a trait requirement.
public func isRequirement<T: DeclID>(_ d: T) -> Bool {
trait(defining: d) != nil
switch d.kind {
case AssociatedTypeDecl.self, AssociatedValueDecl.self:
return true
case FunctionDecl.self, InitializerDecl.self, MethodDecl.self, SubscriptDecl.self:
return nodeToScope[d]!.kind == TraitDecl.self
case MethodImpl.self:
return isRequirement(MethodDecl.ID(nodeToScope[d]!)!)
case SubscriptImpl.self:
return isRequirement(SubscriptDecl.ID(nodeToScope[d]!)!)
default:
return false
}
}

/// If `s` is in a member context, returns the innermost receiver declaration exposed to `s`.
Expand Down Expand Up @@ -313,22 +351,6 @@ extension Program {
scopes(from: scope).first(TranslationUnit.self)!
}

/// Returns the trait of which `d` is a member, or `nil` if `d` isn't member of a trait.
public func trait<T: DeclID>(defining d: T) -> TraitDecl.ID? {
switch d.kind {
case AssociatedTypeDecl.self, AssociatedValueDecl.self:
return TraitDecl.ID(nodeToScope[d]!)!
case FunctionDecl.self, InitializerDecl.self, MethodDecl.self, SubscriptDecl.self:
return TraitDecl.ID(nodeToScope[d]!)
case MethodImpl.self:
return trait(defining: MethodDecl.ID(nodeToScope[d]!)!)
case SubscriptImpl.self:
return trait(defining: SubscriptDecl.ID(nodeToScope[d]!)!)
default:
return nil
}
}

/// Returns the name of `d` if it introduces a single entity.
public func name(of d: AnyDeclID) -> Name? {
if let e = self.ast[d] as? SingleEntityDecl { return Name(stem: e.baseName) }
Expand Down
139 changes: 111 additions & 28 deletions Sources/FrontEnd/TypeChecking/TypeChecker.swift
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ struct TypeChecker {
canonical(t, in: scopeOfUse) == canonical(u, in: scopeOfUse)
}

/// Returns `true` iff `t` is a refinement of `u` in `scopeOfUse`.
mutating func isRefinement(_ t: TraitType, of u: TraitType, in scopeOfUse: AnyScopeID) -> Bool {
(t != u) && conformedTraits(of: t, in: scopeOfUse).contains(u)
}

/// Returns the traits to which `t` is declared conforming in `scopeOfUse`.
mutating func conformedTraits(of t: AnyType, in scopeOfUse: AnyScopeID) -> Set<TraitType> {
let key = Cache.TypeLookupKey(t, in: scopeOfUse)
Expand Down Expand Up @@ -1160,11 +1165,13 @@ struct TypeChecker {
guard !t[.hasError] else { return nil }

let candidates = lookup(n.stem, memberOf: m, exposedTo: scopeOfExposition)
let viable: [AnyDeclID] = candidates.reduce(into: []) { (s, c) in
guard let d = D(c) else { return }
appendDefinitions(d, t, &s)
var viable: [AnyDeclID] = []
for c in candidates {
guard let d = D(c) else { continue }
appendDefinitions(d, t, &viable)
}

viable = viable.minimalElements(by: { (a, b) in compareDepth(a, b, in: scopeOfExposition) })
return viable.uniqueElement
}

Expand Down Expand Up @@ -2607,16 +2614,8 @@ struct TypeChecker {
// TODO: Read source of conformance to disambiguate associated names
let newMatches = lookup(stem, memberOf: ^t, exposedTo: scopeOfUse)

// Associated type and value declarations are not inherited by conformance. Traits do not
// inherit the generic parameters.
switch nominalScope.base {
case is AssociatedTypeType, is GenericTypeParameterType:
matches.formUnion(newMatches)
case is TraitType:
matches.formUnion(newMatches.filter({ $0.kind != GenericParameterDecl.self }))
default:
matches.formUnion(newMatches.filter(program.isRequirement(_:)))
}
// Generic parameters introduced by traits are not inherited.
matches.formUnion(newMatches.filter({ $0.kind != GenericParameterDecl.self }))
}

return matches
Expand Down Expand Up @@ -3049,8 +3048,8 @@ struct TypeChecker {
}

// If the match is a trait member looked up with qualification, specialize its receiver.
if let t = program.trait(defining: m) {
specialization[program[t].receiver] = context?.type
if let t = traitDefining(m) {
specialization[program[t.decl].receiver] = context?.type
}

// If the name resolves to an initializer, determine if it is used as a constructor.
Expand All @@ -3064,8 +3063,8 @@ struct TypeChecker {
// If the receiver is an existential, replace its receiver.
if let container = ExistentialType(context?.type) {
candidateType = candidateType.asMember(of: container)
if let t = program.trait(defining: m) {
specialization[program[t].receiver] = ^WitnessType(of: container)
if let t = traitDefining(m) {
specialization[program[t.decl].receiver] = ^WitnessType(of: container)
}
}

Expand Down Expand Up @@ -3473,6 +3472,27 @@ struct TypeChecker {
}
}

/// Returns the trait of which `d` is a member, or `nil` if `d` isn't member of a trait.
mutating func traitDefining<T: DeclID>(_ d: T) -> TraitType? {
guard let p = program.nodeToScope[d] else {
assert(d.kind == ModuleDecl.self)
return nil
}

switch p.kind {
case TraitDecl.self:
return TraitType(TraitDecl.ID(p)!, ast: program.ast)
case ExtensionDecl.self:
return TraitType(uncheckedType(of: ExtensionDecl.ID(p)!))
case MethodDecl.self:
return traitDefining(MethodDecl.ID(p)!)
case SubscriptDecl.self:
return traitDefining(SubscriptDecl.ID(p)!)
default:
return nil
}
}

// MARK: Quantifier elimination

/// A context in which a generic parameter can be instantiated.
Expand Down Expand Up @@ -4585,16 +4605,16 @@ struct TypeChecker {
var ranking: StrictPartialOrdering = .equal
var namesInCommon = 0

for (n, lhsDeclRef) in lhs.bindingAssumptions {
guard let rhsDeclRef = rhs.bindingAssumptions[n] else { continue }
for (n, lhs) in lhs.bindingAssumptions {
guard let rhs = rhs.bindingAssumptions[n] else { continue }
namesInCommon += 1

// Nothing to do if both functions have the binding.
if lhsDeclRef == rhsDeclRef { continue }
let lhs = uncheckedType(of: lhsDeclRef.decl!)
let rhs = uncheckedType(of: rhsDeclRef.decl!)
// Nothing to do if both functions have the same binding.
if lhs == rhs { continue }

switch compareSpecificity(lhs, rhs, in: program[n].scope, at: program[n].site) {
let o = compareBindingPrecedence(
lhs.decl!, rhs.decl!, in: program[n].scope, at: program[n].site)
switch o {
case .ascending:
if ranking == .descending { return nil }
ranking = .ascending
Expand Down Expand Up @@ -4625,11 +4645,74 @@ struct TypeChecker {
return namesInCommon == lhs.bindingAssumptions.count ? ranking : nil
}

/// Compares `lhs` and `rhs` and returns whether one is more specific than the other in
/// `scopeOfUse`, instantiating generic type constraints at `site`.
/// Compares `lhs` and `rhs` in `scopeOfUse` and returns whether one has either shadows or is
/// more specific than the other.
///
/// `lhs` and `rhs` are assumed to have compatible types.
private mutating func compareBindingPrecedence(
_ lhs: AnyDeclID, _ rhs: AnyDeclID, in scopeOfUse: AnyScopeID, at site: SourceRange
) -> StrictPartialOrdering {
if let o = compareDepth(lhs, rhs, in: scopeOfUse) {
return o
}

let t = uncheckedType(of: lhs)
let u = uncheckedType(of: rhs)
return compareSpecificity(t, u, in: scopeOfUse, at: site)
}

/// Compares `lhs` and `rhs` in `scopeOfUse` and returns whether one shadows the other.
///
/// `lhs` is deeper than `rhs` w.r.t. `scopeOfUse` if either of these statements hold:
/// - `lhs` and `rhs` are members of traits `t1` and `t2`, respectively, and `t1` refines `t2`
/// - `lhs` isn't member of a trait and `rhs` is.
/// - `lhs` is declared in the module containing `scopeOfUse` and `rhs` isn't.
/// - `lhs` and `rhs` are declared in module containing `scopeOfUse` and `lhs` has more ancestors
/// than `rhs`.
private mutating func compareDepth(
_ lhs: AnyDeclID, _ rhs: AnyDeclID, in scopeOfUse: AnyScopeID
) -> StrictPartialOrdering {
if let l = traitDefining(lhs) {
// If `lhs` is a trait member but `rhs` isn't, then `rhs` shadows `lhs`.
guard let r = traitDefining(rhs) else { return .descending }

// If `lhs` and `rhs` are members of traits `t1` and `t2`, respectively, then `lhs` shadows
// `rhs` iff `t1` refines `t2`.
if isRefinement(l, of: r, in: scopeOfUse) { return .ascending }
if isRefinement(r, of: l, in: scopeOfUse) { return .descending }
return nil
}

if traitDefining(rhs) != nil {
// If `rhs` is a trait member but `lhs` isn't, then `lhs` shadows `rhs`.
return .ascending
}

let m = program.module(containing: scopeOfUse)
if program.isContained(lhs, in: m) {
// If `lhs` is in the same module as `scopeOfUse` but `rhs` isn't, then `lhs` shadows `rhs`.
guard program.isContained(rhs, in: m) else { return .ascending }

// If `lhs` and `rhs` are in the same module as `scopeOfUse`, then `lhs` shadows `rhs` iff
// it has more ancestors than `rhs`.
if program.hasMoreAncestors(lhs, than: rhs) { return .ascending }
if program.hasMoreAncestors(rhs, than: lhs) { return .descending }
return nil
}

if program.isContained(rhs, in: m) {
// If `rhs` is in the same module as `scopeOfUse` but `lhs` isn't, then `rhs` shadows `lhs`.
return .descending
}

return nil
}

/// Compares `lhs` and `rhs` in `scopeOfUse` and returns whether one is more specific than the
/// other, instantiating generic type constraints at `site`.
///
/// `t1` is more specific than `t2` if both are callable types with the same labels and `t1`
/// accepts strictly less arguments than `t2`.
/// `lhs` is more specific than `rhs` iff both `lhs` and `rhs` are callable types with the same
/// labels and `lhs` accepts strictly less arguments than `rhs`.
private mutating func compareSpecificity(
_ lhs: AnyType, _ rhs: AnyType, in scopeOfUse: AnyScopeID, at site: SourceRange
) -> StrictPartialOrdering {
Expand Down
6 changes: 6 additions & 0 deletions Sources/FrontEnd/TypedProgram.swift
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,12 @@ public struct TypedProgram {
return checker.accumulatedGenericParameters(in: s)
}

/// Returns the trait of which `d` is a member, or `nil` if `d` isn't member of a trait.
public func traitDefining<T: DeclID>(_ d: T) -> TraitType? {
var checker = TypeChecker(asContextFor: self)
return checker.traitDefining(d)
}

/// Returns `true` iff `model` conforms to `concept` in `scopeOfUse`.
public func conforms(
_ model: AnyType, to concept: TraitType, in scopeOfUse: AnyScopeID
Expand Down
8 changes: 3 additions & 5 deletions Sources/IR/TypedProgram+Extensions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@ extension TypedProgram {

/// If `f` refers to a trait member, returns the declaration of that member along with the trait
/// in which it is defined. Otherwise, returns `nil`.
func traitMember(
referredBy f: Function.ID
) -> (declaration: AnyDeclID, trait: TraitType)? {
func traitMember(referredBy f: Function.ID) -> (declaration: AnyDeclID, trait: TraitType)? {
switch f.value {
case .lowered(let d):
guard let t = trait(defining: d) else { return nil }
return (declaration: d, trait: TraitType(t, ast: ast))
guard let t = traitDefining(d) else { return nil }
return (declaration: d, trait: t)

default:
return nil
Expand Down
47 changes: 47 additions & 0 deletions Sources/Utils/Collection+Extensions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,53 @@ extension Collection {
return l
}

/// Returns the minimal elements of `self` using `compare` to order them.
///
/// A minimal element of a set *S* with a strict partial order *R* is an element is not smaller
/// than any other element in *S*. If *S* is a finite set and *R* is a strict total order, the
/// notions of minimal element and minimum coincide.
///
/// - Complexity: O(*n*^2) where *n* is the length of `self`.
public func minimalElements(
by compare: (Element, Element) -> StrictPartialOrdering
) -> [Element] {
if let u = uniqueElement { return [u] }

// This algorithm successively eliminates elements that are not minimal until all candidates
// have been considered. All elements are candidates at the start. Then, each candidate is
// compared with others. Greater elements are moved beyond the end of the candidate list while
// incomparable ones are left in place. At each point, elements left of the current candidate
// are known to be incomparable with each others and smaller than eliminated elements.

var candidates = Array(indices)
var end = candidates.count
var i = 0
var j = 1

while i < end {
while j < end {
switch compare(self[candidates[i]], self[candidates[j]]) {
case .ascending, .equal:
candidates.swapAt(j, end - 1)
end -= 1

case .descending:
candidates.swapAt(i, end - 1)
end -= 1
j = i + 1

case nil:
j += 1
}
}

i += 1
j = i + 1
}

return candidates[0 ..< end].map({ self[$0] })
}

}

extension Collection where Index == Int {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
//- typeCheck expecting: success

trait P { fun foo() }

extension P {
public fun foo() {}
}

conformance Int: P {}

conformance Bool: P {
fun foo() {}
}

public fun main() {
(1 + 1).foo()
(1 < 1).foo()
}
Loading

0 comments on commit 92f2b69

Please sign in to comment.