closures and they can recurse!!!
This commit is contained in:
@@ -2,6 +2,7 @@ package ast
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"fmt"
|
||||||
"monkey/token"
|
"monkey/token"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@@ -220,6 +221,7 @@ type FunctionLiteral struct {
|
|||||||
Token token.Token // The 'fn' token
|
Token token.Token // The 'fn' token
|
||||||
Parameters []*Identifier
|
Parameters []*Identifier
|
||||||
Body *BlockStatement
|
Body *BlockStatement
|
||||||
|
Name string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fl *FunctionLiteral) expressionNode() {}
|
func (fl *FunctionLiteral) expressionNode() {}
|
||||||
@@ -233,6 +235,9 @@ func (fl *FunctionLiteral) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
out.WriteString(fl.TokenLiteral())
|
out.WriteString(fl.TokenLiteral())
|
||||||
|
if fl.Name != "" {
|
||||||
|
out.WriteString(fmt.Sprintf("<%s>", fl.Name))
|
||||||
|
}
|
||||||
out.WriteString("(")
|
out.WriteString("(")
|
||||||
out.WriteString(strings.Join(params, ", "))
|
out.WriteString(strings.Join(params, ", "))
|
||||||
out.WriteString(") ")
|
out.WriteString(") ")
|
||||||
|
|||||||
@@ -39,6 +39,8 @@ const (
|
|||||||
OpSetLocal
|
OpSetLocal
|
||||||
OpGetBuiltin
|
OpGetBuiltin
|
||||||
OpClosure
|
OpClosure
|
||||||
|
OpGetFree
|
||||||
|
OpCurrentClosure
|
||||||
)
|
)
|
||||||
|
|
||||||
type Definition struct {
|
type Definition struct {
|
||||||
@@ -75,6 +77,8 @@ var definitions = map[Opcode]*Definition{
|
|||||||
OpSetLocal: {"OpSetLocal", []int{1}},
|
OpSetLocal: {"OpSetLocal", []int{1}},
|
||||||
OpGetBuiltin: {"OpGetBuiltin", []int{1}},
|
OpGetBuiltin: {"OpGetBuiltin", []int{1}},
|
||||||
OpClosure: {"OpClosure", []int{2, 1}},
|
OpClosure: {"OpClosure", []int{2, 1}},
|
||||||
|
OpGetFree: {"OpGetFree", []int{1}},
|
||||||
|
OpCurrentClosure: {"OpCurrentClosure", []int{}},
|
||||||
}
|
}
|
||||||
|
|
||||||
func Lookup(op byte) (*Definition, error) {
|
func Lookup(op byte) (*Definition, error) {
|
||||||
|
|||||||
@@ -196,12 +196,12 @@ func (c *Compiler) Compile(node ast.Node) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
case *ast.LetStatement:
|
case *ast.LetStatement:
|
||||||
|
symbol := c.symbolTable.Define(node.Name.Value)
|
||||||
err := c.Compile(node.Value)
|
err := c.Compile(node.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
symbol := c.symbolTable.Define(node.Name.Value)
|
|
||||||
if symbol.Scope == GlobalScope {
|
if symbol.Scope == GlobalScope {
|
||||||
c.emit(code.OpSetGlobal, symbol.Index)
|
c.emit(code.OpSetGlobal, symbol.Index)
|
||||||
} else {
|
} else {
|
||||||
@@ -268,6 +268,10 @@ func (c *Compiler) Compile(node ast.Node) error {
|
|||||||
case *ast.FunctionLiteral:
|
case *ast.FunctionLiteral:
|
||||||
c.enterScope()
|
c.enterScope()
|
||||||
|
|
||||||
|
if node.Name != "" {
|
||||||
|
c.symbolTable.DefineFunctionName(node.Name)
|
||||||
|
}
|
||||||
|
|
||||||
for _, p := range node.Parameters {
|
for _, p := range node.Parameters {
|
||||||
c.symbolTable.Define(p.Value)
|
c.symbolTable.Define(p.Value)
|
||||||
}
|
}
|
||||||
@@ -284,9 +288,14 @@ func (c *Compiler) Compile(node ast.Node) error {
|
|||||||
c.emit(code.OpReturn)
|
c.emit(code.OpReturn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
freeSymbols := c.symbolTable.FreeSymbols
|
||||||
numLocals := c.symbolTable.numDefinitions
|
numLocals := c.symbolTable.numDefinitions
|
||||||
instructions := c.leaveScope()
|
instructions := c.leaveScope()
|
||||||
|
|
||||||
|
for _, s := range freeSymbols {
|
||||||
|
c.loadSymbol(s)
|
||||||
|
}
|
||||||
|
|
||||||
compiledFn := &object.CompiledFunction{
|
compiledFn := &object.CompiledFunction{
|
||||||
Instructions: instructions,
|
Instructions: instructions,
|
||||||
NumLocals: numLocals,
|
NumLocals: numLocals,
|
||||||
@@ -294,7 +303,7 @@ func (c *Compiler) Compile(node ast.Node) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fnIndex := c.addConstant(compiledFn)
|
fnIndex := c.addConstant(compiledFn)
|
||||||
c.emit(code.OpClosure, fnIndex, 0)
|
c.emit(code.OpClosure, fnIndex, len(freeSymbols))
|
||||||
|
|
||||||
case *ast.ReturnStatement:
|
case *ast.ReturnStatement:
|
||||||
err := c.Compile(node.ReturnValue)
|
err := c.Compile(node.ReturnValue)
|
||||||
@@ -439,5 +448,9 @@ func (c *Compiler) loadSymbol(s Symbol) {
|
|||||||
c.emit(code.OpGetLocal, s.Index)
|
c.emit(code.OpGetLocal, s.Index)
|
||||||
case BuiltinScope:
|
case BuiltinScope:
|
||||||
c.emit(code.OpGetBuiltin, s.Index)
|
c.emit(code.OpGetBuiltin, s.Index)
|
||||||
|
case FreeScope:
|
||||||
|
c.emit(code.OpGetFree, s.Index)
|
||||||
|
case FunctionScope:
|
||||||
|
c.emit(code.OpCurrentClosure)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -503,6 +503,132 @@ func TestFunctionsWithoutReturnValue(t *testing.T) {
|
|||||||
runCompilerTests(t, tests)
|
runCompilerTests(t, tests)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClosures(t *testing.T) {
|
||||||
|
tests := []compilerTestCase{
|
||||||
|
{
|
||||||
|
input: `fn(a) {
|
||||||
|
fn(b) {
|
||||||
|
a + b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
expectedConstants: []interface{}{
|
||||||
|
[]code.Instructions{
|
||||||
|
code.Make(code.OpGetFree, 0),
|
||||||
|
code.Make(code.OpGetLocal, 0),
|
||||||
|
code.Make(code.OpAdd),
|
||||||
|
code.Make(code.OpReturnValue),
|
||||||
|
},
|
||||||
|
[]code.Instructions{
|
||||||
|
code.Make(code.OpGetLocal, 0),
|
||||||
|
code.Make(code.OpClosure, 0, 1),
|
||||||
|
code.Make(code.OpReturnValue),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedInstructions: []code.Instructions{
|
||||||
|
code.Make(code.OpClosure, 1, 0),
|
||||||
|
code.Make(code.OpPop),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: `
|
||||||
|
fn(a) {
|
||||||
|
fn(b) {
|
||||||
|
fn(c) {
|
||||||
|
a + b + c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
`,
|
||||||
|
expectedConstants: []interface{}{
|
||||||
|
[]code.Instructions{
|
||||||
|
code.Make(code.OpGetFree, 0),
|
||||||
|
code.Make(code.OpGetFree, 1),
|
||||||
|
code.Make(code.OpAdd),
|
||||||
|
code.Make(code.OpGetLocal, 0),
|
||||||
|
code.Make(code.OpAdd),
|
||||||
|
code.Make(code.OpReturnValue),
|
||||||
|
},
|
||||||
|
[]code.Instructions{
|
||||||
|
code.Make(code.OpGetFree, 0),
|
||||||
|
code.Make(code.OpGetLocal, 0),
|
||||||
|
code.Make(code.OpClosure, 0, 2),
|
||||||
|
code.Make(code.OpReturnValue),
|
||||||
|
},
|
||||||
|
[]code.Instructions{
|
||||||
|
code.Make(code.OpGetLocal, 0),
|
||||||
|
code.Make(code.OpClosure, 1, 1),
|
||||||
|
code.Make(code.OpReturnValue),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedInstructions: []code.Instructions{
|
||||||
|
code.Make(code.OpClosure, 2, 0),
|
||||||
|
code.Make(code.OpPop),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: `
|
||||||
|
let global = 55;
|
||||||
|
|
||||||
|
fn() {
|
||||||
|
let a = 66;
|
||||||
|
|
||||||
|
fn() {
|
||||||
|
let b = 77;
|
||||||
|
|
||||||
|
fn() {
|
||||||
|
let c = 88;
|
||||||
|
|
||||||
|
global + a + b + c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
expectedConstants: []interface{}{
|
||||||
|
55,
|
||||||
|
66,
|
||||||
|
77,
|
||||||
|
88,
|
||||||
|
[]code.Instructions{
|
||||||
|
code.Make(code.OpConstant, 3),
|
||||||
|
code.Make(code.OpSetLocal, 0),
|
||||||
|
code.Make(code.OpGetGlobal, 0),
|
||||||
|
code.Make(code.OpGetFree, 0),
|
||||||
|
code.Make(code.OpAdd),
|
||||||
|
code.Make(code.OpGetFree, 1),
|
||||||
|
code.Make(code.OpAdd),
|
||||||
|
code.Make(code.OpGetLocal, 0),
|
||||||
|
code.Make(code.OpAdd),
|
||||||
|
code.Make(code.OpReturnValue),
|
||||||
|
},
|
||||||
|
[]code.Instructions{
|
||||||
|
code.Make(code.OpConstant, 2),
|
||||||
|
code.Make(code.OpSetLocal, 0),
|
||||||
|
code.Make(code.OpGetFree, 0),
|
||||||
|
code.Make(code.OpGetLocal, 0),
|
||||||
|
code.Make(code.OpClosure, 4, 2),
|
||||||
|
code.Make(code.OpReturnValue),
|
||||||
|
},
|
||||||
|
[]code.Instructions{
|
||||||
|
code.Make(code.OpConstant, 1),
|
||||||
|
code.Make(code.OpSetLocal, 0),
|
||||||
|
code.Make(code.OpGetLocal, 0),
|
||||||
|
code.Make(code.OpClosure, 5, 1),
|
||||||
|
code.Make(code.OpReturnValue),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedInstructions: []code.Instructions{
|
||||||
|
code.Make(code.OpConstant, 0),
|
||||||
|
code.Make(code.OpSetGlobal, 0),
|
||||||
|
code.Make(code.OpClosure, 6, 0),
|
||||||
|
code.Make(code.OpPop),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
runCompilerTests(t, tests)
|
||||||
|
}
|
||||||
|
|
||||||
func TestCompilerScopes(t *testing.T) {
|
func TestCompilerScopes(t *testing.T) {
|
||||||
compiler := New()
|
compiler := New()
|
||||||
if compiler.scopeIndex != 0 {
|
if compiler.scopeIndex != 0 {
|
||||||
@@ -755,6 +881,75 @@ func TestBuiltins(t *testing.T) {
|
|||||||
runCompilerTests(t, tests)
|
runCompilerTests(t, tests)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRecursiveFunctions(t *testing.T) {
|
||||||
|
tests := []compilerTestCase{
|
||||||
|
{
|
||||||
|
input: `
|
||||||
|
let countDown = fn(x) { countDown(x - 1); };
|
||||||
|
countDown(1);
|
||||||
|
`,
|
||||||
|
expectedConstants: []interface{}{
|
||||||
|
1,
|
||||||
|
[]code.Instructions{
|
||||||
|
code.Make(code.OpCurrentClosure),
|
||||||
|
code.Make(code.OpGetLocal, 0),
|
||||||
|
code.Make(code.OpConstant, 0),
|
||||||
|
code.Make(code.OpSub),
|
||||||
|
code.Make(code.OpCall, 1),
|
||||||
|
code.Make(code.OpReturnValue),
|
||||||
|
},
|
||||||
|
1,
|
||||||
|
},
|
||||||
|
expectedInstructions: []code.Instructions{
|
||||||
|
code.Make(code.OpClosure, 1, 0),
|
||||||
|
code.Make(code.OpSetGlobal, 0),
|
||||||
|
code.Make(code.OpGetGlobal, 0),
|
||||||
|
code.Make(code.OpConstant, 2),
|
||||||
|
code.Make(code.OpCall, 1),
|
||||||
|
code.Make(code.OpPop),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: `
|
||||||
|
let wrapper = fn() {
|
||||||
|
let countDown = fn(x) { countDown(x - 1); };
|
||||||
|
countDown(1);
|
||||||
|
};
|
||||||
|
wrapper();
|
||||||
|
`,
|
||||||
|
expectedConstants: []interface{}{
|
||||||
|
1,
|
||||||
|
[]code.Instructions{
|
||||||
|
code.Make(code.OpCurrentClosure),
|
||||||
|
code.Make(code.OpGetLocal, 0),
|
||||||
|
code.Make(code.OpConstant, 0),
|
||||||
|
code.Make(code.OpSub),
|
||||||
|
code.Make(code.OpCall, 1),
|
||||||
|
code.Make(code.OpReturnValue),
|
||||||
|
},
|
||||||
|
1,
|
||||||
|
[]code.Instructions{
|
||||||
|
code.Make(code.OpClosure, 1, 0),
|
||||||
|
code.Make(code.OpSetLocal, 0),
|
||||||
|
code.Make(code.OpGetLocal, 0),
|
||||||
|
code.Make(code.OpConstant, 2),
|
||||||
|
code.Make(code.OpCall, 1),
|
||||||
|
code.Make(code.OpReturnValue),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedInstructions: []code.Instructions{
|
||||||
|
code.Make(code.OpClosure, 3, 0),
|
||||||
|
code.Make(code.OpSetGlobal, 0),
|
||||||
|
code.Make(code.OpGetGlobal, 0),
|
||||||
|
code.Make(code.OpCall, 0),
|
||||||
|
code.Make(code.OpPop),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
runCompilerTests(t, tests)
|
||||||
|
}
|
||||||
|
|
||||||
func runCompilerTests(t *testing.T, tests []compilerTestCase) {
|
func runCompilerTests(t *testing.T, tests []compilerTestCase) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ const (
|
|||||||
LocalScope SymbolScope = "LOCAL"
|
LocalScope SymbolScope = "LOCAL"
|
||||||
GlobalScope SymbolScope = "GLOBAL"
|
GlobalScope SymbolScope = "GLOBAL"
|
||||||
BuiltinScope SymbolScope = "BUILTIN"
|
BuiltinScope SymbolScope = "BUILTIN"
|
||||||
|
FreeScope SymbolScope = "FREE"
|
||||||
|
FunctionScope SymbolScope = "FUNCTION"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Symbol struct {
|
type Symbol struct {
|
||||||
@@ -19,6 +21,8 @@ type SymbolTable struct {
|
|||||||
|
|
||||||
store map[string]Symbol
|
store map[string]Symbol
|
||||||
numDefinitions int
|
numDefinitions int
|
||||||
|
|
||||||
|
FreeSymbols []Symbol
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewEnclosedSymbolTable(outer *SymbolTable) *SymbolTable {
|
func NewEnclosedSymbolTable(outer *SymbolTable) *SymbolTable {
|
||||||
@@ -29,7 +33,8 @@ func NewEnclosedSymbolTable(outer *SymbolTable) *SymbolTable {
|
|||||||
|
|
||||||
func NewSymbolTable() *SymbolTable {
|
func NewSymbolTable() *SymbolTable {
|
||||||
s := make(map[string]Symbol)
|
s := make(map[string]Symbol)
|
||||||
return &SymbolTable{store: s}
|
free := []Symbol{}
|
||||||
|
return &SymbolTable{store: s, FreeSymbols: free}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SymbolTable) Define(name string) Symbol {
|
func (s *SymbolTable) Define(name string) Symbol {
|
||||||
@@ -49,9 +54,19 @@ func (s *SymbolTable) Resolve(name string) (Symbol, bool) {
|
|||||||
obj, ok := s.store[name]
|
obj, ok := s.store[name]
|
||||||
if !ok && s.Outer != nil {
|
if !ok && s.Outer != nil {
|
||||||
obj, ok = s.Outer.Resolve(name)
|
obj, ok = s.Outer.Resolve(name)
|
||||||
|
if !ok {
|
||||||
return obj, ok
|
return obj, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if obj.Scope == GlobalScope || obj.Scope == BuiltinScope {
|
||||||
|
return obj, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
free := s.DefineFree(obj)
|
||||||
|
return free, true
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
return obj, ok
|
return obj, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -60,3 +75,19 @@ func (s *SymbolTable) DefineBuiltin(index int, name string) Symbol {
|
|||||||
s.store[name] = symbol
|
s.store[name] = symbol
|
||||||
return symbol
|
return symbol
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SymbolTable) DefineFree(original Symbol) Symbol {
|
||||||
|
s.FreeSymbols = append(s.FreeSymbols, original)
|
||||||
|
|
||||||
|
symbol := Symbol{Name: original.Name, Index: len(s.FreeSymbols) - 1}
|
||||||
|
symbol.Scope = FreeScope
|
||||||
|
|
||||||
|
s.store[original.Name] = symbol
|
||||||
|
return symbol
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SymbolTable) DefineFunctionName(name string) Symbol {
|
||||||
|
symbol := Symbol{Name: name, Index: 0, Scope: FunctionScope}
|
||||||
|
s.store[name] = symbol
|
||||||
|
return symbol
|
||||||
|
}
|
||||||
|
|||||||
@@ -197,3 +197,154 @@ func TestDefineResolveBuiltins(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestResolveFree(t *testing.T) {
|
||||||
|
global := NewSymbolTable()
|
||||||
|
global.Define("a")
|
||||||
|
global.Define("b")
|
||||||
|
|
||||||
|
firstLocal := NewEnclosedSymbolTable(global)
|
||||||
|
firstLocal.Define("c")
|
||||||
|
firstLocal.Define("d")
|
||||||
|
|
||||||
|
secondLocal := NewEnclosedSymbolTable(firstLocal)
|
||||||
|
secondLocal.Define("e")
|
||||||
|
secondLocal.Define("f")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
table *SymbolTable
|
||||||
|
expectedSymbols []Symbol
|
||||||
|
expectedFreeSymbols []Symbol
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
firstLocal,
|
||||||
|
[]Symbol{
|
||||||
|
Symbol{Name: "a", Scope: GlobalScope, Index: 0},
|
||||||
|
Symbol{Name: "b", Scope: GlobalScope, Index: 1},
|
||||||
|
Symbol{Name: "c", Scope: LocalScope, Index: 0},
|
||||||
|
Symbol{Name: "d", Scope: LocalScope, Index: 1},
|
||||||
|
},
|
||||||
|
[]Symbol{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
secondLocal,
|
||||||
|
[]Symbol{
|
||||||
|
Symbol{Name: "a", Scope: GlobalScope, Index: 0},
|
||||||
|
Symbol{Name: "b", Scope: GlobalScope, Index: 1},
|
||||||
|
Symbol{Name: "c", Scope: FreeScope, Index: 0},
|
||||||
|
Symbol{Name: "d", Scope: FreeScope, Index: 1},
|
||||||
|
Symbol{Name: "e", Scope: LocalScope, Index: 0},
|
||||||
|
Symbol{Name: "f", Scope: LocalScope, Index: 1},
|
||||||
|
},
|
||||||
|
[]Symbol{
|
||||||
|
Symbol{Name: "c", Scope: LocalScope, Index: 0},
|
||||||
|
Symbol{Name: "d", Scope: LocalScope, Index: 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
for _, sym := range tt.expectedSymbols {
|
||||||
|
result, ok := tt.table.Resolve(sym.Name)
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("name %s not resolvable", sym.Name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if result != sym {
|
||||||
|
t.Errorf("expected %s to resolve to %+v, got=%+v",
|
||||||
|
sym.Name, sym, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tt.table.FreeSymbols) != len(tt.expectedFreeSymbols) {
|
||||||
|
t.Errorf("wrong number of free symbols. got=%d, want=%d",
|
||||||
|
len(tt.table.FreeSymbols), len(tt.expectedFreeSymbols))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, sym := range tt.expectedFreeSymbols {
|
||||||
|
result := tt.table.FreeSymbols[i]
|
||||||
|
if result != sym {
|
||||||
|
t.Errorf("wrong free symbol. got=%+v, want=%+v",
|
||||||
|
result, sym)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveUnresolvableFree(t *testing.T) {
|
||||||
|
global := NewSymbolTable()
|
||||||
|
global.Define("a")
|
||||||
|
|
||||||
|
firstLocal := NewEnclosedSymbolTable(global)
|
||||||
|
firstLocal.Define("c")
|
||||||
|
|
||||||
|
secondLocal := NewEnclosedSymbolTable(firstLocal)
|
||||||
|
secondLocal.Define("e")
|
||||||
|
secondLocal.Define("f")
|
||||||
|
|
||||||
|
expected := []Symbol{
|
||||||
|
Symbol{Name: "a", Scope: GlobalScope, Index: 0},
|
||||||
|
Symbol{Name: "c", Scope: FreeScope, Index: 0},
|
||||||
|
Symbol{Name: "e", Scope: LocalScope, Index: 0},
|
||||||
|
Symbol{Name: "f", Scope: LocalScope, Index: 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, sym := range expected {
|
||||||
|
result, ok := secondLocal.Resolve(sym.Name)
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("name %s not resolvable", sym.Name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if result != sym {
|
||||||
|
t.Errorf("expected %s to resolve to %+v, got=%+v",
|
||||||
|
sym.Name, sym, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedUnresolvable := []string{
|
||||||
|
"b",
|
||||||
|
"d",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, name := range expectedUnresolvable {
|
||||||
|
_, ok := secondLocal.Resolve(name)
|
||||||
|
if ok {
|
||||||
|
t.Errorf("name %s resolved, but was expected not to", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefineAndResolveFunctionName(t *testing.T) {
|
||||||
|
global := NewSymbolTable()
|
||||||
|
global.DefineFunctionName("a")
|
||||||
|
|
||||||
|
expected := Symbol{Name: "a", Scope: FunctionScope, Index: 0}
|
||||||
|
|
||||||
|
result, ok := global.Resolve(expected.Name)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("function name %s not resolvable", expected.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf("expected %s to resolve to %+v, got=%+v",
|
||||||
|
expected.Name, expected, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func TestShadowingFunctionName(t *testing.T) {
|
||||||
|
global := NewSymbolTable()
|
||||||
|
global.DefineFunctionName("a")
|
||||||
|
global.Define("a")
|
||||||
|
|
||||||
|
expected := Symbol{Name: "a", Scope: GlobalScope, Index: 0}
|
||||||
|
|
||||||
|
result, ok := global.Resolve(expected.Name)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("function name %s not resolvable", expected.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf("expected %s to resolve to %+v, got=%+v",
|
||||||
|
expected.Name, expected, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -169,6 +169,10 @@ func (p *Parser) parseLetStatement() *ast.LetStatement {
|
|||||||
|
|
||||||
stmt.Value = p.parseExpression(LOWEST)
|
stmt.Value = p.parseExpression(LOWEST)
|
||||||
|
|
||||||
|
if fl, ok := stmt.Value.(*ast.FunctionLiteral); ok {
|
||||||
|
fl.Name = stmt.Name.Value
|
||||||
|
}
|
||||||
|
|
||||||
if p.peekTokenIs(token.SEMICOLON) {
|
if p.peekTokenIs(token.SEMICOLON) {
|
||||||
p.nextToken()
|
p.nextToken()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -845,6 +845,37 @@ func TestParsingHashLiteralsWithExpressions(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFunctionLiteralWithName(t *testing.T) {
|
||||||
|
input := `let myFunction = fn() { };`
|
||||||
|
|
||||||
|
l := lexer.New(input)
|
||||||
|
p := New(l)
|
||||||
|
program := p.ParseProgram()
|
||||||
|
checkParserErrors(t, p)
|
||||||
|
|
||||||
|
if len(program.Statements) != 1 {
|
||||||
|
t.Fatalf("program.Body does not contain %d statements. got=%d\n",
|
||||||
|
1, len(program.Statements))
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt, ok := program.Statements[0].(*ast.LetStatement)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("program.Statements[0] is not ast.LetStatement. got=%T",
|
||||||
|
program.Statements[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
function, ok := stmt.Value.(*ast.FunctionLiteral)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("stmt.Value is not ast.FunctionLiteral. got=%T",
|
||||||
|
stmt.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
if function.Name != "myFunction" {
|
||||||
|
t.Fatalf("function literal name wrong. want 'myFunction', got=%q\n",
|
||||||
|
function.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func testLetStatement(t *testing.T, s ast.Statement, name string) bool {
|
func testLetStatement(t *testing.T, s ast.Statement, name string) bool {
|
||||||
if s.TokenLiteral() != "let" {
|
if s.TokenLiteral() != "let" {
|
||||||
t.Errorf("s.TokenLiteral not 'let'. got=%q", s.TokenLiteral())
|
t.Errorf("s.TokenLiteral not 'let'. got=%q", s.TokenLiteral())
|
||||||
|
|||||||
31
vm/vm.go
31
vm/vm.go
@@ -264,10 +264,27 @@ func (vm *VM) Run() error {
|
|||||||
|
|
||||||
case code.OpClosure:
|
case code.OpClosure:
|
||||||
constIndex := code.ReadUint16(ins[ip+1:])
|
constIndex := code.ReadUint16(ins[ip+1:])
|
||||||
_ = code.ReadUint8(ins[ip+3:])
|
numFree := code.ReadUint8(ins[ip+3:])
|
||||||
vm.currentFrame().ip += 3
|
vm.currentFrame().ip += 3
|
||||||
|
|
||||||
err := vm.pushClosure(int(constIndex))
|
err := vm.pushClosure(int(constIndex), int(numFree))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
case code.OpGetFree:
|
||||||
|
freeIndex := code.ReadUint8(ins[ip+1:])
|
||||||
|
vm.currentFrame().ip += 1
|
||||||
|
|
||||||
|
currentClosure := vm.currentFrame().cl
|
||||||
|
err := vm.push(currentClosure.Free[freeIndex])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
case code.OpCurrentClosure:
|
||||||
|
currentClosure := vm.currentFrame().cl
|
||||||
|
err := vm.push(currentClosure)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -533,14 +550,20 @@ func (vm *VM) callBuiltin(builtin *object.Builtin, numArgs int) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vm *VM) pushClosure(constIndex int) error {
|
func (vm *VM) pushClosure(constIndex, numFree int) error {
|
||||||
constant := vm.constants[constIndex]
|
constant := vm.constants[constIndex]
|
||||||
function, ok := constant.(*object.CompiledFunction)
|
function, ok := constant.(*object.CompiledFunction)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("not a function %+v", constant)
|
return fmt.Errorf("not a function %+v", constant)
|
||||||
}
|
}
|
||||||
|
|
||||||
closure := &object.Closure{Fn: function}
|
free := make([]object.Object, numFree)
|
||||||
|
for i := 0; i < numFree; i++ {
|
||||||
|
free[i] = vm.stack[vm.sp-numFree+i]
|
||||||
|
}
|
||||||
|
vm.sp = vm.sp - numFree
|
||||||
|
|
||||||
|
closure := &object.Closure{Fn: function, Free: free}
|
||||||
return vm.push(closure)
|
return vm.push(closure)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
144
vm/vm_test.go
144
vm/vm_test.go
@@ -27,6 +27,19 @@ func runVmTests(t *testing.T, tests []vmTestCase) {
|
|||||||
t.Fatalf("compiler error: %s", err)
|
t.Fatalf("compiler error: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//for i, constant := range comp.Bytecode().Constants {
|
||||||
|
// fmt.Printf("CONSTANT %d %p (%T):\n", i, constant, constant)
|
||||||
|
//
|
||||||
|
// switch constant := constant.(type) {
|
||||||
|
// case *object.CompiledFunction:
|
||||||
|
// fmt.Printf(" Instructions:\n%s", constant.Instructions)
|
||||||
|
// case *object.Integer:
|
||||||
|
// fmt.Printf(" Value: %d\n", constant.Value)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// fmt.Printf("\n")
|
||||||
|
//}
|
||||||
|
|
||||||
vm := New(comp.Bytecode())
|
vm := New(comp.Bytecode())
|
||||||
err = vm.Run()
|
err = vm.Run()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -611,3 +624,134 @@ func TestBuiltinFunctions(t *testing.T) {
|
|||||||
|
|
||||||
runVmTests(t, tests)
|
runVmTests(t, tests)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClosures(t *testing.T) {
|
||||||
|
tests := []vmTestCase{
|
||||||
|
{
|
||||||
|
input: `
|
||||||
|
let newClosure = fn(a) {
|
||||||
|
fn() { a; };
|
||||||
|
};
|
||||||
|
let closure = newClosure(99);
|
||||||
|
closure();
|
||||||
|
`,
|
||||||
|
expected: 99,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: `
|
||||||
|
let newAdder = fn(a, b) {
|
||||||
|
fn(c) { a + b + c };
|
||||||
|
};
|
||||||
|
let adder = newAdder(1, 2);
|
||||||
|
adder(8);
|
||||||
|
`,
|
||||||
|
expected: 11,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: `
|
||||||
|
let newAdder = fn(a, b) {
|
||||||
|
let c = a + b;
|
||||||
|
fn(d) { c + d };
|
||||||
|
};
|
||||||
|
let adder = newAdder(1, 2);
|
||||||
|
adder(8);
|
||||||
|
`,
|
||||||
|
expected: 11,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: `
|
||||||
|
let newAdderOuter = fn(a, b) {
|
||||||
|
let c = a + b;
|
||||||
|
fn(d) {
|
||||||
|
let e = d + c;
|
||||||
|
fn(f) { e + f; };
|
||||||
|
};
|
||||||
|
};
|
||||||
|
let newAdderInner = newAdderOuter(1, 2)
|
||||||
|
let adder = newAdderInner(3);
|
||||||
|
adder(8);
|
||||||
|
`,
|
||||||
|
expected: 14,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: `
|
||||||
|
let a = 1;
|
||||||
|
let newAdderOuter = fn(b) {
|
||||||
|
fn(c) {
|
||||||
|
fn(d) { a + b + c + d };
|
||||||
|
};
|
||||||
|
};
|
||||||
|
let newAdderInner = newAdderOuter(2)
|
||||||
|
let adder = newAdderInner(3);
|
||||||
|
adder(8);
|
||||||
|
`,
|
||||||
|
expected: 14,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: `
|
||||||
|
let newClosure = fn(a, b) {
|
||||||
|
let one = fn() { a; };
|
||||||
|
let two = fn() { b; };
|
||||||
|
fn() { one() + two(); };
|
||||||
|
};
|
||||||
|
let closure = newClosure(9, 90);
|
||||||
|
closure();
|
||||||
|
`,
|
||||||
|
expected: 99,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
runVmTests(t, tests)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecursiveFunctions(t *testing.T) {
|
||||||
|
tests := []vmTestCase{
|
||||||
|
{
|
||||||
|
input: `
|
||||||
|
let countDown = fn(x) {
|
||||||
|
if (x == 0) {
|
||||||
|
return 0;
|
||||||
|
} else {
|
||||||
|
countDown(x - 1);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
countDown(1);
|
||||||
|
`,
|
||||||
|
expected: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: `
|
||||||
|
let countDown = fn(x) {
|
||||||
|
if (x == 0) {
|
||||||
|
return 0;
|
||||||
|
} else {
|
||||||
|
countDown(x - 1);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let wrapper = fn() {
|
||||||
|
countDown(1);
|
||||||
|
};
|
||||||
|
wrapper();
|
||||||
|
`,
|
||||||
|
expected: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: `
|
||||||
|
let wrapper = fn() {
|
||||||
|
let countDown = fn(x) {
|
||||||
|
if (x == 0) {
|
||||||
|
return 0;
|
||||||
|
} else {
|
||||||
|
countDown(x - 1);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
countDown(1);
|
||||||
|
};
|
||||||
|
wrapper();
|
||||||
|
`,
|
||||||
|
expected: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
runVmTests(t, tests)
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user