closures and they can recurse!!!
Some checks failed
Build / build (push) Failing after 1m35s
Test / build (push) Has been cancelled

This commit is contained in:
Chuck Smith
2024-03-14 20:08:40 -04:00
parent 78b560e457
commit cc78fee3c8
10 changed files with 640 additions and 39 deletions

View File

@@ -2,6 +2,7 @@ package ast
import ( import (
"bytes" "bytes"
"fmt"
"monkey/token" "monkey/token"
"strings" "strings"
) )
@@ -220,6 +221,7 @@ type FunctionLiteral struct {
Token token.Token // The 'fn' token Token token.Token // The 'fn' token
Parameters []*Identifier Parameters []*Identifier
Body *BlockStatement Body *BlockStatement
Name string
} }
func (fl *FunctionLiteral) expressionNode() {} func (fl *FunctionLiteral) expressionNode() {}
@@ -233,6 +235,9 @@ func (fl *FunctionLiteral) String() string {
} }
out.WriteString(fl.TokenLiteral()) out.WriteString(fl.TokenLiteral())
if fl.Name != "" {
out.WriteString(fmt.Sprintf("<%s>", fl.Name))
}
out.WriteString("(") out.WriteString("(")
out.WriteString(strings.Join(params, ", ")) out.WriteString(strings.Join(params, ", "))
out.WriteString(") ") out.WriteString(") ")

View File

@@ -39,6 +39,8 @@ const (
OpSetLocal OpSetLocal
OpGetBuiltin OpGetBuiltin
OpClosure OpClosure
OpGetFree
OpCurrentClosure
) )
type Definition struct { type Definition struct {
@@ -75,6 +77,8 @@ var definitions = map[Opcode]*Definition{
OpSetLocal: {"OpSetLocal", []int{1}}, OpSetLocal: {"OpSetLocal", []int{1}},
OpGetBuiltin: {"OpGetBuiltin", []int{1}}, OpGetBuiltin: {"OpGetBuiltin", []int{1}},
OpClosure: {"OpClosure", []int{2, 1}}, OpClosure: {"OpClosure", []int{2, 1}},
OpGetFree: {"OpGetFree", []int{1}},
OpCurrentClosure: {"OpCurrentClosure", []int{}},
} }
func Lookup(op byte) (*Definition, error) { func Lookup(op byte) (*Definition, error) {

View File

@@ -196,12 +196,12 @@ func (c *Compiler) Compile(node ast.Node) error {
} }
case *ast.LetStatement: case *ast.LetStatement:
symbol := c.symbolTable.Define(node.Name.Value)
err := c.Compile(node.Value) err := c.Compile(node.Value)
if err != nil { if err != nil {
return err return err
} }
symbol := c.symbolTable.Define(node.Name.Value)
if symbol.Scope == GlobalScope { if symbol.Scope == GlobalScope {
c.emit(code.OpSetGlobal, symbol.Index) c.emit(code.OpSetGlobal, symbol.Index)
} else { } else {
@@ -268,6 +268,10 @@ func (c *Compiler) Compile(node ast.Node) error {
case *ast.FunctionLiteral: case *ast.FunctionLiteral:
c.enterScope() c.enterScope()
if node.Name != "" {
c.symbolTable.DefineFunctionName(node.Name)
}
for _, p := range node.Parameters { for _, p := range node.Parameters {
c.symbolTable.Define(p.Value) c.symbolTable.Define(p.Value)
} }
@@ -284,9 +288,14 @@ func (c *Compiler) Compile(node ast.Node) error {
c.emit(code.OpReturn) c.emit(code.OpReturn)
} }
freeSymbols := c.symbolTable.FreeSymbols
numLocals := c.symbolTable.numDefinitions numLocals := c.symbolTable.numDefinitions
instructions := c.leaveScope() instructions := c.leaveScope()
for _, s := range freeSymbols {
c.loadSymbol(s)
}
compiledFn := &object.CompiledFunction{ compiledFn := &object.CompiledFunction{
Instructions: instructions, Instructions: instructions,
NumLocals: numLocals, NumLocals: numLocals,
@@ -294,7 +303,7 @@ func (c *Compiler) Compile(node ast.Node) error {
} }
fnIndex := c.addConstant(compiledFn) fnIndex := c.addConstant(compiledFn)
c.emit(code.OpClosure, fnIndex, 0) c.emit(code.OpClosure, fnIndex, len(freeSymbols))
case *ast.ReturnStatement: case *ast.ReturnStatement:
err := c.Compile(node.ReturnValue) err := c.Compile(node.ReturnValue)
@@ -439,5 +448,9 @@ func (c *Compiler) loadSymbol(s Symbol) {
c.emit(code.OpGetLocal, s.Index) c.emit(code.OpGetLocal, s.Index)
case BuiltinScope: case BuiltinScope:
c.emit(code.OpGetBuiltin, s.Index) c.emit(code.OpGetBuiltin, s.Index)
case FreeScope:
c.emit(code.OpGetFree, s.Index)
case FunctionScope:
c.emit(code.OpCurrentClosure)
} }
} }

View File

@@ -503,6 +503,132 @@ func TestFunctionsWithoutReturnValue(t *testing.T) {
runCompilerTests(t, tests) runCompilerTests(t, tests)
} }
func TestClosures(t *testing.T) {
tests := []compilerTestCase{
{
input: `fn(a) {
fn(b) {
a + b
}
}
`,
expectedConstants: []interface{}{
[]code.Instructions{
code.Make(code.OpGetFree, 0),
code.Make(code.OpGetLocal, 0),
code.Make(code.OpAdd),
code.Make(code.OpReturnValue),
},
[]code.Instructions{
code.Make(code.OpGetLocal, 0),
code.Make(code.OpClosure, 0, 1),
code.Make(code.OpReturnValue),
},
},
expectedInstructions: []code.Instructions{
code.Make(code.OpClosure, 1, 0),
code.Make(code.OpPop),
},
},
{
input: `
fn(a) {
fn(b) {
fn(c) {
a + b + c
}
}
};
`,
expectedConstants: []interface{}{
[]code.Instructions{
code.Make(code.OpGetFree, 0),
code.Make(code.OpGetFree, 1),
code.Make(code.OpAdd),
code.Make(code.OpGetLocal, 0),
code.Make(code.OpAdd),
code.Make(code.OpReturnValue),
},
[]code.Instructions{
code.Make(code.OpGetFree, 0),
code.Make(code.OpGetLocal, 0),
code.Make(code.OpClosure, 0, 2),
code.Make(code.OpReturnValue),
},
[]code.Instructions{
code.Make(code.OpGetLocal, 0),
code.Make(code.OpClosure, 1, 1),
code.Make(code.OpReturnValue),
},
},
expectedInstructions: []code.Instructions{
code.Make(code.OpClosure, 2, 0),
code.Make(code.OpPop),
},
},
{
input: `
let global = 55;
fn() {
let a = 66;
fn() {
let b = 77;
fn() {
let c = 88;
global + a + b + c;
}
}
}
`,
expectedConstants: []interface{}{
55,
66,
77,
88,
[]code.Instructions{
code.Make(code.OpConstant, 3),
code.Make(code.OpSetLocal, 0),
code.Make(code.OpGetGlobal, 0),
code.Make(code.OpGetFree, 0),
code.Make(code.OpAdd),
code.Make(code.OpGetFree, 1),
code.Make(code.OpAdd),
code.Make(code.OpGetLocal, 0),
code.Make(code.OpAdd),
code.Make(code.OpReturnValue),
},
[]code.Instructions{
code.Make(code.OpConstant, 2),
code.Make(code.OpSetLocal, 0),
code.Make(code.OpGetFree, 0),
code.Make(code.OpGetLocal, 0),
code.Make(code.OpClosure, 4, 2),
code.Make(code.OpReturnValue),
},
[]code.Instructions{
code.Make(code.OpConstant, 1),
code.Make(code.OpSetLocal, 0),
code.Make(code.OpGetLocal, 0),
code.Make(code.OpClosure, 5, 1),
code.Make(code.OpReturnValue),
},
},
expectedInstructions: []code.Instructions{
code.Make(code.OpConstant, 0),
code.Make(code.OpSetGlobal, 0),
code.Make(code.OpClosure, 6, 0),
code.Make(code.OpPop),
},
},
}
runCompilerTests(t, tests)
}
func TestCompilerScopes(t *testing.T) { func TestCompilerScopes(t *testing.T) {
compiler := New() compiler := New()
if compiler.scopeIndex != 0 { if compiler.scopeIndex != 0 {
@@ -755,6 +881,75 @@ func TestBuiltins(t *testing.T) {
runCompilerTests(t, tests) runCompilerTests(t, tests)
} }
func TestRecursiveFunctions(t *testing.T) {
tests := []compilerTestCase{
{
input: `
let countDown = fn(x) { countDown(x - 1); };
countDown(1);
`,
expectedConstants: []interface{}{
1,
[]code.Instructions{
code.Make(code.OpCurrentClosure),
code.Make(code.OpGetLocal, 0),
code.Make(code.OpConstant, 0),
code.Make(code.OpSub),
code.Make(code.OpCall, 1),
code.Make(code.OpReturnValue),
},
1,
},
expectedInstructions: []code.Instructions{
code.Make(code.OpClosure, 1, 0),
code.Make(code.OpSetGlobal, 0),
code.Make(code.OpGetGlobal, 0),
code.Make(code.OpConstant, 2),
code.Make(code.OpCall, 1),
code.Make(code.OpPop),
},
},
{
input: `
let wrapper = fn() {
let countDown = fn(x) { countDown(x - 1); };
countDown(1);
};
wrapper();
`,
expectedConstants: []interface{}{
1,
[]code.Instructions{
code.Make(code.OpCurrentClosure),
code.Make(code.OpGetLocal, 0),
code.Make(code.OpConstant, 0),
code.Make(code.OpSub),
code.Make(code.OpCall, 1),
code.Make(code.OpReturnValue),
},
1,
[]code.Instructions{
code.Make(code.OpClosure, 1, 0),
code.Make(code.OpSetLocal, 0),
code.Make(code.OpGetLocal, 0),
code.Make(code.OpConstant, 2),
code.Make(code.OpCall, 1),
code.Make(code.OpReturnValue),
},
},
expectedInstructions: []code.Instructions{
code.Make(code.OpClosure, 3, 0),
code.Make(code.OpSetGlobal, 0),
code.Make(code.OpGetGlobal, 0),
code.Make(code.OpCall, 0),
code.Make(code.OpPop),
},
},
}
runCompilerTests(t, tests)
}
func runCompilerTests(t *testing.T, tests []compilerTestCase) { func runCompilerTests(t *testing.T, tests []compilerTestCase) {
t.Helper() t.Helper()

View File

@@ -6,6 +6,8 @@ const (
LocalScope SymbolScope = "LOCAL" LocalScope SymbolScope = "LOCAL"
GlobalScope SymbolScope = "GLOBAL" GlobalScope SymbolScope = "GLOBAL"
BuiltinScope SymbolScope = "BUILTIN" BuiltinScope SymbolScope = "BUILTIN"
FreeScope SymbolScope = "FREE"
FunctionScope SymbolScope = "FUNCTION"
) )
type Symbol struct { type Symbol struct {
@@ -19,6 +21,8 @@ type SymbolTable struct {
store map[string]Symbol store map[string]Symbol
numDefinitions int numDefinitions int
FreeSymbols []Symbol
} }
func NewEnclosedSymbolTable(outer *SymbolTable) *SymbolTable { func NewEnclosedSymbolTable(outer *SymbolTable) *SymbolTable {
@@ -29,7 +33,8 @@ func NewEnclosedSymbolTable(outer *SymbolTable) *SymbolTable {
func NewSymbolTable() *SymbolTable { func NewSymbolTable() *SymbolTable {
s := make(map[string]Symbol) s := make(map[string]Symbol)
return &SymbolTable{store: s} free := []Symbol{}
return &SymbolTable{store: s, FreeSymbols: free}
} }
func (s *SymbolTable) Define(name string) Symbol { func (s *SymbolTable) Define(name string) Symbol {
@@ -49,9 +54,19 @@ func (s *SymbolTable) Resolve(name string) (Symbol, bool) {
obj, ok := s.store[name] obj, ok := s.store[name]
if !ok && s.Outer != nil { if !ok && s.Outer != nil {
obj, ok = s.Outer.Resolve(name) obj, ok = s.Outer.Resolve(name)
if !ok {
return obj, ok return obj, ok
} }
if obj.Scope == GlobalScope || obj.Scope == BuiltinScope {
return obj, ok
}
free := s.DefineFree(obj)
return free, true
}
return obj, ok return obj, ok
} }
@@ -60,3 +75,19 @@ func (s *SymbolTable) DefineBuiltin(index int, name string) Symbol {
s.store[name] = symbol s.store[name] = symbol
return symbol return symbol
} }
func (s *SymbolTable) DefineFree(original Symbol) Symbol {
s.FreeSymbols = append(s.FreeSymbols, original)
symbol := Symbol{Name: original.Name, Index: len(s.FreeSymbols) - 1}
symbol.Scope = FreeScope
s.store[original.Name] = symbol
return symbol
}
func (s *SymbolTable) DefineFunctionName(name string) Symbol {
symbol := Symbol{Name: name, Index: 0, Scope: FunctionScope}
s.store[name] = symbol
return symbol
}

View File

@@ -197,3 +197,154 @@ func TestDefineResolveBuiltins(t *testing.T) {
} }
} }
} }
func TestResolveFree(t *testing.T) {
global := NewSymbolTable()
global.Define("a")
global.Define("b")
firstLocal := NewEnclosedSymbolTable(global)
firstLocal.Define("c")
firstLocal.Define("d")
secondLocal := NewEnclosedSymbolTable(firstLocal)
secondLocal.Define("e")
secondLocal.Define("f")
tests := []struct {
table *SymbolTable
expectedSymbols []Symbol
expectedFreeSymbols []Symbol
}{
{
firstLocal,
[]Symbol{
Symbol{Name: "a", Scope: GlobalScope, Index: 0},
Symbol{Name: "b", Scope: GlobalScope, Index: 1},
Symbol{Name: "c", Scope: LocalScope, Index: 0},
Symbol{Name: "d", Scope: LocalScope, Index: 1},
},
[]Symbol{},
},
{
secondLocal,
[]Symbol{
Symbol{Name: "a", Scope: GlobalScope, Index: 0},
Symbol{Name: "b", Scope: GlobalScope, Index: 1},
Symbol{Name: "c", Scope: FreeScope, Index: 0},
Symbol{Name: "d", Scope: FreeScope, Index: 1},
Symbol{Name: "e", Scope: LocalScope, Index: 0},
Symbol{Name: "f", Scope: LocalScope, Index: 1},
},
[]Symbol{
Symbol{Name: "c", Scope: LocalScope, Index: 0},
Symbol{Name: "d", Scope: LocalScope, Index: 1},
},
},
}
for _, tt := range tests {
for _, sym := range tt.expectedSymbols {
result, ok := tt.table.Resolve(sym.Name)
if !ok {
t.Errorf("name %s not resolvable", sym.Name)
continue
}
if result != sym {
t.Errorf("expected %s to resolve to %+v, got=%+v",
sym.Name, sym, result)
}
}
if len(tt.table.FreeSymbols) != len(tt.expectedFreeSymbols) {
t.Errorf("wrong number of free symbols. got=%d, want=%d",
len(tt.table.FreeSymbols), len(tt.expectedFreeSymbols))
continue
}
for i, sym := range tt.expectedFreeSymbols {
result := tt.table.FreeSymbols[i]
if result != sym {
t.Errorf("wrong free symbol. got=%+v, want=%+v",
result, sym)
}
}
}
}
func TestResolveUnresolvableFree(t *testing.T) {
global := NewSymbolTable()
global.Define("a")
firstLocal := NewEnclosedSymbolTable(global)
firstLocal.Define("c")
secondLocal := NewEnclosedSymbolTable(firstLocal)
secondLocal.Define("e")
secondLocal.Define("f")
expected := []Symbol{
Symbol{Name: "a", Scope: GlobalScope, Index: 0},
Symbol{Name: "c", Scope: FreeScope, Index: 0},
Symbol{Name: "e", Scope: LocalScope, Index: 0},
Symbol{Name: "f", Scope: LocalScope, Index: 1},
}
for _, sym := range expected {
result, ok := secondLocal.Resolve(sym.Name)
if !ok {
t.Errorf("name %s not resolvable", sym.Name)
continue
}
if result != sym {
t.Errorf("expected %s to resolve to %+v, got=%+v",
sym.Name, sym, result)
}
}
expectedUnresolvable := []string{
"b",
"d",
}
for _, name := range expectedUnresolvable {
_, ok := secondLocal.Resolve(name)
if ok {
t.Errorf("name %s resolved, but was expected not to", name)
}
}
}
func TestDefineAndResolveFunctionName(t *testing.T) {
global := NewSymbolTable()
global.DefineFunctionName("a")
expected := Symbol{Name: "a", Scope: FunctionScope, Index: 0}
result, ok := global.Resolve(expected.Name)
if !ok {
t.Fatalf("function name %s not resolvable", expected.Name)
}
if result != expected {
t.Errorf("expected %s to resolve to %+v, got=%+v",
expected.Name, expected, result)
}
}
func TestShadowingFunctionName(t *testing.T) {
global := NewSymbolTable()
global.DefineFunctionName("a")
global.Define("a")
expected := Symbol{Name: "a", Scope: GlobalScope, Index: 0}
result, ok := global.Resolve(expected.Name)
if !ok {
t.Fatalf("function name %s not resolvable", expected.Name)
}
if result != expected {
t.Errorf("expected %s to resolve to %+v, got=%+v",
expected.Name, expected, result)
}
}

View File

@@ -169,6 +169,10 @@ func (p *Parser) parseLetStatement() *ast.LetStatement {
stmt.Value = p.parseExpression(LOWEST) stmt.Value = p.parseExpression(LOWEST)
if fl, ok := stmt.Value.(*ast.FunctionLiteral); ok {
fl.Name = stmt.Name.Value
}
if p.peekTokenIs(token.SEMICOLON) { if p.peekTokenIs(token.SEMICOLON) {
p.nextToken() p.nextToken()
} }

View File

@@ -845,6 +845,37 @@ func TestParsingHashLiteralsWithExpressions(t *testing.T) {
} }
} }
func TestFunctionLiteralWithName(t *testing.T) {
input := `let myFunction = fn() { };`
l := lexer.New(input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 1 {
t.Fatalf("program.Body does not contain %d statements. got=%d\n",
1, len(program.Statements))
}
stmt, ok := program.Statements[0].(*ast.LetStatement)
if !ok {
t.Fatalf("program.Statements[0] is not ast.LetStatement. got=%T",
program.Statements[0])
}
function, ok := stmt.Value.(*ast.FunctionLiteral)
if !ok {
t.Fatalf("stmt.Value is not ast.FunctionLiteral. got=%T",
stmt.Value)
}
if function.Name != "myFunction" {
t.Fatalf("function literal name wrong. want 'myFunction', got=%q\n",
function.Name)
}
}
func testLetStatement(t *testing.T, s ast.Statement, name string) bool { func testLetStatement(t *testing.T, s ast.Statement, name string) bool {
if s.TokenLiteral() != "let" { if s.TokenLiteral() != "let" {
t.Errorf("s.TokenLiteral not 'let'. got=%q", s.TokenLiteral()) t.Errorf("s.TokenLiteral not 'let'. got=%q", s.TokenLiteral())

View File

@@ -264,10 +264,27 @@ func (vm *VM) Run() error {
case code.OpClosure: case code.OpClosure:
constIndex := code.ReadUint16(ins[ip+1:]) constIndex := code.ReadUint16(ins[ip+1:])
_ = code.ReadUint8(ins[ip+3:]) numFree := code.ReadUint8(ins[ip+3:])
vm.currentFrame().ip += 3 vm.currentFrame().ip += 3
err := vm.pushClosure(int(constIndex)) err := vm.pushClosure(int(constIndex), int(numFree))
if err != nil {
return err
}
case code.OpGetFree:
freeIndex := code.ReadUint8(ins[ip+1:])
vm.currentFrame().ip += 1
currentClosure := vm.currentFrame().cl
err := vm.push(currentClosure.Free[freeIndex])
if err != nil {
return err
}
case code.OpCurrentClosure:
currentClosure := vm.currentFrame().cl
err := vm.push(currentClosure)
if err != nil { if err != nil {
return err return err
} }
@@ -533,14 +550,20 @@ func (vm *VM) callBuiltin(builtin *object.Builtin, numArgs int) error {
return nil return nil
} }
func (vm *VM) pushClosure(constIndex int) error { func (vm *VM) pushClosure(constIndex, numFree int) error {
constant := vm.constants[constIndex] constant := vm.constants[constIndex]
function, ok := constant.(*object.CompiledFunction) function, ok := constant.(*object.CompiledFunction)
if !ok { if !ok {
return fmt.Errorf("not a function %+v", constant) return fmt.Errorf("not a function %+v", constant)
} }
closure := &object.Closure{Fn: function} free := make([]object.Object, numFree)
for i := 0; i < numFree; i++ {
free[i] = vm.stack[vm.sp-numFree+i]
}
vm.sp = vm.sp - numFree
closure := &object.Closure{Fn: function, Free: free}
return vm.push(closure) return vm.push(closure)
} }

View File

@@ -27,6 +27,19 @@ func runVmTests(t *testing.T, tests []vmTestCase) {
t.Fatalf("compiler error: %s", err) t.Fatalf("compiler error: %s", err)
} }
//for i, constant := range comp.Bytecode().Constants {
// fmt.Printf("CONSTANT %d %p (%T):\n", i, constant, constant)
//
// switch constant := constant.(type) {
// case *object.CompiledFunction:
// fmt.Printf(" Instructions:\n%s", constant.Instructions)
// case *object.Integer:
// fmt.Printf(" Value: %d\n", constant.Value)
// }
//
// fmt.Printf("\n")
//}
vm := New(comp.Bytecode()) vm := New(comp.Bytecode())
err = vm.Run() err = vm.Run()
if err != nil { if err != nil {
@@ -611,3 +624,134 @@ func TestBuiltinFunctions(t *testing.T) {
runVmTests(t, tests) runVmTests(t, tests)
} }
func TestClosures(t *testing.T) {
tests := []vmTestCase{
{
input: `
let newClosure = fn(a) {
fn() { a; };
};
let closure = newClosure(99);
closure();
`,
expected: 99,
},
{
input: `
let newAdder = fn(a, b) {
fn(c) { a + b + c };
};
let adder = newAdder(1, 2);
adder(8);
`,
expected: 11,
},
{
input: `
let newAdder = fn(a, b) {
let c = a + b;
fn(d) { c + d };
};
let adder = newAdder(1, 2);
adder(8);
`,
expected: 11,
},
{
input: `
let newAdderOuter = fn(a, b) {
let c = a + b;
fn(d) {
let e = d + c;
fn(f) { e + f; };
};
};
let newAdderInner = newAdderOuter(1, 2)
let adder = newAdderInner(3);
adder(8);
`,
expected: 14,
},
{
input: `
let a = 1;
let newAdderOuter = fn(b) {
fn(c) {
fn(d) { a + b + c + d };
};
};
let newAdderInner = newAdderOuter(2)
let adder = newAdderInner(3);
adder(8);
`,
expected: 14,
},
{
input: `
let newClosure = fn(a, b) {
let one = fn() { a; };
let two = fn() { b; };
fn() { one() + two(); };
};
let closure = newClosure(9, 90);
closure();
`,
expected: 99,
},
}
runVmTests(t, tests)
}
func TestRecursiveFunctions(t *testing.T) {
tests := []vmTestCase{
{
input: `
let countDown = fn(x) {
if (x == 0) {
return 0;
} else {
countDown(x - 1);
}
};
countDown(1);
`,
expected: 0,
},
{
input: `
let countDown = fn(x) {
if (x == 0) {
return 0;
} else {
countDown(x - 1);
}
};
let wrapper = fn() {
countDown(1);
};
wrapper();
`,
expected: 0,
},
{
input: `
let wrapper = fn() {
let countDown = fn(x) {
if (x == 0) {
return 0;
} else {
countDown(x - 1);
}
};
countDown(1);
};
wrapper();
`,
expected: 0,
},
}
runVmTests(t, tests)
}