diff --git a/code/code.go b/code/code.go index b417cab..96b2438 100644 --- a/code/code.go +++ b/code/code.go @@ -35,6 +35,8 @@ const ( OpCall OpReturnValue OpReturn + OpGetLocal + OpSetLocal ) type Definition struct { @@ -67,6 +69,8 @@ var definitions = map[Opcode]*Definition{ OpCall: {"OpCall", []int{}}, OpReturnValue: {"OpReturnValue", []int{}}, OpReturn: {"OpReturn", []int{}}, + OpGetLocal: {"OpGetLocal", []int{1}}, + OpSetLocal: {"OpSetLocal", []int{1}}, } func Lookup(op byte) (*Definition, error) { @@ -98,6 +102,8 @@ func Make(op Opcode, operands ...int) []byte { switch width { case 2: binary.BigEndian.PutUint16(instruction[offset:], uint16(o)) + case 1: + instruction[offset] = byte(o) } offset += width } @@ -151,6 +157,9 @@ func ReadOperands(def *Definition, ins Instructions) ([]int, int) { switch width { case 2: operands[i] = int(ReadUint16(ins[offset:])) + case 1: + operands[i] = int(ReadUint8(ins[offset:])) + } offset += width @@ -162,3 +171,7 @@ func ReadOperands(def *Definition, ins Instructions) ([]int, int) { func ReadUint16(ins Instructions) uint16 { return binary.BigEndian.Uint16(ins) } + +func ReadUint8(ins Instructions) uint8 { + return ins[0] +} diff --git a/code/code_test.go b/code/code_test.go index 47d8246..aafd7cc 100644 --- a/code/code_test.go +++ b/code/code_test.go @@ -10,6 +10,7 @@ func TestMake(t *testing.T) { }{ {OpConstant, []int{65534}, []byte{byte(OpConstant), 255, 254}}, {OpAdd, []int{}, []byte{byte(OpAdd)}}, + {OpGetLocal, []int{255}, []byte{byte(OpGetLocal), 255}}, } for _, tt := range test { @@ -30,13 +31,15 @@ func TestMake(t *testing.T) { func TestInstructions(t *testing.T) { instructions := []Instructions{ Make(OpAdd), + Make(OpGetLocal, 1), Make(OpConstant, 2), Make(OpConstant, 65535), } expected := `0000 OpAdd -0001 OpConstant 2 -0004 OpConstant 65535 +0001 OpGetLocal 1 +0003 OpConstant 2 +0006 OpConstant 65535 ` concatted := Instructions{} @@ -49,13 +52,14 @@ func TestInstructions(t *testing.T) { } } -func TestOperands(t *testing.T) { +func TestReadOperands(t *testing.T) { tests := []struct { op Opcode operands []int bytesRead int }{ {OpConstant, []int{65535}, 2}, + {OpGetLocal, []int{255}, 1}, } for _, tt := range tests { diff --git a/compiler/compiler.go b/compiler/compiler.go index 5565154..12a726e 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -193,8 +193,13 @@ func (c *Compiler) Compile(node ast.Node) error { if err != nil { return err } + symbol := c.symbolTable.Define(node.Name.Value) - c.emit(code.OpSetGlobal, symbol.Index) + if symbol.Scope == GlobalScope { + c.emit(code.OpSetGlobal, symbol.Index) + } else { + c.emit(code.OpSetLocal, symbol.Index) + } case *ast.Identifier: symbol, ok := c.symbolTable.Resolve(node.Value) @@ -202,7 +207,11 @@ func (c *Compiler) Compile(node ast.Node) error { return fmt.Errorf("undefined varible %s", node.Value) } - c.emit(code.OpGetGlobal, symbol.Index) + if symbol.Scope == GlobalScope { + c.emit(code.OpGetGlobal, symbol.Index) + } else { + c.emit(code.OpGetLocal, symbol.Index) + } case *ast.StringLiteral: str := &object.String{Value: node.Value} @@ -268,9 +277,10 @@ func (c *Compiler) Compile(node ast.Node) error { c.emit(code.OpReturn) } + numLocals := c.symbolTable.numDefinitions instructions := c.leaveScope() - compiledFn := &object.CompiledFunction{Instructions: instructions} + compiledFn := &object.CompiledFunction{Instructions: instructions, NumLocals: numLocals} c.emit(code.OpConstant, c.addConstant(compiledFn)) case *ast.ReturnStatement: @@ -374,6 +384,8 @@ func (c *Compiler) enterScope() { } c.scopes = append(c.scopes, scope) c.scopeIndex++ + + c.symbolTable = NewEnclosedSymbolTable(c.symbolTable) } func (c *Compiler) leaveScope() code.Instructions { @@ -382,6 +394,8 @@ func (c *Compiler) leaveScope() code.Instructions { c.scopes = c.scopes[:len(c.scopes)-1] c.scopeIndex-- + c.symbolTable = c.symbolTable.Outer + return instructions } diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go index 8987ce1..1b45446 100644 --- a/compiler/compiler_test.go +++ b/compiler/compiler_test.go @@ -591,6 +591,80 @@ func TestFunctionCalls(t *testing.T) { runCompilerTests(t, tests) } +func TestLetStatementScopes(t *testing.T) { + tests := []compilerTestCase{ + { + input: ` + let num = 55; + fn() { num } + `, + expectedConstants: []interface{}{ + 55, + []code.Instructions{ + code.Make(code.OpGetGlobal, 0), + code.Make(code.OpReturnValue), + }, + }, + expectedInstructions: []code.Instructions{ + code.Make(code.OpConstant, 0), + code.Make(code.OpSetGlobal, 0), + code.Make(code.OpConstant, 1), + code.Make(code.OpPop), + }, + }, + { + input: ` + fn() { + let num = 55; + num + } + `, + expectedConstants: []interface{}{ + 55, + []code.Instructions{ + code.Make(code.OpConstant, 0), + code.Make(code.OpSetLocal, 0), + code.Make(code.OpGetLocal, 0), + code.Make(code.OpReturnValue), + }, + }, + expectedInstructions: []code.Instructions{ + code.Make(code.OpConstant, 1), + code.Make(code.OpPop), + }, + }, + { + input: ` + fn() { + let a = 55; + let b = 77; + a + b + } + `, + expectedConstants: []interface{}{ + 55, + 77, + []code.Instructions{ + code.Make(code.OpConstant, 0), + code.Make(code.OpSetLocal, 0), + code.Make(code.OpConstant, 1), + code.Make(code.OpSetLocal, 1), + code.Make(code.OpGetLocal, 0), + code.Make(code.OpGetLocal, 1), + code.Make(code.OpAdd), + code.Make(code.OpReturnValue), + }, + }, + expectedInstructions: []code.Instructions{ + code.Make(code.OpConstant, 2), + code.Make(code.OpPop), + }, + }, + } + + runCompilerTests(t, tests) +} + func runCompilerTests(t *testing.T, tests []compilerTestCase) { t.Helper() diff --git a/compiler/symbol_table.go b/compiler/symbol_table.go index 2bd84d8..1835e78 100644 --- a/compiler/symbol_table.go +++ b/compiler/symbol_table.go @@ -3,7 +3,8 @@ package compiler type SymbolScope string const ( - GlobalScope SymbolScope = "Global" + LocalScope SymbolScope = "LOCAL" + GlobalScope SymbolScope = "GLOBAL" ) type Symbol struct { @@ -13,21 +14,31 @@ type Symbol struct { } type SymbolTable struct { + Outer *SymbolTable + store map[string]Symbol numDefinitions int } +func NewEnclosedSymbolTable(outer *SymbolTable) *SymbolTable { + s := NewSymbolTable() + s.Outer = outer + return s +} + func NewSymbolTable() *SymbolTable { s := make(map[string]Symbol) return &SymbolTable{store: s} } func (s *SymbolTable) Define(name string) Symbol { - symbol := Symbol{ - Name: name, - Scope: GlobalScope, - Index: s.numDefinitions, + symbol := Symbol{Name: name, Index: s.numDefinitions} + if s.Outer == nil { + symbol.Scope = GlobalScope + } else { + symbol.Scope = LocalScope } + s.store[name] = symbol s.numDefinitions++ return symbol @@ -35,5 +46,10 @@ func (s *SymbolTable) Define(name string) Symbol { func (s *SymbolTable) Resolve(name string) (Symbol, bool) { obj, ok := s.store[name] + if !ok && s.Outer != nil { + obj, ok = s.Outer.Resolve(name) + return obj, ok + } + return obj, ok } diff --git a/compiler/symbol_table_test.go b/compiler/symbol_table_test.go index 6ca824c..bca4de2 100644 --- a/compiler/symbol_table_test.go +++ b/compiler/symbol_table_test.go @@ -14,6 +14,10 @@ func TestDefine(t *testing.T) { Scope: GlobalScope, Index: 1, }, + "c": {Name: "c", Scope: LocalScope, Index: 0}, + "d": {Name: "d", Scope: LocalScope, Index: 1}, + "e": {Name: "e", Scope: LocalScope, Index: 0}, + "f": {Name: "f", Scope: LocalScope, Index: 1}, } global := NewSymbolTable() @@ -27,6 +31,30 @@ func TestDefine(t *testing.T) { if b != expected["b"] { t.Errorf("expected b=%+v, got=%+v", expected["b"], b) } + + firstLocal := NewEnclosedSymbolTable(global) + + c := firstLocal.Define("c") + if c != expected["c"] { + t.Errorf("expected c=%+v, got=%+v", expected["c"], c) + } + + d := firstLocal.Define("d") + if d != expected["d"] { + t.Errorf("expected d=%+v, got=%+v", expected["d"], d) + } + + secondLocal := NewEnclosedSymbolTable(firstLocal) + + e := secondLocal.Define("e") + if e != expected["e"] { + t.Errorf("expected e=%+v, got=%+v", expected["e"], e) + } + + f := secondLocal.Define("f") + if f != expected["f"] { + t.Errorf("expected f=%+v, got=%+v", expected["f"], f) + } } func TestResolveGlobal(t *testing.T) { @@ -57,3 +85,84 @@ func TestResolveGlobal(t *testing.T) { } } } + +func TestResolveLocal(t *testing.T) { + global := NewSymbolTable() + global.Define("a") + global.Define("b") + + local := NewEnclosedSymbolTable(global) + local.Define("c") + local.Define("d") + + expected := []Symbol{ + Symbol{Name: "a", Scope: GlobalScope, Index: 0}, + Symbol{Name: "b", Scope: GlobalScope, Index: 1}, + Symbol{Name: "c", Scope: LocalScope, Index: 0}, + Symbol{Name: "d", Scope: LocalScope, Index: 1}, + } + + for _, sym := range expected { + result, ok := local.Resolve(sym.Name) + if !ok { + t.Errorf("name %s not resolvable", sym.Name) + continue + } + if result != sym { + t.Errorf("expected %s to resolve to %+v, got=%+v", + sym.Name, sym, result) + } + } +} + +func TestResolveNestedLocal(t *testing.T) { + global := NewSymbolTable() + global.Define("a") + global.Define("b") + + firstLocal := NewEnclosedSymbolTable(global) + firstLocal.Define("c") + firstLocal.Define("d") + + secondLocal := NewEnclosedSymbolTable(global) + secondLocal.Define("e") + secondLocal.Define("f") + + tests := []struct { + table *SymbolTable + expectedSymbols []Symbol + }{ + { + firstLocal, + []Symbol{ + Symbol{Name: "a", Scope: GlobalScope, Index: 0}, + Symbol{Name: "b", Scope: GlobalScope, Index: 1}, + Symbol{Name: "c", Scope: LocalScope, Index: 0}, + Symbol{Name: "d", Scope: LocalScope, Index: 1}, + }, + }, + { + secondLocal, + []Symbol{ + Symbol{Name: "a", Scope: GlobalScope, Index: 0}, + Symbol{Name: "b", Scope: GlobalScope, Index: 1}, + Symbol{Name: "e", Scope: LocalScope, Index: 0}, + Symbol{Name: "f", Scope: LocalScope, Index: 1}, + }, + }, + } + + for _, tt := range tests { + for _, sym := range tt.expectedSymbols { + result, ok := tt.table.Resolve(sym.Name) + if !ok { + t.Errorf("name %s not resolvable", sym.Name) + continue + } + if result != sym { + t.Errorf("expected %s to resolve to %+v, got=%+v", + sym.Name, sym, result) + } + } + } +} diff --git a/object/object.go b/object/object.go index f6742bc..0346230 100644 --- a/object/object.go +++ b/object/object.go @@ -36,6 +36,7 @@ type Integer struct { type CompiledFunction struct { Instructions code.Instructions + NumLocals int } func (i *Integer) Type() ObjectType { diff --git a/vm/frame.go b/vm/frame.go new file mode 100644 index 0000000..beffc30 --- /dev/null +++ b/vm/frame.go @@ -0,0 +1,20 @@ +package vm + +import ( + "monkey/code" + "monkey/object" +) + +type Frame struct { + fn *object.CompiledFunction + ip int + basePointer int +} + +func NewFrame(fn *object.CompiledFunction, basePointer int) *Frame { + return &Frame{fn: fn, ip: -1, basePointer: basePointer} +} + +func (f *Frame) Instructions() code.Instructions { + return f.fn.Instructions +} diff --git a/vm/vm.go b/vm/vm.go index 3b835b7..fb74782 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -15,19 +15,6 @@ var Null = &object.Null{} var True = &object.Boolean{Value: true} var False = &object.Boolean{Value: false} -type Frame struct { - fn *object.CompiledFunction - ip int -} - -func NewFrame(fn *object.CompiledFunction) *Frame { - return &Frame{fn: fn, ip: -1} -} - -func (f *Frame) Instructions() code.Instructions { - return f.fn.Instructions -} - type VM struct { constants []object.Object @@ -42,7 +29,7 @@ type VM struct { func New(bytecode *compiler.Bytecode) *VM { mainFn := &object.CompiledFunction{Instructions: bytecode.Instructions} - mainFrame := NewFrame(mainFn) + mainFrame := NewFrame(mainFn, 0) frames := make([]*Frame, MaxFrames) frames[0] = mainFrame @@ -220,13 +207,15 @@ func (vm *VM) Run() error { if !ok { return fmt.Errorf("calling non-function") } - frame := NewFrame(fn) + frame := NewFrame(fn, vm.sp) vm.pushFrame(frame) + vm.sp = frame.basePointer + fn.NumLocals case code.OpReturnValue: returnValue := vm.pop() - vm.popFrame() - vm.pop() + + frame := vm.popFrame() + vm.sp = frame.basePointer - 1 err := vm.push(returnValue) if err != nil { @@ -234,14 +223,32 @@ func (vm *VM) Run() error { } case code.OpReturn: - vm.popFrame() - vm.pop() + frame := vm.popFrame() + vm.sp = frame.basePointer - 1 err := vm.push(Null) if err != nil { return err } + case code.OpSetLocal: + localIndex := code.ReadUint8(ins[ip+1:]) + vm.currentFrame().ip += 1 + + frame := vm.currentFrame() + + vm.stack[frame.basePointer+int(localIndex)] = vm.pop() + + case code.OpGetLocal: + localIndex := code.ReadUint8(ins[ip+1:]) + vm.currentFrame().ip += 1 + + frame := vm.currentFrame() + + err := vm.push(vm.stack[frame.basePointer+int(localIndex)]) + if err != nil { + return err + } } } diff --git a/vm/vm_test.go b/vm/vm_test.go index adba853..a1b8328 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -379,6 +379,16 @@ func TestFirstClassFunctions(t *testing.T) { let returnsOne = fn() { 1; }; let returnsOneReturner = fn() { returnsOne; }; returnsOneReturner()(); + `, + expected: 1, + }, + { + input: ` + let returnsOneReturner = fn() { + let returnsOne = fn() { 1; }; + returnsOne; + }; + returnsOneReturner()(); `, expected: 1, }, @@ -386,3 +396,55 @@ func TestFirstClassFunctions(t *testing.T) { runVmTests(t, tests) } + +func TestCallingFunctionsWithBindings(t *testing.T) { + tests := []vmTestCase{ + { + input: ` + let one = fn() { let one = 1; one }; + one(); + `, + expected: 1, + }, + { + input: ` + let oneAndTwo = fn() { let one = 1; let two = 2; one + two; }; + oneAndTwo(); + `, + expected: 3, + }, + { + input: ` + let oneAndTwo = fn() { let one = 1; let two = 2; one + two; }; + let threeAndFour = fn() { let three = 3; let four = 4; three + four; }; + oneAndTwo() + threeAndFour(); + `, + expected: 10, + }, + { + input: ` + let firstFoobar = fn() { let foobar = 50; foobar; }; + let secondFoobar = fn() { let foobar = 100; foobar; }; + firstFoobar() + secondFoobar(); + `, + expected: 150, + }, + { + input: ` + let globalSeed = 50; + let minusOne = fn() { + let num = 1; + globalSeed - num; + } + let minusTwo = fn() { + let num = 2; + globalSeed - num; + } + minusOne() + minusTwo(); + `, + expected: 97, + }, + } + + runVmTests(t, tests) +}