From f08b458325f4ea5f38c76ad5d0ffb3248a036580 Mon Sep 17 00:00:00 2001 From: Chuck Smith Date: Tue, 2 Apr 2024 14:54:08 -0400 Subject: [PATCH] optimizations --- internal/evaluator/evaluator.go | 270 +++++++-------------------- internal/evaluator/evaluator_test.go | 34 ++-- internal/object/errors.go | 4 +- internal/object/float.go | 25 ++- internal/object/int.go | 13 ++ internal/object/object.go | 18 ++ internal/ops/ops.go | 128 +++++++++++++ internal/vm/vm.go | 158 +++------------- internal/vm/vm_test.go | 18 ++ 9 files changed, 313 insertions(+), 355 deletions(-) create mode 100644 internal/ops/ops.go diff --git a/internal/evaluator/evaluator.go b/internal/evaluator/evaluator.go index 5d0edfe..dae2943 100644 --- a/internal/evaluator/evaluator.go +++ b/internal/evaluator/evaluator.go @@ -7,16 +7,10 @@ import ( "monkey/internal/context" "monkey/internal/lexer" "monkey/internal/object" + "monkey/internal/ops" "monkey/internal/parser" "monkey/internal/utils" "os" - "strings" -) - -var ( - NULL = object.Null{} - TRUE = object.Boolean{Value: true} - FALSE = object.Boolean{Value: false} ) func isError(obj object.Object) bool { @@ -68,7 +62,7 @@ func Eval(ctx context.Context, node ast.Node, env *object.Environment) object.Ob env.Set(ident.Value, value) } - return NULL + return object.NULL } return newError("expected identifier on left got=%T", node.Left) @@ -119,7 +113,7 @@ func Eval(ctx context.Context, node ast.Node, env *object.Environment) object.Ob return newError("expected identifier or index expression got=%T", left) } - return NULL + return object.NULL case *ast.Identifier: return evalIdentifier(node, env) @@ -137,10 +131,10 @@ func Eval(ctx context.Context, node ast.Node, env *object.Environment) object.Ob return object.Float{Value: node.Value} case *ast.Boolean: - return nativeBoolToBooleanObject(node.Value) + return object.FromNativeBoolean(node.Value) case *ast.Null: - return NULL + return object.NULL case *ast.PrefixExpression: right := Eval(ctx, node.Right, env) @@ -226,7 +220,7 @@ func evalWhileExpression(ctx context.Context, we *ast.WhileExpression, env *obje return condition } - if isTruthy(condition) { + if object.IsTruthy(condition) { result = Eval(ctx, we.Consequence, env) } else { break @@ -237,7 +231,7 @@ func evalWhileExpression(ctx context.Context, we *ast.WhileExpression, env *obje return result } - return NULL + return object.NULL } func evalProgram(ctx context.Context, program *ast.Program, env *object.Environment) object.Object { @@ -275,13 +269,6 @@ func evalBlockStatements(ctx context.Context, block *ast.BlockStatement, env *ob return result } -func nativeBoolToBooleanObject(input bool) object.Boolean { - if input { - return TRUE - } - return FALSE -} - func evalPrefixExpression(operator string, right object.Object) object.Object { switch operator { case "!": @@ -292,155 +279,97 @@ func evalPrefixExpression(operator string, right object.Object) object.Object { case "~", "-": return evalIntegerPrefixOperatorExpression(operator, right) default: - return newError("unknown operator: %s%s", operator, right.Type()) + return newError("unsupported operator: %s%s", operator, right.Type()) } } func evalBooleanPrefixOperatorExpression(operator string, right object.Object) object.Object { if right.Type() != object.BooleanType { - return newError("unknown operator: %s%s", operator, right.Type()) + return newError("unsupported operator: %s%s", operator, right.Type()) } switch right { - case TRUE: - return FALSE - case FALSE: - return TRUE - case NULL: - return TRUE + case object.TRUE: + return object.FALSE + case object.FALSE: + return object.TRUE + case object.NULL: + return object.TRUE default: - return FALSE + return object.FALSE } } func evalIntegerPrefixOperatorExpression(operator string, right object.Object) object.Object { if right.Type() != object.IntegerType { - return newError("unknown operator: -%s", right.Type()) + return newError("unsupported operator: -%s", right.Type()) } value := right.(object.Integer).Value switch operator { case "!": - return FALSE + return object.FALSE case "~": return object.Integer{Value: ^value} case "-": return object.Integer{Value: -value} default: - return newError("unknown operator: %s", operator) + return newError("unsupported operator: %s", operator) } } func evalInfixExpression(operator string, left, right object.Object) object.Object { switch { - // {"a": 1} + {"b": 2} - case operator == "+" && left.Type() == object.HashType && right.Type() == object.HashType: - leftVal := left.(*object.Hash).Pairs - rightVal := right.(*object.Hash).Pairs - pairs := make(map[object.HashKey]object.HashPair) - for k, v := range leftVal { - pairs[k] = v - } - for k, v := range rightVal { - pairs[k] = v - } - return &object.Hash{Pairs: pairs} - - // [1] + [2] - case operator == "+" && left.Type() == object.ArrayType && right.Type() == object.ArrayType: - leftVal := left.(*object.Array).Elements - rightVal := right.(*object.Array).Elements - elements := make([]object.Object, len(leftVal)+len(rightVal)) - elements = append(leftVal, rightVal...) - return &object.Array{Elements: elements} - - // [1] * 3 - case operator == "*" && left.Type() == object.ArrayType && right.Type() == object.IntegerType: - leftVal := left.(*object.Array).Elements - rightVal := int(right.(object.Integer).Value) - elements := leftVal - for i := rightVal; i > 1; i-- { - elements = append(elements, leftVal...) - } - return &object.Array{Elements: elements} - - // 3 * [1] - case operator == "*" && left.Type() == object.IntegerType && right.Type() == object.ArrayType: - leftVal := int(left.(object.Integer).Value) - rightVal := right.(*object.Array).Elements - elements := rightVal - for i := leftVal; i > 1; i-- { - elements = append(elements, rightVal...) - } - return &object.Array{Elements: elements} - - // " " * 4 - case operator == "*" && left.Type() == object.StringType && right.Type() == object.IntegerType: - leftVal := left.(object.String).Value - rightVal := right.(object.Integer).Value - return object.String{Value: strings.Repeat(leftVal, int(rightVal))} - - // 4 * " " - case operator == "*" && left.Type() == object.IntegerType && right.Type() == object.StringType: - leftVal := left.(object.Integer).Value - rightVal := right.(object.String).Value - return object.String{Value: strings.Repeat(rightVal, int(leftVal))} - case operator == "==": - return nativeBoolToBooleanObject(left.Compare(right) == 0) + return object.FromNativeBoolean(left.Compare(right) == 0) case operator == "!=": - return nativeBoolToBooleanObject(left.Compare(right) != 0) + return object.FromNativeBoolean(left.Compare(right) != 0) case operator == "<=": - return nativeBoolToBooleanObject(left.Compare(right) < 1) + return object.FromNativeBoolean(left.Compare(right) < 1) case operator == ">=": - return nativeBoolToBooleanObject(left.Compare(right) > -1) + return object.FromNativeBoolean(left.Compare(right) > -1) case operator == "<": - return nativeBoolToBooleanObject(left.Compare(right) == -1) + return object.FromNativeBoolean(left.Compare(right) == -1) case operator == ">": - return nativeBoolToBooleanObject(left.Compare(right) == 1) + return object.FromNativeBoolean(left.Compare(right) == 1) - case left.Type() == object.BooleanType && right.Type() == object.BooleanType: + case left.Type() == right.Type() && left.Type() == object.BooleanType: return evalBooleanInfixExpression(operator, left, right) - case left.Type() == object.IntegerType && right.Type() == object.IntegerType: - return evalIntegerInfixExpression(operator, left, right) - case left.Type() == right.Type() && left.Type() == object.FloatType: - return evalFloatInfixExpression(operator, left, right) - case left.Type() == object.StringType && right.Type() == object.StringType: - return evalStringInfixExpression(operator, left, right) + case operator == "+": + if val, err := ops.Add(left, right); err != nil { + return newError(err.Error()) + } else { + return val + } + case operator == "-": + if val, err := ops.Sub(left, right); err != nil { + return newError(err.Error()) + } else { + return val + } + case operator == "*": + if val, err := ops.Mul(left, right); err != nil { + return newError(err.Error()) + } else { + return val + } + case operator == "/": + if val, err := ops.Div(left, right); err != nil { + return newError(err.Error()) + } else { + return val + } + case operator == "%": + if val, err := ops.Mod(left, right); err != nil { + return newError(err.Error()) + } else { + return val + } default: - return newError("unknown operator: %s %s %s", left.Type(), operator, right.Type()) - } -} - -func evalFloatInfixExpression(operator string, left, right object.Object) object.Object { - leftVal := left.(object.Float).Value - rightVal := right.(object.Float).Value - - switch operator { - case "+": - return object.Float{Value: leftVal + rightVal} - case "-": - return object.Float{Value: leftVal - rightVal} - case "*": - return object.Float{Value: leftVal * rightVal} - case "/": - return object.Float{Value: leftVal / rightVal} - case "<": - return nativeBoolToBooleanObject(leftVal < rightVal) - case "<=": - return nativeBoolToBooleanObject(leftVal <= rightVal) - case ">": - return nativeBoolToBooleanObject(leftVal > rightVal) - case ">=": - return nativeBoolToBooleanObject(leftVal >= rightVal) - case "==": - return nativeBoolToBooleanObject(leftVal == rightVal) - case "!=": - return nativeBoolToBooleanObject(leftVal != rightVal) - default: - return NULL + return newError("unsupported operator: %s %s %s", + left.Type(), operator, right.Type()) } } @@ -450,65 +379,11 @@ func evalBooleanInfixExpression(operator string, left, right object.Object) obje switch operator { case "&&": - return nativeBoolToBooleanObject(leftVal && rightVal) + return object.FromNativeBoolean(leftVal && rightVal) case "||": - return nativeBoolToBooleanObject(leftVal || rightVal) + return object.FromNativeBoolean(leftVal || rightVal) default: - return newError("unknown operator: %s %s %s", left.Type(), operator, right.Type()) - } -} - -func evalStringInfixExpression(operator string, left object.Object, right object.Object) object.Object { - leftVal := left.(object.String).Value - rightVal := right.(object.String).Value - - switch operator { - case "+": - return object.String{Value: leftVal + rightVal} - default: - return newError("unknown operator: %s %s %s", left.Type(), operator, right.Type()) - } -} - -func evalIntegerInfixExpression(operator string, left, right object.Object) object.Object { - leftVal := left.(object.Integer).Value - rightVal := right.(object.Integer).Value - - switch operator { - case "+": - return object.Integer{Value: leftVal + rightVal} - case "-": - return object.Integer{Value: leftVal - rightVal} - case "*": - return object.Integer{Value: leftVal * rightVal} - case "/": - return object.Integer{Value: leftVal / rightVal} - case "%": - return object.Integer{Value: leftVal % rightVal} - case "|": - return object.Integer{Value: leftVal | rightVal} - case "^": - return object.Integer{Value: leftVal ^ rightVal} - case "&": - return object.Integer{Value: leftVal & rightVal} - case "<<": - return object.Integer{Value: leftVal << uint64(rightVal)} - case ">>": - return object.Integer{Value: leftVal >> uint64(rightVal)} - case "<": - return nativeBoolToBooleanObject(leftVal < rightVal) - case "<=": - return nativeBoolToBooleanObject(leftVal <= rightVal) - case ">": - return nativeBoolToBooleanObject(leftVal > rightVal) - case ">=": - return nativeBoolToBooleanObject(leftVal >= rightVal) - case "==": - return nativeBoolToBooleanObject(leftVal == rightVal) - case "!=": - return nativeBoolToBooleanObject(leftVal != rightVal) - default: - return newError("unknown operator: %s %s %s", left.Type(), operator, right.Type()) + return newError("unsupported operator: %s %s %s", left.Type(), operator, right.Type()) } } @@ -518,25 +393,12 @@ func evalIfExpression(ctx context.Context, ie *ast.IfExpression, env *object.Env return condition } - if isTruthy(condition) { + if object.IsTruthy(condition) { return Eval(ctx, ie.Consequence, env) } else if ie.Alternative != nil { return Eval(ctx, ie.Alternative, env) } else { - return NULL - } -} - -func isTruthy(obj object.Object) bool { - switch obj { - case NULL: - return false - case TRUE: - return true - case FALSE: - return false - default: - return true + return object.NULL } } @@ -605,7 +467,7 @@ func applyFunction(ctx context.Context, fn object.Object, args []object.Object) if result := fn.Fn(ctx, args...); result != nil { return result } - return NULL + return object.NULL default: return newError("not a function: %s", fn.Type()) @@ -673,7 +535,7 @@ func evalHashIndexExpression(hash, index object.Object) object.Object { pair, ok := hashObject.Pairs[key.Hash()] if !ok { - return NULL + return object.NULL } return pair.Value @@ -685,7 +547,7 @@ func evalArrayIndexExpression(array, index object.Object) object.Object { maxInx := int64(len(arrayObject.Elements) - 1) if idx < 0 || idx > maxInx { - return NULL + return object.NULL } return arrayObject.Elements[idx] diff --git a/internal/evaluator/evaluator_test.go b/internal/evaluator/evaluator_test.go index 1f873db..6c7d437 100644 --- a/internal/evaluator/evaluator_test.go +++ b/internal/evaluator/evaluator_test.go @@ -85,6 +85,18 @@ func TestEvalExpressions(t *testing.T) { {"1 << 2", 4}, {"4 >> 2", 1}, {"5.0 / 2.0", 2.5}, + + {"1 + 2.1", 3.1}, + {"2 - 0.5", 1.5}, + {"2 * 2.5", 5.0}, + {"5 / 2.0", 2.5}, + {"5 % 2.0", 1.0}, + + {"2.1 + 1", 3.1}, + {"2.5 - 2", 0.5}, + {"2.5 * 2", 5.0}, + {"5.0 / 2", 2.5}, + {"5.0 % 2", 1.0}, } for _, tt := range tests { @@ -231,27 +243,27 @@ func TestErrorHandling(t *testing.T) { }{ { "5 + true;", - "unknown operator: int + bool", + "unsupported operator: int + bool", }, { "5 + true; 5;", - "unknown operator: int + bool", + "unsupported operator: int + bool", }, { "-true", - "unknown operator: -bool", + "unsupported operator: -bool", }, { "true + false;", - "unknown operator: bool + bool", + "unsupported operator: bool + bool", }, { "5; true + false; 5", - "unknown operator: bool + bool", + "unsupported operator: bool + bool", }, { "if (10 > 1) { true + false; }", - "unknown operator: bool + bool", + "unsupported operator: bool + bool", }, { ` @@ -263,7 +275,7 @@ func TestErrorHandling(t *testing.T) { return 1; } `, - "unknown operator: bool + bool", + "unsupported operator: bool + bool", }, { "foobar", @@ -271,7 +283,7 @@ func TestErrorHandling(t *testing.T) { }, { `"Hello" - "World"`, - "unknown operator: str - str", + "unsupported operator: str - str", }, { `{"name": "Monkey"}[fn(x) { x }];`, @@ -676,8 +688,8 @@ func TestHashLiterals(t *testing.T) { (object.String{Value: "two"}).Hash(): 2, (object.String{Value: "three"}).Hash(): 3, (object.Integer{Value: 4}).Hash(): 4, - TRUE.Hash(): 5, - FALSE.Hash(): 6, + object.TRUE.Hash(): 5, + object.FALSE.Hash(): 6, } if len(result.Pairs) != len(expected) { @@ -882,7 +894,7 @@ func testBooleanObject(t *testing.T, obj object.Object, expected bool) bool { } func testNullObject(t *testing.T, obj object.Object) bool { - if obj != NULL { + if obj != object.NULL { t.Errorf("object is not NULL. got=%T (%+v)", obj, obj) return false } diff --git a/internal/object/errors.go b/internal/object/errors.go index a7295d0..3a066c8 100644 --- a/internal/object/errors.go +++ b/internal/object/errors.go @@ -11,7 +11,7 @@ type BinaryOpError struct { } func (e BinaryOpError) Error() string { - return fmt.Sprintf("unsupported types for binary operation: %s %s %s %s %s", e.left.Type(), e.left.Inspect(), e.op, e.right.Type(), e.right.Inspect()) + return fmt.Sprintf("unsupported operator: %s %s %s", e.left.Type(), e.op, e.right.Type()) } // NewBinaryOpError returns a new BinaryOpError @@ -25,7 +25,7 @@ type DivisionByZeroError struct { } func (e DivisionByZeroError) Error() string { - return fmt.Sprintf("cannot divide %s by zero", e.left) + return fmt.Sprintf("division by zero: %s", e.left) } // NewDivisionByZeroError returns a new DivisionByZeroError diff --git a/internal/object/float.go b/internal/object/float.go index 82964f2..c8b651f 100644 --- a/internal/object/float.go +++ b/internal/object/float.go @@ -17,6 +17,8 @@ func (f Float) Bool() bool { func (f Float) Add(other Object) (Object, error) { switch obj := other.(type) { + case Integer: + return Float{f.Value + float64(obj.Value)}, nil case Float: return Float{f.Value + obj.Value}, nil default: @@ -26,6 +28,8 @@ func (f Float) Add(other Object) (Object, error) { func (f Float) Sub(other Object) (Object, error) { switch obj := other.(type) { + case Integer: + return Float{f.Value - float64(obj.Value)}, nil case Float: return Float{f.Value - obj.Value}, nil default: @@ -35,6 +39,8 @@ func (f Float) Sub(other Object) (Object, error) { func (f Float) Mul(other Object) (Object, error) { switch obj := other.(type) { + case Integer: + return Float{f.Value * float64(obj.Value)}, nil case Float: return Float{f.Value * obj.Value}, nil case String: @@ -48,6 +54,11 @@ func (f Float) Mul(other Object) (Object, error) { func (f Float) Div(other Object) (Object, error) { switch obj := other.(type) { + case Integer: + if obj.Value == 0.0 { + return nil, NewDivisionByZeroError(f) + } + return Float{f.Value / float64(obj.Value)}, nil case Float: if obj.Value == 0 { return nil, NewDivisionByZeroError(f) @@ -60,6 +71,8 @@ func (f Float) Div(other Object) (Object, error) { func (f Float) Mod(other Object) (Object, error) { switch obj := other.(type) { + case Integer: + return Float{math.Mod(f.Value, float64(obj.Value))}, nil case Float: return Float{math.Mod(f.Value, obj.Value)}, nil default: @@ -95,17 +108,17 @@ func (f Float) String() string { } // Copy implements the Copyable interface -func (i Float) Copy() Object { - return Float{Value: i.Value} +func (f Float) Copy() Object { + return Float{Value: f.Value} } // Hash implements the Hasher interface -func (i Float) Hash() HashKey { - return HashKey{Type: i.Type(), Value: uint64(i.Value)} +func (f Float) Hash() HashKey { + return HashKey{Type: f.Type(), Value: uint64(f.Value)} } // Type returns the type of the object -func (i Float) Type() Type { return FloatType } +func (f Float) Type() Type { return FloatType } // Inspect returns a stringified version of the object for debugging -func (i Float) Inspect() string { return fmt.Sprintf("%f", i.Value) } +func (f Float) Inspect() string { return fmt.Sprintf("%f", f.Value) } diff --git a/internal/object/int.go b/internal/object/int.go index 7c6be32..b44ee2c 100644 --- a/internal/object/int.go +++ b/internal/object/int.go @@ -35,6 +35,8 @@ func (i Integer) Add(other Object) (Object, error) { switch obj := other.(type) { case Integer: return Integer{i.Value + obj.Value}, nil + case Float: + return Float{float64(i.Value) + obj.Value}, nil default: return nil, NewBinaryOpError(i, other, "+") } @@ -44,6 +46,8 @@ func (i Integer) Sub(other Object) (Object, error) { switch obj := other.(type) { case Integer: return Integer{i.Value - obj.Value}, nil + case Float: + return Float{float64(i.Value) - obj.Value}, nil default: return nil, NewBinaryOpError(i, other, "-") } @@ -53,6 +57,8 @@ func (i Integer) Mul(other Object) (Object, error) { switch obj := other.(type) { case Integer: return Integer{i.Value * obj.Value}, nil + case Float: + return Float{float64(i.Value) * obj.Value}, nil case String: return obj.Mul(i) case *Array: @@ -69,6 +75,11 @@ func (i Integer) Div(other Object) (Object, error) { return nil, NewDivisionByZeroError(i) } return Integer{i.Value / obj.Value}, nil + case Float: + if obj.Value == 0.0 { + return nil, NewDivisionByZeroError(i) + } + return Float{float64(i.Value) / obj.Value}, nil default: return nil, NewBinaryOpError(i, other, "/") } @@ -78,6 +89,8 @@ func (i Integer) Mod(other Object) (Object, error) { switch obj := other.(type) { case Integer: return Integer{i.Value % obj.Value}, nil + case Float: + return Float{float64(i.Value % int64(obj.Value))}, nil default: return nil, NewBinaryOpError(i, other, "%") } diff --git a/internal/object/object.go b/internal/object/object.go index 38bbf29..4231932 100644 --- a/internal/object/object.go +++ b/internal/object/object.go @@ -95,3 +95,21 @@ func AssertTypes(obj Object, types ...Type) bool { } return false } + +func FromNativeBoolean(b bool) Boolean { + if b { + return TRUE + } + return FALSE +} + +func IsTruthy(obj Object) bool { + switch obj := obj.(type) { + case Boolean: + return obj.Value + case Null: + return false + default: + return true + } +} diff --git a/internal/ops/ops.go b/internal/ops/ops.go new file mode 100644 index 0000000..aeb1610 --- /dev/null +++ b/internal/ops/ops.go @@ -0,0 +1,128 @@ +package ops + +import "monkey/internal/object" + +func Add(left, right object.Object) (object.Object, error) { + switch obj := left.(type) { + case object.Integer: + val, err := obj.Add(right) + if err != nil { + return nil, err + } + return val, nil + case object.Float: + val, err := obj.Add(right) + if err != nil { + return nil, err + } + return val, err + case object.String: + val, err := obj.Add(right) + if err != nil { + return nil, err + } + return val, nil + case *object.Array: + val, err := obj.Add(right) + if err != nil { + return nil, err + } + return val, nil + case *object.Hash: + val, err := obj.Add(right) + if err != nil { + return nil, err + } + return val, nil + default: + return nil, object.NewBinaryOpError(left, right, "+") + } +} + +func Sub(left, right object.Object) (object.Object, error) { + switch obj := left.(type) { + case object.Integer: + val, err := obj.Sub(right) + if err != nil { + return nil, err + } + return val, nil + case object.Float: + val, err := obj.Sub(right) + if err != nil { + return nil, err + } + return val, nil + default: + return nil, object.NewBinaryOpError(left, right, "-") + } +} + +func Mul(left, right object.Object) (object.Object, error) { + switch obj := left.(type) { + case *object.Array: + val, err := obj.Mul(right) + if err != nil { + return nil, err + } + return val, nil + case object.Integer: + val, err := obj.Mul(right) + if err != nil { + return nil, err + } + return val, nil + case object.Float: + val, err := obj.Mul(right) + if err != nil { + return nil, err + } + return val, nil + case object.String: + val, err := obj.Mul(right) + if err != nil { + return nil, err + } + return val, nil + default: + return nil, object.NewBinaryOpError(left, right, "*") + } +} + +func Div(left, right object.Object) (object.Object, error) { + switch obj := left.(type) { + case object.Integer: + val, err := obj.Div(right) + if err != nil { + return nil, err + } + return val, nil + case object.Float: + val, err := obj.Div(right) + if err != nil { + return nil, err + } + return val, nil + default: + return nil, object.NewBinaryOpError(left, right, "/") + } +} + +func Mod(left, right object.Object) (object.Object, error) { + switch obj := left.(type) { + case object.Integer: + val, err := obj.Mod(right) + if err != nil { + return nil, err + } + return val, nil + case object.Float: + val, err := obj.Mod(right) + if err != nil { + return nil, err + } + return val, nil + default: + return nil, object.NewBinaryOpError(left, right, "%") + } +} diff --git a/internal/vm/vm.go b/internal/vm/vm.go index a699443..700b250 100644 --- a/internal/vm/vm.go +++ b/internal/vm/vm.go @@ -9,6 +9,7 @@ import ( "monkey/internal/context" "monkey/internal/lexer" "monkey/internal/object" + "monkey/internal/ops" "monkey/internal/parser" "monkey/internal/utils" "os" @@ -24,20 +25,6 @@ const ( maxGlobals = 65536 ) -func isTruthy(obj object.Object) bool { - switch obj := obj.(type) { - - case object.Boolean: - return obj.Value - - case object.Null: - return false - - default: - return true - } -} - func executeModule(name string, state *State) (object.Object, error) { filename := utils.FindModule(name) if filename == "" { @@ -331,131 +318,51 @@ func (vm *VM) executeMakeArray() error { func (vm *VM) executeAdd() error { right, left := vm.pop2() - switch obj := left.(type) { - case object.Integer: - val, err := obj.Add(right) - if err != nil { - return err - } - return vm.push(val) - case object.Float: - val, err := obj.Add(right) - if err != nil { - return err - } - return vm.push(val) - case object.String: - val, err := obj.Add(right) - if err != nil { - return err - } - return vm.push(val) - case *object.Array: - val, err := obj.Add(right) - if err != nil { - return err - } - return vm.push(val) - case *object.Hash: - val, err := obj.Add(right) - if err != nil { - return err - } - return vm.push(val) - default: - return object.NewBinaryOpError(left, right, "+") + val, err := ops.Add(left, right) + if err != nil { + return err } + return vm.push(val) } func (vm *VM) executeSub() error { right, left := vm.pop2() - switch obj := left.(type) { - case object.Integer: - val, err := obj.Sub(right) - if err != nil { - return err - } - return vm.push(val) - case object.Float: - val, err := obj.Sub(right) - if err != nil { - return err - } - return vm.push(val) - default: - return fmt.Errorf("unsupported types for unary operation: -%s", left.Type()) + val, err := ops.Sub(left, right) + if err != nil { + return err } + return vm.push(val) } func (vm *VM) executeMul() error { right, left := vm.pop2() - switch obj := left.(type) { - case *object.Array: - val, err := obj.Mul(right) - if err != nil { - return err - } - return vm.push(val) - case object.Integer: - val, err := obj.Mul(right) - if err != nil { - return err - } - return vm.push(val) - case object.Float: - val, err := obj.Mul(right) - if err != nil { - return err - } - return vm.push(val) - case object.String: - val, err := obj.Mul(right) - if err != nil { - return err - } - return vm.push(val) - default: - return object.NewBinaryOpError(left, right, "*") + val, err := ops.Mul(left, right) + if err != nil { + return err } + return vm.push(val) } func (vm *VM) executeDiv() error { right, left := vm.pop2() - switch obj := left.(type) { - case object.Integer: - val, err := obj.Div(right) - if err != nil { - return err - } - return vm.push(val) - case object.Float: - val, err := obj.Div(right) - if err != nil { - return err - } - return vm.push(val) - default: - return object.NewBinaryOpError(left, right, "/") + val, err := ops.Div(left, right) + if err != nil { + return err } - + return vm.push(val) } func (vm *VM) executeMod() error { right, left := vm.pop2() - switch obj := left.(type) { - case object.Integer: - val, err := obj.Mod(right) - if err != nil { - return err - } - return vm.push(val) - default: - return object.NewBinaryOpError(left, right, "%") + val, err := ops.Mod(left, right) + if err != nil { + return err } + return vm.push(val) } func (vm *VM) executeOr() error { @@ -577,37 +484,25 @@ func (vm *VM) executeRightShift() error { func (vm *VM) executeEqual() error { right, left := vm.pop2() - if left.Compare(right) == 0 { - return vm.push(object.TRUE) - } - return vm.push(object.FALSE) + return vm.push(object.FromNativeBoolean(left.Compare(right) == 0)) } func (vm *VM) executeNotEqual() error { right, left := vm.pop2() - if left.Compare(right) != 0 { - return vm.push(object.TRUE) - } - return vm.push(object.FALSE) + return vm.push(object.FromNativeBoolean(left.Compare(right) != 0)) } func (vm *VM) executeGreaterThan() error { right, left := vm.pop2() - if left.Compare(right) == 1 { - return vm.push(object.TRUE) - } - return vm.push(object.FALSE) + return vm.push(object.FromNativeBoolean(left.Compare(right) == 1)) } func (vm *VM) executeGreaterThanOrEqual() error { right, left := vm.pop2() - if left.Compare(right) == 1 { - return vm.push(object.TRUE) - } - return vm.push(object.FALSE) + return vm.push(object.FromNativeBoolean(left.Compare(right) >= 0)) } func (vm *VM) executeNot() error { @@ -745,7 +640,6 @@ func (vm *VM) callClosure(cl *object.Closure, numArgs int) error { func (vm *VM) callBuiltin(builtin object.Builtin, numArgs int) error { args := vm.stack[vm.sp-numArgs : vm.sp] - log.Printf("args: %+v", args) result := builtin.Fn(vm.ctx, args...) vm.sp = vm.sp - numArgs - 1 @@ -839,7 +733,7 @@ func (vm *VM) Run() (err error) { case code.OpJumpNotTruthy: pos := vm.currentFrame().ReadUint16() - if !isTruthy(vm.pop()) { + if !object.IsTruthy(vm.pop()) { vm.currentFrame().SetIP(pos) } diff --git a/internal/vm/vm_test.go b/internal/vm/vm_test.go index 10459b6..88856f1 100644 --- a/internal/vm/vm_test.go +++ b/internal/vm/vm_test.go @@ -257,6 +257,24 @@ func TestFloatingPointArithmetic(t *testing.T) { runVmTests(t, tests) } +func TestMixedArithmetic(t *testing.T) { + tests := []vmTestCase{ + {"1 + 2.1", 3.1}, + {"2 - 0.5", 1.5}, + {"2 * 2.5", 5.0}, + {"5 / 2.0", 2.5}, + {"5 % 2.0", 1.0}, + + {"2.1 + 1", 3.1}, + {"2.5 - 2", 0.5}, + {"2.5 * 2", 5.0}, + {"5.0 / 2", 2.5}, + {"5.0 % 2", 1.0}, + } + + runVmTests(t, tests) +} + func TestBooleanExpressions(t *testing.T) { tests := []vmTestCase{ {"true", true},