Return statements

This commit is contained in:
Chuck Smith
2024-01-19 17:50:02 -05:00
parent e6d5567681
commit 7eba7471c8
3 changed files with 70 additions and 7 deletions

View File

@@ -16,17 +16,21 @@ func Eval(node ast.Node) object.Object {
// Statements // Statements
case *ast.Program: case *ast.Program:
return evalStatements(node.Statements) return evalProgram(node)
case *ast.ExpressionStatement: case *ast.ExpressionStatement:
return Eval(node.Expression) return Eval(node.Expression)
case *ast.BlockStatement: case *ast.BlockStatement:
return evalStatements(node.Statements) return evalBlockStatements(node)
case *ast.IfExpression: case *ast.IfExpression:
return evalIfExpression(node) return evalIfExpression(node)
case *ast.ReturnStatement:
val := Eval(node.ReturnValue)
return &object.ReturnValue{Value: val}
// Expressions // Expressions
case *ast.IntegerLiteral: case *ast.IntegerLiteral:
return &object.Integer{Value: node.Value} return &object.Integer{Value: node.Value}
@@ -48,11 +52,29 @@ func Eval(node ast.Node) object.Object {
return nil return nil
} }
func evalStatements(stmts []ast.Statement) object.Object { func evalProgram(program *ast.Program) object.Object {
var result object.Object var result object.Object
for _, statement := range stmts { for _, statement := range program.Statements {
result = Eval(statement) result = Eval(statement)
if returnValue, ok := result.(*object.ReturnValue); ok {
return returnValue.Value
}
}
return result
}
func evalBlockStatements(block *ast.BlockStatement) object.Object {
var result object.Object
for _, statement := range block.Statements {
result = Eval(statement)
if result != nil && result.Type() == object.RETURN_VALUE_OBJ {
return result
}
} }
return result return result

View File

@@ -111,6 +111,34 @@ func TestIfElseExpression(t *testing.T) {
} }
} }
func TestReturnStatements(t *testing.T) {
tests := []struct {
input string
expected int64
}{
{"return 10;", 10},
{"return 10; 9;", 10},
{"return 2 * 5; 9;", 10},
{"9; return 2 * 5; 9;", 10},
{`
if (10 > 1) {
if (10 > 1) {
return 10;
}
return 1;
}
`,
10,
},
}
for _, tt := range tests {
evaluated := testEval(tt.input)
testIntegerObject(t, evaluated, tt.expected)
}
}
func testEval(input string) object.Object { func testEval(input string) object.Object {
l := lexer.New(input) l := lexer.New(input)
p := parser.New(l) p := parser.New(l)

View File

@@ -8,6 +8,7 @@ const (
INTEGER_OBJ = "INTEGER" INTEGER_OBJ = "INTEGER"
BOOLEAN_OBJ = "BOOLEAN" BOOLEAN_OBJ = "BOOLEAN"
NULL_OBJ = "NULL" NULL_OBJ = "NULL"
RETURN_VALUE_OBJ = "RETURN_VALUE"
) )
type Object interface { type Object interface {
@@ -46,3 +47,15 @@ func (n *Null) Type() ObjectType {
func (n *Null) Inspect() string { func (n *Null) Inspect() string {
return "null" return "null"
} }
type ReturnValue struct {
Value Object
}
func (rv ReturnValue) Type() ObjectType {
return RETURN_VALUE_OBJ
}
func (rv ReturnValue) Inspect() string {
return rv.Value.Inspect()
}