From 069b5ba8cfa3ba7276067961bce56b226a5d12d9 Mon Sep 17 00:00:00 2001 From: Chuck Smith Date: Mon, 22 Jan 2024 12:47:16 -0500 Subject: [PATCH] arrays + builtins --- ast/ast.go | 48 ++++++++++++++++++++++ evaluator/builtins.go | 80 +++++++++++++++++++++++++++++++++++++ evaluator/evaluator.go | 38 ++++++++++++++++++ evaluator/evaluator_test.go | 77 +++++++++++++++++++++++++++++++++++ lexer/lexer.go | 4 ++ lexer/lexer_test.go | 9 +++++ object/object.go | 24 +++++++++++ parser/parser.go | 25 +++++++++++- parser/parser_test.go | 54 +++++++++++++++++++++++++ token/token.go | 10 +++-- 10 files changed, 364 insertions(+), 5 deletions(-) diff --git a/ast/ast.go b/ast/ast.go index eb667bc..c44caa9 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -277,3 +277,51 @@ func (sl StringLiteral) String() string { return sl.Token.Literal } func (sl StringLiteral) expressionNode() {} + +type ArrayLiteral struct { + Token token.Token // the '[' token + Elements []Expression +} + +func (al ArrayLiteral) TokenLiteral() string { + return al.Token.Literal +} +func (al ArrayLiteral) String() string { + var out bytes.Buffer + + elements := []string{} + for _, el := range al.Elements { + elements = append(elements, el.String()) + } + + out.WriteString("[") + out.WriteString(strings.Join(elements, ", ")) + out.WriteString("]") + + return out.String() +} +func (al ArrayLiteral) expressionNode() {} + +type IndexExpression struct { + Token token.Token // The [ token + Left Expression + Index Expression +} + +func (ie IndexExpression) TokenLiteral() string { + return ie.Token.Literal +} + +func (ie IndexExpression) String() string { + var out bytes.Buffer + + out.WriteString("(") + out.WriteString(ie.Left.String()) + out.WriteString("[") + out.WriteString(ie.Index.String()) + out.WriteString("])") + + return out.String() +} + +func (ie IndexExpression) expressionNode() {} diff --git a/evaluator/builtins.go b/evaluator/builtins.go index 989e7c9..e61e16d 100644 --- a/evaluator/builtins.go +++ b/evaluator/builtins.go @@ -10,6 +10,8 @@ var builtins = map[string]*object.Builtin{ } switch arg := args[0].(type) { + case *object.Array: + return &object.Integer{Value: int64(len(arg.Elements))} case *object.String: return &object.Integer{Value: int64(len(arg.Value))} default: @@ -18,4 +20,82 @@ var builtins = map[string]*object.Builtin{ } }, }, + + "first": &object.Builtin{ + Fn: func(args ...object.Object) object.Object { + if len(args) != 1 { + return newError("wrong number of arguments. got=%d, want=1", len(args)) + } + if args[0].Type() != object.ARRAY_OBJ { + return newError("argument to `first` must be Array, got %s", args[0].Type()) + } + + arr := args[0].(*object.Array) + if len(arr.Elements) > 0 { + return arr.Elements[0] + } + + return NULL + }, + }, + + "last": &object.Builtin{ + Fn: func(args ...object.Object) object.Object { + if len(args) != 1 { + return newError("wrong number of arguments. got=%d, want=1", len(args)) + } + if args[0].Type() != object.ARRAY_OBJ { + return newError("argument to `last` must be Array, got %s", args[0].Type()) + } + + arr := args[0].(*object.Array) + length := len(arr.Elements) + if length > 0 { + return arr.Elements[length-1] + } + + return NULL + }, + }, + + "rest": &object.Builtin{ + Fn: func(args ...object.Object) object.Object { + if len(args) != 1 { + return newError("wrong number of arguments. got=%d, want=1", len(args)) + } + if args[0].Type() != object.ARRAY_OBJ { + return newError("argument to `rest` must be Array, got %s", args[0].Type()) + } + + arr := args[0].(*object.Array) + length := len(arr.Elements) + + newElements := make([]object.Object, length+1) + copy(newElements, arr.Elements) + newElements[length] = args[1] + + return &object.Array{Elements: newElements} + }, + }, + + "push": &object.Builtin{ + Fn: func(args ...object.Object) object.Object { + if len(args) != 1 { + return newError("wrong number of arguments. got=%d, want=1", len(args)) + } + if args[0].Type() != object.ARRAY_OBJ { + return newError("argument to `push` must be Array, got %s", args[0].Type()) + } + + arr := args[0].(*object.Array) + length := len(arr.Elements) + if length > 0 { + newElements := make([]object.Object, length-1) + copy(newElements, arr.Elements[1:length]) + return &object.Array{Elements: newElements} + } + + return NULL + }, + }, } diff --git a/evaluator/evaluator.go b/evaluator/evaluator.go index b2f9213..1dc6400 100644 --- a/evaluator/evaluator.go +++ b/evaluator/evaluator.go @@ -97,6 +97,23 @@ func Eval(node ast.Node, env *object.Environment) object.Object { case *ast.StringLiteral: return &object.String{Value: node.Value} + case *ast.ArrayLiteral: + elements := evalExpressions(node.Elements, env) + if len(elements) == 1 && isError(elements[0]) { + return elements[0] + } + return &object.Array{Elements: elements} + + case *ast.IndexExpression: + left := Eval(node.Left, env) + if isError(left) { + return left + } + index := Eval(node.Index, env) + if isError(index) { + return index + } + return evalIndexExpression(left, index) } return nil @@ -323,3 +340,24 @@ func unwrapReturnValue(obj object.Object) object.Object { return obj } + +func evalIndexExpression(left, index object.Object) object.Object { + switch { + case left.Type() == object.ARRAY_OBJ && index.Type() == object.INTEGER_OBJ: + return evalArrayIndexExpression(left, index) + default: + return newError("index operator not supported: %s", left.Type()) + } +} + +func evalArrayIndexExpression(array, index object.Object) object.Object { + arrayObject := array.(*object.Array) + idx := index.(*object.Integer).Value + maxInx := int64(len(arrayObject.Elements) - 1) + + if idx < 0 || idx > maxInx { + return NULL + } + + return arrayObject.Elements[idx] +} diff --git a/evaluator/evaluator_test.go b/evaluator/evaluator_test.go index 160258b..3dbdf55 100644 --- a/evaluator/evaluator_test.go +++ b/evaluator/evaluator_test.go @@ -357,6 +357,83 @@ func TestBuiltinFunction(t *testing.T) { } } +func TestArrayLiterals(t *testing.T) { + input := "[1, 2 * 2, 3 + 3]" + + evaluated := testEval(input) + result, ok := evaluated.(*object.Array) + if !ok { + t.Fatalf("object is not Array. got=%T (+%v)", evaluated, evaluated) + } + + if len(result.Elements) != 3 { + t.Fatalf("array has wrong num of elements. got=%d", evaluated) + } + + testIntegerObject(t, result.Elements[0], 1) + testIntegerObject(t, result.Elements[1], 4) + testIntegerObject(t, result.Elements[2], 6) +} + +func TestArrayIndexExpressions(t *testing.T) { + tests := []struct { + input string + expected interface{} + }{ + { + "[1, 2, 3][0]", + 1, + }, + { + "[1, 2, 3][1]", + 2, + }, + { + "[1, 2, 3][2]", + 3, + }, + { + "let i = 0; [1][i];", + 1, + }, + { + "[1, 2, 3][1 + 1];", + 3, + }, + { + "let myArray = [1, 2, 3]; myArray[2];", + 3, + }, + { + "let myArray = [1, 2, 3]; myArray[0] + myArray[1] + myArray[2];", + 6, + }, + { + "let myArray = [1, 2, 3]; let i = myArray[0]; myArray[i]", + 2, + }, + { + "[1, 2, 3][3]", + nil, + }, + { + "[1, 2, 3][-1]", + nil, + }, + } + + for _, tt := range tests { + evaluated := testEval(tt.input) + + integer, ok := tt.expected.(int) + if ok { + testIntegerObject(t, evaluated, int64(integer)) + } else { + testNullObject(t, evaluated) + } + } +} + func testEval(input string) object.Object { l := lexer.New(input) p := parser.New(l) diff --git a/lexer/lexer.go b/lexer/lexer.go index 33b9203..150920d 100644 --- a/lexer/lexer.go +++ b/lexer/lexer.go @@ -94,6 +94,10 @@ func (l *Lexer) NextToken() token.Token { case '"': tok.Type = token.STRING tok.Literal = l.readString() + case '[': + tok = newToken(token.LBRACKET, l.ch) + case ']': + tok = newToken(token.RBRACKET, l.ch) default: if isLetter(l.ch) { tok.Literal = l.readIdentifier() diff --git a/lexer/lexer_test.go b/lexer/lexer_test.go index aa04c77..3e28dfe 100644 --- a/lexer/lexer_test.go +++ b/lexer/lexer_test.go @@ -27,6 +27,7 @@ func TestNextToken(t *testing.T) { 10 != 9; "foobar" "foo bar" + [1, 2]; ` tests := []struct { @@ -117,6 +118,14 @@ func TestNextToken(t *testing.T) { {token.STRING, "foobar"}, {token.STRING, "foo bar"}, + + {token.LBRACKET, "["}, + {token.INT, "1"}, + {token.COMMA, ","}, + {token.INT, "2"}, + {token.RBRACKET, "]"}, + {token.SEMICOLON, ";"}, + {token.EOF, ""}, } diff --git a/object/object.go b/object/object.go index 6b1aa65..6455f92 100644 --- a/object/object.go +++ b/object/object.go @@ -18,6 +18,7 @@ const ( FUNCTION_OBJ = "FUNCTION" STRING_OBJ = "STRING" BUILTIN_OBJ = "BUILTIN" + ARRAY_OBJ = "ARRAY" ) type Object interface { @@ -133,3 +134,26 @@ func (b Builtin) Type() ObjectType { func (b Builtin) Inspect() string { return "builtin function" } + +type Array struct { + Elements []Object +} + +func (ao Array) Type() ObjectType { + return ARRAY_OBJ +} + +func (ao Array) Inspect() string { + var out bytes.Buffer + + elements := []string{} + for _, e := range ao.Elements { + elements = append(elements, e.Inspect()) + } + + out.WriteString("[") + out.WriteString(strings.Join(elements, ", ")) + out.WriteString("]") + + return out.String() +} diff --git a/parser/parser.go b/parser/parser.go index 5f58356..5a1e8e4 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -30,6 +30,7 @@ var precedences = map[token.TokenType]int{ token.SLASH: PRODUCT, token.ASTERISK: PRODUCT, token.LPAREN: CALL, + token.LBRACKET: INDEX, } type ( @@ -65,6 +66,7 @@ func New(l *lexer.Lexer) *Parser { p.registerPrefix(token.IF, p.parseIfExpression) p.registerPrefix(token.FUNCTION, p.parseFunctionLiteral) p.registerPrefix(token.STRING, p.parseStringLiteral) + p.registerPrefix(token.LBRACKET, p.parseArrayLiteral) p.infixParseFns = make(map[token.TokenType]infixParseFn) p.registerInfix(token.PLUS, p.parseInfixExpression) @@ -75,8 +77,8 @@ func New(l *lexer.Lexer) *Parser { p.registerInfix(token.NOT_EQ, p.parseInfixExpression) p.registerInfix(token.LT, p.parseInfixExpression) p.registerInfix(token.GT, p.parseInfixExpression) - p.registerInfix(token.LPAREN, p.parseCallExpression) + p.registerInfix(token.LBRACKET, p.parseIndexExpression) // Read two tokens, so curToken and peekToken are both set p.nextToken() @@ -435,3 +437,24 @@ func (p *Parser) registerInfix(tokenType token.TokenType, fn infixParseFn) { func (p *Parser) parseStringLiteral() ast.Expression { return &ast.StringLiteral{Token: p.curToken, Value: p.curToken.Literal} } + +func (p *Parser) parseArrayLiteral() ast.Expression { + array := &ast.ArrayLiteral{Token: p.curToken} + + array.Elements = p.parseExpressionList(token.RBRACKET) + + return array +} + +func (p *Parser) parseIndexExpression(left ast.Expression) ast.Expression { + exp := &ast.IndexExpression{Token: p.curToken, Left: left} + + p.nextToken() + exp.Index = p.parseExpression(LOWEST) + + if !p.expectPeek(token.RBRACKET) { + return nil + } + + return exp +} diff --git a/parser/parser_test.go b/parser/parser_test.go index 3b872f9..22e231e 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -341,6 +341,14 @@ func TestOperatorPrecedenceParsing(t *testing.T) { "add(a + b + c * d / f + g)", "add((((a + b) + ((c * d) / f)) + g))", }, + { + "a * [1, 2, 3, 4][b * c] * d", + "((a * ([1, 2, 3, 4][(b * c)])) * d)", + }, + { + "add(a * b[2], b[1], 2 * [1, 2][1])", + "add((a * (b[2])), (b[1]), (2 * ([1, 2][1])))", + }, } for _, tt := range tests { @@ -689,6 +697,52 @@ func TestStringLiteralExpression(t *testing.T) { } } +func TestParsingArrayLiterals(t *testing.T) { + input := "[1, 2 * 2, 3 + 3]" + + l := lexer.New(input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + stmt := program.Statements[0].(*ast.ExpressionStatement) + array, ok := stmt.Expression.(*ast.ArrayLiteral) + if !ok { + t.Fatalf("exp not *ast.ArrayLiteral. got=%T", stmt.Expression) + } + + if len(array.Elements) != 3 { + t.Fatalf("len(array.Elements) not 3. got=%d", len(array.Elements)) + } + + testIntegerLiteral(t, array.Elements[0], 1) + testInfixExpression(t, array.Elements[1], 2, "*", 2) + testInfixExpression(t, array.Elements[2], 3, "+", 3) +} + +func TestParsingIndexExpressions(t *testing.T) { + input := "myArray[1 + 1]" + + l := lexer.New(input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + stmt := program.Statements[0].(*ast.ExpressionStatement) + indexExp, ok := stmt.Expression.(*ast.IndexExpression) + if !ok { + t.Fatalf("exp not *ast.IndexExpression. got=%T", stmt.Expression) + } + + if !testIdentifier(t, indexExp.Left, "myArray") { + return + } + + if !testInfixExpression(t, indexExp.Index, 1, "+", 1) { + return + } +} + func testLetStatement(t *testing.T, s ast.Statement, name string) bool { if s.TokenLiteral() != "let" { t.Errorf("s.TokenLiteral not 'let'. got=%q", s.TokenLiteral()) diff --git a/token/token.go b/token/token.go index fa5e6dc..b3937cf 100644 --- a/token/token.go +++ b/token/token.go @@ -34,10 +34,12 @@ const ( COMMA = "," SEMICOLON = ";" - LPAREN = "(" - RPAREN = ")" - LBRACE = "{" - RBRACE = "}" + LPAREN = "(" + RPAREN = ")" + LBRACE = "{" + RBRACE = "}" + LBRACKET = "[" + RBRACKET = "]" // Keywords FUNCTION = "FUNCTION"