restructure project
This commit is contained in:
651
internal/compiler/compiler.go
Normal file
651
internal/compiler/compiler.go
Normal file
@@ -0,0 +1,651 @@
|
||||
package compiler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"monkey/internal/ast"
|
||||
"monkey/internal/builtins"
|
||||
"monkey/internal/code"
|
||||
"monkey/internal/object"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type EmittedInstruction struct {
|
||||
Opcode code.Opcode
|
||||
Position int
|
||||
}
|
||||
|
||||
type CompilationScope struct {
|
||||
instructions code.Instructions
|
||||
lastInstruction EmittedInstruction
|
||||
previousInstruction EmittedInstruction
|
||||
}
|
||||
|
||||
type Compiler struct {
|
||||
Debug bool
|
||||
|
||||
l int
|
||||
constants []object.Object
|
||||
|
||||
symbolTable *SymbolTable
|
||||
|
||||
scopes []CompilationScope
|
||||
scopeIndex int
|
||||
}
|
||||
|
||||
func New() *Compiler {
|
||||
mainScope := CompilationScope{
|
||||
instructions: code.Instructions{},
|
||||
lastInstruction: EmittedInstruction{},
|
||||
previousInstruction: EmittedInstruction{},
|
||||
}
|
||||
|
||||
symbolTable := NewSymbolTable()
|
||||
|
||||
for i, builtin := range builtins.BuiltinsIndex {
|
||||
symbolTable.DefineBuiltin(i, builtin.Name)
|
||||
}
|
||||
|
||||
return &Compiler{
|
||||
constants: []object.Object{},
|
||||
symbolTable: symbolTable,
|
||||
scopes: []CompilationScope{mainScope},
|
||||
scopeIndex: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Compiler) currentInstructions() code.Instructions {
|
||||
return c.scopes[c.scopeIndex].instructions
|
||||
}
|
||||
|
||||
func NewWithState(s *SymbolTable, constants []object.Object) *Compiler {
|
||||
compiler := New()
|
||||
compiler.symbolTable = s
|
||||
compiler.constants = constants
|
||||
return compiler
|
||||
}
|
||||
|
||||
func (c *Compiler) Compile(node ast.Node) error {
|
||||
if c.Debug {
|
||||
log.Printf(
|
||||
"%sCompiling %T: %s\n",
|
||||
strings.Repeat(" ", c.l), node, node.String(),
|
||||
)
|
||||
}
|
||||
|
||||
switch node := node.(type) {
|
||||
case *ast.Program:
|
||||
c.l++
|
||||
for _, s := range node.Statements {
|
||||
err := c.Compile(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
case *ast.ExpressionStatement:
|
||||
c.l++
|
||||
err := c.Compile(node.Expression)
|
||||
c.l--
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.emit(code.OpPop)
|
||||
|
||||
case *ast.InfixExpression:
|
||||
if node.Operator == "<" || node.Operator == "<=" {
|
||||
c.l++
|
||||
err := c.Compile(node.Right)
|
||||
c.l--
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.l++
|
||||
err = c.Compile(node.Left)
|
||||
c.l--
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if node.Operator == "<=" {
|
||||
c.emit(code.OpGreaterThanEqual)
|
||||
} else {
|
||||
c.emit(code.OpGreaterThan)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
c.l++
|
||||
err := c.Compile(node.Left)
|
||||
c.l--
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.l++
|
||||
err = c.Compile(node.Right)
|
||||
c.l--
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch node.Operator {
|
||||
case "+":
|
||||
c.emit(code.OpAdd)
|
||||
case "-":
|
||||
c.emit(code.OpSub)
|
||||
case "*":
|
||||
c.emit(code.OpMul)
|
||||
case "/":
|
||||
c.emit(code.OpDiv)
|
||||
case "%":
|
||||
c.emit(code.OpMod)
|
||||
case "|":
|
||||
c.emit(code.OpBitwiseOR)
|
||||
case "^":
|
||||
c.emit(code.OpBitwiseXOR)
|
||||
case "&":
|
||||
c.emit(code.OpBitwiseAND)
|
||||
case "<<":
|
||||
c.emit(code.OpLeftShift)
|
||||
case ">>":
|
||||
c.emit(code.OpRightShift)
|
||||
case "||":
|
||||
c.emit(code.OpOr)
|
||||
case "&&":
|
||||
c.emit(code.OpAnd)
|
||||
case ">":
|
||||
c.emit(code.OpGreaterThan)
|
||||
case ">=":
|
||||
c.emit(code.OpGreaterThanEqual)
|
||||
case "==":
|
||||
c.emit(code.OpEqual)
|
||||
case "!=":
|
||||
c.emit(code.OpNotEqual)
|
||||
default:
|
||||
return fmt.Errorf("unknown operator %s", node.Operator)
|
||||
}
|
||||
|
||||
case *ast.IntegerLiteral:
|
||||
integer := &object.Integer{Value: node.Value}
|
||||
c.emit(code.OpConstant, c.addConstant(integer))
|
||||
|
||||
case *ast.Null:
|
||||
c.emit(code.OpNull)
|
||||
|
||||
case *ast.Boolean:
|
||||
if node.Value {
|
||||
c.emit(code.OpTrue)
|
||||
} else {
|
||||
c.emit(code.OpFalse)
|
||||
}
|
||||
|
||||
case *ast.PrefixExpression:
|
||||
c.l++
|
||||
err := c.Compile(node.Right)
|
||||
c.l--
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch node.Operator {
|
||||
case "!":
|
||||
c.emit(code.OpNot)
|
||||
case "~":
|
||||
c.emit(code.OpBitwiseNOT)
|
||||
case "-":
|
||||
c.emit(code.OpMinus)
|
||||
default:
|
||||
return fmt.Errorf("unknown operator %s", node.Operator)
|
||||
}
|
||||
|
||||
case *ast.IfExpression:
|
||||
c.l++
|
||||
err := c.Compile(node.Condition)
|
||||
c.l--
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Emit an `OpJumpNotTruthy` with a bogus value
|
||||
jumpNotTruthyPos := c.emit(code.OpJumpNotTruthy, 9999)
|
||||
|
||||
c.l++
|
||||
err = c.Compile(node.Consequence)
|
||||
c.l--
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.lastInstructionIs(code.OpPop) {
|
||||
c.removeLastPop()
|
||||
}
|
||||
|
||||
// Emit an `OnJump` with a bogus value
|
||||
jumpPos := c.emit(code.OpJump, 9999)
|
||||
|
||||
afterConsequencePos := len(c.currentInstructions())
|
||||
c.changeOperand(jumpNotTruthyPos, afterConsequencePos)
|
||||
|
||||
if node.Alternative == nil {
|
||||
c.emit(code.OpNull)
|
||||
} else {
|
||||
c.l++
|
||||
err := c.Compile(node.Alternative)
|
||||
c.l--
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.lastInstructionIs(code.OpPop) {
|
||||
c.removeLastPop()
|
||||
}
|
||||
}
|
||||
|
||||
afterAlternativePos := len(c.currentInstructions())
|
||||
c.changeOperand(jumpPos, afterAlternativePos)
|
||||
|
||||
case *ast.BlockStatement:
|
||||
c.l++
|
||||
for _, s := range node.Statements {
|
||||
err := c.Compile(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
c.l--
|
||||
|
||||
if c.lastInstructionIs(code.OpPop) {
|
||||
c.removeLastPop()
|
||||
} else {
|
||||
if !c.lastInstructionIs(code.OpReturn) {
|
||||
c.emit(code.OpNull)
|
||||
}
|
||||
}
|
||||
|
||||
case *ast.AssignmentExpression:
|
||||
if ident, ok := node.Left.(*ast.Identifier); ok {
|
||||
symbol, ok := c.symbolTable.Resolve(ident.Value)
|
||||
if !ok {
|
||||
return fmt.Errorf("undefined variable %s", ident.Value)
|
||||
}
|
||||
|
||||
c.l++
|
||||
err := c.Compile(node.Value)
|
||||
c.l--
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if symbol.Scope == GlobalScope {
|
||||
c.emit(code.OpAssignGlobal, symbol.Index)
|
||||
} else {
|
||||
c.emit(code.OpAssignLocal, symbol.Index)
|
||||
}
|
||||
} else if ie, ok := node.Left.(*ast.IndexExpression); ok {
|
||||
c.l++
|
||||
err := c.Compile(ie.Left)
|
||||
c.l--
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.l++
|
||||
err = c.Compile(ie.Index)
|
||||
c.l--
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.l++
|
||||
err = c.Compile(node.Value)
|
||||
c.l--
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.emit(code.OpSetItem)
|
||||
} else {
|
||||
return fmt.Errorf("expected identifier or index expression got=%s", node.Left)
|
||||
}
|
||||
|
||||
case *ast.BindExpression:
|
||||
var symbol Symbol
|
||||
|
||||
if ident, ok := node.Left.(*ast.Identifier); ok {
|
||||
symbol, ok = c.symbolTable.Resolve(ident.Value)
|
||||
if !ok {
|
||||
symbol = c.symbolTable.Define(ident.Value)
|
||||
} else {
|
||||
// Local shadowing of previously defined "free" variable in a
|
||||
// function now being rebound to a locally scoped variable.
|
||||
if symbol.Scope == FreeScope {
|
||||
symbol = c.symbolTable.Define(ident.Value)
|
||||
}
|
||||
}
|
||||
|
||||
c.l++
|
||||
err := c.Compile(node.Value)
|
||||
c.l--
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if symbol.Scope == GlobalScope {
|
||||
c.emit(code.OpSetGlobal, symbol.Index)
|
||||
} else {
|
||||
c.emit(code.OpSetLocal, symbol.Index)
|
||||
}
|
||||
|
||||
} else {
|
||||
return fmt.Errorf("expected identifier got=%s", node.Left)
|
||||
}
|
||||
|
||||
case *ast.Identifier:
|
||||
symbol, ok := c.symbolTable.Resolve(node.Value)
|
||||
if !ok {
|
||||
return fmt.Errorf("undefined varible %s", node.Value)
|
||||
}
|
||||
|
||||
c.loadSymbol(symbol)
|
||||
|
||||
case *ast.StringLiteral:
|
||||
str := &object.String{Value: node.Value}
|
||||
c.emit(code.OpConstant, c.addConstant(str))
|
||||
|
||||
case *ast.ArrayLiteral:
|
||||
for _, el := range node.Elements {
|
||||
c.l++
|
||||
err := c.Compile(el)
|
||||
c.l--
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
c.emit(code.OpArray, len(node.Elements))
|
||||
|
||||
case *ast.HashLiteral:
|
||||
keys := []ast.Expression{}
|
||||
for k := range node.Pairs {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Slice(keys, func(i, j int) bool {
|
||||
return keys[i].String() < keys[j].String()
|
||||
})
|
||||
|
||||
for _, k := range keys {
|
||||
c.l++
|
||||
err := c.Compile(k)
|
||||
c.l--
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.l++
|
||||
err = c.Compile(node.Pairs[k])
|
||||
c.l--
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
c.emit(code.OpHash, len(node.Pairs)*2)
|
||||
|
||||
case *ast.IndexExpression:
|
||||
c.l++
|
||||
err := c.Compile(node.Left)
|
||||
c.l--
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.l++
|
||||
err = c.Compile(node.Index)
|
||||
c.l--
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.emit(code.OpGetItem)
|
||||
|
||||
case *ast.FunctionLiteral:
|
||||
c.enterScope()
|
||||
|
||||
if node.Name != "" {
|
||||
c.symbolTable.DefineFunctionName(node.Name)
|
||||
}
|
||||
|
||||
for _, p := range node.Parameters {
|
||||
c.symbolTable.Define(p.Value)
|
||||
}
|
||||
|
||||
c.l++
|
||||
err := c.Compile(node.Body)
|
||||
c.l--
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.lastInstructionIs(code.OpPop) {
|
||||
c.replaceLastPopWithReturn()
|
||||
}
|
||||
|
||||
// If the function doesn't end with a return statement add one with a
|
||||
// `return null;` and also handle the edge-case of empty functions.
|
||||
if !c.lastInstructionIs(code.OpReturn) {
|
||||
if !c.lastInstructionIs(code.OpNull) {
|
||||
c.emit(code.OpNull)
|
||||
}
|
||||
c.emit(code.OpReturn)
|
||||
}
|
||||
|
||||
freeSymbols := c.symbolTable.FreeSymbols
|
||||
numLocals := c.symbolTable.numDefinitions
|
||||
instructions := c.leaveScope()
|
||||
|
||||
for _, s := range freeSymbols {
|
||||
c.loadSymbol(s)
|
||||
}
|
||||
|
||||
compiledFn := &object.CompiledFunction{
|
||||
Instructions: instructions,
|
||||
NumLocals: numLocals,
|
||||
NumParameters: len(node.Parameters),
|
||||
}
|
||||
|
||||
fnIndex := c.addConstant(compiledFn)
|
||||
c.emit(code.OpClosure, fnIndex, len(freeSymbols))
|
||||
|
||||
case *ast.ReturnStatement:
|
||||
c.l++
|
||||
err := c.Compile(node.ReturnValue)
|
||||
c.l--
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.emit(code.OpReturn)
|
||||
|
||||
case *ast.CallExpression:
|
||||
c.l++
|
||||
err := c.Compile(node.Function)
|
||||
c.l--
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, a := range node.Arguments {
|
||||
err := c.Compile(a)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
c.emit(code.OpCall, len(node.Arguments))
|
||||
|
||||
case *ast.WhileExpression:
|
||||
jumpConditionPos := len(c.currentInstructions())
|
||||
|
||||
c.l++
|
||||
err := c.Compile(node.Condition)
|
||||
c.l--
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Emit an `OpJump`with a bogus value
|
||||
jumpIfFalsePos := c.emit(code.OpJumpNotTruthy, 0xFFFF)
|
||||
|
||||
c.l++
|
||||
err = c.Compile(node.Consequence)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.l--
|
||||
|
||||
// Pop off the LoadNull(s) from ast.BlockStatement(s)
|
||||
c.emit(code.OpPop)
|
||||
|
||||
c.emit(code.OpJump, jumpConditionPos)
|
||||
|
||||
afterConsequencePos := c.emit(code.OpNull)
|
||||
c.changeOperand(jumpIfFalsePos, afterConsequencePos)
|
||||
|
||||
case *ast.ImportExpression:
|
||||
c.l++
|
||||
err := c.Compile(node.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.l--
|
||||
|
||||
c.emit(code.OpLoadModule)
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Compiler) addConstant(obj object.Object) int {
|
||||
c.constants = append(c.constants, obj)
|
||||
return len(c.constants) - 1
|
||||
}
|
||||
|
||||
func (c *Compiler) emit(op code.Opcode, operands ...int) int {
|
||||
ins := code.Make(op, operands...)
|
||||
pos := c.addInstruction(ins)
|
||||
|
||||
c.setLastInstruction(op, pos)
|
||||
|
||||
return pos
|
||||
}
|
||||
|
||||
func (c *Compiler) setLastInstruction(op code.Opcode, pos int) {
|
||||
previous := c.scopes[c.scopeIndex].lastInstruction
|
||||
last := EmittedInstruction{Opcode: op, Position: pos}
|
||||
|
||||
c.scopes[c.scopeIndex].previousInstruction = previous
|
||||
c.scopes[c.scopeIndex].lastInstruction = last
|
||||
}
|
||||
|
||||
func (c *Compiler) Bytecode() *Bytecode {
|
||||
return &Bytecode{
|
||||
Instructions: c.currentInstructions(),
|
||||
Constants: c.constants,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Compiler) addInstruction(ins []byte) int {
|
||||
postNewInstruction := len(c.currentInstructions())
|
||||
updatedInstructions := append(c.currentInstructions(), ins...)
|
||||
|
||||
c.scopes[c.scopeIndex].instructions = updatedInstructions
|
||||
|
||||
return postNewInstruction
|
||||
}
|
||||
|
||||
func (c *Compiler) lastInstructionIs(op code.Opcode) bool {
|
||||
if len(c.currentInstructions()) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
return c.scopes[c.scopeIndex].lastInstruction.Opcode == op
|
||||
}
|
||||
|
||||
func (c *Compiler) removeLastPop() {
|
||||
last := c.scopes[c.scopeIndex].lastInstruction
|
||||
previous := c.scopes[c.scopeIndex].previousInstruction
|
||||
|
||||
old := c.currentInstructions()
|
||||
new := old[:last.Position]
|
||||
|
||||
c.scopes[c.scopeIndex].instructions = new
|
||||
c.scopes[c.scopeIndex].lastInstruction = previous
|
||||
}
|
||||
|
||||
func (c *Compiler) replaceInstruction(pos int, newInstruction []byte) {
|
||||
ins := c.currentInstructions()
|
||||
|
||||
for i := 0; i < len(newInstruction); i++ {
|
||||
ins[pos+i] = newInstruction[i]
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Compiler) changeOperand(opPos int, operand int) {
|
||||
op := code.Opcode(c.currentInstructions()[opPos])
|
||||
newInstruction := code.Make(op, operand)
|
||||
|
||||
c.replaceInstruction(opPos, newInstruction)
|
||||
}
|
||||
|
||||
func (c *Compiler) enterScope() {
|
||||
scope := CompilationScope{
|
||||
instructions: code.Instructions{},
|
||||
lastInstruction: EmittedInstruction{},
|
||||
previousInstruction: EmittedInstruction{},
|
||||
}
|
||||
c.scopes = append(c.scopes, scope)
|
||||
c.scopeIndex++
|
||||
|
||||
c.symbolTable = NewEnclosedSymbolTable(c.symbolTable)
|
||||
}
|
||||
|
||||
func (c *Compiler) leaveScope() code.Instructions {
|
||||
instructions := c.currentInstructions()
|
||||
|
||||
c.scopes = c.scopes[:len(c.scopes)-1]
|
||||
c.scopeIndex--
|
||||
|
||||
c.symbolTable = c.symbolTable.Outer
|
||||
|
||||
return instructions
|
||||
}
|
||||
|
||||
func (c *Compiler) replaceLastPopWithReturn() {
|
||||
lastPos := c.scopes[c.scopeIndex].lastInstruction.Position
|
||||
c.replaceInstruction(lastPos, code.Make(code.OpReturn))
|
||||
|
||||
c.scopes[c.scopeIndex].lastInstruction.Opcode = code.OpReturn
|
||||
}
|
||||
|
||||
type Bytecode struct {
|
||||
Instructions code.Instructions
|
||||
Constants []object.Object
|
||||
}
|
||||
|
||||
func (c *Compiler) loadSymbol(s Symbol) {
|
||||
switch s.Scope {
|
||||
case GlobalScope:
|
||||
c.emit(code.OpGetGlobal, s.Index)
|
||||
case LocalScope:
|
||||
c.emit(code.OpGetLocal, s.Index)
|
||||
case BuiltinScope:
|
||||
c.emit(code.OpGetBuiltin, s.Index)
|
||||
case FreeScope:
|
||||
c.emit(code.OpGetFree, s.Index)
|
||||
case FunctionScope:
|
||||
c.emit(code.OpCurrentClosure)
|
||||
}
|
||||
}
|
||||
1204
internal/compiler/compiler_test.go
Normal file
1204
internal/compiler/compiler_test.go
Normal file
File diff suppressed because it is too large
Load Diff
93
internal/compiler/symbol_table.go
Normal file
93
internal/compiler/symbol_table.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package compiler
|
||||
|
||||
type SymbolScope string
|
||||
|
||||
const (
|
||||
LocalScope SymbolScope = "LOCAL"
|
||||
GlobalScope SymbolScope = "GLOBAL"
|
||||
BuiltinScope SymbolScope = "BUILTIN"
|
||||
FreeScope SymbolScope = "FREE"
|
||||
FunctionScope SymbolScope = "FUNCTION"
|
||||
)
|
||||
|
||||
type Symbol struct {
|
||||
Name string
|
||||
Scope SymbolScope
|
||||
Index int
|
||||
}
|
||||
|
||||
type SymbolTable struct {
|
||||
Outer *SymbolTable
|
||||
|
||||
Store map[string]Symbol
|
||||
numDefinitions int
|
||||
|
||||
FreeSymbols []Symbol
|
||||
}
|
||||
|
||||
func NewEnclosedSymbolTable(outer *SymbolTable) *SymbolTable {
|
||||
s := NewSymbolTable()
|
||||
s.Outer = outer
|
||||
return s
|
||||
}
|
||||
|
||||
func NewSymbolTable() *SymbolTable {
|
||||
s := make(map[string]Symbol)
|
||||
free := []Symbol{}
|
||||
return &SymbolTable{Store: s, FreeSymbols: free}
|
||||
}
|
||||
|
||||
func (s *SymbolTable) Define(name string) Symbol {
|
||||
symbol := Symbol{Name: name, Index: s.numDefinitions}
|
||||
if s.Outer == nil {
|
||||
symbol.Scope = GlobalScope
|
||||
} else {
|
||||
symbol.Scope = LocalScope
|
||||
}
|
||||
|
||||
s.Store[name] = symbol
|
||||
s.numDefinitions++
|
||||
return symbol
|
||||
}
|
||||
|
||||
func (s *SymbolTable) Resolve(name string) (Symbol, bool) {
|
||||
obj, ok := s.Store[name]
|
||||
if !ok && s.Outer != nil {
|
||||
obj, ok = s.Outer.Resolve(name)
|
||||
if !ok {
|
||||
return obj, ok
|
||||
}
|
||||
|
||||
if obj.Scope == GlobalScope || obj.Scope == BuiltinScope {
|
||||
return obj, ok
|
||||
}
|
||||
|
||||
free := s.DefineFree(obj)
|
||||
return free, true
|
||||
|
||||
}
|
||||
|
||||
return obj, ok
|
||||
}
|
||||
|
||||
func (s *SymbolTable) DefineBuiltin(index int, name string) Symbol {
|
||||
symbol := Symbol{Name: name, Index: index, Scope: BuiltinScope}
|
||||
s.Store[name] = symbol
|
||||
return symbol
|
||||
}
|
||||
|
||||
func (s *SymbolTable) DefineFree(original Symbol) Symbol {
|
||||
s.FreeSymbols = append(s.FreeSymbols, original)
|
||||
|
||||
symbol := Symbol{Name: original.Name, Index: len(s.FreeSymbols) - 1}
|
||||
symbol.Scope = FreeScope
|
||||
|
||||
s.Store[original.Name] = symbol
|
||||
return symbol
|
||||
}
|
||||
|
||||
func (s *SymbolTable) DefineFunctionName(name string) Symbol {
|
||||
symbol := Symbol{Name: name, Index: 0, Scope: FunctionScope}
|
||||
s.Store[name] = symbol
|
||||
return symbol
|
||||
}
|
||||
350
internal/compiler/symbol_table_test.go
Normal file
350
internal/compiler/symbol_table_test.go
Normal file
@@ -0,0 +1,350 @@
|
||||
package compiler
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDefine(t *testing.T) {
|
||||
expected := map[string]Symbol{
|
||||
"a": {
|
||||
Name: "a",
|
||||
Scope: GlobalScope,
|
||||
Index: 0,
|
||||
},
|
||||
"b": {
|
||||
Name: "b",
|
||||
Scope: GlobalScope,
|
||||
Index: 1,
|
||||
},
|
||||
"c": {Name: "c", Scope: LocalScope, Index: 0},
|
||||
"d": {Name: "d", Scope: LocalScope, Index: 1},
|
||||
"e": {Name: "e", Scope: LocalScope, Index: 0},
|
||||
"f": {Name: "f", Scope: LocalScope, Index: 1},
|
||||
}
|
||||
|
||||
global := NewSymbolTable()
|
||||
|
||||
a := global.Define("a")
|
||||
if a != expected["a"] {
|
||||
t.Errorf("expected a=%+v, got=%+v", expected["a"], a)
|
||||
}
|
||||
|
||||
b := global.Define("b")
|
||||
if b != expected["b"] {
|
||||
t.Errorf("expected b=%+v, got=%+v", expected["b"], b)
|
||||
}
|
||||
|
||||
firstLocal := NewEnclosedSymbolTable(global)
|
||||
|
||||
c := firstLocal.Define("c")
|
||||
if c != expected["c"] {
|
||||
t.Errorf("expected c=%+v, got=%+v", expected["c"], c)
|
||||
}
|
||||
|
||||
d := firstLocal.Define("d")
|
||||
if d != expected["d"] {
|
||||
t.Errorf("expected d=%+v, got=%+v", expected["d"], d)
|
||||
}
|
||||
|
||||
secondLocal := NewEnclosedSymbolTable(firstLocal)
|
||||
|
||||
e := secondLocal.Define("e")
|
||||
if e != expected["e"] {
|
||||
t.Errorf("expected e=%+v, got=%+v", expected["e"], e)
|
||||
}
|
||||
|
||||
f := secondLocal.Define("f")
|
||||
if f != expected["f"] {
|
||||
t.Errorf("expected f=%+v, got=%+v", expected["f"], f)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveGlobal(t *testing.T) {
|
||||
global := NewSymbolTable()
|
||||
global.Define("a")
|
||||
global.Define("b")
|
||||
|
||||
expected := []Symbol{
|
||||
Symbol{
|
||||
Name: "a",
|
||||
Scope: GlobalScope,
|
||||
Index: 0,
|
||||
},
|
||||
Symbol{
|
||||
Name: "b",
|
||||
Scope: GlobalScope,
|
||||
Index: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, sym := range expected {
|
||||
result, ok := global.Resolve(sym.Name)
|
||||
if !ok {
|
||||
t.Errorf("name %s not resolvable", sym.Name)
|
||||
}
|
||||
if result != sym {
|
||||
t.Errorf("expected %s to resolve to %+v, got=%+v", sym.Name, sym, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveLocal(t *testing.T) {
|
||||
global := NewSymbolTable()
|
||||
global.Define("a")
|
||||
global.Define("b")
|
||||
|
||||
local := NewEnclosedSymbolTable(global)
|
||||
local.Define("c")
|
||||
local.Define("d")
|
||||
|
||||
expected := []Symbol{
|
||||
Symbol{Name: "a", Scope: GlobalScope, Index: 0},
|
||||
Symbol{Name: "b", Scope: GlobalScope, Index: 1},
|
||||
Symbol{Name: "c", Scope: LocalScope, Index: 0},
|
||||
Symbol{Name: "d", Scope: LocalScope, Index: 1},
|
||||
}
|
||||
|
||||
for _, sym := range expected {
|
||||
result, ok := local.Resolve(sym.Name)
|
||||
if !ok {
|
||||
t.Errorf("name %s not resolvable", sym.Name)
|
||||
continue
|
||||
}
|
||||
if result != sym {
|
||||
t.Errorf("expected %s to resolve to %+v, got=%+v",
|
||||
sym.Name, sym, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveNestedLocal(t *testing.T) {
|
||||
global := NewSymbolTable()
|
||||
global.Define("a")
|
||||
global.Define("b")
|
||||
|
||||
firstLocal := NewEnclosedSymbolTable(global)
|
||||
firstLocal.Define("c")
|
||||
firstLocal.Define("d")
|
||||
|
||||
secondLocal := NewEnclosedSymbolTable(global)
|
||||
secondLocal.Define("e")
|
||||
secondLocal.Define("f")
|
||||
|
||||
tests := []struct {
|
||||
table *SymbolTable
|
||||
expectedSymbols []Symbol
|
||||
}{
|
||||
{
|
||||
firstLocal,
|
||||
[]Symbol{
|
||||
Symbol{Name: "a", Scope: GlobalScope, Index: 0},
|
||||
Symbol{Name: "b", Scope: GlobalScope, Index: 1},
|
||||
Symbol{Name: "c", Scope: LocalScope, Index: 0},
|
||||
Symbol{Name: "d", Scope: LocalScope, Index: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
secondLocal,
|
||||
[]Symbol{
|
||||
Symbol{Name: "a", Scope: GlobalScope, Index: 0},
|
||||
Symbol{Name: "b", Scope: GlobalScope, Index: 1},
|
||||
Symbol{Name: "e", Scope: LocalScope, Index: 0},
|
||||
Symbol{Name: "f", Scope: LocalScope, Index: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
for _, sym := range tt.expectedSymbols {
|
||||
result, ok := tt.table.Resolve(sym.Name)
|
||||
if !ok {
|
||||
t.Errorf("name %s not resolvable", sym.Name)
|
||||
continue
|
||||
}
|
||||
if result != sym {
|
||||
t.Errorf("expected %s to resolve to %+v, got=%+v",
|
||||
sym.Name, sym, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefineResolveBuiltins(t *testing.T) {
|
||||
global := NewSymbolTable()
|
||||
firstLocal := NewEnclosedSymbolTable(global)
|
||||
secondLocal := NewEnclosedSymbolTable(firstLocal)
|
||||
|
||||
expected := []Symbol{
|
||||
Symbol{Name: "a", Scope: BuiltinScope, Index: 0},
|
||||
Symbol{Name: "c", Scope: BuiltinScope, Index: 1},
|
||||
Symbol{Name: "e", Scope: BuiltinScope, Index: 2},
|
||||
Symbol{Name: "f", Scope: BuiltinScope, Index: 3},
|
||||
}
|
||||
|
||||
for i, v := range expected {
|
||||
global.DefineBuiltin(i, v.Name)
|
||||
}
|
||||
|
||||
for _, table := range []*SymbolTable{global, firstLocal, secondLocal} {
|
||||
for _, sym := range expected {
|
||||
result, ok := table.Resolve(sym.Name)
|
||||
if !ok {
|
||||
t.Errorf("name %s not resolvable", sym.Name)
|
||||
continue
|
||||
}
|
||||
if result != sym {
|
||||
t.Errorf("expected %s to resolve to %+v, got=%+v",
|
||||
sym.Name, sym, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveFree(t *testing.T) {
|
||||
global := NewSymbolTable()
|
||||
global.Define("a")
|
||||
global.Define("b")
|
||||
|
||||
firstLocal := NewEnclosedSymbolTable(global)
|
||||
firstLocal.Define("c")
|
||||
firstLocal.Define("d")
|
||||
|
||||
secondLocal := NewEnclosedSymbolTable(firstLocal)
|
||||
secondLocal.Define("e")
|
||||
secondLocal.Define("f")
|
||||
|
||||
tests := []struct {
|
||||
table *SymbolTable
|
||||
expectedSymbols []Symbol
|
||||
expectedFreeSymbols []Symbol
|
||||
}{
|
||||
{
|
||||
firstLocal,
|
||||
[]Symbol{
|
||||
Symbol{Name: "a", Scope: GlobalScope, Index: 0},
|
||||
Symbol{Name: "b", Scope: GlobalScope, Index: 1},
|
||||
Symbol{Name: "c", Scope: LocalScope, Index: 0},
|
||||
Symbol{Name: "d", Scope: LocalScope, Index: 1},
|
||||
},
|
||||
[]Symbol{},
|
||||
},
|
||||
{
|
||||
secondLocal,
|
||||
[]Symbol{
|
||||
Symbol{Name: "a", Scope: GlobalScope, Index: 0},
|
||||
Symbol{Name: "b", Scope: GlobalScope, Index: 1},
|
||||
Symbol{Name: "c", Scope: FreeScope, Index: 0},
|
||||
Symbol{Name: "d", Scope: FreeScope, Index: 1},
|
||||
Symbol{Name: "e", Scope: LocalScope, Index: 0},
|
||||
Symbol{Name: "f", Scope: LocalScope, Index: 1},
|
||||
},
|
||||
[]Symbol{
|
||||
Symbol{Name: "c", Scope: LocalScope, Index: 0},
|
||||
Symbol{Name: "d", Scope: LocalScope, Index: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
for _, sym := range tt.expectedSymbols {
|
||||
result, ok := tt.table.Resolve(sym.Name)
|
||||
if !ok {
|
||||
t.Errorf("name %s not resolvable", sym.Name)
|
||||
continue
|
||||
}
|
||||
if result != sym {
|
||||
t.Errorf("expected %s to resolve to %+v, got=%+v",
|
||||
sym.Name, sym, result)
|
||||
}
|
||||
}
|
||||
|
||||
if len(tt.table.FreeSymbols) != len(tt.expectedFreeSymbols) {
|
||||
t.Errorf("wrong number of free symbols. got=%d, want=%d",
|
||||
len(tt.table.FreeSymbols), len(tt.expectedFreeSymbols))
|
||||
continue
|
||||
}
|
||||
|
||||
for i, sym := range tt.expectedFreeSymbols {
|
||||
result := tt.table.FreeSymbols[i]
|
||||
if result != sym {
|
||||
t.Errorf("wrong free symbol. got=%+v, want=%+v",
|
||||
result, sym)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUnresolvableFree(t *testing.T) {
|
||||
global := NewSymbolTable()
|
||||
global.Define("a")
|
||||
|
||||
firstLocal := NewEnclosedSymbolTable(global)
|
||||
firstLocal.Define("c")
|
||||
|
||||
secondLocal := NewEnclosedSymbolTable(firstLocal)
|
||||
secondLocal.Define("e")
|
||||
secondLocal.Define("f")
|
||||
|
||||
expected := []Symbol{
|
||||
Symbol{Name: "a", Scope: GlobalScope, Index: 0},
|
||||
Symbol{Name: "c", Scope: FreeScope, Index: 0},
|
||||
Symbol{Name: "e", Scope: LocalScope, Index: 0},
|
||||
Symbol{Name: "f", Scope: LocalScope, Index: 1},
|
||||
}
|
||||
|
||||
for _, sym := range expected {
|
||||
result, ok := secondLocal.Resolve(sym.Name)
|
||||
if !ok {
|
||||
t.Errorf("name %s not resolvable", sym.Name)
|
||||
continue
|
||||
}
|
||||
if result != sym {
|
||||
t.Errorf("expected %s to resolve to %+v, got=%+v",
|
||||
sym.Name, sym, result)
|
||||
}
|
||||
}
|
||||
|
||||
expectedUnresolvable := []string{
|
||||
"b",
|
||||
"d",
|
||||
}
|
||||
|
||||
for _, name := range expectedUnresolvable {
|
||||
_, ok := secondLocal.Resolve(name)
|
||||
if ok {
|
||||
t.Errorf("name %s resolved, but was expected not to", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefineAndResolveFunctionName(t *testing.T) {
|
||||
global := NewSymbolTable()
|
||||
global.DefineFunctionName("a")
|
||||
|
||||
expected := Symbol{Name: "a", Scope: FunctionScope, Index: 0}
|
||||
|
||||
result, ok := global.Resolve(expected.Name)
|
||||
if !ok {
|
||||
t.Fatalf("function name %s not resolvable", expected.Name)
|
||||
}
|
||||
|
||||
if result != expected {
|
||||
t.Errorf("expected %s to resolve to %+v, got=%+v",
|
||||
expected.Name, expected, result)
|
||||
}
|
||||
}
|
||||
func TestShadowingFunctionName(t *testing.T) {
|
||||
global := NewSymbolTable()
|
||||
global.DefineFunctionName("a")
|
||||
global.Define("a")
|
||||
|
||||
expected := Symbol{Name: "a", Scope: GlobalScope, Index: 0}
|
||||
|
||||
result, ok := global.Resolve(expected.Name)
|
||||
if !ok {
|
||||
t.Fatalf("function name %s not resolvable", expected.Name)
|
||||
}
|
||||
|
||||
if result != expected {
|
||||
t.Errorf("expected %s to resolve to %+v, got=%+v",
|
||||
expected.Name, expected, result)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user