Files
monkey/compiler/compiler.go
Chuck Smith e56fb40f83
Some checks failed
Build / build (push) Failing after 1m57s
Test / build (push) Successful in 2m15s
compile functions
2024-02-28 16:57:01 -05:00

399 lines
8.0 KiB
Go

package compiler
import (
"fmt"
"monkey/ast"
"monkey/code"
"monkey/object"
"sort"
)
type EmittedInstruction struct {
Opcode code.Opcode
Position int
}
type CompilationScope struct {
instructions code.Instructions
lastInstruction EmittedInstruction
previousInstruction EmittedInstruction
}
type Compiler struct {
constants []object.Object
symbolTable *SymbolTable
scopes []CompilationScope
scopeIndex int
}
func New() *Compiler {
mainScope := CompilationScope{
instructions: code.Instructions{},
lastInstruction: EmittedInstruction{},
previousInstruction: EmittedInstruction{},
}
return &Compiler{
constants: []object.Object{},
symbolTable: NewSymbolTable(),
scopes: []CompilationScope{mainScope},
scopeIndex: 0,
}
}
func (c *Compiler) currentInstructions() code.Instructions {
return c.scopes[c.scopeIndex].instructions
}
func NewWithState(s *SymbolTable, constants []object.Object) *Compiler {
compiler := New()
compiler.symbolTable = s
compiler.constants = constants
return compiler
}
func (c *Compiler) Compile(node ast.Node) error {
switch node := node.(type) {
case *ast.Program:
for _, s := range node.Statements {
err := c.Compile(s)
if err != nil {
return err
}
}
case *ast.ExpressionStatement:
err := c.Compile(node.Expression)
if err != nil {
return err
}
c.emit(code.OpPop)
case *ast.InfixExpression:
if node.Operator == "<" {
err := c.Compile(node.Right)
if err != nil {
return err
}
err = c.Compile(node.Left)
if err != nil {
return err
}
c.emit(code.OpGreaterThan)
return nil
}
err := c.Compile(node.Left)
if err != nil {
return err
}
err = c.Compile(node.Right)
if err != nil {
return err
}
switch node.Operator {
case "+":
c.emit(code.OpAdd)
case "-":
c.emit(code.OpSub)
case "*":
c.emit(code.OpMul)
case "/":
c.emit(code.OpDiv)
case ">":
c.emit(code.OpGreaterThan)
case "==":
c.emit(code.OpEqual)
case "!=":
c.emit(code.OpNotEqual)
default:
return fmt.Errorf("unknown operator %s", node.Operator)
}
case *ast.IntegerLiteral:
integer := &object.Integer{Value: node.Value}
c.emit(code.OpConstant, c.addConstant(integer))
case *ast.Boolean:
if node.Value {
c.emit(code.OpTrue)
} else {
c.emit(code.OpFalse)
}
case *ast.PrefixExpression:
err := c.Compile(node.Right)
if err != nil {
return err
}
switch node.Operator {
case "!":
c.emit(code.OpBang)
case "-":
c.emit(code.OpMinus)
default:
return fmt.Errorf("unknown operator %s", node.Operator)
}
case *ast.IfExpression:
err := c.Compile(node.Condition)
if err != nil {
return err
}
// Emit an `OpJumpNotTruthy` with a bogus value
jumpNotTruthyPos := c.emit(code.OpJumpNotTruthy, 9999)
err = c.Compile(node.Consequence)
if err != nil {
return err
}
if c.lastInstructionIs(code.OpPop) {
c.removeLastPop()
}
// Emit an `OnJump` with a bogus value
jumpPos := c.emit(code.OpJump, 9999)
afterConsequencePos := len(c.currentInstructions())
c.changeOperand(jumpNotTruthyPos, afterConsequencePos)
if node.Alternative == nil {
c.emit(code.OpNull)
} else {
err = c.Compile(node.Alternative)
if err != nil {
return err
}
if c.lastInstructionIs(code.OpPop) {
c.removeLastPop()
}
}
afterAlternativePos := len(c.currentInstructions())
c.changeOperand(jumpPos, afterAlternativePos)
case *ast.BlockStatement:
for _, s := range node.Statements {
err := c.Compile(s)
if err != nil {
return err
}
}
case *ast.LetStatement:
err := c.Compile(node.Value)
if err != nil {
return err
}
symbol := c.symbolTable.Define(node.Name.Value)
c.emit(code.OpSetGlobal, symbol.Index)
case *ast.Identifier:
symbol, ok := c.symbolTable.Resolve(node.Value)
if !ok {
return fmt.Errorf("undefined varible %s", node.Value)
}
c.emit(code.OpGetGlobal, symbol.Index)
case *ast.StringLiteral:
str := &object.String{Value: node.Value}
c.emit(code.OpConstant, c.addConstant(str))
case *ast.ArrayLiteral:
for _, el := range node.Elements {
err := c.Compile(el)
if err != nil {
return err
}
}
c.emit(code.OpArray, len(node.Elements))
case *ast.HashLiteral:
keys := []ast.Expression{}
for k := range node.Pairs {
keys = append(keys, k)
}
sort.Slice(keys, func(i, j int) bool {
return keys[i].String() < keys[j].String()
})
for _, k := range keys {
err := c.Compile(k)
if err != nil {
return err
}
err = c.Compile(node.Pairs[k])
if err != nil {
return err
}
}
c.emit(code.OpHash, len(node.Pairs)*2)
case *ast.IndexExpression:
err := c.Compile(node.Left)
if err != nil {
return err
}
err = c.Compile(node.Index)
if err != nil {
return err
}
c.emit(code.OpIndex)
case *ast.FunctionLiteral:
c.enterScope()
err := c.Compile(node.Body)
if err != nil {
return err
}
if c.lastInstructionIs(code.OpPop) {
c.replaceLastPopWithReturn()
}
if !c.lastInstructionIs(code.OpReturnValue) {
c.emit(code.OpReturn)
}
instructions := c.leaveScope()
compiledFn := &object.CompiledFunction{Instructions: instructions}
c.emit(code.OpConstant, c.addConstant(compiledFn))
case *ast.ReturnStatement:
err := c.Compile(node.ReturnValue)
if err != nil {
return err
}
c.emit(code.OpReturnValue)
case *ast.CallExpression:
err := c.Compile(node.Function)
if err != nil {
return err
}
c.emit(code.OpCall)
}
return nil
}
func (c *Compiler) addConstant(obj object.Object) int {
c.constants = append(c.constants, obj)
return len(c.constants) - 1
}
func (c *Compiler) emit(op code.Opcode, operands ...int) int {
ins := code.Make(op, operands...)
pos := c.addInstruction(ins)
c.setLastInstruction(op, pos)
return pos
}
func (c *Compiler) setLastInstruction(op code.Opcode, pos int) {
previous := c.scopes[c.scopeIndex].lastInstruction
last := EmittedInstruction{Opcode: op, Position: pos}
c.scopes[c.scopeIndex].previousInstruction = previous
c.scopes[c.scopeIndex].lastInstruction = last
}
func (c *Compiler) Bytecode() *Bytecode {
return &Bytecode{
Instructions: c.currentInstructions(),
Constants: c.constants,
}
}
func (c *Compiler) addInstruction(ins []byte) int {
postNewInstruction := len(c.currentInstructions())
updatedInstructions := append(c.currentInstructions(), ins...)
c.scopes[c.scopeIndex].instructions = updatedInstructions
return postNewInstruction
}
func (c *Compiler) lastInstructionIs(op code.Opcode) bool {
if len(c.currentInstructions()) == 0 {
return false
}
return c.scopes[c.scopeIndex].lastInstruction.Opcode == op
}
func (c *Compiler) removeLastPop() {
last := c.scopes[c.scopeIndex].lastInstruction
previous := c.scopes[c.scopeIndex].previousInstruction
old := c.currentInstructions()
new := old[:last.Position]
c.scopes[c.scopeIndex].instructions = new
c.scopes[c.scopeIndex].lastInstruction = previous
}
func (c *Compiler) replaceInstruction(pos int, newInstruction []byte) {
ins := c.currentInstructions()
for i := 0; i < len(newInstruction); i++ {
ins[pos+i] = newInstruction[i]
}
}
func (c *Compiler) changeOperand(opPos int, operand int) {
op := code.Opcode(c.currentInstructions()[opPos])
newInstruction := code.Make(op, operand)
c.replaceInstruction(opPos, newInstruction)
}
func (c *Compiler) enterScope() {
scope := CompilationScope{
instructions: code.Instructions{},
lastInstruction: EmittedInstruction{},
previousInstruction: EmittedInstruction{},
}
c.scopes = append(c.scopes, scope)
c.scopeIndex++
}
func (c *Compiler) leaveScope() code.Instructions {
instructions := c.currentInstructions()
c.scopes = c.scopes[:len(c.scopes)-1]
c.scopeIndex--
return instructions
}
func (c *Compiler) replaceLastPopWithReturn() {
lastPos := c.scopes[c.scopeIndex].lastInstruction.Position
c.replaceInstruction(lastPos, code.Make(code.OpReturnValue))
c.scopes[c.scopeIndex].lastInstruction.Opcode = code.OpReturnValue
}
type Bytecode struct {
Instructions code.Instructions
Constants []object.Object
}