Skip to content

Commit

Permalink
Merge pull request #1568 from hylo-lang/non-consuming-narrowing
Browse files Browse the repository at this point in the history
Add support for non-consuming projections of union payloads
  • Loading branch information
kyouko-taiga authored Aug 23, 2024
2 parents d61d644 + dff9eca commit 610865d
Show file tree
Hide file tree
Showing 17 changed files with 174 additions and 147 deletions.
5 changes: 2 additions & 3 deletions Sources/CodeGen/LLVM/Transpilation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1211,14 +1211,13 @@ extension SwiftyLLVM.Module {
if let (_, b) = s.targets.elements.uniqueElement {
insertBr(to: block[b]!, at: insertionPoint)
} else {
let d = discriminator(s.scrutinee)
let t = UnionType(m.type(of: s.scrutinee).ast)!
let e = m.program.discriminatorToElement(in: t)
let e = m.program.discriminatorToElement(in: s.union)
let branches = s.targets.map { (t, b) in
(word().constant(e.firstIndex(of: t)!), block[b]!)
}

// The last branch is the "default".
let d = llvm(s.discriminator)
insertSwitch(
on: d, cases: branches.dropLast(), default: branches.last!.1,
at: insertionPoint)
Expand Down
2 changes: 1 addition & 1 deletion Sources/IR/Analysis/Module+AccessReification.swift
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ extension Module {

forEachClient(of: i) { (u) in
let rs = requests(u)
lower = max(rs.weakest!, lower)
if let w = rs.weakest { lower = max(w, lower) }
upper = rs.strongest(including: upper)
}

Expand Down
6 changes: 6 additions & 0 deletions Sources/IR/Analysis/Module+CloseBorrows.swift
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ extension Module {
this.makeCloseCapture(.register(i), at: site)
}

case is OpenUnion:
let region = extendedLiveRange(of: .register(i))
insertClose(i, atBoundariesOf: region) { (this, site) in
this.makeCloseUnion(.register(i), at: site)
}

case is Project:
let region = extendedLiveRange(of: .register(i))
insertClose(i, atBoundariesOf: region) { (this, site) in
Expand Down
195 changes: 98 additions & 97 deletions Sources/IR/Emitter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -612,30 +612,27 @@ struct Emitter {
///
/// - Requires: `d` is a local `let` or `inout` binding.
private mutating func lower(projectedLocalBinding d: BindingDecl.ID) {
let access = AccessEffect(program[d].pattern.introducer.value)
precondition(access == .let || access == .inout)
precondition(program.isLocal(d))
let source = emitLValue(ast[d].initializer!)
assignProjections(of: source, to: program[d].pattern)
}

let initializer = ast[d].initializer!
let source = emitLValue(initializer)
let isSink = module.isSink(source)
/// Assigns the bindings declared in `d` to their corresponding projection of `rhs`.
private mutating func assignProjections(of rhs: Operand, to d: BindingPattern.ID) {
precondition(!program[d].introducer.value.isConsuming)
let k = AccessEffect(program[d].introducer.value)
let request: AccessEffectSet = module.isSink(rhs) ? [k, .sink] : [k]

for (path, name) in ast.names(in: program[d].pattern.subpattern) {
var part = emitSubfieldView(source, at: path, at: program[name].decl.site)
for (path, name) in ast.names(in: program[d].subpattern) {
var part = emitSubfieldView(rhs, at: path, at: program[name].decl.site)
let partDecl = ast[name].decl

let t = canonical(program[partDecl].type)
part = emitCoerce(part, to: t, at: ast[partDecl].site)
let bindingType = canonical(program[partDecl].type)
part = emitCoerce(part, to: bindingType, at: ast[partDecl].site)

if isSink {
let b = module.makeAccess(
[.sink, access], from: part, correspondingTo: partDecl, at: ast[partDecl].site)
frames[partDecl] = insert(b)!
} else {
let b = module.makeAccess(
access, from: part, correspondingTo: partDecl, at: ast[partDecl].site)
frames[partDecl] = insert(b)!
}
let b = module.makeAccess(
request, from: part, correspondingTo: partDecl, at: ast[partDecl].site)
frames[partDecl] = insert(b)!
}
}

Expand Down Expand Up @@ -740,7 +737,7 @@ struct Emitter {
let targets = UnionSwitch.Targets(
t.elements.map({ (e) in (key: e, value: appendBlock()) }),
uniquingKeysWith: { (a, _) in a })
insert(module.makeUnionSwitch(on: receiver, toOneOf: targets, at: site))
emitUnionSwitch(on: receiver, toOneOf: targets, at: site)

let tail = appendBlock()
for (u, b) in targets {
Expand Down Expand Up @@ -896,7 +893,7 @@ struct Emitter {
let targets = UnionSwitch.Targets(
t.elements.map({ (e) in (key: e, value: appendBlock()) }),
uniquingKeysWith: { (a, _) in a })
insert(module.makeUnionSwitch(on: source, toOneOf: targets, at: site))
emitUnionSwitch(on: source, toOneOf: targets, at: site)

let tail = appendBlock()
for (u, b) in targets {
Expand Down Expand Up @@ -1661,8 +1658,8 @@ struct Emitter {
let calleeType = ArrowType(t)!.lifted

// Emit the operands, starting with RHS.
let r = emit(infixOperand: rhs, passed: ParameterType(calleeType.inputs[1].type)!.access)
let l = emit(infixOperand: lhs, passed: ParameterType(calleeType.inputs[0].type)!.access)
let r = emit(infixOperand: rhs, passedTo: ParameterType(calleeType.inputs[1].type)!)
let l = emit(infixOperand: lhs, passedTo: ParameterType(calleeType.inputs[0].type)!)

// The callee must be a reference to member function.
guard case .member(let d, let a, _) = program[callee.expr].referredDecl else {
Expand Down Expand Up @@ -2073,24 +2070,21 @@ struct Emitter {
return (callee, captures + arguments)
}

/// Inserts the IR for infix operand `e` passed with convention `access`.
/// Inserts the IR for infix operand `e` passed to a parameter of type `p`.
private mutating func emit(
infixOperand e: FoldedSequenceExpr, passed access: AccessEffect
infixOperand e: FoldedSequenceExpr, passedTo p: ParameterType
) -> Operand {
let storage: Operand

switch e {
case .infix(let callee, _, _):
let t = ArrowType(canonical(program[callee.expr].type))!.lifted
storage = emitAllocStack(for: t.output, at: ast.site(of: e))
emitStore(e, to: storage)
case .infix(let f, _, _):
let t = ArrowType(canonical(program[f.expr].type))!.lifted
let s = emitAllocStack(for: t.output, at: ast.site(of: e))
emitStore(e, to: s)
let u = emitCoerce(s, to: p.bareType, at: ast.site(of: e))
return insert(module.makeAccess(p.access, from: u, at: ast.site(of: e)))!

case .leaf(let e):
let x0 = emitLValue(e)
storage = unwrapCapture(x0, at: program[e].site)
return emitArgument(e, to: p, at: program[e].site)
}

return insert(module.makeAccess(access, from: storage, at: ast.site(of: e)))!
}

/// Emits the IR of a call to `f` with given `arguments` at `site`.
Expand Down Expand Up @@ -2308,106 +2302,103 @@ struct Emitter {
return (entityToCall, [c])
}


/// Returns `(success: a, failure: b)` where `a` is the basic block reached if all items in
/// `condition` hold and `b` is the basic block reached otherwise, creating new basic blocks
/// in `scope`.
private mutating func emitTest(
condition: [ConditionItem], in scope: AnyScopeID
) -> (success: Block.ID, failure: Block.ID) {
let f = insertionFunction!

// Allocate storage for all the declarations in the condition before branching so that all
// `dealloc_stack` are to dominated by their corresponding `alloc_stack`.
var allocs: [Operand] = []
// `dealloc_stack` are dominated by their corresponding `alloc_stack`.
var allocations: [Operand?] = []
for case .decl(let d) in condition {
let a = insert(module.makeAllocStack(program[d].type, at: ast[d].site))!
allocs.append(a)
if program[d].pattern.introducer.value.isConsuming {
allocations.append(insert(module.makeAllocStack(program[d].type, at: ast[d].site)))
} else {
allocations.append(nil)
}
}

let failure = module.appendBlock(in: scope, to: f)
let failure = module.appendBlock(in: scope, to: insertionFunction!)
for (i, item) in condition.enumerated() {
switch item {
case .expr(let e):
let test = pushing(Frame(), { $0.emit(branchCondition: e) })
let next = module.appendBlock(in: scope, to: f)
let next = appendBlock(in: scope)
insert(module.makeCondBranch(if: test, then: next, else: failure, at: ast[e].site))
insertionPoint = .end(of: next)

case .decl(let d):
let subject = emitLValue(ast[d].initializer!)
let patternType = canonical(program[d].type)
let next = emitConditionalNarrowing(
subject, as: ast[d].pattern, typed: patternType, to: allocs[i],
else: failure, in: scope)
d, movingConsumedValuesTo: allocations[i],
branchingOnFailureTo: failure, in: scope)
insertionPoint = .end(of: next)
}
}

return (success: insertionBlock!, failure: failure)
}

/// Returns a basic block in which the names in `pattern` have been declared and initialized.
/// Returns a basic block in which the names in `d` have been declared and initialized.
///
/// This method emits IR to:
///
/// - check whether the value in `subject` is an instance of `patternType`;
/// - evaluate the `d`'s initializer as value *v*,
/// - check whether the value in *v* is an instance of `d`'s type;
/// - if it isn't, jump to `failure`;
/// - if it is, jump to a new basic block *b*, coerce the contents of `subject` into `storage`,
/// applying consuming coercions as necessary, and define the bindings declared in `pattern`.
/// - if it is, jump to a new basic block and define and initialize the bindings declared in `d`.
///
/// If `subject` always matches `patternType`, the narrowing is irrefutable and `failure` is
/// unreachable in the generated IR.
/// If `d` has a consuming introducer (e.g., `var`), the value of `d`'s initializer is moved to
/// `storage`, which denotes a memory location with `d`'s type. Otherwise, `storage` is `nil` and
/// the bindings in `d` are defined as new projections. In either case, the emitter's context is
/// is updated to associate each binding to its value.
///
/// The return value is the new basic block *b*, which is defined in `scope`. The emitter context
/// is updated to associate the bindings declared in `pattern` to their address in `storage`.
/// The return value of the method is a basic block, defined in `scope`. If *v* has the same type
/// as `d`, the narrowing is irrefutable and `failure` is unreachable in the generated IR.
private mutating func emitConditionalNarrowing(
_ subject: Operand,
as pattern: BindingPattern.ID, typed patternType: AnyType,
to storage: Operand,
else failure: Block.ID, in scope: AnyScopeID
_ d: BindingDecl.ID,
movingConsumedValuesTo storage: Operand?,
branchingOnFailureTo failure: Block.ID,
in scope: AnyScopeID
) -> Block.ID {
switch module.type(of: subject).ast.base {
case let t as UnionType:
return emitConditionalNarrowing(
subject, typed: t, as: pattern, typed: patternType, to: storage,
else: failure, in: scope)
default:
break
}
let lhsType = canonical(program[d].type)
let rhs = emitLValue(ast[d].initializer!)
let lhs = ast[d].pattern

UNIMPLEMENTED()
}
assert(program[lhs].introducer.value.isConsuming || (storage == nil))

/// Returns a basic block in which the names in `pattern` have been declared and initialized.
///
/// This method method implements conditional narrowing for union types.
private mutating func emitConditionalNarrowing(
_ subject: Operand, typed union: UnionType,
as pattern: BindingPattern.ID, typed patternType: AnyType,
to storage: Operand,
else failure: Block.ID, in scope: AnyScopeID
) -> Block.ID {
// TODO: Implement narrowing to an arbitrary subtype.
guard union.elements.contains(patternType) else { UNIMPLEMENTED() }
let site = ast[pattern].site
if let rhsType = UnionType(module.type(of: rhs).ast) {
guard rhsType.elements.contains(lhsType) else { UNIMPLEMENTED("recursive narrowing") }

let next = appendBlock(in: scope)
var targets = UnionSwitch.Targets(
union.elements.map({ (e) in (key: e, value: failure) }),
uniquingKeysWith: { (a, _) in a })
targets[patternType] = next
let next = appendBlock(in: scope)
let site = program[lhs].site

insert(module.makeUnionSwitch(on: subject, toOneOf: targets, at: site))
insertionPoint = .end(of: next)
let x0 = insert(module.makeOpenUnion(subject, as: patternType, at: site))!
pushing(Frame()) { (this) in
this.emitMove([.set], x0, to: storage, at: site)
}
insert(module.makeCloseUnion(x0, at: site))
var targets = UnionSwitch.Targets(
rhsType.elements.map({ (e) in (key: e, value: failure) }),
uniquingKeysWith: { (a, _) in a })
targets[lhsType] = next
emitUnionSwitch(on: rhs, toOneOf: targets, at: site)

emitLocalDeclarations(introducedBy: pattern, referringTo: [], relativeTo: storage)
insertionPoint = .end(of: next)

return next
if let target = storage {
let x0 = insert(module.makeAccess(.sink, from: rhs, at: site))!
let x1 = insert(module.makeOpenUnion(x0, as: lhsType, at: site))!
emitMove([.set], x1, to: target, at: site)
emitLocalDeclarations(introducedBy: lhs, referringTo: [], relativeTo: target)
insert(module.makeCloseUnion(x1, at: site))
insert(module.makeEndAccess(x0, at: site))
} else {
let k = AccessEffect(program[lhs].introducer.value)
let x0 = insert(module.makeAccess(k, from: rhs, at: site))!
let x1 = insert(module.makeOpenUnion(x0, as: lhsType, at: site))!
assignProjections(of: x1, to: program[d].pattern)
}

return next
} else {
UNIMPLEMENTED()
}
}

/// Inserts the IR for branch condition `e`.
Expand Down Expand Up @@ -3077,7 +3068,7 @@ struct Emitter {
let targets = UnionSwitch.Targets(
t.elements.map({ (e) in (key: e, value: appendBlock()) }),
uniquingKeysWith: { (a, _) in a })
insert(module.makeUnionSwitch(on: storage, toOneOf: targets, at: site))
emitUnionSwitch(on: storage, toOneOf: targets, at: site)

let tail = appendBlock()
for (u, b) in targets {
Expand Down Expand Up @@ -3187,7 +3178,7 @@ struct Emitter {
insert(module.makeCondBranch(if: x0, then: same, else: fail, at: site))

insertionPoint = .end(of: same)
insert(module.makeUnionSwitch(on: lhs, toOneOf: targets, at: site))
emitUnionSwitch(on: lhs, toOneOf: targets, at: site)
for (u, b) in targets {
insertionPoint = .end(of: b)
let y0 = insert(module.makeOpenUnion(lhs, as: u, at: site))!
Expand Down Expand Up @@ -3299,6 +3290,16 @@ struct Emitter {
return x1
}

/// Appends the IR for jumping to the block assigned to the type of `scrutinee`'s payload in
/// `targets`.
private mutating func emitUnionSwitch(
on scrutinee: Operand, toOneOf targets: UnionSwitch.Targets, at site: SourceRange
) {
let u = UnionType(module.type(of: scrutinee).ast)!
let i = emitUnionDiscriminator(scrutinee, at: site)
insert(module.makeUnionSwitch(over: i, of: u, toOneOf: targets, at: site))
}

/// Returns the result of calling `action` on a copy of `self` in which a `newFrame` is the top
/// frame.
///
Expand Down
7 changes: 4 additions & 3 deletions Sources/IR/InstructionTransformer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,13 @@ extension IR.Program {
}

case let s as UnionSwitch:
let x0 = t.transform(s.scrutinee, in: &self)
let x1 = s.targets.reduce(into: UnionSwitch.Targets()) { (d, kv) in
let x0 = t.transform(s.discriminator, in: &self)
let x1 = UnionType(t.transform(^s.union, in: &self))!
let x2 = s.targets.reduce(into: UnionSwitch.Targets()) { (d, kv) in
_ = d[t.transform(kv.key, in: &self)].setIfNil(t.transform(kv.value, in: &self))
}
return insert(at: p, in:n) { (target) in
target.makeUnionSwitch(on: x0, toOneOf: x1, at: s.site)
target.makeUnionSwitch(over: x0, of: x1, toOneOf: x2, at: s.site)
}

case let s as Unreachable:
Expand Down
Loading

0 comments on commit 610865d

Please sign in to comment.