functions with arguments
This commit is contained in:
@@ -66,7 +66,7 @@ var definitions = map[Opcode]*Definition{
|
|||||||
OpArray: {"OpArray", []int{2}},
|
OpArray: {"OpArray", []int{2}},
|
||||||
OpHash: {"OpHash", []int{2}},
|
OpHash: {"OpHash", []int{2}},
|
||||||
OpIndex: {"OpIndex", []int{}},
|
OpIndex: {"OpIndex", []int{}},
|
||||||
OpCall: {"OpCall", []int{}},
|
OpCall: {"OpCall", []int{1}},
|
||||||
OpReturnValue: {"OpReturnValue", []int{}},
|
OpReturnValue: {"OpReturnValue", []int{}},
|
||||||
OpReturn: {"OpReturn", []int{}},
|
OpReturn: {"OpReturn", []int{}},
|
||||||
OpGetLocal: {"OpGetLocal", []int{1}},
|
OpGetLocal: {"OpGetLocal", []int{1}},
|
||||||
|
|||||||
@@ -265,6 +265,10 @@ func (c *Compiler) Compile(node ast.Node) error {
|
|||||||
case *ast.FunctionLiteral:
|
case *ast.FunctionLiteral:
|
||||||
c.enterScope()
|
c.enterScope()
|
||||||
|
|
||||||
|
for _, p := range node.Parameters {
|
||||||
|
c.symbolTable.Define(p.Value)
|
||||||
|
}
|
||||||
|
|
||||||
err := c.Compile(node.Body)
|
err := c.Compile(node.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -280,7 +284,11 @@ func (c *Compiler) Compile(node ast.Node) error {
|
|||||||
numLocals := c.symbolTable.numDefinitions
|
numLocals := c.symbolTable.numDefinitions
|
||||||
instructions := c.leaveScope()
|
instructions := c.leaveScope()
|
||||||
|
|
||||||
compiledFn := &object.CompiledFunction{Instructions: instructions, NumLocals: numLocals}
|
compiledFn := &object.CompiledFunction{
|
||||||
|
Instructions: instructions,
|
||||||
|
NumLocals: numLocals,
|
||||||
|
NumParameters: len(node.Parameters),
|
||||||
|
}
|
||||||
c.emit(code.OpConstant, c.addConstant(compiledFn))
|
c.emit(code.OpConstant, c.addConstant(compiledFn))
|
||||||
|
|
||||||
case *ast.ReturnStatement:
|
case *ast.ReturnStatement:
|
||||||
@@ -297,7 +305,14 @@ func (c *Compiler) Compile(node ast.Node) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.emit(code.OpCall)
|
for _, a := range node.Arguments {
|
||||||
|
err := c.Compile(a)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.emit(code.OpCall, len(node.Arguments))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -562,7 +562,7 @@ func TestFunctionCalls(t *testing.T) {
|
|||||||
},
|
},
|
||||||
expectedInstructions: []code.Instructions{
|
expectedInstructions: []code.Instructions{
|
||||||
code.Make(code.OpConstant, 1), // The compiled function
|
code.Make(code.OpConstant, 1), // The compiled function
|
||||||
code.Make(code.OpCall),
|
code.Make(code.OpCall, 0),
|
||||||
code.Make(code.OpPop),
|
code.Make(code.OpPop),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -582,7 +582,57 @@ func TestFunctionCalls(t *testing.T) {
|
|||||||
code.Make(code.OpConstant, 1), // The compiled function
|
code.Make(code.OpConstant, 1), // The compiled function
|
||||||
code.Make(code.OpSetGlobal, 0),
|
code.Make(code.OpSetGlobal, 0),
|
||||||
code.Make(code.OpGetGlobal, 0),
|
code.Make(code.OpGetGlobal, 0),
|
||||||
code.Make(code.OpCall),
|
code.Make(code.OpCall, 0),
|
||||||
|
code.Make(code.OpPop),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: `
|
||||||
|
let oneArg = fn(a) { a };
|
||||||
|
oneArg(24);
|
||||||
|
`,
|
||||||
|
expectedConstants: []interface{}{
|
||||||
|
[]code.Instructions{
|
||||||
|
code.Make(code.OpGetLocal, 0),
|
||||||
|
code.Make(code.OpReturnValue),
|
||||||
|
},
|
||||||
|
24,
|
||||||
|
},
|
||||||
|
expectedInstructions: []code.Instructions{
|
||||||
|
code.Make(code.OpConstant, 0),
|
||||||
|
code.Make(code.OpSetGlobal, 0),
|
||||||
|
code.Make(code.OpGetGlobal, 0),
|
||||||
|
code.Make(code.OpConstant, 1),
|
||||||
|
code.Make(code.OpCall, 1),
|
||||||
|
code.Make(code.OpPop),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: `
|
||||||
|
let manyArg = fn(a, b, c) { a; b; c };
|
||||||
|
manyArg(24, 25, 26);
|
||||||
|
`,
|
||||||
|
expectedConstants: []interface{}{
|
||||||
|
[]code.Instructions{
|
||||||
|
code.Make(code.OpGetLocal, 0),
|
||||||
|
code.Make(code.OpPop),
|
||||||
|
code.Make(code.OpGetLocal, 1),
|
||||||
|
code.Make(code.OpPop),
|
||||||
|
code.Make(code.OpGetLocal, 2),
|
||||||
|
code.Make(code.OpReturnValue),
|
||||||
|
},
|
||||||
|
24,
|
||||||
|
25,
|
||||||
|
26,
|
||||||
|
},
|
||||||
|
expectedInstructions: []code.Instructions{
|
||||||
|
code.Make(code.OpConstant, 0),
|
||||||
|
code.Make(code.OpSetGlobal, 0),
|
||||||
|
code.Make(code.OpGetGlobal, 0),
|
||||||
|
code.Make(code.OpConstant, 1),
|
||||||
|
code.Make(code.OpConstant, 2),
|
||||||
|
code.Make(code.OpConstant, 3),
|
||||||
|
code.Make(code.OpCall, 3),
|
||||||
code.Make(code.OpPop),
|
code.Make(code.OpPop),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -35,8 +35,9 @@ type Integer struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type CompiledFunction struct {
|
type CompiledFunction struct {
|
||||||
Instructions code.Instructions
|
Instructions code.Instructions
|
||||||
NumLocals int
|
NumLocals int
|
||||||
|
NumParameters int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *Integer) Type() ObjectType {
|
func (i *Integer) Type() ObjectType {
|
||||||
|
|||||||
29
vm/vm.go
29
vm/vm.go
@@ -203,13 +203,13 @@ func (vm *VM) Run() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
case code.OpCall:
|
case code.OpCall:
|
||||||
fn, ok := vm.stack[vm.sp-1].(*object.CompiledFunction)
|
numArgs := code.ReadUint8(ins[ip+1:])
|
||||||
if !ok {
|
vm.currentFrame().ip += 1
|
||||||
return fmt.Errorf("calling non-function")
|
|
||||||
|
err := vm.callFunction(int(numArgs))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
frame := NewFrame(fn, vm.sp)
|
|
||||||
vm.pushFrame(frame)
|
|
||||||
vm.sp = frame.basePointer + fn.NumLocals
|
|
||||||
|
|
||||||
case code.OpReturnValue:
|
case code.OpReturnValue:
|
||||||
returnValue := vm.pop()
|
returnValue := vm.pop()
|
||||||
@@ -465,6 +465,23 @@ func (vm *VM) executeMinusOperator() error {
|
|||||||
return vm.push(&object.Integer{Value: -value})
|
return vm.push(&object.Integer{Value: -value})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (vm *VM) callFunction(numArgs int) error {
|
||||||
|
fn, ok := vm.stack[vm.sp-1-numArgs].(*object.CompiledFunction)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("calling non-function")
|
||||||
|
}
|
||||||
|
|
||||||
|
if numArgs != fn.NumParameters {
|
||||||
|
return fmt.Errorf("wrong number of arguments: want=%d, got=%d", fn.NumParameters, numArgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
frame := NewFrame(fn, vm.sp-numArgs)
|
||||||
|
vm.pushFrame(frame)
|
||||||
|
vm.sp = frame.basePointer + fn.NumLocals
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func nativeBoolToBooleanObject(input bool) *object.Boolean {
|
func nativeBoolToBooleanObject(input bool) *object.Boolean {
|
||||||
if input {
|
if input {
|
||||||
return True
|
return True
|
||||||
|
|||||||
107
vm/vm_test.go
107
vm/vm_test.go
@@ -448,3 +448,110 @@ func TestCallingFunctionsWithBindings(t *testing.T) {
|
|||||||
|
|
||||||
runVmTests(t, tests)
|
runVmTests(t, tests)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCallingFunctionsWithArgumentsAndBindings(t *testing.T) {
|
||||||
|
tests := []vmTestCase{
|
||||||
|
{
|
||||||
|
input: `
|
||||||
|
let identity = fn(a) { a; };
|
||||||
|
identity(4);
|
||||||
|
`,
|
||||||
|
expected: 4,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: `
|
||||||
|
let sum = fn(a, b) { a + b; };
|
||||||
|
sum(1, 2);
|
||||||
|
`,
|
||||||
|
expected: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: `
|
||||||
|
let sum = fn(a, b) {
|
||||||
|
let c = a + b;
|
||||||
|
c;
|
||||||
|
};
|
||||||
|
sum(1, 2);
|
||||||
|
`,
|
||||||
|
expected: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: `
|
||||||
|
let sum = fn(a, b) {
|
||||||
|
let c = a + b;
|
||||||
|
c;
|
||||||
|
};
|
||||||
|
sum(1, 2) + sum(3, 4);`,
|
||||||
|
expected: 10,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: `
|
||||||
|
let sum = fn(a, b) {
|
||||||
|
let c = a + b;
|
||||||
|
c;
|
||||||
|
};
|
||||||
|
let outer = fn() {
|
||||||
|
sum(1, 2) + sum(3, 4);
|
||||||
|
};
|
||||||
|
outer();
|
||||||
|
`,
|
||||||
|
expected: 10,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: `
|
||||||
|
let globalNum = 10;
|
||||||
|
|
||||||
|
let sum = fn(a, b) {
|
||||||
|
let c = a + b;
|
||||||
|
c + globalNum;
|
||||||
|
};
|
||||||
|
|
||||||
|
let outer = fn() {
|
||||||
|
sum(1, 2) + sum(3, 4) + globalNum;
|
||||||
|
};
|
||||||
|
|
||||||
|
outer() + globalNum;
|
||||||
|
`,
|
||||||
|
expected: 50,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
runVmTests(t, tests)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCallingFunctionsWithWrongArguments(t *testing.T) {
|
||||||
|
tests := []vmTestCase{
|
||||||
|
{
|
||||||
|
input: `fn() { 1; }(1);`,
|
||||||
|
expected: `wrong number of arguments: want=0, got=1`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: `fn(a) { a; }();`,
|
||||||
|
expected: `wrong number of arguments: want=1, got=0`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: `fn(a, b) { a + b; }(1);`,
|
||||||
|
expected: `wrong number of arguments: want=2, got=1`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
program := parse(tt.input)
|
||||||
|
|
||||||
|
comp := compiler.New()
|
||||||
|
err := comp.Compile(program)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("compiler error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
vm := New(comp.Bytecode())
|
||||||
|
err = vm.Run()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected VM error but resulted in none.")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err.Error() != tt.expected {
|
||||||
|
t.Fatalf("wrong VM error: want=%q, got=%q", tt.expected, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user