Skip to content

Commit

Permalink
Bug fixes, refactor, and add more unit tests (txaty#12)
Browse files Browse the repository at this point in the history
* Update README.

Signed-off-by: txaty <[email protected]>

* Update README.

Signed-off-by: txaty <[email protected]>

* Add more unit tests.

* Fix bugs.

Signed-off-by: txaty <[email protected]>
  • Loading branch information
txaty authored Aug 24, 2022
1 parent 827cb04 commit 11665bd
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 66 deletions.
89 changes: 30 additions & 59 deletions merkle_tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ const (
// Default hash result length using SHA256.
defaultHashLen = 32
// Default job queue size of the goroutine pool for parallel executions.
jobQueueSize = 100
jobQueueSize = 64
)

// ModeType is the type in the Merkle Tree configuration indicating what operations are performed.
Expand Down Expand Up @@ -147,24 +147,25 @@ func New(config *Config, blocks []DataBlock) (m *MerkleTree, err error) {
}
if m.Mode == ModeTreeBuild {
if m.RunInParallel {
err = m.treeBuildParal(wp)
err = m.treeBuild(wp)
return
}
err = m.treeBuild()
err = m.treeBuild(nil)
return
}
if m.Mode == ModeProofGenAndTreeBuild {
if m.RunInParallel {
err = m.treeBuildParal(wp)
err = m.treeBuild(wp)
if err != nil {
return
}
m.initProofs()
for i := 0; i < len(m.tree); i++ {
m.updateProofsParal(m.tree[i], len(m.tree[i]), i, wp)
}
return
}
err = m.treeBuild()
err = m.treeBuild(nil)
if err != nil {
return
}
Expand Down Expand Up @@ -467,7 +468,7 @@ func (m *MerkleTree) leafGenParal(blocks []DataBlock, wp *gool.Pool) ([][]byte,
return leaves, nil
}

func (m *MerkleTree) treeBuild() (err error) {
func (m *MerkleTree) treeBuild(wp *gool.Pool) (err error) {
numLeaves := len(m.Leaves)
finishMap := make(chan struct{})
go func() {
Expand All @@ -486,10 +487,29 @@ func (m *MerkleTree) treeBuild() (err error) {
}
for i := uint32(0); i < m.Depth-1; i++ {
m.tree[i+1] = make([][]byte, prevLen>>1)
for j := 0; j < prevLen; j += 2 {
m.tree[i+1][j>>1], err = m.HashFunc(append(m.tree[i][j], m.tree[i][j+1]...))
if err != nil {
return
if m.RunInParallel {
argList := make([]interface{}, m.NumRoutines)
for j := 0; j < m.NumRoutines; j++ {
argList[j] = treeBuildArgs{
m: m,
depth: i,
start: j << 1,
prevLen: prevLen,
numRoutines: m.NumRoutines,
}
}
errList := wp.Map(treeBuildHandler, argList)
for _, err := range errList {
if err != nil {
return err.(error)
}
}
} else {
for j := 0; j < prevLen; j += 2 {
m.tree[i+1][j>>1], err = m.HashFunc(append(m.tree[i][j], m.tree[i][j+1]...))
if err != nil {
return
}
}
}
m.tree[i+1], prevLen, err = m.fixOdd(m.tree[i+1], len(m.tree[i+1]))
Expand Down Expand Up @@ -526,55 +546,6 @@ func treeBuildHandler(argInterface interface{}) interface{} {
return nil
}

func (m *MerkleTree) treeBuildParal(wp *gool.Pool) (err error) {
numRoutines := m.NumRoutines
numLeaves := len(m.Leaves)
finishMap := make(chan struct{})
go func() {
for i := 0; i < numLeaves; i++ {
m.leafMap.Store(string(m.Leaves[i]), i)
}
finishMap <- struct{}{}
}()
m.tree = make([][][]byte, m.Depth)
m.tree[0] = make([][]byte, numLeaves)
copy(m.tree[0], m.Leaves)
var prevLen int
m.tree[0], prevLen, err = m.fixOdd(m.tree[0], numLeaves)
if err != nil {
return
}
for i := uint32(0); i < m.Depth-1; i++ {
m.tree[i+1] = make([][]byte, prevLen>>1)
argList := make([]interface{}, numRoutines)
for j := 0; j < numRoutines; j++ {
argList[j] = treeBuildArgs{
m: m,
depth: i,
start: j << 1,
prevLen: prevLen,
numRoutines: numRoutines,
}
}
errList := wp.Map(treeBuildHandler, argList)
for _, err := range errList {
if err != nil {
return err.(error)
}
}
m.tree[i+1], prevLen, err = m.fixOdd(m.tree[i+1], len(m.tree[i+1]))
if err != nil {
return
}
}
m.Root, err = m.HashFunc(append(m.tree[m.Depth-1][0], m.tree[m.Depth-1][1]...))
if err != nil {
return
}
<-finishMap
return
}

// Verify verifies the data block with the Merkle Tree proof
func (m *MerkleTree) Verify(dataBlock DataBlock, proof *Proof) (bool, error) {
return Verify(dataBlock, proof, m.Root, m.HashFunc)
Expand Down
94 changes: 87 additions & 7 deletions merkle_tree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -683,13 +683,6 @@ func TestMerkleTree_Verify(t *testing.T) {
want: true,
wantErr: false,
},
{
name: "test_pseudo_random_10",
setupFunc: verifySetup,
blockSize: 10,
want: true,
wantErr: false,
},
{
name: "test_pseudo_random_1001",
setupFunc: verifySetup,
Expand Down Expand Up @@ -826,6 +819,12 @@ func TestMerkleTree_GenerateProof(t *testing.T) {
}
}

func testHashFunc(data []byte) ([]byte, error) {
sha256Func := sha256.New()
sha256Func.Write(data)
return sha256Func.Sum(nil), nil
}

func TestMerkleTree_proofGen(t *testing.T) {
patches := gomonkey.NewPatches()
defer patches.Reset()
Expand Down Expand Up @@ -855,6 +854,22 @@ func TestMerkleTree_proofGen(t *testing.T) {
},
wantErr: true,
},
{
name: "test_hash_func_err",
args: args{
config: &Config{
HashFunc: testHashFunc,
},
blocks: genTestDataBlocks(5),
},
mock: func() {
patches.ApplyFunc(testHashFunc,
func([]byte) ([]byte, error) {
return nil, errors.New("test_hash_func_err")
})
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -948,6 +963,30 @@ func TestVerify(t *testing.T) {
want: false,
wantErr: true,
},
{
name: "test_hash_func_nil",
args: args{
dataBlock: blocks[0],
proof: m.Proofs[0],
root: m.Root,
hashFunc: nil,
},
want: true,
wantErr: false,
},
{
name: "test_hash_func_err",
args: args{
dataBlock: blocks[0],
proof: m.Proofs[0],
root: m.Root,
hashFunc: func([]byte) ([]byte, error) {
return nil, errors.New("test_hash_func_err")
},
},
want: false,
wantErr: true,
},
{
name: "data_block_serialize_err",
args: args{
Expand Down Expand Up @@ -984,6 +1023,47 @@ func TestVerify(t *testing.T) {
}
}

func Test_proofGenHandler(t *testing.T) {
patches := gomonkey.NewPatches()
defer patches.Reset()
type args struct {
argInterface interface{}
}
tests := []struct {
name string
args args
mock func()
wantErr bool
}{
{
name: "test_hash_func_err",
args: args{
argInterface: &proofGenArgs{
hashFunc: func([]byte) ([]byte, error) {
return nil, errors.New("test_hash_func_err")
},
buf1: [][]byte{[]byte("test_buf1"), []byte("test_buf1")},
buf2: [][]byte{[]byte("test_buf2")},
prevLen: 2,
numRoutines: 2,
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.mock != nil {
tt.mock()
}
defer patches.Reset()
if err := proofGenHandler(tt.args.argInterface); (err != nil) != tt.wantErr {
t.Errorf("proofGenHandler() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

func BenchmarkMerkleTreeNew(b *testing.B) {
testCases := genTestDataBlocks(benchSize)
b.ResetTimer()
Expand Down

0 comments on commit 11665bd

Please sign in to comment.