diff --git a/ast/ast.go b/ast/ast.go index fbbbf70..015c3ea 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -74,6 +74,13 @@ type PrefixExpression struct { Right Expression } +type InfixExpression struct { + Token token.Token + Left Expression + Operator string + Right Expression +} + func (ls *LetStatement) statementNode() { } func (ls *LetStatement) TokenLiteral() string { @@ -157,3 +164,19 @@ func (pe *PrefixExpression) String() string { return out.String() } + +func (ie *InfixExpression) expressionNode() {} +func (ie *InfixExpression) TokenLiteral() string { + return ie.Token.Literal +} +func (ie *InfixExpression) String() string { + var out bytes.Buffer + + out.WriteString("(") + out.WriteString(ie.Left.String()) + out.WriteString(" " + ie.Operator + " ") + out.WriteString(ie.Right.String()) + out.WriteString(")") + + return out.String() +} diff --git a/parser/parser.go b/parser/parser.go index a42dd6a..2dbe71e 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -19,6 +19,17 @@ const ( CALL // myFunction(X) ) +var precedences = map[token.TokenType]int{ + token.EQ: EQUALS, + token.NOT_EQ: EQUALS, + token.LT: LESSGREATER, + token.GT: LESSGREATER, + token.PLUS: SUM, + token.MINUS: SUM, + token.SLASH: PRODUCT, + token.ASTERISK: PRODUCT, +} + type Parser struct { l *lexer.Lexer errors []string @@ -52,6 +63,16 @@ func New(l *lexer.Lexer) *Parser { p.registerPrefix(token.BANG, p.parsePrefixExpression) p.registerPrefix(token.MINUS, p.parsePrefixExpression) + p.infixParseFns = make(map[token.TokenType]infixParseFn) + p.registerInfix(token.PLUS, p.parseInfixExpression) + p.registerInfix(token.MINUS, p.parseInfixExpression) + p.registerInfix(token.SLASH, p.parseInfixExpression) + p.registerInfix(token.ASTERISK, p.parseInfixExpression) + p.registerInfix(token.EQ, p.parseInfixExpression) + p.registerInfix(token.NOT_EQ, p.parseInfixExpression) + p.registerInfix(token.LT, p.parseInfixExpression) + p.registerInfix(token.GT, p.parseInfixExpression) + // Read two token, so curToken and peekToken are both set p.nextToken() p.nextToken() @@ -170,6 +191,17 @@ func (p *Parser) parseExpression(precedence int) ast.Expression { } leftExp := prefix() + for !p.peekTokenIs(token.SEMICOLON) && precedence < p.peekPrecedence() { + infix := p.infixParseFns[p.peekToken.Type] + if infix == nil { + return leftExp + } + + p.nextToken() + + leftExp = infix(leftExp) + } + return leftExp } @@ -220,3 +252,33 @@ func (p *Parser) parsePrefixExpression() ast.Expression { return expression } + +func (p *Parser) peekPrecedence() int { + if p, ok := precedences[p.peekToken.Type]; ok { + return p + } + + return LOWEST +} + +func (p *Parser) curPrecedence() int { + if p, ok := precedences[p.curToken.Type]; ok { + return p + } + + return LOWEST +} + +func (p *Parser) parseInfixExpression(left ast.Expression) ast.Expression { + expression := &ast.InfixExpression{ + Token: p.curToken, + Operator: p.curToken.Literal, + Left: left, + } + + precedence := p.curPrecedence() + p.nextToken() + expression.Right = p.parseExpression(precedence) + + return expression +} diff --git a/parser/parser_test.go b/parser/parser_test.go index 69023ad..7c9ac89 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -8,49 +8,355 @@ import ( ) func TestLetStatements(t *testing.T) { - input := ` -return 5; -return 10; -return 993322; -` + tests := []struct { + input string + expectedIdentifier string + expectedValue interface{} + }{ + {"let x = 5;", "x", 5}, + {"let y = true;", "y", true}, + {"let foobar = y;", "foobar", "y"}, + } + + for _, tt := range tests { + l := lexer.New(tt.input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("program.Statements does not contain 1 statements. got=%d", + len(program.Statements)) + } + + stmt := program.Statements[0] + if !testLetStatement(t, stmt, tt.expectedIdentifier) { + return + } + + val := stmt.(*ast.LetStatement).Value + if !testLiteralExpression(t, val, tt.expectedValue) { + return + } + } +} + +func TestReturnStatements(t *testing.T) { + tests := []struct { + input string + expectedValue interface{} + }{ + {"return 5;", 5}, + {"return true;", true}, + {"return foobar;", "foobar"}, + } + + for _, tt := range tests { + l := lexer.New(tt.input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("program.Statements does not contain 1 statements. got=%d", + len(program.Statements)) + } + + stmt := program.Statements[0] + returnStmt, ok := stmt.(*ast.ReturnStatement) + if !ok { + t.Fatalf("stmt not *ast.ReturnStatement. got=%T", stmt) + } + if returnStmt.TokenLiteral() != "return" { + t.Fatalf("returnStmt.TokenLiteral not 'return', got %q", + returnStmt.TokenLiteral()) + } + if testLiteralExpression(t, returnStmt.ReturnValue, tt.expectedValue) { + return + } + } +} + +func TestIdentifierExpression(t *testing.T) { + input := "foobar;" + l := lexer.New(input) p := New(l) - program := p.ParseProgram() checkParserErrors(t, p) - if len(program.Statements) != 3 { - t.Fatalf("program.Statements does not contain 3 statements. got=%d", + if len(program.Statements) != 1 { + t.Fatalf("program has not enough statements. got=%d", len(program.Statements)) } + stmt, ok := program.Statements[0].(*ast.ExpressionStatement) + if !ok { + t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T", + program.Statements[0]) + } - for _, stmt := range program.Statements { - returnStmt, ok := stmt.(*ast.ReturnStatement) + ident, ok := stmt.Expression.(*ast.Identifier) + if !ok { + t.Fatalf("exp not *ast.Identifier. got=%T", stmt.Expression) + } + if ident.Value != "foobar" { + t.Errorf("ident.Value not %s. got=%s", "foobar", ident.Value) + } + if ident.TokenLiteral() != "foobar" { + t.Errorf("ident.TokenLiteral not %s. got=%s", "foobar", + ident.TokenLiteral()) + } +} + +func TestIntegerLiteralExpression(t *testing.T) { + input := "5;" + + l := lexer.New(input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("program has not enough statements. got=%d", + len(program.Statements)) + } + stmt, ok := program.Statements[0].(*ast.ExpressionStatement) + if !ok { + t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T", + program.Statements[0]) + } + + literal, ok := stmt.Expression.(*ast.IntegerLiteral) + if !ok { + t.Fatalf("exp not *ast.IntegerLiteral. got=%T", stmt.Expression) + } + if literal.Value != 5 { + t.Errorf("literal.Value not %d. got=%d", 5, literal.Value) + } + if literal.TokenLiteral() != "5" { + t.Errorf("literal.TokenLiteral not %s. got=%s", "5", + literal.TokenLiteral()) + } +} + +func TestParsingPrefixExpressions(t *testing.T) { + prefixTests := []struct { + input string + operator string + value interface{} + }{ + {"!5;", "!", 5}, + {"-15;", "-", 15}, + {"!foobar;", "!", "foobar"}, + {"-foobar;", "-", "foobar"}, + {"!true;", "!", true}, + {"!false;", "!", false}, + } + + for _, tt := range prefixTests { + l := lexer.New(tt.input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("program.Statements does not contain %d statements. got=%d\n", + 1, len(program.Statements)) + } + + stmt, ok := program.Statements[0].(*ast.ExpressionStatement) if !ok { - t.Errorf("stmt not *ast.ReturnStatement. got=%T", stmt) - continue + t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T", + program.Statements[0]) } - if returnStmt.TokenLiteral() != "return" { - t.Errorf("returnStmt.TokenLiteral not 'return', got %q", - returnStmt.TokenLiteral()) + + exp, ok := stmt.Expression.(*ast.PrefixExpression) + if !ok { + t.Fatalf("stmt is not ast.PrefixExpression. got=%T", stmt.Expression) + } + if exp.Operator != tt.operator { + t.Fatalf("exp.Operator is not '%s'. got=%s", + tt.operator, exp.Operator) + } + if !testLiteralExpression(t, exp.Right, tt.value) { + return } } } -func checkParserErrors(t *testing.T, p *Parser) { - errors := p.Errors() - if len(errors) == 0 { - return +func TestParsingInfixExpressions(t *testing.T) { + infixTests := []struct { + input string + leftValue interface{} + operator string + rightValue interface{} + }{ + {"5 + 5;", 5, "+", 5}, + {"5 - 5;", 5, "-", 5}, + {"5 * 5;", 5, "*", 5}, + {"5 / 5;", 5, "/", 5}, + {"5 > 5;", 5, ">", 5}, + {"5 < 5;", 5, "<", 5}, + {"5 == 5;", 5, "==", 5}, + {"5 != 5;", 5, "!=", 5}, + {"foobar + barfoo;", "foobar", "+", "barfoo"}, + {"foobar - barfoo;", "foobar", "-", "barfoo"}, + {"foobar * barfoo;", "foobar", "*", "barfoo"}, + {"foobar / barfoo;", "foobar", "/", "barfoo"}, + {"foobar > barfoo;", "foobar", ">", "barfoo"}, + {"foobar < barfoo;", "foobar", "<", "barfoo"}, + {"foobar == barfoo;", "foobar", "==", "barfoo"}, + {"foobar != barfoo;", "foobar", "!=", "barfoo"}, + {"true == true", true, "==", true}, + {"true != false", true, "!=", false}, + {"false == false", false, "==", false}, } - t.Errorf("parser has %d errors", len(errors)) - for _, msg := range errors { - t.Errorf("parser error %q", msg) + for _, tt := range infixTests { + l := lexer.New(tt.input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("program.Statements does not contain %d statements. got=%d\n", + 1, len(program.Statements)) + } + + stmt, ok := program.Statements[0].(*ast.ExpressionStatement) + if !ok { + t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T", + program.Statements[0]) + } + + if !testInfixExpression(t, stmt.Expression, tt.leftValue, + tt.operator, tt.rightValue) { + return + } } - t.FailNow() } -func testLetStatements(t *testing.T, s ast.Statement, name string) bool { +func TestOperatorPrecedenceParsing(t *testing.T) { + tests := []struct { + input string + expected string + }{ + { + "-a * b", + "((-a) * b)", + }, + { + "!-a", + "(!(-a))", + }, + { + "a + b + c", + "((a + b) + c)", + }, + { + "a + b - c", + "((a + b) - c)", + }, + { + "a * b * c", + "((a * b) * c)", + }, + { + "a * b / c", + "((a * b) / c)", + }, + { + "a + b / c", + "(a + (b / c))", + }, + { + "a + b * c + d / e - f", + "(((a + (b * c)) + (d / e)) - f)", + }, + { + "3 + 4; -5 * 5", + "(3 + 4)((-5) * 5)", + }, + { + "5 > 4 == 3 < 4", + "((5 > 4) == (3 < 4))", + }, + { + "5 < 4 != 3 > 4", + "((5 < 4) != (3 > 4))", + }, + { + "3 + 4 * 5 == 3 * 1 + 4 * 5", + "((3 + (4 * 5)) == ((3 * 1) + (4 * 5)))", + }, + { + "true", + "true", + }, + { + "false", + "false", + }, + { + "3 > 5 == false", + "((3 > 5) == false)", + }, + { + "3 < 5 == true", + "((3 < 5) == true)", + }, + { + "1 + (2 + 3) + 4", + "((1 + (2 + 3)) + 4)", + }, + { + "(5 + 5) * 2", + "((5 + 5) * 2)", + }, + { + "2 / (5 + 5)", + "(2 / (5 + 5))", + }, + { + "(5 + 5) * 2 * (5 + 5)", + "(((5 + 5) * 2) * (5 + 5))", + }, + { + "-(5 + 5)", + "(-(5 + 5))", + }, + { + "!(true == true)", + "(!(true == true))", + }, + { + "a + add(b * c) + d", + "((a + add((b * c))) + d)", + }, + { + "add(a, b, 1, 2 * 3, 4 + 5, add(6, 7 * 8))", + "add(a, b, 1, (2 * 3), (4 + 5), add(6, (7 * 8)))", + }, + { + "add(a + b + c * d / f + g)", + "add((((a + b) + ((c * d) / f)) + g))", + }, + } + + for _, tt := range tests { + l := lexer.New(tt.input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + actual := program.String() + if actual != tt.expected { + t.Errorf("expected=%q, got=%q", tt.expected, actual) + } + } +} + +func testLetStatement(t *testing.T, s ast.Statement, name string) bool { if s.TokenLiteral() != "let" { t.Errorf("s.TokenLiteral not 'let'. got=%q", s.TokenLiteral()) return false @@ -68,104 +374,54 @@ func testLetStatements(t *testing.T, s ast.Statement, name string) bool { } if letStmt.Name.TokenLiteral() != name { - t.Errorf("letStmt.Name.TokenLiteral not '%s'. got=%s", name, letStmt.Name.TokenLiteral()) + t.Errorf("letStmt.Name.TokenLiteral() not '%s'. got=%s", + name, letStmt.Name.TokenLiteral()) return false } return true } -func TestIdentifierExpressions(t *testing.T) { - input := "foobar;" +func testInfixExpression(t *testing.T, exp ast.Expression, left interface{}, + operator string, right interface{}) bool { - l := lexer.New(input) - p := New(l) - program := p.ParseProgram() - checkParserErrors(t, p) - - if len(program.Statements) != 1 { - t.Fatalf("program has not enough statements. got=%d", len(program.Statements)) - } - stmt, ok := program.Statements[0].(*ast.ExpressionStatement) + opExp, ok := exp.(*ast.InfixExpression) if !ok { - t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T", program.Statements[0]) + t.Errorf("exp is not ast.InfixExpression. got=%T(%s)", exp, exp) + return false } - ident, ok := stmt.Expression.(*ast.Identifier) - if !ok { - t.Fatalf("exp is not *ast.Identifier. got=%T", stmt.Expression) + if !testLiteralExpression(t, opExp.Left, left) { + return false } - if ident.Value != "foobar" { - t.Errorf("ident.Value not %s. got=%s", "foobar", ident.Value) + + if opExp.Operator != operator { + t.Errorf("exp.Operator is not '%s'. got=%q", operator, opExp.Operator) + return false } - if ident.TokenLiteral() != "foobar" { - t.Errorf("ident.TokenLiteral not %s. got=%s", "foobar", ident.TokenLiteral()) + + if !testLiteralExpression(t, opExp.Right, right) { + return false } + + return true } -func TestIntegerLiteralExpressions(t *testing.T) { - input := "5;" - - l := lexer.New(input) - p := New(l) - program := p.ParseProgram() - checkParserErrors(t, p) - - if len(program.Statements) != 1 { - t.Fatalf("program has not enough statements. got=%d", len(program.Statements)) - } - stmt, ok := program.Statements[0].(*ast.ExpressionStatement) - if !ok { - t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T", program.Statements[0]) - } - - literal, ok := stmt.Expression.(*ast.IntegerLiteral) - if !ok { - t.Fatalf("exp is not *ast.IntegerLiteral. got=%T", stmt.Expression) - } - if literal.Value != 5 { - t.Errorf("ident.Value not %d. got=%d", 5, literal.Value) - } - if literal.TokenLiteral() != "5" { - t.Errorf("ident.TokenLiteral not %s. got=%s", "5", literal.TokenLiteral()) - } -} - -func TestParsingPrefixExpressions(t *testing.T) { - prefixTests := []struct { - input string - operator string - integerValue int64 - }{ - {"!5;", "!", 5}, - {"-15;", "-", 15}, - } - - for _, tt := range prefixTests { - l := lexer.New(tt.input) - p := New(l) - program := p.ParseProgram() - checkParserErrors(t, p) - - if len(program.Statements) != 1 { - t.Fatalf("program.Statements does not contain %d statements. got=%d\n", 1, len(program.Statements)) - } - stmt, ok := program.Statements[0].(*ast.ExpressionStatement) - if !ok { - t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T", program.Statements[0]) - } - - exp, ok := stmt.Expression.(*ast.PrefixExpression) - if !ok { - t.Fatalf("exp is not *ast.PrefixExpression. got=%T", stmt.Expression) - } - if exp.Operator != tt.operator { - t.Fatalf("exp.Operator is not '%s'. got=%s", tt.operator, exp.Operator) - } - if !testIntegerLiteral(t, exp.Right, tt.integerValue) { - return - } +func testLiteralExpression( + t *testing.T, + exp ast.Expression, + expected interface{}, +) bool { + switch v := expected.(type) { + case int: + return testIntegerLiteral(t, exp, int64(v)) + case int64: + return testIntegerLiteral(t, exp, v) + case string: + return testIdentifier(t, exp, v) } + t.Errorf("type of exp not handled. got=%T", exp) + return false } func testIntegerLiteral(t *testing.T, il ast.Expression, value int64) bool { @@ -181,9 +437,44 @@ func testIntegerLiteral(t *testing.T, il ast.Expression, value int64) bool { } if integ.TokenLiteral() != fmt.Sprintf("%d", value) { - t.Errorf("integ.TopkenLiteral not %d. got=%s", value, integ.TokenLiteral()) + t.Errorf("integ.TokenLiteral not %d. got=%s", value, + integ.TokenLiteral()) return false } return true } + +func testIdentifier(t *testing.T, exp ast.Expression, value string) bool { + ident, ok := exp.(*ast.Identifier) + if !ok { + t.Errorf("exp not *ast.Identifier. got=%T", exp) + return false + } + + if ident.Value != value { + t.Errorf("ident.Value not %s. got=%s", value, ident.Value) + return false + } + + if ident.TokenLiteral() != value { + t.Errorf("ident.TokenLiteral not %s. got=%s", value, + ident.TokenLiteral()) + return false + } + + return true +} + +func checkParserErrors(t *testing.T, p *Parser) { + errors := p.Errors() + if len(errors) == 0 { + return + } + + t.Errorf("parser has %d errors", len(errors)) + for _, msg := range errors { + t.Errorf("parser error: %q", msg) + } + t.FailNow() +} diff --git a/parser/parser_tracing.go b/parser/parser_tracing.go new file mode 100644 index 0000000..5fc569b --- /dev/null +++ b/parser/parser_tracing.go @@ -0,0 +1,32 @@ +package parser + +import ( + "fmt" + "strings" +) + +var traceLevel int = 0 + +const traceIdentPlaceholder string = "\t" + +func identLevel() string { + return strings.Repeat(traceIdentPlaceholder, traceLevel-1) +} + +func tracePrint(fs string) { + fmt.Printf("%s%s\n", identLevel(), fs) +} + +func incIdent() { traceLevel = traceLevel + 1 } +func decIdent() { traceLevel = traceLevel - 1 } + +func trace(msg string) string { + incIdent() + tracePrint("BEGIN " + msg) + return msg +} + +func untrace(msg string) { + tracePrint("END " + msg) + decIdent() +}