closures and they can recurse!!!
This commit is contained in:
@@ -196,12 +196,12 @@ func (c *Compiler) Compile(node ast.Node) error {
|
||||
}
|
||||
|
||||
case *ast.LetStatement:
|
||||
symbol := c.symbolTable.Define(node.Name.Value)
|
||||
err := c.Compile(node.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
symbol := c.symbolTable.Define(node.Name.Value)
|
||||
if symbol.Scope == GlobalScope {
|
||||
c.emit(code.OpSetGlobal, symbol.Index)
|
||||
} else {
|
||||
@@ -268,6 +268,10 @@ func (c *Compiler) Compile(node ast.Node) error {
|
||||
case *ast.FunctionLiteral:
|
||||
c.enterScope()
|
||||
|
||||
if node.Name != "" {
|
||||
c.symbolTable.DefineFunctionName(node.Name)
|
||||
}
|
||||
|
||||
for _, p := range node.Parameters {
|
||||
c.symbolTable.Define(p.Value)
|
||||
}
|
||||
@@ -284,9 +288,14 @@ func (c *Compiler) Compile(node ast.Node) error {
|
||||
c.emit(code.OpReturn)
|
||||
}
|
||||
|
||||
freeSymbols := c.symbolTable.FreeSymbols
|
||||
numLocals := c.symbolTable.numDefinitions
|
||||
instructions := c.leaveScope()
|
||||
|
||||
for _, s := range freeSymbols {
|
||||
c.loadSymbol(s)
|
||||
}
|
||||
|
||||
compiledFn := &object.CompiledFunction{
|
||||
Instructions: instructions,
|
||||
NumLocals: numLocals,
|
||||
@@ -294,7 +303,7 @@ func (c *Compiler) Compile(node ast.Node) error {
|
||||
}
|
||||
|
||||
fnIndex := c.addConstant(compiledFn)
|
||||
c.emit(code.OpClosure, fnIndex, 0)
|
||||
c.emit(code.OpClosure, fnIndex, len(freeSymbols))
|
||||
|
||||
case *ast.ReturnStatement:
|
||||
err := c.Compile(node.ReturnValue)
|
||||
@@ -439,5 +448,9 @@ func (c *Compiler) loadSymbol(s Symbol) {
|
||||
c.emit(code.OpGetLocal, s.Index)
|
||||
case BuiltinScope:
|
||||
c.emit(code.OpGetBuiltin, s.Index)
|
||||
case FreeScope:
|
||||
c.emit(code.OpGetFree, s.Index)
|
||||
case FunctionScope:
|
||||
c.emit(code.OpCurrentClosure)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -503,6 +503,132 @@ func TestFunctionsWithoutReturnValue(t *testing.T) {
|
||||
runCompilerTests(t, tests)
|
||||
}
|
||||
|
||||
func TestClosures(t *testing.T) {
|
||||
tests := []compilerTestCase{
|
||||
{
|
||||
input: `fn(a) {
|
||||
fn(b) {
|
||||
a + b
|
||||
}
|
||||
}
|
||||
`,
|
||||
expectedConstants: []interface{}{
|
||||
[]code.Instructions{
|
||||
code.Make(code.OpGetFree, 0),
|
||||
code.Make(code.OpGetLocal, 0),
|
||||
code.Make(code.OpAdd),
|
||||
code.Make(code.OpReturnValue),
|
||||
},
|
||||
[]code.Instructions{
|
||||
code.Make(code.OpGetLocal, 0),
|
||||
code.Make(code.OpClosure, 0, 1),
|
||||
code.Make(code.OpReturnValue),
|
||||
},
|
||||
},
|
||||
expectedInstructions: []code.Instructions{
|
||||
code.Make(code.OpClosure, 1, 0),
|
||||
code.Make(code.OpPop),
|
||||
},
|
||||
},
|
||||
{
|
||||
input: `
|
||||
fn(a) {
|
||||
fn(b) {
|
||||
fn(c) {
|
||||
a + b + c
|
||||
}
|
||||
}
|
||||
};
|
||||
`,
|
||||
expectedConstants: []interface{}{
|
||||
[]code.Instructions{
|
||||
code.Make(code.OpGetFree, 0),
|
||||
code.Make(code.OpGetFree, 1),
|
||||
code.Make(code.OpAdd),
|
||||
code.Make(code.OpGetLocal, 0),
|
||||
code.Make(code.OpAdd),
|
||||
code.Make(code.OpReturnValue),
|
||||
},
|
||||
[]code.Instructions{
|
||||
code.Make(code.OpGetFree, 0),
|
||||
code.Make(code.OpGetLocal, 0),
|
||||
code.Make(code.OpClosure, 0, 2),
|
||||
code.Make(code.OpReturnValue),
|
||||
},
|
||||
[]code.Instructions{
|
||||
code.Make(code.OpGetLocal, 0),
|
||||
code.Make(code.OpClosure, 1, 1),
|
||||
code.Make(code.OpReturnValue),
|
||||
},
|
||||
},
|
||||
expectedInstructions: []code.Instructions{
|
||||
code.Make(code.OpClosure, 2, 0),
|
||||
code.Make(code.OpPop),
|
||||
},
|
||||
},
|
||||
{
|
||||
input: `
|
||||
let global = 55;
|
||||
|
||||
fn() {
|
||||
let a = 66;
|
||||
|
||||
fn() {
|
||||
let b = 77;
|
||||
|
||||
fn() {
|
||||
let c = 88;
|
||||
|
||||
global + a + b + c;
|
||||
}
|
||||
}
|
||||
}
|
||||
`,
|
||||
expectedConstants: []interface{}{
|
||||
55,
|
||||
66,
|
||||
77,
|
||||
88,
|
||||
[]code.Instructions{
|
||||
code.Make(code.OpConstant, 3),
|
||||
code.Make(code.OpSetLocal, 0),
|
||||
code.Make(code.OpGetGlobal, 0),
|
||||
code.Make(code.OpGetFree, 0),
|
||||
code.Make(code.OpAdd),
|
||||
code.Make(code.OpGetFree, 1),
|
||||
code.Make(code.OpAdd),
|
||||
code.Make(code.OpGetLocal, 0),
|
||||
code.Make(code.OpAdd),
|
||||
code.Make(code.OpReturnValue),
|
||||
},
|
||||
[]code.Instructions{
|
||||
code.Make(code.OpConstant, 2),
|
||||
code.Make(code.OpSetLocal, 0),
|
||||
code.Make(code.OpGetFree, 0),
|
||||
code.Make(code.OpGetLocal, 0),
|
||||
code.Make(code.OpClosure, 4, 2),
|
||||
code.Make(code.OpReturnValue),
|
||||
},
|
||||
[]code.Instructions{
|
||||
code.Make(code.OpConstant, 1),
|
||||
code.Make(code.OpSetLocal, 0),
|
||||
code.Make(code.OpGetLocal, 0),
|
||||
code.Make(code.OpClosure, 5, 1),
|
||||
code.Make(code.OpReturnValue),
|
||||
},
|
||||
},
|
||||
expectedInstructions: []code.Instructions{
|
||||
code.Make(code.OpConstant, 0),
|
||||
code.Make(code.OpSetGlobal, 0),
|
||||
code.Make(code.OpClosure, 6, 0),
|
||||
code.Make(code.OpPop),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
runCompilerTests(t, tests)
|
||||
}
|
||||
|
||||
func TestCompilerScopes(t *testing.T) {
|
||||
compiler := New()
|
||||
if compiler.scopeIndex != 0 {
|
||||
@@ -755,6 +881,75 @@ func TestBuiltins(t *testing.T) {
|
||||
runCompilerTests(t, tests)
|
||||
}
|
||||
|
||||
func TestRecursiveFunctions(t *testing.T) {
|
||||
tests := []compilerTestCase{
|
||||
{
|
||||
input: `
|
||||
let countDown = fn(x) { countDown(x - 1); };
|
||||
countDown(1);
|
||||
`,
|
||||
expectedConstants: []interface{}{
|
||||
1,
|
||||
[]code.Instructions{
|
||||
code.Make(code.OpCurrentClosure),
|
||||
code.Make(code.OpGetLocal, 0),
|
||||
code.Make(code.OpConstant, 0),
|
||||
code.Make(code.OpSub),
|
||||
code.Make(code.OpCall, 1),
|
||||
code.Make(code.OpReturnValue),
|
||||
},
|
||||
1,
|
||||
},
|
||||
expectedInstructions: []code.Instructions{
|
||||
code.Make(code.OpClosure, 1, 0),
|
||||
code.Make(code.OpSetGlobal, 0),
|
||||
code.Make(code.OpGetGlobal, 0),
|
||||
code.Make(code.OpConstant, 2),
|
||||
code.Make(code.OpCall, 1),
|
||||
code.Make(code.OpPop),
|
||||
},
|
||||
},
|
||||
{
|
||||
input: `
|
||||
let wrapper = fn() {
|
||||
let countDown = fn(x) { countDown(x - 1); };
|
||||
countDown(1);
|
||||
};
|
||||
wrapper();
|
||||
`,
|
||||
expectedConstants: []interface{}{
|
||||
1,
|
||||
[]code.Instructions{
|
||||
code.Make(code.OpCurrentClosure),
|
||||
code.Make(code.OpGetLocal, 0),
|
||||
code.Make(code.OpConstant, 0),
|
||||
code.Make(code.OpSub),
|
||||
code.Make(code.OpCall, 1),
|
||||
code.Make(code.OpReturnValue),
|
||||
},
|
||||
1,
|
||||
[]code.Instructions{
|
||||
code.Make(code.OpClosure, 1, 0),
|
||||
code.Make(code.OpSetLocal, 0),
|
||||
code.Make(code.OpGetLocal, 0),
|
||||
code.Make(code.OpConstant, 2),
|
||||
code.Make(code.OpCall, 1),
|
||||
code.Make(code.OpReturnValue),
|
||||
},
|
||||
},
|
||||
expectedInstructions: []code.Instructions{
|
||||
code.Make(code.OpClosure, 3, 0),
|
||||
code.Make(code.OpSetGlobal, 0),
|
||||
code.Make(code.OpGetGlobal, 0),
|
||||
code.Make(code.OpCall, 0),
|
||||
code.Make(code.OpPop),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
runCompilerTests(t, tests)
|
||||
}
|
||||
|
||||
func runCompilerTests(t *testing.T, tests []compilerTestCase) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -3,9 +3,11 @@ package compiler
|
||||
type SymbolScope string
|
||||
|
||||
const (
|
||||
LocalScope SymbolScope = "LOCAL"
|
||||
GlobalScope SymbolScope = "GLOBAL"
|
||||
BuiltinScope SymbolScope = "BUILTIN"
|
||||
LocalScope SymbolScope = "LOCAL"
|
||||
GlobalScope SymbolScope = "GLOBAL"
|
||||
BuiltinScope SymbolScope = "BUILTIN"
|
||||
FreeScope SymbolScope = "FREE"
|
||||
FunctionScope SymbolScope = "FUNCTION"
|
||||
)
|
||||
|
||||
type Symbol struct {
|
||||
@@ -19,6 +21,8 @@ type SymbolTable struct {
|
||||
|
||||
store map[string]Symbol
|
||||
numDefinitions int
|
||||
|
||||
FreeSymbols []Symbol
|
||||
}
|
||||
|
||||
func NewEnclosedSymbolTable(outer *SymbolTable) *SymbolTable {
|
||||
@@ -29,7 +33,8 @@ func NewEnclosedSymbolTable(outer *SymbolTable) *SymbolTable {
|
||||
|
||||
func NewSymbolTable() *SymbolTable {
|
||||
s := make(map[string]Symbol)
|
||||
return &SymbolTable{store: s}
|
||||
free := []Symbol{}
|
||||
return &SymbolTable{store: s, FreeSymbols: free}
|
||||
}
|
||||
|
||||
func (s *SymbolTable) Define(name string) Symbol {
|
||||
@@ -49,7 +54,17 @@ func (s *SymbolTable) Resolve(name string) (Symbol, bool) {
|
||||
obj, ok := s.store[name]
|
||||
if !ok && s.Outer != nil {
|
||||
obj, ok = s.Outer.Resolve(name)
|
||||
return obj, ok
|
||||
if !ok {
|
||||
return obj, ok
|
||||
}
|
||||
|
||||
if obj.Scope == GlobalScope || obj.Scope == BuiltinScope {
|
||||
return obj, ok
|
||||
}
|
||||
|
||||
free := s.DefineFree(obj)
|
||||
return free, true
|
||||
|
||||
}
|
||||
|
||||
return obj, ok
|
||||
@@ -60,3 +75,19 @@ func (s *SymbolTable) DefineBuiltin(index int, name string) Symbol {
|
||||
s.store[name] = symbol
|
||||
return symbol
|
||||
}
|
||||
|
||||
func (s *SymbolTable) DefineFree(original Symbol) Symbol {
|
||||
s.FreeSymbols = append(s.FreeSymbols, original)
|
||||
|
||||
symbol := Symbol{Name: original.Name, Index: len(s.FreeSymbols) - 1}
|
||||
symbol.Scope = FreeScope
|
||||
|
||||
s.store[original.Name] = symbol
|
||||
return symbol
|
||||
}
|
||||
|
||||
func (s *SymbolTable) DefineFunctionName(name string) Symbol {
|
||||
symbol := Symbol{Name: name, Index: 0, Scope: FunctionScope}
|
||||
s.store[name] = symbol
|
||||
return symbol
|
||||
}
|
||||
|
||||
@@ -197,3 +197,154 @@ func TestDefineResolveBuiltins(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveFree(t *testing.T) {
|
||||
global := NewSymbolTable()
|
||||
global.Define("a")
|
||||
global.Define("b")
|
||||
|
||||
firstLocal := NewEnclosedSymbolTable(global)
|
||||
firstLocal.Define("c")
|
||||
firstLocal.Define("d")
|
||||
|
||||
secondLocal := NewEnclosedSymbolTable(firstLocal)
|
||||
secondLocal.Define("e")
|
||||
secondLocal.Define("f")
|
||||
|
||||
tests := []struct {
|
||||
table *SymbolTable
|
||||
expectedSymbols []Symbol
|
||||
expectedFreeSymbols []Symbol
|
||||
}{
|
||||
{
|
||||
firstLocal,
|
||||
[]Symbol{
|
||||
Symbol{Name: "a", Scope: GlobalScope, Index: 0},
|
||||
Symbol{Name: "b", Scope: GlobalScope, Index: 1},
|
||||
Symbol{Name: "c", Scope: LocalScope, Index: 0},
|
||||
Symbol{Name: "d", Scope: LocalScope, Index: 1},
|
||||
},
|
||||
[]Symbol{},
|
||||
},
|
||||
{
|
||||
secondLocal,
|
||||
[]Symbol{
|
||||
Symbol{Name: "a", Scope: GlobalScope, Index: 0},
|
||||
Symbol{Name: "b", Scope: GlobalScope, Index: 1},
|
||||
Symbol{Name: "c", Scope: FreeScope, Index: 0},
|
||||
Symbol{Name: "d", Scope: FreeScope, Index: 1},
|
||||
Symbol{Name: "e", Scope: LocalScope, Index: 0},
|
||||
Symbol{Name: "f", Scope: LocalScope, Index: 1},
|
||||
},
|
||||
[]Symbol{
|
||||
Symbol{Name: "c", Scope: LocalScope, Index: 0},
|
||||
Symbol{Name: "d", Scope: LocalScope, Index: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
for _, sym := range tt.expectedSymbols {
|
||||
result, ok := tt.table.Resolve(sym.Name)
|
||||
if !ok {
|
||||
t.Errorf("name %s not resolvable", sym.Name)
|
||||
continue
|
||||
}
|
||||
if result != sym {
|
||||
t.Errorf("expected %s to resolve to %+v, got=%+v",
|
||||
sym.Name, sym, result)
|
||||
}
|
||||
}
|
||||
|
||||
if len(tt.table.FreeSymbols) != len(tt.expectedFreeSymbols) {
|
||||
t.Errorf("wrong number of free symbols. got=%d, want=%d",
|
||||
len(tt.table.FreeSymbols), len(tt.expectedFreeSymbols))
|
||||
continue
|
||||
}
|
||||
|
||||
for i, sym := range tt.expectedFreeSymbols {
|
||||
result := tt.table.FreeSymbols[i]
|
||||
if result != sym {
|
||||
t.Errorf("wrong free symbol. got=%+v, want=%+v",
|
||||
result, sym)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUnresolvableFree(t *testing.T) {
|
||||
global := NewSymbolTable()
|
||||
global.Define("a")
|
||||
|
||||
firstLocal := NewEnclosedSymbolTable(global)
|
||||
firstLocal.Define("c")
|
||||
|
||||
secondLocal := NewEnclosedSymbolTable(firstLocal)
|
||||
secondLocal.Define("e")
|
||||
secondLocal.Define("f")
|
||||
|
||||
expected := []Symbol{
|
||||
Symbol{Name: "a", Scope: GlobalScope, Index: 0},
|
||||
Symbol{Name: "c", Scope: FreeScope, Index: 0},
|
||||
Symbol{Name: "e", Scope: LocalScope, Index: 0},
|
||||
Symbol{Name: "f", Scope: LocalScope, Index: 1},
|
||||
}
|
||||
|
||||
for _, sym := range expected {
|
||||
result, ok := secondLocal.Resolve(sym.Name)
|
||||
if !ok {
|
||||
t.Errorf("name %s not resolvable", sym.Name)
|
||||
continue
|
||||
}
|
||||
if result != sym {
|
||||
t.Errorf("expected %s to resolve to %+v, got=%+v",
|
||||
sym.Name, sym, result)
|
||||
}
|
||||
}
|
||||
|
||||
expectedUnresolvable := []string{
|
||||
"b",
|
||||
"d",
|
||||
}
|
||||
|
||||
for _, name := range expectedUnresolvable {
|
||||
_, ok := secondLocal.Resolve(name)
|
||||
if ok {
|
||||
t.Errorf("name %s resolved, but was expected not to", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefineAndResolveFunctionName(t *testing.T) {
|
||||
global := NewSymbolTable()
|
||||
global.DefineFunctionName("a")
|
||||
|
||||
expected := Symbol{Name: "a", Scope: FunctionScope, Index: 0}
|
||||
|
||||
result, ok := global.Resolve(expected.Name)
|
||||
if !ok {
|
||||
t.Fatalf("function name %s not resolvable", expected.Name)
|
||||
}
|
||||
|
||||
if result != expected {
|
||||
t.Errorf("expected %s to resolve to %+v, got=%+v",
|
||||
expected.Name, expected, result)
|
||||
}
|
||||
}
|
||||
func TestShadowingFunctionName(t *testing.T) {
|
||||
global := NewSymbolTable()
|
||||
global.DefineFunctionName("a")
|
||||
global.Define("a")
|
||||
|
||||
expected := Symbol{Name: "a", Scope: GlobalScope, Index: 0}
|
||||
|
||||
result, ok := global.Resolve(expected.Name)
|
||||
if !ok {
|
||||
t.Fatalf("function name %s not resolvable", expected.Name)
|
||||
}
|
||||
|
||||
if result != expected {
|
||||
t.Errorf("expected %s to resolve to %+v, got=%+v",
|
||||
expected.Name, expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user