Files
monkey/internal/compiler/compiler.go
Chuck Smith 88e3330856
Some checks failed
Build / build (push) Successful in 9m47s
Publish Image / publish (push) Failing after 49s
Test / build (push) Failing after 6m19s
optimizations
2024-04-02 14:32:03 -04:00

660 lines
13 KiB
Go

package compiler
import (
"fmt"
"log"
"monkey/internal/ast"
"monkey/internal/builtins"
"monkey/internal/code"
"monkey/internal/object"
"sort"
"strings"
)
type EmittedInstruction struct {
Opcode code.Opcode
Position int
}
type CompilationScope struct {
instructions code.Instructions
lastInstruction EmittedInstruction
previousInstruction EmittedInstruction
}
type Compiler struct {
Debug bool
fn string
input string
l int
constants *[]object.Object
symbolTable *SymbolTable
scopes []CompilationScope
scopeIndex int
}
func New() *Compiler {
mainScope := CompilationScope{
instructions: code.Instructions{},
lastInstruction: EmittedInstruction{},
previousInstruction: EmittedInstruction{},
}
symbolTable := NewSymbolTable()
for i, builtin := range builtins.BuiltinsIndex {
symbolTable.DefineBuiltin(i, builtin.Name)
}
return &Compiler{
constants: &[]object.Object{},
symbolTable: symbolTable,
scopes: []CompilationScope{mainScope},
scopeIndex: 0,
}
}
func (c *Compiler) currentInstructions() code.Instructions {
return c.scopes[c.scopeIndex].instructions
}
func NewWithState(s *SymbolTable, constants *[]object.Object) *Compiler {
compiler := New()
compiler.symbolTable = s
compiler.constants = constants
return compiler
}
func (c *Compiler) Compile(node ast.Node) error {
if c.Debug {
log.Printf(
"%sCompiling %T: %s\n",
strings.Repeat(" ", c.l), node, node.String(),
)
}
switch node := node.(type) {
case *ast.Program:
c.l++
for _, s := range node.Statements {
err := c.Compile(s)
if err != nil {
return err
}
}
c.emit(code.OpHalt)
case *ast.ExpressionStatement:
c.l++
err := c.Compile(node.Expression)
c.l--
if err != nil {
return err
}
c.emit(code.OpPop)
case *ast.InfixExpression:
if node.Operator == "<" || node.Operator == "<=" {
c.l++
err := c.Compile(node.Right)
c.l--
if err != nil {
return err
}
c.l++
err = c.Compile(node.Left)
c.l--
if err != nil {
return err
}
if node.Operator == "<=" {
c.emit(code.OpGreaterThanEqual)
} else {
c.emit(code.OpGreaterThan)
}
return nil
}
c.l++
err := c.Compile(node.Left)
c.l--
if err != nil {
return err
}
c.l++
err = c.Compile(node.Right)
c.l--
if err != nil {
return err
}
switch node.Operator {
case "+":
c.emit(code.OpAdd)
case "-":
c.emit(code.OpSub)
case "*":
c.emit(code.OpMul)
case "/":
c.emit(code.OpDiv)
case "%":
c.emit(code.OpMod)
case "|":
c.emit(code.OpBitwiseOR)
case "^":
c.emit(code.OpBitwiseXOR)
case "&":
c.emit(code.OpBitwiseAND)
case "<<":
c.emit(code.OpLeftShift)
case ">>":
c.emit(code.OpRightShift)
case "||":
c.emit(code.OpOr)
case "&&":
c.emit(code.OpAnd)
case ">":
c.emit(code.OpGreaterThan)
case ">=":
c.emit(code.OpGreaterThanEqual)
case "==":
c.emit(code.OpEqual)
case "!=":
c.emit(code.OpNotEqual)
default:
return fmt.Errorf("unknown operator %s", node.Operator)
}
case *ast.IntegerLiteral:
i := object.Integer{Value: node.Value}
c.emit(code.OpConstant, c.addConstant(i))
case *ast.FloatLiteral:
f := object.Float{Value: node.Value}
c.emit(code.OpConstant, c.addConstant(f))
case *ast.Null:
c.emit(code.OpNull)
case *ast.Boolean:
if node.Value {
c.emit(code.OpTrue)
} else {
c.emit(code.OpFalse)
}
case *ast.PrefixExpression:
c.l++
err := c.Compile(node.Right)
c.l--
if err != nil {
return err
}
switch node.Operator {
case "!":
c.emit(code.OpNot)
case "~":
c.emit(code.OpBitwiseNOT)
case "-":
c.emit(code.OpMinus)
default:
return fmt.Errorf("unknown operator %s", node.Operator)
}
case *ast.IfExpression:
c.l++
err := c.Compile(node.Condition)
c.l--
if err != nil {
return err
}
// Emit an `OpJumpNotTruthy` with a bogus value
jumpNotTruthyPos := c.emit(code.OpJumpNotTruthy, 9999)
c.l++
err = c.Compile(node.Consequence)
c.l--
if err != nil {
return err
}
if c.lastInstructionIs(code.OpPop) {
c.removeLastPop()
}
// Emit an `OnJump` with a bogus value
jumpPos := c.emit(code.OpJump, 9999)
afterConsequencePos := len(c.currentInstructions())
c.changeOperand(jumpNotTruthyPos, afterConsequencePos)
if node.Alternative == nil {
c.emit(code.OpNull)
} else {
c.l++
err := c.Compile(node.Alternative)
c.l--
if err != nil {
return err
}
if c.lastInstructionIs(code.OpPop) {
c.removeLastPop()
}
}
afterAlternativePos := len(c.currentInstructions())
c.changeOperand(jumpPos, afterAlternativePos)
case *ast.BlockStatement:
c.l++
for _, s := range node.Statements {
err := c.Compile(s)
if err != nil {
return err
}
}
c.l--
if c.lastInstructionIs(code.OpPop) {
c.removeLastPop()
} else {
if !c.lastInstructionIs(code.OpReturn) {
c.emit(code.OpNull)
}
}
case *ast.AssignmentExpression:
if ident, ok := node.Left.(*ast.Identifier); ok {
symbol, ok := c.symbolTable.Resolve(ident.Value)
if !ok {
return fmt.Errorf("undefined variable %s", ident.Value)
}
c.l++
err := c.Compile(node.Value)
c.l--
if err != nil {
return err
}
if symbol.Scope == GlobalScope {
c.emit(code.OpAssignGlobal, symbol.Index)
} else {
c.emit(code.OpAssignLocal, symbol.Index)
}
} else if ie, ok := node.Left.(*ast.IndexExpression); ok {
c.l++
err := c.Compile(ie.Left)
c.l--
if err != nil {
return err
}
c.l++
err = c.Compile(ie.Index)
c.l--
if err != nil {
return err
}
c.l++
err = c.Compile(node.Value)
c.l--
if err != nil {
return err
}
c.emit(code.OpSetItem)
} else {
return fmt.Errorf("expected identifier or index expression got=%s", node.Left)
}
case *ast.BindExpression:
var symbol Symbol
if ident, ok := node.Left.(*ast.Identifier); ok {
symbol, ok = c.symbolTable.Resolve(ident.Value)
if !ok {
symbol = c.symbolTable.Define(ident.Value)
} else {
// Local shadowing of previously defined "free" variable in a
// function now being rebound to a locally scoped variable.
if symbol.Scope == FreeScope {
symbol = c.symbolTable.Define(ident.Value)
}
}
c.l++
err := c.Compile(node.Value)
c.l--
if err != nil {
return err
}
if symbol.Scope == GlobalScope {
c.emit(code.OpSetGlobal, symbol.Index)
} else {
c.emit(code.OpSetLocal, symbol.Index)
}
} else {
return fmt.Errorf("expected identifier got=%s", node.Left)
}
case *ast.Identifier:
symbol, ok := c.symbolTable.Resolve(node.Value)
if !ok {
return fmt.Errorf("undefined varible %s", node.Value)
}
c.loadSymbol(symbol)
case *ast.StringLiteral:
str := object.String{Value: node.Value}
c.emit(code.OpConstant, c.addConstant(str))
case *ast.ArrayLiteral:
for _, el := range node.Elements {
c.l++
err := c.Compile(el)
c.l--
if err != nil {
return err
}
}
c.emit(code.OpArray, len(node.Elements))
case *ast.HashLiteral:
keys := []ast.Expression{}
for k := range node.Pairs {
keys = append(keys, k)
}
sort.Slice(keys, func(i, j int) bool {
return keys[i].String() < keys[j].String()
})
for _, k := range keys {
c.l++
err := c.Compile(k)
c.l--
if err != nil {
return err
}
c.l++
err = c.Compile(node.Pairs[k])
c.l--
if err != nil {
return err
}
}
c.emit(code.OpHash, len(node.Pairs)*2)
case *ast.IndexExpression:
c.l++
err := c.Compile(node.Left)
c.l--
if err != nil {
return err
}
c.l++
err = c.Compile(node.Index)
c.l--
if err != nil {
return err
}
c.emit(code.OpGetItem)
case *ast.FunctionLiteral:
c.enterScope()
if node.Name != "" {
c.symbolTable.DefineFunctionName(node.Name)
}
for _, p := range node.Parameters {
c.symbolTable.Define(p.Value)
}
c.l++
err := c.Compile(node.Body)
c.l--
if err != nil {
return err
}
if c.lastInstructionIs(code.OpPop) {
c.replaceLastPopWithReturn()
}
// If the function doesn't end with a return statement add one with a
// `return null;` and also handle the edge-case of empty functions.
if !c.lastInstructionIs(code.OpReturn) {
if !c.lastInstructionIs(code.OpNull) {
c.emit(code.OpNull)
}
c.emit(code.OpReturn)
}
freeSymbols := c.symbolTable.FreeSymbols
numLocals := c.symbolTable.NumDefinitions
instructions := c.leaveScope()
for _, s := range freeSymbols {
c.loadSymbol(s)
}
compiledFn := &object.CompiledFunction{
Instructions: instructions,
NumLocals: numLocals,
NumParameters: len(node.Parameters),
}
fnIndex := c.addConstant(compiledFn)
c.emit(code.OpClosure, fnIndex, len(freeSymbols))
case *ast.ReturnStatement:
c.l++
err := c.Compile(node.ReturnValue)
c.l--
if err != nil {
return err
}
c.emit(code.OpReturn)
case *ast.CallExpression:
c.l++
err := c.Compile(node.Function)
c.l--
if err != nil {
return err
}
for _, a := range node.Arguments {
err := c.Compile(a)
if err != nil {
return err
}
}
c.emit(code.OpCall, len(node.Arguments))
case *ast.WhileExpression:
jumpConditionPos := len(c.currentInstructions())
c.l++
err := c.Compile(node.Condition)
c.l--
if err != nil {
return err
}
// Emit an `OpJump`with a bogus value
jumpIfFalsePos := c.emit(code.OpJumpNotTruthy, 0xFFFF)
c.l++
err = c.Compile(node.Consequence)
if err != nil {
return err
}
c.l--
// Pop off the LoadNull(s) from ast.BlockStatement(s)
c.emit(code.OpPop)
c.emit(code.OpJump, jumpConditionPos)
afterConsequencePos := c.emit(code.OpNull)
c.changeOperand(jumpIfFalsePos, afterConsequencePos)
case *ast.ImportExpression:
c.l++
err := c.Compile(node.Name)
if err != nil {
return err
}
c.l--
c.emit(code.OpLoadModule)
}
return nil
}
func (c *Compiler) addConstant(obj object.Object) int {
*c.constants = append(*c.constants, obj)
return len(*c.constants) - 1
}
func (c *Compiler) emit(op code.Opcode, operands ...int) int {
ins := code.Make(op, operands...)
pos := c.addInstruction(ins)
c.setLastInstruction(op, pos)
return pos
}
func (c *Compiler) SetFileInfo(fn, input string) {
c.fn = fn
c.input = input
}
func (c *Compiler) setLastInstruction(op code.Opcode, pos int) {
previous := c.scopes[c.scopeIndex].lastInstruction
last := EmittedInstruction{Opcode: op, Position: pos}
c.scopes[c.scopeIndex].previousInstruction = previous
c.scopes[c.scopeIndex].lastInstruction = last
}
func (c *Compiler) Bytecode() *Bytecode {
return &Bytecode{
Instructions: c.currentInstructions(),
Constants: *c.constants,
}
}
func (c *Compiler) addInstruction(ins []byte) int {
postNewInstruction := len(c.currentInstructions())
updatedInstructions := append(c.currentInstructions(), ins...)
c.scopes[c.scopeIndex].instructions = updatedInstructions
return postNewInstruction
}
func (c *Compiler) lastInstructionIs(op code.Opcode) bool {
if len(c.currentInstructions()) == 0 {
return false
}
return c.scopes[c.scopeIndex].lastInstruction.Opcode == op
}
func (c *Compiler) removeLastPop() {
last := c.scopes[c.scopeIndex].lastInstruction
previous := c.scopes[c.scopeIndex].previousInstruction
old := c.currentInstructions()
new := old[:last.Position]
c.scopes[c.scopeIndex].instructions = new
c.scopes[c.scopeIndex].lastInstruction = previous
}
func (c *Compiler) replaceInstruction(pos int, newInstruction []byte) {
ins := c.currentInstructions()
for i := 0; i < len(newInstruction); i++ {
ins[pos+i] = newInstruction[i]
}
}
func (c *Compiler) changeOperand(opPos int, operand int) {
op := code.Opcode(c.currentInstructions()[opPos])
newInstruction := code.Make(op, operand)
c.replaceInstruction(opPos, newInstruction)
}
func (c *Compiler) enterScope() {
scope := CompilationScope{
instructions: code.Instructions{},
lastInstruction: EmittedInstruction{},
previousInstruction: EmittedInstruction{},
}
c.scopes = append(c.scopes, scope)
c.scopeIndex++
c.symbolTable = NewEnclosedSymbolTable(c.symbolTable)
}
func (c *Compiler) leaveScope() code.Instructions {
instructions := c.currentInstructions()
c.scopes = c.scopes[:len(c.scopes)-1]
c.scopeIndex--
c.symbolTable = c.symbolTable.Outer
return instructions
}
func (c *Compiler) replaceLastPopWithReturn() {
lastPos := c.scopes[c.scopeIndex].lastInstruction.Position
c.replaceInstruction(lastPos, code.Make(code.OpReturn))
c.scopes[c.scopeIndex].lastInstruction.Opcode = code.OpReturn
}
func (c *Compiler) loadSymbol(s Symbol) {
switch s.Scope {
case GlobalScope:
c.emit(code.OpGetGlobal, s.Index)
case LocalScope:
c.emit(code.OpGetLocal, s.Index)
case BuiltinScope:
c.emit(code.OpGetBuiltin, s.Index)
case FreeScope:
c.emit(code.OpGetFree, s.Index)
case FunctionScope:
c.emit(code.OpCurrentClosure)
}
}