Skip to content

Commit

Permalink
refactor: performance improvement
Browse files Browse the repository at this point in the history
1. replace node strcut with []byte
2. precompute tree depth for proof neighbor slice allocation

Signed-off-by: txaty <[email protected]>
  • Loading branch information
txaty committed Aug 13, 2022
1 parent e39c9c5 commit 798d3e7
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 58 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func main() {
AllowDuplicates: true,
}
// build a new Merkle Tree
tree, err := mt.New(blocks, config)
tree, err := mt.New(config, blocks)
handleError(err)
// get the root hash of the Merkle Tree
rootHash := tree.Root
Expand Down
103 changes: 56 additions & 47 deletions merkle_tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"crypto/rand"
"crypto/sha256"
"math"
"sync"

"golang.org/x/sync/errgroup"
Expand All @@ -23,6 +24,8 @@ type DataBlock interface {
type Config struct {
// customizable hash function used for tree generation
HashFunc func([]byte) ([]byte, error)
// default length of hash function output, e.g. 32 bytes for SHA256
HashResultLen int
// if true, the generation runs in parallel,
// this increase the performance for the calculation of large number of data blocks, e.g. over 10,000 blocks
RunInParallel bool
Expand All @@ -35,15 +38,11 @@ type Config struct {

// MerkleTree implements the Merkle Tree structure
type MerkleTree struct {
*Config // Merkle Tree configuration
Root []byte // Merkle root hash
Leaves []*node // Merkle Tree leaves, i.e. the hashes of the data blocks for tree generation
Proofs []*Proof // proofs to the data blocks generated during the tree building process
}

// node implements the Merkle Tree node
type node struct {
Hash []byte
*Config // Merkle Tree configuration
Root []byte // Merkle root hash
Leaves [][]byte // Merkle Tree leaves, i.e. the hashes of the data blocks for tree generation
Proofs []*Proof // proofs to the data blocks generated during the tree building process
treeDepth int // the Merkle Tree depth
}

// Proof implements the Merkle Tree proof
Expand All @@ -53,19 +52,24 @@ type Proof struct {
}

// New generates a new Merkle Tree with specified configuration
func New(blocks []DataBlock, config *Config) (m *MerkleTree, err error) {
func New(config *Config, blocks []DataBlock) (m *MerkleTree, err error) {
if len(blocks) <= 1 {
return nil, nil
}
if config == nil {
config = &Config{}
}
if config.HashFunc == nil {
config.HashFunc = defaultHashFunc
config.HashResultLen = defaultHashLen
}
if config.HashResultLen == 0 {
config.HashResultLen = defaultHashLen
}
m = &MerkleTree{
Config: config,
}
if len(blocks) <= 1 {
return nil, nil
}
m.treeDepth = calTreeDepth(len(blocks))
if m.RunInParallel {
m.Leaves, err = generateLeavesParallel(blocks, m.HashFunc, m.Config.NumRoutines)
if err != nil {
Expand All @@ -82,31 +86,40 @@ func New(blocks []DataBlock, config *Config) (m *MerkleTree, err error) {
return
}

func calTreeDepth(blockLen int) int {
log2BlockLen := math.Log2(float64(blockLen))
return int(math.Round(log2BlockLen) + 0.499)
}

func (m *MerkleTree) buildTree() (root []byte, err error) {
numLeaves := len(m.Leaves)
m.Proofs = make([]*Proof, numLeaves)
for i := 0; i < numLeaves; i++ {
m.Proofs[i] = new(Proof)
m.Proofs[i].Neighbors = make([][]byte, 0, m.treeDepth)
}
var (
step = 1
prevLen int
)
buf := make([]*node, numLeaves)
buf := make([][]byte, numLeaves)
copy(buf, m.Leaves)
buf, prevLen, err = m.fixOdd(buf, numLeaves)
if err != nil {
return nil, err
}
m.assignProves(buf, numLeaves, 0)
m.assignProofs(buf, numLeaves, 0)
for {
buf, prevLen, err = m.fixOdd(buf, prevLen)
if err != nil {
return nil, err
}
for idx := 0; idx < prevLen; idx += 2 {
appendHash := append(buf[idx].Hash, buf[idx+1].Hash...)
buf[idx/2].Hash, err = m.HashFunc(appendHash)
appendHash := append(buf[idx], buf[idx+1]...)
buf[idx/2], err = m.HashFunc(appendHash)
if err != nil {
return nil, err
}
}
prevLen /= 2
if prevLen == 1 {
Expand All @@ -117,30 +130,28 @@ func (m *MerkleTree) buildTree() (root []byte, err error) {
return nil, err
}
}
m.assignProves(buf, prevLen, step)
m.assignProofs(buf, prevLen, step)
step++
}
root = buf[0].Hash
root = buf[0]
m.Root = root
return
}

// if the length of the buffer calculating the Merkle Tree is odd, then append a node to the buffer
// if AllowDuplicates is true, append a node by duplicating the previous node
// otherwise, append a node by random
func (m *MerkleTree) fixOdd(buf []*node, prevLen int) ([]*node, int, error) {
func (m *MerkleTree) fixOdd(buf [][]byte, prevLen int) ([][]byte, int, error) {
if prevLen%2 == 1 {
var appendNode *node
var appendNode []byte
if m.AllowDuplicates {
appendNode = buf[prevLen-1]
} else {
dummyHash, err := getDummyHash()
var err error
appendNode, err = getDummyHash()
if err != nil {
return nil, 0, err
}
appendNode = &node{
Hash: dummyHash,
}
}
if len(buf) <= prevLen+1 {
buf = append(buf, appendNode)
Expand All @@ -152,7 +163,7 @@ func (m *MerkleTree) fixOdd(buf []*node, prevLen int) ([]*node, int, error) {
return buf, prevLen, nil
}

func (m *MerkleTree) assignProves(buf []*node, bufLen, step int) {
func (m *MerkleTree) assignProofs(buf [][]byte, bufLen, step int) {
if bufLen < 2 {
return
}
Expand All @@ -162,7 +173,7 @@ func (m *MerkleTree) assignProves(buf []*node, bufLen, step int) {
}
}

func (m *MerkleTree) assignProvesParallel(buf []*node, bufLen, step int) {
func (m *MerkleTree) assignProofsParallel(buf [][]byte, bufLen, step int) {
numRoutines := m.NumRoutines
if bufLen < 2 {
return
Expand All @@ -182,7 +193,7 @@ func (m *MerkleTree) assignProvesParallel(buf []*node, bufLen, step int) {
wg.Wait()
}

func (m *MerkleTree) assignPairProof(buf []*node, bufLen, idx, batch, step int) {
func (m *MerkleTree) assignPairProof(buf [][]byte, bufLen, idx, batch, step int) {
if bufLen < 2 {
return
}
Expand All @@ -193,15 +204,15 @@ func (m *MerkleTree) assignPairProof(buf []*node, bufLen, idx, batch, step int)
}
for j := start; j < end; j++ {
m.Proofs[j].Path += 1 << step
m.Proofs[j].Neighbors = append(m.Proofs[j].Neighbors, buf[idx+1].Hash)
m.Proofs[j].Neighbors = append(m.Proofs[j].Neighbors, buf[idx+1])
}
start = (idx + 1) * batch
end = start + batch
if end > len(m.Proofs) {
end = len(m.Proofs)
}
for j := start; j < end; j++ {
m.Proofs[j].Neighbors = append(m.Proofs[j].Neighbors, buf[idx].Hash)
m.Proofs[j].Neighbors = append(m.Proofs[j].Neighbors, buf[idx])
}
}

Expand All @@ -211,19 +222,20 @@ func (m *MerkleTree) buildTreeParallel() (root []byte, err error) {
m.Proofs = make([]*Proof, numLeaves)
for i := 0; i < numLeaves; i++ {
m.Proofs[i] = new(Proof)
m.Proofs[i].Neighbors = make([][]byte, 0, m.treeDepth)
}
var (
step = 1
prevLen int
)
buf1 := make([]*node, numLeaves)
buf1 := make([][]byte, numLeaves)
copy(buf1, m.Leaves)
buf1, prevLen, err = m.fixOdd(buf1, numLeaves)
if err != nil {
return nil, err
}
buf2 := make([]*node, prevLen/2)
m.assignProvesParallel(buf1, numLeaves, 0)
buf2 := make([][]byte, prevLen/2)
m.assignProofsParallel(buf1, numLeaves, 0)
for {
buf1, prevLen, err = m.fixOdd(buf1, prevLen)
if err != nil {
Expand All @@ -234,13 +246,11 @@ func (m *MerkleTree) buildTreeParallel() (root []byte, err error) {
idx := 2 * i
g.Go(func() error {
for j := idx; j < prevLen; j += 2 * numRoutines {
newHash, err := m.HashFunc(append(buf1[j].Hash, buf1[j+1].Hash...))
newHash, err := m.HashFunc(append(buf1[j], buf1[j+1]...))
if err != nil {
return err
}
buf2[j/2] = &node{
Hash: newHash,
}
buf2[j/2] = newHash
}
return nil
})
Expand All @@ -258,10 +268,10 @@ func (m *MerkleTree) buildTreeParallel() (root []byte, err error) {
return nil, err
}
}
m.assignProvesParallel(buf1, prevLen, step)
m.assignProofsParallel(buf1, prevLen, step)
step++
}
root = buf1[0].Hash
root = buf1[0]
m.Root = root
return
}
Expand All @@ -283,10 +293,10 @@ func defaultHashFunc(data []byte) ([]byte, error) {
return sha256Func.Sum(nil), nil
}

func generateLeaves(blocks []DataBlock, hashFunc func([]byte) ([]byte, error)) ([]*node, error) {
func generateLeaves(blocks []DataBlock, hashFunc func([]byte) ([]byte, error)) ([][]byte, error) {
var (
lenLeaves = len(blocks)
leaves = make([]*node, lenLeaves)
leaves = make([][]byte, lenLeaves)
)
for i := 0; i < lenLeaves; i++ {
data, err := blocks[i].Serialize()
Expand All @@ -297,16 +307,16 @@ func generateLeaves(blocks []DataBlock, hashFunc func([]byte) ([]byte, error)) (
if err != nil {
return nil, err
}
leaves[i] = &node{Hash: hash}
leaves[i] = hash
}
return leaves, nil
}

func generateLeavesParallel(blocks []DataBlock,
hashFunc func([]byte) ([]byte, error), numRoutines int) ([]*node, error) {
hashFunc func([]byte) ([]byte, error), numRoutines int) ([][]byte, error) {
var (
lenLeaves = len(blocks)
leaves = make([]*node, lenLeaves)
leaves = make([][]byte, lenLeaves)
)
g := new(errgroup.Group)
for i := 0; i < numRoutines; i++ {
Expand All @@ -322,7 +332,7 @@ func generateLeavesParallel(blocks []DataBlock,
if err != nil {
return err
}
leaves[j] = &node{Hash: hash}
leaves[j] = hash
}
return nil
})
Expand Down Expand Up @@ -357,8 +367,7 @@ func Verify(dataBlock DataBlock, proof *Proof, root []byte,
}
path := proof.Path
for _, n := range proof.Neighbors {
dir := path & 1
if dir == 1 {
if path&1 == 1 {
hash, err = defaultHashFunc(append(hash, n...))
} else {
hash, err = defaultHashFunc(append(n, hash...))
Expand Down
18 changes: 8 additions & 10 deletions merkle_tree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"testing"
)

const benchSize = 5000000
const benchSize = 1000

type mockDataBlock struct {
data []byte
Expand Down Expand Up @@ -100,7 +100,7 @@ func TestMerkleTreeNew(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if _, err := New(tt.args.blocks, tt.args.config); (err != nil) != tt.wantErr {
if _, err := New(tt.args.config, tt.args.blocks); (err != nil) != tt.wantErr {
t.Errorf("Build() error = %v, wantErr %v", err, tt.wantErr)
}
})
Expand All @@ -109,10 +109,10 @@ func TestMerkleTreeNew(t *testing.T) {

func verifySetup(size int) (*MerkleTree, []DataBlock, error) {
blocks := genTestDataBlocks(size)
m, err := New(blocks, &Config{
m, err := New(&Config{
HashFunc: defaultHashFunc,
AllowDuplicates: true,
})
}, blocks)
if err != nil {
return nil, nil, err
}
Expand All @@ -121,12 +121,12 @@ func verifySetup(size int) (*MerkleTree, []DataBlock, error) {

func verifySetupParallel(size int) (*MerkleTree, []DataBlock, error) {
blocks := genTestDataBlocks(size)
m, err := New(blocks, &Config{
m, err := New(&Config{
HashFunc: defaultHashFunc,
AllowDuplicates: true,
RunInParallel: true,
NumRoutines: 4,
})
}, blocks)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -250,11 +250,9 @@ func BenchmarkMerkleTreeNew(b *testing.B) {
config := &Config{
HashFunc: defaultHashFunc,
AllowDuplicates: true,
RunInParallel: true,
NumRoutines: 4,
}
for i := 0; i < b.N; i++ {
_, err := New(genTestDataBlocks(benchSize), config)
_, err := New(config, genTestDataBlocks(benchSize))
if err != nil {
b.Errorf("Build() error = %v", err)
}
Expand All @@ -269,7 +267,7 @@ func BenchmarkMerkleTreeBuildParallel(b *testing.B) {
NumRoutines: runtime.NumCPU(),
}
for i := 0; i < b.N; i++ {
_, err := New(genTestDataBlocks(benchSize), config)
_, err := New(config, genTestDataBlocks(benchSize))
if err != nil {
b.Errorf("Build() error = %v", err)
}
Expand Down

0 comments on commit 798d3e7

Please sign in to comment.