package vm import ( "fmt" "monkey/internal/builtins" "monkey/internal/code" "monkey/internal/compiler" "monkey/internal/context" "monkey/internal/lexer" "monkey/internal/object" "monkey/internal/parser" "monkey/internal/utils" "os" "path/filepath" "unicode" ) const ( maxStackSize = 2048 maxFrames = 1024 maxGlobals = 65536 ) func isTruthy(obj object.Object) bool { switch obj := obj.(type) { case object.Boolean: return obj.Value case object.Null: return false default: return true } } func executeModule(name string, state *State) (object.Object, error) { filename := utils.FindModule(name) if filename == "" { return nil, fmt.Errorf("ImportError: no module named '%s'", name) } b, err := os.ReadFile(filename) if err != nil { return nil, fmt.Errorf("IOError: error reading module '%s': %s", name, err) } l := lexer.New(string(b)) p := parser.New(fmt.Sprintf("", name), l) module := p.ParseProgram() if len(p.Errors()) != 0 { return nil, fmt.Errorf("ParseError: %s", p.Errors()) } c := compiler.NewWithState(state.Symbols, &state.Constants) err = c.Compile(module) if err != nil { return nil, fmt.Errorf("CompileError: %s", err) } code := c.Bytecode() state.Constants = code.Constants machine := New(fmt.Sprintf("", name), code, WithState(state)) err = machine.Run() if err != nil { return nil, fmt.Errorf("RuntimeError: error loading module '%s'", err) } return state.ExportedHash(), nil } // State is the state of the virtual machine. type State struct { Constants []object.Object Globals []object.Object Symbols *compiler.SymbolTable } func NewState() *State { symbolTable := compiler.NewSymbolTable() for i, builtin := range builtins.BuiltinsIndex { symbolTable.DefineBuiltin(i, builtin.Name) } return &State{ Constants: []object.Object{}, Globals: make([]object.Object, maxGlobals), Symbols: symbolTable, } } // exported binding in the vm state. That is every binding that starts with a // capital letter. This is used by the module import system to wrap up the // compiled and evaulated module into an object. func (s *State) ExportedHash() *object.Hash { pairs := make(map[object.HashKey]object.HashPair) for name, symbol := range s.Symbols.Store { if unicode.IsUpper(rune(name[0])) { if symbol.Scope == compiler.GlobalScope { obj := s.Globals[symbol.Index] s := object.String{Value: name} pairs[s.Hash()] = object.HashPair{Key: s, Value: obj} } } } return &object.Hash{Pairs: pairs} } type VM struct { debug bool trace bool ctx context.Context state *State dir string file string stack []object.Object sp int // Always points to the next value. Top of stack is stack[sp-1] frames []frame fp int // Always points to the current frame. Current frame is frames[fp-1] } func (vm *VM) currentFrame() *frame { return &vm.frames[vm.fp-1] } func (vm *VM) pushFrame(f frame) { if vm.fp >= maxFrames { panic("frame overflow") } vm.frames[vm.fp] = f vm.fp++ } func (vm *VM) popFrame() frame { if vm.fp == 0 { panic("fame underflow") } vm.fp-- return vm.frames[vm.fp] } // Option defines a function option for the virtual machine. type Option func(*VM) // WithContext defines an option to set the context for the virtual machine. func WithContext(ctx context.Context) Option { return func(vm *VM) { vm.ctx = ctx } } // WithDebug enables debug mode in the VM. func WithDebug(debug bool) Option { return func(vm *VM) { vm.debug = debug } } // WithState sets the state of the VM. func WithState(state *State) Option { return func(vm *VM) { vm.state = state } } // WithTrace sets the trace flag for the VM. func WithTrace(trace bool) Option { return func(vm *VM) { vm.trace = trace } } // New constructs a new monkey-lang bytecode virtual machine func New(fn string, bytecode *compiler.Bytecode, options ...Option) *VM { mainFn := object.CompiledFunction{Instructions: bytecode.Instructions} mainClosure := object.Closure{Fn: &mainFn} mainFrame := newFrame(&mainClosure, 0) frames := make([]frame, maxFrames) frames[0] = mainFrame ctx := context.New() state := NewState() state.Constants = bytecode.Constants vm := &VM{ ctx: ctx, state: state, stack: make([]object.Object, maxStackSize), sp: 0, frames: frames, fp: 1, } for _, option := range options { option(vm) } vm.dir, vm.file = filepath.Split(fn) return vm } func (vm *VM) push(o object.Object) error { if vm.sp >= maxStackSize { return fmt.Errorf("stack overflow") } vm.stack[vm.sp] = o vm.sp++ return nil } func (vm *VM) pop() object.Object { if vm.sp == 0 { panic("stack underflow") } o := vm.stack[vm.sp-1] vm.sp-- return o } func (vm *VM) pop2() (object.Object, object.Object) { if vm.sp == 1 { panic("stack underflow") } o1 := vm.stack[vm.sp-1] o2 := vm.stack[vm.sp-2] vm.sp -= 2 return o1, o2 } func (vm *VM) pop3() (object.Object, object.Object, object.Object) { if vm.sp == 2 { panic("stack underflow") } o1 := vm.stack[vm.sp-1] o2 := vm.stack[vm.sp-2] o3 := vm.stack[vm.sp-3] vm.sp -= 3 return o1, o2, o3 } func (vm *VM) executeSetGlobal() error { globalIndex := vm.currentFrame().ReadUint16() ref := vm.pop() if obj, ok := ref.(object.Copyable); ok { vm.state.Globals[globalIndex] = obj.Copy() } else { vm.state.Globals[globalIndex] = ref } return vm.push(object.NULL) } func (vm *VM) executeSetLocal() error { localIndex := vm.currentFrame().ReadUint8() ref := vm.pop() if obj, ok := ref.(object.Copyable); ok { vm.stack[vm.currentFrame().basePointer+int(localIndex)] = obj.Copy() } else { vm.stack[vm.currentFrame().basePointer+int(localIndex)] = ref } return vm.push(object.NULL) } func (vm *VM) buildHash(startIndex, endIndex int) (object.Object, error) { hashedPairs := make(map[object.HashKey]object.HashPair) for i := startIndex; i < endIndex; i += 2 { key := vm.stack[i] value := vm.stack[i+1] pair := object.HashPair{Key: key, Value: value} hashKey, ok := key.(object.Hasher) if !ok { return nil, fmt.Errorf("unusable as hash key: %s", key.Type()) } hashedPairs[hashKey.Hash()] = pair } return &object.Hash{Pairs: hashedPairs}, nil } func (vm *VM) executeMakeHash() error { startIndex := vm.sp - int(vm.currentFrame().ReadUint16()) hash, err := vm.buildHash(startIndex, vm.sp) if err != nil { return err } vm.sp = startIndex return vm.push(hash) } func (vm *VM) buildArray(startIndex, endIndex int) (object.Object, error) { elements := make([]object.Object, endIndex-startIndex) for i := startIndex; i < endIndex; i++ { elements[i-startIndex] = vm.stack[i] } return &object.Array{Elements: elements}, nil } func (vm *VM) executeMakeArray() error { startIndex := vm.sp - int(vm.currentFrame().ReadUint16()) hash, err := vm.buildArray(startIndex, vm.sp) if err != nil { return err } vm.sp = startIndex return vm.push(hash) } func (vm *VM) executeAdd() error { right, left := vm.pop2() switch obj := left.(type) { case object.Integer: val, err := obj.Add(right) if err != nil { return err } return vm.push(val) case object.String: val, err := obj.Add(right) if err != nil { return err } return vm.push(val) case *object.Array: val, err := obj.Add(right) if err != nil { return err } return vm.push(val) case *object.Hash: val, err := obj.Add(right) if err != nil { return err } return vm.push(val) default: return object.NewBinaryOpError(left, right, "+") } } func (vm *VM) executeSub() error { right, left := vm.pop2() switch obj := left.(type) { case object.Integer: val, err := obj.Sub(right) if err != nil { return err } return vm.push(val) default: return fmt.Errorf("unsupported types for unary operation: -%s", left.Type()) } } func (vm *VM) executeMul() error { right, left := vm.pop2() switch obj := left.(type) { case *object.Array: val, err := obj.Mul(right) if err != nil { return err } return vm.push(val) case object.Integer: val, err := obj.Mul(right) if err != nil { return err } return vm.push(val) case object.String: val, err := obj.Mul(right) if err != nil { return err } return vm.push(val) default: return object.NewBinaryOpError(left, right, "*") } } func (vm *VM) executeDiv() error { right, left := vm.pop2() switch obj := left.(type) { case object.Integer: val, err := obj.Div(right) if err != nil { return err } return vm.push(val) default: return object.NewBinaryOpError(left, right, "/") } } func (vm *VM) executeMod() error { right, left := vm.pop2() switch obj := left.(type) { case object.Integer: val, err := obj.Mod(right) if err != nil { return err } return vm.push(val) default: return object.NewBinaryOpError(left, right, "%") } } func (vm *VM) executeOr() error { right, left := vm.pop2() switch obj := left.(type) { case object.Boolean: val, err := obj.LogicalOr(right) if err != nil { return err } return vm.push(val) default: return object.NewBinaryOpError(left, right, "||") } } func (vm *VM) executeAnd() error { right, left := vm.pop2() switch obj := left.(type) { case object.Boolean: val, err := obj.LogicalAnd(right) if err != nil { return err } return vm.push(val) default: return object.NewBinaryOpError(left, right, "&&") } } func (vm *VM) executeBitwiseOr() error { right, left := vm.pop2() switch obj := left.(type) { case object.Integer: val, err := obj.BitwiseOr(right) if err != nil { return err } return vm.push(val) default: return object.NewBinaryOpError(left, right, "|") } } func (vm *VM) executeBitwiseXor() error { right, left := vm.pop2() switch obj := left.(type) { case object.Integer: val, err := obj.BitwiseXor(right) if err != nil { return err } return vm.push(val) default: return object.NewBinaryOpError(left, right, "^") } } func (vm *VM) executeBitwiseAnd() error { right, left := vm.pop2() switch obj := left.(type) { case object.Integer: val, err := obj.BitwiseAnd(right) if err != nil { return err } return vm.push(val) default: return object.NewBinaryOpError(left, right, "&") } } func (vm *VM) executeBitwiseNot() error { left := vm.pop() switch obj := left.(type) { case object.Integer: return vm.push(obj.BitwiseNot()) default: return fmt.Errorf("unsupported types for unary operation: ~%s", left.Type()) } } func (vm *VM) executeLeftShift() error { right, left := vm.pop2() switch obj := left.(type) { case object.Integer: val, err := obj.LeftShift(right) if err != nil { return err } return vm.push(val) default: return object.NewBinaryOpError(left, right, "<<") } } func (vm *VM) executeRightShift() error { right, left := vm.pop2() switch obj := left.(type) { case object.Integer: val, err := obj.RightShift(right) if err != nil { return err } return vm.push(val) default: return object.NewBinaryOpError(left, right, ">>") } } func (vm *VM) executeEqual() error { right, left := vm.pop2() if left.Compare(right) == 0 { return vm.push(object.TRUE) } return vm.push(object.FALSE) } func (vm *VM) executeNotEqual() error { right, left := vm.pop2() if left.Compare(right) != 0 { return vm.push(object.TRUE) } return vm.push(object.FALSE) } func (vm *VM) executeGreaterThan() error { right, left := vm.pop2() if left.Compare(right) == 1 { return vm.push(object.TRUE) } return vm.push(object.FALSE) } func (vm *VM) executeGreaterThanOrEqual() error { right, left := vm.pop2() if left.Compare(right) == 1 { return vm.push(object.TRUE) } return vm.push(object.FALSE) } func (vm *VM) executeNot() error { left := vm.pop() switch obj := left.(type) { case object.Boolean: return vm.push(obj.LogicalNot()) case object.Integer: return vm.push(obj.LogicalNot()) case object.Null: return vm.push(obj.LogicalNot()) default: return fmt.Errorf("unsupported types for unary operation: !%s", left.Type()) } } func (vm *VM) executeMinus() error { left := vm.pop() switch obj := left.(type) { case object.Integer: return vm.push(obj.Negate()) default: return fmt.Errorf("unsupported types for unary operation: -%s", left.Type()) } } func (vm *VM) executeSetItem() error { right, index, left := vm.pop3() switch obj := left.(type) { case *object.Array: err := obj.Set(index, right) if err != nil { return err } return vm.push(object.NULL) case *object.Hash: err := obj.Set(index, right) if err != nil { return err } return vm.push(object.NULL) default: return fmt.Errorf( "set item operation not supported: left=%s index=%s", left.Type(), index.Type(), ) } } func (vm *VM) executeGetItem() error { index, left := vm.pop2() switch obj := left.(type) { case object.String: val, err := obj.Get(index) if err != nil { return err } return vm.push(val) case *object.Array: val, err := obj.Get(index) if err != nil { return err } return vm.push(val) case *object.Hash: val, err := obj.Get(index) if err != nil { return err } return vm.push(val) case object.Module: val, err := obj.Get(index) if err != nil { return err } return vm.push(val) default: return fmt.Errorf( "get item operation not supported: left=%s index=%s", left.Type(), index.Type(), ) } } func (vm *VM) executeCall() error { args := int(vm.currentFrame().ReadUint8()) callee := vm.stack[vm.sp-1-args] switch callee := callee.(type) { case *object.Closure: return vm.callClosure(callee, args) case object.Builtin: return vm.callBuiltin(callee, args) default: return fmt.Errorf( "calling non-closure and non-builtin: %T %v", callee, callee, ) } } func (vm *VM) callClosure(cl *object.Closure, numArgs int) error { if numArgs != cl.Fn.NumParameters { return fmt.Errorf("wrong number of arguments: want=%d, got=%d", cl.Fn.NumParameters, numArgs) } // Optimize tail calls and avoid a new frame if cl.Fn == vm.currentFrame().cl.Fn { nextOP := vm.currentFrame().PeekNextOp() if nextOP == code.OpReturn { for p := 0; p < numArgs; p++ { vm.stack[vm.currentFrame().basePointer+p] = vm.stack[vm.sp-numArgs+p] } vm.sp -= numArgs + 1 vm.currentFrame().SetIP(0) // reset IP to the beginning of the frame return nil } } frame := newFrame(cl, vm.sp-numArgs) vm.pushFrame(frame) vm.sp = frame.basePointer + cl.Fn.NumLocals return nil } func (vm *VM) callBuiltin(builtin object.Builtin, numArgs int) error { args := vm.stack[vm.sp-numArgs : vm.sp] result := builtin.Fn(vm.ctx, args...) vm.sp = vm.sp - numArgs - 1 if result != nil { return vm.push(result) } return vm.push(object.NULL) } func (vm *VM) pushClosure(constIndex, numFree int) error { constant := vm.state.Constants[constIndex] function, ok := constant.(*object.CompiledFunction) if !ok { return fmt.Errorf("not a function %+v", constant) } free := make([]object.Object, numFree) for i := 0; i < numFree; i++ { free[i] = vm.stack[vm.sp-numFree+i] } vm.sp -= numFree closure := object.Closure{Fn: function, Free: free} return vm.push(&closure) } func (vm *VM) loadModule(name object.Object) error { s, ok := name.(object.String) if !ok { return fmt.Errorf( "TypeError: import() expected argument #1 to be `str` got `%s`", name.Type(), ) } attrs, err := executeModule(s.Value, vm.state) if err != nil { return err } module := object.Module{Name: s.Value, Attrs: attrs} return vm.push(module) } func (vm *VM) LastPoppedStackElem() object.Object { return vm.stack[vm.sp] } func (vm *VM) Run() (err error) { for err == nil { op := vm.currentFrame().ReadNextOp() switch op { case code.OpConstant: constIndex := vm.currentFrame().ReadUint16() err = vm.push(vm.state.Constants[constIndex]) case code.OpPop: vm.pop() case code.OpTrue: err = vm.push(object.TRUE) case code.OpFalse: err = vm.push(object.FALSE) case code.OpJump: pos := vm.currentFrame().ReadUint16() vm.currentFrame().SetIP(pos) case code.OpJumpNotTruthy: pos := vm.currentFrame().ReadUint16() if !isTruthy(vm.pop()) { vm.currentFrame().SetIP(pos) } case code.OpNull: err = vm.push(object.NULL) case code.OpSetGlobal: err = vm.executeSetGlobal() case code.OpAssignGlobal: globalIndex := vm.currentFrame().ReadUint16() vm.state.Globals[globalIndex] = vm.pop() err = vm.push(object.NULL) case code.OpAssignLocal: localIndex := vm.currentFrame().ReadUint8() vm.stack[vm.currentFrame().basePointer+int(localIndex)] = vm.pop() err = vm.push(object.NULL) case code.OpGetGlobal: globalIndex := vm.currentFrame().ReadUint16() err = vm.push(vm.state.Globals[globalIndex]) case code.OpArray: err = vm.executeMakeArray() case code.OpHash: err = vm.executeMakeHash() case code.OpSetItem: err = vm.executeSetItem() case code.OpGetItem: err = vm.executeGetItem() case code.OpCall: err = vm.executeCall() case code.OpReturn: returnValue := vm.pop() frame := vm.popFrame() vm.sp = frame.basePointer - 1 err = vm.push(returnValue) case code.OpSetLocal: err = vm.executeSetLocal() case code.OpGetLocal: localIndex := vm.currentFrame().ReadUint8() err = vm.push(vm.stack[vm.currentFrame().basePointer+int(localIndex)]) case code.OpGetBuiltin: builtinIndex := vm.currentFrame().ReadUint8() builtin := builtins.BuiltinsIndex[builtinIndex] err = vm.push(builtin) case code.OpClosure: constIndex := vm.currentFrame().ReadUint16() numFree := vm.currentFrame().ReadUint8() err = vm.pushClosure(int(constIndex), int(numFree)) case code.OpGetFree: freeIndex := vm.currentFrame().ReadUint8() err = vm.push(vm.currentFrame().GetFree(freeIndex)) case code.OpCurrentClosure: currentClosure := vm.currentFrame().cl err = vm.push(currentClosure) case code.OpLoadModule: name := vm.pop() err = vm.loadModule(name) case code.OpAdd: err = vm.executeAdd() case code.OpSub: err = vm.executeSub() case code.OpMul: err = vm.executeMul() case code.OpDiv: err = vm.executeDiv() case code.OpMod: err = vm.executeMod() case code.OpOr: err = vm.executeOr() case code.OpAnd: err = vm.executeAnd() case code.OpBitwiseOR: err = vm.executeBitwiseOr() case code.OpBitwiseXOR: err = vm.executeBitwiseXor() case code.OpBitwiseAND: err = vm.executeBitwiseAnd() case code.OpLeftShift: err = vm.executeLeftShift() case code.OpRightShift: err = vm.executeRightShift() case code.OpEqual: err = vm.executeEqual() case code.OpNotEqual: err = vm.executeNotEqual() case code.OpGreaterThan: err = vm.executeGreaterThan() case code.OpGreaterThanEqual: err = vm.executeGreaterThanOrEqual() case code.OpNot: err = vm.executeNot() case code.OpBitwiseNOT: err = vm.executeBitwiseNot() case code.OpMinus: err = vm.executeMinus() case code.OpHalt: return default: err = fmt.Errorf("unhandled opcode: %s", op) } } return }