diff --git a/ast/ast.go b/ast/ast.go index f8298b3..037afb2 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -379,3 +379,29 @@ func (we *WhileExpression) String() string { return out.String() } + +// AssignmentStatement the `=` statement represents the AST node that rebinds +// an expression to an identifier (assigning a new value). +type AssignmentStatement struct { + Token token.Token // the token.ASSIGN token + Name *Identifier + Value Expression +} + +func (as AssignmentStatement) TokenLiteral() string { + return as.Token.Literal +} + +func (as AssignmentStatement) String() string { + var out bytes.Buffer + + out.WriteString(as.Name.String()) + out.WriteString(as.TokenLiteral() + " ") + out.WriteString(as.Value.String()) + + out.WriteString(";") + + return out.String() +} + +func (as AssignmentStatement) statementNode() {} diff --git a/code/code.go b/code/code.go index 69316b7..15a0f53 100644 --- a/code/code.go +++ b/code/code.go @@ -27,6 +27,7 @@ const ( OpJumpNotTruthy OpJump OpNull + OpAssign OpGetGlobal OpSetGlobal OpArray @@ -66,6 +67,7 @@ var definitions = map[Opcode]*Definition{ OpJumpNotTruthy: {"OpJumpNotTruthy", []int{2}}, OpJump: {"OpJump", []int{2}}, OpNull: {"OpNull", []int{}}, + OpAssign: {"OpAssign", []int{}}, OpGetGlobal: {"OpGetGlobal", []int{2}}, OpSetGlobal: {"OpSetGlobal", []int{2}}, OpArray: {"OpArray", []int{2}}, diff --git a/compiler/compiler.go b/compiler/compiler.go index 53aa540..2744f9a 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -198,6 +198,25 @@ func (c *Compiler) Compile(node ast.Node) error { } } + case *ast.AssignmentStatement: + symbol, ok := c.symbolTable.Resolve(node.Name.Value) + if !ok { + return fmt.Errorf("undefined variable %s", node.Value) + } + + err := c.Compile(node.Value) + if err != nil { + return err + } + + if symbol.Scope == GlobalScope { + c.emit(code.OpGetGlobal, symbol.Index) + } else { + c.emit(code.OpGetLocal, symbol.Index) + } + + c.emit(code.OpAssign) + case *ast.LetStatement: symbol, ok := c.symbolTable.Resolve(node.Name.Value) if !ok { diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go index 143ffb9..9e74ed6 100644 --- a/compiler/compiler_test.go +++ b/compiler/compiler_test.go @@ -767,6 +767,102 @@ func TestFunctionCalls(t *testing.T) { runCompilerTests(t, tests) } +func TestAssignmentStatementScopes(t *testing.T) { + tests := []compilerTestCase{ + { + input: ` + let num = 0; + fn() { num = 55; } + `, + expectedConstants: []interface{}{ + 0, + 55, + []code.Instructions{ + code.Make(code.OpConstant, 1), + code.Make(code.OpGetGlobal, 0), + code.Make(code.OpAssign), + code.Make(code.OpReturn), + }, + }, + expectedInstructions: []code.Instructions{ + code.Make(code.OpConstant, 0), + code.Make(code.OpSetGlobal, 0), + code.Make(code.OpClosure, 2, 0), + code.Make(code.OpPop), + }, + }, + + { + input: ` + fn() { + let num = 55; + num + } + `, + expectedConstants: []interface{}{ + 55, + []code.Instructions{ + code.Make(code.OpConstant, 0), + code.Make(code.OpSetLocal, 0), + code.Make(code.OpGetLocal, 0), + code.Make(code.OpReturnValue), + }, + }, + expectedInstructions: []code.Instructions{ + code.Make(code.OpClosure, 1, 0), + code.Make(code.OpPop), + }, + }, + { + input: ` + fn() { + let a = 55; + let b = 77; + a + b + } + `, + expectedConstants: []interface{}{ + 55, + 77, + []code.Instructions{ + code.Make(code.OpConstant, 0), + code.Make(code.OpSetLocal, 0), + code.Make(code.OpConstant, 1), + code.Make(code.OpSetLocal, 1), + code.Make(code.OpGetLocal, 0), + code.Make(code.OpGetLocal, 1), + code.Make(code.OpAdd), + code.Make(code.OpReturnValue), + }, + }, + expectedInstructions: []code.Instructions{ + code.Make(code.OpClosure, 2, 0), + code.Make(code.OpPop), + }, + }, + { + input: ` + let a = 0; + let a = a + 1; + `, + expectedConstants: []interface{}{ + 0, + 1, + }, + expectedInstructions: []code.Instructions{ + code.Make(code.OpConstant, 0), + code.Make(code.OpSetGlobal, 0), + code.Make(code.OpGetGlobal, 0), + code.Make(code.OpConstant, 1), + code.Make(code.OpAdd), + code.Make(code.OpSetGlobal, 0), + }, + }, + } + + runCompilerTests(t, tests) +} + func TestLetStatementScopes(t *testing.T) { tests := []compilerTestCase{ { diff --git a/evaluator/evaluator.go b/evaluator/evaluator.go index b1b75d5..df6210b 100644 --- a/evaluator/evaluator.go +++ b/evaluator/evaluator.go @@ -45,6 +45,26 @@ func Eval(node ast.Node, env *object.Environment) object.Object { } return &object.ReturnValue{Value: val} + case *ast.AssignmentStatement: + ident := evalIdentifier(node.Name, env) + if isError(ident) { + return ident + } + + val := Eval(node.Value, env) + if isError(val) { + return val + } + + obj, ok := ident.(object.Mutable) + if !ok { + return newError("cannot assign to %s", ident.Type()) + } + + obj.Set(val) + + return val + case *ast.LetStatement: val := Eval(node.Value, env) if isError(val) { diff --git a/evaluator/evaluator_test.go b/evaluator/evaluator_test.go index d26aee4..e9e713d 100644 --- a/evaluator/evaluator_test.go +++ b/evaluator/evaluator_test.go @@ -238,6 +238,26 @@ func TestErrorHandling(t *testing.T) { } } +func TestAssignmentStatements(t *testing.T) { + tests := []struct { + input string + expected int64 + }{ + {"let a = 0; a = 5;", 5}, + {"let a = 0; a = 5; a;", 5}, + {"let a = 0; a = 5 * 5;", 25}, + {"let a = 0; a = 5 * 5; a;", 25}, + {"let a = 0; a = 5; let b = 0; b = a;", 5}, + {"let a = 0; a = 5; let b = 0; b = a; b;", 5}, + {"let a = 0; a = 5; let b = 0; b = a; let c = 0; c = a + b + 5;", 15}, + {"let a = 0; a = 5; let b = 0; b = a; let c = 0; c = a + b + 5; c;", 15}, + } + + for _, tt := range tests { + testIntegerObject(t, testEval(tt.input), tt.expected) + } +} + func TestLetStatements(t *testing.T) { tests := []struct { input string @@ -560,10 +580,8 @@ func TestWhileExpressions(t *testing.T) { {"while (false) { }", nil}, {"let n = 0; while (n < 10) { let n = n + 1 }; n", 10}, {"let n = 10; while (n > 0) { let n = n - 1 }; n", 0}, - // FIXME: let is an expression statement and bind new values - // there is currently no assignment expressions :/ - {"let n = 0; while (n < 10) { let n = n + 1 }", nil}, - {"let n = 10; while (n > 0) { let n = n - 1 }", nil}, + {"let n = 0; while (n < 10) { n = n + 1 }", 10}, + {"let n = 10; while (n > 0) { n = n - 1 }", 0}, } for _, tt := range tests { diff --git a/examples/foo.monkey b/examples/foo.monkey new file mode 100644 index 0000000..efbe5d4 --- /dev/null +++ b/examples/foo.monkey @@ -0,0 +1,2 @@ +let a = 0; +let a = a + 1; \ No newline at end of file diff --git a/object/object.go b/object/object.go index 89eb6a2..9b6f841 100644 --- a/object/object.go +++ b/object/object.go @@ -26,6 +26,12 @@ const ( CLOSURE_OBJ = "CLOSURE" ) +// Mutable is the interface for all mutable objects which must implement +// the Set() method which rebinds its internal value for assignment statements +type Mutable interface { + Set(obj Object) +} + type Object interface { Type() ObjectType Inspect() string @@ -47,6 +53,9 @@ func (i *Integer) Type() ObjectType { func (i *Integer) Inspect() string { return fmt.Sprintf("%d", i.Value) } +func (i *Integer) Set(obj Object) { + i.Value = obj.(*Integer).Value +} type Boolean struct { Value bool @@ -55,10 +64,12 @@ type Boolean struct { func (b *Boolean) Type() ObjectType { return BOOLEAN_OBJ } - func (b *Boolean) Inspect() string { return fmt.Sprintf("%t", b.Value) } +func (i *Boolean) Set(obj Object) { + i.Value = obj.(*Boolean).Value +} type Null struct{} @@ -88,10 +99,12 @@ type Error struct { func (e *Error) Type() ObjectType { return ERROR_OBJ } - func (e *Error) Inspect() string { return "Error: " + e.Message } +func (e *Error) Set(obj Object) { + e.Message = obj.(*Error).Message +} type Function struct { Parameters []*ast.Identifier @@ -128,10 +141,12 @@ type String struct { func (s *String) Type() ObjectType { return STRING_OBJ } - func (s *String) Inspect() string { return s.Value } +func (s *String) Set(obj Object) { + s.Value = obj.(*String).Value +} type BuiltinFunction func(args ...Object) Object @@ -150,11 +165,10 @@ type Array struct { Elements []Object } -func (ao Array) Type() ObjectType { +func (ao *Array) Type() ObjectType { return ARRAY_OBJ } - -func (ao Array) Inspect() string { +func (ao *Array) Inspect() string { var out bytes.Buffer elements := []string{} @@ -168,6 +182,9 @@ func (ao Array) Inspect() string { return out.String() } +func (ao *Array) Set(obj Object) { + ao.Elements = obj.(*Array).Elements +} type HashKey struct { Type ObjectType @@ -213,7 +230,6 @@ type Hash struct { func (h *Hash) Type() ObjectType { return HASH_OBJ } - func (h *Hash) Inspect() string { var out bytes.Buffer @@ -228,6 +244,9 @@ func (h *Hash) Inspect() string { return out.String() } +func (h *Hash) Set(obj Object) { + h.Pairs = obj.(*Hash).Pairs +} func (cf *CompiledFunction) Type() ObjectType { return COMPILED_FUNCTION_OBJ diff --git a/parser/parser.go b/parser/parser.go index 5ce4316..cebd535 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -127,6 +127,11 @@ func (p *Parser) noPrefixParseFnError(t token.TokenType) { p.errors = append(p.errors, msg) } +func (p *Parser) noInfixParseFnError(t token.TokenType) { + msg := fmt.Sprintf("no infix parse function for %s found", t) + p.errors = append(p.errors, msg) +} + func (p *Parser) ParseProgram() *ast.Program { program := &ast.Program{} program.Statements = []ast.Statement{} @@ -143,6 +148,10 @@ func (p *Parser) ParseProgram() *ast.Program { } func (p *Parser) parseStatement() ast.Statement { + if p.peekToken.Type == token.ASSIGN { + return p.parseAssignmentStatement() + } + switch p.curToken.Type { case token.LET: return p.parseLetStatement() @@ -218,7 +227,8 @@ func (p *Parser) parseExpression(precedence int) ast.Expression { for !p.peekTokenIs(token.SEMICOLON) && precedence < p.peekPrecedence() { infix := p.infixParseFns[p.peekToken.Type] if infix == nil { - return leftExp + p.noInfixParseFnError(p.peekToken.Type) + //return leftExp } p.nextToken() @@ -516,3 +526,19 @@ func (p *Parser) parseWhileExpression() ast.Expression { return expression } + +func (p *Parser) parseAssignmentStatement() ast.Statement { + stmt := &ast.AssignmentStatement{Token: p.peekToken} + stmt.Name = &ast.Identifier{Token: p.curToken, Value: p.curToken.Literal} + + p.nextToken() + p.nextToken() + + stmt.Value = p.parseExpression(LOWEST) + + if p.peekTokenIs(token.SEMICOLON) { + p.nextToken() + } + + return stmt +} diff --git a/parser/parser_test.go b/parser/parser_test.go index 92cb250..f340e22 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -921,6 +921,40 @@ func TestWhileExpression(t *testing.T) { } } +func TestAssignmentStatements(t *testing.T) { + tests := []struct { + input string + expectedIdentifier string + expectedValue interface{} + }{ + {"x = 5;", "x", 5}, + {"y = true;", "y", true}, + {"foobar = y;", "foobar", "y"}, + } + + for _, tt := range tests { + l := lexer.New(tt.input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("program.Statements does not contain 1 statements. got=%d", + len(program.Statements)) + } + + stmt := program.Statements[0] + if !testAssignmentStatement(t, stmt, tt.expectedIdentifier) { + return + } + + val := stmt.(*ast.AssignmentStatement).Value + if !testLiteralExpression(t, val, tt.expectedValue) { + return + } + } +} + func testLetStatement(t *testing.T, s ast.Statement, name string) bool { if s.TokenLiteral() != "let" { t.Errorf("s.TokenLiteral not 'let'. got=%q", s.TokenLiteral()) @@ -1066,3 +1100,29 @@ func checkParserErrors(t *testing.T, p *Parser) { } t.FailNow() } + +func testAssignmentStatement(t *testing.T, s ast.Statement, name string) bool { + if s.TokenLiteral() != "=" { + t.Errorf("s.TokenLiteral not '='. got=%q", s.TokenLiteral()) + return false + } + + assignStmt, ok := s.(*ast.AssignmentStatement) + if !ok { + t.Errorf("s not *ast.AssignmentStatement. got=%T", s) + return false + } + + if assignStmt.Name.Value != name { + t.Errorf("assignStmt.Name.Value not '%s'. got=%s", name, assignStmt.Name.Value) + return false + } + + if assignStmt.Name.TokenLiteral() != name { + t.Errorf("assignStmt.Name.TokenLiteral() not '%s'. got=%s", + name, assignStmt.Name.TokenLiteral()) + return false + } + + return true +} diff --git a/vim/monkey.vim b/vim/monkey.vim index 945e6b7..6684ba6 100644 --- a/vim/monkey.vim +++ b/vim/monkey.vim @@ -18,7 +18,7 @@ syntax keyword xKeyword let fn if else return while syntax keyword xFunction len input print first last rest push pop exit syntax keyword xOperator == != < > ! -syntax keyword xOperator + - * / +syntax keyword xOperator + - * / = syntax region xString start=/"/ skip=/\\./ end=/"/ diff --git a/vm/vm.go b/vm/vm.go index 59fe46e..1baa48b 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -158,6 +158,17 @@ func (vm *VM) Run() error { vm.globals[globalIndex] = vm.pop() + case code.OpAssign: + ident := vm.pop() + val := vm.pop() + + obj, ok := ident.(object.Mutable) + if !ok { + return fmt.Errorf("cannot assign to %s", ident.Type()) + } + + obj.Set(val) + case code.OpGetGlobal: globalIndex := code.ReadUint16(ins[ip+1:]) vm.currentFrame().ip += 2 diff --git a/vm/vm_test.go b/vm/vm_test.go index 6708c6c..9bdfcc1 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -801,10 +801,21 @@ func TestIterations(t *testing.T) { {"while (false) { }", nil}, {"let n = 0; while (n < 10) { let n = n + 1 }; n", 10}, {"let n = 10; while (n > 0) { let n = n - 1 }; n", 0}, - // FIXME: let is an expression statement and bind new values - // there is currently no assignment expressions :/ - {"let n = 0; while (n < 10) { let n = n + 1 }", nil}, - {"let n = 10; while (n > 0) { let n = n - 1 }", nil}, + {"let n = 0; while (n < 10) { n = n + 1 }; n", 10}, + {"let n = 10; while (n > 0) { n = n - 1 }; n", 0}, + {"let n = 0; while (n < 10) { n = n + 1 }", nil}, + {"let n = 10; while (n > 0) { n = n - 1 }", nil}, + } + + runVmTests(t, tests) +} + +func TestAssignmentStatements(t *testing.T) { + tests := []vmTestCase{ + {"let one = 0; one = 1", 1}, + {"let one = 0; one = 1; one", 1}, + {"let one = 0; one = 1; let two = 0; two = 2; one + two", 3}, + {"let one = 0; one = 1; let two = 0; two = one + one; one + two", 3}, } runVmTests(t, tests)