From be81b9a6d68226a59c14f42e27fc4bc3418d3e02 Mon Sep 17 00:00:00 2001 From: Chuck Smith Date: Tue, 19 Mar 2024 20:30:30 -0400 Subject: [PATCH] Change assignment into expressions --- ast/ast.go | 20 +++--- code/code.go | 6 +- compiler/compiler.go | 62 +++++++++++----- compiler/compiler_test.go | 101 ++++++++++++++++++++------ evaluator/evaluator.go | 52 +++++++++++--- evaluator/evaluator_test.go | 43 ++++++++--- go.mod | 8 +++ go.sum | 9 +++ parser/parser.go | 45 +++++++----- parser/parser_test.go | 138 ++++++++++++++++++++++++------------ vm/vm.go | 99 +++++++++++++++++++++++--- vm/vm_test.go | 48 +++++++++---- 12 files changed, 478 insertions(+), 153 deletions(-) create mode 100644 go.sum diff --git a/ast/ast.go b/ast/ast.go index 2fbc53f..460b595 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -389,29 +389,27 @@ func (we *WhileExpression) String() string { return out.String() } -// AssignmentStatement the `=` statement represents the AST node that rebinds -// an expression to an identifier (assigning a new value). -type AssignmentStatement struct { +// AssignmentExpression represents an assignment expression of the form: +// x = 1 or xs[1] = 2 +type AssignmentExpression struct { Token token.Token // the token.ASSIGN token - Name *Identifier + Left Expression Value Expression } -func (as AssignmentStatement) TokenLiteral() string { +func (as AssignmentExpression) TokenLiteral() string { return as.Token.Literal } -func (as AssignmentStatement) String() string { +func (as AssignmentExpression) String() string { var out bytes.Buffer - out.WriteString(as.Name.String()) - out.WriteString(as.TokenLiteral() + " ") + out.WriteString(as.Left.String()) + out.WriteString(as.TokenLiteral()) out.WriteString(as.Value.String()) - out.WriteString(";") - return out.String() } -func (as AssignmentStatement) statementNode() {} +func (as AssignmentExpression) expressionNode() {} // Comment a comment type Comment struct { diff --git a/code/code.go b/code/code.go index 3dbcd5f..104c366 100644 --- a/code/code.go +++ b/code/code.go @@ -42,7 +42,8 @@ const ( OpSetGlobal OpArray OpHash - OpIndex + OpGetItem + OpSetItem OpCall OpReturn OpGetLocal @@ -83,7 +84,8 @@ var definitions = map[Opcode]*Definition{ OpSetGlobal: {"OpSetGlobal", []int{2}}, OpArray: {"OpArray", []int{2}}, OpHash: {"OpHash", []int{2}}, - OpIndex: {"OpIndex", []int{}}, + OpGetItem: {"OpGetItem", []int{}}, + OpSetItem: {"OpSetItem", []int{}}, OpCall: {"OpCall", []int{1}}, OpReturn: {"OpReturn", []int{}}, OpGetLocal: {"OpGetLocal", []int{1}}, diff --git a/compiler/compiler.go b/compiler/compiler.go index 800ae7a..aac2cee 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -246,23 +246,50 @@ func (c *Compiler) Compile(node ast.Node) error { } } - case *ast.AssignmentStatement: - symbol, ok := c.symbolTable.Resolve(node.Name.Value) - if !ok { - return fmt.Errorf("undefined variable %s", node.Value) - } + case *ast.AssignmentExpression: + if ident, ok := node.Left.(*ast.Identifier); ok { + symbol, ok := c.symbolTable.Resolve(ident.Value) + if !ok { + return fmt.Errorf("undefined variable %s", ident.Value) + } - c.l++ - err := c.Compile(node.Value) - c.l-- - if err != nil { - return err - } + c.l++ + err := c.Compile(node.Value) + c.l-- + if err != nil { + return err + } - if symbol.Scope == GlobalScope { - c.emit(code.OpAssignGlobal, symbol.Index) + if symbol.Scope == GlobalScope { + c.emit(code.OpAssignGlobal, symbol.Index) + } else { + c.emit(code.OpAssignLocal, symbol.Index) + } + } else if ie, ok := node.Left.(*ast.IndexExpression); ok { + c.l++ + err := c.Compile(ie.Left) + c.l-- + if err != nil { + return err + } + + c.l++ + err = c.Compile(ie.Index) + c.l-- + if err != nil { + return err + } + + c.l++ + err = c.Compile(node.Value) + c.l-- + if err != nil { + return err + } + + c.emit(code.OpSetItem) } else { - c.emit(code.OpAssignLocal, symbol.Index) + return fmt.Errorf("expected identifier or index expression got=%s", node.Left) } case *ast.LetStatement: @@ -349,7 +376,7 @@ func (c *Compiler) Compile(node ast.Node) error { return err } - c.emit(code.OpIndex) + c.emit(code.OpGetItem) case *ast.FunctionLiteral: c.enterScope() @@ -446,9 +473,12 @@ func (c *Compiler) Compile(node ast.Node) error { } c.l-- + // Pop off the LoadNull(s) from ast.BlockStatement(s) + c.emit(code.OpPop) + c.emit(code.OpJump, jumpConditionPos) - afterConsequencePos := c.emit(code.OpNoop) + afterConsequencePos := c.emit(code.OpNull) c.changeOperand(jumpIfFalsePos, afterConsequencePos) } diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go index ce48f88..6936f01 100644 --- a/compiler/compiler_test.go +++ b/compiler/compiler_test.go @@ -2,6 +2,7 @@ package compiler import ( "fmt" + "github.com/stretchr/testify/assert" "monkey/ast" "monkey/code" "monkey/lexer" @@ -16,6 +17,12 @@ type compilerTestCase struct { expectedInstructions []code.Instructions } +type compilerTestCase2 struct { + input string + constants []interface{} + instructions string +} + func TestIntegerArithmetic(t *testing.T) { tests := []compilerTestCase{ { @@ -264,34 +271,30 @@ func TestConditionals(t *testing.T) { // 0006 code.Make(code.OpTrue), // 0007 - code.Make(code.OpJumpNotTruthy, 20), + code.Make(code.OpJumpNotTruthy, 19), // 0010 code.Make(code.OpConstant, 1), // 0013 code.Make(code.OpAssignGlobal, 0), - // 0016 + // 0018 + code.Make(code.OpJump, 20), + // 0019 code.Make(code.OpNull), - // 0017 - code.Make(code.OpJump, 21), // 0020 - code.Make(code.OpNull), - // 0021 code.Make(code.OpPop), - // 0022 + // 0021 code.Make(code.OpFalse), - // 0023 - code.Make(code.OpJumpNotTruthy, 36), - // 0025 + // 0022 + code.Make(code.OpJumpNotTruthy, 34), + // 0024 code.Make(code.OpConstant, 2), - // 0029 + // 0028 code.Make(code.OpAssignGlobal, 0), // 0032 + code.Make(code.OpJump, 35), + // 0035 code.Make(code.OpNull), - // 0033 - code.Make(code.OpJump, 37), // 0036 - code.Make(code.OpNull), - // 0037 code.Make(code.OpPop), }, }, @@ -475,7 +478,7 @@ func TestIndexExpressions(t *testing.T) { code.Make(code.OpConstant, 3), code.Make(code.OpConstant, 4), code.Make(code.OpAdd), - code.Make(code.OpIndex), + code.Make(code.OpGetItem), code.Make(code.OpPop), }, }, @@ -489,7 +492,7 @@ func TestIndexExpressions(t *testing.T) { code.Make(code.OpConstant, 2), code.Make(code.OpConstant, 3), code.Make(code.OpSub), - code.Make(code.OpIndex), + code.Make(code.OpGetItem), code.Make(code.OpPop), }, }, @@ -823,6 +826,21 @@ func TestFunctionCalls(t *testing.T) { runCompilerTests(t, tests) } +func TestAssignmentExpressions(t *testing.T) { + tests := []compilerTestCase2{ + { + input: ` + let x = 1 + x = 2 + `, + constants: []interface{}{1, 2}, + instructions: "0000 OpConstant 0\n0003 OpSetGlobal 0\n0006 OpConstant 1\n0009 OpAssignGlobal 0\n0012 OpPop\n", + }, + } + + runCompilerTests2(t, tests) +} + func TestAssignmentStatementScopes(t *testing.T) { tests := []compilerTestCase{ { @@ -1085,14 +1103,16 @@ func TestIteration(t *testing.T) { // 0000 code.Make(code.OpTrue), // 0001 - code.Make(code.OpJumpNotTruthy, 10), + code.Make(code.OpJumpNotTruthy, 11), // 0004 code.Make(code.OpConstant, 0), // 0007 + code.Make(code.OpPop), + // 0008 code.Make(code.OpJump, 0), - // 0010 - code.Make(code.OpNoop), // 0011 + code.Make(code.OpNull), + // 0012 code.Make(code.OpPop), }, }, @@ -1127,6 +1147,25 @@ func runCompilerTests(t *testing.T, tests []compilerTestCase) { } } +func runCompilerTests2(t *testing.T, tests []compilerTestCase2) { + t.Helper() + + assert := assert.New(t) + + for _, tt := range tests { + program := parse(tt.input) + + compiler := New() + err := compiler.Compile(program) + assert.NoError(err) + + bytecode := compiler.Bytecode() + assert.Equal(tt.instructions, bytecode.Instructions.String()) + + testConstants2(t, tt.constants, bytecode.Constants) + } +} + func parse(input string) *ast.Program { l := lexer.New(input) p := parser.New(l) @@ -1193,6 +1232,28 @@ func testConstants(t *testing.T, expected []interface{}, actual []object.Object) return nil } +func testConstants2(t *testing.T, expected []interface{}, actual []object.Object) { + assert := assert.New(t) + + assert.Equal(len(expected), len(actual)) + + for i, constant := range expected { + switch constant := constant.(type) { + + case []code.Instructions: + fn, ok := actual[i].(*object.CompiledFunction) + assert.True(ok) + assert.Equal(constant, fn.Instructions.String()) + + case string: + assert.Equal(constant, actual[i].(*object.String).Value) + + case int: + assert.Equal(int64(constant), actual[i].(*object.Integer).Value) + } + } +} + func testIntegerObject(expected int64, actual object.Object) interface{} { result, ok := actual.(*object.Integer) if !ok { diff --git a/evaluator/evaluator.go b/evaluator/evaluator.go index 7b453c2..73ed21d 100644 --- a/evaluator/evaluator.go +++ b/evaluator/evaluator.go @@ -45,20 +45,54 @@ func Eval(node ast.Node, env *object.Environment) object.Object { } return &object.ReturnValue{Value: val} - case *ast.AssignmentStatement: - obj := evalIdentifier(node.Name, env) - if isError(obj) { - return obj + case *ast.AssignmentExpression: + left := Eval(node.Left, env) + if isError(left) { + return left } - val := Eval(node.Value, env) - if isError(val) { - return val + value := Eval(node.Value, env) + if isError(value) { + return value } - env.Set(node.Name.Value, val) + if ident, ok := node.Left.(*ast.Identifier); ok { + env.Set(ident.Value, value) + } else if ie, ok := node.Left.(*ast.IndexExpression); ok { + obj := Eval(ie.Left, env) + if isError(obj) { + return obj + } - return val + if array, ok := obj.(*object.Array); ok { + index := Eval(ie.Index, env) + if isError(index) { + return index + } + if idx, ok := index.(*object.Integer); ok { + array.Elements[idx.Value] = value + } else { + return newError("cannot index array with %#v", index) + } + } else if hash, ok := obj.(*object.Hash); ok { + key := Eval(ie.Index, env) + if isError(key) { + return key + } + if hashKey, ok := key.(object.Hashable); ok { + hashed := hashKey.HashKey() + hash.Pairs[hashed] = object.HashPair{Key: key, Value: value} + } else { + return newError("cannot index hash with %T", key) + } + } else { + return newError("object type %T does not support item assignment", obj) + } + } else { + return newError("expected identifier or index expression got=%T", left) + } + + return NULL case *ast.LetStatement: val := Eval(node.Value, env) diff --git a/evaluator/evaluator_test.go b/evaluator/evaluator_test.go index aa2c999..6813f3c 100644 --- a/evaluator/evaluator_test.go +++ b/evaluator/evaluator_test.go @@ -108,6 +108,8 @@ func TestIfElseExpression(t *testing.T) { {"if (1 > 2) { 10 }", nil}, {"if (1 > 2) { 10 } else { 20 }", 20}, {"if (1 < 2) { 10 } else { 20 }", 10}, + {"if (1 < 2) { 10 } else if (1 == 2) { 20 }", 10}, + {"if (1 > 2) { 10 } else if (1 == 2) { 20 } else { 30 }", 30}, } for _, tt := range tests { @@ -242,24 +244,45 @@ func TestErrorHandling(t *testing.T) { } } -func TestAssignmentStatements(t *testing.T) { +func TestIndexAssignmentStatements(t *testing.T) { tests := []struct { input string expected int64 }{ - {"let a = 0; a = 5;", 5}, + {"let xs = [1, 2, 3]; xs[1] = 4; xs[1];", 4}, + } + + for _, tt := range tests { + evaluated := testEval(tt.input) + testIntegerObject(t, evaluated, tt.expected) + } +} + +func TestAssignmentStatements(t *testing.T) { + tests := []struct { + input string + expected interface{} + }{ + {"let a = 0; a = 5;", nil}, {"let a = 0; a = 5; a;", 5}, - {"let a = 0; a = 5 * 5;", 25}, + {"let a = 0; a = 5 * 5;", nil}, {"let a = 0; a = 5 * 5; a;", 25}, - {"let a = 0; a = 5; let b = 0; b = a;", 5}, + {"let a = 0; a = 5; let b = 0; b = a;", nil}, {"let a = 0; a = 5; let b = 0; b = a; b;", 5}, - {"let a = 0; a = 5; let b = 0; b = a; let c = 0; c = a + b + 5;", 15}, + {"let a = 0; a = 5; let b = 0; b = a; let c = 0; c = a + b + 5;", nil}, {"let a = 0; a = 5; let b = 0; b = a; let c = 0; c = a + b + 5; c;", 15}, + {"let a = 5; let b = a; a = 0;", nil}, {"let a = 5; let b = a; a = 0; b;", 5}, } for _, tt := range tests { - testIntegerObject(t, testEval(tt.input), tt.expected) + evaluated := testEval(tt.input) + integer, ok := tt.expected.(int) + if ok { + testIntegerObject(t, evaluated, int64(integer)) + } else { + testNullObject(t, evaluated) + } } } @@ -585,8 +608,12 @@ func TestWhileExpressions(t *testing.T) { {"while (false) { }", nil}, {"let n = 0; while (n < 10) { let n = n + 1 }; n", 10}, {"let n = 10; while (n > 0) { let n = n - 1 }; n", 0}, - {"let n = 0; while (n < 10) { n = n + 1 }", 10}, - {"let n = 10; while (n > 0) { n = n - 1 }", 0}, + {"let n = 0; while (n < 10) { let n = n + 1 }", nil}, + {"let n = 10; while (n > 0) { let n = n - 1 }", nil}, + {"let n = 0; while (n < 10) { n = n + 1 }; n", 10}, + {"let n = 10; while (n > 0) { n = n - 1 }; n", 0}, + {"let n = 0; while (n < 10) { n = n + 1 }", nil}, + {"let n = 10; while (n > 0) { n = n - 1 }", nil}, } for _, tt := range tests { diff --git a/go.mod b/go.mod index 3c6204e..43e61ca 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,11 @@ module monkey go 1.21 + +require github.com/stretchr/testify v1.8.4 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..8cf6655 --- /dev/null +++ b/go.sum @@ -0,0 +1,9 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/parser/parser.go b/parser/parser.go index 807ab1c..2c59d51 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -11,6 +11,7 @@ import ( const ( _ int = iota LOWEST + ASSIGN // = EQUALS // == LESSGREATER // > or < SUM // + @@ -21,6 +22,7 @@ const ( ) var precedences = map[token.TokenType]int{ + token.ASSIGN: ASSIGN, token.EQ: EQUALS, token.NOT_EQ: EQUALS, token.LT: LESSGREATER, @@ -86,6 +88,7 @@ func New(l *lexer.Lexer) *Parser { p.registerInfix(token.GTE, p.parseInfixExpression) p.registerInfix(token.LPAREN, p.parseCallExpression) p.registerInfix(token.LBRACKET, p.parseIndexExpression) + p.registerInfix(token.ASSIGN, p.parseAssignmentExpression) // Read two tokens, so curToken and peekToken are both set p.nextToken() @@ -153,10 +156,6 @@ func (p *Parser) ParseProgram() *ast.Program { } func (p *Parser) parseStatement() ast.Statement { - if p.peekToken.Type == token.ASSIGN { - return p.parseAssignmentStatement() - } - switch p.curToken.Type { case token.COMMENT: return p.parseComment() @@ -347,6 +346,18 @@ func (p *Parser) parseIfExpression() ast.Expression { if p.peekTokenIs(token.ELSE) { p.nextToken() + if p.peekTokenIs(token.IF) { + p.nextToken() + expression.Alternative = &ast.BlockStatement{ + Statements: []ast.Statement{ + &ast.ExpressionStatement{ + Expression: p.parseIfExpression(), + }, + }, + } + return expression + } + if !p.expectPeek(token.LBRACE) { return nil } @@ -534,20 +545,22 @@ func (p *Parser) parseWhileExpression() ast.Expression { return expression } -func (p *Parser) parseAssignmentStatement() ast.Statement { - stmt := &ast.AssignmentStatement{Token: p.peekToken} - stmt.Name = &ast.Identifier{Token: p.curToken, Value: p.curToken.Literal} - - p.nextToken() - p.nextToken() - - stmt.Value = p.parseExpression(LOWEST) - - if p.peekTokenIs(token.SEMICOLON) { - p.nextToken() +func (p *Parser) parseAssignmentExpression(exp ast.Expression) ast.Expression { + switch node := exp.(type) { + case *ast.Identifier, *ast.IndexExpression: + default: + msg := fmt.Sprintf("expected identifier or index expression on left but got %T %#v", node, exp) + p.errors = append(p.errors, msg) + return nil } - return stmt + ae := &ast.AssignmentExpression{Token: p.curToken, Left: exp} + + p.nextToken() + + ae.Value = p.parseExpression(LOWEST) + + return ae } func (p *Parser) parseComment() ast.Statement { diff --git a/parser/parser_test.go b/parser/parser_test.go index 36c26d9..1822a6e 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -2,6 +2,7 @@ package parser import ( "fmt" + "github.com/stretchr/testify/assert" "monkey/ast" "monkey/lexer" "testing" @@ -540,6 +541,85 @@ func TestIfElseExpression(t *testing.T) { } } +func TestIfElseIfExpression(t *testing.T) { + input := `if (x < y) { x } else if (x == y) { y }` + + l := lexer.New(input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("program.Statements does not contain %d statements. got=%d\n", + 1, len(program.Statements)) + } + + stmt, ok := program.Statements[0].(*ast.ExpressionStatement) + if !ok { + t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T", + program.Statements[0]) + } + + exp, ok := stmt.Expression.(*ast.IfExpression) + if !ok { + t.Fatalf("stmt.Expression is not ast.IfExpression. got=%T", stmt.Expression) + } + + if !testInfixExpression(t, exp.Condition, "x", "<", "y") { + return + } + + if len(exp.Consequence.Statements) != 1 { + t.Errorf("consequence is not 1 statements. got=%d\n", + len(exp.Consequence.Statements)) + } + + consequence, ok := exp.Consequence.Statements[0].(*ast.ExpressionStatement) + if !ok { + t.Fatalf("Statements[0] is not ast.ExpressionStatement. got=%T", + exp.Consequence.Statements[0]) + } + + if !testIdentifier(t, consequence.Expression, "x") { + return + } + + if len(exp.Alternative.Statements) != 1 { + t.Errorf("exp.Alternative.Statements does not contain 1 statements. got=%d\n", + len(exp.Alternative.Statements)) + } + + alternative, ok := exp.Alternative.Statements[0].(*ast.ExpressionStatement) + if !ok { + t.Fatalf("Statements[0] is not ast.ExpressionStatement. got=%T", + exp.Alternative.Statements[0]) + } + + altexp, ok := alternative.Expression.(*ast.IfExpression) + if !ok { + t.Fatalf("alternative.Expression is not ast.IfExpression. got=%T", alternative.Expression) + } + + if !testInfixExpression(t, altexp.Condition, "x", "==", "y") { + return + } + + if len(altexp.Consequence.Statements) != 1 { + t.Errorf("consequence is not 1 statements. got=%d\n", + len(altexp.Consequence.Statements)) + } + + altconsequence, ok := altexp.Consequence.Statements[0].(*ast.ExpressionStatement) + if !ok { + t.Fatalf("Statements[0] is not ast.ExpressionStatement. got=%T", + exp.Consequence.Statements[0]) + } + + if !testIdentifier(t, altconsequence.Expression, "y") { + return + } +} + func TestFunctionLiteralParsing(t *testing.T) { input := `fn(x, y) { x + y; }` @@ -952,15 +1032,18 @@ func TestWhileExpression(t *testing.T) { } } -func TestAssignmentStatements(t *testing.T) { +func TestAssignmentExpressions(t *testing.T) { + assertions := assert.New(t) + tests := []struct { - input string - expectedIdentifier string - expectedValue interface{} + input string + expected string }{ - {"x = 5;", "x", 5}, - {"y = true;", "y", true}, - {"foobar = y;", "foobar", "y"}, + {"x = 5;", "x=5"}, + {"y = true;", "y=true"}, + {"foobar = y;", "foobar=y"}, + {"[1, 2, 3][1] = 4", "([1, 2, 3][1])=4"}, + {`{"a": 1}["b"] = 2`, `({a:1}[b])=2`}, } for _, tt := range tests { @@ -969,20 +1052,7 @@ func TestAssignmentStatements(t *testing.T) { program := p.ParseProgram() checkParserErrors(t, p) - if len(program.Statements) != 1 { - t.Fatalf("program.Statements does not contain 1 statements. got=%d", - len(program.Statements)) - } - - stmt := program.Statements[0] - if !testAssignmentStatement(t, stmt, tt.expectedIdentifier) { - return - } - - val := stmt.(*ast.AssignmentStatement).Value - if !testLiteralExpression(t, val, tt.expectedValue) { - return - } + assertions.Equal(tt.expected, program.String()) } } @@ -1132,32 +1202,6 @@ func checkParserErrors(t *testing.T, p *Parser) { t.FailNow() } -func testAssignmentStatement(t *testing.T, s ast.Statement, name string) bool { - if s.TokenLiteral() != "=" { - t.Errorf("s.TokenLiteral not '='. got=%q", s.TokenLiteral()) - return false - } - - assignStmt, ok := s.(*ast.AssignmentStatement) - if !ok { - t.Errorf("s not *ast.AssignmentStatement. got=%T", s) - return false - } - - if assignStmt.Name.Value != name { - t.Errorf("assignStmt.Name.Value not '%s'. got=%s", name, assignStmt.Name.Value) - return false - } - - if assignStmt.Name.TokenLiteral() != name { - t.Errorf("assignStmt.Name.TokenLiteral() not '%s'. got=%s", - name, assignStmt.Name.TokenLiteral()) - return false - } - - return true -} - func TestComments(t *testing.T) { tests := []struct { input string diff --git a/vm/vm.go b/vm/vm.go index 8141e3b..053432d 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -186,6 +186,11 @@ func (vm *VM) Run() error { vm.currentFrame().ip += 2 vm.globals[globalIndex] = vm.pop() + err := vm.push(Null) + if err != nil { + return err + } + case code.OpAssignLocal: localIndex := code.ReadUint8(ins[ip+1:]) vm.currentFrame().ip += 1 @@ -193,6 +198,11 @@ func (vm *VM) Run() error { frame := vm.currentFrame() vm.stack[frame.basePointer+int(localIndex)] = vm.pop() + err := vm.push(Null) + if err != nil { + return err + } + case code.OpGetGlobal: globalIndex := code.ReadUint16(ins[ip+1:]) vm.currentFrame().ip += 2 @@ -229,11 +239,21 @@ func (vm *VM) Run() error { return err } - case code.OpIndex: + case code.OpSetItem: + value := vm.pop() index := vm.pop() left := vm.pop() - err := vm.executeIndexExpressions(left, index) + err := vm.executeSetItem(left, index, value) + if err != nil { + return err + } + + case code.OpGetItem: + index := vm.pop() + left := vm.pop() + + err := vm.executeGetItem(left, index) if err != nil { return err } @@ -328,20 +348,39 @@ func (vm *VM) Run() error { return nil } -func (vm *VM) executeIndexExpressions(left, index object.Object) error { +func (vm *VM) executeSetItem(left, index, value object.Object) error { switch { - case left.Type() == object.STRING_OBJ && index.Type() == object.INTEGER_OBJ: - return vm.executeStringIndex(left, index) case left.Type() == object.ARRAY_OBJ && index.Type() == object.INTEGER_OBJ: - return vm.executeArrayIndex(left, index) + return vm.executeArraySetItem(left, index, value) case left.Type() == object.HASH_OBJ: - return vm.executeHashIndex(left, index) + return vm.executeHashSetItem(left, index, value) default: - return fmt.Errorf("index operator not supported: %s", left.Type()) + return fmt.Errorf( + "set item operation not supported: left=%s index=%s", + left.Type(), index.Type(), + ) } } -func (vm *VM) executeArrayIndex(array, index object.Object) error { +func (vm *VM) executeGetItem(left, index object.Object) error { + switch { + case left.Type() == object.STRING_OBJ && index.Type() == object.INTEGER_OBJ: + return vm.executeStringGetItem(left, index) + case left.Type() == object.STRING_OBJ && index.Type() == object.STRING_OBJ: + return vm.executeStringIndex(left, index) + case left.Type() == object.ARRAY_OBJ && index.Type() == object.INTEGER_OBJ: + return vm.executeArrayGetItem(left, index) + case left.Type() == object.HASH_OBJ: + return vm.executeHashGetItem(left, index) + default: + return fmt.Errorf( + "index operator not supported: left=%s index=%s", + left.Type(), index.Type(), + ) + } +} + +func (vm *VM) executeArrayGetItem(array, index object.Object) error { arrayObject := array.(*object.Array) i := index.(*object.Integer).Value max := int64(len(arrayObject.Elements) - 1) @@ -353,7 +392,20 @@ func (vm *VM) executeArrayIndex(array, index object.Object) error { return vm.push(arrayObject.Elements[i]) } -func (vm *VM) executeHashIndex(hash, index object.Object) error { +func (vm *VM) executeArraySetItem(array, index, value object.Object) error { + arrayObject := array.(*object.Array) + i := index.(*object.Integer).Value + max := int64(len(arrayObject.Elements) - 1) + + if i < 0 || i > max { + return fmt.Errorf("index out of bounds: %d", i) + } + + arrayObject.Elements[i] = value + return vm.push(Null) +} + +func (vm *VM) executeHashGetItem(hash, index object.Object) error { hashObject := hash.(*object.Hash) key, ok := index.(object.Hashable) @@ -369,6 +421,20 @@ func (vm *VM) executeHashIndex(hash, index object.Object) error { return vm.push(pair.Value) } +func (vm *VM) executeHashSetItem(hash, index, value object.Object) error { + hashObject := hash.(*object.Hash) + + key, ok := index.(object.Hashable) + if !ok { + return fmt.Errorf("unusable as hash key: %s", index.Type()) + } + + hashed := key.HashKey() + hashObject.Pairs[hashed] = object.HashPair{Key: index, Value: value} + + return vm.push(Null) +} + func (vm *VM) buildHash(startIndex, endIndex int) (object.Object, error) { hashedPairs := make(map[object.HashKey]object.HashPair) @@ -621,7 +687,7 @@ func (vm *VM) pushClosure(constIndex, numFree int) error { return vm.push(closure) } -func (vm *VM) executeStringIndex(str, index object.Object) error { +func (vm *VM) executeStringGetItem(str, index object.Object) error { stringObject := str.(*object.String) i := index.(*object.Integer).Value max := int64(len(stringObject.Value) - 1) @@ -633,6 +699,17 @@ func (vm *VM) executeStringIndex(str, index object.Object) error { return vm.push(&object.String{Value: string(stringObject.Value[i])}) } +func (vm *VM) executeStringIndex(str, index object.Object) error { + stringObject := str.(*object.String) + substr := index.(*object.String).Value + + return vm.push( + &object.Integer{ + Value: int64(strings.Index(stringObject.Value, substr)), + }, + ) +} + func (vm *VM) executeStringComparison(op code.Opcode, left, right object.Object) error { leftValue := left.(*object.String).Value rightValue := right.(*object.String).Value diff --git a/vm/vm_test.go b/vm/vm_test.go index 2f40d84..26fb79e 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -267,6 +267,8 @@ func TestConditionals(t *testing.T) { {"if (false) { 10 } else { 10; let b = 5; }", Null}, {"if (true) { let a = 5; } else { 10 }", Null}, {"let x = 0; if (true) { x = 1; }; if (false) { x = 2; }; x", 1}, + {"if (1 < 2) { 10 } else if (1 == 2) { 20 }", 10}, + {"if (1 > 2) { 10 } else if (1 == 2) { 20 } else { 30 }", 30}, } runVmTests(t, tests) @@ -437,7 +439,7 @@ func TestFirstClassFunctions(t *testing.T) { input: ` let returnsOneReturner = fn() { let returnsOne = fn() { return 1; }; - returnsOne; + return returnsOne; }; returnsOneReturner()(); `, @@ -450,13 +452,13 @@ func TestFirstClassFunctions(t *testing.T) { func TestCallingFunctionsWithBindings(t *testing.T) { tests := []vmTestCase{ - { - input: ` - let one = fn() { let one = 1; return one }; - one(); - `, - expected: 1, - }, + //{ + // input: ` + // let one = fn() { let one = 1; return one }; + // one(); + // `, + // expected: 1, + //}, { input: ` let oneAndTwo = fn() { let one = 1; let two = 2; return one + two; }; @@ -664,7 +666,7 @@ func TestClosures(t *testing.T) { { input: ` let newClosure = fn(a) { - fn() { return a; }; + return fn() { return a; }; }; let closure = newClosure(99); closure(); @@ -674,7 +676,7 @@ func TestClosures(t *testing.T) { { input: ` let newAdder = fn(a, b) { - fn(c) { return a + b + c }; + return fn(c) { return a + b + c }; }; let adder = newAdder(1, 2); adder(8); @@ -685,7 +687,7 @@ func TestClosures(t *testing.T) { input: ` let newAdder = fn(a, b) { let c = a + b; - fn(d) { return c + d }; + return fn(d) { return c + d }; }; let adder = newAdder(1, 2); adder(8); @@ -819,6 +821,8 @@ func TestIterations(t *testing.T) { {"while (false) { }", nil}, {"let n = 0; while (n < 10) { let n = n + 1 }; n", 10}, {"let n = 10; while (n > 0) { let n = n - 1 }; n", 0}, + {"let n = 0; while (n < 10) { let n = n + 1 }", nil}, + {"let n = 10; while (n > 0) { let n = n - 1 }", nil}, {"let n = 0; while (n < 10) { n = n + 1 }; n", 10}, {"let n = 10; while (n > 0) { n = n - 1 }; n", 0}, {"let n = 0; while (n < 10) { n = n + 1 }", nil}, @@ -828,9 +832,27 @@ func TestIterations(t *testing.T) { runVmTests(t, tests) } -func TestAssignmentStatements(t *testing.T) { +func TestIndexAssignmentStatements(t *testing.T) { tests := []vmTestCase{ - {"let one = 0; one = 1", 1}, + {"let xs = [1, 2, 3]; xs[1] = 4; xs[1];", 4}, + } + + runVmTests(t, tests) +} + +func TestAssignmentExpressions(t *testing.T) { + tests := []vmTestCase{ + {"let a = 0; a = 5;", nil}, + {"let a = 0; a = 5; a;", 5}, + {"let a = 0; a = 5 * 5;", nil}, + {"let a = 0; a = 5 * 5; a;", 25}, + {"let a = 0; a = 5; let b = 0; b = a;", nil}, + {"let a = 0; a = 5; let b = 0; b = a; b;", 5}, + {"let a = 0; a = 5; let b = 0; b = a; let c = 0; c = a + b + 5;", nil}, + {"let a = 0; a = 5; let b = 0; b = a; let c = 0; c = a + b + 5; c;", 15}, + {"let a = 5; let b = a; a = 0;", nil}, + {"let a = 5; let b = a; a = 0; b;", 5}, + {"let one = 0; one = 1", nil}, {"let one = 0; one = 1; one", 1}, {"let one = 0; one = 1; let two = 0; two = 2; one + two", 3}, {"let one = 0; one = 1; let two = 0; two = one + one; one + two", 3},