From f33f77b05b728eb080242d6f9a47e1cc3d5b7b91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Mon, 3 Feb 2025 17:34:08 -0800 Subject: [PATCH 1/2] compile if-let statements, add JumpIfNil instruction --- bbq/compiler/codegen.go | 4 +++ bbq/compiler/compiler.go | 30 +++++++++++++++-- bbq/compiler/compiler_test.go | 62 +++++++++++++++++++++++++++++++++++ bbq/compiler/function.go | 9 +++-- bbq/opcode/instructions.go | 32 ++++++++++++++++++ bbq/opcode/instructions.yml | 16 +++++++++ bbq/opcode/opcode.go | 2 +- bbq/vm/test/vm_test.go | 48 +++++++++++++++++++++++++++ bbq/vm/vm.go | 9 +++++ 9 files changed, 207 insertions(+), 5 deletions(-) diff --git a/bbq/compiler/codegen.go b/bbq/compiler/codegen.go index 24bf7c357..0e93a7fd1 100644 --- a/bbq/compiler/codegen.go +++ b/bbq/compiler/codegen.go @@ -82,6 +82,10 @@ func (g *InstructionCodeGen) PatchJump(offset int, newTarget uint16) { ins.Target = newTarget (*g.target)[offset] = ins + case opcode.InstructionJumpIfNil: + ins.Target = newTarget + (*g.target)[offset] = ins + default: panic(errors.NewUnreachableError()) } diff --git a/bbq/compiler/compiler.go b/bbq/compiler/compiler.go index 651799b80..0f73bc1b8 100644 --- a/bbq/compiler/compiler.go +++ b/bbq/compiler/compiler.go @@ -250,6 +250,12 @@ func (c *Compiler[_]) emitUndefinedJumpIfFalse() int { return offset } +func (c *Compiler[_]) emitUndefinedJumpIfNil() int { + offset := c.codeGen.Offset() + c.codeGen.Emit(opcode.InstructionJumpIfNil{Target: math.MaxUint16}) + return offset +} + func (c *Compiler[_]) patchJump(opcodeOffset int) { count := c.codeGen.Offset() if count == 0 { @@ -570,14 +576,33 @@ func (c *Compiler[_]) VisitContinueStatement(_ *ast.ContinueStatement) (_ struct func (c *Compiler[_]) VisitIfStatement(statement *ast.IfStatement) (_ struct{}) { // TODO: scope + var elseJump int switch test := statement.Test.(type) { case ast.Expression: c.compileExpression(test) + elseJump = c.emitUndefinedJumpIfFalse() + + case *ast.VariableDeclaration: + // TODO: second value + c.compileExpression(test.Value) + + tempIndex := c.currentFunction.generateLocalIndex() + c.codeGen.Emit(opcode.InstructionSetLocal{LocalIndex: tempIndex}) + + c.codeGen.Emit(opcode.InstructionGetLocal{LocalIndex: tempIndex}) + elseJump = c.emitUndefinedJumpIfNil() + + c.codeGen.Emit(opcode.InstructionGetLocal{LocalIndex: tempIndex}) + c.codeGen.Emit(opcode.InstructionUnwrap{}) + varDeclTypes := c.ExtendedElaboration.VariableDeclarationTypes(test) + c.emitTransfer(varDeclTypes.TargetType) + local := c.currentFunction.declareLocal(test.Identifier.Identifier) + c.codeGen.Emit(opcode.InstructionSetLocal{LocalIndex: local.index}) + default: - // TODO: panic(errors.NewUnreachableError()) } - elseJump := c.emitUndefinedJumpIfFalse() + c.compileBlock(statement.Then) elseBlock := statement.Else if elseBlock != nil { @@ -588,6 +613,7 @@ func (c *Compiler[_]) VisitIfStatement(statement *ast.IfStatement) (_ struct{}) } else { c.patchJump(elseJump) } + return } diff --git a/bbq/compiler/compiler_test.go b/bbq/compiler/compiler_test.go index 9881804f0..67d08b073 100644 --- a/bbq/compiler/compiler_test.go +++ b/bbq/compiler/compiler_test.go @@ -21,6 +21,7 @@ package compiler import ( "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/onflow/cadence/bbq" @@ -457,3 +458,64 @@ func TestCompileDictionary(t *testing.T) { program.Constants, ) } + +func TestCompileIfLet(t *testing.T) { + + t.Parallel() + + checker, err := ParseAndCheck(t, ` + fun test(x: Int?): Int { + if let y = x { + return y + } else { + return 2 + } + } + `) + require.NoError(t, err) + + compiler := NewInstructionCompiler(checker) + program := compiler.Compile() + + require.Len(t, program.Functions, 1) + + assert.Equal(t, + []opcode.Instruction{ + // let y = x + opcode.InstructionGetLocal{LocalIndex: 0x0}, + opcode.InstructionSetLocal{LocalIndex: 0x1}, + + // if + opcode.InstructionGetLocal{LocalIndex: 0x1}, + opcode.InstructionJumpIfNil{Target: 11}, + + // let y = x + opcode.InstructionGetLocal{LocalIndex: 0x1}, + opcode.InstructionUnwrap{}, + opcode.InstructionTransfer{TypeIndex: 0x0}, + opcode.InstructionSetLocal{LocalIndex: 0x2}, + + // then { return y } + opcode.InstructionGetLocal{LocalIndex: 0x2}, + opcode.InstructionReturnValue{}, + opcode.InstructionJump{Target: 13}, + + // else { return 2 } + opcode.InstructionGetConstant{ConstantIndex: 0x0}, + opcode.InstructionReturnValue{}, + + opcode.InstructionReturn{}, + }, + compiler.ExportFunctions()[0].Code, + ) + + assert.Equal(t, + []*bbq.Constant{ + { + Data: []byte{0x2}, + Kind: constantkind.Int, + }, + }, + program.Constants, + ) +} diff --git a/bbq/compiler/function.go b/bbq/compiler/function.go index e430a9c42..39be3fe8a 100644 --- a/bbq/compiler/function.go +++ b/bbq/compiler/function.go @@ -43,12 +43,17 @@ func newFunction[E any](name string, parameterCount uint16, isCompositeFunction } } +func (f *function[E]) generateLocalIndex() uint16 { + index := f.localCount + f.localCount++ + return index +} + func (f *function[E]) declareLocal(name string) *local { if f.localCount == math.MaxUint16 { panic(errors.NewDefaultUserError("invalid local declaration")) } - index := f.localCount - f.localCount++ + index := f.generateLocalIndex() local := &local{index: index} f.locals.Set(name, local) return local diff --git a/bbq/opcode/instructions.go b/bbq/opcode/instructions.go index 6791cbf41..d849d31b6 100644 --- a/bbq/opcode/instructions.go +++ b/bbq/opcode/instructions.go @@ -785,6 +785,36 @@ func DecodeJumpIfFalse(ip *uint16, code []byte) (i InstructionJumpIfFalse) { return i } +// InstructionJumpIfNil +// +// Jumps to the given instruction, if the top value on the stack is `nil`. +type InstructionJumpIfNil struct { + Target uint16 +} + +var _ Instruction = InstructionJumpIfNil{} + +func (InstructionJumpIfNil) Opcode() Opcode { + return JumpIfNil +} + +func (i InstructionJumpIfNil) String() string { + var sb strings.Builder + sb.WriteString(i.Opcode().String()) + printfArgument(&sb, "target", i.Target) + return sb.String() +} + +func (i InstructionJumpIfNil) Encode(code *[]byte) { + emitOpcode(code, i.Opcode()) + emitUint16(code, i.Target) +} + +func DecodeJumpIfNil(ip *uint16, code []byte) (i InstructionJumpIfNil) { + i.Target = decodeUint16(ip, code) + return i +} + // InstructionReturn // // Returns from the current function, without a value. @@ -1123,6 +1153,8 @@ func DecodeInstruction(ip *uint16, code []byte) Instruction { return DecodeJump(ip, code) case JumpIfFalse: return DecodeJumpIfFalse(ip, code) + case JumpIfNil: + return DecodeJumpIfNil(ip, code) case Return: return InstructionReturn{} case ReturnValue: diff --git a/bbq/opcode/instructions.yml b/bbq/opcode/instructions.yml index e7c58ffb5..57924d2a6 100644 --- a/bbq/opcode/instructions.yml +++ b/bbq/opcode/instructions.yml @@ -334,6 +334,22 @@ type: "index" controlEffects: - jump: "target" + valueEffects: + pop: + - name: "value" + type: "value" + +- name: "jumpIfNil" + description: Jumps to the given instruction, if the top value on the stack is `nil`. + operands: + - name: "target" + type: "index" + controlEffects: + - jump: "target" + valueEffects: + pop: + - name: "value" + type: "value" - name: "return" description: Returns from the current function, without a value. diff --git a/bbq/opcode/opcode.go b/bbq/opcode/opcode.go index e701b5443..7f9109f93 100644 --- a/bbq/opcode/opcode.go +++ b/bbq/opcode/opcode.go @@ -31,7 +31,7 @@ const ( ReturnValue Jump JumpIfFalse - _ + JumpIfNil _ _ _ diff --git a/bbq/vm/test/vm_test.go b/bbq/vm/test/vm_test.go index 529215dfe..270debe48 100644 --- a/bbq/vm/test/vm_test.go +++ b/bbq/vm/test/vm_test.go @@ -3184,3 +3184,51 @@ func TestFunctionPostConditions(t *testing.T) { assert.Equal(t, []string{"A", "D", "F", "E", "C", "B"}, logs) }) } + +func TestIfLet(t *testing.T) { + + t.Parallel() + + t.Run("some", func(t *testing.T) { + + t.Parallel() + + result, err := compileAndInvoke(t, ` + fun main(x: Int?): Int { + if let y = x { + return y + } else { + return 2 + } + } + `, + "main", + vm.NewSomeValueNonCopying( + vm.NewIntValue(1), + ), + ) + require.NoError(t, err) + assert.Equal(t, vm.NewIntValue(1), result) + }) + + t.Run("nil", func(t *testing.T) { + + t.Parallel() + + result, err := compileAndInvoke(t, ` + fun main(x: Int?): Int { + if let y = x { + return y + } else { + return 2 + } + } + `, + "main", + vm.NilValue{}, + ) + + require.NoError(t, err) + assert.Equal(t, vm.NewIntValue(2), result) + }) +} diff --git a/bbq/vm/vm.go b/bbq/vm/vm.go index a03c53fc8..02abcc4e5 100644 --- a/bbq/vm/vm.go +++ b/bbq/vm/vm.go @@ -344,6 +344,13 @@ func opJumpIfFalse(vm *VM, ins opcode.InstructionJumpIfFalse) { } } +func opJumpIfNil(vm *VM, ins opcode.InstructionJumpIfNil) { + _, ok := vm.pop().(NilValue) + if ok { + vm.ip = ins.Target + } +} + func opBinaryIntAdd(vm *VM) { left, right := vm.peekPop() leftNumber := left.(IntValue) @@ -698,6 +705,8 @@ func (vm *VM) run() { opJump(vm, ins) case opcode.InstructionJumpIfFalse: opJumpIfFalse(vm, ins) + case opcode.InstructionJumpIfNil: + opJumpIfNil(vm, ins) case opcode.InstructionIntAdd: opBinaryIntAdd(vm) case opcode.InstructionIntSubtract: From 6c7faa63535998407ff09f108e88e6d19e840372 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Tue, 4 Feb 2025 09:23:07 -0800 Subject: [PATCH 2/2] go generate --- bbq/opcode/opcode_string.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/bbq/opcode/opcode_string.go b/bbq/opcode/opcode_string.go index 62084e729..37e6bd7e1 100644 --- a/bbq/opcode/opcode_string.go +++ b/bbq/opcode/opcode_string.go @@ -13,6 +13,7 @@ func _() { _ = x[ReturnValue-2] _ = x[Jump-3] _ = x[JumpIfFalse-4] + _ = x[JumpIfNil-5] _ = x[IntAdd-11] _ = x[IntSubtract-12] _ = x[IntMultiply-13] @@ -53,7 +54,7 @@ func _() { } const ( - _Opcode_name_0 = "UnknownReturnReturnValueJumpJumpIfFalse" + _Opcode_name_0 = "UnknownReturnReturnValueJumpJumpIfFalseJumpIfNil" _Opcode_name_1 = "IntAddIntSubtractIntMultiplyIntDivideIntModIntLessIntGreaterIntLessOrEqualIntGreaterOrEqual" _Opcode_name_2 = "EqualNotEqualNot" _Opcode_name_3 = "UnwrapDestroyTransferCast" @@ -64,7 +65,7 @@ const ( ) var ( - _Opcode_index_0 = [...]uint8{0, 7, 13, 24, 28, 39} + _Opcode_index_0 = [...]uint8{0, 7, 13, 24, 28, 39, 48} _Opcode_index_1 = [...]uint8{0, 6, 17, 28, 37, 43, 50, 60, 74, 91} _Opcode_index_2 = [...]uint8{0, 5, 13, 16} _Opcode_index_3 = [...]uint8{0, 6, 13, 21, 25} @@ -76,7 +77,7 @@ var ( func (i Opcode) String() string { switch { - case i <= 4: + case i <= 5: return _Opcode_name_0[_Opcode_index_0[i]:_Opcode_index_0[i+1]] case 11 <= i && i <= 19: i -= 11