diff --git a/ast/ast.go b/ast/ast.go index 8b4ac2a..2ae2b6c 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -2,6 +2,7 @@ package ast import ( "bytes" + "fmt" "monkey/token" "strings" ) @@ -220,6 +221,7 @@ type FunctionLiteral struct { Token token.Token // The 'fn' token Parameters []*Identifier Body *BlockStatement + Name string } func (fl *FunctionLiteral) expressionNode() {} @@ -233,6 +235,9 @@ func (fl *FunctionLiteral) String() string { } out.WriteString(fl.TokenLiteral()) + if fl.Name != "" { + out.WriteString(fmt.Sprintf("<%s>", fl.Name)) + } out.WriteString("(") out.WriteString(strings.Join(params, ", ")) out.WriteString(") ") diff --git a/code/code.go b/code/code.go index 764e660..2c730fd 100644 --- a/code/code.go +++ b/code/code.go @@ -39,6 +39,8 @@ const ( OpSetLocal OpGetBuiltin OpClosure + OpGetFree + OpCurrentClosure ) type Definition struct { @@ -47,34 +49,36 @@ type Definition struct { } var definitions = map[Opcode]*Definition{ - OpConstant: {"OpConstant", []int{2}}, - OpAdd: {"OpAdd", []int{}}, - OpPop: {"OpPop", []int{}}, - OpSub: {"OpSub", []int{}}, - OpMul: {"OpMul", []int{}}, - OpDiv: {"OpDiv", []int{}}, - OpTrue: {"OpTrue", []int{}}, - OpFalse: {"OpFalse", []int{}}, - OpEqual: {"OpEqual", []int{}}, - OpNotEqual: {"OpNotEqual", []int{}}, - OpGreaterThan: {"OpGreaterThan", []int{}}, - OpMinus: {"OpMinus", []int{}}, - OpBang: {"OpBang", []int{}}, - OpJumpNotTruthy: {"OpJumpNotTruthy", []int{2}}, - OpJump: {"OpJump", []int{2}}, - OpNull: {"OpNull", []int{}}, - OpGetGlobal: {"OpGetGlobal", []int{2}}, - OpSetGlobal: {"OpSetGlobal", []int{2}}, - OpArray: {"OpArray", []int{2}}, - OpHash: {"OpHash", []int{2}}, - OpIndex: {"OpIndex", []int{}}, - OpCall: {"OpCall", []int{1}}, - OpReturnValue: {"OpReturnValue", []int{}}, - OpReturn: {"OpReturn", []int{}}, - OpGetLocal: {"OpGetLocal", []int{1}}, - OpSetLocal: {"OpSetLocal", []int{1}}, - OpGetBuiltin: {"OpGetBuiltin", []int{1}}, - OpClosure: {"OpClosure", []int{2, 1}}, + OpConstant: {"OpConstant", []int{2}}, + OpAdd: {"OpAdd", []int{}}, + OpPop: {"OpPop", []int{}}, + OpSub: {"OpSub", []int{}}, + OpMul: {"OpMul", []int{}}, + OpDiv: {"OpDiv", []int{}}, + OpTrue: {"OpTrue", []int{}}, + OpFalse: {"OpFalse", []int{}}, + OpEqual: {"OpEqual", []int{}}, + OpNotEqual: {"OpNotEqual", []int{}}, + OpGreaterThan: {"OpGreaterThan", []int{}}, + OpMinus: {"OpMinus", []int{}}, + OpBang: {"OpBang", []int{}}, + OpJumpNotTruthy: {"OpJumpNotTruthy", []int{2}}, + OpJump: {"OpJump", []int{2}}, + OpNull: {"OpNull", []int{}}, + OpGetGlobal: {"OpGetGlobal", []int{2}}, + OpSetGlobal: {"OpSetGlobal", []int{2}}, + OpArray: {"OpArray", []int{2}}, + OpHash: {"OpHash", []int{2}}, + OpIndex: {"OpIndex", []int{}}, + OpCall: {"OpCall", []int{1}}, + OpReturnValue: {"OpReturnValue", []int{}}, + OpReturn: {"OpReturn", []int{}}, + OpGetLocal: {"OpGetLocal", []int{1}}, + OpSetLocal: {"OpSetLocal", []int{1}}, + OpGetBuiltin: {"OpGetBuiltin", []int{1}}, + OpClosure: {"OpClosure", []int{2, 1}}, + OpGetFree: {"OpGetFree", []int{1}}, + OpCurrentClosure: {"OpCurrentClosure", []int{}}, } func Lookup(op byte) (*Definition, error) { diff --git a/compiler/compiler.go b/compiler/compiler.go index 9dc48af..9657355 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -196,12 +196,12 @@ func (c *Compiler) Compile(node ast.Node) error { } case *ast.LetStatement: + symbol := c.symbolTable.Define(node.Name.Value) err := c.Compile(node.Value) if err != nil { return err } - symbol := c.symbolTable.Define(node.Name.Value) if symbol.Scope == GlobalScope { c.emit(code.OpSetGlobal, symbol.Index) } else { @@ -268,6 +268,10 @@ func (c *Compiler) Compile(node ast.Node) error { case *ast.FunctionLiteral: c.enterScope() + if node.Name != "" { + c.symbolTable.DefineFunctionName(node.Name) + } + for _, p := range node.Parameters { c.symbolTable.Define(p.Value) } @@ -284,9 +288,14 @@ func (c *Compiler) Compile(node ast.Node) error { c.emit(code.OpReturn) } + freeSymbols := c.symbolTable.FreeSymbols numLocals := c.symbolTable.numDefinitions instructions := c.leaveScope() + for _, s := range freeSymbols { + c.loadSymbol(s) + } + compiledFn := &object.CompiledFunction{ Instructions: instructions, NumLocals: numLocals, @@ -294,7 +303,7 @@ func (c *Compiler) Compile(node ast.Node) error { } fnIndex := c.addConstant(compiledFn) - c.emit(code.OpClosure, fnIndex, 0) + c.emit(code.OpClosure, fnIndex, len(freeSymbols)) case *ast.ReturnStatement: err := c.Compile(node.ReturnValue) @@ -439,5 +448,9 @@ func (c *Compiler) loadSymbol(s Symbol) { c.emit(code.OpGetLocal, s.Index) case BuiltinScope: c.emit(code.OpGetBuiltin, s.Index) + case FreeScope: + c.emit(code.OpGetFree, s.Index) + case FunctionScope: + c.emit(code.OpCurrentClosure) } } diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go index bbd131b..9845ce9 100644 --- a/compiler/compiler_test.go +++ b/compiler/compiler_test.go @@ -503,6 +503,132 @@ func TestFunctionsWithoutReturnValue(t *testing.T) { runCompilerTests(t, tests) } +func TestClosures(t *testing.T) { + tests := []compilerTestCase{ + { + input: `fn(a) { + fn(b) { + a + b + } + } + `, + expectedConstants: []interface{}{ + []code.Instructions{ + code.Make(code.OpGetFree, 0), + code.Make(code.OpGetLocal, 0), + code.Make(code.OpAdd), + code.Make(code.OpReturnValue), + }, + []code.Instructions{ + code.Make(code.OpGetLocal, 0), + code.Make(code.OpClosure, 0, 1), + code.Make(code.OpReturnValue), + }, + }, + expectedInstructions: []code.Instructions{ + code.Make(code.OpClosure, 1, 0), + code.Make(code.OpPop), + }, + }, + { + input: ` + fn(a) { + fn(b) { + fn(c) { + a + b + c + } + } + }; + `, + expectedConstants: []interface{}{ + []code.Instructions{ + code.Make(code.OpGetFree, 0), + code.Make(code.OpGetFree, 1), + code.Make(code.OpAdd), + code.Make(code.OpGetLocal, 0), + code.Make(code.OpAdd), + code.Make(code.OpReturnValue), + }, + []code.Instructions{ + code.Make(code.OpGetFree, 0), + code.Make(code.OpGetLocal, 0), + code.Make(code.OpClosure, 0, 2), + code.Make(code.OpReturnValue), + }, + []code.Instructions{ + code.Make(code.OpGetLocal, 0), + code.Make(code.OpClosure, 1, 1), + code.Make(code.OpReturnValue), + }, + }, + expectedInstructions: []code.Instructions{ + code.Make(code.OpClosure, 2, 0), + code.Make(code.OpPop), + }, + }, + { + input: ` + let global = 55; + + fn() { + let a = 66; + + fn() { + let b = 77; + + fn() { + let c = 88; + + global + a + b + c; + } + } + } + `, + expectedConstants: []interface{}{ + 55, + 66, + 77, + 88, + []code.Instructions{ + code.Make(code.OpConstant, 3), + code.Make(code.OpSetLocal, 0), + code.Make(code.OpGetGlobal, 0), + code.Make(code.OpGetFree, 0), + code.Make(code.OpAdd), + code.Make(code.OpGetFree, 1), + code.Make(code.OpAdd), + code.Make(code.OpGetLocal, 0), + code.Make(code.OpAdd), + code.Make(code.OpReturnValue), + }, + []code.Instructions{ + code.Make(code.OpConstant, 2), + code.Make(code.OpSetLocal, 0), + code.Make(code.OpGetFree, 0), + code.Make(code.OpGetLocal, 0), + code.Make(code.OpClosure, 4, 2), + code.Make(code.OpReturnValue), + }, + []code.Instructions{ + code.Make(code.OpConstant, 1), + code.Make(code.OpSetLocal, 0), + code.Make(code.OpGetLocal, 0), + code.Make(code.OpClosure, 5, 1), + code.Make(code.OpReturnValue), + }, + }, + expectedInstructions: []code.Instructions{ + code.Make(code.OpConstant, 0), + code.Make(code.OpSetGlobal, 0), + code.Make(code.OpClosure, 6, 0), + code.Make(code.OpPop), + }, + }, + } + + runCompilerTests(t, tests) +} + func TestCompilerScopes(t *testing.T) { compiler := New() if compiler.scopeIndex != 0 { @@ -755,6 +881,75 @@ func TestBuiltins(t *testing.T) { runCompilerTests(t, tests) } +func TestRecursiveFunctions(t *testing.T) { + tests := []compilerTestCase{ + { + input: ` + let countDown = fn(x) { countDown(x - 1); }; + countDown(1); + `, + expectedConstants: []interface{}{ + 1, + []code.Instructions{ + code.Make(code.OpCurrentClosure), + code.Make(code.OpGetLocal, 0), + code.Make(code.OpConstant, 0), + code.Make(code.OpSub), + code.Make(code.OpCall, 1), + code.Make(code.OpReturnValue), + }, + 1, + }, + expectedInstructions: []code.Instructions{ + code.Make(code.OpClosure, 1, 0), + code.Make(code.OpSetGlobal, 0), + code.Make(code.OpGetGlobal, 0), + code.Make(code.OpConstant, 2), + code.Make(code.OpCall, 1), + code.Make(code.OpPop), + }, + }, + { + input: ` + let wrapper = fn() { + let countDown = fn(x) { countDown(x - 1); }; + countDown(1); + }; + wrapper(); + `, + expectedConstants: []interface{}{ + 1, + []code.Instructions{ + code.Make(code.OpCurrentClosure), + code.Make(code.OpGetLocal, 0), + code.Make(code.OpConstant, 0), + code.Make(code.OpSub), + code.Make(code.OpCall, 1), + code.Make(code.OpReturnValue), + }, + 1, + []code.Instructions{ + code.Make(code.OpClosure, 1, 0), + code.Make(code.OpSetLocal, 0), + code.Make(code.OpGetLocal, 0), + code.Make(code.OpConstant, 2), + code.Make(code.OpCall, 1), + code.Make(code.OpReturnValue), + }, + }, + expectedInstructions: []code.Instructions{ + code.Make(code.OpClosure, 3, 0), + code.Make(code.OpSetGlobal, 0), + code.Make(code.OpGetGlobal, 0), + code.Make(code.OpCall, 0), + code.Make(code.OpPop), + }, + }, + } + + runCompilerTests(t, tests) +} + func runCompilerTests(t *testing.T, tests []compilerTestCase) { t.Helper() diff --git a/compiler/symbol_table.go b/compiler/symbol_table.go index 11d45af..3aabc81 100644 --- a/compiler/symbol_table.go +++ b/compiler/symbol_table.go @@ -3,9 +3,11 @@ package compiler type SymbolScope string const ( - LocalScope SymbolScope = "LOCAL" - GlobalScope SymbolScope = "GLOBAL" - BuiltinScope SymbolScope = "BUILTIN" + LocalScope SymbolScope = "LOCAL" + GlobalScope SymbolScope = "GLOBAL" + BuiltinScope SymbolScope = "BUILTIN" + FreeScope SymbolScope = "FREE" + FunctionScope SymbolScope = "FUNCTION" ) type Symbol struct { @@ -19,6 +21,8 @@ type SymbolTable struct { store map[string]Symbol numDefinitions int + + FreeSymbols []Symbol } func NewEnclosedSymbolTable(outer *SymbolTable) *SymbolTable { @@ -29,7 +33,8 @@ func NewEnclosedSymbolTable(outer *SymbolTable) *SymbolTable { func NewSymbolTable() *SymbolTable { s := make(map[string]Symbol) - return &SymbolTable{store: s} + free := []Symbol{} + return &SymbolTable{store: s, FreeSymbols: free} } func (s *SymbolTable) Define(name string) Symbol { @@ -49,7 +54,17 @@ func (s *SymbolTable) Resolve(name string) (Symbol, bool) { obj, ok := s.store[name] if !ok && s.Outer != nil { obj, ok = s.Outer.Resolve(name) - return obj, ok + if !ok { + return obj, ok + } + + if obj.Scope == GlobalScope || obj.Scope == BuiltinScope { + return obj, ok + } + + free := s.DefineFree(obj) + return free, true + } return obj, ok @@ -60,3 +75,19 @@ func (s *SymbolTable) DefineBuiltin(index int, name string) Symbol { s.store[name] = symbol return symbol } + +func (s *SymbolTable) DefineFree(original Symbol) Symbol { + s.FreeSymbols = append(s.FreeSymbols, original) + + symbol := Symbol{Name: original.Name, Index: len(s.FreeSymbols) - 1} + symbol.Scope = FreeScope + + s.store[original.Name] = symbol + return symbol +} + +func (s *SymbolTable) DefineFunctionName(name string) Symbol { + symbol := Symbol{Name: name, Index: 0, Scope: FunctionScope} + s.store[name] = symbol + return symbol +} diff --git a/compiler/symbol_table_test.go b/compiler/symbol_table_test.go index 6b7f74a..d75dcd7 100644 --- a/compiler/symbol_table_test.go +++ b/compiler/symbol_table_test.go @@ -197,3 +197,154 @@ func TestDefineResolveBuiltins(t *testing.T) { } } } + +func TestResolveFree(t *testing.T) { + global := NewSymbolTable() + global.Define("a") + global.Define("b") + + firstLocal := NewEnclosedSymbolTable(global) + firstLocal.Define("c") + firstLocal.Define("d") + + secondLocal := NewEnclosedSymbolTable(firstLocal) + secondLocal.Define("e") + secondLocal.Define("f") + + tests := []struct { + table *SymbolTable + expectedSymbols []Symbol + expectedFreeSymbols []Symbol + }{ + { + firstLocal, + []Symbol{ + Symbol{Name: "a", Scope: GlobalScope, Index: 0}, + Symbol{Name: "b", Scope: GlobalScope, Index: 1}, + Symbol{Name: "c", Scope: LocalScope, Index: 0}, + Symbol{Name: "d", Scope: LocalScope, Index: 1}, + }, + []Symbol{}, + }, + { + secondLocal, + []Symbol{ + Symbol{Name: "a", Scope: GlobalScope, Index: 0}, + Symbol{Name: "b", Scope: GlobalScope, Index: 1}, + Symbol{Name: "c", Scope: FreeScope, Index: 0}, + Symbol{Name: "d", Scope: FreeScope, Index: 1}, + Symbol{Name: "e", Scope: LocalScope, Index: 0}, + Symbol{Name: "f", Scope: LocalScope, Index: 1}, + }, + []Symbol{ + Symbol{Name: "c", Scope: LocalScope, Index: 0}, + Symbol{Name: "d", Scope: LocalScope, Index: 1}, + }, + }, + } + + for _, tt := range tests { + for _, sym := range tt.expectedSymbols { + result, ok := tt.table.Resolve(sym.Name) + if !ok { + t.Errorf("name %s not resolvable", sym.Name) + continue + } + if result != sym { + t.Errorf("expected %s to resolve to %+v, got=%+v", + sym.Name, sym, result) + } + } + + if len(tt.table.FreeSymbols) != len(tt.expectedFreeSymbols) { + t.Errorf("wrong number of free symbols. got=%d, want=%d", + len(tt.table.FreeSymbols), len(tt.expectedFreeSymbols)) + continue + } + + for i, sym := range tt.expectedFreeSymbols { + result := tt.table.FreeSymbols[i] + if result != sym { + t.Errorf("wrong free symbol. got=%+v, want=%+v", + result, sym) + } + } + } +} + +func TestResolveUnresolvableFree(t *testing.T) { + global := NewSymbolTable() + global.Define("a") + + firstLocal := NewEnclosedSymbolTable(global) + firstLocal.Define("c") + + secondLocal := NewEnclosedSymbolTable(firstLocal) + secondLocal.Define("e") + secondLocal.Define("f") + + expected := []Symbol{ + Symbol{Name: "a", Scope: GlobalScope, Index: 0}, + Symbol{Name: "c", Scope: FreeScope, Index: 0}, + Symbol{Name: "e", Scope: LocalScope, Index: 0}, + Symbol{Name: "f", Scope: LocalScope, Index: 1}, + } + + for _, sym := range expected { + result, ok := secondLocal.Resolve(sym.Name) + if !ok { + t.Errorf("name %s not resolvable", sym.Name) + continue + } + if result != sym { + t.Errorf("expected %s to resolve to %+v, got=%+v", + sym.Name, sym, result) + } + } + + expectedUnresolvable := []string{ + "b", + "d", + } + + for _, name := range expectedUnresolvable { + _, ok := secondLocal.Resolve(name) + if ok { + t.Errorf("name %s resolved, but was expected not to", name) + } + } +} + +func TestDefineAndResolveFunctionName(t *testing.T) { + global := NewSymbolTable() + global.DefineFunctionName("a") + + expected := Symbol{Name: "a", Scope: FunctionScope, Index: 0} + + result, ok := global.Resolve(expected.Name) + if !ok { + t.Fatalf("function name %s not resolvable", expected.Name) + } + + if result != expected { + t.Errorf("expected %s to resolve to %+v, got=%+v", + expected.Name, expected, result) + } +} +func TestShadowingFunctionName(t *testing.T) { + global := NewSymbolTable() + global.DefineFunctionName("a") + global.Define("a") + + expected := Symbol{Name: "a", Scope: GlobalScope, Index: 0} + + result, ok := global.Resolve(expected.Name) + if !ok { + t.Fatalf("function name %s not resolvable", expected.Name) + } + + if result != expected { + t.Errorf("expected %s to resolve to %+v, got=%+v", + expected.Name, expected, result) + } +} diff --git a/parser/parser.go b/parser/parser.go index 5f03233..bdf9fae 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -169,6 +169,10 @@ func (p *Parser) parseLetStatement() *ast.LetStatement { stmt.Value = p.parseExpression(LOWEST) + if fl, ok := stmt.Value.(*ast.FunctionLiteral); ok { + fl.Name = stmt.Name.Value + } + if p.peekTokenIs(token.SEMICOLON) { p.nextToken() } diff --git a/parser/parser_test.go b/parser/parser_test.go index 71d2928..5b2f800 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -845,6 +845,37 @@ func TestParsingHashLiteralsWithExpressions(t *testing.T) { } } +func TestFunctionLiteralWithName(t *testing.T) { + input := `let myFunction = fn() { };` + + l := lexer.New(input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("program.Body does not contain %d statements. got=%d\n", + 1, len(program.Statements)) + } + + stmt, ok := program.Statements[0].(*ast.LetStatement) + if !ok { + t.Fatalf("program.Statements[0] is not ast.LetStatement. got=%T", + program.Statements[0]) + } + + function, ok := stmt.Value.(*ast.FunctionLiteral) + if !ok { + t.Fatalf("stmt.Value is not ast.FunctionLiteral. got=%T", + stmt.Value) + } + + if function.Name != "myFunction" { + t.Fatalf("function literal name wrong. want 'myFunction', got=%q\n", + function.Name) + } +} + func testLetStatement(t *testing.T, s ast.Statement, name string) bool { if s.TokenLiteral() != "let" { t.Errorf("s.TokenLiteral not 'let'. got=%q", s.TokenLiteral()) diff --git a/vm/vm.go b/vm/vm.go index ec7df47..0c093be 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -264,10 +264,27 @@ func (vm *VM) Run() error { case code.OpClosure: constIndex := code.ReadUint16(ins[ip+1:]) - _ = code.ReadUint8(ins[ip+3:]) + numFree := code.ReadUint8(ins[ip+3:]) vm.currentFrame().ip += 3 - err := vm.pushClosure(int(constIndex)) + err := vm.pushClosure(int(constIndex), int(numFree)) + if err != nil { + return err + } + + case code.OpGetFree: + freeIndex := code.ReadUint8(ins[ip+1:]) + vm.currentFrame().ip += 1 + + currentClosure := vm.currentFrame().cl + err := vm.push(currentClosure.Free[freeIndex]) + if err != nil { + return err + } + + case code.OpCurrentClosure: + currentClosure := vm.currentFrame().cl + err := vm.push(currentClosure) if err != nil { return err } @@ -533,14 +550,20 @@ func (vm *VM) callBuiltin(builtin *object.Builtin, numArgs int) error { return nil } -func (vm *VM) pushClosure(constIndex int) error { +func (vm *VM) pushClosure(constIndex, numFree int) error { constant := vm.constants[constIndex] function, ok := constant.(*object.CompiledFunction) if !ok { return fmt.Errorf("not a function %+v", constant) } - closure := &object.Closure{Fn: function} + free := make([]object.Object, numFree) + for i := 0; i < numFree; i++ { + free[i] = vm.stack[vm.sp-numFree+i] + } + vm.sp = vm.sp - numFree + + closure := &object.Closure{Fn: function, Free: free} return vm.push(closure) } diff --git a/vm/vm_test.go b/vm/vm_test.go index 0cbdb60..395eca4 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -27,6 +27,19 @@ func runVmTests(t *testing.T, tests []vmTestCase) { t.Fatalf("compiler error: %s", err) } + //for i, constant := range comp.Bytecode().Constants { + // fmt.Printf("CONSTANT %d %p (%T):\n", i, constant, constant) + // + // switch constant := constant.(type) { + // case *object.CompiledFunction: + // fmt.Printf(" Instructions:\n%s", constant.Instructions) + // case *object.Integer: + // fmt.Printf(" Value: %d\n", constant.Value) + // } + // + // fmt.Printf("\n") + //} + vm := New(comp.Bytecode()) err = vm.Run() if err != nil { @@ -611,3 +624,134 @@ func TestBuiltinFunctions(t *testing.T) { runVmTests(t, tests) } + +func TestClosures(t *testing.T) { + tests := []vmTestCase{ + { + input: ` + let newClosure = fn(a) { + fn() { a; }; + }; + let closure = newClosure(99); + closure(); + `, + expected: 99, + }, + { + input: ` + let newAdder = fn(a, b) { + fn(c) { a + b + c }; + }; + let adder = newAdder(1, 2); + adder(8); + `, + expected: 11, + }, + { + input: ` + let newAdder = fn(a, b) { + let c = a + b; + fn(d) { c + d }; + }; + let adder = newAdder(1, 2); + adder(8); + `, + expected: 11, + }, + { + input: ` + let newAdderOuter = fn(a, b) { + let c = a + b; + fn(d) { + let e = d + c; + fn(f) { e + f; }; + }; + }; + let newAdderInner = newAdderOuter(1, 2) + let adder = newAdderInner(3); + adder(8); + `, + expected: 14, + }, + { + input: ` + let a = 1; + let newAdderOuter = fn(b) { + fn(c) { + fn(d) { a + b + c + d }; + }; + }; + let newAdderInner = newAdderOuter(2) + let adder = newAdderInner(3); + adder(8); + `, + expected: 14, + }, + { + input: ` + let newClosure = fn(a, b) { + let one = fn() { a; }; + let two = fn() { b; }; + fn() { one() + two(); }; + }; + let closure = newClosure(9, 90); + closure(); + `, + expected: 99, + }, + } + + runVmTests(t, tests) +} + +func TestRecursiveFunctions(t *testing.T) { + tests := []vmTestCase{ + { + input: ` + let countDown = fn(x) { + if (x == 0) { + return 0; + } else { + countDown(x - 1); + } + }; + countDown(1); + `, + expected: 0, + }, + { + input: ` + let countDown = fn(x) { + if (x == 0) { + return 0; + } else { + countDown(x - 1); + } + }; + let wrapper = fn() { + countDown(1); + }; + wrapper(); + `, + expected: 0, + }, + { + input: ` + let wrapper = fn() { + let countDown = fn(x) { + if (x == 0) { + return 0; + } else { + countDown(x - 1); + } + }; + countDown(1); + }; + wrapper(); + `, + expected: 0, + }, + } + + runVmTests(t, tests) +}