From 1c99d2198b9131fcc3a44a4178bd612037a6e105 Mon Sep 17 00:00:00 2001 From: Chuck Smith Date: Sun, 24 Mar 2024 17:11:48 -0400 Subject: [PATCH] Simplified Equality --- builtins/find.go | 20 +++++++------- evaluator/evaluator.go | 34 +++++++++++------------- evaluator/evaluator_test.go | 5 ++-- object/array.go | 31 ++++++++++++---------- object/bool.go | 18 ++++++------- object/hash.go | 18 +++++++------ object/int.go | 23 +++++++++------- object/null.go | 14 +++++----- object/object.go | 3 +-- object/str.go | 23 +++++++++------- vm/vm.go | 53 +++++-------------------------------- 11 files changed, 109 insertions(+), 133 deletions(-) diff --git a/builtins/find.go b/builtins/find.go index 7664989..29b3780 100644 --- a/builtins/find.go +++ b/builtins/find.go @@ -1,6 +1,9 @@ package builtins -import "monkey/object" +import ( + "monkey/object" + "sort" +) import ( "strings" @@ -21,15 +24,14 @@ func Find(args ...object.Object) object.Object { return newError("expected arg #2 to be `str` got got=%T", args[1]) } } else if haystack, ok := args[0].(*object.Array); ok { - needle := args[1] - index := -1 - for i, el := range haystack.Elements { - if cmp, ok := el.(object.Comparable); ok && cmp.Equal(needle) { - index = i - break - } + needle := args[1].(object.Comparable) + i := sort.Search(len(haystack.Elements), func(i int) bool { + return needle.Compare(haystack.Elements[i]) == 0 + }) + if i < len(haystack.Elements) && needle.Compare(haystack.Elements[i]) == 0 { + return &object.Integer{Value: int64(i)} } - return &object.Integer{Value: int64(index)} + return &object.Integer{Value: -1} } else { return newError("expected arg #1 to be `str` or `array` got got=%T", args[0]) } diff --git a/evaluator/evaluator.go b/evaluator/evaluator.go index a31f7fe..59ffd43 100644 --- a/evaluator/evaluator.go +++ b/evaluator/evaluator.go @@ -337,6 +337,7 @@ func evalInfixExpression(operator string, left, right object.Object) object.Obje elements = append(elements, leftVal...) } return &object.Array{Elements: elements} + // 3 * [1] case operator == "*" && left.Type() == object.INTEGER_OBJ && right.Type() == object.ARRAY_OBJ: leftVal := int(left.(*object.Integer).Value) @@ -352,24 +353,33 @@ func evalInfixExpression(operator string, left, right object.Object) object.Obje leftVal := left.(*object.String).Value rightVal := right.(*object.Integer).Value return &object.String{Value: strings.Repeat(leftVal, int(rightVal))} + // 4 * " " case operator == "*" && left.Type() == object.INTEGER_OBJ && right.Type() == object.STRING_OBJ: leftVal := left.(*object.Integer).Value rightVal := right.(*object.String).Value return &object.String{Value: strings.Repeat(rightVal, int(leftVal))} + case operator == "==": + return nativeBoolToBooleanObject(left.(object.Comparable).Compare(right) == 0) + case operator == "!=": + return nativeBoolToBooleanObject(left.(object.Comparable).Compare(right) != 0) + case operator == "<=": + return nativeBoolToBooleanObject(left.(object.Comparable).Compare(right) < 1) + case operator == ">=": + return nativeBoolToBooleanObject(left.(object.Comparable).Compare(right) > -1) + case operator == "<": + return nativeBoolToBooleanObject(left.(object.Comparable).Compare(right) == -1) + case operator == ">": + return nativeBoolToBooleanObject(left.(object.Comparable).Compare(right) == 1) + case left.Type() == object.BOOLEAN_OBJ && right.Type() == object.BOOLEAN_OBJ: return evalBooleanInfixExpression(operator, left, right) case left.Type() == object.INTEGER_OBJ && right.Type() == object.INTEGER_OBJ: return evalIntegerInfixExpression(operator, left, right) case left.Type() == object.STRING_OBJ && right.Type() == object.STRING_OBJ: return evalStringInfixExpression(operator, left, right) - case operator == "==": - return nativeBoolToBooleanObject(left == right) - case operator == "!=": - return nativeBoolToBooleanObject(left != right) - case left.Type() != right.Type(): - return newError("type mismatch: %s %s %s", left.Type(), operator, right.Type()) + default: return newError("unknown operator: %s %s %s", left.Type(), operator, right.Type()) } @@ -380,10 +390,6 @@ func evalBooleanInfixExpression(operator string, left, right object.Object) obje rightVal := right.(*object.Boolean).Value switch operator { - case "==": - return nativeBoolToBooleanObject(left == right) - case "!=": - return nativeBoolToBooleanObject(leftVal != rightVal) case "&&": return &object.Boolean{Value: leftVal && rightVal} case "||": @@ -400,14 +406,6 @@ func evalStringInfixExpression(operator string, left object.Object, right object switch operator { case "+": return &object.String{Value: leftVal + rightVal} - case "<": - return nativeBoolToBooleanObject(leftVal < rightVal) - case ">": - return nativeBoolToBooleanObject(leftVal > rightVal) - case "==": - return nativeBoolToBooleanObject(leftVal == rightVal) - case "!=": - return nativeBoolToBooleanObject(leftVal != rightVal) default: return newError("unknown operator: %s %s %s", left.Type(), operator, right.Type()) } diff --git a/evaluator/evaluator_test.go b/evaluator/evaluator_test.go index 40837ef..e57cbb4 100644 --- a/evaluator/evaluator_test.go +++ b/evaluator/evaluator_test.go @@ -118,6 +118,7 @@ func TestIfElseExpression(t *testing.T) { } for _, tt := range tests { + t.Log(tt.input) evaluated := testEval(tt.input) integer, ok := tt.expected.(int) if ok { @@ -184,11 +185,11 @@ func TestErrorHandling(t *testing.T) { }{ { "5 + true;", - "type mismatch: int + bool", + "unknown operator: int + bool", }, { "5 + true; 5;", - "type mismatch: int + bool", + "unknown operator: int + bool", }, { "-true", diff --git a/object/array.go b/object/array.go index 4256369..b47bb37 100644 --- a/object/array.go +++ b/object/array.go @@ -13,6 +13,7 @@ type Array struct { func (ao *Array) Type() ObjectType { return ARRAY_OBJ } + func (ao *Array) Inspect() string { var out bytes.Buffer @@ -30,25 +31,34 @@ func (ao *Array) Inspect() string { func (ao *Array) String() string { return ao.Inspect() } -func (ao *Array) Equal(other Object) bool { + +func (a *Array) Less(i, j int) bool { + if cmp, ok := a.Elements[i].(Comparable); ok { + return cmp.Compare(a.Elements[j]) == -1 + } + return false +} + +func (ao *Array) Compare(other Object) int { if obj, ok := other.(*Array); ok { if len(ao.Elements) != len(obj.Elements) { - return false + return -1 } for i, el := range ao.Elements { cmp, ok := el.(Comparable) if !ok { - return false + return -1 } - if !cmp.Equal(obj.Elements[i]) { - return false + if cmp.Compare(obj.Elements[i]) != 0 { + return cmp.Compare(obj.Elements[i]) } } - return true + return 0 } - return false + return -1 } + func (ao *Array) Copy() *Array { elements := make([]Object, len(ao.Elements)) for i, e := range ao.Elements { @@ -70,10 +80,3 @@ func (ao *Array) Len() int { func (ao *Array) Swap(i, j int) { ao.Elements[i], ao.Elements[j] = ao.Elements[j], ao.Elements[i] } - -func (ao *Array) Less(i, j int) bool { - if cmp, ok := ao.Elements[i].(Comparable); ok { - return cmp.Less(ao.Elements[j]) - } - return false -} diff --git a/object/bool.go b/object/bool.go index 558a171..f7565ec 100644 --- a/object/bool.go +++ b/object/bool.go @@ -11,31 +11,29 @@ type Boolean struct { func (b *Boolean) Type() ObjectType { return BOOLEAN_OBJ } + func (b *Boolean) Inspect() string { return fmt.Sprintf("%t", b.Value) } + func (b *Boolean) Clone() Object { return &Boolean{Value: b.Value} } + func (b *Boolean) String() string { return b.Inspect() } -func (b *Boolean) Equal(other Object) bool { - if obj, ok := other.(*Boolean); ok { - return b.Value == obj.Value - } - return false -} -func (b *Boolean) Int() int64 { + +func (b *Boolean) Int() int { if b.Value { return 1 } return 0 } -func (b *Boolean) Less(other Object) bool { +func (b *Boolean) Compare(other Object) int { if obj, ok := other.(*Boolean); ok { - return b.Int() < obj.Int() + return b.Int() - obj.Int() } - return false + return 1 } diff --git a/object/hash.go b/object/hash.go index b0b420c..d0d50db 100644 --- a/object/hash.go +++ b/object/hash.go @@ -47,6 +47,7 @@ type Hash struct { func (h *Hash) Type() ObjectType { return HASH_OBJ } + func (h *Hash) Inspect() string { var out bytes.Buffer @@ -64,28 +65,29 @@ func (h *Hash) Inspect() string { func (h *Hash) String() string { return h.Inspect() } -func (h *Hash) Equal(other Object) bool { + +func (h *Hash) Compare(other Object) int { if obj, ok := other.(*Hash); ok { if len(h.Pairs) != len(obj.Pairs) { - return false + return -1 } for _, pair := range h.Pairs { left := pair.Value hashed := left.(Hashable) right, ok := obj.Pairs[hashed.HashKey()] if !ok { - return false + return -1 } cmp, ok := left.(Comparable) if !ok { - return false + return -1 } - if !cmp.Equal(right.Value) { - return false + if cmp.Compare(right.Value) != 0 { + return cmp.Compare(right.Value) } } - return true + return 0 } - return false + return -1 } diff --git a/object/int.go b/object/int.go index a9453c6..ae6d300 100644 --- a/object/int.go +++ b/object/int.go @@ -9,24 +9,29 @@ type Integer struct { func (i *Integer) Type() ObjectType { return INTEGER_OBJ } + func (i *Integer) Inspect() string { return fmt.Sprintf("%d", i.Value) } + func (i *Integer) Clone() Object { return &Integer{Value: i.Value} } + func (i *Integer) String() string { return i.Inspect() } -func (i *Integer) Equal(other Object) bool { + +func (i *Integer) Compare(other Object) int { if obj, ok := other.(*Integer); ok { - return i.Value == obj.Value + switch { + case i.Value < obj.Value: + return -1 + case i.Value > obj.Value: + return 1 + default: + return 0 + } } - return false -} -func (i *Integer) Less(other Object) bool { - if obj, ok := other.(*Integer); ok { - return i.Value < obj.Value - } - return true + return -1 } diff --git a/object/null.go b/object/null.go index ac2c162..9440e6c 100644 --- a/object/null.go +++ b/object/null.go @@ -5,16 +5,18 @@ type Null struct{} func (n *Null) Type() ObjectType { return NULL_OBJ } + func (n *Null) Inspect() string { return "null" } + func (n *Null) String() string { return n.Inspect() } -func (n *Null) Equal(other Object) bool { - _, ok := other.(*Null) - return ok -} -func (n *Null) Less(other Object) bool { - return false + +func (n *Null) Compare(other Object) int { + if _, ok := other.(*Null); ok { + return 0 + } + return 1 } diff --git a/object/object.go b/object/object.go index 3b6e9bf..2ee6b74 100644 --- a/object/object.go +++ b/object/object.go @@ -22,8 +22,7 @@ const ( // values. It is the responsibility of the caller (left) to check for types. // Returns `true` iif the types and values are identical, `false` otherwise. type Comparable interface { - Equal(other Object) bool - Less(other Object) bool + Compare(other Object) int } // Immutable is the interface for all immutable objects which must implement diff --git a/object/str.go b/object/str.go index a5202f4..e3a10e0 100644 --- a/object/str.go +++ b/object/str.go @@ -9,24 +9,29 @@ type String struct { func (s *String) Type() ObjectType { return STRING_OBJ } + func (s *String) Inspect() string { return fmt.Sprintf("%#v", s.Value) } + func (s *String) Clone() Object { return &String{Value: s.Value} } + func (s *String) String() string { return s.Value } -func (s *String) Equal(other Object) bool { + +func (s *String) Compare(other Object) int { if obj, ok := other.(*String); ok { - return s.Value == obj.Value + switch { + case s.Value < obj.Value: + return -1 + case s.Value > obj.Value: + return 1 + default: + return 0 + } } - return false -} -func (s *String) Less(other Object) bool { - if obj, ok := other.(*String); ok { - return s.Value < obj.Value - } - return false + return 1 } diff --git a/vm/vm.go b/vm/vm.go index b501980..a59e51a 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -659,39 +659,18 @@ func (vm *VM) executeComparison(op code.Opcode) error { right := vm.pop() left := vm.pop() - if left.Type() == object.INTEGER_OBJ && right.Type() == object.INTEGER_OBJ { - return vm.executeIntegerComparison(op, left, right) - } - - if left.Type() == object.STRING_OBJ && right.Type() == object.STRING_OBJ { - return vm.executeStringComparison(op, left, right) - } - switch op { case code.OpEqual: - return vm.push(nativeBoolToBooleanObject(right == left)) + return vm.push(nativeBoolToBooleanObject(left.(object.Comparable).Compare(right) == 0)) case code.OpNotEqual: - return vm.push(nativeBoolToBooleanObject(right != left)) - default: - return fmt.Errorf("unknown operator: %d (%s %s)", op, left.Type(), right.Type()) - } -} - -func (vm *VM) executeIntegerComparison(op code.Opcode, left, right object.Object) error { - leftValue := left.(*object.Integer).Value - rightValue := right.(*object.Integer).Value - - switch op { - case code.OpEqual: - return vm.push(nativeBoolToBooleanObject(rightValue == leftValue)) - case code.OpNotEqual: - return vm.push(nativeBoolToBooleanObject(rightValue != leftValue)) - case code.OpGreaterThan: - return vm.push(nativeBoolToBooleanObject(leftValue > rightValue)) + return vm.push(nativeBoolToBooleanObject(left.(object.Comparable).Compare(right) != 0)) case code.OpGreaterThanEqual: - return vm.push(nativeBoolToBooleanObject(leftValue >= rightValue)) + return vm.push(nativeBoolToBooleanObject(left.(object.Comparable).Compare(right) > -1)) + case code.OpGreaterThan: + return vm.push(nativeBoolToBooleanObject(left.(object.Comparable).Compare(right) == 1)) default: - return fmt.Errorf("unknown operator: %d", op) + return fmt.Errorf("unknown operator: %d (%s %s)", + op, left.Type(), right.Type()) } } @@ -826,24 +805,6 @@ func (vm *VM) executeStringIndex(str, index object.Object) error { ) } -func (vm *VM) executeStringComparison(op code.Opcode, left, right object.Object) error { - leftValue := left.(*object.String).Value - rightValue := right.(*object.String).Value - - switch op { - case code.OpEqual: - return vm.push(nativeBoolToBooleanObject(rightValue == leftValue)) - case code.OpNotEqual: - return vm.push(nativeBoolToBooleanObject(rightValue != leftValue)) - case code.OpGreaterThan: - return vm.push(nativeBoolToBooleanObject(leftValue > rightValue)) - case code.OpGreaterThanEqual: - return vm.push(nativeBoolToBooleanObject(leftValue >= rightValue)) - default: - return fmt.Errorf("unknown operator: %d", op) - } -} - func nativeBoolToBooleanObject(input bool) *object.Boolean { if input { return True