conditionals

This commit is contained in:
Chuck Smith
2024-02-07 15:46:45 -05:00
parent cff4375649
commit 77401260a2
5 changed files with 198 additions and 15 deletions

View File

@@ -24,6 +24,8 @@ const (
OpGreaterThan
OpMinus
OpBang
OpJumpNotTruthy
OpJump
)
type Definition struct {
@@ -32,19 +34,21 @@ type Definition struct {
}
var definitions = map[Opcode]*Definition{
OpConstant: {"OpConstant", []int{2}},
OpAdd: {"OpAdd", []int{}},
OpPop: {"OpPop", []int{}},
OpSub: {"OpSub", []int{}},
OpMul: {"OpMul", []int{}},
OpDiv: {"OpDiv", []int{}},
OpTrue: {"OpTrue", []int{}},
OpFalse: {"OpFalse", []int{}},
OpEqual: {"OpEqual", []int{}},
OpNotEqual: {"OpNotEqual", []int{}},
OpGreaterThan: {"OpGreaterThan", []int{}},
OpMinus: {"OpMinus", []int{}},
OpBang: {"OpBang", []int{}},
OpConstant: {"OpConstant", []int{2}},
OpAdd: {"OpAdd", []int{}},
OpPop: {"OpPop", []int{}},
OpSub: {"OpSub", []int{}},
OpMul: {"OpMul", []int{}},
OpDiv: {"OpDiv", []int{}},
OpTrue: {"OpTrue", []int{}},
OpFalse: {"OpFalse", []int{}},
OpEqual: {"OpEqual", []int{}},
OpNotEqual: {"OpNotEqual", []int{}},
OpGreaterThan: {"OpGreaterThan", []int{}},
OpMinus: {"OpMinus", []int{}},
OpBang: {"OpBang", []int{}},
OpJumpNotTruthy: {"OpJumpNotTruthy", []int{2}},
OpJump: {"OpJump", []int{2}},
}
func Lookup(op byte) (*Definition, error) {

View File

@@ -7,15 +7,25 @@ import (
"monkey/object"
)
type EmittedInstruction struct {
Opcode code.Opcode
Position int
}
type Compiler struct {
instructions code.Instructions
constants []object.Object
lastInstruction EmittedInstruction
previousInstruction EmittedInstruction
}
func New() *Compiler {
return &Compiler{
instructions: code.Instructions{},
constants: []object.Object{},
instructions: code.Instructions{},
constants: []object.Object{},
lastInstruction: EmittedInstruction{},
previousInstruction: EmittedInstruction{},
}
}
@@ -106,6 +116,56 @@ func (c *Compiler) Compile(node ast.Node) error {
return fmt.Errorf("unknown operator %s", node.Operator)
}
case *ast.IfExpression:
err := c.Compile(node.Condition)
if err != nil {
return err
}
// Emit an `OpJumpNotTruthy` with a bogus value
jumpNotTruthyPos := c.emit(code.OpJumpNotTruthy, 9999)
err = c.Compile(node.Consequence)
if err != nil {
return err
}
if c.lastInstructionIsPop() {
c.removeLastPop()
}
if node.Alternative == nil {
afterConsequencePos := len(c.instructions)
c.changeOperand(jumpNotTruthyPos, afterConsequencePos)
} else {
// Emit an `OnJump` with a bogus value
jumpPos := c.emit(code.OpJump, 9999)
afterConsequencePos := len(c.instructions)
c.changeOperand(jumpNotTruthyPos, afterConsequencePos)
err = c.Compile(node.Alternative)
if err != nil {
return err
}
if c.lastInstructionIsPop() {
c.removeLastPop()
}
afterAlternativePos := len(c.instructions)
c.changeOperand(jumpPos, afterAlternativePos)
}
case *ast.BlockStatement:
for _, s := range node.Statements {
err := c.Compile(s)
if err != nil {
return err
}
}
}
return nil
@@ -119,9 +179,20 @@ func (c *Compiler) addConstant(obj object.Object) int {
func (c *Compiler) emit(op code.Opcode, operands ...int) int {
ins := code.Make(op, operands...)
pos := c.addInstruction(ins)
c.setLastInstruction(op, pos)
return pos
}
func (c *Compiler) setLastInstruction(op code.Opcode, pos int) {
previous := c.lastInstruction
last := EmittedInstruction{Opcode: op, Position: pos}
c.previousInstruction = previous
c.lastInstruction = last
}
func (c *Compiler) Bytecode() *Bytecode {
return &Bytecode{
Instructions: c.instructions,
@@ -135,6 +206,28 @@ func (c *Compiler) addInstruction(ins []byte) int {
return postNewInstruction
}
func (c *Compiler) lastInstructionIsPop() bool {
return c.lastInstruction.Opcode == code.OpPop
}
func (c *Compiler) removeLastPop() {
c.instructions = c.instructions[:c.lastInstruction.Position]
c.lastInstruction = c.previousInstruction
}
func (c *Compiler) replaceInstruction(pos int, newInstruction []byte) {
for i := 0; i < len(newInstruction); i++ {
c.instructions[pos+i] = newInstruction[i]
}
}
func (c *Compiler) changeOperand(opPos int, operand int) {
op := code.Opcode(c.instructions[opPos])
newInstruction := code.Make(op, operand)
c.replaceInstruction(opPos, newInstruction)
}
type Bytecode struct {
Instructions code.Instructions
Constants []object.Object

View File

@@ -174,6 +174,55 @@ func TestBooleanExpressions(t *testing.T) {
runCompilerTests(t, tests)
}
func TestConditionals(t *testing.T) {
tests := []compilerTestCase{
{
input: `
if (true) { 10 }; 3333;
`,
expectedConstants: []interface{}{10, 3333},
expectedInstructions: []code.Instructions{
code.Make(code.OpTrue),
// 0001
code.Make(code.OpJumpNotTruthy, 7),
// 0004
code.Make(code.OpConstant, 0),
// 0007
code.Make(code.OpPop),
// 0008
code.Make(code.OpConstant, 1),
// 0011
code.Make(code.OpPop),
},
}, {
input: `
if (true) { 10 } else { 20 }; 3333;
`,
expectedConstants: []interface{}{10, 20, 3333},
expectedInstructions: []code.Instructions{
// 0000
code.Make(code.OpTrue),
// 0001
code.Make(code.OpJumpNotTruthy, 10),
// 0004
code.Make(code.OpConstant, 0),
// 0007
code.Make(code.OpJump, 13),
// 0010
code.Make(code.OpConstant, 1),
// 0013
code.Make(code.OpPop),
// 0014
code.Make(code.OpConstant, 2),
// 0017
code.Make(code.OpPop),
},
},
}
runCompilerTests(t, tests)
}
func runCompilerTests(t *testing.T, tests []compilerTestCase) {
t.Helper()

View File

@@ -87,12 +87,35 @@ func (vm *VM) Run() error {
return err
}
case code.OpJump:
pos := int(code.ReadUint16(vm.instructions[ip+1:]))
ip = pos - 1
case code.OpJumpNotTruthy:
pos := int(code.ReadUint16(vm.instructions[ip+1:]))
ip += 2
condition := vm.pop()
if !isTruthy(condition) {
ip = pos - 1
}
}
}
return nil
}
func isTruthy(obj object.Object) bool {
switch obj := obj.(type) {
case *object.Boolean:
return obj.Value
default:
return true
}
}
func (vm *VM) push(o object.Object) error {
if vm.sp >= StackSize {
return fmt.Errorf("stack overflow")

View File

@@ -142,3 +142,17 @@ func TestBooleanExpressions(t *testing.T) {
runVmTests(t, tests)
}
func TestConditionals(t *testing.T) {
tests := []vmTestCase{
{"if (true) { 10 }", 10},
{"if (true) { 10 } else { 20 }", 10},
{"if (false) { 10 } else { 20 } ", 20},
{"if (1) { 10 }", 10},
{"if (1 < 2) { 10 }", 10},
{"if (1 < 2) { 10 } else { 20 }", 10},
{"if (1 > 2) { 10 } else { 20 }", 20},
}
runVmTests(t, tests)
}