diff --git a/common/interface.go b/common/interface.go index 855d730a827..f9988f01ae3 100644 --- a/common/interface.go +++ b/common/interface.go @@ -375,8 +375,8 @@ type ExecutionOrderGetter interface { // TrieBatcher defines the methods needed for a trie batcher type TrieBatcher interface { BatchHandler - GetSortedDataForInsertion() ([]string, map[string]core.TrieData) - GetSortedDataForRemoval() []string + GetSortedDataForInsertion() []core.TrieData + GetSortedDataForRemoval() []core.TrieData IsInterfaceNil() bool } @@ -390,7 +390,7 @@ type TrieBatchManager interface { // BatchHandler is the interface for the batch handler type BatchHandler interface { - Add(key []byte, data core.TrieData) + Add(data core.TrieData) MarkForRemoval(key []byte) Get(key []byte) ([]byte, bool) } diff --git a/integrationTests/vm/txsFee/migrateDataTrie_test.go b/integrationTests/vm/txsFee/migrateDataTrie_test.go index 9c62a4f30fd..a33d57883d4 100644 --- a/integrationTests/vm/txsFee/migrateDataTrie_test.go +++ b/integrationTests/vm/txsFee/migrateDataTrie_test.go @@ -185,6 +185,8 @@ func TestMigrateDataTrieBuiltInFunc(t *testing.T) { err = testContext.Accounts.SaveAccount(acc) require.Nil(t, err) + _, err = testContext.Accounts.Commit() + require.Nil(t, err) acc = getAccount(t, testContext, sndAddr) diff --git a/trie/branchNode.go b/trie/branchNode.go index 39f8402d289..68b64a7af27 100644 --- a/trie/branchNode.go +++ b/trie/branchNode.go @@ -443,61 +443,139 @@ func (bn *branchNode) getNext(key []byte, db common.TrieStorageInteractor) (node return bn.children[childPos], key, nil } -func (bn *branchNode) insert(newData core.TrieData, db common.TrieStorageInteractor) (node, [][]byte, error) { +func (bn *branchNode) insert(newData []core.TrieData, db common.TrieStorageInteractor) (node, [][]byte, error) { emptyHashes := make([][]byte, 0) err := bn.isEmptyOrNil() if err != nil { return nil, emptyHashes, fmt.Errorf("insert error %w", err) } - if len(newData.Key) == 0 { - return nil, emptyHashes, ErrValueTooShort + dataForInsertion, err := splitDataForChildren(newData) + if err != nil { + return nil, emptyHashes, err } - childPos := newData.Key[firstByte] - if childPosOutOfRange(childPos) { - return nil, emptyHashes, ErrChildPosOutOfRange + modifiedHashes := make([][]byte, 0) + bnHasBeenModified := false + + for childPos := range dataForInsertion { + if len(dataForInsertion[childPos]) == 0 { + continue + } + err = resolveIfCollapsed(bn, byte(childPos), db) + if err != nil { + return nil, emptyHashes, err + } + + if bn.children[childPos] == nil { + newModifiedHashes, err := bn.insertOnNilChild(dataForInsertion[childPos], byte(childPos), db) + if err != nil { + return nil, emptyHashes, err + } + modifiedHashes = append(modifiedHashes, newModifiedHashes...) + bnHasBeenModified = true + + continue + } + + dirty, newModifiedHashes, err := bn.insertOnExistingChild(dataForInsertion[childPos], byte(childPos), db) + if err != nil { + return nil, emptyHashes, err + } + if dirty { + bnHasBeenModified = true + } + modifiedHashes = append(modifiedHashes, newModifiedHashes...) } - newData.Key = newData.Key[1:] - err = resolveIfCollapsed(bn, childPos, db) - if err != nil { - return nil, emptyHashes, err + if bnHasBeenModified { + return bn, modifiedHashes, nil } - if bn.children[childPos] == nil { - return bn.insertOnNilChild(newData, childPos) + return nil, emptyHashes, nil +} + +// the prerequisite for this to work is that the data is already sorted +func splitDataForChildren(newSortedData []core.TrieData) ([][]core.TrieData, error) { + if len(newSortedData) == 0 { + return nil, ErrValueTooShort + } + childrenData := make([][]core.TrieData, nrOfChildren) + + startIndex := 0 + childPos := byte(0) + prevChildPos := byte(0) + for i := range newSortedData { + if len(newSortedData[i].Key) == 0 { + return nil, ErrValueTooShort + } + childPos = newSortedData[i].Key[firstByte] + if childPosOutOfRange(childPos) { + return nil, ErrChildPosOutOfRange + } + newSortedData[i].Key = newSortedData[i].Key[1:] + + if i == 0 { + prevChildPos = childPos + continue + } + + if childPos == prevChildPos { + continue + } + + childrenData[prevChildPos] = newSortedData[startIndex:i] + startIndex = i + prevChildPos = childPos } - return bn.insertOnExistingChild(newData, childPos, db) + childrenData[childPos] = newSortedData[startIndex:] + return childrenData, nil } -func (bn *branchNode) insertOnNilChild(newData core.TrieData, childPos byte) (node, [][]byte, error) { - newLn, err := newLeafNode(newData, bn.marsh, bn.hasher) - if err != nil { - return nil, [][]byte{}, err +func (bn *branchNode) insertOnNilChild(newData []core.TrieData, childPos byte, db common.TrieStorageInteractor) ([][]byte, error) { + if len(newData) == 0 { + return [][]byte{}, ErrValueTooShort } + var newNode node modifiedHashes := make([][]byte, 0) - modifiedHashes, err = bn.modifyNodeAfterInsert(modifiedHashes, childPos, newLn) + + newNode, err := newLeafNode(newData[0], bn.marsh, bn.hasher) + if err != nil { + return [][]byte{}, err + } + + if len(newData) > 1 { + newNode, modifiedHashes, err = newNode.insert(newData[1:], db) + if check.IfNil(newNode) || err != nil { + return [][]byte{}, err + } + } + + modifiedHashes, err = bn.modifyNodeAfterInsert(modifiedHashes, childPos, newNode) if err != nil { - return nil, [][]byte{}, err + return [][]byte{}, err } - return bn, modifiedHashes, nil + return modifiedHashes, nil } -func (bn *branchNode) insertOnExistingChild(newData core.TrieData, childPos byte, db common.TrieStorageInteractor) (node, [][]byte, error) { +func (bn *branchNode) insertOnExistingChild(newData []core.TrieData, childPos byte, db common.TrieStorageInteractor) (bool, [][]byte, error) { newNode, modifiedHashes, err := bn.children[childPos].insert(newData, db) - if check.IfNil(newNode) || err != nil { - return nil, [][]byte{}, err + if err != nil { + return false, [][]byte{}, err + } + + if check.IfNil(newNode) { + return false, [][]byte{}, nil } modifiedHashes, err = bn.modifyNodeAfterInsert(modifiedHashes, childPos, newNode) if err != nil { - return nil, [][]byte{}, err + return false, [][]byte{}, err } - return bn, modifiedHashes, nil + return true, modifiedHashes, nil } func (bn *branchNode) modifyNodeAfterInsert(modifiedHashes [][]byte, childPos byte, newNode node) ([][]byte, error) { @@ -518,45 +596,65 @@ func (bn *branchNode) modifyNodeAfterInsert(modifiedHashes [][]byte, childPos by return modifiedHashes, nil } -func (bn *branchNode) delete(key []byte, db common.TrieStorageInteractor) (bool, node, [][]byte, error) { +func (bn *branchNode) delete(data []core.TrieData, db common.TrieStorageInteractor) (bool, node, [][]byte, error) { emptyHashes := make([][]byte, 0) err := bn.isEmptyOrNil() if err != nil { return false, nil, emptyHashes, fmt.Errorf("delete error %w", err) } - if len(key) == 0 { - return false, nil, emptyHashes, ErrValueTooShort - } - childPos := key[firstByte] - if childPosOutOfRange(childPos) { - return false, nil, emptyHashes, ErrChildPosOutOfRange - } - key = key[1:] - err = resolveIfCollapsed(bn, childPos, db) + + dataForRemoval, err := splitDataForChildren(data) if err != nil { return false, nil, emptyHashes, err } + modifiedHashes := make([][]byte, 0) + oldHash := make([]byte, len(bn.hash)) + copy(oldHash, bn.hash) + hasBeenModified := false - if bn.children[childPos] == nil { - return false, bn, emptyHashes, nil - } + for childPos := range dataForRemoval { + if len(dataForRemoval[childPos]) == 0 { + continue + } + err = resolveIfCollapsed(bn, byte(childPos), db) + if err != nil { + return false, nil, emptyHashes, err + } + + if bn.children[childPos] == nil { + continue + } + + dirty, newNode, oldHashes, err := bn.children[childPos].delete(dataForRemoval[childPos], db) + if err != nil { + return false, bn, emptyHashes, err + } + if !dirty { + continue + } + + hasBeenModified = true + err = bn.setNewChild(byte(childPos), newNode) + if err != nil { + return false, nil, emptyHashes, err + } - dirty, newNode, oldHashes, err := bn.children[childPos].delete(key, db) - if !dirty || err != nil { - return false, bn, emptyHashes, err + modifiedHashes = append(modifiedHashes, oldHashes...) } - if !bn.dirty { - oldHashes = append(oldHashes, bn.hash) + if !hasBeenModified { + return false, bn, emptyHashes, nil } - err = bn.setNewChild(childPos, newNode) - if err != nil { - return false, nil, emptyHashes, err + if len(oldHash) != 0 { + modifiedHashes = append(modifiedHashes, oldHash) } + bn.dirty = true numChildren, pos := getChildPosition(bn) - + if numChildren == 0 { + return true, nil, modifiedHashes, nil + } if numChildren == 1 { err = resolveIfCollapsed(bn, byte(pos), db) if err != nil { @@ -569,21 +667,19 @@ func (bn *branchNode) delete(key []byte, db common.TrieStorageInteractor) (bool, } var newChildHash bool - newNode, newChildHash, err = bn.children[pos].reduceNode(pos) + newNode, newChildHash, err := bn.children[pos].reduceNode(pos) if err != nil { return false, nil, emptyHashes, err } if newChildHash && !bn.children[pos].isDirty() { - oldHashes = append(oldHashes, bn.children[pos].getHash()) + modifiedHashes = append(modifiedHashes, bn.children[pos].getHash()) } - return true, newNode, oldHashes, nil + return true, newNode, modifiedHashes, nil } - bn.dirty = dirty - - return true, bn, oldHashes, nil + return true, bn, modifiedHashes, nil } func (bn *branchNode) setNewChild(childPos byte, newNode node) error { diff --git a/trie/branchNode_test.go b/trie/branchNode_test.go index ba45634af24..dbf6e57026a 100644 --- a/trie/branchNode_test.go +++ b/trie/branchNode_test.go @@ -622,7 +622,8 @@ func TestBranchNode_insert(t *testing.T) { bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) nodeKey := []byte{0, 2, 3} - newBn, _, err := bn.insert(getTrieDataWithDefaultVersion(string(nodeKey), "dogs"), nil) + data := []core.TrieData{getTrieDataWithDefaultVersion(string(nodeKey), "dogs")} + newBn, _, err := bn.insert(data, nil) assert.NotNil(t, newBn) assert.Nil(t, err) @@ -636,7 +637,8 @@ func TestBranchNode_insertEmptyKey(t *testing.T) { bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) - newBn, _, err := bn.insert(getTrieDataWithDefaultVersion("", "dogs"), nil) + data := []core.TrieData{getTrieDataWithDefaultVersion("", "dogs")} + newBn, _, err := bn.insert(data, nil) assert.Equal(t, ErrValueTooShort, err) assert.Nil(t, newBn) } @@ -646,7 +648,8 @@ func TestBranchNode_insertChildPosOutOfRange(t *testing.T) { bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) - newBn, _, err := bn.insert(getTrieDataWithDefaultVersion("dog", "dogs"), nil) + data := []core.TrieData{getTrieDataWithDefaultVersion("dog", "dogs")} + newBn, _, err := bn.insert(data, nil) assert.Equal(t, ErrChildPosOutOfRange, err) assert.Nil(t, newBn) } @@ -662,7 +665,8 @@ func TestBranchNode_insertCollapsedNode(t *testing.T) { _ = bn.setHash() _ = bn.commitDirty(0, 5, db, db) - newBn, _, err := collapsedBn.insert(getTrieDataWithDefaultVersion(string(key), "dogs"), db) + data := []core.TrieData{getTrieDataWithDefaultVersion(string(key), "dogs")} + newBn, _, err := collapsedBn.insert(data, db) assert.NotNil(t, newBn) assert.Nil(t, err) @@ -684,7 +688,8 @@ func TestBranchNode_insertInStoredBnOnExistingPos(t *testing.T) { lnHash := ln.getHash() expectedHashes := [][]byte{lnHash, bnHash} - newNode, oldHashes, err := bn.insert(getTrieDataWithDefaultVersion(string(key), "dogs"), db) + data := []core.TrieData{getTrieDataWithDefaultVersion(string(key), "dogs")} + newNode, oldHashes, err := bn.insert(data, db) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, expectedHashes, oldHashes) @@ -702,7 +707,8 @@ func TestBranchNode_insertInStoredBnOnNilPos(t *testing.T) { bnHash := bn.getHash() expectedHashes := [][]byte{bnHash} - newNode, oldHashes, err := bn.insert(getTrieDataWithDefaultVersion(string(key), "dogs"), db) + data := []core.TrieData{getTrieDataWithDefaultVersion(string(key), "dogs")} + newNode, oldHashes, err := bn.insert(data, db) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, expectedHashes, oldHashes) @@ -715,7 +721,8 @@ func TestBranchNode_insertInDirtyBnOnNilPos(t *testing.T) { nilChildPos := byte(11) key := append([]byte{nilChildPos}, []byte("dog")...) - newNode, oldHashes, err := bn.insert(getTrieDataWithDefaultVersion(string(key), "dogs"), nil) + data := []core.TrieData{getTrieDataWithDefaultVersion(string(key), "dogs")} + newNode, oldHashes, err := bn.insert(data, nil) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, [][]byte{}, oldHashes) @@ -728,7 +735,8 @@ func TestBranchNode_insertInDirtyBnOnExistingPos(t *testing.T) { childPos := byte(2) key := append([]byte{childPos}, []byte("dog")...) - newNode, oldHashes, err := bn.insert(getTrieDataWithDefaultVersion(string(key), "dogs"), nil) + data := []core.TrieData{getTrieDataWithDefaultVersion(string(key), "dogs")} + newNode, oldHashes, err := bn.insert(data, nil) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, [][]byte{}, oldHashes) @@ -739,7 +747,8 @@ func TestBranchNode_insertInNilNode(t *testing.T) { var bn *branchNode - newBn, _, err := bn.insert(getTrieDataWithDefaultVersion("key", "dogs"), nil) + data := []core.TrieData{getTrieDataWithDefaultVersion("key", "dogs")} + newBn, _, err := bn.insert(data, nil) assert.True(t, errors.Is(err, ErrNilBranchNode)) assert.Nil(t, newBn) } @@ -756,8 +765,9 @@ func TestBranchNode_delete(t *testing.T) { childPos := byte(2) key := append([]byte{childPos}, []byte("dog")...) + data := []core.TrieData{{Key: key}} - dirty, newBn, _, err := bn.delete(key, nil) + dirty, newBn, _, err := bn.delete(data, nil) assert.True(t, dirty) assert.Nil(t, err) @@ -780,7 +790,8 @@ func TestBranchNode_deleteFromStoredBn(t *testing.T) { lnHash := ln.getHash() expectedHashes := [][]byte{lnHash, bnHash} - dirty, _, oldHashes, err := bn.delete(lnKey, db) + data := []core.TrieData{{Key: lnKey}} + dirty, _, oldHashes, err := bn.delete(data, db) assert.True(t, dirty) assert.Nil(t, err) assert.Equal(t, expectedHashes, oldHashes) @@ -792,8 +803,9 @@ func TestBranchNode_deleteFromDirtyBn(t *testing.T) { bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) childPos := byte(2) lnKey := append([]byte{childPos}, []byte("dog")...) + data := []core.TrieData{{Key: lnKey}} - dirty, _, oldHashes, err := bn.delete(lnKey, nil) + dirty, _, oldHashes, err := bn.delete(data, nil) assert.True(t, dirty) assert.Nil(t, err) assert.Equal(t, [][]byte{}, oldHashes) @@ -805,8 +817,9 @@ func TestBranchNode_deleteEmptyNode(t *testing.T) { bn := emptyDirtyBranchNode() childPos := byte(2) key := append([]byte{childPos}, []byte("dog")...) + data := []core.TrieData{{Key: key}} - dirty, newBn, _, err := bn.delete(key, nil) + dirty, newBn, _, err := bn.delete(data, nil) assert.False(t, dirty) assert.True(t, errors.Is(err, ErrEmptyBranchNode)) assert.Nil(t, newBn) @@ -818,8 +831,9 @@ func TestBranchNode_deleteNilNode(t *testing.T) { var bn *branchNode childPos := byte(2) key := append([]byte{childPos}, []byte("dog")...) + data := []core.TrieData{{Key: key}} - dirty, newBn, _, err := bn.delete(key, nil) + dirty, newBn, _, err := bn.delete(data, nil) assert.False(t, dirty) assert.True(t, errors.Is(err, ErrNilBranchNode)) assert.Nil(t, newBn) @@ -832,8 +846,9 @@ func TestBranchNode_deleteNonexistentNodeFromChild(t *testing.T) { childPos := byte(2) key := append([]byte{childPos}, []byte("butterfly")...) + data := []core.TrieData{{Key: key}} - dirty, newBn, _, err := bn.delete(key, nil) + dirty, newBn, _, err := bn.delete(data, nil) assert.False(t, dirty) assert.Nil(t, err) assert.Equal(t, bn, newBn) @@ -843,8 +858,9 @@ func TestBranchNode_deleteEmptykey(t *testing.T) { t.Parallel() bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) + data := []core.TrieData{{Key: []byte{}}} - dirty, newBn, _, err := bn.delete([]byte{}, nil) + dirty, newBn, _, err := bn.delete(data, nil) assert.False(t, dirty) assert.Equal(t, ErrValueTooShort, err) assert.Nil(t, newBn) @@ -860,8 +876,9 @@ func TestBranchNode_deleteCollapsedNode(t *testing.T) { childPos := byte(2) key := append([]byte{childPos}, []byte("dog")...) + data := []core.TrieData{{Key: key}} - dirty, newBn, _, err := collapsedBn.delete(key, db) + dirty, newBn, _, err := collapsedBn.delete(data, db) assert.True(t, dirty) assert.Nil(t, err) @@ -885,7 +902,8 @@ func TestBranchNode_deleteAndReduceBn(t *testing.T) { ln, _ := newLeafNode(getTrieDataWithDefaultVersion(string(key), "dog"), bn.marsh, bn.hasher) key = append([]byte{secondChildPos}, []byte("doe")...) - dirty, newBn, _, err := bn.delete(key, nil) + data := []core.TrieData{{Key: key}} + dirty, newBn, _, err := bn.delete(data, nil) assert.True(t, dirty) assert.Nil(t, err) assert.Equal(t, ln, newBn) @@ -1453,7 +1471,7 @@ func TestBranchNode_VerifyChildrenVersionIsSetCorrectlyAfterInsertAndDelete(t *t Value: []byte("value"), Version: 0, } - newBn, _, err := bn.insert(data, &testscommon.MemDbMock{}) + newBn, _, err := bn.insert([]core.TrieData{data}, &testscommon.MemDbMock{}) assert.Nil(t, err) assert.Nil(t, newBn.(*branchNode).ChildrenVersion) }) @@ -1465,8 +1483,9 @@ func TestBranchNode_VerifyChildrenVersionIsSetCorrectlyAfterInsertAndDelete(t *t bn.ChildrenVersion = make([]byte, nrOfChildren) bn.ChildrenVersion[2] = byte(core.AutoBalanceEnabled) childKey := []byte{2, 'd', 'o', 'g'} + data := []core.TrieData{{Key: childKey}} - _, newBn, _, err := bn.delete(childKey, &testscommon.MemDbMock{}) + _, newBn, _, err := bn.delete(data, &testscommon.MemDbMock{}) assert.Nil(t, err) assert.Nil(t, newBn.(*branchNode).ChildrenVersion) }) @@ -1511,3 +1530,474 @@ func TestBranchNode_revertChildrenVersionSliceIfNeeded(t *testing.T) { assert.Nil(t, bn.ChildrenVersion) }) } + +func TestBranchNode_splitDataForChildren(t *testing.T) { + t.Parallel() + + t.Run("empty array returns err", func(t *testing.T) { + t.Parallel() + + var newData []core.TrieData + data, err := splitDataForChildren(newData) + assert.Nil(t, data) + assert.Equal(t, ErrValueTooShort, err) + }) + t.Run("empty key returns err", func(t *testing.T) { + t.Parallel() + + newData := []core.TrieData{ + {Key: []byte{2, 3, 4}}, + {Key: []byte{}}, + } + + data, err := splitDataForChildren(newData) + assert.Nil(t, data) + assert.Equal(t, ErrValueTooShort, err) + }) + t.Run("child pos out of range returns err", func(t *testing.T) { + t.Parallel() + + newData := []core.TrieData{ + {Key: []byte{2, 3, 4}}, + {Key: []byte{17, 2, 3}}, + } + + data, err := splitDataForChildren(newData) + assert.Nil(t, data) + assert.Equal(t, ErrChildPosOutOfRange, err) + }) + t.Run("one child on last pos should work", func(t *testing.T) { + t.Parallel() + + childPos := byte(16) + newData := []core.TrieData{ + {Key: []byte{childPos}}, + } + + data, err := splitDataForChildren(newData) + assert.True(t, len(data) == nrOfChildren) + assert.Nil(t, err) + assert.Equal(t, newData, data[childPos]) + }) + t.Run("all children have same pos should work", func(t *testing.T) { + t.Parallel() + + childPos := byte(2) + newData := []core.TrieData{ + {Key: []byte{childPos, 3, 4}}, + {Key: []byte{childPos, 5, 6}}, + {Key: []byte{childPos, 7, 8}}, + } + + data, err := splitDataForChildren(newData) + assert.True(t, len(data) == nrOfChildren) + assert.Nil(t, err) + for i := range data { + if i != int(childPos) { + assert.Nil(t, data[i]) + continue + } + assert.Equal(t, newData, data[i]) + } + }) + t.Run("all children have different pos should work", func(t *testing.T) { + t.Parallel() + + childPos1 := byte(2) + childPos2 := byte(6) + childPos3 := byte(13) + newData := []core.TrieData{ + {Key: []byte{childPos1, 3, 4}}, + {Key: []byte{childPos2, 5, 6}}, + {Key: []byte{childPos3, 7, 8}}, + } + + data, err := splitDataForChildren(newData) + assert.True(t, len(data) == nrOfChildren) + assert.Nil(t, err) + for i := range data { + if i == int(childPos1) { + assert.Equal(t, []core.TrieData{newData[0]}, data[i]) + } else if i == int(childPos2) { + assert.Equal(t, []core.TrieData{newData[1]}, data[i]) + } else if i == int(childPos3) { + assert.Equal(t, []core.TrieData{newData[2]}, data[i]) + } else { + assert.Nil(t, data[i]) + } + } + }) + t.Run("some children have same pos should work", func(t *testing.T) { + t.Parallel() + + childPos1 := byte(2) + childPos2 := byte(6) + newData := []core.TrieData{ + {Key: []byte{childPos1, 3, 4}}, + {Key: []byte{childPos1, 5, 6}}, + {Key: []byte{childPos2, 7, 8}}, + } + + data, err := splitDataForChildren(newData) + assert.True(t, len(data) == nrOfChildren) + assert.Nil(t, err) + for i := range data { + if i == int(childPos1) { + assert.Equal(t, []core.TrieData{newData[0], newData[1]}, data[i]) + } else if i == int(childPos2) { + assert.Equal(t, []core.TrieData{newData[2]}, data[i]) + } else { + assert.Nil(t, data[i]) + } + } + }) + t.Run("child pos is removed from key", func(t *testing.T) { + t.Parallel() + + childPos := byte(2) + newData := []core.TrieData{ + {Key: []byte{childPos, 3, 4}}, + {Key: []byte{childPos, 5, 6}}, + } + + data, err := splitDataForChildren(newData) + assert.True(t, len(data) == nrOfChildren) + assert.Nil(t, err) + assert.Equal(t, 2, len(data[childPos])) + assert.Equal(t, []byte{3, 4}, newData[0].Key) + assert.Equal(t, []byte{5, 6}, newData[1].Key) + }) +} + +func TestBranchNode_insertOnNilChild(t *testing.T) { + t.Parallel() + + t.Run("empty data should err", func(t *testing.T) { + t.Parallel() + + bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) + var data []core.TrieData + modifiedHashes, err := bn.insertOnNilChild(data, 0, nil) + expectedNumTrieNodesChanged := 0 + assert.Equal(t, expectedNumTrieNodesChanged, len(modifiedHashes)) + assert.Equal(t, ErrValueTooShort, err) + }) + t.Run("insert one child in !dirty node", func(t *testing.T) { + t.Parallel() + + db := testscommon.NewMemDbMock() + bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) + err := bn.commitDirty(0, 5, db, db) + assert.Nil(t, err) + assert.False(t, bn.dirty) + originalHash := bn.getHash() + assert.True(t, len(originalHash) > 0) + newData := []core.TrieData{ + { + Key: []byte{1, 2, 3}, + Value: []byte("value"), + Version: core.AutoBalanceEnabled, + }, + } + childPos := byte(0) + modifiedHashes, err := bn.insertOnNilChild(newData, childPos, db) + assert.Nil(t, err) + expectedNumTrieNodesChanged := 1 + assert.Equal(t, expectedNumTrieNodesChanged, len(modifiedHashes)) + assert.Equal(t, originalHash, modifiedHashes[0]) + assert.True(t, bn.dirty) + assert.NotNil(t, bn.children[childPos]) + assert.Equal(t, byte(core.AutoBalanceEnabled), bn.ChildrenVersion[childPos]) + }) + t.Run("insert one child in dirty node", func(t *testing.T) { + t.Parallel() + + db := testscommon.NewMemDbMock() + bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) + assert.True(t, bn.dirty) + newData := []core.TrieData{ + { + Key: []byte{1, 2, 3}, + Value: []byte("value"), + Version: core.AutoBalanceEnabled, + }, + } + childPos := byte(0) + modifiedHashes, err := bn.insertOnNilChild(newData, childPos, db) + assert.Nil(t, err) + expectedNumTrieNodesChanged := 0 + assert.Equal(t, expectedNumTrieNodesChanged, len(modifiedHashes)) + assert.True(t, bn.dirty) + assert.NotNil(t, bn.children[childPos]) + assert.Equal(t, byte(core.AutoBalanceEnabled), bn.ChildrenVersion[childPos]) + + }) + t.Run("insert multiple children", func(t *testing.T) { + t.Parallel() + + db := testscommon.NewMemDbMock() + bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) + newData := []core.TrieData{ + { + Key: []byte{1, 2, 3}, + Value: []byte("value"), + Version: core.AutoBalanceEnabled, + }, + { + Key: []byte{1, 2, 4}, + Value: []byte("value"), + Version: core.AutoBalanceEnabled, + }, + } + childPos := byte(0) + modifiedHashes, err := bn.insertOnNilChild(newData, childPos, db) + assert.Nil(t, err) + expectedNumTrieNodesChanged := 0 + assert.Equal(t, expectedNumTrieNodesChanged, len(modifiedHashes)) + assert.True(t, bn.dirty) + assert.NotNil(t, bn.children[childPos]) + assert.Equal(t, byte(core.AutoBalanceEnabled), bn.ChildrenVersion[childPos]) + _, ok := bn.children[0].(*extensionNode) + assert.True(t, ok) + }) +} + +func TestBranchNode_insertOnExistingChild(t *testing.T) { + t.Parallel() + + t.Run("insert on existing child multiple children", func(t *testing.T) { + t.Parallel() + + childPos := byte(2) + db := testscommon.NewMemDbMock() + var children [nrOfChildren]node + marshaller, hasher := getTestMarshalizerAndHasher() + children[2], _ = newLeafNode(getTrieDataWithDefaultVersion(string([]byte{1, 2, 5}), "dog"), marshaller, hasher) + children[6], _ = newLeafNode(getTrieDataWithDefaultVersion("doe", "doe"), marshaller, hasher) + bn, _ := newBranchNode(marshaller, hasher) + bn.children = children + newData := []core.TrieData{ + { + Key: []byte{1, 2, 3}, + Value: []byte("value"), + Version: core.AutoBalanceEnabled, + }, + { + Key: []byte{1, 2, 4}, + Value: []byte("value"), + Version: core.AutoBalanceEnabled, + }, + } + err := bn.commitDirty(0, 5, db, db) + assert.Nil(t, err) + assert.False(t, bn.dirty) + originalHash := bn.getHash() + assert.True(t, len(originalHash) > 0) + originalChildHash := bn.children[childPos].getHash() + assert.True(t, len(originalChildHash) > 0) + + dirty, modifiedHashes, err := bn.insertOnExistingChild(newData, childPos, db) + assert.Nil(t, err) + assert.True(t, dirty) + assert.True(t, bn.dirty) + expectedNumTrieNodesChanged := 2 + assert.Equal(t, expectedNumTrieNodesChanged, len(modifiedHashes)) + assert.Equal(t, originalChildHash, modifiedHashes[0]) + assert.Equal(t, originalHash, modifiedHashes[1]) + _, ok := bn.children[childPos].(*extensionNode) + assert.True(t, ok) + }) + t.Run("insert on existing child same node", func(t *testing.T) { + t.Parallel() + + childPos := byte(2) + db := testscommon.NewMemDbMock() + var children [nrOfChildren]node + marshaller, hasher := getTestMarshalizerAndHasher() + key := []byte{1, 2, 5} + value := "dog" + children[2], _ = newLeafNode(getTrieDataWithDefaultVersion(string(key), value), marshaller, hasher) + children[6], _ = newLeafNode(getTrieDataWithDefaultVersion("doe", "doe"), marshaller, hasher) + bn, _ := newBranchNode(marshaller, hasher) + bn.children = children + newData := []core.TrieData{ + { + Key: key, + Value: []byte(value), + Version: core.NotSpecified, + }, + } + err := bn.commitDirty(0, 5, db, db) + assert.Nil(t, err) + assert.False(t, bn.dirty) + + dirty, modifiedHashes, err := bn.insertOnExistingChild(newData, childPos, db) + assert.Nil(t, err) + assert.False(t, dirty) + assert.False(t, bn.dirty) + expectedNumTrieNodesChanged := 0 + assert.Equal(t, expectedNumTrieNodesChanged, len(modifiedHashes)) + _, ok := bn.children[childPos].(*leafNode) + assert.True(t, ok) + }) +} + +func TestBranchNode_insertBatch(t *testing.T) { + t.Parallel() + + db := testscommon.NewMemDbMock() + var children [nrOfChildren]node + marshaller, hasher := getTestMarshalizerAndHasher() + children[2], _ = newLeafNode(getTrieDataWithDefaultVersion(string([]byte{3, 4, 5}), "dog"), marshaller, hasher) + children[6], _ = newLeafNode(getTrieDataWithDefaultVersion(string([]byte{7, 8, 9}), "doe"), marshaller, hasher) + bn, _ := newBranchNode(marshaller, hasher) + bn.children = children + + newData := []core.TrieData{ + { + Key: []byte{1, 2, 3}, + Value: []byte("value1"), + }, + { + Key: []byte{6, 7, 8, 16}, + Value: []byte("value2"), + }, + { + Key: []byte{6, 10, 11, 16}, + Value: []byte("value3"), + }, + { + Key: []byte{16}, + Value: []byte("value4"), + }, + } + err := bn.commitDirty(0, 5, db, db) + assert.Nil(t, err) + assert.False(t, bn.dirty) + + newNode, modifiedHashes, err := bn.insert(newData, db) + assert.Nil(t, err) + expectedNumTrieNodesChanged := 2 + assert.Equal(t, expectedNumTrieNodesChanged, len(modifiedHashes)) + assert.True(t, newNode.isDirty()) + + bn, ok := newNode.(*branchNode) + assert.True(t, ok) + assert.True(t, bn.dirty) + assert.False(t, bn.children[2].isDirty()) + _, ok = bn.children[1].(*leafNode) + assert.True(t, ok) + _, ok = bn.children[6].(*branchNode) + assert.True(t, ok) + _, ok = bn.children[16].(*leafNode) + assert.True(t, ok) + +} + +func getNewBn() *branchNode { + marsh, hasher := getTestMarshalizerAndHasher() + var children [nrOfChildren]node + childBn, _ := newBranchNode(marsh, hasher) + childBn.children[1], _ = newLeafNode(getTrieDataWithDefaultVersion(string([]byte{3, 4, 5}), "dog"), marsh, hasher) + childBn.children[3], _ = newLeafNode(getTrieDataWithDefaultVersion(string([]byte{7, 8, 9}), "doe"), marsh, hasher) + + children[4], _ = newLeafNode(getTrieDataWithDefaultVersion(string([]byte{3, 4, 5}), "dog"), marsh, hasher) + children[7], _ = newLeafNode(getTrieDataWithDefaultVersion(string([]byte{7, 8, 9}), "doe"), marsh, hasher) + children[9] = childBn + + bn, _ := newBranchNode(marsh, hasher) + bn.children = children + _ = bn.commitDirty(0, 5, testscommon.NewMemDbMock(), testscommon.NewMemDbMock()) + return bn +} + +func TestBranchNode_deleteBatch(t *testing.T) { + t.Parallel() + + t.Run("delete multiple children", func(t *testing.T) { + t.Parallel() + + bn := getNewBn() + assert.False(t, bn.dirty) + + data := []core.TrieData{ + { + Key: []byte{4, 3, 4, 5}, + }, + { + Key: []byte{9, 1, 3, 4, 5}, + }, + } + + dirty, newNode, modifiedHashes, err := bn.delete(data, testscommon.NewMemDbMock()) + assert.Nil(t, err) + assert.True(t, dirty) + assert.True(t, newNode.isDirty()) + expectedNumTrieNodesChanged := 5 + assert.Equal(t, expectedNumTrieNodesChanged, len(modifiedHashes)) + bn, ok := newNode.(*branchNode) + assert.True(t, ok) + + assert.Nil(t, bn.children[4]) + assert.False(t, bn.children[7].isDirty()) + _, ok = bn.children[9].(*leafNode) + assert.True(t, ok) + + }) + t.Run("reduce node after delete batch", func(t *testing.T) { + t.Parallel() + + bn := getNewBn() + assert.False(t, bn.dirty) + + data := []core.TrieData{ + { + Key: []byte{4, 3, 4, 5}, + }, + { + Key: []byte{7, 7, 8, 9}, + }, + { + Key: []byte{9, 1, 3, 4, 5}, + }, + } + + dirty, newNode, modifiedHashes, err := bn.delete(data, testscommon.NewMemDbMock()) + assert.Nil(t, err) + assert.True(t, dirty) + assert.True(t, newNode.isDirty()) + expectedNumTrieNodesChanged := 6 + assert.Equal(t, expectedNumTrieNodesChanged, len(modifiedHashes)) + ln, ok := newNode.(*leafNode) + assert.True(t, ok) + assert.Equal(t, []byte{9, 3, 7, 8, 9}, ln.Key) + }) + t.Run("delete all children", func(t *testing.T) { + t.Parallel() + + bn := getNewBn() + assert.False(t, bn.dirty) + + data := []core.TrieData{ + { + Key: []byte{4, 3, 4, 5}, + }, + { + Key: []byte{7, 7, 8, 9}, + }, + { + Key: []byte{9, 1, 3, 4, 5}, + }, + { + Key: []byte{9, 3, 7, 8, 9}, + }, + } + + dirty, newNode, modifiedHashes, err := bn.delete(data, testscommon.NewMemDbMock()) + assert.Nil(t, err) + assert.True(t, dirty) + assert.Nil(t, newNode) + expectedNumTrieNodesChanged := 6 + assert.Equal(t, expectedNumTrieNodesChanged, len(modifiedHashes)) + }) +} diff --git a/trie/extensionNode.go b/trie/extensionNode.go index 9c05caaeebe..a1a2f08d383 100644 --- a/trie/extensionNode.go +++ b/trie/extensionNode.go @@ -334,7 +334,7 @@ func (en *extensionNode) getNext(key []byte, db common.TrieStorageInteractor) (n return en.child, key, nil } -func (en *extensionNode) insert(newData core.TrieData, db common.TrieStorageInteractor) (node, [][]byte, error) { +func (en *extensionNode) insert(newData []core.TrieData, db common.TrieStorageInteractor) (node, [][]byte, error) { emptyHashes := make([][]byte, 0) err := en.isEmptyOrNil() if err != nil { @@ -345,7 +345,7 @@ func (en *extensionNode) insert(newData core.TrieData, db common.TrieStorageInte return nil, emptyHashes, err } - keyMatchLen := prefixLen(newData.Key, en.Key) + keyMatchLen, index := getMinKeyMatchLen(newData, en.Key) // If the whole key matches, keep this extension node as is // and only update the value. @@ -354,16 +354,50 @@ func (en *extensionNode) insert(newData core.TrieData, db common.TrieStorageInte } // Otherwise branch out at the index where they differ. - return en.insertInNewBn(newData, keyMatchLen) + return en.insertInNewBn(newData, db, keyMatchLen, index) } -func (en *extensionNode) insertInSameEn(newData core.TrieData, keyMatchLen int, db common.TrieStorageInteractor) (node, [][]byte, error) { - newData.Key = newData.Key[keyMatchLen:] +func getMinKeyMatchLen(newData []core.TrieData, enKey []byte) (int, int) { + minKeyMatchLen := len(enKey) + index := 0 + for i, data := range newData { + if minKeyMatchLen == 0 { + return 0, index + } + matchLen := prefixLen(data.Key, enKey) + if matchLen < minKeyMatchLen { + minKeyMatchLen = matchLen + index = i + } + } + + return minKeyMatchLen, index +} + +func removeCommonPrefix(newData []core.TrieData, prefixLen int) error { + for i := range newData { + if len(newData[i].Key) < prefixLen { + return ErrValueTooShort + } + newData[i].Key = newData[i].Key[prefixLen:] + } + + return nil +} + +func (en *extensionNode) insertInSameEn(newData []core.TrieData, keyMatchLen int, db common.TrieStorageInteractor) (node, [][]byte, error) { + for i := range newData { + newData[i].Key = newData[i].Key[keyMatchLen:] + } newNode, oldHashes, err := en.child.insert(newData, db) - if check.IfNil(newNode) || err != nil { + if err != nil { return nil, [][]byte{}, err } + if check.IfNil(newNode) { + return nil, [][]byte{}, nil + } + if !en.dirty { oldHashes = append(oldHashes, en.hash) } @@ -376,7 +410,7 @@ func (en *extensionNode) insertInSameEn(newData core.TrieData, keyMatchLen int, return newEn, oldHashes, nil } -func (en *extensionNode) insertInNewBn(newData core.TrieData, keyMatchLen int) (node, [][]byte, error) { +func (en *extensionNode) insertInNewBn(newData []core.TrieData, db common.TrieStorageInteractor, keyMatchLen int, index int) (node, [][]byte, error) { oldHash := make([][]byte, 0) if !en.dirty { oldHash = append(oldHash, en.hash) @@ -388,7 +422,7 @@ func (en *extensionNode) insertInNewBn(newData core.TrieData, keyMatchLen int) ( } oldChildPos := en.Key[keyMatchLen] - newChildPos := newData.Key[keyMatchLen] + newChildPos := newData[index].Key[keyMatchLen] if childPosOutOfRange(oldChildPos) || childPosOutOfRange(newChildPos) { return nil, [][]byte{}, ErrChildPosOutOfRange } @@ -398,16 +432,33 @@ func (en *extensionNode) insertInNewBn(newData core.TrieData, keyMatchLen int) ( return nil, [][]byte{}, err } - err = en.insertNewChildInBn(bn, newData, newChildPos, keyMatchLen) + newChild := newData[index] + newData = append(newData[:index], newData[index+1:]...) + + err = en.insertNewChildInBn(bn, newChild, newChildPos, keyMatchLen) + if err != nil { + return nil, [][]byte{}, err + } + + err = removeCommonPrefix(newData, keyMatchLen) if err != nil { return nil, [][]byte{}, err } + var newNode node + newNode = bn + if len(newData) != 0 { + newNode, _, err = bn.insert(newData, db) + if err != nil { + return nil, [][]byte{}, err + } + } + if keyMatchLen == 0 { - return bn, oldHash, nil + return newNode, oldHash, nil } - newEn, err := newExtensionNode(en.Key[:keyMatchLen], bn, en.marsh, en.hasher) + newEn, err := newExtensionNode(en.Key[:keyMatchLen], newNode, en.marsh, en.hasher) if err != nil { return nil, [][]byte{}, err } @@ -437,30 +488,40 @@ func (en *extensionNode) insertOldChildInBn(bn *branchNode, oldChildPos byte, ke return nil } -func (en *extensionNode) insertNewChildInBn(bn *branchNode, newData core.TrieData, newChildPos byte, keyMatchLen int) error { - newData.Key = newData.Key[keyMatchLen+1:] +func (en *extensionNode) insertNewChildInBn(bn *branchNode, newChild core.TrieData, newChildPos byte, keyMatchLen int) error { + newChild.Key = newChild.Key[keyMatchLen+1:] - newLeaf, err := newLeafNode(newData, en.marsh, en.hasher) + newLeaf, err := newLeafNode(newChild, en.marsh, en.hasher) if err != nil { return err } bn.children[newChildPos] = newLeaf - bn.setVersionForChild(newData.Version, newChildPos) + bn.setVersionForChild(newChild.Version, newChildPos) return nil } -func (en *extensionNode) delete(key []byte, db common.TrieStorageInteractor) (bool, node, [][]byte, error) { +func (en *extensionNode) getDataWithMatchingPrefix(data []core.TrieData) []core.TrieData { + dataWithMatchingKey := make([]core.TrieData, 0) + for _, d := range data { + if len(en.Key) == prefixLen(d.Key, en.Key) { + d.Key = d.Key[len(en.Key):] + dataWithMatchingKey = append(dataWithMatchingKey, d) + } + } + + return dataWithMatchingKey +} + +func (en *extensionNode) delete(data []core.TrieData, db common.TrieStorageInteractor) (bool, node, [][]byte, error) { emptyHashes := make([][]byte, 0) err := en.isEmptyOrNil() if err != nil { return false, nil, emptyHashes, fmt.Errorf("delete error %w", err) } - if len(key) == 0 { - return false, nil, emptyHashes, ErrValueTooShort - } - keyMatchLen := prefixLen(key, en.Key) - if keyMatchLen < len(en.Key) { + + dataWithMatchingKey := en.getDataWithMatchingPrefix(data) + if len(dataWithMatchingKey) == 0 { return false, en, emptyHashes, nil } err = resolveIfCollapsed(en, 0, db) @@ -468,11 +529,15 @@ func (en *extensionNode) delete(key []byte, db common.TrieStorageInteractor) (bo return false, nil, emptyHashes, err } - dirty, newNode, oldHashes, err := en.child.delete(key[len(en.Key):], db) - if !dirty || err != nil { + dirty, newNode, oldHashes, err := en.child.delete(dataWithMatchingKey, db) + if err != nil { return false, en, emptyHashes, err } + if !dirty { + return false, en, emptyHashes, nil + } + if !en.dirty { oldHashes = append(oldHashes, en.hash) } @@ -505,7 +570,6 @@ func (en *extensionNode) delete(key []byte, db common.TrieStorageInteractor) (bo return true, n, oldHashes, nil case nil: - log.Warn("nil child after deleting from extension node") return true, nil, oldHashes, nil default: return false, nil, oldHashes, ErrInvalidNode diff --git a/trie/extensionNode_test.go b/trie/extensionNode_test.go index 45a3adeffc4..1db5e0e076e 100644 --- a/trie/extensionNode_test.go +++ b/trie/extensionNode_test.go @@ -493,7 +493,8 @@ func TestExtensionNode_insert(t *testing.T) { en, _ := getEnAndCollapsedEn() key := []byte{100, 15, 5, 6} - newNode, _, err := en.insert(getTrieDataWithDefaultVersion(string(key), "dogs"), nil) + data := []core.TrieData{getTrieDataWithDefaultVersion(string(key), "dogs")} + newNode, _, err := en.insert(data, nil) assert.NotNil(t, newNode) assert.Nil(t, err) @@ -511,7 +512,8 @@ func TestExtensionNode_insertCollapsedNode(t *testing.T) { _ = en.setHash() _ = en.commitDirty(0, 5, db, db) - newNode, _, err := collapsedEn.insert(getTrieDataWithDefaultVersion(string(key), "dogs"), db) + data := []core.TrieData{getTrieDataWithDefaultVersion(string(key), "dogs")} + newNode, _, err := collapsedEn.insert(data, db) assert.NotNil(t, newNode) assert.Nil(t, err) @@ -533,7 +535,8 @@ func TestExtensionNode_insertInStoredEnSameKey(t *testing.T) { bnHash := bn.getHash() expectedHashes := [][]byte{bnHash, enHash} - newNode, oldHashes, err := en.insert(getTrieDataWithDefaultVersion(string(key), "dogs"), db) + data := []core.TrieData{getTrieDataWithDefaultVersion(string(key), "dogs")} + newNode, oldHashes, err := en.insert(data, db) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, expectedHashes, oldHashes) @@ -551,7 +554,8 @@ func TestExtensionNode_insertInStoredEnDifferentKey(t *testing.T) { _ = en.commitDirty(0, 5, db, db) expectedHashes := [][]byte{en.getHash()} - newNode, oldHashes, err := en.insert(getTrieDataWithDefaultVersion(string(nodeKey), "dogs"), db) + data := []core.TrieData{getTrieDataWithDefaultVersion(string(nodeKey), "dogs")} + newNode, oldHashes, err := en.insert(data, db) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, expectedHashes, oldHashes) @@ -563,7 +567,8 @@ func TestExtensionNode_insertInDirtyEnSameKey(t *testing.T) { en, _ := getEnAndCollapsedEn() nodeKey := []byte{100, 11, 12} - newNode, oldHashes, err := en.insert(getTrieDataWithDefaultVersion(string(nodeKey), "dogs"), nil) + data := []core.TrieData{getTrieDataWithDefaultVersion(string(nodeKey), "dogs")} + newNode, oldHashes, err := en.insert(data, nil) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, [][]byte{}, oldHashes) @@ -577,7 +582,8 @@ func TestExtensionNode_insertInDirtyEnDifferentKey(t *testing.T) { en, _ := newExtensionNode(enKey, bn, bn.marsh, bn.hasher) nodeKey := []byte{11, 12} - newNode, oldHashes, err := en.insert(getTrieDataWithDefaultVersion(string(nodeKey), "dogs"), nil) + data := []core.TrieData{getTrieDataWithDefaultVersion(string(nodeKey), "dogs")} + newNode, oldHashes, err := en.insert(data, nil) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, [][]byte{}, oldHashes) @@ -588,7 +594,8 @@ func TestExtensionNode_insertInNilNode(t *testing.T) { var en *extensionNode - newNode, _, err := en.insert(getTrieDataWithDefaultVersion("key", "val"), nil) + data := []core.TrieData{getTrieDataWithDefaultVersion("key", "val")} + newNode, _, err := en.insert(data, nil) assert.Nil(t, newNode) assert.True(t, errors.Is(err, ErrNilExtensionNode)) assert.Nil(t, newNode) @@ -608,8 +615,9 @@ func TestExtensionNode_delete(t *testing.T) { val, _, _ := en.tryGet(key, 0, nil) assert.Equal(t, dogBytes, val) + data := []core.TrieData{{Key: key}} - dirty, _, _, err := en.delete(key, nil) + dirty, _, _, err := en.delete(data, nil) assert.True(t, dirty) assert.Nil(t, err) val, _, _ = en.tryGet(key, 0, nil) @@ -632,8 +640,9 @@ func TestExtensionNode_deleteFromStoredEn(t *testing.T) { bn, key, _ := en.getNext(key, db) ln, _, _ := bn.getNext(key, db) expectedHashes := [][]byte{ln.getHash(), bn.getHash(), en.getHash()} + data := []core.TrieData{{Key: lnPathKey}} - dirty, _, oldHashes, err := en.delete(lnPathKey, db) + dirty, _, oldHashes, err := en.delete(data, db) assert.True(t, dirty) assert.Nil(t, err) assert.Equal(t, expectedHashes, oldHashes) @@ -644,8 +653,9 @@ func TestExtensionNode_deleteFromDirtyEn(t *testing.T) { en, _ := getEnAndCollapsedEn() lnKey := []byte{100, 2, 100, 111, 103} + data := []core.TrieData{{Key: lnKey}} - dirty, _, oldHashes, err := en.delete(lnKey, nil) + dirty, _, oldHashes, err := en.delete(data, nil) assert.True(t, dirty) assert.Nil(t, err) assert.Equal(t, [][]byte{}, oldHashes) @@ -655,8 +665,9 @@ func TestExtendedNode_deleteEmptyNode(t *testing.T) { t.Parallel() en := &extensionNode{} + data := []core.TrieData{{Key: []byte("dog")}} - dirty, newNode, _, err := en.delete([]byte("dog"), nil) + dirty, newNode, _, err := en.delete(data, nil) assert.False(t, dirty) assert.True(t, errors.Is(err, ErrEmptyExtensionNode)) assert.Nil(t, newNode) @@ -666,24 +677,14 @@ func TestExtensionNode_deleteNilNode(t *testing.T) { t.Parallel() var en *extensionNode + data := []core.TrieData{{Key: []byte("dog")}} - dirty, newNode, _, err := en.delete([]byte("dog"), nil) + dirty, newNode, _, err := en.delete(data, nil) assert.False(t, dirty) assert.True(t, errors.Is(err, ErrNilExtensionNode)) assert.Nil(t, newNode) } -func TestExtensionNode_deleteEmptykey(t *testing.T) { - t.Parallel() - - en, _ := getEnAndCollapsedEn() - - dirty, newNode, _, err := en.delete([]byte{}, nil) - assert.False(t, dirty) - assert.Equal(t, ErrValueTooShort, err) - assert.Nil(t, newNode) -} - func TestExtensionNode_deleteCollapsedNode(t *testing.T) { t.Parallel() @@ -700,8 +701,9 @@ func TestExtensionNode_deleteCollapsedNode(t *testing.T) { val, _, _ := en.tryGet(key, 0, db) assert.Equal(t, []byte("dog"), val) + data := []core.TrieData{{Key: key}} - dirty, newNode, _, err := collapsedEn.delete(key, db) + dirty, newNode, _, err := collapsedEn.delete(data, db) assert.True(t, dirty) assert.Nil(t, err) val, _, _ = newNode.tryGet(key, 0, db) @@ -1078,3 +1080,329 @@ func TestExtensionNode_getVersion(t *testing.T) { assert.Nil(t, err) }) } + +func Test_getMinKeyMatchLen(t *testing.T) { + t.Parallel() + + t.Run("same key", func(t *testing.T) { + t.Parallel() + + newData := []core.TrieData{ + { + Key: []byte("dog"), + }, + } + keyMatchLen, index := getMinKeyMatchLen(newData, []byte("dog")) + assert.Equal(t, 3, keyMatchLen) + assert.Equal(t, 0, index) + }) + t.Run("first key is min", func(t *testing.T) { + t.Parallel() + + newData := []core.TrieData{ + { + Key: []byte("dog"), + }, + { + Key: []byte("doge"), + }, + } + keyMatchLen, index := getMinKeyMatchLen(newData, []byte("doe")) + assert.Equal(t, 2, keyMatchLen) + assert.Equal(t, 0, index) + + }) + t.Run("last key is min", func(t *testing.T) { + t.Parallel() + + newData := []core.TrieData{ + { + Key: []byte("doge"), + }, + { + Key: []byte("dad"), + }, + } + keyMatchLen, index := getMinKeyMatchLen(newData, []byte("doe")) + assert.Equal(t, 1, keyMatchLen) + assert.Equal(t, 1, index) + + }) + t.Run("no match", func(t *testing.T) { + t.Parallel() + + newData := []core.TrieData{ + { + Key: []byte("doge"), + }, + { + Key: []byte("dog"), + }, + } + keyMatchLen, index := getMinKeyMatchLen(newData, []byte("cat")) + assert.Equal(t, 0, keyMatchLen) + assert.Equal(t, 0, index) + }) +} + +func Test_removeCommonPrefix(t *testing.T) { + t.Parallel() + + t.Run("no common prefix", func(t *testing.T) { + t.Parallel() + + newData := []core.TrieData{ + { + Key: []byte("doge"), + }, + { + Key: []byte("cat"), + }, + } + + err := removeCommonPrefix(newData, 0) + assert.Nil(t, err) + assert.Equal(t, []byte("doge"), newData[0].Key) + assert.Equal(t, []byte("cat"), newData[1].Key) + }) + t.Run("remove prefix from all", func(t *testing.T) { + t.Parallel() + + newData := []core.TrieData{ + { + Key: []byte("doge"), + }, + { + Key: []byte("doe"), + }, + } + + err := removeCommonPrefix(newData, 2) + assert.Nil(t, err) + assert.Equal(t, []byte("ge"), newData[0].Key) + assert.Equal(t, []byte("e"), newData[1].Key) + + }) + t.Run("one key is less than the prefix", func(t *testing.T) { + t.Parallel() + + newData := []core.TrieData{ + { + Key: []byte("doge"), + }, + { + Key: []byte("do"), + }, + } + + err := removeCommonPrefix(newData, 3) + assert.Equal(t, ErrValueTooShort, err) + }) +} + +func getEn() *extensionNode { + marsh, hasher := getTestMarshalizerAndHasher() + var children [nrOfChildren]node + children[4], _ = newLeafNode(getTrieDataWithDefaultVersion(string([]byte{3, 4, 5}), "dog"), marsh, hasher) + children[7], _ = newLeafNode(getTrieDataWithDefaultVersion(string([]byte{7, 8, 9}), "doe"), marsh, hasher) + bn, _ := newBranchNode(marsh, hasher) + bn.children = children + en, _ := newExtensionNode([]byte{1, 2}, bn, marsh, hasher) + return en +} + +func TestExtensionNode_insertInSameEn(t *testing.T) { + t.Parallel() + + t.Run("insert same data", func(t *testing.T) { + t.Parallel() + + en := getEn() + err := en.commitDirty(0, 5, testscommon.NewMemDbMock(), testscommon.NewMemDbMock()) + assert.Nil(t, err) + + data := []core.TrieData{ + getTrieDataWithDefaultVersion(string([]byte{1, 2, 4, 3, 4, 5}), "dog"), + getTrieDataWithDefaultVersion(string([]byte{1, 2, 7, 7, 8, 9}), "doe"), + } + + newNode, modifiedHashes, err := en.insert(data, nil) + assert.Nil(t, err) + expectedNumTrieNodesChanged := 0 + assert.Equal(t, expectedNumTrieNodesChanged, len(modifiedHashes)) + assert.Nil(t, newNode) + assert.False(t, en.dirty) + }) + t.Run("insert different data", func(t *testing.T) { + t.Parallel() + + en := getEn() + err := en.commitDirty(0, 5, testscommon.NewMemDbMock(), testscommon.NewMemDbMock()) + assert.Nil(t, err) + + data := []core.TrieData{ + getTrieDataWithDefaultVersion(string([]byte{1, 2, 6, 7, 16}), "dog"), + getTrieDataWithDefaultVersion(string([]byte{1, 2, 3, 4, 5}), "doe"), + } + + newNode, modifiedHashes, err := en.insert(data, nil) + assert.Nil(t, err) + expectedNumTrieNodesChanged := 2 + assert.Equal(t, expectedNumTrieNodesChanged, len(modifiedHashes)) + en, ok := newNode.(*extensionNode) + assert.True(t, ok) + assert.True(t, en.dirty) + bn, ok := en.child.(*branchNode) + assert.True(t, ok) + assert.False(t, bn.children[4].isDirty()) + assert.False(t, bn.children[7].isDirty()) + assert.Equal(t, []byte{4, 5}, bn.children[3].(*leafNode).Key) + assert.Equal(t, []byte{7, 16}, bn.children[6].(*leafNode).Key) + }) +} + +func TestExtensionNode_insertInNewBn(t *testing.T) { + t.Parallel() + + t.Run("with a new en parent", func(t *testing.T) { + t.Parallel() + + en := getEn() + err := en.commitDirty(0, 5, testscommon.NewMemDbMock(), testscommon.NewMemDbMock()) + assert.Nil(t, err) + + data := []core.TrieData{ + getTrieDataWithDefaultVersion(string([]byte{1, 3, 6, 7, 16}), "dog"), + getTrieDataWithDefaultVersion(string([]byte{1, 3, 3, 4, 5}), "doe"), + } + + newNode, modifiedHashes, err := en.insert(data, nil) + assert.Nil(t, err) + expectedNumTrieNodesChanged := 1 + assert.Equal(t, expectedNumTrieNodesChanged, len(modifiedHashes)) + en, ok := newNode.(*extensionNode) + assert.True(t, ok) + assert.True(t, en.dirty) + assert.Equal(t, []byte{1}, en.Key) + bn, ok := en.child.(*branchNode) + assert.True(t, ok) + assert.False(t, bn.children[2].isDirty()) + assert.True(t, bn.children[3].isDirty()) + _, ok = bn.children[2].(*branchNode) + assert.True(t, ok) + _, ok = bn.children[3].(*branchNode) + assert.True(t, ok) + }) + t.Run("branch at the beginning of the en", func(t *testing.T) { + t.Parallel() + + en := getEn() + err := en.commitDirty(0, 5, testscommon.NewMemDbMock(), testscommon.NewMemDbMock()) + assert.Nil(t, err) + + data := []core.TrieData{ + getTrieDataWithDefaultVersion(string([]byte{2, 3, 6, 7, 16}), "dog"), + getTrieDataWithDefaultVersion(string([]byte{3, 3, 3, 4, 5}), "doe"), + } + + newNode, modifiedHashes, err := en.insert(data, nil) + assert.Nil(t, err) + expectedNumTrieNodesChanged := 1 + assert.Equal(t, expectedNumTrieNodesChanged, len(modifiedHashes)) + bn, ok := newNode.(*branchNode) + assert.True(t, ok) + assert.True(t, bn.dirty) + assert.Equal(t, []byte{3, 6, 7, 16}, bn.children[2].(*leafNode).Key) + assert.Equal(t, []byte{3, 3, 4, 5}, bn.children[3].(*leafNode).Key) + assert.Equal(t, []byte{2}, bn.children[1].(*extensionNode).Key) + }) +} + +func TestExtensionNode_deleteBatch(t *testing.T) { + t.Parallel() + + t.Run("delete invalid node", func(t *testing.T) { + t.Parallel() + + en := getEn() + err := en.commitDirty(0, 5, testscommon.NewMemDbMock(), testscommon.NewMemDbMock()) + assert.Nil(t, err) + + data := []core.TrieData{ + getTrieDataWithDefaultVersion(string([]byte{2, 3, 6, 7, 16}), "dog"), + getTrieDataWithDefaultVersion(string([]byte{3, 3, 3, 4, 5}), "doe"), + } + + dirty, newNode, modifiedHashes, err := en.delete(data, nil) + assert.False(t, dirty) + assert.Nil(t, err) + expectedNumTrieNodesChanged := 0 + assert.Equal(t, expectedNumTrieNodesChanged, len(modifiedHashes)) + assert.Equal(t, en, newNode) + }) + t.Run("reduce to leaf after delete", func(t *testing.T) { + t.Parallel() + + en := getEn() + err := en.commitDirty(0, 5, testscommon.NewMemDbMock(), testscommon.NewMemDbMock()) + assert.Nil(t, err) + + data := []core.TrieData{ + getTrieDataWithDefaultVersion(string([]byte{1, 2, 4, 3, 4, 5}), "dog"), + } + + dirty, newNode, modifiedHashes, err := en.delete(data, nil) + assert.True(t, dirty) + assert.Nil(t, err) + expectedNumTrieNodesChanged := 4 + assert.Equal(t, expectedNumTrieNodesChanged, len(modifiedHashes)) + ln, ok := newNode.(*leafNode) + assert.True(t, ok) + assert.Equal(t, []byte{1, 2, 7, 7, 8, 9}, ln.Key) + assert.True(t, ln.dirty) + }) + t.Run("reduce to extension node after delete", func(t *testing.T) { + t.Parallel() + + en := getEn() + data := []core.TrieData{ + getTrieDataWithDefaultVersion(string([]byte{1, 2, 4, 4, 5, 6}), "dog"), + } + _, _, _ = en.insert(data, nil) + err := en.commitDirty(0, 5, testscommon.NewMemDbMock(), testscommon.NewMemDbMock()) + assert.Nil(t, err) + + dataForRemoval := []core.TrieData{ + getTrieDataWithDefaultVersion(string([]byte{1, 2, 7, 7, 8, 9}), "dog"), + } + + dirty, newNode, modifiedHashes, err := en.delete(dataForRemoval, nil) + assert.True(t, dirty) + assert.Nil(t, err) + expectedNumTrieNodesChanged := 3 + assert.Equal(t, expectedNumTrieNodesChanged, len(modifiedHashes)) + en, ok := newNode.(*extensionNode) + assert.True(t, ok) + assert.Equal(t, []byte{1, 2, 4}, en.Key) + assert.True(t, en.dirty) + }) + t.Run("delete all children", func(t *testing.T) { + t.Parallel() + + en := getEn() + err := en.commitDirty(0, 5, testscommon.NewMemDbMock(), testscommon.NewMemDbMock()) + assert.Nil(t, err) + + data := []core.TrieData{ + getTrieDataWithDefaultVersion(string([]byte{1, 2, 4, 3, 4, 5}), "dog"), + getTrieDataWithDefaultVersion(string([]byte{1, 2, 7, 7, 8, 9}), "doe"), + } + + dirty, newNode, modifiedHashes, err := en.delete(data, nil) + assert.True(t, dirty) + assert.Nil(t, err) + expectedNumTrieNodesChanged := 4 + assert.Equal(t, expectedNumTrieNodesChanged, len(modifiedHashes)) + assert.Nil(t, newNode) + }) +} diff --git a/trie/interface.go b/trie/interface.go index 3bbc79119f2..cb623b1d686 100644 --- a/trie/interface.go +++ b/trie/interface.go @@ -29,8 +29,8 @@ type node interface { hashChildren() error tryGet(key []byte, depth uint32, db common.TrieStorageInteractor) ([]byte, uint32, error) getNext(key []byte, db common.TrieStorageInteractor) (node, []byte, error) - insert(newData core.TrieData, db common.TrieStorageInteractor) (node, [][]byte, error) - delete(key []byte, db common.TrieStorageInteractor) (bool, node, [][]byte, error) + insert(newData []core.TrieData, db common.TrieStorageInteractor) (node, [][]byte, error) + delete(data []core.TrieData, db common.TrieStorageInteractor) (bool, node, [][]byte, error) reduceNode(pos int) (node, bool, error) isEmptyOrNil() error print(writer io.Writer, index int, db common.TrieStorageInteractor) diff --git a/trie/leafNode.go b/trie/leafNode.go index 0b0ab6384d6..85959c28de2 100644 --- a/trie/leafNode.go +++ b/trie/leafNode.go @@ -235,7 +235,7 @@ func (ln *leafNode) getNext(key []byte, _ common.TrieStorageInteractor) (node, [ } return nil, nil, ErrNodeNotFound } -func (ln *leafNode) insert(newData core.TrieData, _ common.TrieStorageInteractor) (node, [][]byte, error) { +func (ln *leafNode) insert(newData []core.TrieData, db common.TrieStorageInteractor) (node, [][]byte, error) { err := ln.isEmptyOrNil() if err != nil { return nil, [][]byte{}, fmt.Errorf("insert error %w", err) @@ -246,14 +246,12 @@ func (ln *leafNode) insert(newData core.TrieData, _ common.TrieStorageInteractor oldHash = append(oldHash, ln.hash) } - nodeKey := ln.Key - - if bytes.Equal(newData.Key, nodeKey) { - return ln.insertInSameLn(newData, oldHash) + if len(newData) == 1 && bytes.Equal(newData[0].Key, ln.Key) { + return ln.insertInSameLn(newData[0], oldHash) } - keyMatchLen := prefixLen(newData.Key, nodeKey) - bn, err := ln.insertInNewBn(newData, keyMatchLen) + keyMatchLen, _ := getMinKeyMatchLen(newData, ln.Key) + bn, err := ln.insertInNewBn(newData, keyMatchLen, db) if err != nil { return nil, [][]byte{}, err } @@ -262,7 +260,7 @@ func (ln *leafNode) insert(newData core.TrieData, _ common.TrieStorageInteractor return bn, oldHash, nil } - newEn, err := newExtensionNode(nodeKey[:keyMatchLen], bn, ln.marsh, ln.hasher) + newEn, err := newExtensionNode(ln.Key[:keyMatchLen], bn, ln.marsh, ln.hasher) if err != nil { return nil, [][]byte{}, err } @@ -282,54 +280,62 @@ func (ln *leafNode) insertInSameLn(newData core.TrieData, oldHashes [][]byte) (n return ln, oldHashes, nil } -func (ln *leafNode) insertInNewBn(newData core.TrieData, keyMatchLen int) (node, error) { +func trimKeys(data []core.TrieData, keyMatchLen int) { + for i := range data { + data[i].Key = data[i].Key[keyMatchLen:] + } +} + +func (ln *leafNode) insertInNewBn(newData []core.TrieData, keyMatchLen int, db common.TrieStorageInteractor) (node, error) { bn, err := newBranchNode(ln.marsh, ln.hasher) if err != nil { return nil, err } - oldChildPos := ln.Key[keyMatchLen] - newChildPos := newData.Key[keyMatchLen] - if childPosOutOfRange(oldChildPos) || childPosOutOfRange(newChildPos) { - return nil, ErrChildPosOutOfRange - } - - oldLnVersion, err := ln.getVersion() + lnVersion, err := ln.getVersion() if err != nil { return nil, err } - oldLnData := core.TrieData{ - Key: ln.Key[keyMatchLen+1:], + var newKeyForOldLn []byte + posForOldLn := byte(hexTerminator) + if len(ln.Key) > keyMatchLen { + newKeyForOldLn = ln.Key[keyMatchLen+1:] + posForOldLn = ln.Key[keyMatchLen] + } + + lnData := core.TrieData{ + Key: newKeyForOldLn, Value: ln.Value, - Version: oldLnVersion, + Version: lnVersion, } - newLnOldChildPos, err := newLeafNode(oldLnData, ln.marsh, ln.hasher) + + oldLn, err := newLeafNode(lnData, ln.marsh, ln.hasher) if err != nil { return nil, err } - bn.children[oldChildPos] = newLnOldChildPos - bn.setVersionForChild(oldLnVersion, oldChildPos) + bn.children[posForOldLn] = oldLn + bn.setVersionForChild(lnVersion, posForOldLn) - newData.Key = newData.Key[keyMatchLen+1:] - newLnNewChildPos, err := newLeafNode(newData, ln.marsh, ln.hasher) + trimKeys(newData, keyMatchLen) + newNode, _, err := bn.insert(newData, db) if err != nil { return nil, err } - bn.children[newChildPos] = newLnNewChildPos - bn.setVersionForChild(newData.Version, newChildPos) - return bn, nil + return newNode, nil } -func (ln *leafNode) delete(key []byte, _ common.TrieStorageInteractor) (bool, node, [][]byte, error) { - if bytes.Equal(key, ln.Key) { - oldHash := make([][]byte, 0) - if !ln.dirty { - oldHash = append(oldHash, ln.hash) - } +func (ln *leafNode) delete(data []core.TrieData, _ common.TrieStorageInteractor) (bool, node, [][]byte, error) { + for _, d := range data { + if bytes.Equal(d.Key, ln.Key) { + oldHash := make([][]byte, 0) + if !ln.dirty { + oldHash = append(oldHash, ln.hash) + } - return true, nil, oldHash, nil + return true, nil, oldHash, nil + } } return false, ln, [][]byte{}, nil } diff --git a/trie/leafNode_test.go b/trie/leafNode_test.go index 9f7b27acbf7..2c138ab5373 100644 --- a/trie/leafNode_test.go +++ b/trie/leafNode_test.go @@ -321,8 +321,9 @@ func TestLeafNode_insertAtSameKey(t *testing.T) { ln := getLn(getTestMarshalizerAndHasher()) key := "dog" expectedVal := "dogs" + data := []core.TrieData{getTrieDataWithDefaultVersion(key, expectedVal)} - newNode, _, err := ln.insert(getTrieDataWithDefaultVersion(key, expectedVal), nil) + newNode, _, err := ln.insert(data, nil) assert.NotNil(t, newNode) assert.Nil(t, err) @@ -340,8 +341,9 @@ func TestLeafNode_insertAtDifferentKey(t *testing.T) { nodeKey := []byte{3, 4, 5} nodeVal := []byte{3, 4, 5} + data := []core.TrieData{getTrieDataWithDefaultVersion(string(nodeKey), string(nodeVal))} - newNode, _, err := ln.insert(getTrieDataWithDefaultVersion(string(nodeKey), string(nodeVal)), nil) + newNode, _, err := ln.insert(data, nil) assert.NotNil(t, newNode) assert.Nil(t, err) @@ -357,8 +359,9 @@ func TestLeafNode_insertInStoredLnAtSameKey(t *testing.T) { ln := getLn(getTestMarshalizerAndHasher()) _ = ln.commitDirty(0, 5, db, db) lnHash := ln.getHash() + data := []core.TrieData{getTrieDataWithDefaultVersion("dog", "dogs")} - newNode, oldHashes, err := ln.insert(getTrieDataWithDefaultVersion("dog", "dogs"), db) + newNode, oldHashes, err := ln.insert(data, db) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, [][]byte{lnHash}, oldHashes) @@ -372,8 +375,9 @@ func TestLeafNode_insertInStoredLnAtDifferentKey(t *testing.T) { ln, _ := newLeafNode(getTrieDataWithDefaultVersion(string([]byte{1, 2, 3}), "dog"), marsh, hasher) _ = ln.commitDirty(0, 5, db, db) lnHash := ln.getHash() + data := []core.TrieData{getTrieDataWithDefaultVersion(string([]byte{4, 5, 6}), "dogs")} - newNode, oldHashes, err := ln.insert(getTrieDataWithDefaultVersion(string([]byte{4, 5, 6}), "dogs"), db) + newNode, oldHashes, err := ln.insert(data, db) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, [][]byte{lnHash}, oldHashes) @@ -383,8 +387,9 @@ func TestLeafNode_insertInDirtyLnAtSameKey(t *testing.T) { t.Parallel() ln := getLn(getTestMarshalizerAndHasher()) + data := []core.TrieData{getTrieDataWithDefaultVersion("dog", "dogs")} - newNode, oldHashes, err := ln.insert(getTrieDataWithDefaultVersion("dog", "dogs"), nil) + newNode, oldHashes, err := ln.insert(data, nil) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, [][]byte{}, oldHashes) @@ -395,8 +400,9 @@ func TestLeafNode_insertInDirtyLnAtDifferentKey(t *testing.T) { marsh, hasher := getTestMarshalizerAndHasher() ln, _ := newLeafNode(getTrieDataWithDefaultVersion(string([]byte{1, 2, 3}), "dog"), marsh, hasher) + data := []core.TrieData{getTrieDataWithDefaultVersion(string([]byte{4, 5, 6}), "dogs")} - newNode, oldHashes, err := ln.insert(getTrieDataWithDefaultVersion(string([]byte{4, 5, 6}), "dogs"), nil) + newNode, oldHashes, err := ln.insert(data, nil) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, [][]byte{}, oldHashes) @@ -406,8 +412,9 @@ func TestLeafNode_insertInNilNode(t *testing.T) { t.Parallel() var ln *leafNode + data := []core.TrieData{getTrieDataWithDefaultVersion("dog", "dogs")} - newNode, _, err := ln.insert(getTrieDataWithDefaultVersion("dog", "dogs"), nil) + newNode, _, err := ln.insert(data, nil) assert.Nil(t, newNode) assert.True(t, errors.Is(err, ErrNilLeafNode)) assert.Nil(t, newNode) @@ -417,8 +424,9 @@ func TestLeafNode_deletePresent(t *testing.T) { t.Parallel() ln := getLn(getTestMarshalizerAndHasher()) + data := []core.TrieData{{Key: []byte("dog")}} - dirty, newNode, _, err := ln.delete([]byte("dog"), nil) + dirty, newNode, _, err := ln.delete(data, nil) assert.True(t, dirty) assert.Nil(t, err) assert.Nil(t, newNode) @@ -431,8 +439,9 @@ func TestLeafNode_deleteFromStoredLnAtSameKey(t *testing.T) { ln := getLn(getTestMarshalizerAndHasher()) _ = ln.commitDirty(0, 5, db, db) lnHash := ln.getHash() + data := []core.TrieData{{Key: []byte("dog")}} - dirty, _, oldHashes, err := ln.delete([]byte("dog"), db) + dirty, _, oldHashes, err := ln.delete(data, db) assert.True(t, dirty) assert.Nil(t, err) assert.Equal(t, [][]byte{lnHash}, oldHashes) @@ -445,8 +454,9 @@ func TestLeafNode_deleteFromLnAtDifferentKey(t *testing.T) { ln := getLn(getTestMarshalizerAndHasher()) _ = ln.commitDirty(0, 5, db, db) wrongKey := []byte{1, 2, 3} + data := []core.TrieData{{Key: wrongKey}} - dirty, _, oldHashes, err := ln.delete(wrongKey, db) + dirty, _, oldHashes, err := ln.delete(data, db) assert.False(t, dirty) assert.Nil(t, err) assert.Equal(t, [][]byte{}, oldHashes) @@ -456,8 +466,9 @@ func TestLeafNode_deleteFromDirtyLnAtSameKey(t *testing.T) { t.Parallel() ln := getLn(getTestMarshalizerAndHasher()) + data := []core.TrieData{{Key: []byte("dog")}} - dirty, _, oldHashes, err := ln.delete([]byte("dog"), nil) + dirty, _, oldHashes, err := ln.delete(data, nil) assert.True(t, dirty) assert.Nil(t, err) assert.Equal(t, [][]byte{}, oldHashes) @@ -468,8 +479,9 @@ func TestLeafNode_deleteNotPresent(t *testing.T) { ln := getLn(getTestMarshalizerAndHasher()) wrongKey := []byte{1, 2, 3} + data := []core.TrieData{{Key: wrongKey}} - dirty, newNode, _, err := ln.delete(wrongKey, nil) + dirty, newNode, _, err := ln.delete(data, nil) assert.False(t, dirty) assert.Nil(t, err) assert.Equal(t, ln, newNode) @@ -777,3 +789,131 @@ func TestLeafNode_getVersion(t *testing.T) { assert.Nil(t, err) }) } + +func TestLeafNode_insertBatch(t *testing.T) { + t.Parallel() + + t.Run("insert in same leaf node different val", func(t *testing.T) { + t.Parallel() + + marshaller, hasher := getTestMarshalizerAndHasher() + ln, _ := newLeafNode(getTrieDataWithDefaultVersion(string([]byte{1, 2, 3, 4, 16}), "dog"), marshaller, hasher) + + newData := []core.TrieData{getTrieDataWithDefaultVersion(string([]byte{1, 2, 3, 4, 16}), "dogs")} + _ = ln.commitDirty(0, 5, testscommon.NewMemDbMock(), testscommon.NewMemDbMock()) + assert.False(t, ln.dirty) + originalHash := ln.getHash() + + newNode, modifiedHahses, err := ln.insert(newData, nil) + assert.Nil(t, err) + assert.True(t, newNode.isDirty()) + assert.Equal(t, [][]byte{originalHash}, modifiedHahses) + assert.Equal(t, []byte("dogs"), newNode.(*leafNode).Value) + }) + t.Run("insert in same leaf node same val", func(t *testing.T) { + t.Parallel() + + marshaller, hasher := getTestMarshalizerAndHasher() + ln, _ := newLeafNode(getTrieDataWithDefaultVersion(string([]byte{1, 2, 3, 4, 16}), "dog"), marshaller, hasher) + + newData := []core.TrieData{getTrieDataWithDefaultVersion(string([]byte{1, 2, 3, 4, 16}), "dog")} + _ = ln.commitDirty(0, 5, testscommon.NewMemDbMock(), testscommon.NewMemDbMock()) + assert.False(t, ln.dirty) + + newNode, modifiedHahses, err := ln.insert(newData, nil) + assert.Nil(t, err) + assert.Nil(t, newNode) + assert.Equal(t, 0, len(modifiedHahses)) + }) + t.Run("branch at the beginning after insert", func(t *testing.T) { + t.Parallel() + + marshaller, hasher := getTestMarshalizerAndHasher() + ln, _ := newLeafNode(getTrieDataWithDefaultVersion(string([]byte{1, 2, 3, 4, 16}), "dog"), marshaller, hasher) + + newData := []core.TrieData{ + getTrieDataWithDefaultVersion(string([]byte{1, 2, 3, 4, 16}), "dogs"), + getTrieDataWithDefaultVersion(string([]byte{2, 3, 4, 5, 16}), "dog"), + getTrieDataWithDefaultVersion(string([]byte{3, 4, 5, 6, 16}), "dog"), + } + _ = ln.commitDirty(0, 5, testscommon.NewMemDbMock(), testscommon.NewMemDbMock()) + assert.False(t, ln.dirty) + originalHash := ln.getHash() + + newNode, modifiedHahses, err := ln.insert(newData, nil) + assert.Nil(t, err) + assert.Equal(t, [][]byte{originalHash}, modifiedHahses) + bn, ok := newNode.(*branchNode) + assert.True(t, ok) + assert.Equal(t, []byte("dogs"), bn.children[1].(*leafNode).Value) + assert.NotNil(t, []byte("dog"), bn.children[2].(*leafNode).Value) + assert.NotNil(t, []byte("dog"), bn.children[3].(*leafNode).Value) + assert.True(t, bn.dirty) + assert.Nil(t, bn.hash) + }) + t.Run("extension node at the beginning after insert ", func(t *testing.T) { + t.Parallel() + + marshaller, hasher := getTestMarshalizerAndHasher() + ln, _ := newLeafNode(getTrieDataWithDefaultVersion(string([]byte{1, 2, 3, 4, 16}), "dog"), marshaller, hasher) + + newData := []core.TrieData{ + getTrieDataWithDefaultVersion(string([]byte{1, 2, 3, 4, 16}), "dogs"), + getTrieDataWithDefaultVersion(string([]byte{1, 2, 4, 5, 16}), "dog"), + getTrieDataWithDefaultVersion(string([]byte{1, 2, 5, 6, 16}), "dog"), + } + _ = ln.commitDirty(0, 5, testscommon.NewMemDbMock(), testscommon.NewMemDbMock()) + assert.False(t, ln.dirty) + originalHash := ln.getHash() + + newNode, modifiedHahses, err := ln.insert(newData, nil) + assert.Nil(t, err) + assert.Equal(t, [][]byte{originalHash}, modifiedHahses) + en, ok := newNode.(*extensionNode) + assert.True(t, ok) + assert.Equal(t, []byte{1, 2}, en.Key) + assert.True(t, en.dirty) + assert.Nil(t, en.hash) + }) +} + +func TestLeafNode_deleteBatch(t *testing.T) { + t.Parallel() + + t.Run("delete existing", func(t *testing.T) { + t.Parallel() + + marshaller, hasher := getTestMarshalizerAndHasher() + ln, _ := newLeafNode(getTrieDataWithDefaultVersion(string([]byte{1, 2, 3, 4, 16}), "dog"), marshaller, hasher) + newData := []core.TrieData{ + getTrieDataWithDefaultVersion(string([]byte{1, 2, 3, 4, 16}), ""), + getTrieDataWithDefaultVersion(string([]byte{2, 2, 3, 4, 16}), ""), + getTrieDataWithDefaultVersion(string([]byte{3, 2, 3, 4, 16}), ""), + } + _ = ln.commitDirty(0, 5, testscommon.NewMemDbMock(), testscommon.NewMemDbMock()) + originalHash := ln.getHash() + + dirty, newNode, modifiedHashes, err := ln.delete(newData, nil) + assert.True(t, dirty) + assert.Nil(t, err) + assert.Nil(t, newNode) + assert.Equal(t, [][]byte{originalHash}, modifiedHashes) + }) + t.Run("delete not existing", func(t *testing.T) { + t.Parallel() + + marshaller, hasher := getTestMarshalizerAndHasher() + ln, _ := newLeafNode(getTrieDataWithDefaultVersion(string([]byte{1, 2, 3, 4, 16}), "dog"), marshaller, hasher) + newData := []core.TrieData{ + getTrieDataWithDefaultVersion(string([]byte{2, 2, 3, 4, 16}), ""), + getTrieDataWithDefaultVersion(string([]byte{3, 2, 3, 4, 16}), ""), + } + _ = ln.commitDirty(0, 5, testscommon.NewMemDbMock(), testscommon.NewMemDbMock()) + + dirty, newNode, modifiedHashes, err := ln.delete(newData, nil) + assert.False(t, dirty) + assert.Nil(t, err) + assert.Equal(t, ln, newNode) + assert.Equal(t, [][]byte{}, modifiedHashes) + }) +} diff --git a/trie/patriciaMerkleTrie.go b/trie/patriciaMerkleTrie.go index 54b27d6a1d9..07c0e06cb16 100644 --- a/trie/patriciaMerkleTrie.go +++ b/trie/patriciaMerkleTrie.go @@ -151,7 +151,7 @@ func (tr *patriciaMerkleTrie) updateBatch(key []byte, value []byte, version core Value: value, Version: version, } - tr.batchManager.Add(hexKey, newData) + tr.batchManager.Add(newData) return nil } @@ -172,56 +172,73 @@ func (tr *patriciaMerkleTrie) updateTrie() error { } defer tr.batchManager.MarkTrieUpdateCompleted() - keys, data := batch.GetSortedDataForInsertion() - for _, key := range keys { - newData := data[key] - if tr.root == nil { - newRoot, err := newLeafNode(newData, tr.marshalizer, tr.hasher) - if err != nil { - return err - } - - tr.root = newRoot - continue - } + err = tr.insertBatch(batch.GetSortedDataForInsertion()) + if err != nil { + return err + } - if !tr.root.isDirty() { - tr.oldRoot = tr.root.getHash() - } + return tr.deleteBatch(batch.GetSortedDataForRemoval()) +} - newRoot, oldHashes, err := tr.root.insert(newData, tr.trieStorage) +func (tr *patriciaMerkleTrie) insertBatch(sortedDataForInsertion []core.TrieData) error { + if len(sortedDataForInsertion) == 0 { + return nil + } + + if tr.root == nil { + newRoot, err := newLeafNode(sortedDataForInsertion[0], tr.marshalizer, tr.hasher) if err != nil { return err } - if check.IfNil(newRoot) { - continue + tr.root = newRoot + sortedDataForInsertion = sortedDataForInsertion[1:] + + if len(sortedDataForInsertion) == 0 { + return nil } + } - tr.root = newRoot - tr.oldHashes = append(tr.oldHashes, oldHashes...) + if !tr.root.isDirty() { + tr.oldRoot = tr.root.getHash() + } - logArrayWithTrace("oldHashes after insert", "hash", oldHashes) + newRoot, oldHashes, err := tr.root.insert(sortedDataForInsertion, tr.trieStorage) + if err != nil { + return err } - keysToBeRemoved := batch.GetSortedDataForRemoval() - for _, hexKey := range keysToBeRemoved { - if tr.root == nil { - return nil - } + if check.IfNil(newRoot) { + return nil + } - if !tr.root.isDirty() { - tr.oldRoot = tr.root.getHash() - } + tr.root = newRoot + tr.oldHashes = append(tr.oldHashes, oldHashes...) - _, newRoot, oldHashes, err := tr.root.delete([]byte(hexKey), tr.trieStorage) - if err != nil { - return err - } - tr.root = newRoot - tr.oldHashes = append(tr.oldHashes, oldHashes...) - logArrayWithTrace("oldHashes after delete", "hash", oldHashes) + logArrayWithTrace("oldHashes after insert", "hash", oldHashes) + return nil +} + +func (tr *patriciaMerkleTrie) deleteBatch(data []core.TrieData) error { + if len(data) == 0 { + return nil + } + + if tr.root == nil { + return nil + } + + if !tr.root.isDirty() { + tr.oldRoot = tr.root.getHash() + } + + _, newRoot, oldHashes, err := tr.root.delete(data, tr.trieStorage) + if err != nil { + return err } + tr.root = newRoot + tr.oldHashes = append(tr.oldHashes, oldHashes...) + logArrayWithTrace("oldHashes after delete", "hash", oldHashes) return nil } diff --git a/trie/patriciaMerkleTrie_test.go b/trie/patriciaMerkleTrie_test.go index 6b14e406b96..b4d19739073 100644 --- a/trie/patriciaMerkleTrie_test.go +++ b/trie/patriciaMerkleTrie_test.go @@ -1536,6 +1536,21 @@ func TestPatriciaMerkleTrie_IsMigrated(t *testing.T) { }) } +func TestPatriciaMerkleTrie_InsertOneValInNilTrie(t *testing.T) { + t.Parallel() + + tr := emptyTrie() + key := []byte("dog") + value := []byte("cat") + _ = tr.Update(key, value) + trie.ExecuteUpdatesFromBatch(tr) + + val, depth, err := tr.Get(key) + assert.Nil(t, err) + assert.Equal(t, value, val) + assert.Equal(t, uint32(0), depth) +} + func BenchmarkPatriciaMerkleTree_Insert(b *testing.B) { tr := emptyTrie() hsh := keccak.NewKeccak() diff --git a/trie/trieBatchManager/trieBatchManager.go b/trie/trieBatchManager/trieBatchManager.go index ae74ba2e826..4ea3713eb76 100644 --- a/trie/trieBatchManager/trieBatchManager.go +++ b/trie/trieBatchManager/trieBatchManager.go @@ -56,11 +56,11 @@ func (t *trieBatchManager) MarkTrieUpdateCompleted() { } // Add adds a new key and data to the current batch -func (t *trieBatchManager) Add(key []byte, data core.TrieData) { +func (t *trieBatchManager) Add(data core.TrieData) { t.mutex.RLock() defer t.mutex.RUnlock() - t.currentBatch.Add(key, data) + t.currentBatch.Add(data) } // Get returns the data for the given key checking first the current batch and then the temp batch if the trie update is in progress diff --git a/trie/trieBatchManager/trieBatchManager_test.go b/trie/trieBatchManager/trieBatchManager_test.go index 6edb6219ad7..e253fda7b5b 100644 --- a/trie/trieBatchManager/trieBatchManager_test.go +++ b/trie/trieBatchManager/trieBatchManager_test.go @@ -23,8 +23,8 @@ func TestTrieBatchManager_TrieUpdateInProgress(t *testing.T) { t.Parallel() tbm := NewTrieBatchManager() - tbm.Add([]byte("key1"), core.TrieData{}) - tbm.Add([]byte("key2"), core.TrieData{}) + tbm.Add(core.TrieData{Key: []byte("key1")}) + tbm.Add(core.TrieData{Key: []byte("key2")}) assert.False(t, tbm.isUpdateInProgress) assert.True(t, check.IfNil(tbm.tempBatch)) @@ -50,12 +50,12 @@ func TestTrieBatchManager_AddUpdatesCurrentBatch(t *testing.T) { t.Parallel() tbm := NewTrieBatchManager() - tbm.Add([]byte("key1"), core.TrieData{}) + tbm.Add(core.TrieData{Key: []byte("key1")}) _, found := tbm.currentBatch.Get([]byte("key1")) assert.True(t, found) _, _ = tbm.MarkTrieUpdateInProgress() - tbm.Add([]byte("key2"), core.TrieData{}) + tbm.Add(core.TrieData{Key: []byte("key2")}) _, found = tbm.currentBatch.Get([]byte("key2")) assert.True(t, found) } @@ -70,7 +70,8 @@ func TestTrieBatchManager_Get(t *testing.T) { t.Parallel() tbm := NewTrieBatchManager() - tbm.currentBatch.Add(key, core.TrieData{ + tbm.currentBatch.Add(core.TrieData{ + Key: key, Value: value, }) @@ -82,7 +83,8 @@ func TestTrieBatchManager_Get(t *testing.T) { t.Parallel() tbm := NewTrieBatchManager() - tbm.currentBatch.Add(key, core.TrieData{ + tbm.currentBatch.Add(core.TrieData{ + Key: key, Value: value, }) _, _ = tbm.MarkTrieUpdateInProgress() diff --git a/trie/trieChangesBatch/trieChangesBatch.go b/trie/trieChangesBatch/trieChangesBatch.go index b73df553db6..2ea86e37ebb 100644 --- a/trie/trieChangesBatch/trieChangesBatch.go +++ b/trie/trieChangesBatch/trieChangesBatch.go @@ -1,6 +1,7 @@ package trieChangesBatch import ( + "bytes" "sort" "sync" @@ -23,16 +24,16 @@ func NewTrieChangesBatch() *trieChangesBatch { } // Add adds a new key and data to the batch -func (t *trieChangesBatch) Add(key []byte, data core.TrieData) { +func (t *trieChangesBatch) Add(data core.TrieData) { t.mutex.Lock() defer t.mutex.Unlock() - _, ok := t.deletedKeys[string(key)] + _, ok := t.deletedKeys[string(data.Key)] if ok { - delete(t.deletedKeys, string(key)) + delete(t.deletedKeys, string(data.Key)) } - t.insertedData[string(key)] = data + t.insertedData[string(data.Key)] = data } // MarkForRemoval marks the key for removal @@ -67,34 +68,40 @@ func (t *trieChangesBatch) Get(key []byte) ([]byte, bool) { } // GetSortedDataForInsertion returns the data sorted for insertion -func (t *trieChangesBatch) GetSortedDataForInsertion() ([]string, map[string]core.TrieData) { +func (t *trieChangesBatch) GetSortedDataForInsertion() []core.TrieData { t.mutex.RLock() defer t.mutex.RUnlock() - keys := make([]string, 0, len(t.insertedData)) + data := make([]core.TrieData, 0, len(t.insertedData)) for k := range t.insertedData { - keys = append(keys, k) + data = append(data, t.insertedData[k]) } - sort.Strings(keys) - return keys, t.insertedData + return getSortedData(data) } // GetSortedDataForRemoval returns the data sorted for removal -func (t *trieChangesBatch) GetSortedDataForRemoval() []string { +func (t *trieChangesBatch) GetSortedDataForRemoval() []core.TrieData { t.mutex.RLock() defer t.mutex.RUnlock() - keys := make([]string, 0, len(t.deletedKeys)) + data := make([]core.TrieData, 0, len(t.deletedKeys)) for k := range t.deletedKeys { - keys = append(keys, k) + data = append(data, core.TrieData{Key: []byte(k)}) } - sort.Strings(keys) - return keys + return getSortedData(data) } // IsInterfaceNil returns true if there is no value under the interface func (t *trieChangesBatch) IsInterfaceNil() bool { return t == nil } + +func getSortedData(data []core.TrieData) []core.TrieData { + sort.Slice(data, func(i, j int) bool { + return bytes.Compare(data[i].Key, data[j].Key) < 0 + }) + + return data +} diff --git a/trie/trieChangesBatch/trieChangesBatch_test.go b/trie/trieChangesBatch/trieChangesBatch_test.go index 1d0b64e910b..383c922fff6 100644 --- a/trie/trieChangesBatch/trieChangesBatch_test.go +++ b/trie/trieChangesBatch/trieChangesBatch_test.go @@ -20,7 +20,6 @@ func TestNewTrieChangesBatch(t *testing.T) { func TestTrieChangesBatch_Add(t *testing.T) { t.Parallel() - keyForInsertion := []byte("keyForInsertion") dataForInsertion := core.TrieData{ Key: []byte("trieKey"), Value: []byte("trieValue"), @@ -28,12 +27,12 @@ func TestTrieChangesBatch_Add(t *testing.T) { } tcb := NewTrieChangesBatch() - tcb.deletedKeys[string(keyForInsertion)] = struct{}{} + tcb.deletedKeys[string(dataForInsertion.Key)] = struct{}{} - tcb.Add(keyForInsertion, dataForInsertion) + tcb.Add(dataForInsertion) assert.Equal(t, 0, len(tcb.deletedKeys)) assert.Equal(t, 1, len(tcb.insertedData)) - assert.Equal(t, dataForInsertion, tcb.insertedData[string(keyForInsertion)]) + assert.Equal(t, dataForInsertion, tcb.insertedData[string(dataForInsertion.Key)]) } func TestTrieChangesBatch_MarkForRemoval(t *testing.T) { @@ -102,14 +101,15 @@ func TestTrieChangesBatch_GetSortedDataForInsertion(t *testing.T) { t.Parallel() tcb := NewTrieChangesBatch() + tcb.insertedData["key3"] = core.TrieData{Key: []byte("key3")} + tcb.insertedData["key1"] = core.TrieData{Key: []byte("key1")} + tcb.insertedData["key2"] = core.TrieData{Key: []byte("key2")} - tcb.insertedData["key3"] = core.TrieData{} - tcb.insertedData["key1"] = core.TrieData{} - tcb.insertedData["key2"] = core.TrieData{} - - keys, data := tcb.GetSortedDataForInsertion() - assert.Equal(t, []string{"key1", "key2", "key3"}, keys) + data := tcb.GetSortedDataForInsertion() assert.Equal(t, 3, len(data)) + assert.Equal(t, "key1", string(data[0].Key)) + assert.Equal(t, "key2", string(data[1].Key)) + assert.Equal(t, "key3", string(data[2].Key)) } func TestTrieChangesBatch_GetSortedDataForRemoval(t *testing.T) { @@ -121,6 +121,9 @@ func TestTrieChangesBatch_GetSortedDataForRemoval(t *testing.T) { tcb.deletedKeys["key1"] = struct{}{} tcb.deletedKeys["key2"] = struct{}{} - keys := tcb.GetSortedDataForRemoval() - assert.Equal(t, []string{"key1", "key2", "key3"}, keys) + data := tcb.GetSortedDataForRemoval() + assert.Equal(t, 3, len(data)) + assert.Equal(t, "key1", string(data[0].Key)) + assert.Equal(t, "key2", string(data[1].Key)) + assert.Equal(t, "key3", string(data[2].Key)) }