From e56fb40f83e70c7ded289e0e0301a225d7fcaa5c Mon Sep 17 00:00:00 2001 From: Chuck Smith Date: Wed, 28 Feb 2024 16:57:01 -0500 Subject: [PATCH] compile functions --- code/code.go | 6 ++ compiler/compiler.go | 136 +++++++++++++++++++++----- compiler/compiler_test.go | 176 ++++++++++++++++++++++++++++++++++ compiler/symbol_table_test.go | 2 +- object/object.go | 34 +++++-- 5 files changed, 320 insertions(+), 34 deletions(-) diff --git a/code/code.go b/code/code.go index 0c1d16d..b417cab 100644 --- a/code/code.go +++ b/code/code.go @@ -32,6 +32,9 @@ const ( OpArray OpHash OpIndex + OpCall + OpReturnValue + OpReturn ) type Definition struct { @@ -61,6 +64,9 @@ var definitions = map[Opcode]*Definition{ OpArray: {"OpArray", []int{2}}, OpHash: {"OpHash", []int{2}}, OpIndex: {"OpIndex", []int{}}, + OpCall: {"OpCall", []int{}}, + OpReturnValue: {"OpReturnValue", []int{}}, + OpReturn: {"OpReturn", []int{}}, } func Lookup(op byte) (*Definition, error) { diff --git a/compiler/compiler.go b/compiler/compiler.go index 5bad298..5565154 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -13,24 +13,37 @@ type EmittedInstruction struct { Position int } -type Compiler struct { - instructions code.Instructions - constants []object.Object - +type CompilationScope struct { + instructions code.Instructions lastInstruction EmittedInstruction previousInstruction EmittedInstruction +} + +type Compiler struct { + constants []object.Object symbolTable *SymbolTable + + scopes []CompilationScope + scopeIndex int } func New() *Compiler { - return &Compiler{ + mainScope := CompilationScope{ instructions: code.Instructions{}, - constants: []object.Object{}, lastInstruction: EmittedInstruction{}, previousInstruction: EmittedInstruction{}, - symbolTable: NewSymbolTable(), } + return &Compiler{ + constants: []object.Object{}, + symbolTable: NewSymbolTable(), + scopes: []CompilationScope{mainScope}, + scopeIndex: 0, + } +} + +func (c *Compiler) currentInstructions() code.Instructions { + return c.scopes[c.scopeIndex].instructions } func NewWithState(s *SymbolTable, constants []object.Object) *Compiler { @@ -141,14 +154,14 @@ func (c *Compiler) Compile(node ast.Node) error { return err } - if c.lastInstructionIsPop() { + if c.lastInstructionIs(code.OpPop) { c.removeLastPop() } // Emit an `OnJump` with a bogus value jumpPos := c.emit(code.OpJump, 9999) - afterConsequencePos := len(c.instructions) + afterConsequencePos := len(c.currentInstructions()) c.changeOperand(jumpNotTruthyPos, afterConsequencePos) if node.Alternative == nil { @@ -159,12 +172,12 @@ func (c *Compiler) Compile(node ast.Node) error { return err } - if c.lastInstructionIsPop() { + if c.lastInstructionIs(code.OpPop) { c.removeLastPop() } } - afterAlternativePos := len(c.instructions) + afterAlternativePos := len(c.currentInstructions()) c.changeOperand(jumpPos, afterAlternativePos) case *ast.BlockStatement: @@ -240,6 +253,42 @@ func (c *Compiler) Compile(node ast.Node) error { c.emit(code.OpIndex) + case *ast.FunctionLiteral: + c.enterScope() + + err := c.Compile(node.Body) + if err != nil { + return err + } + + if c.lastInstructionIs(code.OpPop) { + c.replaceLastPopWithReturn() + } + if !c.lastInstructionIs(code.OpReturnValue) { + c.emit(code.OpReturn) + } + + instructions := c.leaveScope() + + compiledFn := &object.CompiledFunction{Instructions: instructions} + c.emit(code.OpConstant, c.addConstant(compiledFn)) + + case *ast.ReturnStatement: + err := c.Compile(node.ReturnValue) + if err != nil { + return err + } + + c.emit(code.OpReturnValue) + + case *ast.CallExpression: + err := c.Compile(node.Function) + if err != nil { + return err + } + + c.emit(code.OpCall) + } return nil @@ -260,48 +309,89 @@ func (c *Compiler) emit(op code.Opcode, operands ...int) int { } func (c *Compiler) setLastInstruction(op code.Opcode, pos int) { - previous := c.lastInstruction + previous := c.scopes[c.scopeIndex].lastInstruction last := EmittedInstruction{Opcode: op, Position: pos} - c.previousInstruction = previous - c.lastInstruction = last + c.scopes[c.scopeIndex].previousInstruction = previous + c.scopes[c.scopeIndex].lastInstruction = last } func (c *Compiler) Bytecode() *Bytecode { return &Bytecode{ - Instructions: c.instructions, + Instructions: c.currentInstructions(), Constants: c.constants, } } func (c *Compiler) addInstruction(ins []byte) int { - postNewInstruction := len(c.instructions) - c.instructions = append(c.instructions, ins...) + postNewInstruction := len(c.currentInstructions()) + updatedInstructions := append(c.currentInstructions(), ins...) + + c.scopes[c.scopeIndex].instructions = updatedInstructions + return postNewInstruction } -func (c *Compiler) lastInstructionIsPop() bool { - return c.lastInstruction.Opcode == code.OpPop +func (c *Compiler) lastInstructionIs(op code.Opcode) bool { + if len(c.currentInstructions()) == 0 { + return false + } + + return c.scopes[c.scopeIndex].lastInstruction.Opcode == op } func (c *Compiler) removeLastPop() { - c.instructions = c.instructions[:c.lastInstruction.Position] - c.lastInstruction = c.previousInstruction + last := c.scopes[c.scopeIndex].lastInstruction + previous := c.scopes[c.scopeIndex].previousInstruction + + old := c.currentInstructions() + new := old[:last.Position] + + c.scopes[c.scopeIndex].instructions = new + c.scopes[c.scopeIndex].lastInstruction = previous } func (c *Compiler) replaceInstruction(pos int, newInstruction []byte) { + ins := c.currentInstructions() + for i := 0; i < len(newInstruction); i++ { - c.instructions[pos+i] = newInstruction[i] + ins[pos+i] = newInstruction[i] } } func (c *Compiler) changeOperand(opPos int, operand int) { - op := code.Opcode(c.instructions[opPos]) + op := code.Opcode(c.currentInstructions()[opPos]) newInstruction := code.Make(op, operand) c.replaceInstruction(opPos, newInstruction) } +func (c *Compiler) enterScope() { + scope := CompilationScope{ + instructions: code.Instructions{}, + lastInstruction: EmittedInstruction{}, + previousInstruction: EmittedInstruction{}, + } + c.scopes = append(c.scopes, scope) + c.scopeIndex++ +} + +func (c *Compiler) leaveScope() code.Instructions { + instructions := c.currentInstructions() + + c.scopes = c.scopes[:len(c.scopes)-1] + c.scopeIndex-- + + return instructions +} + +func (c *Compiler) replaceLastPopWithReturn() { + lastPos := c.scopes[c.scopeIndex].lastInstruction.Position + c.replaceInstruction(lastPos, code.Make(code.OpReturnValue)) + + c.scopes[c.scopeIndex].lastInstruction.Opcode = code.OpReturnValue +} + type Bytecode struct { Instructions code.Instructions Constants []object.Object diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go index 49929c3..8987ce1 100644 --- a/compiler/compiler_test.go +++ b/compiler/compiler_test.go @@ -421,6 +421,171 @@ func TestIndexExpressions(t *testing.T) { code.Make(code.OpPop), }, }, + { + input: `fn() { 1; 2 }`, + expectedConstants: []interface{}{ + 1, + 2, + []code.Instructions{ + code.Make(code.OpConstant, 0), + code.Make(code.OpPop), + code.Make(code.OpConstant, 1), + code.Make(code.OpReturnValue), + }, + }, + expectedInstructions: []code.Instructions{ + code.Make(code.OpConstant, 2), + code.Make(code.OpPop), + }, + }, + } + + runCompilerTests(t, tests) +} + +func TestFunctions(t *testing.T) { + tests := []compilerTestCase{ + { + input: "fn() { return 5 + 10 }", + expectedConstants: []interface{}{ + 5, + 10, + []code.Instructions{ + code.Make(code.OpConstant, 0), + code.Make(code.OpConstant, 1), + code.Make(code.OpAdd), + code.Make(code.OpReturnValue), + }, + }, + expectedInstructions: []code.Instructions{ + code.Make(code.OpConstant, 2), + code.Make(code.OpPop), + }, + }, + { + input: `fn() { 5 + 10 }`, + expectedConstants: []interface{}{ + 5, + 10, + []code.Instructions{ + code.Make(code.OpConstant, 0), + code.Make(code.OpConstant, 1), + code.Make(code.OpAdd), + code.Make(code.OpReturnValue), + }, + }, + expectedInstructions: []code.Instructions{ + code.Make(code.OpConstant, 2), + code.Make(code.OpPop), + }, + }, + } + + runCompilerTests(t, tests) +} + +func TestFunctionsWithoutReturnValue(t *testing.T) { + tests := []compilerTestCase{ + { + input: `fn() { }`, + expectedConstants: []interface{}{ + []code.Instructions{ + code.Make(code.OpReturn), + }, + }, + expectedInstructions: []code.Instructions{ + code.Make(code.OpConstant, 0), + code.Make(code.OpPop), + }, + }, + } + + runCompilerTests(t, tests) +} + +func TestCompilerScopes(t *testing.T) { + compiler := New() + if compiler.scopeIndex != 0 { + t.Errorf("scopeIndex wrong. got=%d, want=%d", compiler.scopeIndex, 0) + } + + compiler.emit(code.OpMul) + + compiler.enterScope() + if compiler.scopeIndex != 1 { + t.Errorf("scopeIndex wrong. got=%d, want=%d", compiler.scopeIndex, 1) + } + + compiler.emit(code.OpSub) + + if len(compiler.scopes[compiler.scopeIndex].instructions) != 1 { + t.Errorf("instructions length is wrong. got=%d", len(compiler.scopes[compiler.scopeIndex].instructions)) + } + + last := compiler.scopes[compiler.scopeIndex].lastInstruction + if last.Opcode != code.OpSub { + t.Errorf("lastInstruction.OpCode wrong. got=%d, want=%d", last.Opcode, code.OpSub) + } + + compiler.leaveScope() + if compiler.scopeIndex != 0 { + t.Errorf("scopeIndex wrong. got=%d, want=%d", compiler.scopeIndex, 0) + } + + compiler.emit(code.OpAdd) + + if len(compiler.scopes[compiler.scopeIndex].instructions) != 2 { + t.Errorf("instructions length is wrong. got=%d", len(compiler.scopes[compiler.scopeIndex].instructions)) + } + + last = compiler.scopes[compiler.scopeIndex].lastInstruction + if last.Opcode != code.OpAdd { + t.Errorf("lastInstruction.OpCode wrong. got=%d, want=%d", last.Opcode, code.OpSub) + } + + previous := compiler.scopes[compiler.scopeIndex].previousInstruction + if previous.Opcode != code.OpMul { + t.Errorf("previousInstruction.OpCode wrong. got=%d, want=%d", last.Opcode, code.OpMul) + } +} + +func TestFunctionCalls(t *testing.T) { + tests := []compilerTestCase{ + { + input: `fn() { 24 }();`, + expectedConstants: []interface{}{ + 24, + []code.Instructions{ + code.Make(code.OpConstant, 0), // The literal "24" + code.Make(code.OpReturnValue), + }, + }, + expectedInstructions: []code.Instructions{ + code.Make(code.OpConstant, 1), // The compiled function + code.Make(code.OpCall), + code.Make(code.OpPop), + }, + }, + { + input: ` + let noArg = fn() { 24 }; + noArg(); + `, + expectedConstants: []interface{}{ + 24, + []code.Instructions{ + code.Make(code.OpConstant, 0), // The literal "24" + code.Make(code.OpReturnValue), + }, + }, + expectedInstructions: []code.Instructions{ + code.Make(code.OpConstant, 1), // The compiled function + code.Make(code.OpSetGlobal, 0), + code.Make(code.OpGetGlobal, 0), + code.Make(code.OpCall), + code.Make(code.OpPop), + }, + }, } runCompilerTests(t, tests) @@ -502,6 +667,17 @@ func testConstants(t *testing.T, expected []interface{}, actual []object.Object) if err != nil { return fmt.Errorf("constant %d = testStringObject failed : %s", i, err) } + + case []code.Instructions: + fn, ok := actual[i].(*object.CompiledFunction) + if !ok { + return fmt.Errorf("constant %d - not a function: %T", i, actual[i]) + } + + err := testInstructions(constant, fn.Instructions) + if err != nil { + return fmt.Errorf("constant %d = testInstructions failed: %s", i, err) + } } } return nil diff --git a/compiler/symbol_table_test.go b/compiler/symbol_table_test.go index b1656dc..6ca824c 100644 --- a/compiler/symbol_table_test.go +++ b/compiler/symbol_table_test.go @@ -24,7 +24,7 @@ func TestDefine(t *testing.T) { } b := global.Define("b") - if a != expected["b"] { + if b != expected["b"] { t.Errorf("expected b=%+v, got=%+v", expected["b"], b) } } diff --git a/object/object.go b/object/object.go index b3f69be..f6742bc 100644 --- a/object/object.go +++ b/object/object.go @@ -5,22 +5,24 @@ import ( "fmt" "hash/fnv" "monkey/ast" + "monkey/code" "strings" ) type ObjectType string const ( - INTEGER_OBJ = "INTEGER" - BOOLEAN_OBJ = "BOOLEAN" - NULL_OBJ = "NULL" - RETURN_VALUE_OBJ = "RETURN_VALUE" - ERROR_OBJ = "ERROR" - FUNCTION_OBJ = "FUNCTION" - STRING_OBJ = "STRING" - BUILTIN_OBJ = "BUILTIN" - ARRAY_OBJ = "ARRAY" - HASH_OBJ = "HASH" + INTEGER_OBJ = "INTEGER" + BOOLEAN_OBJ = "BOOLEAN" + NULL_OBJ = "NULL" + RETURN_VALUE_OBJ = "RETURN_VALUE" + ERROR_OBJ = "ERROR" + FUNCTION_OBJ = "FUNCTION" + STRING_OBJ = "STRING" + BUILTIN_OBJ = "BUILTIN" + ARRAY_OBJ = "ARRAY" + HASH_OBJ = "HASH" + COMPILED_FUNCTION_OBJ = "COMPILED_FUNCTION_OBJ " ) type Object interface { @@ -32,6 +34,10 @@ type Integer struct { Value int64 } +type CompiledFunction struct { + Instructions code.Instructions +} + func (i *Integer) Type() ObjectType { return INTEGER_OBJ } @@ -219,3 +225,11 @@ func (h *Hash) Inspect() string { return out.String() } + +func (cf *CompiledFunction) Type() ObjectType { + return COMPILED_FUNCTION_OBJ +} + +func (cf *CompiledFunction) Inspect() string { + return fmt.Sprintf("CompiledFunction[%p]", cf) +}