functions with arguments
This commit is contained in:
@@ -66,7 +66,7 @@ var definitions = map[Opcode]*Definition{
|
||||
OpArray: {"OpArray", []int{2}},
|
||||
OpHash: {"OpHash", []int{2}},
|
||||
OpIndex: {"OpIndex", []int{}},
|
||||
OpCall: {"OpCall", []int{}},
|
||||
OpCall: {"OpCall", []int{1}},
|
||||
OpReturnValue: {"OpReturnValue", []int{}},
|
||||
OpReturn: {"OpReturn", []int{}},
|
||||
OpGetLocal: {"OpGetLocal", []int{1}},
|
||||
|
||||
@@ -265,6 +265,10 @@ func (c *Compiler) Compile(node ast.Node) error {
|
||||
case *ast.FunctionLiteral:
|
||||
c.enterScope()
|
||||
|
||||
for _, p := range node.Parameters {
|
||||
c.symbolTable.Define(p.Value)
|
||||
}
|
||||
|
||||
err := c.Compile(node.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -280,7 +284,11 @@ func (c *Compiler) Compile(node ast.Node) error {
|
||||
numLocals := c.symbolTable.numDefinitions
|
||||
instructions := c.leaveScope()
|
||||
|
||||
compiledFn := &object.CompiledFunction{Instructions: instructions, NumLocals: numLocals}
|
||||
compiledFn := &object.CompiledFunction{
|
||||
Instructions: instructions,
|
||||
NumLocals: numLocals,
|
||||
NumParameters: len(node.Parameters),
|
||||
}
|
||||
c.emit(code.OpConstant, c.addConstant(compiledFn))
|
||||
|
||||
case *ast.ReturnStatement:
|
||||
@@ -297,7 +305,14 @@ func (c *Compiler) Compile(node ast.Node) error {
|
||||
return err
|
||||
}
|
||||
|
||||
c.emit(code.OpCall)
|
||||
for _, a := range node.Arguments {
|
||||
err := c.Compile(a)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
c.emit(code.OpCall, len(node.Arguments))
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -562,7 +562,7 @@ func TestFunctionCalls(t *testing.T) {
|
||||
},
|
||||
expectedInstructions: []code.Instructions{
|
||||
code.Make(code.OpConstant, 1), // The compiled function
|
||||
code.Make(code.OpCall),
|
||||
code.Make(code.OpCall, 0),
|
||||
code.Make(code.OpPop),
|
||||
},
|
||||
},
|
||||
@@ -582,7 +582,57 @@ func TestFunctionCalls(t *testing.T) {
|
||||
code.Make(code.OpConstant, 1), // The compiled function
|
||||
code.Make(code.OpSetGlobal, 0),
|
||||
code.Make(code.OpGetGlobal, 0),
|
||||
code.Make(code.OpCall),
|
||||
code.Make(code.OpCall, 0),
|
||||
code.Make(code.OpPop),
|
||||
},
|
||||
},
|
||||
{
|
||||
input: `
|
||||
let oneArg = fn(a) { a };
|
||||
oneArg(24);
|
||||
`,
|
||||
expectedConstants: []interface{}{
|
||||
[]code.Instructions{
|
||||
code.Make(code.OpGetLocal, 0),
|
||||
code.Make(code.OpReturnValue),
|
||||
},
|
||||
24,
|
||||
},
|
||||
expectedInstructions: []code.Instructions{
|
||||
code.Make(code.OpConstant, 0),
|
||||
code.Make(code.OpSetGlobal, 0),
|
||||
code.Make(code.OpGetGlobal, 0),
|
||||
code.Make(code.OpConstant, 1),
|
||||
code.Make(code.OpCall, 1),
|
||||
code.Make(code.OpPop),
|
||||
},
|
||||
},
|
||||
{
|
||||
input: `
|
||||
let manyArg = fn(a, b, c) { a; b; c };
|
||||
manyArg(24, 25, 26);
|
||||
`,
|
||||
expectedConstants: []interface{}{
|
||||
[]code.Instructions{
|
||||
code.Make(code.OpGetLocal, 0),
|
||||
code.Make(code.OpPop),
|
||||
code.Make(code.OpGetLocal, 1),
|
||||
code.Make(code.OpPop),
|
||||
code.Make(code.OpGetLocal, 2),
|
||||
code.Make(code.OpReturnValue),
|
||||
},
|
||||
24,
|
||||
25,
|
||||
26,
|
||||
},
|
||||
expectedInstructions: []code.Instructions{
|
||||
code.Make(code.OpConstant, 0),
|
||||
code.Make(code.OpSetGlobal, 0),
|
||||
code.Make(code.OpGetGlobal, 0),
|
||||
code.Make(code.OpConstant, 1),
|
||||
code.Make(code.OpConstant, 2),
|
||||
code.Make(code.OpConstant, 3),
|
||||
code.Make(code.OpCall, 3),
|
||||
code.Make(code.OpPop),
|
||||
},
|
||||
},
|
||||
|
||||
@@ -35,8 +35,9 @@ type Integer struct {
|
||||
}
|
||||
|
||||
type CompiledFunction struct {
|
||||
Instructions code.Instructions
|
||||
NumLocals int
|
||||
Instructions code.Instructions
|
||||
NumLocals int
|
||||
NumParameters int
|
||||
}
|
||||
|
||||
func (i *Integer) Type() ObjectType {
|
||||
|
||||
29
vm/vm.go
29
vm/vm.go
@@ -203,13 +203,13 @@ func (vm *VM) Run() error {
|
||||
}
|
||||
|
||||
case code.OpCall:
|
||||
fn, ok := vm.stack[vm.sp-1].(*object.CompiledFunction)
|
||||
if !ok {
|
||||
return fmt.Errorf("calling non-function")
|
||||
numArgs := code.ReadUint8(ins[ip+1:])
|
||||
vm.currentFrame().ip += 1
|
||||
|
||||
err := vm.callFunction(int(numArgs))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
frame := NewFrame(fn, vm.sp)
|
||||
vm.pushFrame(frame)
|
||||
vm.sp = frame.basePointer + fn.NumLocals
|
||||
|
||||
case code.OpReturnValue:
|
||||
returnValue := vm.pop()
|
||||
@@ -465,6 +465,23 @@ func (vm *VM) executeMinusOperator() error {
|
||||
return vm.push(&object.Integer{Value: -value})
|
||||
}
|
||||
|
||||
func (vm *VM) callFunction(numArgs int) error {
|
||||
fn, ok := vm.stack[vm.sp-1-numArgs].(*object.CompiledFunction)
|
||||
if !ok {
|
||||
return fmt.Errorf("calling non-function")
|
||||
}
|
||||
|
||||
if numArgs != fn.NumParameters {
|
||||
return fmt.Errorf("wrong number of arguments: want=%d, got=%d", fn.NumParameters, numArgs)
|
||||
}
|
||||
|
||||
frame := NewFrame(fn, vm.sp-numArgs)
|
||||
vm.pushFrame(frame)
|
||||
vm.sp = frame.basePointer + fn.NumLocals
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func nativeBoolToBooleanObject(input bool) *object.Boolean {
|
||||
if input {
|
||||
return True
|
||||
|
||||
107
vm/vm_test.go
107
vm/vm_test.go
@@ -448,3 +448,110 @@ func TestCallingFunctionsWithBindings(t *testing.T) {
|
||||
|
||||
runVmTests(t, tests)
|
||||
}
|
||||
|
||||
func TestCallingFunctionsWithArgumentsAndBindings(t *testing.T) {
|
||||
tests := []vmTestCase{
|
||||
{
|
||||
input: `
|
||||
let identity = fn(a) { a; };
|
||||
identity(4);
|
||||
`,
|
||||
expected: 4,
|
||||
},
|
||||
{
|
||||
input: `
|
||||
let sum = fn(a, b) { a + b; };
|
||||
sum(1, 2);
|
||||
`,
|
||||
expected: 3,
|
||||
},
|
||||
{
|
||||
input: `
|
||||
let sum = fn(a, b) {
|
||||
let c = a + b;
|
||||
c;
|
||||
};
|
||||
sum(1, 2);
|
||||
`,
|
||||
expected: 3,
|
||||
},
|
||||
{
|
||||
input: `
|
||||
let sum = fn(a, b) {
|
||||
let c = a + b;
|
||||
c;
|
||||
};
|
||||
sum(1, 2) + sum(3, 4);`,
|
||||
expected: 10,
|
||||
},
|
||||
{
|
||||
input: `
|
||||
let sum = fn(a, b) {
|
||||
let c = a + b;
|
||||
c;
|
||||
};
|
||||
let outer = fn() {
|
||||
sum(1, 2) + sum(3, 4);
|
||||
};
|
||||
outer();
|
||||
`,
|
||||
expected: 10,
|
||||
},
|
||||
{
|
||||
input: `
|
||||
let globalNum = 10;
|
||||
|
||||
let sum = fn(a, b) {
|
||||
let c = a + b;
|
||||
c + globalNum;
|
||||
};
|
||||
|
||||
let outer = fn() {
|
||||
sum(1, 2) + sum(3, 4) + globalNum;
|
||||
};
|
||||
|
||||
outer() + globalNum;
|
||||
`,
|
||||
expected: 50,
|
||||
},
|
||||
}
|
||||
|
||||
runVmTests(t, tests)
|
||||
}
|
||||
|
||||
func TestCallingFunctionsWithWrongArguments(t *testing.T) {
|
||||
tests := []vmTestCase{
|
||||
{
|
||||
input: `fn() { 1; }(1);`,
|
||||
expected: `wrong number of arguments: want=0, got=1`,
|
||||
},
|
||||
{
|
||||
input: `fn(a) { a; }();`,
|
||||
expected: `wrong number of arguments: want=1, got=0`,
|
||||
},
|
||||
{
|
||||
input: `fn(a, b) { a + b; }(1);`,
|
||||
expected: `wrong number of arguments: want=2, got=1`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
program := parse(tt.input)
|
||||
|
||||
comp := compiler.New()
|
||||
err := comp.Compile(program)
|
||||
if err != nil {
|
||||
t.Fatalf("compiler error: %s", err)
|
||||
}
|
||||
|
||||
vm := New(comp.Bytecode())
|
||||
err = vm.Run()
|
||||
if err == nil {
|
||||
t.Fatalf("expected VM error but resulted in none.")
|
||||
}
|
||||
|
||||
if err.Error() != tt.expected {
|
||||
t.Fatalf("wrong VM error: want=%q, got=%q", tt.expected, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user