diff --git a/compiler/compiler.go b/compiler/compiler.go index ef0078d..bc8cfac 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -2,10 +2,12 @@ package compiler import ( "fmt" + "log" "monkey/ast" "monkey/code" "monkey/object" "sort" + "strings" ) type EmittedInstruction struct { @@ -20,6 +22,9 @@ type CompilationScope struct { } type Compiler struct { + Debug bool + + l int constants []object.Object symbolTable *SymbolTable @@ -61,8 +66,16 @@ func NewWithState(s *SymbolTable, constants []object.Object) *Compiler { } func (c *Compiler) Compile(node ast.Node) error { + if c.Debug { + log.Printf( + "%s Compiling %T: %s\n", + strings.Repeat(" ", c.l), node, node.String(), + ) + } + switch node := node.(type) { case *ast.Program: + c.l++ for _, s := range node.Statements { err := c.Compile(s) if err != nil { @@ -71,7 +84,9 @@ func (c *Compiler) Compile(node ast.Node) error { } case *ast.ExpressionStatement: + c.l++ err := c.Compile(node.Expression) + c.l-- if err != nil { return err } @@ -82,12 +97,16 @@ func (c *Compiler) Compile(node ast.Node) error { case *ast.InfixExpression: if node.Operator == "<" || node.Operator == "<=" { + c.l++ err := c.Compile(node.Right) + c.l-- if err != nil { return err } + c.l++ err = c.Compile(node.Left) + c.l-- if err != nil { return err } @@ -99,12 +118,16 @@ func (c *Compiler) Compile(node ast.Node) error { return nil } + c.l++ err := c.Compile(node.Left) + c.l-- if err != nil { return err } + c.l++ err = c.Compile(node.Right) + c.l-- if err != nil { return err } @@ -142,7 +165,9 @@ func (c *Compiler) Compile(node ast.Node) error { } case *ast.PrefixExpression: + c.l++ err := c.Compile(node.Right) + c.l-- if err != nil { return err } @@ -160,7 +185,9 @@ func (c *Compiler) Compile(node ast.Node) error { if c.lastInstructionIs(code.OpPop) { c.removeLastPop() } + c.l++ err := c.Compile(node.Condition) + c.l-- if err != nil { return err } @@ -168,7 +195,9 @@ func (c *Compiler) Compile(node ast.Node) error { // Emit an `OpJumpNotTruthy` with a bogus value jumpNotTruthyPos := c.emit(code.OpJumpNotTruthy, 9999) + c.l++ err = c.Compile(node.Consequence) + c.l-- if err != nil { return err } @@ -186,7 +215,9 @@ func (c *Compiler) Compile(node ast.Node) error { if node.Alternative == nil { c.emit(code.OpNull) } else { - err = c.Compile(node.Alternative) + c.l++ + err := c.Compile(node.Alternative) + c.l-- if err != nil { return err } @@ -200,12 +231,14 @@ func (c *Compiler) Compile(node ast.Node) error { c.changeOperand(jumpPos, afterAlternativePos) case *ast.BlockStatement: + c.l++ for _, s := range node.Statements { err := c.Compile(s) if err != nil { return err } } + c.l-- case *ast.AssignmentStatement: symbol, ok := c.symbolTable.Resolve(node.Name.Value) @@ -213,7 +246,9 @@ func (c *Compiler) Compile(node ast.Node) error { return fmt.Errorf("undefined variable %s", node.Value) } + c.l++ err := c.Compile(node.Value) + c.l-- if err != nil { return err } @@ -230,7 +265,9 @@ func (c *Compiler) Compile(node ast.Node) error { symbol = c.symbolTable.Define(node.Name.Value) } + c.l++ err := c.Compile(node.Value) + c.l-- if err != nil { return err } @@ -255,7 +292,9 @@ func (c *Compiler) Compile(node ast.Node) error { case *ast.ArrayLiteral: for _, el := range node.Elements { + c.l++ err := c.Compile(el) + c.l-- if err != nil { return err } @@ -273,11 +312,15 @@ func (c *Compiler) Compile(node ast.Node) error { }) for _, k := range keys { + c.l++ err := c.Compile(k) + c.l-- if err != nil { return err } + c.l++ err = c.Compile(node.Pairs[k]) + c.l-- if err != nil { return err } @@ -286,12 +329,16 @@ func (c *Compiler) Compile(node ast.Node) error { c.emit(code.OpHash, len(node.Pairs)*2) case *ast.IndexExpression: + c.l++ err := c.Compile(node.Left) + c.l-- if err != nil { return err } + c.l++ err = c.Compile(node.Index) + c.l-- if err != nil { return err } @@ -309,7 +356,9 @@ func (c *Compiler) Compile(node ast.Node) error { c.symbolTable.Define(p.Value) } + c.l++ err := c.Compile(node.Body) + c.l-- if err != nil { return err } @@ -340,7 +389,9 @@ func (c *Compiler) Compile(node ast.Node) error { c.emit(code.OpClosure, fnIndex, len(freeSymbols)) case *ast.ReturnStatement: + c.l++ err := c.Compile(node.ReturnValue) + c.l-- if err != nil { return err } @@ -348,7 +399,9 @@ func (c *Compiler) Compile(node ast.Node) error { c.emit(code.OpReturn) case *ast.CallExpression: + c.l++ err := c.Compile(node.Function) + c.l-- if err != nil { return err } @@ -365,7 +418,9 @@ func (c *Compiler) Compile(node ast.Node) error { case *ast.WhileExpression: jumpConditionPos := len(c.currentInstructions()) + c.l++ err := c.Compile(node.Condition) + c.l-- if err != nil { return err } @@ -373,10 +428,12 @@ func (c *Compiler) Compile(node ast.Node) error { // Emit an `OpJump`with a bogus value jumpIfFalsePos := c.emit(code.OpJumpNotTruthy, 0xFFFF) + c.l++ err = c.Compile(node.Consequence) if err != nil { return err } + c.l-- c.emit(code.OpJump, jumpConditionPos) diff --git a/repl/repl.go b/repl/repl.go index 3e86796..7d86cd4 100644 --- a/repl/repl.go +++ b/repl/repl.go @@ -116,6 +116,7 @@ func (r *REPL) Exec(f io.Reader) (state *VMState) { } c := compiler.NewWithState(state.symbols, state.constants) + c.Debug = r.opts.Debug err = c.Compile(program) if err != nil { fmt.Fprintf(os.Stderr, "Woops! Compilation failed:\n %s\n", err) @@ -197,6 +198,7 @@ func (r *REPL) StartExecLoop(in io.Reader, out io.Writer, state *VMState) { } c := compiler.NewWithState(state.symbols, state.constants) + c.Debug = r.opts.Debug err := c.Compile(program) if err != nil { fmt.Fprintf(os.Stderr, "Woops! Compilation failed:\n %s\n", err)