diff --git a/code/code.go b/code/code.go index 96b2438..8cea249 100644 --- a/code/code.go +++ b/code/code.go @@ -66,7 +66,7 @@ var definitions = map[Opcode]*Definition{ OpArray: {"OpArray", []int{2}}, OpHash: {"OpHash", []int{2}}, OpIndex: {"OpIndex", []int{}}, - OpCall: {"OpCall", []int{}}, + OpCall: {"OpCall", []int{1}}, OpReturnValue: {"OpReturnValue", []int{}}, OpReturn: {"OpReturn", []int{}}, OpGetLocal: {"OpGetLocal", []int{1}}, diff --git a/compiler/compiler.go b/compiler/compiler.go index 12a726e..1ceb5a6 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -265,6 +265,10 @@ func (c *Compiler) Compile(node ast.Node) error { case *ast.FunctionLiteral: c.enterScope() + for _, p := range node.Parameters { + c.symbolTable.Define(p.Value) + } + err := c.Compile(node.Body) if err != nil { return err @@ -280,7 +284,11 @@ func (c *Compiler) Compile(node ast.Node) error { numLocals := c.symbolTable.numDefinitions instructions := c.leaveScope() - compiledFn := &object.CompiledFunction{Instructions: instructions, NumLocals: numLocals} + compiledFn := &object.CompiledFunction{ + Instructions: instructions, + NumLocals: numLocals, + NumParameters: len(node.Parameters), + } c.emit(code.OpConstant, c.addConstant(compiledFn)) case *ast.ReturnStatement: @@ -297,7 +305,14 @@ func (c *Compiler) Compile(node ast.Node) error { return err } - c.emit(code.OpCall) + for _, a := range node.Arguments { + err := c.Compile(a) + if err != nil { + return err + } + } + + c.emit(code.OpCall, len(node.Arguments)) } diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go index 1b45446..3ea3bbc 100644 --- a/compiler/compiler_test.go +++ b/compiler/compiler_test.go @@ -562,7 +562,7 @@ func TestFunctionCalls(t *testing.T) { }, expectedInstructions: []code.Instructions{ code.Make(code.OpConstant, 1), // The compiled function - code.Make(code.OpCall), + code.Make(code.OpCall, 0), code.Make(code.OpPop), }, }, @@ -582,7 +582,57 @@ func TestFunctionCalls(t *testing.T) { code.Make(code.OpConstant, 1), // The compiled function code.Make(code.OpSetGlobal, 0), code.Make(code.OpGetGlobal, 0), - code.Make(code.OpCall), + code.Make(code.OpCall, 0), + code.Make(code.OpPop), + }, + }, + { + input: ` + let oneArg = fn(a) { a }; + oneArg(24); + `, + expectedConstants: []interface{}{ + []code.Instructions{ + code.Make(code.OpGetLocal, 0), + code.Make(code.OpReturnValue), + }, + 24, + }, + expectedInstructions: []code.Instructions{ + code.Make(code.OpConstant, 0), + code.Make(code.OpSetGlobal, 0), + code.Make(code.OpGetGlobal, 0), + code.Make(code.OpConstant, 1), + code.Make(code.OpCall, 1), + code.Make(code.OpPop), + }, + }, + { + input: ` + let manyArg = fn(a, b, c) { a; b; c }; + manyArg(24, 25, 26); + `, + expectedConstants: []interface{}{ + []code.Instructions{ + code.Make(code.OpGetLocal, 0), + code.Make(code.OpPop), + code.Make(code.OpGetLocal, 1), + code.Make(code.OpPop), + code.Make(code.OpGetLocal, 2), + code.Make(code.OpReturnValue), + }, + 24, + 25, + 26, + }, + expectedInstructions: []code.Instructions{ + code.Make(code.OpConstant, 0), + code.Make(code.OpSetGlobal, 0), + code.Make(code.OpGetGlobal, 0), + code.Make(code.OpConstant, 1), + code.Make(code.OpConstant, 2), + code.Make(code.OpConstant, 3), + code.Make(code.OpCall, 3), code.Make(code.OpPop), }, }, diff --git a/object/object.go b/object/object.go index 0346230..e88750e 100644 --- a/object/object.go +++ b/object/object.go @@ -35,8 +35,9 @@ type Integer struct { } type CompiledFunction struct { - Instructions code.Instructions - NumLocals int + Instructions code.Instructions + NumLocals int + NumParameters int } func (i *Integer) Type() ObjectType { diff --git a/vm/vm.go b/vm/vm.go index fb74782..a6cb8a2 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -203,13 +203,13 @@ func (vm *VM) Run() error { } case code.OpCall: - fn, ok := vm.stack[vm.sp-1].(*object.CompiledFunction) - if !ok { - return fmt.Errorf("calling non-function") + numArgs := code.ReadUint8(ins[ip+1:]) + vm.currentFrame().ip += 1 + + err := vm.callFunction(int(numArgs)) + if err != nil { + return err } - frame := NewFrame(fn, vm.sp) - vm.pushFrame(frame) - vm.sp = frame.basePointer + fn.NumLocals case code.OpReturnValue: returnValue := vm.pop() @@ -465,6 +465,23 @@ func (vm *VM) executeMinusOperator() error { return vm.push(&object.Integer{Value: -value}) } +func (vm *VM) callFunction(numArgs int) error { + fn, ok := vm.stack[vm.sp-1-numArgs].(*object.CompiledFunction) + if !ok { + return fmt.Errorf("calling non-function") + } + + if numArgs != fn.NumParameters { + return fmt.Errorf("wrong number of arguments: want=%d, got=%d", fn.NumParameters, numArgs) + } + + frame := NewFrame(fn, vm.sp-numArgs) + vm.pushFrame(frame) + vm.sp = frame.basePointer + fn.NumLocals + + return nil +} + func nativeBoolToBooleanObject(input bool) *object.Boolean { if input { return True diff --git a/vm/vm_test.go b/vm/vm_test.go index a1b8328..13c005e 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -448,3 +448,110 @@ func TestCallingFunctionsWithBindings(t *testing.T) { runVmTests(t, tests) } + +func TestCallingFunctionsWithArgumentsAndBindings(t *testing.T) { + tests := []vmTestCase{ + { + input: ` + let identity = fn(a) { a; }; + identity(4); + `, + expected: 4, + }, + { + input: ` + let sum = fn(a, b) { a + b; }; + sum(1, 2); + `, + expected: 3, + }, + { + input: ` + let sum = fn(a, b) { + let c = a + b; + c; + }; + sum(1, 2); + `, + expected: 3, + }, + { + input: ` + let sum = fn(a, b) { + let c = a + b; + c; + }; + sum(1, 2) + sum(3, 4);`, + expected: 10, + }, + { + input: ` + let sum = fn(a, b) { + let c = a + b; + c; + }; + let outer = fn() { + sum(1, 2) + sum(3, 4); + }; + outer(); + `, + expected: 10, + }, + { + input: ` + let globalNum = 10; + + let sum = fn(a, b) { + let c = a + b; + c + globalNum; + }; + + let outer = fn() { + sum(1, 2) + sum(3, 4) + globalNum; + }; + + outer() + globalNum; + `, + expected: 50, + }, + } + + runVmTests(t, tests) +} + +func TestCallingFunctionsWithWrongArguments(t *testing.T) { + tests := []vmTestCase{ + { + input: `fn() { 1; }(1);`, + expected: `wrong number of arguments: want=0, got=1`, + }, + { + input: `fn(a) { a; }();`, + expected: `wrong number of arguments: want=1, got=0`, + }, + { + input: `fn(a, b) { a + b; }(1);`, + expected: `wrong number of arguments: want=2, got=1`, + }, + } + + for _, tt := range tests { + program := parse(tt.input) + + comp := compiler.New() + err := comp.Compile(program) + if err != nil { + t.Fatalf("compiler error: %s", err) + } + + vm := New(comp.Bytecode()) + err = vm.Run() + if err == nil { + t.Fatalf("expected VM error but resulted in none.") + } + + if err.Error() != tt.expected { + t.Fatalf("wrong VM error: want=%q, got=%q", tt.expected, err) + } + } +}