Eval complete

This commit is contained in:
Chuck Smith
2024-01-20 11:16:56 -05:00
parent 581573486c
commit 10821fc88a
4 changed files with 183 additions and 4 deletions

View File

@@ -52,6 +52,11 @@ func Eval(node ast.Node, env *object.Environment) object.Object {
case *ast.Identifier: case *ast.Identifier:
return evalIdentifier(node, env) return evalIdentifier(node, env)
case *ast.FunctionLiteral:
params := node.Parameters
body := node.Body
return &object.Function{Parameters: params, Env: env, Body: body}
// Expressions // Expressions
case *ast.IntegerLiteral: case *ast.IntegerLiteral:
return &object.Integer{Value: node.Value} return &object.Integer{Value: node.Value}
@@ -77,6 +82,18 @@ func Eval(node ast.Node, env *object.Environment) object.Object {
} }
return evalInfixExpression(node.Operator, left, right) return evalInfixExpression(node.Operator, left, right)
case *ast.CallExpression:
function := Eval(node.Function, env)
if isError(function) {
return function
}
args := evalExpressions(node.Arguments, env)
if len(args) == 1 && isError(args[0]) {
return args[0]
}
return applyFunction(function, args)
} }
return nil return nil
@@ -237,3 +254,46 @@ func evalIdentifier(node *ast.Identifier, env *object.Environment) object.Object
} }
return val return val
} }
func evalExpressions(exps []ast.Expression, env *object.Environment) []object.Object {
var result []object.Object
for _, e := range exps {
evaluated := Eval(e, env)
if isError(evaluated) {
return []object.Object{evaluated}
}
result = append(result, evaluated)
}
return result
}
func applyFunction(fn object.Object, args []object.Object) object.Object {
function, ok := fn.(*object.Function)
if !ok {
newError("not a function: %s", fn.Type())
}
extendedEnv := extendFunctionEnv(function, args)
evaluated := Eval(function.Body, extendedEnv)
return unwrapReturnValue(evaluated)
}
func extendFunctionEnv(fn *object.Function, args []object.Object) *object.Environment {
env := object.NewEnclosedEnvironment(fn.Env)
for paramIdx, param := range fn.Parameters {
env.Set(param.Value, args[paramIdx])
}
return env
}
func unwrapReturnValue(obj object.Object) object.Object {
if returnValue, ok := obj.(*object.ReturnValue); ok {
return returnValue.Value
}
return obj
}

View File

@@ -120,7 +120,9 @@ func TestReturnStatements(t *testing.T) {
{"return 10; 9;", 10}, {"return 10; 9;", 10},
{"return 2 * 5; 9;", 10}, {"return 2 * 5; 9;", 10},
{"9; return 2 * 5; 9;", 10}, {"9; return 2 * 5; 9;", 10},
{` {"if (10 > 1) { return 10; }", 10},
{
`
if (10 > 1) { if (10 > 1) {
if (10 > 1) { if (10 > 1) {
return 10; return 10;
@@ -131,6 +133,25 @@ func TestReturnStatements(t *testing.T) {
`, `,
10, 10,
}, },
{
`
let f = fn(x) {
return x;
x + 10;
};
f(10);`,
10,
},
{
`
let f = fn(x) {
let result = x + 10;
return result;
return 10;
};
f(10);`,
20,
},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -219,6 +240,60 @@ func TestLetStatements(t *testing.T) {
} }
} }
func TestFunctionObject(t *testing.T) {
input := "fn(x) { x + 2; };"
evaluated := testEval(input)
fn, ok := evaluated.(*object.Function)
if !ok {
t.Fatalf("object is not Function. got=%T (%+v)", evaluated, evaluated)
}
if len(fn.Parameters) != 1 {
t.Fatalf("function has wrong parameters. Parameters=%+v", fn.Parameters)
}
if fn.Parameters[0].String() != "x" {
t.Fatalf("parameter is not 'x'. got=%q", fn.Parameters[0])
}
expectedBody := "(x + 2)"
if fn.Body.String() != expectedBody {
t.Fatalf("body is not %q. got=%q", expectedBody, fn.Body.String())
}
}
func TestFunctionApplication(t *testing.T) {
tests := []struct {
input string
expected int64
}{
{"let identity = fn(x) { x; }; identity(5);", 5},
{"let identity = fn(x) { return x; }; identity(5);", 5},
{"let double = fn(x) { x * 2; }; double(5);", 10},
{"let add = fn(x, y) { x + y; }; add(5, 5);", 10},
{"let add = fn(x, y) { x + y; }; add(5 + 5, add(5, 5));", 20},
{"fn(x) { x; }(5)", 5},
}
for _, tt := range tests {
testIntegerObject(t, testEval(tt.input), tt.expected)
}
}
func TestClosures(t *testing.T) {
input := `
let newAdder = fn(x) {
fn(y) { x + y };
};
let addTwo = newAdder(2);
addTwo(2);`
testIntegerObject(t, testEval(input), 4)
}
func testEval(input string) object.Object { func testEval(input string) object.Object {
l := lexer.New(input) l := lexer.New(input)
p := parser.New(l) p := parser.New(l)

View File

@@ -1,16 +1,26 @@
package object package object
func NewEnclosedEnvironment(outer *Environment) *Environment {
env := NewEnvironment()
env.outer = outer
return env
}
func NewEnvironment() *Environment { func NewEnvironment() *Environment {
s := make(map[string]Object) s := make(map[string]Object)
return &Environment{store: s} return &Environment{store: s, outer: nil}
} }
type Environment struct { type Environment struct {
store map[string]Object store map[string]Object
outer *Environment
} }
func (e *Environment) Get(name string) (Object, bool) { func (e *Environment) Get(name string) (Object, bool) {
obj, ok := e.store[name] obj, ok := e.store[name]
if !ok && e.outer != nil {
obj, ok = e.outer.Get(name)
}
return obj, ok return obj, ok
} }

View File

@@ -1,6 +1,11 @@
package object package object
import "fmt" import (
"bytes"
"fmt"
"monkey/ast"
"strings"
)
type ObjectType string type ObjectType string
@@ -10,6 +15,7 @@ const (
NULL_OBJ = "NULL" NULL_OBJ = "NULL"
RETURN_VALUE_OBJ = "RETURN_VALUE" RETURN_VALUE_OBJ = "RETURN_VALUE"
ERROR_OBJ = "ERROR" ERROR_OBJ = "ERROR"
FUNCTION_OBJ = "FUNCTION"
) )
type Object interface { type Object interface {
@@ -72,3 +78,31 @@ func (e *Error) Type() ObjectType {
func (e *Error) Inspect() string { func (e *Error) Inspect() string {
return "Error: " + e.Message return "Error: " + e.Message
} }
type Function struct {
Parameters []*ast.Identifier
Body *ast.BlockStatement
Env *Environment
}
func (f Function) Type() ObjectType {
return FUNCTION_OBJ
}
func (f Function) Inspect() string {
var out bytes.Buffer
params := []string{}
for _, p := range f.Parameters {
params = append(params, p.String())
}
out.WriteString("fn")
out.WriteString("(")
out.WriteString(strings.Join(params, ", "))
out.WriteString(") {\n")
out.WriteString(f.Body.String())
out.WriteString("\n}")
return out.String()
}