From 77401260a2c4e9506042eecf566005c75123d19f Mon Sep 17 00:00:00 2001 From: Chuck Smith Date: Wed, 7 Feb 2024 15:46:45 -0500 Subject: [PATCH] conditionals --- code/code.go | 30 ++++++------ compiler/compiler.go | 97 ++++++++++++++++++++++++++++++++++++++- compiler/compiler_test.go | 49 ++++++++++++++++++++ vm/vm.go | 23 ++++++++++ vm/vm_test.go | 14 ++++++ 5 files changed, 198 insertions(+), 15 deletions(-) diff --git a/code/code.go b/code/code.go index 5e83d83..c45c15f 100644 --- a/code/code.go +++ b/code/code.go @@ -24,6 +24,8 @@ const ( OpGreaterThan OpMinus OpBang + OpJumpNotTruthy + OpJump ) type Definition struct { @@ -32,19 +34,21 @@ type Definition struct { } var definitions = map[Opcode]*Definition{ - OpConstant: {"OpConstant", []int{2}}, - OpAdd: {"OpAdd", []int{}}, - OpPop: {"OpPop", []int{}}, - OpSub: {"OpSub", []int{}}, - OpMul: {"OpMul", []int{}}, - OpDiv: {"OpDiv", []int{}}, - OpTrue: {"OpTrue", []int{}}, - OpFalse: {"OpFalse", []int{}}, - OpEqual: {"OpEqual", []int{}}, - OpNotEqual: {"OpNotEqual", []int{}}, - OpGreaterThan: {"OpGreaterThan", []int{}}, - OpMinus: {"OpMinus", []int{}}, - OpBang: {"OpBang", []int{}}, + OpConstant: {"OpConstant", []int{2}}, + OpAdd: {"OpAdd", []int{}}, + OpPop: {"OpPop", []int{}}, + OpSub: {"OpSub", []int{}}, + OpMul: {"OpMul", []int{}}, + OpDiv: {"OpDiv", []int{}}, + OpTrue: {"OpTrue", []int{}}, + OpFalse: {"OpFalse", []int{}}, + OpEqual: {"OpEqual", []int{}}, + OpNotEqual: {"OpNotEqual", []int{}}, + OpGreaterThan: {"OpGreaterThan", []int{}}, + OpMinus: {"OpMinus", []int{}}, + OpBang: {"OpBang", []int{}}, + OpJumpNotTruthy: {"OpJumpNotTruthy", []int{2}}, + OpJump: {"OpJump", []int{2}}, } func Lookup(op byte) (*Definition, error) { diff --git a/compiler/compiler.go b/compiler/compiler.go index 0100f97..0d41081 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -7,15 +7,25 @@ import ( "monkey/object" ) +type EmittedInstruction struct { + Opcode code.Opcode + Position int +} + type Compiler struct { instructions code.Instructions constants []object.Object + + lastInstruction EmittedInstruction + previousInstruction EmittedInstruction } func New() *Compiler { return &Compiler{ - instructions: code.Instructions{}, - constants: []object.Object{}, + instructions: code.Instructions{}, + constants: []object.Object{}, + lastInstruction: EmittedInstruction{}, + previousInstruction: EmittedInstruction{}, } } @@ -106,6 +116,56 @@ func (c *Compiler) Compile(node ast.Node) error { return fmt.Errorf("unknown operator %s", node.Operator) } + case *ast.IfExpression: + err := c.Compile(node.Condition) + if err != nil { + return err + } + + // Emit an `OpJumpNotTruthy` with a bogus value + jumpNotTruthyPos := c.emit(code.OpJumpNotTruthy, 9999) + + err = c.Compile(node.Consequence) + if err != nil { + return err + } + + if c.lastInstructionIsPop() { + c.removeLastPop() + } + + if node.Alternative == nil { + afterConsequencePos := len(c.instructions) + c.changeOperand(jumpNotTruthyPos, afterConsequencePos) + } else { + // Emit an `OnJump` with a bogus value + jumpPos := c.emit(code.OpJump, 9999) + + afterConsequencePos := len(c.instructions) + c.changeOperand(jumpNotTruthyPos, afterConsequencePos) + + err = c.Compile(node.Alternative) + if err != nil { + return err + } + + if c.lastInstructionIsPop() { + c.removeLastPop() + } + + afterAlternativePos := len(c.instructions) + c.changeOperand(jumpPos, afterAlternativePos) + + } + + case *ast.BlockStatement: + for _, s := range node.Statements { + err := c.Compile(s) + if err != nil { + return err + } + } + } return nil @@ -119,9 +179,20 @@ func (c *Compiler) addConstant(obj object.Object) int { func (c *Compiler) emit(op code.Opcode, operands ...int) int { ins := code.Make(op, operands...) pos := c.addInstruction(ins) + + c.setLastInstruction(op, pos) + return pos } +func (c *Compiler) setLastInstruction(op code.Opcode, pos int) { + previous := c.lastInstruction + last := EmittedInstruction{Opcode: op, Position: pos} + + c.previousInstruction = previous + c.lastInstruction = last +} + func (c *Compiler) Bytecode() *Bytecode { return &Bytecode{ Instructions: c.instructions, @@ -135,6 +206,28 @@ func (c *Compiler) addInstruction(ins []byte) int { return postNewInstruction } +func (c *Compiler) lastInstructionIsPop() bool { + return c.lastInstruction.Opcode == code.OpPop +} + +func (c *Compiler) removeLastPop() { + c.instructions = c.instructions[:c.lastInstruction.Position] + c.lastInstruction = c.previousInstruction +} + +func (c *Compiler) replaceInstruction(pos int, newInstruction []byte) { + for i := 0; i < len(newInstruction); i++ { + c.instructions[pos+i] = newInstruction[i] + } +} + +func (c *Compiler) changeOperand(opPos int, operand int) { + op := code.Opcode(c.instructions[opPos]) + newInstruction := code.Make(op, operand) + + c.replaceInstruction(opPos, newInstruction) +} + type Bytecode struct { Instructions code.Instructions Constants []object.Object diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go index 8318a99..5a92360 100644 --- a/compiler/compiler_test.go +++ b/compiler/compiler_test.go @@ -174,6 +174,55 @@ func TestBooleanExpressions(t *testing.T) { runCompilerTests(t, tests) } +func TestConditionals(t *testing.T) { + tests := []compilerTestCase{ + { + input: ` + if (true) { 10 }; 3333; + `, + expectedConstants: []interface{}{10, 3333}, + expectedInstructions: []code.Instructions{ + code.Make(code.OpTrue), + // 0001 + code.Make(code.OpJumpNotTruthy, 7), + // 0004 + code.Make(code.OpConstant, 0), + // 0007 + code.Make(code.OpPop), + // 0008 + code.Make(code.OpConstant, 1), + // 0011 + code.Make(code.OpPop), + }, + }, { + input: ` + if (true) { 10 } else { 20 }; 3333; + `, + expectedConstants: []interface{}{10, 20, 3333}, + expectedInstructions: []code.Instructions{ + // 0000 + code.Make(code.OpTrue), + // 0001 + code.Make(code.OpJumpNotTruthy, 10), + // 0004 + code.Make(code.OpConstant, 0), + // 0007 + code.Make(code.OpJump, 13), + // 0010 + code.Make(code.OpConstant, 1), + // 0013 + code.Make(code.OpPop), + // 0014 + code.Make(code.OpConstant, 2), + // 0017 + code.Make(code.OpPop), + }, + }, + } + + runCompilerTests(t, tests) +} + func runCompilerTests(t *testing.T, tests []compilerTestCase) { t.Helper() diff --git a/vm/vm.go b/vm/vm.go index ff391c4..c1ae856 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -87,12 +87,35 @@ func (vm *VM) Run() error { return err } + case code.OpJump: + pos := int(code.ReadUint16(vm.instructions[ip+1:])) + ip = pos - 1 + + case code.OpJumpNotTruthy: + pos := int(code.ReadUint16(vm.instructions[ip+1:])) + ip += 2 + + condition := vm.pop() + if !isTruthy(condition) { + ip = pos - 1 + } } } return nil } +func isTruthy(obj object.Object) bool { + switch obj := obj.(type) { + + case *object.Boolean: + return obj.Value + + default: + return true + } +} + func (vm *VM) push(o object.Object) error { if vm.sp >= StackSize { return fmt.Errorf("stack overflow") diff --git a/vm/vm_test.go b/vm/vm_test.go index ec3d6ec..4af2505 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -142,3 +142,17 @@ func TestBooleanExpressions(t *testing.T) { runVmTests(t, tests) } + +func TestConditionals(t *testing.T) { + tests := []vmTestCase{ + {"if (true) { 10 }", 10}, + {"if (true) { 10 } else { 20 }", 10}, + {"if (false) { 10 } else { 20 } ", 20}, + {"if (1) { 10 }", 10}, + {"if (1 < 2) { 10 }", 10}, + {"if (1 < 2) { 10 } else { 20 }", 10}, + {"if (1 > 2) { 10 } else { 20 }", 20}, + } + + runVmTests(t, tests) +}