Skip to content

Commit

Permalink
Refactor and add unit tests.
Browse files Browse the repository at this point in the history
Signed-off-by: txaty <[email protected]>
  • Loading branch information
txaty committed Aug 21, 2022
1 parent 460d8ea commit 8d3efd6
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 41 deletions.
29 changes: 10 additions & 19 deletions merkle_tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ func New(config *Config, blocks []DataBlock) (m *MerkleTree, err error) {
}
return
}
return
return nil, errors.New("invalid configuration mode")
}

// calTreeDepth calculates the tree depth,
Expand Down Expand Up @@ -324,20 +324,14 @@ func (m *MerkleTree) fixOdd(buf [][]byte, prevLen int) ([][]byte, int, error) {
}

func (m *MerkleTree) updateProofs(buf [][]byte, bufLen, step int) {
if bufLen < 2 {
return
}
batch := 1 << step
for i := 0; i < bufLen; i += 2 {
m.updatePairProof(buf, bufLen, i, batch, step)
m.updatePairProof(buf, i, batch, step)
}
}

func (m *MerkleTree) updateProofsParal(buf [][]byte, bufLen, step int) {
numRoutines := m.NumRoutines
if bufLen < 2 {
return
}
batch := 1 << step
wg := new(sync.WaitGroup)
for i := 0; i < numRoutines; i++ {
Expand All @@ -346,17 +340,14 @@ func (m *MerkleTree) updateProofsParal(buf [][]byte, bufLen, step int) {
go func() {
defer wg.Done()
for j := idx; j < bufLen; j += numRoutines << 1 {
m.updatePairProof(buf, bufLen, j, batch, step)
m.updatePairProof(buf, j, batch, step)
}
}()
}
wg.Wait()
}

func (m *MerkleTree) updatePairProof(buf [][]byte, bufLen, idx, batch, step int) {
if bufLen < 2 {
return
}
func (m *MerkleTree) updatePairProof(buf [][]byte, idx, batch, step int) {
start := idx * batch
end := start + batch
if end > len(m.Proofs) {
Expand Down Expand Up @@ -459,11 +450,11 @@ func (m *MerkleTree) treeBuild() (err error) {
copy(m.tree[0], m.Leaves)
var prevLen int
m.tree[0], prevLen, err = m.fixOdd(m.tree[0], numLeaves)
if err != nil {
return
}
for i := uint32(0); i < m.Depth-1; i++ {
m.tree[i+1] = make([][]byte, prevLen>>1)
if err != nil {
return
}
for j := 0; j < prevLen; j += 2 {
m.tree[i+1][j>>1], err = m.HashFunc(append(m.tree[i][j], m.tree[i][j+1]...))
if err != nil {
Expand Down Expand Up @@ -497,11 +488,11 @@ func (m *MerkleTree) treeBuildParal() (err error) {
copy(m.tree[0], m.Leaves)
var prevLen int
m.tree[0], prevLen, err = m.fixOdd(m.tree[0], numLeaves)
if err != nil {
return
}
for i := uint32(0); i < m.Depth-1; i++ {
m.tree[i+1] = make([][]byte, prevLen>>1)
if err != nil {
return
}
g := new(errgroup.Group)
for j := 0; j < numRoutines && j < prevLen; j++ {
idx := j << 1
Expand Down
75 changes: 53 additions & 22 deletions merkle_tree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ package merkletree

import (
"bytes"
"crypto/sha256"
"fmt"
"math/rand"
"reflect"
Expand Down Expand Up @@ -87,13 +88,6 @@ func TestMerkleTreeNew_proofGen(t *testing.T) {
},
wantErr: false,
},
{
name: "test_4",
args: args{
blocks: genTestDataBlocks(4),
},
wantErr: false,
},
{
name: "test_8",
args: args{
Expand Down Expand Up @@ -173,6 +167,16 @@ func TestMerkleTreeNew_proofGen(t *testing.T) {
},
wantErr: true,
},
{
name: "bad_mode",
args: args{
blocks: genTestDataBlocks(100),
config: &Config{
Mode: 5,
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -203,16 +207,6 @@ func TestMerkleTreeNew_buildTree(t *testing.T) {
},
wantErr: false,
},
{
name: "test_build_tree_4",
args: args{
blocks: genTestDataBlocks(4),
config: &Config{
Mode: ModeTreeBuild,
},
},
wantErr: false,
},
{
name: "test_build_tree_5",
args: args{
Expand Down Expand Up @@ -545,6 +539,25 @@ func TestMerkleTreeNew_proofGenAndTreeBuildParallel(t *testing.T) {
},
wantErr: true,
},
{
name: "test_tree_build_hash_func_error",
args: args{
blocks: genTestDataBlocks(100),
config: &Config{
HashFunc: func(block []byte) ([]byte, error) {
if len(block) == 64 {
return nil, fmt.Errorf("hash func error")
}
sha256Func := sha256.New()
sha256Func.Write(block)
return sha256Func.Sum(nil), nil
},
Mode: ModeProofGenAndTreeBuild,
RunInParallel: true,
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -736,10 +749,11 @@ func BenchmarkMerkleTreeNewParallel(b *testing.B) {

func TestMerkleTree_GenerateProof(t *testing.T) {
tests := []struct {
name string
config *Config
blocks []DataBlock
wantErr bool
name string
config *Config
blocks []DataBlock
proofBlocks []DataBlock
wantErr bool
}{
{
name: "test_2",
Expand Down Expand Up @@ -767,6 +781,17 @@ func TestMerkleTree_GenerateProof(t *testing.T) {
blocks: genTestDataBlocks(5),
wantErr: true,
},
{
name: "test_wrong_blocks",
config: &Config{Mode: ModeTreeBuild},
blocks: genTestDataBlocks(5),
proofBlocks: []DataBlock{
&mockDataBlock{
[]byte("test_wrong_blocks"),
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -780,12 +805,18 @@ func TestMerkleTree_GenerateProof(t *testing.T) {
t.Errorf("m2 New() error = %v", err)
return
}
for idx, block := range tt.blocks {
if tt.proofBlocks == nil {
tt.proofBlocks = tt.blocks
}
for idx, block := range tt.proofBlocks {
got, err := m2.GenerateProof(block)
if (err != nil) != tt.wantErr {
t.Errorf("GenerateProof() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
if !reflect.DeepEqual(got, m1.Proofs[idx]) && !tt.wantErr {
t.Errorf("GenerateProof() %d got = %v, want %v", idx, got, m1.Proofs[idx])
return
Expand Down

0 comments on commit 8d3efd6

Please sign in to comment.