diff --git a/code/code.go b/code/code.go index b199e5e..764e660 100644 --- a/code/code.go +++ b/code/code.go @@ -38,6 +38,7 @@ const ( OpGetLocal OpSetLocal OpGetBuiltin + OpClosure ) type Definition struct { @@ -73,6 +74,7 @@ var definitions = map[Opcode]*Definition{ OpGetLocal: {"OpGetLocal", []int{1}}, OpSetLocal: {"OpSetLocal", []int{1}}, OpGetBuiltin: {"OpGetBuiltin", []int{1}}, + OpClosure: {"OpClosure", []int{2, 1}}, } func Lookup(op byte) (*Definition, error) { @@ -146,6 +148,9 @@ func (ins Instructions) fmtInstruction(def *Definition, operands []int) string { return def.Name case 1: return fmt.Sprintf("%s %d", def.Name, operands[0]) + case 2: + return fmt.Sprintf("%s %d %d", def.Name, operands[0], operands[1]) + } return fmt.Sprintf("ERROR: unhandled operandCount for %s\n", def.Name) diff --git a/code/code_test.go b/code/code_test.go index aafd7cc..eeda1f2 100644 --- a/code/code_test.go +++ b/code/code_test.go @@ -11,6 +11,7 @@ func TestMake(t *testing.T) { {OpConstant, []int{65534}, []byte{byte(OpConstant), 255, 254}}, {OpAdd, []int{}, []byte{byte(OpAdd)}}, {OpGetLocal, []int{255}, []byte{byte(OpGetLocal), 255}}, + {OpClosure, []int{65534, 255}, []byte{byte(OpClosure), 255, 254, 255}}, } for _, tt := range test { @@ -34,12 +35,14 @@ func TestInstructions(t *testing.T) { Make(OpGetLocal, 1), Make(OpConstant, 2), Make(OpConstant, 65535), + Make(OpClosure, 65535, 255), } expected := `0000 OpAdd 0001 OpGetLocal 1 0003 OpConstant 2 0006 OpConstant 65535 +0009 OpClosure 65535 255 ` concatted := Instructions{} @@ -60,6 +63,7 @@ func TestReadOperands(t *testing.T) { }{ {OpConstant, []int{65535}, 2}, {OpGetLocal, []int{255}, 1}, + {OpClosure, []int{65535, 255}, 3}, } for _, tt := range tests { diff --git a/compiler/compiler.go b/compiler/compiler.go index a489cc4..9dc48af 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -292,7 +292,9 @@ func (c *Compiler) Compile(node ast.Node) error { NumLocals: numLocals, NumParameters: len(node.Parameters), } - c.emit(code.OpConstant, c.addConstant(compiledFn)) + + fnIndex := c.addConstant(compiledFn) + c.emit(code.OpClosure, fnIndex, 0) case *ast.ReturnStatement: err := c.Compile(node.ReturnValue) diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go index 35d0da8..bbd131b 100644 --- a/compiler/compiler_test.go +++ b/compiler/compiler_test.go @@ -421,23 +421,6 @@ func TestIndexExpressions(t *testing.T) { code.Make(code.OpPop), }, }, - { - input: `fn() { 1; 2 }`, - expectedConstants: []interface{}{ - 1, - 2, - []code.Instructions{ - code.Make(code.OpConstant, 0), - code.Make(code.OpPop), - code.Make(code.OpConstant, 1), - code.Make(code.OpReturnValue), - }, - }, - expectedInstructions: []code.Instructions{ - code.Make(code.OpConstant, 2), - code.Make(code.OpPop), - }, - }, } runCompilerTests(t, tests) @@ -458,7 +441,7 @@ func TestFunctions(t *testing.T) { }, }, expectedInstructions: []code.Instructions{ - code.Make(code.OpConstant, 2), + code.Make(code.OpClosure, 2, 0), code.Make(code.OpPop), }, }, @@ -475,7 +458,24 @@ func TestFunctions(t *testing.T) { }, }, expectedInstructions: []code.Instructions{ - code.Make(code.OpConstant, 2), + code.Make(code.OpClosure, 2, 0), + code.Make(code.OpPop), + }, + }, + { + input: `fn() { 1; 2 }`, + expectedConstants: []interface{}{ + 1, + 2, + []code.Instructions{ + code.Make(code.OpConstant, 0), + code.Make(code.OpPop), + code.Make(code.OpConstant, 1), + code.Make(code.OpReturnValue), + }, + }, + expectedInstructions: []code.Instructions{ + code.Make(code.OpClosure, 2, 0), code.Make(code.OpPop), }, }, @@ -494,7 +494,7 @@ func TestFunctionsWithoutReturnValue(t *testing.T) { }, }, expectedInstructions: []code.Instructions{ - code.Make(code.OpConstant, 0), + code.Make(code.OpClosure, 0, 0), code.Make(code.OpPop), }, }, @@ -561,7 +561,7 @@ func TestFunctionCalls(t *testing.T) { }, }, expectedInstructions: []code.Instructions{ - code.Make(code.OpConstant, 1), // The compiled function + code.Make(code.OpClosure, 1, 0), // The compiled function code.Make(code.OpCall, 0), code.Make(code.OpPop), }, @@ -579,7 +579,7 @@ func TestFunctionCalls(t *testing.T) { }, }, expectedInstructions: []code.Instructions{ - code.Make(code.OpConstant, 1), // The compiled function + code.Make(code.OpClosure, 1, 0), // The compiled function code.Make(code.OpSetGlobal, 0), code.Make(code.OpGetGlobal, 0), code.Make(code.OpCall, 0), @@ -599,7 +599,7 @@ func TestFunctionCalls(t *testing.T) { 24, }, expectedInstructions: []code.Instructions{ - code.Make(code.OpConstant, 0), + code.Make(code.OpClosure, 0, 0), code.Make(code.OpSetGlobal, 0), code.Make(code.OpGetGlobal, 0), code.Make(code.OpConstant, 1), @@ -626,7 +626,7 @@ func TestFunctionCalls(t *testing.T) { 26, }, expectedInstructions: []code.Instructions{ - code.Make(code.OpConstant, 0), + code.Make(code.OpClosure, 0, 0), code.Make(code.OpSetGlobal, 0), code.Make(code.OpGetGlobal, 0), code.Make(code.OpConstant, 1), @@ -658,7 +658,7 @@ func TestLetStatementScopes(t *testing.T) { expectedInstructions: []code.Instructions{ code.Make(code.OpConstant, 0), code.Make(code.OpSetGlobal, 0), - code.Make(code.OpConstant, 1), + code.Make(code.OpClosure, 1, 0), code.Make(code.OpPop), }, }, @@ -679,7 +679,7 @@ func TestLetStatementScopes(t *testing.T) { }, }, expectedInstructions: []code.Instructions{ - code.Make(code.OpConstant, 1), + code.Make(code.OpClosure, 1, 0), code.Make(code.OpPop), }, }, @@ -706,7 +706,47 @@ func TestLetStatementScopes(t *testing.T) { }, }, expectedInstructions: []code.Instructions{ - code.Make(code.OpConstant, 2), + code.Make(code.OpClosure, 2, 0), + code.Make(code.OpPop), + }, + }, + } + + runCompilerTests(t, tests) +} + +func TestBuiltins(t *testing.T) { + tests := []compilerTestCase{ + { + input: ` + len([]); + push([], 1); + `, + expectedConstants: []interface{}{1}, + expectedInstructions: []code.Instructions{ + code.Make(code.OpGetBuiltin, 0), + code.Make(code.OpArray, 0), + code.Make(code.OpCall, 1), + code.Make(code.OpPop), + code.Make(code.OpGetBuiltin, 5), + code.Make(code.OpArray, 0), + code.Make(code.OpConstant, 0), + code.Make(code.OpCall, 2), + code.Make(code.OpPop), + }, + }, + { + input: `fn() { len([]) }`, + expectedConstants: []interface{}{ + []code.Instructions{ + code.Make(code.OpGetBuiltin, 0), + code.Make(code.OpArray, 0), + code.Make(code.OpCall, 1), + code.Make(code.OpReturnValue), + }, + }, + expectedInstructions: []code.Instructions{ + code.Make(code.OpClosure, 0, 0), code.Make(code.OpPop), }, }, @@ -832,43 +872,3 @@ func testStringObject(expected string, actual object.Object) error { return nil } - -func TestBuiltins(t *testing.T) { - tests := []compilerTestCase{ - { - input: ` - len([]); - push([], 1); - `, - expectedConstants: []interface{}{1}, - expectedInstructions: []code.Instructions{ - code.Make(code.OpGetBuiltin, 0), - code.Make(code.OpArray, 0), - code.Make(code.OpCall, 1), - code.Make(code.OpPop), - code.Make(code.OpGetBuiltin, 5), - code.Make(code.OpArray, 0), - code.Make(code.OpConstant, 0), - code.Make(code.OpCall, 2), - code.Make(code.OpPop), - }, - }, - { - input: `fn() { len([]) }`, - expectedConstants: []interface{}{ - []code.Instructions{ - code.Make(code.OpGetBuiltin, 0), - code.Make(code.OpArray, 0), - code.Make(code.OpCall, 1), - code.Make(code.OpReturnValue), - }, - }, - expectedInstructions: []code.Instructions{ - code.Make(code.OpConstant, 0), - code.Make(code.OpPop), - }, - }, - } - - runCompilerTests(t, tests) -} diff --git a/object/object.go b/object/object.go index e88750e..89eb6a2 100644 --- a/object/object.go +++ b/object/object.go @@ -22,7 +22,8 @@ const ( BUILTIN_OBJ = "BUILTIN" ARRAY_OBJ = "ARRAY" HASH_OBJ = "HASH" - COMPILED_FUNCTION_OBJ = "COMPILED_FUNCTION_OBJ " + COMPILED_FUNCTION_OBJ = "COMPILED_FUNCTION" + CLOSURE_OBJ = "CLOSURE" ) type Object interface { @@ -235,3 +236,16 @@ func (cf *CompiledFunction) Type() ObjectType { func (cf *CompiledFunction) Inspect() string { return fmt.Sprintf("CompiledFunction[%p]", cf) } + +type Closure struct { + Fn *CompiledFunction + Free []Object +} + +func (c *Closure) Type() ObjectType { + return CLOSURE_OBJ +} + +func (c *Closure) Inspect() string { + return fmt.Sprintf("Closure[%p]", c) +} diff --git a/vm/frame.go b/vm/frame.go index beffc30..6a44dbd 100644 --- a/vm/frame.go +++ b/vm/frame.go @@ -6,15 +6,18 @@ import ( ) type Frame struct { - fn *object.CompiledFunction + cl *object.Closure ip int basePointer int } -func NewFrame(fn *object.CompiledFunction, basePointer int) *Frame { - return &Frame{fn: fn, ip: -1, basePointer: basePointer} +func NewFrame(cl *object.Closure, basePointer int) *Frame { + return &Frame{ + cl: cl, + ip: -1, + basePointer: basePointer} } func (f *Frame) Instructions() code.Instructions { - return f.fn.Instructions + return f.cl.Fn.Instructions } diff --git a/vm/vm.go b/vm/vm.go index b86fdac..ec7df47 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -29,7 +29,8 @@ type VM struct { func New(bytecode *compiler.Bytecode) *VM { mainFn := &object.CompiledFunction{Instructions: bytecode.Instructions} - mainFrame := NewFrame(mainFn, 0) + mainClosure := &object.Closure{Fn: mainFn} + mainFrame := NewFrame(mainClosure, 0) frames := make([]*Frame, MaxFrames) frames[0] = mainFrame @@ -260,6 +261,17 @@ func (vm *VM) Run() error { if err != nil { return err } + + case code.OpClosure: + constIndex := code.ReadUint16(ins[ip+1:]) + _ = code.ReadUint8(ins[ip+3:]) + vm.currentFrame().ip += 3 + + err := vm.pushClosure(int(constIndex)) + if err != nil { + return err + } + } } @@ -479,8 +491,8 @@ func (vm *VM) executeMinusOperator() error { func (vm *VM) executeCall(numArgs int) error { callee := vm.stack[vm.sp-1-numArgs] switch callee := callee.(type) { - case *object.CompiledFunction: - return vm.callFunction(callee, numArgs) + case *object.Closure: + return vm.callClosure(callee, numArgs) case *object.Builtin: return vm.callBuiltin(callee, numArgs) default: @@ -488,14 +500,14 @@ func (vm *VM) executeCall(numArgs int) error { } } -func (vm *VM) callFunction(fn *object.CompiledFunction, numArgs int) error { - if numArgs != fn.NumParameters { - return fmt.Errorf("wrong number of arguments: want=%d, got=%d", fn.NumParameters, numArgs) +func (vm *VM) callClosure(cl *object.Closure, numArgs int) error { + if numArgs != cl.Fn.NumParameters { + return fmt.Errorf("wrong number of arguments: want=%d, got=%d", cl.Fn.NumParameters, numArgs) } - frame := NewFrame(fn, vm.sp-numArgs) + frame := NewFrame(cl, vm.sp-numArgs) vm.pushFrame(frame) - vm.sp = frame.basePointer + fn.NumLocals + vm.sp = frame.basePointer + cl.Fn.NumLocals return nil } @@ -521,6 +533,17 @@ func (vm *VM) callBuiltin(builtin *object.Builtin, numArgs int) error { return nil } +func (vm *VM) pushClosure(constIndex int) error { + constant := vm.constants[constIndex] + function, ok := constant.(*object.CompiledFunction) + if !ok { + return fmt.Errorf("not a function %+v", constant) + } + + closure := &object.Closure{Fn: function} + return vm.push(closure) +} + func nativeBoolToBooleanObject(input bool) *object.Boolean { if input { return True