diff --git a/code/code.go b/code/code.go index 5c72efe..ce72499 100644 --- a/code/code.go +++ b/code/code.go @@ -13,6 +13,10 @@ type Opcode byte const ( OpConstant Opcode = iota OpAdd + OpPop + OpSub + OpMul + OpDiv ) type Definition struct { @@ -23,6 +27,10 @@ type Definition struct { var definitions = map[Opcode]*Definition{ OpConstant: {"OpConstant", []int{2}}, OpAdd: {"OpAdd", []int{}}, + OpPop: {"OpPop", []int{}}, + OpSub: {"OpSub", []int{}}, + OpMul: {"OpMul", []int{}}, + OpDiv: {"OpDiv", []int{}}, } func Lookup(op byte) (*Definition, error) { diff --git a/compiler/compiler.go b/compiler/compiler.go index b47d10f..9d45ffc 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -34,6 +34,7 @@ func (c *Compiler) Compile(node ast.Node) error { if err != nil { return err } + c.emit(code.OpPop) case *ast.InfixExpression: err := c.Compile(node.Left) @@ -49,6 +50,12 @@ func (c *Compiler) Compile(node ast.Node) error { switch node.Operator { case "+": c.emit(code.OpAdd) + case "-": + c.emit(code.OpSub) + case "*": + c.emit(code.OpMul) + case "/": + c.emit(code.OpDiv) default: return fmt.Errorf("unknown operator %s", node.Operator) } diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go index 8ce7939..95635bf 100644 --- a/compiler/compiler_test.go +++ b/compiler/compiler_test.go @@ -25,6 +25,47 @@ func TestIntegerArithmetic(t *testing.T) { code.Make(code.OpConstant, 0), code.Make(code.OpConstant, 1), code.Make(code.OpAdd), + code.Make(code.OpPop), + }, + }, + { + input: "1; 2", + expectedConstants: []interface{}{1, 2}, + expectedInstructions: []code.Instructions{ + code.Make(code.OpConstant, 0), + code.Make(code.OpPop), + code.Make(code.OpConstant, 1), + code.Make(code.OpPop), + }, + }, + { + input: "1 - 2", + expectedConstants: []interface{}{1, 2}, + expectedInstructions: []code.Instructions{ + code.Make(code.OpConstant, 0), + code.Make(code.OpConstant, 1), + code.Make(code.OpSub), + code.Make(code.OpPop), + }, + }, + { + input: "1 * 2", + expectedConstants: []interface{}{1, 2}, + expectedInstructions: []code.Instructions{ + code.Make(code.OpConstant, 0), + code.Make(code.OpConstant, 1), + code.Make(code.OpMul), + code.Make(code.OpPop), + }, + }, + { + input: "2 / 1", + expectedConstants: []interface{}{2, 1}, + expectedInstructions: []code.Instructions{ + code.Make(code.OpConstant, 0), + code.Make(code.OpConstant, 1), + code.Make(code.OpDiv), + code.Make(code.OpPop), }, }, } diff --git a/repl/repl.go b/repl/repl.go index 261087d..2c821ff 100644 --- a/repl/repl.go +++ b/repl/repl.go @@ -46,7 +46,7 @@ func Start(in io.Reader, out io.Writer) { continue } - stackTop := machine.StackTop() + stackTop := machine.LastPoppedStackElem() io.WriteString(out, stackTop.Inspect()) io.WriteString(out, "\n") } diff --git a/vm/vm.go b/vm/vm.go index b3e1a5b..2e1998b 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -27,11 +27,8 @@ func New(bytecode *compiler.Bytecode) *VM { } } -func (vm *VM) StackTop() object.Object { - if vm.sp == 0 { - return nil - } - return vm.stack[vm.sp-1] +func (vm *VM) LastPoppedStackElem() object.Object { + return vm.stack[vm.sp] } func (vm *VM) Run() error { @@ -48,14 +45,14 @@ func (vm *VM) Run() error { return err } - case code.OpAdd: - right := vm.pop() - left := vm.pop() - leftValue := left.(*object.Integer).Value - rightValue := right.(*object.Integer).Value + case code.OpAdd, code.OpSub, code.OpMul, code.OpDiv: + err := vm.executeBinaryOperation(op) + if err != nil { + return err + } - result := leftValue + rightValue - vm.push(&object.Integer{Value: result}) + case code.OpPop: + vm.pop() } } @@ -78,3 +75,39 @@ func (vm *VM) pop() object.Object { vm.sp-- return o } + +func (vm *VM) executeBinaryOperation(op code.Opcode) error { + right := vm.pop() + left := vm.pop() + + leftType := left.Type() + rightRight := right.Type() + + if leftType == object.INTEGER_OBJ && rightRight == object.INTEGER_OBJ { + return vm.executeBinaryIntegerOperation(op, left, right) + } + + return fmt.Errorf("unsupported types for binary operation: %s %s", leftType, rightRight) +} + +func (vm *VM) executeBinaryIntegerOperation(op code.Opcode, left, right object.Object) error { + leftValue := left.(*object.Integer).Value + rightValue := right.(*object.Integer).Value + + var result int64 + + switch op { + case code.OpAdd: + result = leftValue + rightValue + case code.OpSub: + result = leftValue - rightValue + case code.OpMul: + result = leftValue * rightValue + case code.OpDiv: + result = leftValue / rightValue + default: + return fmt.Errorf("unknown integer operator: %d", op) + } + + return vm.push(&object.Integer{Value: result}) +} diff --git a/vm/vm_test.go b/vm/vm_test.go index 7ef09e1..2495a2b 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -33,7 +33,7 @@ func runVmTests(t *testing.T, tests []vmTestCase) { t.Fatalf("vm error: %s", err) } - stackElem := vm.StackTop() + stackElem := vm.LastPoppedStackElem() testExpectedObject(t, tt.expected, stackElem) } @@ -75,6 +75,14 @@ func TestIntegerArithmetic(t *testing.T) { {"1", 1}, {"2", 2}, {"1 + 2", 3}, + {"1 * 2", 2}, + {"4 / 2", 2}, + {"50 / 2 * 2 + 10 - 5", 55}, + {"5 + 5 + 5 + 5 - 10", 10}, + {"2 * 2 * 2 * 2 * 2", 32}, + {"5 * 2 + 10", 20}, + {"5 + 2 * 10", 25}, + {"5 * (2 + 10)", 60}, } runVmTests(t, tests)