From fe78b7069b6b56a61779159a215e95db4e1f326e Mon Sep 17 00:00:00 2001 From: Chuck Smith Date: Wed, 24 Jan 2024 19:35:32 -0500 Subject: [PATCH] compiler! --- code/code.go | 116 ++++++++++++++++++++++++++++++++++++ code/code_test.go | 79 +++++++++++++++++++++++++ compiler/compiler.go | 84 ++++++++++++++++++++++++++ compiler/compiler_test.go | 120 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 399 insertions(+) create mode 100644 code/code.go create mode 100644 code/code_test.go create mode 100644 compiler/compiler.go create mode 100644 compiler/compiler_test.go diff --git a/code/code.go b/code/code.go new file mode 100644 index 0000000..3d6997b --- /dev/null +++ b/code/code.go @@ -0,0 +1,116 @@ +package code + +import ( + "bytes" + "encoding/binary" + "fmt" +) + +type Instructions []byte + +type Opcode byte + +const ( + OpConstant Opcode = iota +) + +type Definition struct { + Name string + OperandWidths []int +} + +var definitions = map[Opcode]*Definition{ + OpConstant: {"OpConstant", []int{2}}, +} + +func Lookup(op byte) (*Definition, error) { + def, ok := definitions[Opcode(op)] + if !ok { + return nil, fmt.Errorf("opcode %d undefined", op) + } + + return def, nil +} + +func Make(op Opcode, operands ...int) []byte { + def, ok := definitions[op] + if !ok { + return []byte{} + } + + instructions := 1 + for _, w := range def.OperandWidths { + instructions += w + } + + instruction := make([]byte, instructions) + instruction[0] = byte(op) + + offset := 1 + for i, o := range operands { + width := def.OperandWidths[i] + switch width { + case 2: + binary.BigEndian.PutUint16(instruction[offset:], uint16(o)) + } + offset += width + } + + return instruction +} + +func (ins Instructions) String() string { + var out bytes.Buffer + + i := 0 + for i < len(ins) { + def, err := Lookup(ins[i]) + if err != nil { + fmt.Fprintf(&out, "ERROR: %s\n", err) + continue + } + + operands, read := ReadOperands(def, ins[i+1:]) + + fmt.Fprintf(&out, "%04d %s\n", i, ins.fmtInstruction(def, operands)) + + i += 1 + read + } + + return out.String() +} + +func (ins Instructions) fmtInstruction(def *Definition, operands []int) string { + operandCount := len(def.OperandWidths) + + if len(operands) != operandCount { + return fmt.Sprintf("ERROR: operand len %d does not match defined %d\n", len(operands), operandCount) + } + + switch operandCount { + case 1: + return fmt.Sprintf("%s %d", def.Name, operands[0]) + } + + return fmt.Sprintf("ERROR: unhandled operandCount for %s\n", def.Name) +} + +func ReadOperands(def *Definition, ins Instructions) ([]int, int) { + operands := make([]int, len(def.OperandWidths)) + offset := 0 + + for i, width := range def.OperandWidths { + switch width { + case 2: + operands[i] = int(ReadUint16(ins[offset:])) + } + + offset += width + } + + return operands, offset +} + +func ReadUint16(ins Instructions) uint16 { + return binary.BigEndian.Uint16(ins) +} diff --git a/code/code_test.go b/code/code_test.go new file mode 100644 index 0000000..0d1a6e4 --- /dev/null +++ b/code/code_test.go @@ -0,0 +1,79 @@ +package code + +import "testing" + +func TestMake(t *testing.T) { + test := []struct { + op Opcode + operands []int + expected []byte + }{ + {OpConstant, []int{65534}, []byte{byte(OpConstant), 255, 254}}, + } + + for _, tt := range test { + instructions := Make(tt.op, tt.operands...) + + if len(instructions) != len(tt.expected) { + t.Errorf("instruction has wrong length. want=%d, got=%d", len(tt.expected), len(instructions)) + } + + for i, b := range tt.expected { + if instructions[i] != tt.expected[i] { + t.Errorf("wrong byte at pos %d. want=%d, got=%d", i, b, instructions[i]) + } + } + } +} + +func TestInstructions(t *testing.T) { + instructions := []Instructions{ + Make(OpConstant, 1), + Make(OpConstant, 2), + Make(OpConstant, 65535), + } + + expected := `0000 OpConstant 1 +0003 OpConstant 2 +0006 OpConstant 65535 +` + + concatted := Instructions{} + for _, ins := range instructions { + concatted = append(concatted, ins...) + } + + if concatted.String() != expected { + t.Errorf("instructions wrong formatted.\nwant=%q\ngot=%q", expected, concatted.String()) + } +} + +func TestOperands(t *testing.T) { + tests := []struct { + op Opcode + operands []int + bytesRead int + }{ + {OpConstant, []int{65535}, 2}, + } + + for _, tt := range tests { + instruction := Make(tt.op, tt.operands...) + + def, err := Lookup(byte(tt.op)) + if err != nil { + t.Fatalf("definition not found: %q\n", err) + } + + operandsRead, n := ReadOperands(def, instruction[1:]) + if n != tt.bytesRead { + t.Fatalf("n wrong. want=%d, got=%d", tt.bytesRead, n) + } + + for i, want := range tt.operands { + if operandsRead[i] != want { + t.Errorf("operand wrong. want=%d, got=%d", want, operandsRead[i]) + } + } + } +} diff --git a/compiler/compiler.go b/compiler/compiler.go new file mode 100644 index 0000000..3475874 --- /dev/null +++ b/compiler/compiler.go @@ -0,0 +1,84 @@ +package compiler + +import ( + "monkey/ast" + "monkey/code" + "monkey/object" +) + +type Compiler struct { + instructions code.Instructions + constants []object.Object +} + +func New() *Compiler { + return &Compiler{ + instructions: code.Instructions{}, + constants: []object.Object{}, + } +} + +func (c *Compiler) Compile(node ast.Node) error { + switch node := node.(type) { + case *ast.Program: + for _, s := range node.Statements { + err := c.Compile(s) + if err != nil { + return err + } + } + + case *ast.ExpressionStatement: + err := c.Compile(node.Expression) + if err != nil { + return err + } + + case *ast.InfixExpression: + err := c.Compile(node.Left) + if err != nil { + return err + } + + err = c.Compile(node.Right) + if err != nil { + return err + } + + case *ast.IntegerLiteral: + integer := &object.Integer{Value: node.Value} + c.emit(code.OpConstant, c.addConstant(integer)) + + } + + return nil +} + +func (c *Compiler) addConstant(obj object.Object) int { + c.constants = append(c.constants, obj) + return len(c.constants) - 1 +} + +func (c *Compiler) emit(op code.Opcode, operands ...int) int { + ins := code.Make(op, operands...) + pos := c.addInstruction(ins) + return pos +} + +func (c *Compiler) Bytecode() *Bytecode { + return &Bytecode{ + Instructions: c.instructions, + Constants: c.constants, + } +} + +func (c *Compiler) addInstruction(ins []byte) int { + postNewInstruction := len(c.instructions) + c.instructions = append(c.instructions, ins...) + return postNewInstruction +} + +type Bytecode struct { + Instructions code.Instructions + Constants []object.Object +} diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go new file mode 100644 index 0000000..3a50232 --- /dev/null +++ b/compiler/compiler_test.go @@ -0,0 +1,120 @@ +package compiler + +import ( + "fmt" + "monkey/ast" + "monkey/code" + "monkey/lexer" + "monkey/object" + "monkey/parser" + "testing" +) + +type compilerTestCase struct { + input string + expectedConstants []interface{} + expectedInstructions []code.Instructions +} + +func TestIntegerArithmetic(t *testing.T) { + tests := []compilerTestCase{ + { + input: "1 + 2", + expectedConstants: []interface{}{1, 2}, + expectedInstructions: []code.Instructions{ + code.Make(code.OpConstant, 0), + code.Make(code.OpConstant, 1), + }, + }, + } + + runCompilerTests(t, tests) +} + +func runCompilerTests(t *testing.T, tests []compilerTestCase) { + t.Helper() + + for _, tt := range tests { + program := parse(tt.input) + + compiler := New() + err := compiler.Compile(program) + if err != nil { + t.Fatalf("compiler error: %s", err) + } + + bytecode := compiler.Bytecode() + + err = testInstructions(tt.expectedInstructions, bytecode.Instructions) + if err != nil { + t.Fatalf("testInstructions failed: %s", err) + } + + err = testConstants(t, tt.expectedConstants, bytecode.Constants) + if err != nil { + t.Fatalf("testConstants failed: %s", err) + } + } +} + +func parse(input string) *ast.Program { + l := lexer.New(input) + p := parser.New(l) + return p.ParseProgram() +} + +func testInstructions(expected []code.Instructions, actual code.Instructions) error { + concatted := concatInstructions(expected) + + if len(actual) != len(concatted) { + return fmt.Errorf("wrong instructions length.\nwant=%q\ngot =%q", concatted, actual) + } + + for i, ins := range concatted { + if actual[i] != ins { + return fmt.Errorf("wrong instruction at %d.\nwant=%q\ngot =%q", i, concatted, actual) + } + } + + return nil +} + +func concatInstructions(s []code.Instructions) code.Instructions { + out := code.Instructions{} + + for _, ins := range s { + out = append(out, ins...) + } + + return out +} + +func testConstants(t *testing.T, expected []interface{}, actual []object.Object) error { + if len(expected) != len(actual) { + return fmt.Errorf("wrong number of constants. got=%d, want=%d", len(actual), len(expected)) + } + + for i, constant := range expected { + switch constant := constant.(type) { + case int: + err := testIntegerObject(int64(constant), actual[i]) + if err != nil { + return fmt.Errorf("constant %d = testIntegerObject failed : %s", i, err) + } + } + } + return nil +} + +func testIntegerObject(expected int64, actual object.Object) interface{} { + result, ok := actual.(*object.Integer) + if !ok { + return fmt.Errorf("object is not Integer. got=%T (%+v", actual, actual) + } + + if result.Value != expected { + return fmt.Errorf("object has wrong value. got=%d, want=%d", result.Value, expected) + } + + return nil +}