package compiler import ( "fmt" "monkey/ast" "monkey/code" "monkey/object" "sort" ) type EmittedInstruction struct { Opcode code.Opcode Position int } type CompilationScope struct { instructions code.Instructions lastInstruction EmittedInstruction previousInstruction EmittedInstruction } type Compiler struct { constants []object.Object symbolTable *SymbolTable scopes []CompilationScope scopeIndex int } func New() *Compiler { mainScope := CompilationScope{ instructions: code.Instructions{}, lastInstruction: EmittedInstruction{}, previousInstruction: EmittedInstruction{}, } return &Compiler{ constants: []object.Object{}, symbolTable: NewSymbolTable(), scopes: []CompilationScope{mainScope}, scopeIndex: 0, } } func (c *Compiler) currentInstructions() code.Instructions { return c.scopes[c.scopeIndex].instructions } func NewWithState(s *SymbolTable, constants []object.Object) *Compiler { compiler := New() compiler.symbolTable = s compiler.constants = constants return compiler } func (c *Compiler) Compile(node ast.Node) error { switch node := node.(type) { case *ast.Program: for _, s := range node.Statements { err := c.Compile(s) if err != nil { return err } } case *ast.ExpressionStatement: err := c.Compile(node.Expression) if err != nil { return err } c.emit(code.OpPop) case *ast.InfixExpression: if node.Operator == "<" { err := c.Compile(node.Right) if err != nil { return err } err = c.Compile(node.Left) if err != nil { return err } c.emit(code.OpGreaterThan) return nil } err := c.Compile(node.Left) if err != nil { return err } err = c.Compile(node.Right) if err != nil { return err } switch node.Operator { case "+": c.emit(code.OpAdd) case "-": c.emit(code.OpSub) case "*": c.emit(code.OpMul) case "/": c.emit(code.OpDiv) case ">": c.emit(code.OpGreaterThan) case "==": c.emit(code.OpEqual) case "!=": c.emit(code.OpNotEqual) default: return fmt.Errorf("unknown operator %s", node.Operator) } case *ast.IntegerLiteral: integer := &object.Integer{Value: node.Value} c.emit(code.OpConstant, c.addConstant(integer)) case *ast.Boolean: if node.Value { c.emit(code.OpTrue) } else { c.emit(code.OpFalse) } case *ast.PrefixExpression: err := c.Compile(node.Right) if err != nil { return err } switch node.Operator { case "!": c.emit(code.OpBang) case "-": c.emit(code.OpMinus) default: return fmt.Errorf("unknown operator %s", node.Operator) } case *ast.IfExpression: err := c.Compile(node.Condition) if err != nil { return err } // Emit an `OpJumpNotTruthy` with a bogus value jumpNotTruthyPos := c.emit(code.OpJumpNotTruthy, 9999) err = c.Compile(node.Consequence) if err != nil { return err } if c.lastInstructionIs(code.OpPop) { c.removeLastPop() } // Emit an `OnJump` with a bogus value jumpPos := c.emit(code.OpJump, 9999) afterConsequencePos := len(c.currentInstructions()) c.changeOperand(jumpNotTruthyPos, afterConsequencePos) if node.Alternative == nil { c.emit(code.OpNull) } else { err = c.Compile(node.Alternative) if err != nil { return err } if c.lastInstructionIs(code.OpPop) { c.removeLastPop() } } afterAlternativePos := len(c.currentInstructions()) c.changeOperand(jumpPos, afterAlternativePos) case *ast.BlockStatement: for _, s := range node.Statements { err := c.Compile(s) if err != nil { return err } } case *ast.LetStatement: err := c.Compile(node.Value) if err != nil { return err } symbol := c.symbolTable.Define(node.Name.Value) c.emit(code.OpSetGlobal, symbol.Index) case *ast.Identifier: symbol, ok := c.symbolTable.Resolve(node.Value) if !ok { return fmt.Errorf("undefined varible %s", node.Value) } c.emit(code.OpGetGlobal, symbol.Index) case *ast.StringLiteral: str := &object.String{Value: node.Value} c.emit(code.OpConstant, c.addConstant(str)) case *ast.ArrayLiteral: for _, el := range node.Elements { err := c.Compile(el) if err != nil { return err } } c.emit(code.OpArray, len(node.Elements)) case *ast.HashLiteral: keys := []ast.Expression{} for k := range node.Pairs { keys = append(keys, k) } sort.Slice(keys, func(i, j int) bool { return keys[i].String() < keys[j].String() }) for _, k := range keys { err := c.Compile(k) if err != nil { return err } err = c.Compile(node.Pairs[k]) if err != nil { return err } } c.emit(code.OpHash, len(node.Pairs)*2) case *ast.IndexExpression: err := c.Compile(node.Left) if err != nil { return err } err = c.Compile(node.Index) if err != nil { return err } c.emit(code.OpIndex) case *ast.FunctionLiteral: c.enterScope() err := c.Compile(node.Body) if err != nil { return err } if c.lastInstructionIs(code.OpPop) { c.replaceLastPopWithReturn() } if !c.lastInstructionIs(code.OpReturnValue) { c.emit(code.OpReturn) } instructions := c.leaveScope() compiledFn := &object.CompiledFunction{Instructions: instructions} c.emit(code.OpConstant, c.addConstant(compiledFn)) case *ast.ReturnStatement: err := c.Compile(node.ReturnValue) if err != nil { return err } c.emit(code.OpReturnValue) case *ast.CallExpression: err := c.Compile(node.Function) if err != nil { return err } c.emit(code.OpCall) } return nil } func (c *Compiler) addConstant(obj object.Object) int { c.constants = append(c.constants, obj) return len(c.constants) - 1 } func (c *Compiler) emit(op code.Opcode, operands ...int) int { ins := code.Make(op, operands...) pos := c.addInstruction(ins) c.setLastInstruction(op, pos) return pos } func (c *Compiler) setLastInstruction(op code.Opcode, pos int) { previous := c.scopes[c.scopeIndex].lastInstruction last := EmittedInstruction{Opcode: op, Position: pos} c.scopes[c.scopeIndex].previousInstruction = previous c.scopes[c.scopeIndex].lastInstruction = last } func (c *Compiler) Bytecode() *Bytecode { return &Bytecode{ Instructions: c.currentInstructions(), Constants: c.constants, } } func (c *Compiler) addInstruction(ins []byte) int { postNewInstruction := len(c.currentInstructions()) updatedInstructions := append(c.currentInstructions(), ins...) c.scopes[c.scopeIndex].instructions = updatedInstructions return postNewInstruction } func (c *Compiler) lastInstructionIs(op code.Opcode) bool { if len(c.currentInstructions()) == 0 { return false } return c.scopes[c.scopeIndex].lastInstruction.Opcode == op } func (c *Compiler) removeLastPop() { last := c.scopes[c.scopeIndex].lastInstruction previous := c.scopes[c.scopeIndex].previousInstruction old := c.currentInstructions() new := old[:last.Position] c.scopes[c.scopeIndex].instructions = new c.scopes[c.scopeIndex].lastInstruction = previous } func (c *Compiler) replaceInstruction(pos int, newInstruction []byte) { ins := c.currentInstructions() for i := 0; i < len(newInstruction); i++ { ins[pos+i] = newInstruction[i] } } func (c *Compiler) changeOperand(opPos int, operand int) { op := code.Opcode(c.currentInstructions()[opPos]) newInstruction := code.Make(op, operand) c.replaceInstruction(opPos, newInstruction) } func (c *Compiler) enterScope() { scope := CompilationScope{ instructions: code.Instructions{}, lastInstruction: EmittedInstruction{}, previousInstruction: EmittedInstruction{}, } c.scopes = append(c.scopes, scope) c.scopeIndex++ } func (c *Compiler) leaveScope() code.Instructions { instructions := c.currentInstructions() c.scopes = c.scopes[:len(c.scopes)-1] c.scopeIndex-- return instructions } func (c *Compiler) replaceLastPopWithReturn() { lastPos := c.scopes[c.scopeIndex].lastInstruction.Position c.replaceInstruction(lastPos, code.Make(code.OpReturnValue)) c.scopes[c.scopeIndex].lastInstruction.Opcode = code.OpReturnValue } type Bytecode struct { Instructions code.Instructions Constants []object.Object }