From 7eba7471c84436bc83d4c893599d96ce6acb180a Mon Sep 17 00:00:00 2001 From: Chuck Smith Date: Fri, 19 Jan 2024 17:50:02 -0500 Subject: [PATCH] Return statements --- evaluator/evaluator.go | 30 ++++++++++++++++++++++++++---- evaluator/evaluator_test.go | 28 ++++++++++++++++++++++++++++ object/object.go | 19 ++++++++++++++++--- 3 files changed, 70 insertions(+), 7 deletions(-) diff --git a/evaluator/evaluator.go b/evaluator/evaluator.go index 8a2fbc4..6792b5b 100644 --- a/evaluator/evaluator.go +++ b/evaluator/evaluator.go @@ -16,17 +16,21 @@ func Eval(node ast.Node) object.Object { // Statements case *ast.Program: - return evalStatements(node.Statements) + return evalProgram(node) case *ast.ExpressionStatement: return Eval(node.Expression) case *ast.BlockStatement: - return evalStatements(node.Statements) + return evalBlockStatements(node) case *ast.IfExpression: return evalIfExpression(node) + case *ast.ReturnStatement: + val := Eval(node.ReturnValue) + return &object.ReturnValue{Value: val} + // Expressions case *ast.IntegerLiteral: return &object.Integer{Value: node.Value} @@ -48,11 +52,29 @@ func Eval(node ast.Node) object.Object { return nil } -func evalStatements(stmts []ast.Statement) object.Object { +func evalProgram(program *ast.Program) object.Object { var result object.Object - for _, statement := range stmts { + for _, statement := range program.Statements { result = Eval(statement) + + if returnValue, ok := result.(*object.ReturnValue); ok { + return returnValue.Value + } + } + + return result +} + +func evalBlockStatements(block *ast.BlockStatement) object.Object { + var result object.Object + + for _, statement := range block.Statements { + result = Eval(statement) + + if result != nil && result.Type() == object.RETURN_VALUE_OBJ { + return result + } } return result diff --git a/evaluator/evaluator_test.go b/evaluator/evaluator_test.go index 33be284..4d58edf 100644 --- a/evaluator/evaluator_test.go +++ b/evaluator/evaluator_test.go @@ -111,6 +111,34 @@ func TestIfElseExpression(t *testing.T) { } } +func TestReturnStatements(t *testing.T) { + tests := []struct { + input string + expected int64 + }{ + {"return 10;", 10}, + {"return 10; 9;", 10}, + {"return 2 * 5; 9;", 10}, + {"9; return 2 * 5; 9;", 10}, + {` + if (10 > 1) { + if (10 > 1) { + return 10; + } + + return 1; + } + `, + 10, + }, + } + + for _, tt := range tests { + evaluated := testEval(tt.input) + testIntegerObject(t, evaluated, tt.expected) + } +} + func testEval(input string) object.Object { l := lexer.New(input) p := parser.New(l) diff --git a/object/object.go b/object/object.go index ea095bc..3f02ba5 100644 --- a/object/object.go +++ b/object/object.go @@ -5,9 +5,10 @@ import "fmt" type ObjectType string const ( - INTEGER_OBJ = "INTEGER" - BOOLEAN_OBJ = "BOOLEAN" - NULL_OBJ = "NULL" + INTEGER_OBJ = "INTEGER" + BOOLEAN_OBJ = "BOOLEAN" + NULL_OBJ = "NULL" + RETURN_VALUE_OBJ = "RETURN_VALUE" ) type Object interface { @@ -46,3 +47,15 @@ func (n *Null) Type() ObjectType { func (n *Null) Inspect() string { return "null" } + +type ReturnValue struct { + Value Object +} + +func (rv ReturnValue) Type() ObjectType { + return RETURN_VALUE_OBJ +} + +func (rv ReturnValue) Inspect() string { + return rv.Value.Inspect() +}