This commit is contained in:
Chuck Smith
2024-01-19 18:07:54 -05:00
parent 7eba7471c8
commit 44d20ba7a0
3 changed files with 119 additions and 10 deletions

View File

@@ -1,6 +1,7 @@
package evaluator package evaluator
import ( import (
"fmt"
"monkey/ast" "monkey/ast"
"monkey/object" "monkey/object"
) )
@@ -11,6 +12,13 @@ var (
FALSE = &object.Boolean{Value: false} FALSE = &object.Boolean{Value: false}
) )
func isError(obj object.Object) bool {
if obj != nil {
return obj.Type() == object.ERROR_OBJ
}
return false
}
func Eval(node ast.Node) object.Object { func Eval(node ast.Node) object.Object {
switch node := node.(type) { switch node := node.(type) {
@@ -29,6 +37,9 @@ func Eval(node ast.Node) object.Object {
case *ast.ReturnStatement: case *ast.ReturnStatement:
val := Eval(node.ReturnValue) val := Eval(node.ReturnValue)
if isError(val) {
return val
}
return &object.ReturnValue{Value: val} return &object.ReturnValue{Value: val}
// Expressions // Expressions
@@ -40,11 +51,20 @@ func Eval(node ast.Node) object.Object {
case *ast.PrefixExpression: case *ast.PrefixExpression:
right := Eval(node.Right) right := Eval(node.Right)
if isError(right) {
return right
}
return evalPrefixExpression(node.Operator, right) return evalPrefixExpression(node.Operator, right)
case *ast.InfixExpression: case *ast.InfixExpression:
left := Eval(node.Left) left := Eval(node.Left)
if isError(left) {
return left
}
right := Eval(node.Right) right := Eval(node.Right)
if isError(right) {
return right
}
return evalInfixExpression(node.Operator, left, right) return evalInfixExpression(node.Operator, left, right)
} }
@@ -58,9 +78,13 @@ func evalProgram(program *ast.Program) object.Object {
for _, statement := range program.Statements { for _, statement := range program.Statements {
result = Eval(statement) result = Eval(statement)
if returnValue, ok := result.(*object.ReturnValue); ok { switch result := result.(type) {
return returnValue.Value case *object.ReturnValue:
return result.Value
case *object.Error:
return result
} }
} }
return result return result
@@ -72,10 +96,13 @@ func evalBlockStatements(block *ast.BlockStatement) object.Object {
for _, statement := range block.Statements { for _, statement := range block.Statements {
result = Eval(statement) result = Eval(statement)
if result != nil && result.Type() == object.RETURN_VALUE_OBJ { if result != nil {
rt := result.Type()
if rt == object.RETURN_VALUE_OBJ || rt == object.ERROR_OBJ {
return result return result
} }
} }
}
return result return result
} }
@@ -94,7 +121,7 @@ func evalPrefixExpression(operator string, right object.Object) object.Object {
case "-": case "-":
return evalMinusPrefixOperatorExpression(right) return evalMinusPrefixOperatorExpression(right)
default: default:
return NULL return newError("unknown operator: %s%s", operator, right.Type())
} }
} }
@@ -113,7 +140,7 @@ func evalBangOperatorExpression(right object.Object) object.Object {
func evalMinusPrefixOperatorExpression(right object.Object) object.Object { func evalMinusPrefixOperatorExpression(right object.Object) object.Object {
if right.Type() != object.INTEGER_OBJ { if right.Type() != object.INTEGER_OBJ {
return NULL return newError("unknown operator: -%s", right.Type())
} }
value := right.(*object.Integer).Value value := right.(*object.Integer).Value
@@ -128,8 +155,10 @@ func evalInfixExpression(operator string, left, right object.Object) object.Obje
return nativeBoolToBooleanObject(left == right) return nativeBoolToBooleanObject(left == right)
case operator == "!=": case operator == "!=":
return nativeBoolToBooleanObject(left != right) return nativeBoolToBooleanObject(left != right)
case left.Type() != right.Type():
return newError("type mismatch: %s %s %s", left.Type(), operator, right.Type())
default: default:
return NULL return newError("unknown operator: %s %s %s", left.Type(), operator, right.Type())
} }
} }
@@ -155,12 +184,15 @@ func evalIntegerInfixExpression(operator string, left, right object.Object) obje
case "!=": case "!=":
return nativeBoolToBooleanObject(leftVal != rightVal) return nativeBoolToBooleanObject(leftVal != rightVal)
default: default:
return NULL return newError("unknown operator: %s %s %s", left.Type(), operator, right.Type())
} }
} }
func evalIfExpression(ie *ast.IfExpression) object.Object { func evalIfExpression(ie *ast.IfExpression) object.Object {
condition := Eval(ie.Condition) condition := Eval(ie.Condition)
if isError(condition) {
return condition
}
if isTruthy(condition) { if isTruthy(condition) {
return Eval(ie.Consequence) return Eval(ie.Consequence)
@@ -183,3 +215,7 @@ func isTruthy(obj object.Object) bool {
return true return true
} }
} }
func newError(format string, a ...interface{}) *object.Error {
return &object.Error{Message: fmt.Sprintf(format, a...)}
}

View File

@@ -139,6 +139,66 @@ func TestReturnStatements(t *testing.T) {
} }
} }
func TestErrorHandling(t *testing.T) {
tests := []struct {
input string
expectedMessage string
}{
{
"5 + true;",
"type mismatch: INTEGER + BOOLEAN",
},
{
"5 + true; 5;",
"type mismatch: INTEGER + BOOLEAN",
},
{
"-true",
"unknown operator: -BOOLEAN",
},
{
"true + false;",
"unknown operator: BOOLEAN + BOOLEAN",
},
{
"5; true + false; 5",
"unknown operator: BOOLEAN + BOOLEAN",
},
{
"if (10 > 1) { true + false; }",
"unknown operator: BOOLEAN + BOOLEAN",
},
{
`
if (10 > 1) {
if (10 > 1) {
return true + false;
}
return 1;
}
`,
"unknown operator: BOOLEAN + BOOLEAN",
},
}
for _, tt := range tests {
evaluated := testEval(tt.input)
errObj, ok := evaluated.(*object.Error)
if !ok {
t.Errorf("no error object returned. got=%T(%+v)",
evaluated, evaluated)
continue
}
if errObj.Message != tt.expectedMessage {
t.Errorf("wrong error message. expected=%q, got=%q",
tt.expectedMessage, errObj.Message)
}
}
}
func testEval(input string) object.Object { func testEval(input string) object.Object {
l := lexer.New(input) l := lexer.New(input)
p := parser.New(l) p := parser.New(l)

View File

@@ -9,6 +9,7 @@ const (
BOOLEAN_OBJ = "BOOLEAN" BOOLEAN_OBJ = "BOOLEAN"
NULL_OBJ = "NULL" NULL_OBJ = "NULL"
RETURN_VALUE_OBJ = "RETURN_VALUE" RETURN_VALUE_OBJ = "RETURN_VALUE"
ERROR_OBJ = "ERROR"
) )
type Object interface { type Object interface {
@@ -52,10 +53,22 @@ type ReturnValue struct {
Value Object Value Object
} }
func (rv ReturnValue) Type() ObjectType { func (rv *ReturnValue) Type() ObjectType {
return RETURN_VALUE_OBJ return RETURN_VALUE_OBJ
} }
func (rv ReturnValue) Inspect() string { func (rv *ReturnValue) Inspect() string {
return rv.Value.Inspect() return rv.Value.Inspect()
} }
type Error struct {
Message string
}
func (e *Error) Type() ObjectType {
return ERROR_OBJ
}
func (e *Error) Inspect() string {
return "Error: " + e.Message
}