module support
Some checks failed
Build / build (push) Successful in 10m22s
Test / build (push) Failing after 15m54s

This commit is contained in:
Chuck Smith
2024-03-26 16:49:38 -04:00
parent 6d234099d1
commit 110152a139
21 changed files with 541 additions and 100 deletions

View File

@@ -428,3 +428,25 @@ func (c *Comment) String() string {
return out.String()
}
// ImportExpression represents an `import` expression and holds the name
// of the module being imported.
type ImportExpression struct {
Token token.Token // The 'import' token
Name Expression
}
func (ie *ImportExpression) TokenLiteral() string {
return ie.Token.Literal
}
func (ie *ImportExpression) String() string {
var out bytes.Buffer
out.WriteString(ie.TokenLiteral())
out.WriteString("(")
out.WriteString(fmt.Sprintf("\"%s\"", ie.Name))
out.WriteString(")")
return out.String()
}
func (ie *ImportExpression) expressionNode() {}

View File

@@ -8,43 +8,43 @@ import (
// Builtins ...
var Builtins = map[string]*object.Builtin{
"len": {Name: "len", Fn: Len},
"input": {Name: "input", Fn: Input},
"print": {Name: "print", Fn: Print},
"first": {Name: "first", Fn: First},
"last": {Name: "last", Fn: Last},
"rest": {Name: "rest", Fn: Rest},
"push": {Name: "push", Fn: Push},
"pop": {Name: "pop", Fn: Pop},
"exit": {Name: "exit", Fn: Exit},
"assert": {Name: "assert", Fn: Assert},
"bool": {Name: "bool", Fn: Bool},
"int": {Name: "int", Fn: Int},
"str": {Name: "str", Fn: Str},
"type": {Name: "type", Fn: TypeOf},
"args": {Name: "args", Fn: Args},
"lower": {Name: "lower", Fn: Lower},
"upper": {Name: "upper", Fn: Upper},
"join": {Name: "join", Fn: Join},
"split": {Name: "split", Fn: Split},
"find": {Name: "find", Fn: Find},
"read": {Name: "read", Fn: Read},
"write": {Name: "write", Fn: Write},
"ffi": {Name: "ffi", Fn: FFI},
"abs": {Name: "abs", Fn: Abs},
"bin": {Name: "bin", Fn: Bin},
"hex": {Name: "hex", Fn: Hex},
"ord": {Name: "ord", Fn: Ord},
"chr": {Name: "chr", Fn: Chr},
"divmod": {Name: "divmod", Fn: Divmod},
"hash": {Name: "hash", Fn: HashOf},
"id": {Name: "id", Fn: IdOf},
"oct": {Name: "oct", Fn: Oct},
"pow": {Name: "pow", Fn: Pow},
"min": {Name: "min", Fn: Min},
"max": {Name: "max", Fn: Max},
"sorted": {Name: "sorted", Fn: Sorted},
"reversed": {Name: "reversed", Fn: Reversed},
"len": {Name: "len", Fn: Len},
"input": {Name: "input", Fn: Input},
"print": {Name: "print", Fn: Print},
"first": {Name: "first", Fn: First},
"last": {Name: "last", Fn: Last},
"rest": {Name: "rest", Fn: Rest},
"push": {Name: "push", Fn: Push},
"pop": {Name: "pop", Fn: Pop},
"exit": {Name: "exit", Fn: Exit},
"assert": {Name: "assert", Fn: Assert},
"bool": {Name: "bool", Fn: Bool},
"int": {Name: "int", Fn: Int},
"str": {Name: "str", Fn: Str},
"type": {Name: "type", Fn: TypeOf},
"args": {Name: "args", Fn: Args},
"lower": {Name: "lower", Fn: Lower},
"upper": {Name: "upper", Fn: Upper},
"join": {Name: "join", Fn: Join},
"split": {Name: "split", Fn: Split},
"find": {Name: "find", Fn: Find},
"readfile": {Name: "readfile", Fn: ReadFile},
"writefile": {Name: "writefile", Fn: WriteFile},
"ffi": {Name: "ffi", Fn: FFI},
"abs": {Name: "abs", Fn: Abs},
"bin": {Name: "bin", Fn: Bin},
"hex": {Name: "hex", Fn: Hex},
"ord": {Name: "ord", Fn: Ord},
"chr": {Name: "chr", Fn: Chr},
"divmod": {Name: "divmod", Fn: Divmod},
"hash": {Name: "hash", Fn: HashOf},
"id": {Name: "id", Fn: IdOf},
"oct": {Name: "oct", Fn: Oct},
"pow": {Name: "pow", Fn: Pow},
"min": {Name: "min", Fn: Min},
"max": {Name: "max", Fn: Max},
"sorted": {Name: "sorted", Fn: Sorted},
"reversed": {Name: "reversed", Fn: Reversed},
}
// BuiltinsIndex ...

View File

@@ -6,10 +6,10 @@ import (
"monkey/typing"
)
// Read ...
func Read(args ...object.Object) object.Object {
// ReadFile ...
func ReadFile(args ...object.Object) object.Object {
if err := typing.Check(
"read", args,
"readfile", args,
typing.ExactArgs(1),
typing.WithTypes(object.STRING_OBJ),
); err != nil {

View File

@@ -6,10 +6,10 @@ import (
"monkey/typing"
)
// Write ...
func Write(args ...object.Object) object.Object {
// WriteFile ...
func WriteFile(args ...object.Object) object.Object {
if err := typing.Check(
"write", args,
"writefile", args,
typing.ExactArgs(2),
typing.WithTypes(object.STRING_OBJ, object.STRING_OBJ),
); err != nil {

View File

@@ -61,7 +61,7 @@ const (
OpClosure
OpGetFree
OpCurrentClosure
OpNoop
OpLoadModule
)
type Definition struct {
@@ -112,7 +112,7 @@ var definitions = map[Opcode]*Definition{
OpClosure: {"OpClosure", []int{2, 1}},
OpGetFree: {"OpGetFree", []int{1}},
OpCurrentClosure: {"OpCurrentClosure", []int{}},
OpNoop: {"OpNoop", []int{}},
OpLoadModule: {"OpLoadModule", []int{}},
}
func Lookup(op byte) (*Definition, error) {

View File

@@ -513,6 +513,16 @@ func (c *Compiler) Compile(node ast.Node) error {
afterConsequencePos := c.emit(code.OpNull)
c.changeOperand(jumpIfFalsePos, afterConsequencePos)
case *ast.ImportExpression:
c.l++
err := c.Compile(node.Name)
if err != nil {
return err
}
c.l--
c.emit(code.OpLoadModule)
}
return nil

View File

@@ -1031,6 +1031,18 @@ func TestIteration(t *testing.T) {
runCompilerTests(t, tests)
}
func TestImportExpressions(t *testing.T) {
tests := []compilerTestCase2{
{
input: `import("foo")`,
constants: []interface{}{"foo"},
instructions: "0000 OpConstant 0\n0003 OpLoadModule\n0004 OpPop\n",
},
}
runCompilerTests2(t, tests)
}
func runCompilerTests(t *testing.T, tests []compilerTestCase) {
t.Helper()

View File

@@ -19,7 +19,7 @@ type Symbol struct {
type SymbolTable struct {
Outer *SymbolTable
store map[string]Symbol
Store map[string]Symbol
numDefinitions int
FreeSymbols []Symbol
@@ -34,7 +34,7 @@ func NewEnclosedSymbolTable(outer *SymbolTable) *SymbolTable {
func NewSymbolTable() *SymbolTable {
s := make(map[string]Symbol)
free := []Symbol{}
return &SymbolTable{store: s, FreeSymbols: free}
return &SymbolTable{Store: s, FreeSymbols: free}
}
func (s *SymbolTable) Define(name string) Symbol {
@@ -45,13 +45,13 @@ func (s *SymbolTable) Define(name string) Symbol {
symbol.Scope = LocalScope
}
s.store[name] = symbol
s.Store[name] = symbol
s.numDefinitions++
return symbol
}
func (s *SymbolTable) Resolve(name string) (Symbol, bool) {
obj, ok := s.store[name]
obj, ok := s.Store[name]
if !ok && s.Outer != nil {
obj, ok = s.Outer.Resolve(name)
if !ok {
@@ -72,7 +72,7 @@ func (s *SymbolTable) Resolve(name string) (Symbol, bool) {
func (s *SymbolTable) DefineBuiltin(index int, name string) Symbol {
symbol := Symbol{Name: name, Index: index, Scope: BuiltinScope}
s.store[name] = symbol
s.Store[name] = symbol
return symbol
}
@@ -82,12 +82,12 @@ func (s *SymbolTable) DefineFree(original Symbol) Symbol {
symbol := Symbol{Name: original.Name, Index: len(s.FreeSymbols) - 1}
symbol.Scope = FreeScope
s.store[original.Name] = symbol
s.Store[original.Name] = symbol
return symbol
}
func (s *SymbolTable) DefineFunctionName(name string) Symbol {
symbol := Symbol{Name: name, Index: 0, Scope: FunctionScope}
s.store[name] = symbol
s.Store[name] = symbol
return symbol
}

View File

@@ -4,7 +4,11 @@ import (
"fmt"
"monkey/ast"
"monkey/builtins"
"monkey/lexer"
"monkey/object"
"monkey/parser"
"monkey/utils"
"os"
"strings"
)
@@ -40,6 +44,9 @@ func Eval(node ast.Node, env *object.Environment) object.Object {
case *ast.WhileExpression:
return evalWhileExpression(node, env)
case *ast.ImportExpression:
return evalImportExpression(node, env)
case *ast.ReturnStatement:
val := Eval(node.ReturnValue, env)
if isError(val) {
@@ -190,6 +197,22 @@ func Eval(node ast.Node, env *object.Environment) object.Object {
return nil
}
func evalImportExpression(ie *ast.ImportExpression, env *object.Environment) object.Object {
name := Eval(ie.Name, env)
if isError(name) {
return name
}
if s, ok := name.(*object.String); ok {
attrs := EvalModule(s.Value)
if isError(attrs) {
return attrs
}
return &object.Module{Name: s.Value, Attrs: attrs}
}
return newError("ImportError: invalid import path '%s'", name)
}
func evalWhileExpression(we *ast.WhileExpression, env *object.Environment) object.Object {
var result object.Object
@@ -208,9 +231,9 @@ func evalWhileExpression(we *ast.WhileExpression, env *object.Environment) objec
if result != nil {
return result
} else {
return NULL
}
return NULL
}
func evalProgram(program *ast.Program, env *object.Environment) object.Object {
@@ -485,6 +508,29 @@ func newError(format string, a ...interface{}) *object.Error {
return &object.Error{Message: fmt.Sprintf(format, a...)}
}
// EvalModule evaluates the named module and returns a *object.Module objec
func EvalModule(name string) object.Object {
filename := utils.FindModule(name)
b, err := os.ReadFile(filename)
if err != nil {
return newError("IOError: error reading module '%s': %s", name, err)
}
l := lexer.New(string(b))
p := parser.New(l)
module := p.ParseProgram()
if len(p.Errors()) != 0 {
return newError("ParseError: %s", p.Errors())
}
env := object.NewEnvironment()
Eval(module, env)
return env.ExportedHash()
}
func evalIdentifier(node *ast.Identifier, env *object.Environment) object.Object {
if val, ok := env.Get(node.Value); ok {
return val
@@ -557,11 +603,18 @@ func evalIndexExpression(left, index object.Object) object.Object {
return evalArrayIndexExpression(left, index)
case left.Type() == object.HASH_OBJ:
return evalHashIndexExpression(left, index)
case left.Type() == object.MODULE_OBJ:
return EvalModuleIndexExpression(left, index)
default:
return newError("index operator not supported: %s", left.Type())
}
}
func EvalModuleIndexExpression(module, index object.Object) object.Object {
moduleObject := module.(*object.Module)
return evalHashIndexExpression(moduleObject.Attrs, index)
}
func evalStringIndexExpression(str, index object.Object) object.Object {
stringObject := str.(*object.String)
idx := index.(*object.Integer).Value

View File

@@ -2,14 +2,51 @@ package evaluator
import (
"errors"
"github.com/stretchr/testify/assert"
"monkey/lexer"
"monkey/object"
"monkey/parser"
"monkey/utils"
"os"
"path/filepath"
"testing"
)
func assertEvaluated(t *testing.T, expected interface{}, actual object.Object) {
t.Helper()
assert := assert.New(t)
switch expected.(type) {
case nil:
if _, ok := actual.(*object.Null); ok {
assert.True(ok)
} else {
assert.Equal(expected, actual)
}
case int:
if i, ok := actual.(*object.Integer); ok {
assert.Equal(int64(expected.(int)), i.Value)
} else {
assert.Equal(expected, actual)
}
case error:
if e, ok := actual.(*object.Integer); ok {
assert.Equal(expected.(error).Error(), e.Value)
} else {
assert.Equal(expected, actual)
}
case string:
if s, ok := actual.(*object.String); ok {
assert.Equal(expected.(string), s.Value)
} else {
assert.Equal(expected, actual)
}
default:
t.Fatalf("unsupported type for expected got=%T", expected)
}
}
func TestEvalExpressions(t *testing.T) {
tests := []struct {
input string
@@ -814,6 +851,37 @@ func testStringObject(t *testing.T, obj object.Object, expected string) bool {
}
return true
}
func TestImportExpressions(t *testing.T) {
tests := []struct {
input string
expected interface{}
}{
{`mod := import("../testdata/mod"); mod.A`, 5},
{`mod := import("../testdata/mod"); mod.Sum(2, 3)`, 5},
{`mod := import("../testdata/mod"); mod.a`, nil},
}
for _, tt := range tests {
evaluated := testEval(tt.input)
assertEvaluated(t, tt.expected, evaluated)
}
}
func TestImportSearchPaths(t *testing.T) {
utils.AddPath("../testdata")
tests := []struct {
input string
expected interface{}
}{
{`mod := import("mod"); mod.A`, 5},
}
for _, tt := range tests {
evaluated := testEval(tt.input)
assertEvaluated(t, tt.expected, evaluated)
}
}
func TestExamples(t *testing.T) {
matches, err := filepath.Glob("../examples/*.monkey")

View File

@@ -1,5 +1,7 @@
package object
import "unicode"
func NewEnclosedEnvironment(outer *Environment) *Environment {
env := NewEnvironment()
env.outer = outer
@@ -28,3 +30,18 @@ func (e *Environment) Set(name string, val Object) Object {
e.store[name] = val
return val
}
// ExportedHash returns a new Hash with the names and values of every publicly
// exported binding in the environment. That is every binding that starts with a
// capital letter. This is used by the module import system to wrap up the
// evaluated module into an object.
func (e *Environment) ExportedHash() *Hash {
pairs := make(map[HashKey]HashPair)
for k, v := range e.store {
if unicode.IsUpper(rune(k[0])) {
s := &String{Value: k}
pairs[s.HashKey()] = HashPair{Key: s, Value: v}
}
}
return &Hash{Pairs: pairs}
}

29
object/module.go Normal file
View File

@@ -0,0 +1,29 @@
package object
import "fmt"
// Module is the module type used to represent a collection of variabels.
type Module struct {
Name string
Attrs Object
}
func (m Module) String() string {
return m.Inspect()
}
func (m Module) Type() ObjectType {
return MODULE_OBJ
}
func (m Module) Bool() bool {
return true
}
func (m Module) Inspect() string {
return fmt.Sprintf("<module '%s'>", m.Name)
}
func (m Module) Compare(other Object) int {
return 1
}

View File

@@ -18,6 +18,7 @@ const (
HASH_OBJ = "hash"
COMPILED_FUNCTION_OBJ = "COMPILED_FUNCTION"
CLOSURE_OBJ = "closure"
MODULE_OBJ = "module"
)
// Comparable is the interface for comparing two Object and their underlying

View File

@@ -87,6 +87,7 @@ func New(l *lexer.Lexer) *Parser {
p.registerPrefix(token.LPAREN, p.parseGroupedExpression)
p.registerPrefix(token.IF, p.parseIfExpression)
p.registerPrefix(token.FUNCTION, p.parseFunctionLiteral)
p.registerPrefix(token.IMPORT, p.parseImportExpression)
p.registerPrefix(token.WHILE, p.parseWhileExpression)
p.registerPrefix(token.STRING, p.parseStringLiteral)
p.registerPrefix(token.LBRACKET, p.parseArrayLiteral)
@@ -601,3 +602,20 @@ func (p *Parser) parseBindExpression(expression ast.Expression) ast.Expression {
return be
}
func (p *Parser) parseImportExpression() ast.Expression {
expression := &ast.ImportExpression{Token: p.curToken}
if !p.expectPeek(token.LPAREN) {
return nil
}
p.nextToken()
expression.Name = p.parseExpression(LOWEST)
if !p.expectPeek(token.RPAREN) {
return nil
}
return expression
}

View File

@@ -1227,3 +1227,23 @@ func testComment(t *testing.T, s ast.Statement, expected string) bool {
return true
}
func TestParsingImportExpressions(t *testing.T) {
assert := assert.New(t)
tests := []struct {
input string
expected string
}{
{`import("mod")`, `import("mod")`},
}
for _, tt := range tests {
l := lexer.New(tt.input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)
assert.Equal(tt.expected, program.String())
}
}

View File

@@ -8,7 +8,6 @@ import (
"fmt"
"io"
"log"
"monkey/builtins"
"monkey/compiler"
"monkey/evaluator"
"monkey/lexer"
@@ -42,25 +41,6 @@ type Options struct {
Interactive bool
}
type VMState struct {
constants []object.Object
globals []object.Object
symbols *compiler.SymbolTable
}
func NewVMState() *VMState {
symbolTable := compiler.NewSymbolTable()
for i, builtin := range builtins.BuiltinsIndex {
symbolTable.DefineBuiltin(i, builtin.Name)
}
return &VMState{
constants: []object.Object{},
globals: make([]object.Object, vm.GlobalsSize),
symbols: symbolTable,
}
}
type REPL struct {
user string
args []string
@@ -101,14 +81,14 @@ func (r *REPL) Eval(f io.Reader) (env *object.Environment) {
// Exec parses, compiles and executes the program given by f and returns
// the resulting virtual machine, any errors are printed to stderr
func (r *REPL) Exec(f io.Reader) (state *VMState) {
func (r *REPL) Exec(f io.Reader) (state *vm.VMState) {
b, err := io.ReadAll(f)
if err != nil {
fmt.Fprintf(os.Stderr, "error reading source file: %s", err)
return
}
state = NewVMState()
state = vm.NewVMState()
l := lexer.New(string(b))
p := parser.New(l)
@@ -119,7 +99,7 @@ func (r *REPL) Exec(f io.Reader) (state *VMState) {
return
}
c := compiler.NewWithState(state.symbols, state.constants)
c := compiler.NewWithState(state.Symbols, state.Constants)
c.Debug = r.opts.Debug
err = c.Compile(program)
if err != nil {
@@ -128,9 +108,9 @@ func (r *REPL) Exec(f io.Reader) (state *VMState) {
}
code := c.Bytecode()
state.constants = code.Constants
state.Constants = code.Constants
machine := vm.NewWithGlobalState(code, state.globals)
machine := vm.NewWithState(code, state)
machine.Debug = r.opts.Debug
err = machine.Run()
if err != nil {
@@ -176,11 +156,11 @@ func (r *REPL) StartEvalLoop(in io.Reader, out io.Writer, env *object.Environmen
}
// StartExecLoop starts the REPL in a continious exec loop
func (r *REPL) StartExecLoop(in io.Reader, out io.Writer, state *VMState) {
func (r *REPL) StartExecLoop(in io.Reader, out io.Writer, state *vm.VMState) {
scanner := bufio.NewScanner(in)
if state == nil {
state = NewVMState()
state = vm.NewVMState()
}
for {
@@ -201,7 +181,7 @@ func (r *REPL) StartExecLoop(in io.Reader, out io.Writer, state *VMState) {
continue
}
c := compiler.NewWithState(state.symbols, state.constants)
c := compiler.NewWithState(state.Symbols, state.Constants)
c.Debug = r.opts.Debug
err := c.Compile(program)
if err != nil {
@@ -210,9 +190,9 @@ func (r *REPL) StartExecLoop(in io.Reader, out io.Writer, state *VMState) {
}
code := c.Bytecode()
state.constants = code.Constants
state.Constants = code.Constants
machine := vm.NewWithGlobalState(code, state.globals)
machine := vm.NewWithState(code, state)
machine.Debug = r.opts.Debug
err = machine.Run()
if err != nil {

3
testdata/mod.monkey vendored Normal file
View File

@@ -0,0 +1,3 @@
a := 1
A := 5
Sum := fn(a, b) { return a + b }

View File

@@ -84,6 +84,8 @@ const (
ELSE = "ELSE"
RETURN = "RETURN"
WHILE = "WHILE"
// IMPORT the `import` keyword (import)
IMPORT = "IMPORT"
)
var keywords = map[string]TokenType{
@@ -95,6 +97,7 @@ var keywords = map[string]TokenType{
"else": ELSE,
"return": RETURN,
"while": WHILE,
"import": IMPORT,
}
func LookupIdent(ident string) TokenType {

53
utils/utils.go Normal file
View File

@@ -0,0 +1,53 @@
package utils
import (
"fmt"
"log"
"os"
"path/filepath"
"strings"
)
var SearchPaths []string
func init() {
cwd, err := os.Getwd()
if err != nil {
log.Fatalf("error getting cwd: %s", err)
}
if e := os.Getenv("MONKEYPATH"); e != "" {
tokens := strings.Split(e, ":")
for _, token := range tokens {
AddPath(token) // ignore errors
}
} else {
SearchPaths = append(SearchPaths, cwd)
}
}
func AddPath(path string) error {
path = os.ExpandEnv(filepath.Clean(path))
absPath, err := filepath.Abs(path)
if err != nil {
return err
}
SearchPaths = append(SearchPaths, absPath)
return nil
}
func Exists(path string) bool {
_, err := os.Stat(path)
return err == nil
}
func FindModule(name string) string {
basename := fmt.Sprintf("%s.monkey", name)
for _, p := range SearchPaths {
filename := filepath.Join(p, basename)
if Exists(filename) {
return filename
}
}
return ""
}

151
vm/vm.go
View File

@@ -2,12 +2,17 @@ package vm
import (
"fmt"
"io/ioutil"
"log"
"monkey/builtins"
"monkey/code"
"monkey/compiler"
"monkey/lexer"
"monkey/object"
"monkey/parser"
"monkey/utils"
"strings"
"unicode"
)
const StackSize = 2048
@@ -18,16 +23,88 @@ var Null = &object.Null{}
var True = &object.Boolean{Value: true}
var False = &object.Boolean{Value: false}
// ExecModule compiles the named module and returns a *object.Module object
func ExecModule(name string, state *VMState) (object.Object, error) {
filename := utils.FindModule(name)
if filename == "" {
return nil, fmt.Errorf("ImportError: no module named '%s'", name)
}
b, err := ioutil.ReadFile(filename)
if err != nil {
return nil, fmt.Errorf("IOError: error reading module '%s': %s", name, err)
}
l := lexer.New(string(b))
p := parser.New(l)
module := p.ParseProgram()
if len(p.Errors()) != 0 {
return nil, fmt.Errorf("ParseError: %s", p.Errors())
}
c := compiler.NewWithState(state.Symbols, state.Constants)
err = c.Compile(module)
if err != nil {
return nil, fmt.Errorf("CompileError: %s", err)
}
code := c.Bytecode()
state.Constants = code.Constants
machine := NewWithState(code, state)
err = machine.Run()
if err != nil {
return nil, fmt.Errorf("RuntimeError: error loading module '%s'", err)
}
return state.ExportedHash(), nil
}
type VMState struct {
Constants []object.Object
Globals []object.Object
Symbols *compiler.SymbolTable
}
func NewVMState() *VMState {
symbolTable := compiler.NewSymbolTable()
for i, builtin := range builtins.BuiltinsIndex {
symbolTable.DefineBuiltin(i, builtin.Name)
}
return &VMState{
Constants: []object.Object{},
Globals: make([]object.Object, GlobalsSize),
Symbols: symbolTable,
}
}
// exported binding in the vm state. That is every binding that starts with a
// capital letter. This is used by the module import system to wrap up the
// compiled and evaulated module into an object.
func (s *VMState) ExportedHash() *object.Hash {
pairs := make(map[object.HashKey]object.HashPair)
for name, symbol := range s.Symbols.Store {
if unicode.IsUpper(rune(name[0])) {
if symbol.Scope == compiler.GlobalScope {
obj := s.Globals[symbol.Index]
s := &object.String{Value: name}
pairs[s.HashKey()] = object.HashPair{Key: s, Value: obj}
}
}
}
return &object.Hash{Pairs: pairs}
}
type VM struct {
Debug bool
constants []object.Object
state *VMState
stack []object.Object
sp int // Always points to the next value. Top of stack is stack[sp-1]
globals []object.Object
frames []*Frame
framesIndex int
}
@@ -40,23 +117,37 @@ func New(bytecode *compiler.Bytecode) *VM {
frames := make([]*Frame, MaxFrames)
frames[0] = mainFrame
state := NewVMState()
state.Constants = bytecode.Constants
return &VM{
constants: bytecode.Constants,
state: state,
stack: make([]object.Object, StackSize),
sp: 0,
globals: make([]object.Object, GlobalsSize),
frames: frames,
framesIndex: 1,
}
}
func NewWithGlobalState(bytecode *compiler.Bytecode, s []object.Object) *VM {
vm := New(bytecode)
vm.globals = s
return vm
func NewWithState(bytecode *compiler.Bytecode, state *VMState) *VM {
mainFn := &object.CompiledFunction{Instructions: bytecode.Instructions}
mainClosure := &object.Closure{Fn: mainFn}
mainFrame := NewFrame(mainClosure, 0)
frames := make([]*Frame, MaxFrames)
frames[0] = mainFrame
return &VM{
state: state,
frames: frames,
framesIndex: 1,
stack: make([]object.Object, StackSize),
sp: 0,
}
}
func (vm *VM) currentFrame() *Frame {
@@ -73,6 +164,24 @@ func (vm *VM) popFrame() *Frame {
return vm.frames[vm.framesIndex]
}
func (vm *VM) loadModule(name object.Object) error {
s, ok := name.(*object.String)
if !ok {
return fmt.Errorf(
"TypeError: import() expected argument #1 to be `str` got `%s`",
name.Type(),
)
}
attrs, err := ExecModule(s.Value, vm.state)
if err != nil {
return err
}
module := &object.Module{Name: s.Value, Attrs: attrs}
return vm.push(module)
}
func (vm *VM) LastPoppedStackElem() object.Object {
return vm.stack[vm.sp]
}
@@ -108,7 +217,7 @@ func (vm *VM) Run() error {
constIndex := code.ReadUint16(ins[ip+1:])
vm.currentFrame().ip += 2
err := vm.push(vm.constants[constIndex])
err := vm.push(vm.state.Constants[constIndex])
if err != nil {
return err
}
@@ -185,9 +294,9 @@ func (vm *VM) Run() error {
ref := vm.pop()
if immutable, ok := ref.(object.Immutable); ok {
vm.globals[globalIndex] = immutable.Clone()
vm.state.Globals[globalIndex] = immutable.Clone()
} else {
vm.globals[globalIndex] = ref
vm.state.Globals[globalIndex] = ref
}
err := vm.push(Null)
@@ -198,7 +307,7 @@ func (vm *VM) Run() error {
case code.OpAssignGlobal:
globalIndex := code.ReadUint16(ins[ip+1:])
vm.currentFrame().ip += 2
vm.globals[globalIndex] = vm.pop()
vm.state.Globals[globalIndex] = vm.pop()
err := vm.push(Null)
if err != nil {
@@ -221,7 +330,7 @@ func (vm *VM) Run() error {
globalIndex := code.ReadUint16(ins[ip+1:])
vm.currentFrame().ip += 2
err := vm.push(vm.globals[globalIndex])
err := vm.push(vm.state.Globals[globalIndex])
if err != nil {
return err
}
@@ -359,6 +468,14 @@ func (vm *VM) Run() error {
return err
}
case code.OpLoadModule:
name := vm.pop()
err := vm.loadModule(name)
if err != nil {
return err
}
}
if vm.Debug {
@@ -396,6 +513,8 @@ func (vm *VM) executeGetItem(left, index object.Object) error {
return vm.executeArrayGetItem(left, index)
case left.Type() == object.HASH_OBJ:
return vm.executeHashGetItem(left, index)
case left.Type() == object.MODULE_OBJ:
return vm.executeHashGetItem(left.(*object.Module).Attrs, index)
default:
return fmt.Errorf(
"index operator not supported: left=%s index=%s",
@@ -766,7 +885,7 @@ func (vm *VM) callBuiltin(builtin *object.Builtin, numArgs int) error {
}
func (vm *VM) pushClosure(constIndex, numFree int) error {
constant := vm.constants[constIndex]
constant := vm.state.Constants[constIndex]
function, ok := constant.(*object.CompiledFunction)
if !ok {
return fmt.Errorf("not a function %+v", constant)

View File

@@ -7,6 +7,7 @@ import (
"monkey/lexer"
"monkey/object"
"monkey/parser"
"monkey/utils"
"os"
"path"
"path/filepath"
@@ -1127,3 +1128,35 @@ func BenchmarkFibonacci(b *testing.B) {
})
}
}
func TestImportExpressions(t *testing.T) {
tests := []vmTestCase{
{
input: `mod := import("../testdata/mod"); mod.A`,
expected: 5,
},
{
input: `mod := import("../testdata/mod"); mod.Sum(2, 3)`,
expected: 5,
},
{
input: `mod := import("../testdata/mod"); mod.a`,
expected: nil,
},
}
runVmTests(t, tests)
}
func TestImportSearchPaths(t *testing.T) {
utils.AddPath("../testdata")
tests := []vmTestCase{
{
input: `mod := import("../testdata/mod"); mod.A`,
expected: 5,
},
}
runVmTests(t, tests)
}