package vm import ( "fmt" "log" "monkey/internal/builtins" "monkey/internal/code" "monkey/internal/compiler" "monkey/internal/lexer" "monkey/internal/object" "monkey/internal/parser" "monkey/internal/utils" "os" "path/filepath" "strings" "time" "unicode" ) const StackSize = 2048 const GlobalsSize = 65536 const MaxFrames = 1024 var Null = &object.Null{} var True = &object.Boolean{Value: true} var False = &object.Boolean{Value: false} func isTruthy(obj object.Object) bool { switch obj := obj.(type) { case *object.Boolean: return obj.Value case *object.Null: return false default: return true } } func nativeBoolToBooleanObject(input bool) *object.Boolean { if input { return True } return False } // executeModule compiles the named module and returns a *object.Module object func executeModule(name string, state *VMState) (object.Object, error) { filename := utils.FindModule(name) if filename == "" { return nil, fmt.Errorf("ImportError: no module named '%s'", name) } b, err := os.ReadFile(filename) if err != nil { return nil, fmt.Errorf("IOError: error reading module '%s': %s", name, err) } l := lexer.New(string(b)) p := parser.New(fmt.Sprintf("", name), l) module := p.ParseProgram() if len(p.Errors()) != 0 { return nil, fmt.Errorf("ParseError: %s", p.Errors()) } c := compiler.NewWithState(state.Symbols, &state.Constants) err = c.Compile(module) if err != nil { return nil, fmt.Errorf("CompileError: %s", err) } code := c.Bytecode() state.Constants = code.Constants machine := NewWithState(fmt.Sprintf("", name), code, state) err = machine.Run() if err != nil { return nil, fmt.Errorf("RuntimeError: error loading module '%s'", err) } return state.ExportedHash(), nil } type VMState struct { Constants []object.Object Globals []object.Object Symbols *compiler.SymbolTable } func NewVMState() *VMState { symbolTable := compiler.NewSymbolTable() for i, builtin := range builtins.BuiltinsIndex { symbolTable.DefineBuiltin(i, builtin.Name) } return &VMState{ Constants: []object.Object{}, Globals: make([]object.Object, GlobalsSize), Symbols: symbolTable, } } // exported binding in the vm state. That is every binding that starts with a // capital letter. This is used by the module import system to wrap up the // compiled and evaulated module into an object. func (s *VMState) ExportedHash() *object.Hash { pairs := make(map[object.HashKey]object.HashPair) for name, symbol := range s.Symbols.Store { if unicode.IsUpper(rune(name[0])) { if symbol.Scope == compiler.GlobalScope { obj := s.Globals[symbol.Index] s := &object.String{Value: name} pairs[s.HashKey()] = object.HashPair{Key: s, Value: obj} } } } return &object.Hash{Pairs: pairs} } type VM struct { Debug bool state *VMState dir string file string stack []object.Object sp int // Always points to the next value. Top of stack is stack[sp-1] frames []*Frame frame *Frame // Current frame or nil fp int // Always points to the current frame. Current frame is frames[fp-1] } func (vm *VM) pushFrame(f *Frame) { vm.frame = f vm.frames[vm.fp] = f vm.fp++ } func (vm *VM) popFrame() *Frame { vm.fp-- vm.frame = vm.frames[vm.fp-1] return vm.frames[vm.fp] } // New constructs a new monkey-lang bytecode virtual machine func New(fn string, bytecode *compiler.Bytecode) *VM { mainFn := &object.CompiledFunction{Instructions: bytecode.Instructions} mainClosure := &object.Closure{Fn: mainFn} mainFrame := NewFrame(mainClosure, 0) frames := make([]*Frame, MaxFrames) frames[0] = mainFrame state := NewVMState() state.Constants = bytecode.Constants vm := &VM{ state: state, stack: make([]object.Object, StackSize), sp: 0, frames: frames, frame: mainFrame, fp: 1, } vm.dir, vm.file = filepath.Split(fn) return vm } func NewWithState(fn string, bytecode *compiler.Bytecode, state *VMState) *VM { mainFn := &object.CompiledFunction{Instructions: bytecode.Instructions} mainClosure := &object.Closure{Fn: mainFn} mainFrame := NewFrame(mainClosure, 0) frames := make([]*Frame, MaxFrames) frames[0] = mainFrame vm := &VM{ state: state, frames: frames, frame: mainFrame, fp: 1, stack: make([]object.Object, StackSize), sp: 0, } vm.dir, vm.file = filepath.Split(fn) return vm } func (vm *VM) push(o object.Object) error { if vm.sp >= StackSize { return fmt.Errorf("stack overflow") } vm.stack[vm.sp] = o vm.sp++ return nil } func (vm *VM) pop() object.Object { o := vm.stack[vm.sp-1] vm.sp-- return o } func (vm *VM) executeLoadModule() error { name := vm.pop() return vm.loadModule(name) } func (vm *VM) buildHash(startIndex, endIndex int) (object.Object, error) { hashedPairs := make(map[object.HashKey]object.HashPair) for i := startIndex; i < endIndex; i += 2 { key := vm.stack[i] value := vm.stack[i+1] pair := object.HashPair{Key: key, Value: value} hashKey, ok := key.(object.Hashable) if !ok { return nil, fmt.Errorf("unusable as hash key: %s", key.Type()) } hashedPairs[hashKey.HashKey()] = pair } return &object.Hash{Pairs: hashedPairs}, nil } func (vm *VM) executeMakeHash(startIndex, endIndex int) error { hash, err := vm.buildHash(startIndex, endIndex) if err != nil { return err } vm.sp = startIndex return vm.push(hash) } func (vm *VM) buildArray(startIndex, endIndex int) (object.Object, error) { elements := make([]object.Object, endIndex-startIndex) for i := startIndex; i < endIndex; i++ { elements[i-startIndex] = vm.stack[i] } return &object.Array{Elements: elements}, nil } func (vm *VM) executeMakeArray(startIndex, endIndex int) error { hash, err := vm.buildArray(startIndex, endIndex) if err != nil { return err } vm.sp = startIndex return vm.push(hash) } func (vm *VM) executeAdd() error { right := vm.pop() left := vm.pop() if obj, ok := left.(object.Add); ok { val, err := obj.Add(right) if err != nil { return err } return vm.push(val) } return object.NewBinaryOpError(left, right, "+") } func (vm *VM) executeSub() error { right := vm.pop() left := vm.pop() if obj, ok := left.(object.Sub); ok { val, err := obj.Sub(right) if err != nil { return err } return vm.push(val) } return object.NewBinaryOpError(left, right, "+") } func (vm *VM) executeMul() error { right := vm.pop() left := vm.pop() if obj, ok := left.(object.Mul); ok { val, err := obj.Mul(right) if err != nil { return err } return vm.push(val) } return object.NewBinaryOpError(left, right, "*") } func (vm *VM) executeDiv() error { right := vm.pop() left := vm.pop() if obj, ok := left.(object.Div); ok { val, err := obj.Div(right) if err != nil { return err } return vm.push(val) } return object.NewBinaryOpError(left, right, "+") } func (vm *VM) executeMod() error { right := vm.pop() left := vm.pop() if obj, ok := left.(object.Mod); ok { val, err := obj.Mod(right) if err != nil { return err } return vm.push(val) } return object.NewBinaryOpError(left, right, "%") } func (vm *VM) executeOr() error { right := vm.pop() left := vm.pop() if obj, ok := left.(object.LogicalOr); ok { val, err := obj.LogicalOr(right) if err != nil { return err } return vm.push(val) } return object.NewBinaryOpError(left, right, "||") } func (vm *VM) executeAnd() error { right := vm.pop() left := vm.pop() if obj, ok := left.(object.LogicalAnd); ok { val, err := obj.LogicalAnd(right) if err != nil { return err } return vm.push(val) } return object.NewBinaryOpError(left, right, "&&") } func (vm *VM) executeBitwiseOr() error { right := vm.pop() left := vm.pop() if obj, ok := left.(object.BitwiseOr); ok { val, err := obj.BitwiseOr(right) if err != nil { return err } return vm.push(val) } return object.NewBinaryOpError(left, right, "|") } func (vm *VM) executeBitwiseXor() error { right := vm.pop() left := vm.pop() if obj, ok := left.(object.BitwiseXor); ok { val, err := obj.BitwiseXor(right) if err != nil { return err } return vm.push(val) } return object.NewBinaryOpError(left, right, "^") } func (vm *VM) executeBitwiseAnd() error { right := vm.pop() left := vm.pop() if obj, ok := left.(object.BitwiseAnd); ok { val, err := obj.BitwiseAnd(right) if err != nil { return err } return vm.push(val) } return object.NewBinaryOpError(left, right, "&") } func (vm *VM) executeBitwiseNot() error { left := vm.pop() if obj, ok := left.(object.BitwiseNot); ok { return vm.push(obj.BitwiseNot()) } return fmt.Errorf("unsupported types for unary operation: ~%s", left.Type()) } func (vm *VM) executeLeftShift() error { right := vm.pop() left := vm.pop() if obj, ok := left.(object.LeftShift); ok { val, err := obj.LeftShift(right) if err != nil { return err } return vm.push(val) } return object.NewBinaryOpError(left, right, "<<") } func (vm *VM) executeRightShift() error { right := vm.pop() left := vm.pop() if obj, ok := left.(object.RightShift); ok { val, err := obj.RightShift(right) if err != nil { return err } return vm.push(val) } return object.NewBinaryOpError(left, right, ">>") } func (vm *VM) executeEqual() error { right := vm.pop() left := vm.pop() if obj, ok := left.(object.Comparable); ok { val := obj.Compare(right) if val == 0 { return vm.push(True) } return vm.push(False) } return object.NewBinaryOpError(left, right, "==") } func (vm *VM) executeNotEqual() error { right := vm.pop() left := vm.pop() if obj, ok := left.(object.Comparable); ok { val := obj.Compare(right) if val != 0 { return vm.push(True) } return vm.push(False) } return object.NewBinaryOpError(left, right, "!=") } func (vm *VM) executeGreaterThan() error { right := vm.pop() left := vm.pop() if obj, ok := left.(object.Comparable); ok { val := obj.Compare(right) if val == 1 { return vm.push(True) } return vm.push(False) } return object.NewBinaryOpError(left, right, ">") } func (vm *VM) executeGreaterThanOrEqual() error { right := vm.pop() left := vm.pop() if obj, ok := left.(object.Comparable); ok { val := obj.Compare(right) if val >= 0 { return vm.push(True) } return vm.push(False) } return object.NewBinaryOpError(left, right, ">") } func (vm *VM) executeNot() error { left := vm.pop() if obj, ok := left.(object.LogicalNot); ok { return vm.push(obj.LogicalNot()) } return fmt.Errorf("unsupported types for unary operation: !%s", left.Type()) } func (vm *VM) executeMinus() error { left := vm.pop() if obj, ok := left.(object.Negate); ok { return vm.push(obj.Negate()) } return fmt.Errorf("unsupported types for unary operation: -%s", left.Type()) } func (vm *VM) executeSetItem() error { right := vm.pop() index := vm.pop() left := vm.pop() if obj, ok := left.(object.Setter); ok { err := obj.Set(index, right) if err != nil { return err } return vm.push(Null) } return fmt.Errorf( "set item operation not supported: left=%s index=%s", left.Type(), index.Type(), ) } func (vm *VM) executeGetItem() error { index := vm.pop() left := vm.pop() if obj, ok := left.(object.Getter); ok { val, err := obj.Get(index) if err != nil { return err } return vm.push(val) } return fmt.Errorf( "index operator not supported: left=%s index=%s", left.Type(), index.Type(), ) } func (vm *VM) executeCall(args int) error { callee := vm.stack[vm.sp-1-args] switch callee := callee.(type) { case *object.Closure: return vm.callClosure(callee, args) case *object.Builtin: return vm.callBuiltin(callee, args) default: return fmt.Errorf( "calling non-closure and non-builtin: %T %v", callee, callee, ) } } func (vm *VM) executeReturn() error { returnValue := vm.pop() frame := vm.popFrame() vm.sp = frame.basePointer - 1 return vm.push(returnValue) } func (vm *VM) callClosure(cl *object.Closure, numArgs int) error { if numArgs != cl.Fn.NumParameters { return fmt.Errorf("wrong number of arguments: want=%d, got=%d", cl.Fn.NumParameters, numArgs) } // Optimize tail calls and avoid a new frame if cl.Fn == vm.frame.cl.Fn { nextOP := vm.frame.NextOp() if nextOP == code.OpReturn { for p := 0; p < numArgs; p++ { vm.stack[vm.frame.basePointer+p] = vm.stack[vm.sp-numArgs+p] } vm.sp -= numArgs + 1 vm.frame.ip = -1 // reset IP to the beginning of the frame return nil } } frame := NewFrame(cl, vm.sp-numArgs) vm.pushFrame(frame) vm.sp = frame.basePointer + cl.Fn.NumLocals return nil } func (vm *VM) callBuiltin(builtin *object.Builtin, numArgs int) error { args := vm.stack[vm.sp-numArgs : vm.sp] result := builtin.Fn(args...) vm.sp = vm.sp - numArgs - 1 if result != nil { err := vm.push(result) if err != nil { return err } } else { err := vm.push(Null) if err != nil { return err } } return nil } func (vm *VM) pushClosure(constIndex, numFree int) error { constant := vm.state.Constants[constIndex] function, ok := constant.(*object.CompiledFunction) if !ok { return fmt.Errorf("not a function %+v", constant) } free := make([]object.Object, numFree) for i := 0; i < numFree; i++ { free[i] = vm.stack[vm.sp-numFree+i] } vm.sp = vm.sp - numFree closure := &object.Closure{Fn: function, Free: free} return vm.push(closure) } func (vm *VM) loadModule(name object.Object) error { s, ok := name.(*object.String) if !ok { return fmt.Errorf( "TypeError: import() expected argument #1 to be `str` got `%s`", name.Type(), ) } attrs, err := executeModule(s.Value, vm.state) if err != nil { return err } module := &object.Module{Name: s.Value, Attrs: attrs} return vm.push(module) } func (vm *VM) LastPoppedStackElem() object.Object { return vm.stack[vm.sp] } func (vm *VM) Run() error { var n int var ip int var ins code.Instructions var op code.Opcode if vm.Debug { start := time.Now() defer func() { log.Printf("%d instructions executeuted in %s", n, time.Now().Sub(start)) }() } var err error for vm.frame.ip < len(vm.frame.Instructions())-1 && err == nil { vm.frame.ip++ ip = vm.frame.ip ins = vm.frame.Instructions() op = code.Opcode(ins[ip]) if vm.Debug { n++ log.Printf( "%-25s %-20s\n", fmt.Sprintf( "%04d %s", ip, strings.Split(ins[ip:].String(), "\n")[0][4:], ), fmt.Sprintf( "[ip=%02d fp=%02d, sp=%02d]", ip, vm.fp-1, vm.sp, ), ) } switch op { case code.OpConstant: constIndex := code.ReadUint16(ins[ip+1:]) vm.frame.ip += 2 err = vm.push(vm.state.Constants[constIndex]) case code.OpPop: vm.pop() case code.OpTrue: err = vm.push(True) case code.OpFalse: err = vm.push(False) case code.OpJump: pos := int(code.ReadUint16(ins[ip+1:])) vm.frame.ip = pos - 1 case code.OpJumpNotTruthy: pos := int(code.ReadUint16(ins[ip+1:])) vm.frame.ip += 2 if !isTruthy(vm.pop()) { vm.frame.ip = pos - 1 } case code.OpNull: err = vm.push(Null) case code.OpSetGlobal: globalIndex := code.ReadUint16(ins[ip+1:]) vm.frame.ip += 2 ref := vm.pop() if immutable, ok := ref.(object.Immutable); ok { vm.state.Globals[globalIndex] = immutable.Clone() } else { vm.state.Globals[globalIndex] = ref } err = vm.push(Null) case code.OpAssignGlobal: globalIndex := code.ReadUint16(ins[ip+1:]) vm.frame.ip += 2 vm.state.Globals[globalIndex] = vm.pop() err = vm.push(Null) case code.OpAssignLocal: localIndex := code.ReadUint8(ins[ip+1:]) vm.frame.ip += 1 vm.stack[vm.frame.basePointer+int(localIndex)] = vm.pop() err = vm.push(Null) case code.OpGetGlobal: globalIndex := code.ReadUint16(ins[ip+1:]) vm.frame.ip += 2 err = vm.push(vm.state.Globals[globalIndex]) case code.OpArray: numElements := int(code.ReadUint16(ins[ip+1:])) vm.frame.ip += 2 err = vm.executeMakeArray(vm.sp-numElements, vm.sp) case code.OpHash: numElements := int(code.ReadUint16(ins[ip+1:])) vm.frame.ip += 2 err = vm.executeMakeHash(vm.sp-numElements, vm.sp) case code.OpSetItem: err = vm.executeSetItem() case code.OpGetItem: err = vm.executeGetItem() case code.OpCall: args := int(code.ReadUint8(ins[ip+1:])) vm.frame.ip++ err = vm.executeCall(args) case code.OpReturn: err = vm.executeReturn() case code.OpSetLocal: localIndex := code.ReadUint8(ins[ip+1:]) vm.frame.ip += 1 ref := vm.pop() if immutable, ok := ref.(object.Immutable); ok { vm.stack[vm.frame.basePointer+int(localIndex)] = immutable.Clone() } else { vm.stack[vm.frame.basePointer+int(localIndex)] = ref } err = vm.push(Null) case code.OpGetLocal: localIndex := code.ReadUint8(ins[ip+1:]) vm.frame.ip += 1 err = vm.push(vm.stack[vm.frame.basePointer+int(localIndex)]) case code.OpGetBuiltin: builtinIndex := code.ReadUint8(ins[ip+1:]) vm.frame.ip++ builtin := builtins.BuiltinsIndex[builtinIndex] err = vm.push(builtin) case code.OpClosure: constIndex := code.ReadUint16(ins[ip+1:]) numFree := code.ReadUint8(ins[ip+3:]) vm.frame.ip += 3 err = vm.pushClosure(int(constIndex), int(numFree)) case code.OpGetFree: freeIndex := code.ReadUint8(ins[ip+1:]) vm.frame.ip += 1 err = vm.push(vm.frame.cl.Free[freeIndex]) case code.OpCurrentClosure: currentClosure := vm.frame.cl err = vm.push(currentClosure) case code.OpLoadModule: err = vm.executeLoadModule() case code.OpAdd: err = vm.executeAdd() case code.OpSub: err = vm.executeSub() case code.OpMul: err = vm.executeMul() case code.OpDiv: err = vm.executeDiv() case code.OpMod: err = vm.executeMod() case code.OpOr: err = vm.executeOr() case code.OpAnd: err = vm.executeAnd() case code.OpBitwiseOR: err = vm.executeBitwiseOr() case code.OpBitwiseXOR: err = vm.executeBitwiseXor() case code.OpBitwiseAND: err = vm.executeBitwiseAnd() case code.OpLeftShift: err = vm.executeLeftShift() case code.OpRightShift: err = vm.executeRightShift() case code.OpEqual: err = vm.executeEqual() case code.OpNotEqual: err = vm.executeNotEqual() case code.OpGreaterThan: err = vm.executeGreaterThan() case code.OpGreaterThanEqual: err = vm.executeGreaterThanOrEqual() case code.OpNot: err = vm.executeNot() case code.OpBitwiseNOT: err = vm.executeBitwiseNot() case code.OpMinus: err = vm.executeMinus() default: err = fmt.Errorf("unhandled opcode: %s", op) } if vm.Debug { log.Printf( "%-25s [ip=%02d fp=%02d, sp=%02d]", "", ip, vm.fp-1, vm.sp, ) } } return err }