functions with bindings
This commit is contained in:
@@ -193,8 +193,13 @@ func (c *Compiler) Compile(node ast.Node) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
symbol := c.symbolTable.Define(node.Name.Value)
|
||||
c.emit(code.OpSetGlobal, symbol.Index)
|
||||
if symbol.Scope == GlobalScope {
|
||||
c.emit(code.OpSetGlobal, symbol.Index)
|
||||
} else {
|
||||
c.emit(code.OpSetLocal, symbol.Index)
|
||||
}
|
||||
|
||||
case *ast.Identifier:
|
||||
symbol, ok := c.symbolTable.Resolve(node.Value)
|
||||
@@ -202,7 +207,11 @@ func (c *Compiler) Compile(node ast.Node) error {
|
||||
return fmt.Errorf("undefined varible %s", node.Value)
|
||||
}
|
||||
|
||||
c.emit(code.OpGetGlobal, symbol.Index)
|
||||
if symbol.Scope == GlobalScope {
|
||||
c.emit(code.OpGetGlobal, symbol.Index)
|
||||
} else {
|
||||
c.emit(code.OpGetLocal, symbol.Index)
|
||||
}
|
||||
|
||||
case *ast.StringLiteral:
|
||||
str := &object.String{Value: node.Value}
|
||||
@@ -268,9 +277,10 @@ func (c *Compiler) Compile(node ast.Node) error {
|
||||
c.emit(code.OpReturn)
|
||||
}
|
||||
|
||||
numLocals := c.symbolTable.numDefinitions
|
||||
instructions := c.leaveScope()
|
||||
|
||||
compiledFn := &object.CompiledFunction{Instructions: instructions}
|
||||
compiledFn := &object.CompiledFunction{Instructions: instructions, NumLocals: numLocals}
|
||||
c.emit(code.OpConstant, c.addConstant(compiledFn))
|
||||
|
||||
case *ast.ReturnStatement:
|
||||
@@ -374,6 +384,8 @@ func (c *Compiler) enterScope() {
|
||||
}
|
||||
c.scopes = append(c.scopes, scope)
|
||||
c.scopeIndex++
|
||||
|
||||
c.symbolTable = NewEnclosedSymbolTable(c.symbolTable)
|
||||
}
|
||||
|
||||
func (c *Compiler) leaveScope() code.Instructions {
|
||||
@@ -382,6 +394,8 @@ func (c *Compiler) leaveScope() code.Instructions {
|
||||
c.scopes = c.scopes[:len(c.scopes)-1]
|
||||
c.scopeIndex--
|
||||
|
||||
c.symbolTable = c.symbolTable.Outer
|
||||
|
||||
return instructions
|
||||
}
|
||||
|
||||
|
||||
@@ -591,6 +591,80 @@ func TestFunctionCalls(t *testing.T) {
|
||||
runCompilerTests(t, tests)
|
||||
}
|
||||
|
||||
func TestLetStatementScopes(t *testing.T) {
|
||||
tests := []compilerTestCase{
|
||||
{
|
||||
input: `
|
||||
let num = 55;
|
||||
fn() { num }
|
||||
`,
|
||||
expectedConstants: []interface{}{
|
||||
55,
|
||||
[]code.Instructions{
|
||||
code.Make(code.OpGetGlobal, 0),
|
||||
code.Make(code.OpReturnValue),
|
||||
},
|
||||
},
|
||||
expectedInstructions: []code.Instructions{
|
||||
code.Make(code.OpConstant, 0),
|
||||
code.Make(code.OpSetGlobal, 0),
|
||||
code.Make(code.OpConstant, 1),
|
||||
code.Make(code.OpPop),
|
||||
},
|
||||
},
|
||||
{
|
||||
input: `
|
||||
fn() {
|
||||
let num = 55;
|
||||
num
|
||||
}
|
||||
`,
|
||||
expectedConstants: []interface{}{
|
||||
55,
|
||||
[]code.Instructions{
|
||||
code.Make(code.OpConstant, 0),
|
||||
code.Make(code.OpSetLocal, 0),
|
||||
code.Make(code.OpGetLocal, 0),
|
||||
code.Make(code.OpReturnValue),
|
||||
},
|
||||
},
|
||||
expectedInstructions: []code.Instructions{
|
||||
code.Make(code.OpConstant, 1),
|
||||
code.Make(code.OpPop),
|
||||
},
|
||||
},
|
||||
{
|
||||
input: `
|
||||
fn() {
|
||||
let a = 55;
|
||||
let b = 77;
|
||||
a + b
|
||||
}
|
||||
`,
|
||||
expectedConstants: []interface{}{
|
||||
55,
|
||||
77,
|
||||
[]code.Instructions{
|
||||
code.Make(code.OpConstant, 0),
|
||||
code.Make(code.OpSetLocal, 0),
|
||||
code.Make(code.OpConstant, 1),
|
||||
code.Make(code.OpSetLocal, 1),
|
||||
code.Make(code.OpGetLocal, 0),
|
||||
code.Make(code.OpGetLocal, 1),
|
||||
code.Make(code.OpAdd),
|
||||
code.Make(code.OpReturnValue),
|
||||
},
|
||||
},
|
||||
expectedInstructions: []code.Instructions{
|
||||
code.Make(code.OpConstant, 2),
|
||||
code.Make(code.OpPop),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
runCompilerTests(t, tests)
|
||||
}
|
||||
|
||||
func runCompilerTests(t *testing.T, tests []compilerTestCase) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -3,7 +3,8 @@ package compiler
|
||||
type SymbolScope string
|
||||
|
||||
const (
|
||||
GlobalScope SymbolScope = "Global"
|
||||
LocalScope SymbolScope = "LOCAL"
|
||||
GlobalScope SymbolScope = "GLOBAL"
|
||||
)
|
||||
|
||||
type Symbol struct {
|
||||
@@ -13,21 +14,31 @@ type Symbol struct {
|
||||
}
|
||||
|
||||
type SymbolTable struct {
|
||||
Outer *SymbolTable
|
||||
|
||||
store map[string]Symbol
|
||||
numDefinitions int
|
||||
}
|
||||
|
||||
func NewEnclosedSymbolTable(outer *SymbolTable) *SymbolTable {
|
||||
s := NewSymbolTable()
|
||||
s.Outer = outer
|
||||
return s
|
||||
}
|
||||
|
||||
func NewSymbolTable() *SymbolTable {
|
||||
s := make(map[string]Symbol)
|
||||
return &SymbolTable{store: s}
|
||||
}
|
||||
|
||||
func (s *SymbolTable) Define(name string) Symbol {
|
||||
symbol := Symbol{
|
||||
Name: name,
|
||||
Scope: GlobalScope,
|
||||
Index: s.numDefinitions,
|
||||
symbol := Symbol{Name: name, Index: s.numDefinitions}
|
||||
if s.Outer == nil {
|
||||
symbol.Scope = GlobalScope
|
||||
} else {
|
||||
symbol.Scope = LocalScope
|
||||
}
|
||||
|
||||
s.store[name] = symbol
|
||||
s.numDefinitions++
|
||||
return symbol
|
||||
@@ -35,5 +46,10 @@ func (s *SymbolTable) Define(name string) Symbol {
|
||||
|
||||
func (s *SymbolTable) Resolve(name string) (Symbol, bool) {
|
||||
obj, ok := s.store[name]
|
||||
if !ok && s.Outer != nil {
|
||||
obj, ok = s.Outer.Resolve(name)
|
||||
return obj, ok
|
||||
}
|
||||
|
||||
return obj, ok
|
||||
}
|
||||
|
||||
@@ -14,6 +14,10 @@ func TestDefine(t *testing.T) {
|
||||
Scope: GlobalScope,
|
||||
Index: 1,
|
||||
},
|
||||
"c": {Name: "c", Scope: LocalScope, Index: 0},
|
||||
"d": {Name: "d", Scope: LocalScope, Index: 1},
|
||||
"e": {Name: "e", Scope: LocalScope, Index: 0},
|
||||
"f": {Name: "f", Scope: LocalScope, Index: 1},
|
||||
}
|
||||
|
||||
global := NewSymbolTable()
|
||||
@@ -27,6 +31,30 @@ func TestDefine(t *testing.T) {
|
||||
if b != expected["b"] {
|
||||
t.Errorf("expected b=%+v, got=%+v", expected["b"], b)
|
||||
}
|
||||
|
||||
firstLocal := NewEnclosedSymbolTable(global)
|
||||
|
||||
c := firstLocal.Define("c")
|
||||
if c != expected["c"] {
|
||||
t.Errorf("expected c=%+v, got=%+v", expected["c"], c)
|
||||
}
|
||||
|
||||
d := firstLocal.Define("d")
|
||||
if d != expected["d"] {
|
||||
t.Errorf("expected d=%+v, got=%+v", expected["d"], d)
|
||||
}
|
||||
|
||||
secondLocal := NewEnclosedSymbolTable(firstLocal)
|
||||
|
||||
e := secondLocal.Define("e")
|
||||
if e != expected["e"] {
|
||||
t.Errorf("expected e=%+v, got=%+v", expected["e"], e)
|
||||
}
|
||||
|
||||
f := secondLocal.Define("f")
|
||||
if f != expected["f"] {
|
||||
t.Errorf("expected f=%+v, got=%+v", expected["f"], f)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveGlobal(t *testing.T) {
|
||||
@@ -57,3 +85,84 @@ func TestResolveGlobal(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveLocal(t *testing.T) {
|
||||
global := NewSymbolTable()
|
||||
global.Define("a")
|
||||
global.Define("b")
|
||||
|
||||
local := NewEnclosedSymbolTable(global)
|
||||
local.Define("c")
|
||||
local.Define("d")
|
||||
|
||||
expected := []Symbol{
|
||||
Symbol{Name: "a", Scope: GlobalScope, Index: 0},
|
||||
Symbol{Name: "b", Scope: GlobalScope, Index: 1},
|
||||
Symbol{Name: "c", Scope: LocalScope, Index: 0},
|
||||
Symbol{Name: "d", Scope: LocalScope, Index: 1},
|
||||
}
|
||||
|
||||
for _, sym := range expected {
|
||||
result, ok := local.Resolve(sym.Name)
|
||||
if !ok {
|
||||
t.Errorf("name %s not resolvable", sym.Name)
|
||||
continue
|
||||
}
|
||||
if result != sym {
|
||||
t.Errorf("expected %s to resolve to %+v, got=%+v",
|
||||
sym.Name, sym, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveNestedLocal(t *testing.T) {
|
||||
global := NewSymbolTable()
|
||||
global.Define("a")
|
||||
global.Define("b")
|
||||
|
||||
firstLocal := NewEnclosedSymbolTable(global)
|
||||
firstLocal.Define("c")
|
||||
firstLocal.Define("d")
|
||||
|
||||
secondLocal := NewEnclosedSymbolTable(global)
|
||||
secondLocal.Define("e")
|
||||
secondLocal.Define("f")
|
||||
|
||||
tests := []struct {
|
||||
table *SymbolTable
|
||||
expectedSymbols []Symbol
|
||||
}{
|
||||
{
|
||||
firstLocal,
|
||||
[]Symbol{
|
||||
Symbol{Name: "a", Scope: GlobalScope, Index: 0},
|
||||
Symbol{Name: "b", Scope: GlobalScope, Index: 1},
|
||||
Symbol{Name: "c", Scope: LocalScope, Index: 0},
|
||||
Symbol{Name: "d", Scope: LocalScope, Index: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
secondLocal,
|
||||
[]Symbol{
|
||||
Symbol{Name: "a", Scope: GlobalScope, Index: 0},
|
||||
Symbol{Name: "b", Scope: GlobalScope, Index: 1},
|
||||
Symbol{Name: "e", Scope: LocalScope, Index: 0},
|
||||
Symbol{Name: "f", Scope: LocalScope, Index: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
for _, sym := range tt.expectedSymbols {
|
||||
result, ok := tt.table.Resolve(sym.Name)
|
||||
if !ok {
|
||||
t.Errorf("name %s not resolvable", sym.Name)
|
||||
continue
|
||||
}
|
||||
if result != sym {
|
||||
t.Errorf("expected %s to resolve to %+v, got=%+v",
|
||||
sym.Name, sym, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user