From 10821fc88a7317bb0f601e7f14f278ad7ac23b5f Mon Sep 17 00:00:00 2001 From: Chuck Smith Date: Sat, 20 Jan 2024 11:16:56 -0500 Subject: [PATCH] Eval complete --- evaluator/evaluator.go | 60 ++++++++++++++++++++++++++++ evaluator/evaluator_test.go | 79 ++++++++++++++++++++++++++++++++++++- object/environment.go | 12 +++++- object/object.go | 36 ++++++++++++++++- 4 files changed, 183 insertions(+), 4 deletions(-) diff --git a/evaluator/evaluator.go b/evaluator/evaluator.go index 1fcc816..679490d 100644 --- a/evaluator/evaluator.go +++ b/evaluator/evaluator.go @@ -52,6 +52,11 @@ func Eval(node ast.Node, env *object.Environment) object.Object { case *ast.Identifier: return evalIdentifier(node, env) + case *ast.FunctionLiteral: + params := node.Parameters + body := node.Body + return &object.Function{Parameters: params, Env: env, Body: body} + // Expressions case *ast.IntegerLiteral: return &object.Integer{Value: node.Value} @@ -77,6 +82,18 @@ func Eval(node ast.Node, env *object.Environment) object.Object { } return evalInfixExpression(node.Operator, left, right) + case *ast.CallExpression: + function := Eval(node.Function, env) + if isError(function) { + return function + } + args := evalExpressions(node.Arguments, env) + if len(args) == 1 && isError(args[0]) { + return args[0] + } + + return applyFunction(function, args) + } return nil @@ -237,3 +254,46 @@ func evalIdentifier(node *ast.Identifier, env *object.Environment) object.Object } return val } + +func evalExpressions(exps []ast.Expression, env *object.Environment) []object.Object { + var result []object.Object + + for _, e := range exps { + evaluated := Eval(e, env) + if isError(evaluated) { + return []object.Object{evaluated} + } + result = append(result, evaluated) + } + + return result +} + +func applyFunction(fn object.Object, args []object.Object) object.Object { + function, ok := fn.(*object.Function) + if !ok { + newError("not a function: %s", fn.Type()) + } + + extendedEnv := extendFunctionEnv(function, args) + evaluated := Eval(function.Body, extendedEnv) + return unwrapReturnValue(evaluated) +} + +func extendFunctionEnv(fn *object.Function, args []object.Object) *object.Environment { + env := object.NewEnclosedEnvironment(fn.Env) + + for paramIdx, param := range fn.Parameters { + env.Set(param.Value, args[paramIdx]) + } + + return env +} + +func unwrapReturnValue(obj object.Object) object.Object { + if returnValue, ok := obj.(*object.ReturnValue); ok { + return returnValue.Value + } + + return obj +} diff --git a/evaluator/evaluator_test.go b/evaluator/evaluator_test.go index 569c32c..29baf42 100644 --- a/evaluator/evaluator_test.go +++ b/evaluator/evaluator_test.go @@ -120,7 +120,9 @@ func TestReturnStatements(t *testing.T) { {"return 10; 9;", 10}, {"return 2 * 5; 9;", 10}, {"9; return 2 * 5; 9;", 10}, - {` + {"if (10 > 1) { return 10; }", 10}, + { + ` if (10 > 1) { if (10 > 1) { return 10; @@ -128,9 +130,28 @@ func TestReturnStatements(t *testing.T) { return 1; } - `, + `, 10, }, + { + ` + let f = fn(x) { + return x; + x + 10; + }; + f(10);`, + 10, + }, + { + ` + let f = fn(x) { + let result = x + 10; + return result; + return 10; + }; + f(10);`, + 20, + }, } for _, tt := range tests { @@ -219,6 +240,60 @@ func TestLetStatements(t *testing.T) { } } +func TestFunctionObject(t *testing.T) { + input := "fn(x) { x + 2; };" + + evaluated := testEval(input) + fn, ok := evaluated.(*object.Function) + if !ok { + t.Fatalf("object is not Function. got=%T (%+v)", evaluated, evaluated) + } + + if len(fn.Parameters) != 1 { + t.Fatalf("function has wrong parameters. Parameters=%+v", fn.Parameters) + } + + if fn.Parameters[0].String() != "x" { + t.Fatalf("parameter is not 'x'. got=%q", fn.Parameters[0]) + } + + expectedBody := "(x + 2)" + + if fn.Body.String() != expectedBody { + t.Fatalf("body is not %q. got=%q", expectedBody, fn.Body.String()) + } +} + +func TestFunctionApplication(t *testing.T) { + tests := []struct { + input string + expected int64 + }{ + {"let identity = fn(x) { x; }; identity(5);", 5}, + {"let identity = fn(x) { return x; }; identity(5);", 5}, + {"let double = fn(x) { x * 2; }; double(5);", 10}, + {"let add = fn(x, y) { x + y; }; add(5, 5);", 10}, + {"let add = fn(x, y) { x + y; }; add(5 + 5, add(5, 5));", 20}, + {"fn(x) { x; }(5)", 5}, + } + + for _, tt := range tests { + testIntegerObject(t, testEval(tt.input), tt.expected) + } +} + +func TestClosures(t *testing.T) { + input := ` + let newAdder = fn(x) { + fn(y) { x + y }; + }; + + let addTwo = newAdder(2); + addTwo(2);` + + testIntegerObject(t, testEval(input), 4) +} + func testEval(input string) object.Object { l := lexer.New(input) p := parser.New(l) diff --git a/object/environment.go b/object/environment.go index bd89056..6f31070 100644 --- a/object/environment.go +++ b/object/environment.go @@ -1,16 +1,26 @@ package object +func NewEnclosedEnvironment(outer *Environment) *Environment { + env := NewEnvironment() + env.outer = outer + return env +} + func NewEnvironment() *Environment { s := make(map[string]Object) - return &Environment{store: s} + return &Environment{store: s, outer: nil} } type Environment struct { store map[string]Object + outer *Environment } func (e *Environment) Get(name string) (Object, bool) { obj, ok := e.store[name] + if !ok && e.outer != nil { + obj, ok = e.outer.Get(name) + } return obj, ok } diff --git a/object/object.go b/object/object.go index 9b8c6fd..32495d9 100644 --- a/object/object.go +++ b/object/object.go @@ -1,6 +1,11 @@ package object -import "fmt" +import ( + "bytes" + "fmt" + "monkey/ast" + "strings" +) type ObjectType string @@ -10,6 +15,7 @@ const ( NULL_OBJ = "NULL" RETURN_VALUE_OBJ = "RETURN_VALUE" ERROR_OBJ = "ERROR" + FUNCTION_OBJ = "FUNCTION" ) type Object interface { @@ -72,3 +78,31 @@ func (e *Error) Type() ObjectType { func (e *Error) Inspect() string { return "Error: " + e.Message } + +type Function struct { + Parameters []*ast.Identifier + Body *ast.BlockStatement + Env *Environment +} + +func (f Function) Type() ObjectType { + return FUNCTION_OBJ +} + +func (f Function) Inspect() string { + var out bytes.Buffer + + params := []string{} + for _, p := range f.Parameters { + params = append(params, p.String()) + } + + out.WriteString("fn") + out.WriteString("(") + out.WriteString(strings.Join(params, ", ")) + out.WriteString(") {\n") + out.WriteString(f.Body.String()) + out.WriteString("\n}") + + return out.String() +}