package vm import ( "fmt" "log" "monkey/code" "monkey/compiler" "monkey/object" "strings" ) const StackSize = 2048 const GlobalsSize = 65536 const MaxFrames = 1024 var Null = &object.Null{} var True = &object.Boolean{Value: true} var False = &object.Boolean{Value: false} type VM struct { Debug bool constants []object.Object stack []object.Object sp int // Always points to the next value. Top of stack is stack[sp-1] globals []object.Object frames []*Frame framesIndex int } func New(bytecode *compiler.Bytecode) *VM { mainFn := &object.CompiledFunction{Instructions: bytecode.Instructions} mainClosure := &object.Closure{Fn: mainFn} mainFrame := NewFrame(mainClosure, 0) frames := make([]*Frame, MaxFrames) frames[0] = mainFrame return &VM{ constants: bytecode.Constants, stack: make([]object.Object, StackSize), sp: 0, globals: make([]object.Object, GlobalsSize), frames: frames, framesIndex: 1, } } func NewWithGlobalState(bytecode *compiler.Bytecode, s []object.Object) *VM { vm := New(bytecode) vm.globals = s return vm } func (vm *VM) currentFrame() *Frame { return vm.frames[vm.framesIndex-1] } func (vm *VM) pushFrame(f *Frame) { vm.frames[vm.framesIndex] = f vm.framesIndex++ } func (vm *VM) popFrame() *Frame { vm.framesIndex-- return vm.frames[vm.framesIndex] } func (vm *VM) LastPoppedStackElem() object.Object { return vm.stack[vm.sp] } func (vm *VM) Run() error { var ip int var ins code.Instructions var op code.Opcode if vm.Debug { log.Printf( "%-25s %-20s\n", fmt.Sprintf( "%04d %s", ip, strings.Split(ins[ip:].String(), "\n")[0][4:], ), fmt.Sprintf( "[ip=%02d fp=%02d, sp=%02d]", ip, vm.framesIndex-1, vm.sp, ), ) } for vm.currentFrame().ip < len(vm.currentFrame().Instructions())-1 { vm.currentFrame().ip++ ip = vm.currentFrame().ip ins = vm.currentFrame().Instructions() op = code.Opcode(ins[ip]) switch op { case code.OpConstant: constIndex := code.ReadUint16(ins[ip+1:]) vm.currentFrame().ip += 2 err := vm.push(vm.constants[constIndex]) if err != nil { return err } case code.OpAdd, code.OpSub, code.OpMul, code.OpDiv: err := vm.executeBinaryOperation(op) if err != nil { return err } case code.OpPop: vm.pop() case code.OpTrue: err := vm.push(True) if err != nil { return err } case code.OpFalse: err := vm.push(False) if err != nil { return err } case code.OpEqual, code.OpNotEqual, code.OpGreaterThan, code.OpGreaterThanEqual: err := vm.executeComparison(op) if err != nil { return err } case code.OpBang: err := vm.executeBangOperator() if err != nil { return err } case code.OpMinus: err := vm.executeMinusOperator() if err != nil { return err } case code.OpJump: pos := int(code.ReadUint16(ins[ip+1:])) vm.currentFrame().ip = pos - 1 case code.OpJumpNotTruthy: pos := int(code.ReadUint16(ins[ip+1:])) vm.currentFrame().ip += 2 condition := vm.pop() if !isTruthy(condition) { vm.currentFrame().ip = pos - 1 } case code.OpNull: err := vm.push(Null) if err != nil { return err } case code.OpSetGlobal: globalIndex := code.ReadUint16(ins[ip+1:]) vm.currentFrame().ip += 2 ref := vm.pop() if immutable, ok := ref.(object.Immutable); ok { vm.globals[globalIndex] = immutable.Clone() } else { vm.globals[globalIndex] = ref } err := vm.push(Null) if err != nil { return err } case code.OpAssignGlobal: globalIndex := code.ReadUint16(ins[ip+1:]) vm.currentFrame().ip += 2 vm.globals[globalIndex] = vm.pop() err := vm.push(Null) if err != nil { return err } case code.OpAssignLocal: localIndex := code.ReadUint8(ins[ip+1:]) vm.currentFrame().ip += 1 frame := vm.currentFrame() vm.stack[frame.basePointer+int(localIndex)] = vm.pop() err := vm.push(Null) if err != nil { return err } case code.OpGetGlobal: globalIndex := code.ReadUint16(ins[ip+1:]) vm.currentFrame().ip += 2 err := vm.push(vm.globals[globalIndex]) if err != nil { return err } case code.OpArray: numElements := int(code.ReadUint16(ins[ip+1:])) vm.currentFrame().ip += 2 array := vm.buildArray(vm.sp-numElements, vm.sp) vm.sp = vm.sp - numElements err := vm.push(array) if err != nil { return err } case code.OpHash: numElements := int(code.ReadUint16(ins[ip+1:])) vm.currentFrame().ip += 2 hash, err := vm.buildHash(vm.sp-numElements, vm.sp) if err != nil { return err } vm.sp = vm.sp - numElements err = vm.push(hash) if err != nil { return err } case code.OpSetItem: value := vm.pop() index := vm.pop() left := vm.pop() err := vm.executeSetItem(left, index, value) if err != nil { return err } case code.OpGetItem: index := vm.pop() left := vm.pop() err := vm.executeGetItem(left, index) if err != nil { return err } case code.OpCall: numArgs := code.ReadUint8(ins[ip+1:]) vm.currentFrame().ip += 1 err := vm.executeCall(int(numArgs)) if err != nil { return err } case code.OpReturn: returnValue := vm.pop() frame := vm.popFrame() vm.sp = frame.basePointer - 1 err := vm.push(returnValue) if err != nil { return err } case code.OpSetLocal: localIndex := code.ReadUint8(ins[ip+1:]) vm.currentFrame().ip += 1 frame := vm.currentFrame() ref := vm.pop() if immutable, ok := ref.(object.Immutable); ok { vm.stack[frame.basePointer+int(localIndex)] = immutable.Clone() } else { vm.stack[frame.basePointer+int(localIndex)] = ref } err := vm.push(Null) if err != nil { return err } case code.OpGetLocal: localIndex := code.ReadUint8(ins[ip+1:]) vm.currentFrame().ip += 1 frame := vm.currentFrame() err := vm.push(vm.stack[frame.basePointer+int(localIndex)]) if err != nil { return err } case code.OpGetBuiltin: builtinIndex := code.ReadUint8(ins[ip+1:]) vm.currentFrame().ip += 1 builtin := object.BuiltinsIndex[builtinIndex] err := vm.push(builtin) if err != nil { return err } case code.OpClosure: constIndex := code.ReadUint16(ins[ip+1:]) numFree := code.ReadUint8(ins[ip+3:]) vm.currentFrame().ip += 3 err := vm.pushClosure(int(constIndex), int(numFree)) if err != nil { return err } case code.OpGetFree: freeIndex := code.ReadUint8(ins[ip+1:]) vm.currentFrame().ip += 1 currentClosure := vm.currentFrame().cl err := vm.push(currentClosure.Free[freeIndex]) if err != nil { return err } case code.OpCurrentClosure: currentClosure := vm.currentFrame().cl err := vm.push(currentClosure) if err != nil { return err } } if vm.Debug { log.Printf( "%-25s [ip=%02d fp=%02d, sp=%02d]", "", ip, vm.framesIndex-1, vm.sp, ) } } return nil } func (vm *VM) executeSetItem(left, index, value object.Object) error { switch { case left.Type() == object.ARRAY_OBJ && index.Type() == object.INTEGER_OBJ: return vm.executeArraySetItem(left, index, value) case left.Type() == object.HASH_OBJ: return vm.executeHashSetItem(left, index, value) default: return fmt.Errorf( "set item operation not supported: left=%s index=%s", left.Type(), index.Type(), ) } } func (vm *VM) executeGetItem(left, index object.Object) error { switch { case left.Type() == object.STRING_OBJ && index.Type() == object.INTEGER_OBJ: return vm.executeStringGetItem(left, index) case left.Type() == object.STRING_OBJ && index.Type() == object.STRING_OBJ: return vm.executeStringIndex(left, index) case left.Type() == object.ARRAY_OBJ && index.Type() == object.INTEGER_OBJ: return vm.executeArrayGetItem(left, index) case left.Type() == object.HASH_OBJ: return vm.executeHashGetItem(left, index) default: return fmt.Errorf( "index operator not supported: left=%s index=%s", left.Type(), index.Type(), ) } } func (vm *VM) executeArrayGetItem(array, index object.Object) error { arrayObject := array.(*object.Array) i := index.(*object.Integer).Value max := int64(len(arrayObject.Elements) - 1) if i < 0 || i > max { return vm.push(Null) } return vm.push(arrayObject.Elements[i]) } func (vm *VM) executeArraySetItem(array, index, value object.Object) error { arrayObject := array.(*object.Array) i := index.(*object.Integer).Value max := int64(len(arrayObject.Elements) - 1) if i < 0 || i > max { return fmt.Errorf("index out of bounds: %d", i) } arrayObject.Elements[i] = value return vm.push(Null) } func (vm *VM) executeHashGetItem(hash, index object.Object) error { hashObject := hash.(*object.Hash) key, ok := index.(object.Hashable) if !ok { return fmt.Errorf("unusable as hash key: %s", index.Type()) } pair, ok := hashObject.Pairs[key.HashKey()] if !ok { return vm.push(Null) } return vm.push(pair.Value) } func (vm *VM) executeHashSetItem(hash, index, value object.Object) error { hashObject := hash.(*object.Hash) key, ok := index.(object.Hashable) if !ok { return fmt.Errorf("unusable as hash key: %s", index.Type()) } hashed := key.HashKey() hashObject.Pairs[hashed] = object.HashPair{Key: index, Value: value} return vm.push(Null) } func (vm *VM) buildHash(startIndex, endIndex int) (object.Object, error) { hashedPairs := make(map[object.HashKey]object.HashPair) for i := startIndex; i < endIndex; i += 2 { key := vm.stack[i] value := vm.stack[i+1] pair := object.HashPair{Key: key, Value: value} hashKey, ok := key.(object.Hashable) if !ok { return nil, fmt.Errorf("unusable as hash key: %s", key.Type()) } hashedPairs[hashKey.HashKey()] = pair } return &object.Hash{Pairs: hashedPairs}, nil } func (vm *VM) buildArray(startIndex, endIndex int) object.Object { elements := make([]object.Object, endIndex-startIndex) for i := startIndex; i < endIndex; i++ { elements[i-startIndex] = vm.stack[i] } return &object.Array{Elements: elements} } func isTruthy(obj object.Object) bool { switch obj := obj.(type) { case *object.Boolean: return obj.Value case *object.Null: return false default: return true } } func (vm *VM) push(o object.Object) error { if vm.sp >= StackSize { return fmt.Errorf("stack overflow") } vm.stack[vm.sp] = o vm.sp++ return nil } func (vm *VM) pop() object.Object { o := vm.stack[vm.sp-1] vm.sp-- return o } func (vm *VM) executeBinaryOperation(op code.Opcode) error { right := vm.pop() left := vm.pop() leftType := left.Type() rightRight := right.Type() switch { case leftType == object.INTEGER_OBJ && rightRight == object.INTEGER_OBJ: return vm.executeBinaryIntegerOperation(op, left, right) case leftType == object.STRING_OBJ && rightRight == object.STRING_OBJ: return vm.executeBinaryStringOperation(op, left, right) default: return fmt.Errorf("unsupported types for binary operation: %s %s", leftType, rightRight) } } func (vm *VM) executeBinaryIntegerOperation(op code.Opcode, left, right object.Object) error { leftValue := left.(*object.Integer).Value rightValue := right.(*object.Integer).Value var result int64 switch op { case code.OpAdd: result = leftValue + rightValue case code.OpSub: result = leftValue - rightValue case code.OpMul: result = leftValue * rightValue case code.OpDiv: result = leftValue / rightValue default: return fmt.Errorf("unknown integer operator: %d", op) } return vm.push(&object.Integer{Value: result}) } func (vm *VM) executeBinaryStringOperation(op code.Opcode, left, right object.Object) error { if op != code.OpAdd { return fmt.Errorf("unknown string operator: %d", op) } leftValue := left.(*object.String).Value rightValue := right.(*object.String).Value return vm.push(&object.String{Value: leftValue + rightValue}) } func (vm *VM) executeComparison(op code.Opcode) error { right := vm.pop() left := vm.pop() if left.Type() == object.INTEGER_OBJ && right.Type() == object.INTEGER_OBJ { return vm.executeIntegerComparison(op, left, right) } if left.Type() == object.STRING_OBJ && right.Type() == object.STRING_OBJ { return vm.executeStringComparison(op, left, right) } switch op { case code.OpEqual: return vm.push(nativeBoolToBooleanObject(right == left)) case code.OpNotEqual: return vm.push(nativeBoolToBooleanObject(right != left)) default: return fmt.Errorf("unknown operator: %d (%s %s)", op, left.Type(), right.Type()) } } func (vm *VM) executeIntegerComparison(op code.Opcode, left, right object.Object) error { leftValue := left.(*object.Integer).Value rightValue := right.(*object.Integer).Value switch op { case code.OpEqual: return vm.push(nativeBoolToBooleanObject(rightValue == leftValue)) case code.OpNotEqual: return vm.push(nativeBoolToBooleanObject(rightValue != leftValue)) case code.OpGreaterThan: return vm.push(nativeBoolToBooleanObject(leftValue > rightValue)) case code.OpGreaterThanEqual: return vm.push(nativeBoolToBooleanObject(leftValue >= rightValue)) default: return fmt.Errorf("unknown operator: %d", op) } } func (vm *VM) executeBangOperator() error { operand := vm.pop() switch operand { case True: return vm.push(False) case False: return vm.push(True) case Null: return vm.push(True) default: return vm.push(False) } } func (vm *VM) executeMinusOperator() error { operand := vm.pop() if operand.Type() != object.INTEGER_OBJ { return fmt.Errorf("unsupported type for negation: %s", operand.Type()) } value := operand.(*object.Integer).Value return vm.push(&object.Integer{Value: -value}) } func (vm *VM) executeCall(numArgs int) error { callee := vm.stack[vm.sp-1-numArgs] switch callee := callee.(type) { case *object.Closure: return vm.callClosure(callee, numArgs) case *object.Builtin: return vm.callBuiltin(callee, numArgs) default: return fmt.Errorf("calling non-function and non-built-in") } } func (vm *VM) callClosure(cl *object.Closure, numArgs int) error { if numArgs != cl.Fn.NumParameters { return fmt.Errorf("wrong number of arguments: want=%d, got=%d", cl.Fn.NumParameters, numArgs) } // Optimize tail calls and avoid a new frame if cl.Fn == vm.currentFrame().cl.Fn { nextOP := vm.currentFrame().NextOp() if nextOP == code.OpReturn { for p := 0; p < numArgs; p++ { vm.stack[vm.currentFrame().basePointer+p] = vm.stack[vm.sp-numArgs+p] } vm.sp -= numArgs + 1 vm.currentFrame().ip = -1 // reset IP to the beginning of the frame return nil } } frame := NewFrame(cl, vm.sp-numArgs) vm.pushFrame(frame) vm.sp = frame.basePointer + cl.Fn.NumLocals return nil } func (vm *VM) callBuiltin(builtin *object.Builtin, numArgs int) error { args := vm.stack[vm.sp-numArgs : vm.sp] result := builtin.Fn(args...) vm.sp = vm.sp - numArgs - 1 if result != nil { err := vm.push(result) if err != nil { return err } } else { err := vm.push(Null) if err != nil { return err } } return nil } func (vm *VM) pushClosure(constIndex, numFree int) error { constant := vm.constants[constIndex] function, ok := constant.(*object.CompiledFunction) if !ok { return fmt.Errorf("not a function %+v", constant) } free := make([]object.Object, numFree) for i := 0; i < numFree; i++ { free[i] = vm.stack[vm.sp-numFree+i] } vm.sp = vm.sp - numFree closure := &object.Closure{Fn: function, Free: free} return vm.push(closure) } func (vm *VM) executeStringGetItem(str, index object.Object) error { stringObject := str.(*object.String) i := index.(*object.Integer).Value max := int64(len(stringObject.Value) - 1) if i < 0 || i > max { return vm.push(&object.String{Value: ""}) } return vm.push(&object.String{Value: string(stringObject.Value[i])}) } func (vm *VM) executeStringIndex(str, index object.Object) error { stringObject := str.(*object.String) substr := index.(*object.String).Value return vm.push( &object.Integer{ Value: int64(strings.Index(stringObject.Value, substr)), }, ) } func (vm *VM) executeStringComparison(op code.Opcode, left, right object.Object) error { leftValue := left.(*object.String).Value rightValue := right.(*object.String).Value switch op { case code.OpEqual: return vm.push(nativeBoolToBooleanObject(rightValue == leftValue)) case code.OpNotEqual: return vm.push(nativeBoolToBooleanObject(rightValue != leftValue)) case code.OpGreaterThan: return vm.push(nativeBoolToBooleanObject(leftValue > rightValue)) case code.OpGreaterThanEqual: return vm.push(nativeBoolToBooleanObject(leftValue >= rightValue)) default: return fmt.Errorf("unknown operator: %d", op) } } func nativeBoolToBooleanObject(input bool) *object.Boolean { if input { return True } return False }