package compiler import ( "fmt" "log" "monkey/ast" "monkey/code" "monkey/object" "sort" "strings" ) type EmittedInstruction struct { Opcode code.Opcode Position int } type CompilationScope struct { instructions code.Instructions lastInstruction EmittedInstruction previousInstruction EmittedInstruction } type Compiler struct { Debug bool l int constants []object.Object symbolTable *SymbolTable scopes []CompilationScope scopeIndex int } func New() *Compiler { mainScope := CompilationScope{ instructions: code.Instructions{}, lastInstruction: EmittedInstruction{}, previousInstruction: EmittedInstruction{}, } symbolTable := NewSymbolTable() for i, builtin := range object.BuiltinsIndex { symbolTable.DefineBuiltin(i, builtin.Name) } return &Compiler{ constants: []object.Object{}, symbolTable: symbolTable, scopes: []CompilationScope{mainScope}, scopeIndex: 0, } } func (c *Compiler) currentInstructions() code.Instructions { return c.scopes[c.scopeIndex].instructions } func NewWithState(s *SymbolTable, constants []object.Object) *Compiler { compiler := New() compiler.symbolTable = s compiler.constants = constants return compiler } func (c *Compiler) Compile(node ast.Node) error { if c.Debug { log.Printf( "%sCompiling %T: %s\n", strings.Repeat(" ", c.l), node, node.String(), ) } switch node := node.(type) { case *ast.Program: c.l++ for _, s := range node.Statements { err := c.Compile(s) if err != nil { return err } } case *ast.ExpressionStatement: c.l++ err := c.Compile(node.Expression) c.l-- if err != nil { return err } c.emit(code.OpPop) case *ast.InfixExpression: if node.Operator == "<" || node.Operator == "<=" { c.l++ err := c.Compile(node.Right) c.l-- if err != nil { return err } c.l++ err = c.Compile(node.Left) c.l-- if err != nil { return err } if node.Operator == "<=" { c.emit(code.OpGreaterThanEqual) } else { c.emit(code.OpGreaterThan) } return nil } c.l++ err := c.Compile(node.Left) c.l-- if err != nil { return err } c.l++ err = c.Compile(node.Right) c.l-- if err != nil { return err } switch node.Operator { case "+": c.emit(code.OpAdd) case "-": c.emit(code.OpSub) case "*": c.emit(code.OpMul) case "/": c.emit(code.OpDiv) case "%": c.emit(code.OpMod) case ">": c.emit(code.OpGreaterThan) case ">=": c.emit(code.OpGreaterThanEqual) case "==": c.emit(code.OpEqual) case "!=": c.emit(code.OpNotEqual) default: return fmt.Errorf("unknown operator %s", node.Operator) } case *ast.IntegerLiteral: integer := &object.Integer{Value: node.Value} c.emit(code.OpConstant, c.addConstant(integer)) case *ast.Null: c.emit(code.OpNull) case *ast.Boolean: if node.Value { c.emit(code.OpTrue) } else { c.emit(code.OpFalse) } case *ast.PrefixExpression: c.l++ err := c.Compile(node.Right) c.l-- if err != nil { return err } switch node.Operator { case "!": c.emit(code.OpBang) case "-": c.emit(code.OpMinus) default: return fmt.Errorf("unknown operator %s", node.Operator) } case *ast.IfExpression: c.l++ err := c.Compile(node.Condition) c.l-- if err != nil { return err } // Emit an `OpJumpNotTruthy` with a bogus value jumpNotTruthyPos := c.emit(code.OpJumpNotTruthy, 9999) c.l++ err = c.Compile(node.Consequence) c.l-- if err != nil { return err } if c.lastInstructionIs(code.OpPop) { c.removeLastPop() } // Emit an `OnJump` with a bogus value jumpPos := c.emit(code.OpJump, 9999) afterConsequencePos := len(c.currentInstructions()) c.changeOperand(jumpNotTruthyPos, afterConsequencePos) if node.Alternative == nil { c.emit(code.OpNull) } else { c.l++ err := c.Compile(node.Alternative) c.l-- if err != nil { return err } if c.lastInstructionIs(code.OpPop) { c.removeLastPop() } } afterAlternativePos := len(c.currentInstructions()) c.changeOperand(jumpPos, afterAlternativePos) case *ast.BlockStatement: c.l++ for _, s := range node.Statements { err := c.Compile(s) if err != nil { return err } } c.l-- if c.lastInstructionIs(code.OpPop) { c.removeLastPop() } else { if !c.lastInstructionIs(code.OpReturn) { c.emit(code.OpNull) } } case *ast.AssignmentExpression: if ident, ok := node.Left.(*ast.Identifier); ok { symbol, ok := c.symbolTable.Resolve(ident.Value) if !ok { return fmt.Errorf("undefined variable %s", ident.Value) } c.l++ err := c.Compile(node.Value) c.l-- if err != nil { return err } if symbol.Scope == GlobalScope { c.emit(code.OpAssignGlobal, symbol.Index) } else { c.emit(code.OpAssignLocal, symbol.Index) } } else if ie, ok := node.Left.(*ast.IndexExpression); ok { c.l++ err := c.Compile(ie.Left) c.l-- if err != nil { return err } c.l++ err = c.Compile(ie.Index) c.l-- if err != nil { return err } c.l++ err = c.Compile(node.Value) c.l-- if err != nil { return err } c.emit(code.OpSetItem) } else { return fmt.Errorf("expected identifier or index expression got=%s", node.Left) } case *ast.BindExpression: var symbol Symbol if ident, ok := node.Left.(*ast.Identifier); ok { symbol, ok = c.symbolTable.Resolve(ident.Value) if !ok { symbol = c.symbolTable.Define(ident.Value) } else { // Local shadowing of previously defined "free" variable in a // function now being rebound to a locally scoped variable. if symbol.Scope == FreeScope { symbol = c.symbolTable.Define(ident.Value) } } c.l++ err := c.Compile(node.Value) c.l-- if err != nil { return err } if symbol.Scope == GlobalScope { c.emit(code.OpSetGlobal, symbol.Index) } else { c.emit(code.OpSetLocal, symbol.Index) } } else { return fmt.Errorf("expected identifier got=%s", node.Left) } case *ast.Identifier: symbol, ok := c.symbolTable.Resolve(node.Value) if !ok { return fmt.Errorf("undefined varible %s", node.Value) } c.loadSymbol(symbol) case *ast.StringLiteral: str := &object.String{Value: node.Value} c.emit(code.OpConstant, c.addConstant(str)) case *ast.ArrayLiteral: for _, el := range node.Elements { c.l++ err := c.Compile(el) c.l-- if err != nil { return err } } c.emit(code.OpArray, len(node.Elements)) case *ast.HashLiteral: keys := []ast.Expression{} for k := range node.Pairs { keys = append(keys, k) } sort.Slice(keys, func(i, j int) bool { return keys[i].String() < keys[j].String() }) for _, k := range keys { c.l++ err := c.Compile(k) c.l-- if err != nil { return err } c.l++ err = c.Compile(node.Pairs[k]) c.l-- if err != nil { return err } } c.emit(code.OpHash, len(node.Pairs)*2) case *ast.IndexExpression: c.l++ err := c.Compile(node.Left) c.l-- if err != nil { return err } c.l++ err = c.Compile(node.Index) c.l-- if err != nil { return err } c.emit(code.OpGetItem) case *ast.FunctionLiteral: c.enterScope() if node.Name != "" { c.symbolTable.DefineFunctionName(node.Name) } for _, p := range node.Parameters { c.symbolTable.Define(p.Value) } c.l++ err := c.Compile(node.Body) c.l-- if err != nil { return err } if c.lastInstructionIs(code.OpPop) { c.replaceLastPopWithReturn() } // If the function doesn't end with a return statement add one with a // `return null;` and also handle the edge-case of empty functions. if !c.lastInstructionIs(code.OpReturn) { if !c.lastInstructionIs(code.OpNull) { c.emit(code.OpNull) } c.emit(code.OpReturn) } freeSymbols := c.symbolTable.FreeSymbols numLocals := c.symbolTable.numDefinitions instructions := c.leaveScope() for _, s := range freeSymbols { c.loadSymbol(s) } compiledFn := &object.CompiledFunction{ Instructions: instructions, NumLocals: numLocals, NumParameters: len(node.Parameters), } fnIndex := c.addConstant(compiledFn) c.emit(code.OpClosure, fnIndex, len(freeSymbols)) case *ast.ReturnStatement: c.l++ err := c.Compile(node.ReturnValue) c.l-- if err != nil { return err } c.emit(code.OpReturn) case *ast.CallExpression: c.l++ err := c.Compile(node.Function) c.l-- if err != nil { return err } for _, a := range node.Arguments { err := c.Compile(a) if err != nil { return err } } c.emit(code.OpCall, len(node.Arguments)) case *ast.WhileExpression: jumpConditionPos := len(c.currentInstructions()) c.l++ err := c.Compile(node.Condition) c.l-- if err != nil { return err } // Emit an `OpJump`with a bogus value jumpIfFalsePos := c.emit(code.OpJumpNotTruthy, 0xFFFF) c.l++ err = c.Compile(node.Consequence) if err != nil { return err } c.l-- // Pop off the LoadNull(s) from ast.BlockStatement(s) c.emit(code.OpPop) c.emit(code.OpJump, jumpConditionPos) afterConsequencePos := c.emit(code.OpNull) c.changeOperand(jumpIfFalsePos, afterConsequencePos) } return nil } func (c *Compiler) addConstant(obj object.Object) int { c.constants = append(c.constants, obj) return len(c.constants) - 1 } func (c *Compiler) emit(op code.Opcode, operands ...int) int { ins := code.Make(op, operands...) pos := c.addInstruction(ins) c.setLastInstruction(op, pos) return pos } func (c *Compiler) setLastInstruction(op code.Opcode, pos int) { previous := c.scopes[c.scopeIndex].lastInstruction last := EmittedInstruction{Opcode: op, Position: pos} c.scopes[c.scopeIndex].previousInstruction = previous c.scopes[c.scopeIndex].lastInstruction = last } func (c *Compiler) Bytecode() *Bytecode { return &Bytecode{ Instructions: c.currentInstructions(), Constants: c.constants, } } func (c *Compiler) addInstruction(ins []byte) int { postNewInstruction := len(c.currentInstructions()) updatedInstructions := append(c.currentInstructions(), ins...) c.scopes[c.scopeIndex].instructions = updatedInstructions return postNewInstruction } func (c *Compiler) lastInstructionIs(op code.Opcode) bool { if len(c.currentInstructions()) == 0 { return false } return c.scopes[c.scopeIndex].lastInstruction.Opcode == op } func (c *Compiler) removeLastPop() { last := c.scopes[c.scopeIndex].lastInstruction previous := c.scopes[c.scopeIndex].previousInstruction old := c.currentInstructions() new := old[:last.Position] c.scopes[c.scopeIndex].instructions = new c.scopes[c.scopeIndex].lastInstruction = previous } func (c *Compiler) replaceInstruction(pos int, newInstruction []byte) { ins := c.currentInstructions() for i := 0; i < len(newInstruction); i++ { ins[pos+i] = newInstruction[i] } } func (c *Compiler) changeOperand(opPos int, operand int) { op := code.Opcode(c.currentInstructions()[opPos]) newInstruction := code.Make(op, operand) c.replaceInstruction(opPos, newInstruction) } func (c *Compiler) enterScope() { scope := CompilationScope{ instructions: code.Instructions{}, lastInstruction: EmittedInstruction{}, previousInstruction: EmittedInstruction{}, } c.scopes = append(c.scopes, scope) c.scopeIndex++ c.symbolTable = NewEnclosedSymbolTable(c.symbolTable) } func (c *Compiler) leaveScope() code.Instructions { instructions := c.currentInstructions() c.scopes = c.scopes[:len(c.scopes)-1] c.scopeIndex-- c.symbolTable = c.symbolTable.Outer return instructions } func (c *Compiler) replaceLastPopWithReturn() { lastPos := c.scopes[c.scopeIndex].lastInstruction.Position c.replaceInstruction(lastPos, code.Make(code.OpReturn)) c.scopes[c.scopeIndex].lastInstruction.Opcode = code.OpReturn } type Bytecode struct { Instructions code.Instructions Constants []object.Object } func (c *Compiler) loadSymbol(s Symbol) { switch s.Scope { case GlobalScope: c.emit(code.OpGetGlobal, s.Index) case LocalScope: c.emit(code.OpGetLocal, s.Index) case BuiltinScope: c.emit(code.OpGetBuiltin, s.Index) case FreeScope: c.emit(code.OpGetFree, s.Index) case FunctionScope: c.emit(code.OpCurrentClosure) } }