diff --git a/ast/ast.go b/ast/ast.go index 0035e27..9bcf0ca 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -428,3 +428,25 @@ func (c *Comment) String() string { return out.String() } + +// ImportExpression represents an `import` expression and holds the name +// of the module being imported. +type ImportExpression struct { + Token token.Token // The 'import' token + Name Expression +} + +func (ie *ImportExpression) TokenLiteral() string { + return ie.Token.Literal +} +func (ie *ImportExpression) String() string { + var out bytes.Buffer + + out.WriteString(ie.TokenLiteral()) + out.WriteString("(") + out.WriteString(fmt.Sprintf("\"%s\"", ie.Name)) + out.WriteString(")") + + return out.String() +} +func (ie *ImportExpression) expressionNode() {} diff --git a/builtins/builtins.go b/builtins/builtins.go index 0e4186a..331eb5b 100644 --- a/builtins/builtins.go +++ b/builtins/builtins.go @@ -8,43 +8,43 @@ import ( // Builtins ... var Builtins = map[string]*object.Builtin{ - "len": {Name: "len", Fn: Len}, - "input": {Name: "input", Fn: Input}, - "print": {Name: "print", Fn: Print}, - "first": {Name: "first", Fn: First}, - "last": {Name: "last", Fn: Last}, - "rest": {Name: "rest", Fn: Rest}, - "push": {Name: "push", Fn: Push}, - "pop": {Name: "pop", Fn: Pop}, - "exit": {Name: "exit", Fn: Exit}, - "assert": {Name: "assert", Fn: Assert}, - "bool": {Name: "bool", Fn: Bool}, - "int": {Name: "int", Fn: Int}, - "str": {Name: "str", Fn: Str}, - "type": {Name: "type", Fn: TypeOf}, - "args": {Name: "args", Fn: Args}, - "lower": {Name: "lower", Fn: Lower}, - "upper": {Name: "upper", Fn: Upper}, - "join": {Name: "join", Fn: Join}, - "split": {Name: "split", Fn: Split}, - "find": {Name: "find", Fn: Find}, - "read": {Name: "read", Fn: Read}, - "write": {Name: "write", Fn: Write}, - "ffi": {Name: "ffi", Fn: FFI}, - "abs": {Name: "abs", Fn: Abs}, - "bin": {Name: "bin", Fn: Bin}, - "hex": {Name: "hex", Fn: Hex}, - "ord": {Name: "ord", Fn: Ord}, - "chr": {Name: "chr", Fn: Chr}, - "divmod": {Name: "divmod", Fn: Divmod}, - "hash": {Name: "hash", Fn: HashOf}, - "id": {Name: "id", Fn: IdOf}, - "oct": {Name: "oct", Fn: Oct}, - "pow": {Name: "pow", Fn: Pow}, - "min": {Name: "min", Fn: Min}, - "max": {Name: "max", Fn: Max}, - "sorted": {Name: "sorted", Fn: Sorted}, - "reversed": {Name: "reversed", Fn: Reversed}, + "len": {Name: "len", Fn: Len}, + "input": {Name: "input", Fn: Input}, + "print": {Name: "print", Fn: Print}, + "first": {Name: "first", Fn: First}, + "last": {Name: "last", Fn: Last}, + "rest": {Name: "rest", Fn: Rest}, + "push": {Name: "push", Fn: Push}, + "pop": {Name: "pop", Fn: Pop}, + "exit": {Name: "exit", Fn: Exit}, + "assert": {Name: "assert", Fn: Assert}, + "bool": {Name: "bool", Fn: Bool}, + "int": {Name: "int", Fn: Int}, + "str": {Name: "str", Fn: Str}, + "type": {Name: "type", Fn: TypeOf}, + "args": {Name: "args", Fn: Args}, + "lower": {Name: "lower", Fn: Lower}, + "upper": {Name: "upper", Fn: Upper}, + "join": {Name: "join", Fn: Join}, + "split": {Name: "split", Fn: Split}, + "find": {Name: "find", Fn: Find}, + "readfile": {Name: "readfile", Fn: ReadFile}, + "writefile": {Name: "writefile", Fn: WriteFile}, + "ffi": {Name: "ffi", Fn: FFI}, + "abs": {Name: "abs", Fn: Abs}, + "bin": {Name: "bin", Fn: Bin}, + "hex": {Name: "hex", Fn: Hex}, + "ord": {Name: "ord", Fn: Ord}, + "chr": {Name: "chr", Fn: Chr}, + "divmod": {Name: "divmod", Fn: Divmod}, + "hash": {Name: "hash", Fn: HashOf}, + "id": {Name: "id", Fn: IdOf}, + "oct": {Name: "oct", Fn: Oct}, + "pow": {Name: "pow", Fn: Pow}, + "min": {Name: "min", Fn: Min}, + "max": {Name: "max", Fn: Max}, + "sorted": {Name: "sorted", Fn: Sorted}, + "reversed": {Name: "reversed", Fn: Reversed}, } // BuiltinsIndex ... diff --git a/builtins/read.go b/builtins/readfile.go similarity index 83% rename from builtins/read.go rename to builtins/readfile.go index 759bba9..7a0cddd 100644 --- a/builtins/read.go +++ b/builtins/readfile.go @@ -6,10 +6,10 @@ import ( "monkey/typing" ) -// Read ... -func Read(args ...object.Object) object.Object { +// ReadFile ... +func ReadFile(args ...object.Object) object.Object { if err := typing.Check( - "read", args, + "readfile", args, typing.ExactArgs(1), typing.WithTypes(object.STRING_OBJ), ); err != nil { diff --git a/builtins/write.go b/builtins/writefile.go similarity index 84% rename from builtins/write.go rename to builtins/writefile.go index 0f907bb..66d0cda 100644 --- a/builtins/write.go +++ b/builtins/writefile.go @@ -6,10 +6,10 @@ import ( "monkey/typing" ) -// Write ... -func Write(args ...object.Object) object.Object { +// WriteFile ... +func WriteFile(args ...object.Object) object.Object { if err := typing.Check( - "write", args, + "writefile", args, typing.ExactArgs(2), typing.WithTypes(object.STRING_OBJ, object.STRING_OBJ), ); err != nil { diff --git a/code/code.go b/code/code.go index 8cc0229..af47be0 100644 --- a/code/code.go +++ b/code/code.go @@ -61,7 +61,7 @@ const ( OpClosure OpGetFree OpCurrentClosure - OpNoop + OpLoadModule ) type Definition struct { @@ -112,7 +112,7 @@ var definitions = map[Opcode]*Definition{ OpClosure: {"OpClosure", []int{2, 1}}, OpGetFree: {"OpGetFree", []int{1}}, OpCurrentClosure: {"OpCurrentClosure", []int{}}, - OpNoop: {"OpNoop", []int{}}, + OpLoadModule: {"OpLoadModule", []int{}}, } func Lookup(op byte) (*Definition, error) { diff --git a/compiler/compiler.go b/compiler/compiler.go index dbdedd0..9833e62 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -513,6 +513,16 @@ func (c *Compiler) Compile(node ast.Node) error { afterConsequencePos := c.emit(code.OpNull) c.changeOperand(jumpIfFalsePos, afterConsequencePos) + case *ast.ImportExpression: + c.l++ + err := c.Compile(node.Name) + if err != nil { + return err + } + c.l-- + + c.emit(code.OpLoadModule) + } return nil diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go index 8f874d4..cb12c27 100644 --- a/compiler/compiler_test.go +++ b/compiler/compiler_test.go @@ -1031,6 +1031,18 @@ func TestIteration(t *testing.T) { runCompilerTests(t, tests) } +func TestImportExpressions(t *testing.T) { + tests := []compilerTestCase2{ + { + input: `import("foo")`, + constants: []interface{}{"foo"}, + instructions: "0000 OpConstant 0\n0003 OpLoadModule\n0004 OpPop\n", + }, + } + + runCompilerTests2(t, tests) +} + func runCompilerTests(t *testing.T, tests []compilerTestCase) { t.Helper() diff --git a/compiler/symbol_table.go b/compiler/symbol_table.go index 3aabc81..659537b 100644 --- a/compiler/symbol_table.go +++ b/compiler/symbol_table.go @@ -19,7 +19,7 @@ type Symbol struct { type SymbolTable struct { Outer *SymbolTable - store map[string]Symbol + Store map[string]Symbol numDefinitions int FreeSymbols []Symbol @@ -34,7 +34,7 @@ func NewEnclosedSymbolTable(outer *SymbolTable) *SymbolTable { func NewSymbolTable() *SymbolTable { s := make(map[string]Symbol) free := []Symbol{} - return &SymbolTable{store: s, FreeSymbols: free} + return &SymbolTable{Store: s, FreeSymbols: free} } func (s *SymbolTable) Define(name string) Symbol { @@ -45,13 +45,13 @@ func (s *SymbolTable) Define(name string) Symbol { symbol.Scope = LocalScope } - s.store[name] = symbol + s.Store[name] = symbol s.numDefinitions++ return symbol } func (s *SymbolTable) Resolve(name string) (Symbol, bool) { - obj, ok := s.store[name] + obj, ok := s.Store[name] if !ok && s.Outer != nil { obj, ok = s.Outer.Resolve(name) if !ok { @@ -72,7 +72,7 @@ func (s *SymbolTable) Resolve(name string) (Symbol, bool) { func (s *SymbolTable) DefineBuiltin(index int, name string) Symbol { symbol := Symbol{Name: name, Index: index, Scope: BuiltinScope} - s.store[name] = symbol + s.Store[name] = symbol return symbol } @@ -82,12 +82,12 @@ func (s *SymbolTable) DefineFree(original Symbol) Symbol { symbol := Symbol{Name: original.Name, Index: len(s.FreeSymbols) - 1} symbol.Scope = FreeScope - s.store[original.Name] = symbol + s.Store[original.Name] = symbol return symbol } func (s *SymbolTable) DefineFunctionName(name string) Symbol { symbol := Symbol{Name: name, Index: 0, Scope: FunctionScope} - s.store[name] = symbol + s.Store[name] = symbol return symbol } diff --git a/evaluator/evaluator.go b/evaluator/evaluator.go index 59ffd43..cde90b4 100644 --- a/evaluator/evaluator.go +++ b/evaluator/evaluator.go @@ -4,7 +4,11 @@ import ( "fmt" "monkey/ast" "monkey/builtins" + "monkey/lexer" "monkey/object" + "monkey/parser" + "monkey/utils" + "os" "strings" ) @@ -40,6 +44,9 @@ func Eval(node ast.Node, env *object.Environment) object.Object { case *ast.WhileExpression: return evalWhileExpression(node, env) + case *ast.ImportExpression: + return evalImportExpression(node, env) + case *ast.ReturnStatement: val := Eval(node.ReturnValue, env) if isError(val) { @@ -190,6 +197,22 @@ func Eval(node ast.Node, env *object.Environment) object.Object { return nil } +func evalImportExpression(ie *ast.ImportExpression, env *object.Environment) object.Object { + name := Eval(ie.Name, env) + if isError(name) { + return name + } + + if s, ok := name.(*object.String); ok { + attrs := EvalModule(s.Value) + if isError(attrs) { + return attrs + } + return &object.Module{Name: s.Value, Attrs: attrs} + } + return newError("ImportError: invalid import path '%s'", name) +} + func evalWhileExpression(we *ast.WhileExpression, env *object.Environment) object.Object { var result object.Object @@ -208,9 +231,9 @@ func evalWhileExpression(we *ast.WhileExpression, env *object.Environment) objec if result != nil { return result - } else { - return NULL } + + return NULL } func evalProgram(program *ast.Program, env *object.Environment) object.Object { @@ -485,6 +508,29 @@ func newError(format string, a ...interface{}) *object.Error { return &object.Error{Message: fmt.Sprintf(format, a...)} } +// EvalModule evaluates the named module and returns a *object.Module objec +func EvalModule(name string) object.Object { + filename := utils.FindModule(name) + + b, err := os.ReadFile(filename) + if err != nil { + return newError("IOError: error reading module '%s': %s", name, err) + } + + l := lexer.New(string(b)) + p := parser.New(l) + + module := p.ParseProgram() + if len(p.Errors()) != 0 { + return newError("ParseError: %s", p.Errors()) + } + + env := object.NewEnvironment() + Eval(module, env) + + return env.ExportedHash() +} + func evalIdentifier(node *ast.Identifier, env *object.Environment) object.Object { if val, ok := env.Get(node.Value); ok { return val @@ -557,11 +603,18 @@ func evalIndexExpression(left, index object.Object) object.Object { return evalArrayIndexExpression(left, index) case left.Type() == object.HASH_OBJ: return evalHashIndexExpression(left, index) + case left.Type() == object.MODULE_OBJ: + return EvalModuleIndexExpression(left, index) default: return newError("index operator not supported: %s", left.Type()) } } +func EvalModuleIndexExpression(module, index object.Object) object.Object { + moduleObject := module.(*object.Module) + return evalHashIndexExpression(moduleObject.Attrs, index) +} + func evalStringIndexExpression(str, index object.Object) object.Object { stringObject := str.(*object.String) idx := index.(*object.Integer).Value diff --git a/evaluator/evaluator_test.go b/evaluator/evaluator_test.go index b759c85..f05589b 100644 --- a/evaluator/evaluator_test.go +++ b/evaluator/evaluator_test.go @@ -2,14 +2,51 @@ package evaluator import ( "errors" + "github.com/stretchr/testify/assert" "monkey/lexer" "monkey/object" "monkey/parser" + "monkey/utils" "os" "path/filepath" "testing" ) +func assertEvaluated(t *testing.T, expected interface{}, actual object.Object) { + t.Helper() + + assert := assert.New(t) + + switch expected.(type) { + case nil: + if _, ok := actual.(*object.Null); ok { + assert.True(ok) + } else { + assert.Equal(expected, actual) + } + case int: + if i, ok := actual.(*object.Integer); ok { + assert.Equal(int64(expected.(int)), i.Value) + } else { + assert.Equal(expected, actual) + } + case error: + if e, ok := actual.(*object.Integer); ok { + assert.Equal(expected.(error).Error(), e.Value) + } else { + assert.Equal(expected, actual) + } + case string: + if s, ok := actual.(*object.String); ok { + assert.Equal(expected.(string), s.Value) + } else { + assert.Equal(expected, actual) + } + default: + t.Fatalf("unsupported type for expected got=%T", expected) + } +} + func TestEvalExpressions(t *testing.T) { tests := []struct { input string @@ -814,6 +851,37 @@ func testStringObject(t *testing.T, obj object.Object, expected string) bool { } return true } +func TestImportExpressions(t *testing.T) { + tests := []struct { + input string + expected interface{} + }{ + {`mod := import("../testdata/mod"); mod.A`, 5}, + {`mod := import("../testdata/mod"); mod.Sum(2, 3)`, 5}, + {`mod := import("../testdata/mod"); mod.a`, nil}, + } + + for _, tt := range tests { + evaluated := testEval(tt.input) + assertEvaluated(t, tt.expected, evaluated) + } +} + +func TestImportSearchPaths(t *testing.T) { + utils.AddPath("../testdata") + + tests := []struct { + input string + expected interface{} + }{ + {`mod := import("mod"); mod.A`, 5}, + } + + for _, tt := range tests { + evaluated := testEval(tt.input) + assertEvaluated(t, tt.expected, evaluated) + } +} func TestExamples(t *testing.T) { matches, err := filepath.Glob("../examples/*.monkey") diff --git a/object/environment.go b/object/environment.go index 6f31070..83fcc1d 100644 --- a/object/environment.go +++ b/object/environment.go @@ -1,5 +1,7 @@ package object +import "unicode" + func NewEnclosedEnvironment(outer *Environment) *Environment { env := NewEnvironment() env.outer = outer @@ -28,3 +30,18 @@ func (e *Environment) Set(name string, val Object) Object { e.store[name] = val return val } + +// ExportedHash returns a new Hash with the names and values of every publicly +// exported binding in the environment. That is every binding that starts with a +// capital letter. This is used by the module import system to wrap up the +// evaluated module into an object. +func (e *Environment) ExportedHash() *Hash { + pairs := make(map[HashKey]HashPair) + for k, v := range e.store { + if unicode.IsUpper(rune(k[0])) { + s := &String{Value: k} + pairs[s.HashKey()] = HashPair{Key: s, Value: v} + } + } + return &Hash{Pairs: pairs} +} diff --git a/object/module.go b/object/module.go new file mode 100644 index 0000000..a787b86 --- /dev/null +++ b/object/module.go @@ -0,0 +1,29 @@ +package object + +import "fmt" + +// Module is the module type used to represent a collection of variabels. +type Module struct { + Name string + Attrs Object +} + +func (m Module) String() string { + return m.Inspect() +} + +func (m Module) Type() ObjectType { + return MODULE_OBJ +} + +func (m Module) Bool() bool { + return true +} + +func (m Module) Inspect() string { + return fmt.Sprintf("", m.Name) +} + +func (m Module) Compare(other Object) int { + return 1 +} diff --git a/object/object.go b/object/object.go index 0f25d5f..b50c68f 100644 --- a/object/object.go +++ b/object/object.go @@ -18,6 +18,7 @@ const ( HASH_OBJ = "hash" COMPILED_FUNCTION_OBJ = "COMPILED_FUNCTION" CLOSURE_OBJ = "closure" + MODULE_OBJ = "module" ) // Comparable is the interface for comparing two Object and their underlying diff --git a/parser/parser.go b/parser/parser.go index 1fd0b64..20702de 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -87,6 +87,7 @@ func New(l *lexer.Lexer) *Parser { p.registerPrefix(token.LPAREN, p.parseGroupedExpression) p.registerPrefix(token.IF, p.parseIfExpression) p.registerPrefix(token.FUNCTION, p.parseFunctionLiteral) + p.registerPrefix(token.IMPORT, p.parseImportExpression) p.registerPrefix(token.WHILE, p.parseWhileExpression) p.registerPrefix(token.STRING, p.parseStringLiteral) p.registerPrefix(token.LBRACKET, p.parseArrayLiteral) @@ -601,3 +602,20 @@ func (p *Parser) parseBindExpression(expression ast.Expression) ast.Expression { return be } + +func (p *Parser) parseImportExpression() ast.Expression { + expression := &ast.ImportExpression{Token: p.curToken} + + if !p.expectPeek(token.LPAREN) { + return nil + } + + p.nextToken() + expression.Name = p.parseExpression(LOWEST) + + if !p.expectPeek(token.RPAREN) { + return nil + } + + return expression +} diff --git a/parser/parser_test.go b/parser/parser_test.go index 9e5f60c..831d41e 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -1227,3 +1227,23 @@ func testComment(t *testing.T, s ast.Statement, expected string) bool { return true } + +func TestParsingImportExpressions(t *testing.T) { + assert := assert.New(t) + + tests := []struct { + input string + expected string + }{ + {`import("mod")`, `import("mod")`}, + } + + for _, tt := range tests { + l := lexer.New(tt.input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + assert.Equal(tt.expected, program.String()) + } +} diff --git a/repl/repl.go b/repl/repl.go index e3876e8..7f7a475 100644 --- a/repl/repl.go +++ b/repl/repl.go @@ -8,7 +8,6 @@ import ( "fmt" "io" "log" - "monkey/builtins" "monkey/compiler" "monkey/evaluator" "monkey/lexer" @@ -42,25 +41,6 @@ type Options struct { Interactive bool } -type VMState struct { - constants []object.Object - globals []object.Object - symbols *compiler.SymbolTable -} - -func NewVMState() *VMState { - symbolTable := compiler.NewSymbolTable() - for i, builtin := range builtins.BuiltinsIndex { - symbolTable.DefineBuiltin(i, builtin.Name) - } - - return &VMState{ - constants: []object.Object{}, - globals: make([]object.Object, vm.GlobalsSize), - symbols: symbolTable, - } -} - type REPL struct { user string args []string @@ -101,14 +81,14 @@ func (r *REPL) Eval(f io.Reader) (env *object.Environment) { // Exec parses, compiles and executes the program given by f and returns // the resulting virtual machine, any errors are printed to stderr -func (r *REPL) Exec(f io.Reader) (state *VMState) { +func (r *REPL) Exec(f io.Reader) (state *vm.VMState) { b, err := io.ReadAll(f) if err != nil { fmt.Fprintf(os.Stderr, "error reading source file: %s", err) return } - state = NewVMState() + state = vm.NewVMState() l := lexer.New(string(b)) p := parser.New(l) @@ -119,7 +99,7 @@ func (r *REPL) Exec(f io.Reader) (state *VMState) { return } - c := compiler.NewWithState(state.symbols, state.constants) + c := compiler.NewWithState(state.Symbols, state.Constants) c.Debug = r.opts.Debug err = c.Compile(program) if err != nil { @@ -128,9 +108,9 @@ func (r *REPL) Exec(f io.Reader) (state *VMState) { } code := c.Bytecode() - state.constants = code.Constants + state.Constants = code.Constants - machine := vm.NewWithGlobalState(code, state.globals) + machine := vm.NewWithState(code, state) machine.Debug = r.opts.Debug err = machine.Run() if err != nil { @@ -176,11 +156,11 @@ func (r *REPL) StartEvalLoop(in io.Reader, out io.Writer, env *object.Environmen } // StartExecLoop starts the REPL in a continious exec loop -func (r *REPL) StartExecLoop(in io.Reader, out io.Writer, state *VMState) { +func (r *REPL) StartExecLoop(in io.Reader, out io.Writer, state *vm.VMState) { scanner := bufio.NewScanner(in) if state == nil { - state = NewVMState() + state = vm.NewVMState() } for { @@ -201,7 +181,7 @@ func (r *REPL) StartExecLoop(in io.Reader, out io.Writer, state *VMState) { continue } - c := compiler.NewWithState(state.symbols, state.constants) + c := compiler.NewWithState(state.Symbols, state.Constants) c.Debug = r.opts.Debug err := c.Compile(program) if err != nil { @@ -210,9 +190,9 @@ func (r *REPL) StartExecLoop(in io.Reader, out io.Writer, state *VMState) { } code := c.Bytecode() - state.constants = code.Constants + state.Constants = code.Constants - machine := vm.NewWithGlobalState(code, state.globals) + machine := vm.NewWithState(code, state) machine.Debug = r.opts.Debug err = machine.Run() if err != nil { diff --git a/testdata/mod.monkey b/testdata/mod.monkey new file mode 100644 index 0000000..d0ddcc3 --- /dev/null +++ b/testdata/mod.monkey @@ -0,0 +1,3 @@ +a := 1 +A := 5 +Sum := fn(a, b) { return a + b } \ No newline at end of file diff --git a/token/token.go b/token/token.go index 66ad737..da6ff0e 100644 --- a/token/token.go +++ b/token/token.go @@ -84,6 +84,8 @@ const ( ELSE = "ELSE" RETURN = "RETURN" WHILE = "WHILE" + // IMPORT the `import` keyword (import) + IMPORT = "IMPORT" ) var keywords = map[string]TokenType{ @@ -95,6 +97,7 @@ var keywords = map[string]TokenType{ "else": ELSE, "return": RETURN, "while": WHILE, + "import": IMPORT, } func LookupIdent(ident string) TokenType { diff --git a/utils/utils.go b/utils/utils.go new file mode 100644 index 0000000..32f4cc7 --- /dev/null +++ b/utils/utils.go @@ -0,0 +1,53 @@ +package utils + +import ( + "fmt" + "log" + "os" + "path/filepath" + "strings" +) + +var SearchPaths []string + +func init() { + cwd, err := os.Getwd() + if err != nil { + log.Fatalf("error getting cwd: %s", err) + } + + if e := os.Getenv("MONKEYPATH"); e != "" { + tokens := strings.Split(e, ":") + for _, token := range tokens { + AddPath(token) // ignore errors + } + } else { + SearchPaths = append(SearchPaths, cwd) + } +} + +func AddPath(path string) error { + path = os.ExpandEnv(filepath.Clean(path)) + absPath, err := filepath.Abs(path) + if err != nil { + return err + } + SearchPaths = append(SearchPaths, absPath) + return nil +} + +func Exists(path string) bool { + _, err := os.Stat(path) + return err == nil +} + +func FindModule(name string) string { + basename := fmt.Sprintf("%s.monkey", name) + for _, p := range SearchPaths { + filename := filepath.Join(p, basename) + if Exists(filename) { + return filename + } + } + return "" +} diff --git a/vm/vm.go b/vm/vm.go index a59e51a..4a96ce2 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -2,12 +2,17 @@ package vm import ( "fmt" + "io/ioutil" "log" "monkey/builtins" "monkey/code" "monkey/compiler" + "monkey/lexer" "monkey/object" + "monkey/parser" + "monkey/utils" "strings" + "unicode" ) const StackSize = 2048 @@ -18,16 +23,88 @@ var Null = &object.Null{} var True = &object.Boolean{Value: true} var False = &object.Boolean{Value: false} +// ExecModule compiles the named module and returns a *object.Module object +func ExecModule(name string, state *VMState) (object.Object, error) { + filename := utils.FindModule(name) + if filename == "" { + return nil, fmt.Errorf("ImportError: no module named '%s'", name) + } + + b, err := ioutil.ReadFile(filename) + if err != nil { + return nil, fmt.Errorf("IOError: error reading module '%s': %s", name, err) + } + + l := lexer.New(string(b)) + p := parser.New(l) + + module := p.ParseProgram() + if len(p.Errors()) != 0 { + return nil, fmt.Errorf("ParseError: %s", p.Errors()) + } + + c := compiler.NewWithState(state.Symbols, state.Constants) + err = c.Compile(module) + if err != nil { + return nil, fmt.Errorf("CompileError: %s", err) + } + + code := c.Bytecode() + state.Constants = code.Constants + + machine := NewWithState(code, state) + err = machine.Run() + if err != nil { + return nil, fmt.Errorf("RuntimeError: error loading module '%s'", err) + } + + return state.ExportedHash(), nil +} + +type VMState struct { + Constants []object.Object + Globals []object.Object + Symbols *compiler.SymbolTable +} + +func NewVMState() *VMState { + symbolTable := compiler.NewSymbolTable() + for i, builtin := range builtins.BuiltinsIndex { + symbolTable.DefineBuiltin(i, builtin.Name) + } + + return &VMState{ + Constants: []object.Object{}, + Globals: make([]object.Object, GlobalsSize), + Symbols: symbolTable, + } +} + +// exported binding in the vm state. That is every binding that starts with a +// capital letter. This is used by the module import system to wrap up the +// compiled and evaulated module into an object. +func (s *VMState) ExportedHash() *object.Hash { + pairs := make(map[object.HashKey]object.HashPair) + for name, symbol := range s.Symbols.Store { + if unicode.IsUpper(rune(name[0])) { + if symbol.Scope == compiler.GlobalScope { + obj := s.Globals[symbol.Index] + s := &object.String{Value: name} + pairs[s.HashKey()] = object.HashPair{Key: s, Value: obj} + } + } + } + return &object.Hash{Pairs: pairs} +} + type VM struct { Debug bool - constants []object.Object + state *VMState stack []object.Object sp int // Always points to the next value. Top of stack is stack[sp-1] - globals []object.Object - frames []*Frame framesIndex int } @@ -40,23 +117,37 @@ func New(bytecode *compiler.Bytecode) *VM { frames := make([]*Frame, MaxFrames) frames[0] = mainFrame + state := NewVMState() + state.Constants = bytecode.Constants + return &VM{ - constants: bytecode.Constants, + state: state, stack: make([]object.Object, StackSize), sp: 0, - globals: make([]object.Object, GlobalsSize), - frames: frames, framesIndex: 1, } } -func NewWithGlobalState(bytecode *compiler.Bytecode, s []object.Object) *VM { - vm := New(bytecode) - vm.globals = s - return vm +func NewWithState(bytecode *compiler.Bytecode, state *VMState) *VM { + mainFn := &object.CompiledFunction{Instructions: bytecode.Instructions} + mainClosure := &object.Closure{Fn: mainFn} + mainFrame := NewFrame(mainClosure, 0) + + frames := make([]*Frame, MaxFrames) + frames[0] = mainFrame + + return &VM{ + state: state, + + frames: frames, + framesIndex: 1, + + stack: make([]object.Object, StackSize), + sp: 0, + } } func (vm *VM) currentFrame() *Frame { @@ -73,6 +164,24 @@ func (vm *VM) popFrame() *Frame { return vm.frames[vm.framesIndex] } +func (vm *VM) loadModule(name object.Object) error { + s, ok := name.(*object.String) + if !ok { + return fmt.Errorf( + "TypeError: import() expected argument #1 to be `str` got `%s`", + name.Type(), + ) + } + + attrs, err := ExecModule(s.Value, vm.state) + if err != nil { + return err + } + + module := &object.Module{Name: s.Value, Attrs: attrs} + return vm.push(module) +} + func (vm *VM) LastPoppedStackElem() object.Object { return vm.stack[vm.sp] } @@ -108,7 +217,7 @@ func (vm *VM) Run() error { constIndex := code.ReadUint16(ins[ip+1:]) vm.currentFrame().ip += 2 - err := vm.push(vm.constants[constIndex]) + err := vm.push(vm.state.Constants[constIndex]) if err != nil { return err } @@ -185,9 +294,9 @@ func (vm *VM) Run() error { ref := vm.pop() if immutable, ok := ref.(object.Immutable); ok { - vm.globals[globalIndex] = immutable.Clone() + vm.state.Globals[globalIndex] = immutable.Clone() } else { - vm.globals[globalIndex] = ref + vm.state.Globals[globalIndex] = ref } err := vm.push(Null) @@ -198,7 +307,7 @@ func (vm *VM) Run() error { case code.OpAssignGlobal: globalIndex := code.ReadUint16(ins[ip+1:]) vm.currentFrame().ip += 2 - vm.globals[globalIndex] = vm.pop() + vm.state.Globals[globalIndex] = vm.pop() err := vm.push(Null) if err != nil { @@ -221,7 +330,7 @@ func (vm *VM) Run() error { globalIndex := code.ReadUint16(ins[ip+1:]) vm.currentFrame().ip += 2 - err := vm.push(vm.globals[globalIndex]) + err := vm.push(vm.state.Globals[globalIndex]) if err != nil { return err } @@ -359,6 +468,14 @@ func (vm *VM) Run() error { return err } + case code.OpLoadModule: + name := vm.pop() + + err := vm.loadModule(name) + if err != nil { + return err + } + } if vm.Debug { @@ -396,6 +513,8 @@ func (vm *VM) executeGetItem(left, index object.Object) error { return vm.executeArrayGetItem(left, index) case left.Type() == object.HASH_OBJ: return vm.executeHashGetItem(left, index) + case left.Type() == object.MODULE_OBJ: + return vm.executeHashGetItem(left.(*object.Module).Attrs, index) default: return fmt.Errorf( "index operator not supported: left=%s index=%s", @@ -766,7 +885,7 @@ func (vm *VM) callBuiltin(builtin *object.Builtin, numArgs int) error { } func (vm *VM) pushClosure(constIndex, numFree int) error { - constant := vm.constants[constIndex] + constant := vm.state.Constants[constIndex] function, ok := constant.(*object.CompiledFunction) if !ok { return fmt.Errorf("not a function %+v", constant) diff --git a/vm/vm_test.go b/vm/vm_test.go index c54137b..a87f7b7 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -7,6 +7,7 @@ import ( "monkey/lexer" "monkey/object" "monkey/parser" + "monkey/utils" "os" "path" "path/filepath" @@ -1127,3 +1128,35 @@ func BenchmarkFibonacci(b *testing.B) { }) } } + +func TestImportExpressions(t *testing.T) { + tests := []vmTestCase{ + { + input: `mod := import("../testdata/mod"); mod.A`, + expected: 5, + }, + { + input: `mod := import("../testdata/mod"); mod.Sum(2, 3)`, + expected: 5, + }, + { + input: `mod := import("../testdata/mod"); mod.a`, + expected: nil, + }, + } + + runVmTests(t, tests) +} + +func TestImportSearchPaths(t *testing.T) { + utils.AddPath("../testdata") + + tests := []vmTestCase{ + { + input: `mod := import("../testdata/mod"); mod.A`, + expected: 5, + }, + } + + runVmTests(t, tests) +}