Files
monkey/internal/vm/vm.go
Chuck Smith c8de195ac8
Some checks failed
Build / build (push) Failing after 6m4s
Publish Image / publish (push) Failing after 25s
Test / build (push) Failing after 5m49s
Halt to halt the VM
2024-03-30 14:17:20 -04:00

912 lines
18 KiB
Go

package vm
import (
"fmt"
"log"
"monkey/internal/builtins"
"monkey/internal/code"
"monkey/internal/compiler"
"monkey/internal/lexer"
"monkey/internal/object"
"monkey/internal/parser"
"monkey/internal/utils"
"os"
"path/filepath"
"strings"
"time"
"unicode"
)
const StackSize = 2048
const GlobalsSize = 65536
const MaxFrames = 1024
var Null = &object.Null{}
var True = &object.Boolean{Value: true}
var False = &object.Boolean{Value: false}
func isTruthy(obj object.Object) bool {
switch obj := obj.(type) {
case *object.Boolean:
return obj.Value
case *object.Null:
return false
default:
return true
}
}
func nativeBoolToBooleanObject(input bool) *object.Boolean {
if input {
return True
}
return False
}
// executeModule compiles the named module and returns a *object.Module object
func executeModule(name string, state *VMState) (object.Object, error) {
filename := utils.FindModule(name)
if filename == "" {
return nil, fmt.Errorf("ImportError: no module named '%s'", name)
}
b, err := os.ReadFile(filename)
if err != nil {
return nil, fmt.Errorf("IOError: error reading module '%s': %s", name, err)
}
l := lexer.New(string(b))
p := parser.New(fmt.Sprintf("<module %s>", name), l)
module := p.ParseProgram()
if len(p.Errors()) != 0 {
return nil, fmt.Errorf("ParseError: %s", p.Errors())
}
c := compiler.NewWithState(state.Symbols, &state.Constants)
err = c.Compile(module)
if err != nil {
return nil, fmt.Errorf("CompileError: %s", err)
}
code := c.Bytecode()
state.Constants = code.Constants
machine := NewWithState(fmt.Sprintf("<module %s>", name), code, state)
err = machine.Run()
if err != nil {
return nil, fmt.Errorf("RuntimeError: error loading module '%s'", err)
}
return state.ExportedHash(), nil
}
type VMState struct {
Constants []object.Object
Globals []object.Object
Symbols *compiler.SymbolTable
}
func NewVMState() *VMState {
symbolTable := compiler.NewSymbolTable()
for i, builtin := range builtins.BuiltinsIndex {
symbolTable.DefineBuiltin(i, builtin.Name)
}
return &VMState{
Constants: []object.Object{},
Globals: make([]object.Object, GlobalsSize),
Symbols: symbolTable,
}
}
// exported binding in the vm state. That is every binding that starts with a
// capital letter. This is used by the module import system to wrap up the
// compiled and evaulated module into an object.
func (s *VMState) ExportedHash() *object.Hash {
pairs := make(map[object.HashKey]object.HashPair)
for name, symbol := range s.Symbols.Store {
if unicode.IsUpper(rune(name[0])) {
if symbol.Scope == compiler.GlobalScope {
obj := s.Globals[symbol.Index]
s := &object.String{Value: name}
pairs[s.HashKey()] = object.HashPair{Key: s, Value: obj}
}
}
}
return &object.Hash{Pairs: pairs}
}
type VM struct {
Debug bool
state *VMState
dir string
file string
stack []object.Object
sp int // Always points to the next value. Top of stack is stack[sp-1]
frames []*Frame
frame *Frame // Current frame or nil
fp int // Always points to the current frame. Current frame is frames[fp-1]
}
func (vm *VM) pushFrame(f *Frame) {
vm.frame = f
vm.frames[vm.fp] = f
vm.fp++
}
func (vm *VM) popFrame() *Frame {
vm.fp--
vm.frame = vm.frames[vm.fp-1]
return vm.frames[vm.fp]
}
// New constructs a new monkey-lang bytecode virtual machine
func New(fn string, bytecode *compiler.Bytecode) *VM {
mainFn := &object.CompiledFunction{Instructions: bytecode.Instructions}
mainClosure := &object.Closure{Fn: mainFn}
mainFrame := NewFrame(mainClosure, 0)
frames := make([]*Frame, MaxFrames)
frames[0] = mainFrame
state := NewVMState()
state.Constants = bytecode.Constants
vm := &VM{
state: state,
stack: make([]object.Object, StackSize),
sp: 0,
frames: frames,
frame: mainFrame,
fp: 1,
}
vm.dir, vm.file = filepath.Split(fn)
return vm
}
func NewWithState(fn string, bytecode *compiler.Bytecode, state *VMState) *VM {
mainFn := &object.CompiledFunction{Instructions: bytecode.Instructions}
mainClosure := &object.Closure{Fn: mainFn}
mainFrame := NewFrame(mainClosure, 0)
frames := make([]*Frame, MaxFrames)
frames[0] = mainFrame
vm := &VM{
state: state,
frames: frames,
frame: mainFrame,
fp: 1,
stack: make([]object.Object, StackSize),
sp: 0,
}
vm.dir, vm.file = filepath.Split(fn)
return vm
}
func (vm *VM) push(o object.Object) error {
if vm.sp >= StackSize {
return fmt.Errorf("stack overflow")
}
vm.stack[vm.sp] = o
vm.sp++
return nil
}
func (vm *VM) pop() object.Object {
o := vm.stack[vm.sp-1]
vm.sp--
return o
}
func (vm *VM) executeLoadModule() error {
name := vm.pop()
return vm.loadModule(name)
}
func (vm *VM) buildHash(startIndex, endIndex int) (object.Object, error) {
hashedPairs := make(map[object.HashKey]object.HashPair)
for i := startIndex; i < endIndex; i += 2 {
key := vm.stack[i]
value := vm.stack[i+1]
pair := object.HashPair{Key: key, Value: value}
hashKey, ok := key.(object.Hashable)
if !ok {
return nil, fmt.Errorf("unusable as hash key: %s", key.Type())
}
hashedPairs[hashKey.HashKey()] = pair
}
return &object.Hash{Pairs: hashedPairs}, nil
}
func (vm *VM) executeMakeHash(startIndex, endIndex int) error {
hash, err := vm.buildHash(startIndex, endIndex)
if err != nil {
return err
}
vm.sp = startIndex
return vm.push(hash)
}
func (vm *VM) buildArray(startIndex, endIndex int) (object.Object, error) {
elements := make([]object.Object, endIndex-startIndex)
for i := startIndex; i < endIndex; i++ {
elements[i-startIndex] = vm.stack[i]
}
return &object.Array{Elements: elements}, nil
}
func (vm *VM) executeMakeArray(startIndex, endIndex int) error {
hash, err := vm.buildArray(startIndex, endIndex)
if err != nil {
return err
}
vm.sp = startIndex
return vm.push(hash)
}
func (vm *VM) executeAdd() error {
right := vm.pop()
left := vm.pop()
if obj, ok := left.(object.Add); ok {
val, err := obj.Add(right)
if err != nil {
return err
}
return vm.push(val)
}
return object.NewBinaryOpError(left, right, "+")
}
func (vm *VM) executeSub() error {
right := vm.pop()
left := vm.pop()
if obj, ok := left.(object.Sub); ok {
val, err := obj.Sub(right)
if err != nil {
return err
}
return vm.push(val)
}
return object.NewBinaryOpError(left, right, "+")
}
func (vm *VM) executeMul() error {
right := vm.pop()
left := vm.pop()
if obj, ok := left.(object.Mul); ok {
val, err := obj.Mul(right)
if err != nil {
return err
}
return vm.push(val)
}
return object.NewBinaryOpError(left, right, "*")
}
func (vm *VM) executeDiv() error {
right := vm.pop()
left := vm.pop()
if obj, ok := left.(object.Div); ok {
val, err := obj.Div(right)
if err != nil {
return err
}
return vm.push(val)
}
return object.NewBinaryOpError(left, right, "+")
}
func (vm *VM) executeMod() error {
right := vm.pop()
left := vm.pop()
if obj, ok := left.(object.Mod); ok {
val, err := obj.Mod(right)
if err != nil {
return err
}
return vm.push(val)
}
return object.NewBinaryOpError(left, right, "%")
}
func (vm *VM) executeOr() error {
right := vm.pop()
left := vm.pop()
if obj, ok := left.(object.LogicalOr); ok {
val, err := obj.LogicalOr(right)
if err != nil {
return err
}
return vm.push(val)
}
return object.NewBinaryOpError(left, right, "||")
}
func (vm *VM) executeAnd() error {
right := vm.pop()
left := vm.pop()
if obj, ok := left.(object.LogicalAnd); ok {
val, err := obj.LogicalAnd(right)
if err != nil {
return err
}
return vm.push(val)
}
return object.NewBinaryOpError(left, right, "&&")
}
func (vm *VM) executeBitwiseOr() error {
right := vm.pop()
left := vm.pop()
if obj, ok := left.(object.BitwiseOr); ok {
val, err := obj.BitwiseOr(right)
if err != nil {
return err
}
return vm.push(val)
}
return object.NewBinaryOpError(left, right, "|")
}
func (vm *VM) executeBitwiseXor() error {
right := vm.pop()
left := vm.pop()
if obj, ok := left.(object.BitwiseXor); ok {
val, err := obj.BitwiseXor(right)
if err != nil {
return err
}
return vm.push(val)
}
return object.NewBinaryOpError(left, right, "^")
}
func (vm *VM) executeBitwiseAnd() error {
right := vm.pop()
left := vm.pop()
if obj, ok := left.(object.BitwiseAnd); ok {
val, err := obj.BitwiseAnd(right)
if err != nil {
return err
}
return vm.push(val)
}
return object.NewBinaryOpError(left, right, "&")
}
func (vm *VM) executeBitwiseNot() error {
left := vm.pop()
if obj, ok := left.(object.BitwiseNot); ok {
return vm.push(obj.BitwiseNot())
}
return fmt.Errorf("unsupported types for unary operation: ~%s", left.Type())
}
func (vm *VM) executeLeftShift() error {
right := vm.pop()
left := vm.pop()
if obj, ok := left.(object.LeftShift); ok {
val, err := obj.LeftShift(right)
if err != nil {
return err
}
return vm.push(val)
}
return object.NewBinaryOpError(left, right, "<<")
}
func (vm *VM) executeRightShift() error {
right := vm.pop()
left := vm.pop()
if obj, ok := left.(object.RightShift); ok {
val, err := obj.RightShift(right)
if err != nil {
return err
}
return vm.push(val)
}
return object.NewBinaryOpError(left, right, ">>")
}
func (vm *VM) executeEqual() error {
right := vm.pop()
left := vm.pop()
if obj, ok := left.(object.Comparable); ok {
val := obj.Compare(right)
if val == 0 {
return vm.push(True)
}
return vm.push(False)
}
return object.NewBinaryOpError(left, right, "==")
}
func (vm *VM) executeNotEqual() error {
right := vm.pop()
left := vm.pop()
if obj, ok := left.(object.Comparable); ok {
val := obj.Compare(right)
if val != 0 {
return vm.push(True)
}
return vm.push(False)
}
return object.NewBinaryOpError(left, right, "!=")
}
func (vm *VM) executeGreaterThan() error {
right := vm.pop()
left := vm.pop()
if obj, ok := left.(object.Comparable); ok {
val := obj.Compare(right)
if val == 1 {
return vm.push(True)
}
return vm.push(False)
}
return object.NewBinaryOpError(left, right, ">")
}
func (vm *VM) executeGreaterThanOrEqual() error {
right := vm.pop()
left := vm.pop()
if obj, ok := left.(object.Comparable); ok {
val := obj.Compare(right)
if val >= 0 {
return vm.push(True)
}
return vm.push(False)
}
return object.NewBinaryOpError(left, right, ">")
}
func (vm *VM) executeNot() error {
left := vm.pop()
if obj, ok := left.(object.LogicalNot); ok {
return vm.push(obj.LogicalNot())
}
return fmt.Errorf("unsupported types for unary operation: !%s", left.Type())
}
func (vm *VM) executeMinus() error {
left := vm.pop()
if obj, ok := left.(object.Negate); ok {
return vm.push(obj.Negate())
}
return fmt.Errorf("unsupported types for unary operation: -%s", left.Type())
}
func (vm *VM) executeSetItem() error {
right := vm.pop()
index := vm.pop()
left := vm.pop()
if obj, ok := left.(object.Setter); ok {
err := obj.Set(index, right)
if err != nil {
return err
}
return vm.push(Null)
}
return fmt.Errorf(
"set item operation not supported: left=%s index=%s",
left.Type(), index.Type(),
)
}
func (vm *VM) executeGetItem() error {
index := vm.pop()
left := vm.pop()
if obj, ok := left.(object.Getter); ok {
val, err := obj.Get(index)
if err != nil {
return err
}
return vm.push(val)
}
return fmt.Errorf(
"index operator not supported: left=%s index=%s",
left.Type(), index.Type(),
)
}
func (vm *VM) executeCall(args int) error {
callee := vm.stack[vm.sp-1-args]
switch callee := callee.(type) {
case *object.Closure:
return vm.callClosure(callee, args)
case *object.Builtin:
return vm.callBuiltin(callee, args)
default:
return fmt.Errorf(
"calling non-closure and non-builtin: %T %v",
callee, callee,
)
}
}
func (vm *VM) executeReturn() error {
returnValue := vm.pop()
frame := vm.popFrame()
vm.sp = frame.basePointer - 1
return vm.push(returnValue)
}
func (vm *VM) callClosure(cl *object.Closure, numArgs int) error {
if numArgs != cl.Fn.NumParameters {
return fmt.Errorf("wrong number of arguments: want=%d, got=%d", cl.Fn.NumParameters, numArgs)
}
// Optimize tail calls and avoid a new frame
if cl.Fn == vm.frame.cl.Fn {
nextOP := vm.frame.NextOp()
if nextOP == code.OpReturn {
for p := 0; p < numArgs; p++ {
vm.stack[vm.frame.basePointer+p] = vm.stack[vm.sp-numArgs+p]
}
vm.sp -= numArgs + 1
vm.frame.ip = -1 // reset IP to the beginning of the frame
return nil
}
}
frame := NewFrame(cl, vm.sp-numArgs)
vm.pushFrame(frame)
vm.sp = frame.basePointer + cl.Fn.NumLocals
return nil
}
func (vm *VM) callBuiltin(builtin *object.Builtin, numArgs int) error {
args := vm.stack[vm.sp-numArgs : vm.sp]
result := builtin.Fn(args...)
vm.sp = vm.sp - numArgs - 1
if result != nil {
err := vm.push(result)
if err != nil {
return err
}
} else {
err := vm.push(Null)
if err != nil {
return err
}
}
return nil
}
func (vm *VM) pushClosure(constIndex, numFree int) error {
constant := vm.state.Constants[constIndex]
function, ok := constant.(*object.CompiledFunction)
if !ok {
return fmt.Errorf("not a function %+v", constant)
}
free := make([]object.Object, numFree)
for i := 0; i < numFree; i++ {
free[i] = vm.stack[vm.sp-numFree+i]
}
vm.sp = vm.sp - numFree
closure := &object.Closure{Fn: function, Free: free}
return vm.push(closure)
}
func (vm *VM) loadModule(name object.Object) error {
s, ok := name.(*object.String)
if !ok {
return fmt.Errorf(
"TypeError: import() expected argument #1 to be `str` got `%s`",
name.Type(),
)
}
attrs, err := executeModule(s.Value, vm.state)
if err != nil {
return err
}
module := &object.Module{Name: s.Value, Attrs: attrs}
return vm.push(module)
}
func (vm *VM) LastPoppedStackElem() object.Object {
return vm.stack[vm.sp]
}
func (vm *VM) Run() (err error) {
var n int
var ip int
var ins code.Instructions
var op code.Opcode
if vm.Debug {
start := time.Now()
defer func() {
log.Printf("%d instructions executeuted in %s", n, time.Now().Sub(start))
}()
}
for err == nil {
vm.frame.ip++
ip = vm.frame.ip
ins = vm.frame.Instructions()
op = code.Opcode(ins[ip])
if vm.Debug {
n++
log.Printf(
"%-25s %-20s\n",
fmt.Sprintf(
"%04d %s", ip,
strings.Split(ins[ip:].String(), "\n")[0][4:],
),
fmt.Sprintf(
"[ip=%02d fp=%02d, sp=%02d]",
ip, vm.fp-1, vm.sp,
),
)
}
switch op {
case code.OpConstant:
constIndex := code.ReadUint16(ins[ip+1:])
vm.frame.ip += 2
err = vm.push(vm.state.Constants[constIndex])
case code.OpPop:
vm.pop()
case code.OpTrue:
err = vm.push(True)
case code.OpFalse:
err = vm.push(False)
case code.OpJump:
pos := int(code.ReadUint16(ins[ip+1:]))
vm.frame.ip = pos - 1
case code.OpJumpNotTruthy:
pos := int(code.ReadUint16(ins[ip+1:]))
vm.frame.ip += 2
if !isTruthy(vm.pop()) {
vm.frame.ip = pos - 1
}
case code.OpNull:
err = vm.push(Null)
case code.OpSetGlobal:
globalIndex := code.ReadUint16(ins[ip+1:])
vm.frame.ip += 2
ref := vm.pop()
if immutable, ok := ref.(object.Immutable); ok {
vm.state.Globals[globalIndex] = immutable.Clone()
} else {
vm.state.Globals[globalIndex] = ref
}
err = vm.push(Null)
case code.OpAssignGlobal:
globalIndex := code.ReadUint16(ins[ip+1:])
vm.frame.ip += 2
vm.state.Globals[globalIndex] = vm.pop()
err = vm.push(Null)
case code.OpAssignLocal:
localIndex := code.ReadUint8(ins[ip+1:])
vm.frame.ip += 1
vm.stack[vm.frame.basePointer+int(localIndex)] = vm.pop()
err = vm.push(Null)
case code.OpGetGlobal:
globalIndex := code.ReadUint16(ins[ip+1:])
vm.frame.ip += 2
err = vm.push(vm.state.Globals[globalIndex])
case code.OpArray:
numElements := int(code.ReadUint16(ins[ip+1:]))
vm.frame.ip += 2
err = vm.executeMakeArray(vm.sp-numElements, vm.sp)
case code.OpHash:
numElements := int(code.ReadUint16(ins[ip+1:]))
vm.frame.ip += 2
err = vm.executeMakeHash(vm.sp-numElements, vm.sp)
case code.OpSetItem:
err = vm.executeSetItem()
case code.OpGetItem:
err = vm.executeGetItem()
case code.OpCall:
args := int(code.ReadUint8(ins[ip+1:]))
vm.frame.ip++
err = vm.executeCall(args)
case code.OpReturn:
err = vm.executeReturn()
case code.OpSetLocal:
localIndex := code.ReadUint8(ins[ip+1:])
vm.frame.ip += 1
ref := vm.pop()
if immutable, ok := ref.(object.Immutable); ok {
vm.stack[vm.frame.basePointer+int(localIndex)] = immutable.Clone()
} else {
vm.stack[vm.frame.basePointer+int(localIndex)] = ref
}
err = vm.push(Null)
case code.OpGetLocal:
localIndex := code.ReadUint8(ins[ip+1:])
vm.frame.ip += 1
err = vm.push(vm.stack[vm.frame.basePointer+int(localIndex)])
case code.OpGetBuiltin:
builtinIndex := code.ReadUint8(ins[ip+1:])
vm.frame.ip++
builtin := builtins.BuiltinsIndex[builtinIndex]
err = vm.push(builtin)
case code.OpClosure:
constIndex := code.ReadUint16(ins[ip+1:])
numFree := code.ReadUint8(ins[ip+3:])
vm.frame.ip += 3
err = vm.pushClosure(int(constIndex), int(numFree))
case code.OpGetFree:
freeIndex := code.ReadUint8(ins[ip+1:])
vm.frame.ip += 1
err = vm.push(vm.frame.cl.Free[freeIndex])
case code.OpCurrentClosure:
currentClosure := vm.frame.cl
err = vm.push(currentClosure)
case code.OpLoadModule:
err = vm.executeLoadModule()
case code.OpAdd:
err = vm.executeAdd()
case code.OpSub:
err = vm.executeSub()
case code.OpMul:
err = vm.executeMul()
case code.OpDiv:
err = vm.executeDiv()
case code.OpMod:
err = vm.executeMod()
case code.OpOr:
err = vm.executeOr()
case code.OpAnd:
err = vm.executeAnd()
case code.OpBitwiseOR:
err = vm.executeBitwiseOr()
case code.OpBitwiseXOR:
err = vm.executeBitwiseXor()
case code.OpBitwiseAND:
err = vm.executeBitwiseAnd()
case code.OpLeftShift:
err = vm.executeLeftShift()
case code.OpRightShift:
err = vm.executeRightShift()
case code.OpEqual:
err = vm.executeEqual()
case code.OpNotEqual:
err = vm.executeNotEqual()
case code.OpGreaterThan:
err = vm.executeGreaterThan()
case code.OpGreaterThanEqual:
err = vm.executeGreaterThanOrEqual()
case code.OpNot:
err = vm.executeNot()
case code.OpBitwiseNOT:
err = vm.executeBitwiseNot()
case code.OpMinus:
err = vm.executeMinus()
case code.OpHalt:
return
default:
err = fmt.Errorf("unhandled opcode: %s", op)
}
if vm.Debug {
log.Printf(
"%-25s [ip=%02d fp=%02d, sp=%02d]",
"", ip, vm.fp-1, vm.sp,
)
}
}
return err
}