From 0a1201f1bcf4dfb1bfa741b0c34924015c6d0a6e Mon Sep 17 00:00:00 2001 From: Chuck Smith Date: Mon, 26 Feb 2024 16:59:37 -0500 Subject: [PATCH] hash --- code/code.go | 2 ++ compiler/compiler.go | 23 +++++++++++++++++++ compiler/compiler_test.go | 45 +++++++++++++++++++++++++++++++++++++ vm/vm.go | 35 +++++++++++++++++++++++++++++ vm/vm_test.go | 47 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 152 insertions(+) diff --git a/code/code.go b/code/code.go index 7b8c2ea..c9bf9f5 100644 --- a/code/code.go +++ b/code/code.go @@ -30,6 +30,7 @@ const ( OpGetGlobal OpSetGlobal OpArray + OpHash ) type Definition struct { @@ -57,6 +58,7 @@ var definitions = map[Opcode]*Definition{ OpGetGlobal: {"OpGetGlobal", []int{2}}, OpSetGlobal: {"OpSetGlobal", []int{2}}, OpArray: {"OpArray", []int{2}}, + OpHash: {"OpHash", []int{2}}, } func Lookup(op byte) (*Definition, error) { diff --git a/compiler/compiler.go b/compiler/compiler.go index 24dc0f6..44304c2 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -5,6 +5,7 @@ import ( "monkey/ast" "monkey/code" "monkey/object" + "sort" ) type EmittedInstruction struct { @@ -204,6 +205,28 @@ func (c *Compiler) Compile(node ast.Node) error { c.emit(code.OpArray, len(node.Elements)) + case *ast.HashLiteral: + keys := []ast.Expression{} + for k := range node.Pairs { + keys = append(keys, k) + } + sort.Slice(keys, func(i, j int) bool { + return keys[i].String() < keys[j].String() + }) + + for _, k := range keys { + err := c.Compile(k) + if err != nil { + return err + } + err = c.Compile(node.Pairs[k]) + if err != nil { + return err + } + } + + c.emit(code.OpHash, len(node.Pairs)*2) + } return nil diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go index 150b394..f5df3f8 100644 --- a/compiler/compiler_test.go +++ b/compiler/compiler_test.go @@ -345,6 +345,51 @@ func TestArrayLiterals(t *testing.T) { runCompilerTests(t, tests) } +func TestHashLiterals(t *testing.T) { + tests := []compilerTestCase{ + { + input: "{}", + expectedConstants: []interface{}{}, + expectedInstructions: []code.Instructions{ + code.Make(code.OpHash, 0), + code.Make(code.OpPop), + }, + }, + { + input: "{1: 2, 3: 4, 5: 6}", + expectedConstants: []interface{}{1, 2, 3, 4, 5, 6}, + expectedInstructions: []code.Instructions{ + code.Make(code.OpConstant, 0), + code.Make(code.OpConstant, 1), + code.Make(code.OpConstant, 2), + code.Make(code.OpConstant, 3), + code.Make(code.OpConstant, 4), + code.Make(code.OpConstant, 5), + code.Make(code.OpHash, 6), + code.Make(code.OpPop), + }, + }, + { + input: "{1: 2 + 3, 4: 5 * 6}", + expectedConstants: []interface{}{1, 2, 3, 4, 5, 6}, + expectedInstructions: []code.Instructions{ + code.Make(code.OpConstant, 0), + code.Make(code.OpConstant, 1), + code.Make(code.OpConstant, 2), + code.Make(code.OpAdd), + code.Make(code.OpConstant, 3), + code.Make(code.OpConstant, 4), + code.Make(code.OpConstant, 5), + code.Make(code.OpMul), + code.Make(code.OpHash, 4), + code.Make(code.OpPop), + }, + }, + } + + runCompilerTests(t, tests) +} + func runCompilerTests(t *testing.T, tests []compilerTestCase) { t.Helper() diff --git a/vm/vm.go b/vm/vm.go index 0e32f19..332c9ad 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -145,12 +145,47 @@ func (vm *VM) Run() error { return err } + case code.OpHash: + numElements := int(code.ReadUint16(vm.instructions[ip+1:])) + ip += 2 + + hash, err := vm.buildHash(vm.sp-numElements, vm.sp) + if err != nil { + return err + } + vm.sp = vm.sp - numElements + + err = vm.push(hash) + if err != nil { + return err + } + } } return nil } +func (vm *VM) buildHash(startIndex, endIndex int) (object.Object, error) { + hashedPairs := make(map[object.HashKey]object.HashPair) + + for i := startIndex; i < endIndex; i += 2 { + key := vm.stack[i] + value := vm.stack[i+1] + + pair := object.HashPair{Key: key, Value: value} + + hashKey, ok := key.(object.Hashable) + if !ok { + return nil, fmt.Errorf("unusable as hash key: %s", key.Type()) + } + + hashedPairs[hashKey.HashKey()] = pair + } + + return &object.Hash{Pairs: hashedPairs}, nil +} + func (vm *VM) buildArray(startIndex, endIndex int) object.Object { elements := make([]object.Object, endIndex-startIndex) diff --git a/vm/vm_test.go b/vm/vm_test.go index bb7783b..62c9f54 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -79,6 +79,29 @@ func testExpectedObject(t *testing.T, expected interface{}, actual object.Object } } + case map[object.HashKey]int64: + hash, ok := actual.(*object.Hash) + if !ok { + t.Errorf("object not Hash: %T (%+v)", actual, actual) + } + + if len(hash.Pairs) != len(expected) { + t.Errorf("wrong num of Pairs. want=%d, got=%d", len(expected), len(hash.Pairs)) + return + } + + for expectedKey, expectedValue := range expected { + pair, ok := hash.Pairs[expectedKey] + if !ok { + t.Errorf("no pair for given key in Pairs") + } + + err := testIntegerObject(expectedValue, pair.Value) + if err != nil { + t.Errorf("testIntgerObject failed: %s", err) + } + } + case *object.Null: if actual != Null { t.Errorf("object is not Null: %T (%+v)", actual, actual) @@ -232,3 +255,27 @@ func TestArrayLiterals(t *testing.T) { runVmTests(t, tests) } + +func TestHashLiterals(t *testing.T) { + tests := []vmTestCase{ + { + "{}", map[object.HashKey]int64{}, + }, + { + "{1: 2, 2: 3}", + map[object.HashKey]int64{ + (&object.Integer{Value: 1}).HashKey(): 2, + (&object.Integer{Value: 2}).HashKey(): 3, + }, + }, + { + "{1 + 1: 2 * 2, 3 + 3: 4 * 4}", + map[object.HashKey]int64{ + (&object.Integer{Value: 2}).HashKey(): 4, + (&object.Integer{Value: 6}).HashKey(): 16, + }, + }, + } + + runVmTests(t, tests) +}