functions with bindings
Some checks failed
Build / build (push) Failing after 1m42s
Test / build (push) Successful in 2m27s

This commit is contained in:
Chuck Smith
2024-03-08 14:19:20 -05:00
parent 9d06c90e41
commit ec9a586f7f
10 changed files with 350 additions and 30 deletions

View File

@@ -35,6 +35,8 @@ const (
OpCall OpCall
OpReturnValue OpReturnValue
OpReturn OpReturn
OpGetLocal
OpSetLocal
) )
type Definition struct { type Definition struct {
@@ -67,6 +69,8 @@ var definitions = map[Opcode]*Definition{
OpCall: {"OpCall", []int{}}, OpCall: {"OpCall", []int{}},
OpReturnValue: {"OpReturnValue", []int{}}, OpReturnValue: {"OpReturnValue", []int{}},
OpReturn: {"OpReturn", []int{}}, OpReturn: {"OpReturn", []int{}},
OpGetLocal: {"OpGetLocal", []int{1}},
OpSetLocal: {"OpSetLocal", []int{1}},
} }
func Lookup(op byte) (*Definition, error) { func Lookup(op byte) (*Definition, error) {
@@ -98,6 +102,8 @@ func Make(op Opcode, operands ...int) []byte {
switch width { switch width {
case 2: case 2:
binary.BigEndian.PutUint16(instruction[offset:], uint16(o)) binary.BigEndian.PutUint16(instruction[offset:], uint16(o))
case 1:
instruction[offset] = byte(o)
} }
offset += width offset += width
} }
@@ -151,6 +157,9 @@ func ReadOperands(def *Definition, ins Instructions) ([]int, int) {
switch width { switch width {
case 2: case 2:
operands[i] = int(ReadUint16(ins[offset:])) operands[i] = int(ReadUint16(ins[offset:]))
case 1:
operands[i] = int(ReadUint8(ins[offset:]))
} }
offset += width offset += width
@@ -162,3 +171,7 @@ func ReadOperands(def *Definition, ins Instructions) ([]int, int) {
func ReadUint16(ins Instructions) uint16 { func ReadUint16(ins Instructions) uint16 {
return binary.BigEndian.Uint16(ins) return binary.BigEndian.Uint16(ins)
} }
func ReadUint8(ins Instructions) uint8 {
return ins[0]
}

View File

@@ -10,6 +10,7 @@ func TestMake(t *testing.T) {
}{ }{
{OpConstant, []int{65534}, []byte{byte(OpConstant), 255, 254}}, {OpConstant, []int{65534}, []byte{byte(OpConstant), 255, 254}},
{OpAdd, []int{}, []byte{byte(OpAdd)}}, {OpAdd, []int{}, []byte{byte(OpAdd)}},
{OpGetLocal, []int{255}, []byte{byte(OpGetLocal), 255}},
} }
for _, tt := range test { for _, tt := range test {
@@ -30,13 +31,15 @@ func TestMake(t *testing.T) {
func TestInstructions(t *testing.T) { func TestInstructions(t *testing.T) {
instructions := []Instructions{ instructions := []Instructions{
Make(OpAdd), Make(OpAdd),
Make(OpGetLocal, 1),
Make(OpConstant, 2), Make(OpConstant, 2),
Make(OpConstant, 65535), Make(OpConstant, 65535),
} }
expected := `0000 OpAdd expected := `0000 OpAdd
0001 OpConstant 2 0001 OpGetLocal 1
0004 OpConstant 65535 0003 OpConstant 2
0006 OpConstant 65535
` `
concatted := Instructions{} concatted := Instructions{}
@@ -49,13 +52,14 @@ func TestInstructions(t *testing.T) {
} }
} }
func TestOperands(t *testing.T) { func TestReadOperands(t *testing.T) {
tests := []struct { tests := []struct {
op Opcode op Opcode
operands []int operands []int
bytesRead int bytesRead int
}{ }{
{OpConstant, []int{65535}, 2}, {OpConstant, []int{65535}, 2},
{OpGetLocal, []int{255}, 1},
} }
for _, tt := range tests { for _, tt := range tests {

View File

@@ -193,8 +193,13 @@ func (c *Compiler) Compile(node ast.Node) error {
if err != nil { if err != nil {
return err return err
} }
symbol := c.symbolTable.Define(node.Name.Value) symbol := c.symbolTable.Define(node.Name.Value)
if symbol.Scope == GlobalScope {
c.emit(code.OpSetGlobal, symbol.Index) c.emit(code.OpSetGlobal, symbol.Index)
} else {
c.emit(code.OpSetLocal, symbol.Index)
}
case *ast.Identifier: case *ast.Identifier:
symbol, ok := c.symbolTable.Resolve(node.Value) symbol, ok := c.symbolTable.Resolve(node.Value)
@@ -202,7 +207,11 @@ func (c *Compiler) Compile(node ast.Node) error {
return fmt.Errorf("undefined varible %s", node.Value) return fmt.Errorf("undefined varible %s", node.Value)
} }
if symbol.Scope == GlobalScope {
c.emit(code.OpGetGlobal, symbol.Index) c.emit(code.OpGetGlobal, symbol.Index)
} else {
c.emit(code.OpGetLocal, symbol.Index)
}
case *ast.StringLiteral: case *ast.StringLiteral:
str := &object.String{Value: node.Value} str := &object.String{Value: node.Value}
@@ -268,9 +277,10 @@ func (c *Compiler) Compile(node ast.Node) error {
c.emit(code.OpReturn) c.emit(code.OpReturn)
} }
numLocals := c.symbolTable.numDefinitions
instructions := c.leaveScope() instructions := c.leaveScope()
compiledFn := &object.CompiledFunction{Instructions: instructions} compiledFn := &object.CompiledFunction{Instructions: instructions, NumLocals: numLocals}
c.emit(code.OpConstant, c.addConstant(compiledFn)) c.emit(code.OpConstant, c.addConstant(compiledFn))
case *ast.ReturnStatement: case *ast.ReturnStatement:
@@ -374,6 +384,8 @@ func (c *Compiler) enterScope() {
} }
c.scopes = append(c.scopes, scope) c.scopes = append(c.scopes, scope)
c.scopeIndex++ c.scopeIndex++
c.symbolTable = NewEnclosedSymbolTable(c.symbolTable)
} }
func (c *Compiler) leaveScope() code.Instructions { func (c *Compiler) leaveScope() code.Instructions {
@@ -382,6 +394,8 @@ func (c *Compiler) leaveScope() code.Instructions {
c.scopes = c.scopes[:len(c.scopes)-1] c.scopes = c.scopes[:len(c.scopes)-1]
c.scopeIndex-- c.scopeIndex--
c.symbolTable = c.symbolTable.Outer
return instructions return instructions
} }

View File

@@ -591,6 +591,80 @@ func TestFunctionCalls(t *testing.T) {
runCompilerTests(t, tests) runCompilerTests(t, tests)
} }
func TestLetStatementScopes(t *testing.T) {
tests := []compilerTestCase{
{
input: `
let num = 55;
fn() { num }
`,
expectedConstants: []interface{}{
55,
[]code.Instructions{
code.Make(code.OpGetGlobal, 0),
code.Make(code.OpReturnValue),
},
},
expectedInstructions: []code.Instructions{
code.Make(code.OpConstant, 0),
code.Make(code.OpSetGlobal, 0),
code.Make(code.OpConstant, 1),
code.Make(code.OpPop),
},
},
{
input: `
fn() {
let num = 55;
num
}
`,
expectedConstants: []interface{}{
55,
[]code.Instructions{
code.Make(code.OpConstant, 0),
code.Make(code.OpSetLocal, 0),
code.Make(code.OpGetLocal, 0),
code.Make(code.OpReturnValue),
},
},
expectedInstructions: []code.Instructions{
code.Make(code.OpConstant, 1),
code.Make(code.OpPop),
},
},
{
input: `
fn() {
let a = 55;
let b = 77;
a + b
}
`,
expectedConstants: []interface{}{
55,
77,
[]code.Instructions{
code.Make(code.OpConstant, 0),
code.Make(code.OpSetLocal, 0),
code.Make(code.OpConstant, 1),
code.Make(code.OpSetLocal, 1),
code.Make(code.OpGetLocal, 0),
code.Make(code.OpGetLocal, 1),
code.Make(code.OpAdd),
code.Make(code.OpReturnValue),
},
},
expectedInstructions: []code.Instructions{
code.Make(code.OpConstant, 2),
code.Make(code.OpPop),
},
},
}
runCompilerTests(t, tests)
}
func runCompilerTests(t *testing.T, tests []compilerTestCase) { func runCompilerTests(t *testing.T, tests []compilerTestCase) {
t.Helper() t.Helper()

View File

@@ -3,7 +3,8 @@ package compiler
type SymbolScope string type SymbolScope string
const ( const (
GlobalScope SymbolScope = "Global" LocalScope SymbolScope = "LOCAL"
GlobalScope SymbolScope = "GLOBAL"
) )
type Symbol struct { type Symbol struct {
@@ -13,21 +14,31 @@ type Symbol struct {
} }
type SymbolTable struct { type SymbolTable struct {
Outer *SymbolTable
store map[string]Symbol store map[string]Symbol
numDefinitions int numDefinitions int
} }
func NewEnclosedSymbolTable(outer *SymbolTable) *SymbolTable {
s := NewSymbolTable()
s.Outer = outer
return s
}
func NewSymbolTable() *SymbolTable { func NewSymbolTable() *SymbolTable {
s := make(map[string]Symbol) s := make(map[string]Symbol)
return &SymbolTable{store: s} return &SymbolTable{store: s}
} }
func (s *SymbolTable) Define(name string) Symbol { func (s *SymbolTable) Define(name string) Symbol {
symbol := Symbol{ symbol := Symbol{Name: name, Index: s.numDefinitions}
Name: name, if s.Outer == nil {
Scope: GlobalScope, symbol.Scope = GlobalScope
Index: s.numDefinitions, } else {
symbol.Scope = LocalScope
} }
s.store[name] = symbol s.store[name] = symbol
s.numDefinitions++ s.numDefinitions++
return symbol return symbol
@@ -35,5 +46,10 @@ func (s *SymbolTable) Define(name string) Symbol {
func (s *SymbolTable) Resolve(name string) (Symbol, bool) { func (s *SymbolTable) Resolve(name string) (Symbol, bool) {
obj, ok := s.store[name] obj, ok := s.store[name]
if !ok && s.Outer != nil {
obj, ok = s.Outer.Resolve(name)
return obj, ok
}
return obj, ok return obj, ok
} }

View File

@@ -14,6 +14,10 @@ func TestDefine(t *testing.T) {
Scope: GlobalScope, Scope: GlobalScope,
Index: 1, Index: 1,
}, },
"c": {Name: "c", Scope: LocalScope, Index: 0},
"d": {Name: "d", Scope: LocalScope, Index: 1},
"e": {Name: "e", Scope: LocalScope, Index: 0},
"f": {Name: "f", Scope: LocalScope, Index: 1},
} }
global := NewSymbolTable() global := NewSymbolTable()
@@ -27,6 +31,30 @@ func TestDefine(t *testing.T) {
if b != expected["b"] { if b != expected["b"] {
t.Errorf("expected b=%+v, got=%+v", expected["b"], b) t.Errorf("expected b=%+v, got=%+v", expected["b"], b)
} }
firstLocal := NewEnclosedSymbolTable(global)
c := firstLocal.Define("c")
if c != expected["c"] {
t.Errorf("expected c=%+v, got=%+v", expected["c"], c)
}
d := firstLocal.Define("d")
if d != expected["d"] {
t.Errorf("expected d=%+v, got=%+v", expected["d"], d)
}
secondLocal := NewEnclosedSymbolTable(firstLocal)
e := secondLocal.Define("e")
if e != expected["e"] {
t.Errorf("expected e=%+v, got=%+v", expected["e"], e)
}
f := secondLocal.Define("f")
if f != expected["f"] {
t.Errorf("expected f=%+v, got=%+v", expected["f"], f)
}
} }
func TestResolveGlobal(t *testing.T) { func TestResolveGlobal(t *testing.T) {
@@ -57,3 +85,84 @@ func TestResolveGlobal(t *testing.T) {
} }
} }
} }
func TestResolveLocal(t *testing.T) {
global := NewSymbolTable()
global.Define("a")
global.Define("b")
local := NewEnclosedSymbolTable(global)
local.Define("c")
local.Define("d")
expected := []Symbol{
Symbol{Name: "a", Scope: GlobalScope, Index: 0},
Symbol{Name: "b", Scope: GlobalScope, Index: 1},
Symbol{Name: "c", Scope: LocalScope, Index: 0},
Symbol{Name: "d", Scope: LocalScope, Index: 1},
}
for _, sym := range expected {
result, ok := local.Resolve(sym.Name)
if !ok {
t.Errorf("name %s not resolvable", sym.Name)
continue
}
if result != sym {
t.Errorf("expected %s to resolve to %+v, got=%+v",
sym.Name, sym, result)
}
}
}
func TestResolveNestedLocal(t *testing.T) {
global := NewSymbolTable()
global.Define("a")
global.Define("b")
firstLocal := NewEnclosedSymbolTable(global)
firstLocal.Define("c")
firstLocal.Define("d")
secondLocal := NewEnclosedSymbolTable(global)
secondLocal.Define("e")
secondLocal.Define("f")
tests := []struct {
table *SymbolTable
expectedSymbols []Symbol
}{
{
firstLocal,
[]Symbol{
Symbol{Name: "a", Scope: GlobalScope, Index: 0},
Symbol{Name: "b", Scope: GlobalScope, Index: 1},
Symbol{Name: "c", Scope: LocalScope, Index: 0},
Symbol{Name: "d", Scope: LocalScope, Index: 1},
},
},
{
secondLocal,
[]Symbol{
Symbol{Name: "a", Scope: GlobalScope, Index: 0},
Symbol{Name: "b", Scope: GlobalScope, Index: 1},
Symbol{Name: "e", Scope: LocalScope, Index: 0},
Symbol{Name: "f", Scope: LocalScope, Index: 1},
},
},
}
for _, tt := range tests {
for _, sym := range tt.expectedSymbols {
result, ok := tt.table.Resolve(sym.Name)
if !ok {
t.Errorf("name %s not resolvable", sym.Name)
continue
}
if result != sym {
t.Errorf("expected %s to resolve to %+v, got=%+v",
sym.Name, sym, result)
}
}
}
}

View File

@@ -36,6 +36,7 @@ type Integer struct {
type CompiledFunction struct { type CompiledFunction struct {
Instructions code.Instructions Instructions code.Instructions
NumLocals int
} }
func (i *Integer) Type() ObjectType { func (i *Integer) Type() ObjectType {

20
vm/frame.go Normal file
View File

@@ -0,0 +1,20 @@
package vm
import (
"monkey/code"
"monkey/object"
)
type Frame struct {
fn *object.CompiledFunction
ip int
basePointer int
}
func NewFrame(fn *object.CompiledFunction, basePointer int) *Frame {
return &Frame{fn: fn, ip: -1, basePointer: basePointer}
}
func (f *Frame) Instructions() code.Instructions {
return f.fn.Instructions
}

View File

@@ -15,19 +15,6 @@ var Null = &object.Null{}
var True = &object.Boolean{Value: true} var True = &object.Boolean{Value: true}
var False = &object.Boolean{Value: false} var False = &object.Boolean{Value: false}
type Frame struct {
fn *object.CompiledFunction
ip int
}
func NewFrame(fn *object.CompiledFunction) *Frame {
return &Frame{fn: fn, ip: -1}
}
func (f *Frame) Instructions() code.Instructions {
return f.fn.Instructions
}
type VM struct { type VM struct {
constants []object.Object constants []object.Object
@@ -42,7 +29,7 @@ type VM struct {
func New(bytecode *compiler.Bytecode) *VM { func New(bytecode *compiler.Bytecode) *VM {
mainFn := &object.CompiledFunction{Instructions: bytecode.Instructions} mainFn := &object.CompiledFunction{Instructions: bytecode.Instructions}
mainFrame := NewFrame(mainFn) mainFrame := NewFrame(mainFn, 0)
frames := make([]*Frame, MaxFrames) frames := make([]*Frame, MaxFrames)
frames[0] = mainFrame frames[0] = mainFrame
@@ -220,13 +207,15 @@ func (vm *VM) Run() error {
if !ok { if !ok {
return fmt.Errorf("calling non-function") return fmt.Errorf("calling non-function")
} }
frame := NewFrame(fn) frame := NewFrame(fn, vm.sp)
vm.pushFrame(frame) vm.pushFrame(frame)
vm.sp = frame.basePointer + fn.NumLocals
case code.OpReturnValue: case code.OpReturnValue:
returnValue := vm.pop() returnValue := vm.pop()
vm.popFrame()
vm.pop() frame := vm.popFrame()
vm.sp = frame.basePointer - 1
err := vm.push(returnValue) err := vm.push(returnValue)
if err != nil { if err != nil {
@@ -234,14 +223,32 @@ func (vm *VM) Run() error {
} }
case code.OpReturn: case code.OpReturn:
vm.popFrame() frame := vm.popFrame()
vm.pop() vm.sp = frame.basePointer - 1
err := vm.push(Null) err := vm.push(Null)
if err != nil { if err != nil {
return err return err
} }
case code.OpSetLocal:
localIndex := code.ReadUint8(ins[ip+1:])
vm.currentFrame().ip += 1
frame := vm.currentFrame()
vm.stack[frame.basePointer+int(localIndex)] = vm.pop()
case code.OpGetLocal:
localIndex := code.ReadUint8(ins[ip+1:])
vm.currentFrame().ip += 1
frame := vm.currentFrame()
err := vm.push(vm.stack[frame.basePointer+int(localIndex)])
if err != nil {
return err
}
} }
} }

View File

@@ -379,6 +379,16 @@ func TestFirstClassFunctions(t *testing.T) {
let returnsOne = fn() { 1; }; let returnsOne = fn() { 1; };
let returnsOneReturner = fn() { returnsOne; }; let returnsOneReturner = fn() { returnsOne; };
returnsOneReturner()(); returnsOneReturner()();
`,
expected: 1,
},
{
input: `
let returnsOneReturner = fn() {
let returnsOne = fn() { 1; };
returnsOne;
};
returnsOneReturner()();
`, `,
expected: 1, expected: 1,
}, },
@@ -386,3 +396,55 @@ func TestFirstClassFunctions(t *testing.T) {
runVmTests(t, tests) runVmTests(t, tests)
} }
func TestCallingFunctionsWithBindings(t *testing.T) {
tests := []vmTestCase{
{
input: `
let one = fn() { let one = 1; one };
one();
`,
expected: 1,
},
{
input: `
let oneAndTwo = fn() { let one = 1; let two = 2; one + two; };
oneAndTwo();
`,
expected: 3,
},
{
input: `
let oneAndTwo = fn() { let one = 1; let two = 2; one + two; };
let threeAndFour = fn() { let three = 3; let four = 4; three + four; };
oneAndTwo() + threeAndFour();
`,
expected: 10,
},
{
input: `
let firstFoobar = fn() { let foobar = 50; foobar; };
let secondFoobar = fn() { let foobar = 100; foobar; };
firstFoobar() + secondFoobar();
`,
expected: 150,
},
{
input: `
let globalSeed = 50;
let minusOne = fn() {
let num = 1;
globalSeed - num;
}
let minusTwo = fn() {
let num = 2;
globalSeed - num;
}
minusOne() + minusTwo();
`,
expected: 97,
},
}
runVmTests(t, tests)
}