functions with arguments
Some checks failed
Build / build (push) Failing after 22s
Test / build (push) Failing after 22s

This commit is contained in:
Chuck Smith
2024-03-12 15:53:35 -04:00
parent ec9a586f7f
commit 1d2c7f0a51
6 changed files with 203 additions and 13 deletions

View File

@@ -66,7 +66,7 @@ var definitions = map[Opcode]*Definition{
OpArray: {"OpArray", []int{2}},
OpHash: {"OpHash", []int{2}},
OpIndex: {"OpIndex", []int{}},
OpCall: {"OpCall", []int{}},
OpCall: {"OpCall", []int{1}},
OpReturnValue: {"OpReturnValue", []int{}},
OpReturn: {"OpReturn", []int{}},
OpGetLocal: {"OpGetLocal", []int{1}},

View File

@@ -265,6 +265,10 @@ func (c *Compiler) Compile(node ast.Node) error {
case *ast.FunctionLiteral:
c.enterScope()
for _, p := range node.Parameters {
c.symbolTable.Define(p.Value)
}
err := c.Compile(node.Body)
if err != nil {
return err
@@ -280,7 +284,11 @@ func (c *Compiler) Compile(node ast.Node) error {
numLocals := c.symbolTable.numDefinitions
instructions := c.leaveScope()
compiledFn := &object.CompiledFunction{Instructions: instructions, NumLocals: numLocals}
compiledFn := &object.CompiledFunction{
Instructions: instructions,
NumLocals: numLocals,
NumParameters: len(node.Parameters),
}
c.emit(code.OpConstant, c.addConstant(compiledFn))
case *ast.ReturnStatement:
@@ -297,7 +305,14 @@ func (c *Compiler) Compile(node ast.Node) error {
return err
}
c.emit(code.OpCall)
for _, a := range node.Arguments {
err := c.Compile(a)
if err != nil {
return err
}
}
c.emit(code.OpCall, len(node.Arguments))
}

View File

@@ -562,7 +562,7 @@ func TestFunctionCalls(t *testing.T) {
},
expectedInstructions: []code.Instructions{
code.Make(code.OpConstant, 1), // The compiled function
code.Make(code.OpCall),
code.Make(code.OpCall, 0),
code.Make(code.OpPop),
},
},
@@ -582,7 +582,57 @@ func TestFunctionCalls(t *testing.T) {
code.Make(code.OpConstant, 1), // The compiled function
code.Make(code.OpSetGlobal, 0),
code.Make(code.OpGetGlobal, 0),
code.Make(code.OpCall),
code.Make(code.OpCall, 0),
code.Make(code.OpPop),
},
},
{
input: `
let oneArg = fn(a) { a };
oneArg(24);
`,
expectedConstants: []interface{}{
[]code.Instructions{
code.Make(code.OpGetLocal, 0),
code.Make(code.OpReturnValue),
},
24,
},
expectedInstructions: []code.Instructions{
code.Make(code.OpConstant, 0),
code.Make(code.OpSetGlobal, 0),
code.Make(code.OpGetGlobal, 0),
code.Make(code.OpConstant, 1),
code.Make(code.OpCall, 1),
code.Make(code.OpPop),
},
},
{
input: `
let manyArg = fn(a, b, c) { a; b; c };
manyArg(24, 25, 26);
`,
expectedConstants: []interface{}{
[]code.Instructions{
code.Make(code.OpGetLocal, 0),
code.Make(code.OpPop),
code.Make(code.OpGetLocal, 1),
code.Make(code.OpPop),
code.Make(code.OpGetLocal, 2),
code.Make(code.OpReturnValue),
},
24,
25,
26,
},
expectedInstructions: []code.Instructions{
code.Make(code.OpConstant, 0),
code.Make(code.OpSetGlobal, 0),
code.Make(code.OpGetGlobal, 0),
code.Make(code.OpConstant, 1),
code.Make(code.OpConstant, 2),
code.Make(code.OpConstant, 3),
code.Make(code.OpCall, 3),
code.Make(code.OpPop),
},
},

View File

@@ -37,6 +37,7 @@ type Integer struct {
type CompiledFunction struct {
Instructions code.Instructions
NumLocals int
NumParameters int
}
func (i *Integer) Type() ObjectType {

View File

@@ -203,13 +203,13 @@ func (vm *VM) Run() error {
}
case code.OpCall:
fn, ok := vm.stack[vm.sp-1].(*object.CompiledFunction)
if !ok {
return fmt.Errorf("calling non-function")
numArgs := code.ReadUint8(ins[ip+1:])
vm.currentFrame().ip += 1
err := vm.callFunction(int(numArgs))
if err != nil {
return err
}
frame := NewFrame(fn, vm.sp)
vm.pushFrame(frame)
vm.sp = frame.basePointer + fn.NumLocals
case code.OpReturnValue:
returnValue := vm.pop()
@@ -465,6 +465,23 @@ func (vm *VM) executeMinusOperator() error {
return vm.push(&object.Integer{Value: -value})
}
func (vm *VM) callFunction(numArgs int) error {
fn, ok := vm.stack[vm.sp-1-numArgs].(*object.CompiledFunction)
if !ok {
return fmt.Errorf("calling non-function")
}
if numArgs != fn.NumParameters {
return fmt.Errorf("wrong number of arguments: want=%d, got=%d", fn.NumParameters, numArgs)
}
frame := NewFrame(fn, vm.sp-numArgs)
vm.pushFrame(frame)
vm.sp = frame.basePointer + fn.NumLocals
return nil
}
func nativeBoolToBooleanObject(input bool) *object.Boolean {
if input {
return True

View File

@@ -448,3 +448,110 @@ func TestCallingFunctionsWithBindings(t *testing.T) {
runVmTests(t, tests)
}
func TestCallingFunctionsWithArgumentsAndBindings(t *testing.T) {
tests := []vmTestCase{
{
input: `
let identity = fn(a) { a; };
identity(4);
`,
expected: 4,
},
{
input: `
let sum = fn(a, b) { a + b; };
sum(1, 2);
`,
expected: 3,
},
{
input: `
let sum = fn(a, b) {
let c = a + b;
c;
};
sum(1, 2);
`,
expected: 3,
},
{
input: `
let sum = fn(a, b) {
let c = a + b;
c;
};
sum(1, 2) + sum(3, 4);`,
expected: 10,
},
{
input: `
let sum = fn(a, b) {
let c = a + b;
c;
};
let outer = fn() {
sum(1, 2) + sum(3, 4);
};
outer();
`,
expected: 10,
},
{
input: `
let globalNum = 10;
let sum = fn(a, b) {
let c = a + b;
c + globalNum;
};
let outer = fn() {
sum(1, 2) + sum(3, 4) + globalNum;
};
outer() + globalNum;
`,
expected: 50,
},
}
runVmTests(t, tests)
}
func TestCallingFunctionsWithWrongArguments(t *testing.T) {
tests := []vmTestCase{
{
input: `fn() { 1; }(1);`,
expected: `wrong number of arguments: want=0, got=1`,
},
{
input: `fn(a) { a; }();`,
expected: `wrong number of arguments: want=1, got=0`,
},
{
input: `fn(a, b) { a + b; }(1);`,
expected: `wrong number of arguments: want=2, got=1`,
},
}
for _, tt := range tests {
program := parse(tt.input)
comp := compiler.New()
err := comp.Compile(program)
if err != nil {
t.Fatalf("compiler error: %s", err)
}
vm := New(comp.Bytecode())
err = vm.Run()
if err == nil {
t.Fatalf("expected VM error but resulted in none.")
}
if err.Error() != tt.expected {
t.Fatalf("wrong VM error: want=%q, got=%q", tt.expected, err)
}
}
}