Skip to content

Commit

Permalink
Refactor and add more unit tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
txaty committed Aug 23, 2022
1 parent d6074ec commit ca0afd0
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 27 deletions.
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ tests:
tests_race:
go test -v -race -covermode atomic -coverprofile coverage.out && go tool cover -html coverage.out -o coverage.html && go tool cover -func coverage.out -o coverage.out

tests_with_mock:
go test -v -race -gcflags=all=-l -covermode atomic -coverprofile coverage.out && go tool cover -html coverage.out -o coverage.html && go tool cover -func coverage.out -o coverage.out

format:
go fmt .

Expand Down
3 changes: 3 additions & 0 deletions merkle_tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,9 @@ func (m *MerkleTree) Verify(dataBlock DataBlock, proof *Proof) (bool, error) {

// Verify verifies the data block with the Merkle Tree proof and Merkle root hash
func Verify(dataBlock DataBlock, proof *Proof, root []byte, hashFunc HashFuncType) (bool, error) {
if dataBlock == nil {
return false, errors.New("data block is nil")
}
if proof == nil {
return false, errors.New("proof is nil")
}
Expand Down
168 changes: 141 additions & 27 deletions merkle_tree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@ package merkletree

import (
"bytes"
"crypto/rand"
"crypto/sha256"
"errors"
"fmt"
"github.com/agiledragon/gomonkey/v2"
"math/rand"
"reflect"
"runtime"
"testing"

"github.com/agiledragon/gomonkey/v2"
)

const benchSize = 10000
Expand Down Expand Up @@ -732,19 +733,6 @@ func TestMerkleTree_Verify(t *testing.T) {
}
}

func TestVerify(t *testing.T) {
m, blocks, _ := verifySetup(2)
// hashFunc is nil
got, err := Verify(blocks[0], m.Proofs[0], []byte{}, nil)
if err != nil {
t.Errorf("Verify() error = %v, wantErr %v", err, nil)
return
}
if got {
t.Errorf("Verify() got = %v, want %v", got, false)
}
}

func BenchmarkMerkleTreeNew(b *testing.B) {
for i := 0; i < b.N; i++ {
_, err := New(nil, genTestDataBlocks(benchSize))
Expand All @@ -768,9 +756,12 @@ func BenchmarkMerkleTreeNewParallel(b *testing.B) {
}

func TestMerkleTree_GenerateProof(t *testing.T) {
patches := gomonkey.NewPatches()
defer patches.Reset()
tests := []struct {
name string
config *Config
mock func()
blocks []DataBlock
proofBlocks []DataBlock
wantErr bool
Expand All @@ -780,11 +771,6 @@ func TestMerkleTree_GenerateProof(t *testing.T) {
config: &Config{Mode: ModeTreeBuild},
blocks: genTestDataBlocks(2),
},
{
name: "test_3",
config: &Config{Mode: ModeTreeBuild},
blocks: genTestDataBlocks(3),
},
{
name: "test_4",
config: &Config{Mode: ModeTreeBuild},
Expand Down Expand Up @@ -812,6 +798,18 @@ func TestMerkleTree_GenerateProof(t *testing.T) {
},
wantErr: true,
},
{
name: "test_data_block_serialize_error",
config: &Config{Mode: ModeTreeBuild},
mock: func() {
patches.ApplyMethod(reflect.TypeOf(&mockDataBlock{}), "Serialize",
func(*mockDataBlock) ([]byte, error) {
return nil, errors.New("data block serialize error")
})
},
blocks: genTestDataBlocks(5),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -828,6 +826,10 @@ func TestMerkleTree_GenerateProof(t *testing.T) {
if tt.proofBlocks == nil {
tt.proofBlocks = tt.blocks
}
if tt.mock != nil {
tt.mock()
}
defer patches.Reset()
for idx, block := range tt.proofBlocks {
got, err := m2.GenerateProof(block)
if (err != nil) != tt.wantErr {
Expand All @@ -847,6 +849,8 @@ func TestMerkleTree_GenerateProof(t *testing.T) {
}

func TestMerkleTree_proofGen(t *testing.T) {
patches := gomonkey.NewPatches()
defer patches.Reset()
type args struct {
config *Config
blocks []DataBlock
Expand All @@ -866,28 +870,138 @@ func TestMerkleTree_proofGen(t *testing.T) {
blocks: genTestDataBlocks(5),
},
mock: func() {
patches := gomonkey.ApplyFunc(rand.Read,
func(b []byte) (n int, err error) {
return 0, errors.New("test_rand_read_err")
patches.ApplyFunc(getDummyHash,
func() ([]byte, error) {
return nil, errors.New("test_get_dummy_hash_err")
})
defer patches.Reset()
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.mock != nil {
tt.mock()
}
m, err := New(tt.args.config, tt.args.blocks)
if err != nil {
t.Errorf("New() error = %v", err)
return
}
if tt.mock != nil {
tt.mock()
}
defer patches.Reset()
if err := m.proofGen(); (err != nil) != tt.wantErr {
t.Errorf("proofGen() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

func TestVerify(t *testing.T) {
blocks := genTestDataBlocks(5)
m, err := New(nil, blocks)
if err != nil {
t.Errorf("New() error = %v", err)
return
}
patches := gomonkey.NewPatches()
defer patches.Reset()
type args struct {
dataBlock DataBlock
proof *Proof
root []byte
hashFunc HashFuncType
}
tests := []struct {
name string
args args
mock func()
want bool
wantErr bool
}{
{
name: "test_ok",
args: args{
dataBlock: blocks[0],
proof: m.Proofs[0],
root: m.Root,
hashFunc: m.HashFunc,
},
want: true,
},
{
name: "test_wrong_root",
args: args{
dataBlock: blocks[0],
proof: m.Proofs[0],
root: []byte("test_wrong_root"),
hashFunc: m.HashFunc,
},
want: false,
},
{
name: "test_wrong_hash_func",
args: args{
dataBlock: blocks[0],
proof: m.Proofs[0],
root: m.Root,
hashFunc: func([]byte) ([]byte, error) { return []byte("test_wrong_hash_hash"), nil },
},
want: false,
},
{
name: "test_proof_nil",
args: args{
dataBlock: blocks[0],
proof: nil,
root: m.Root,
hashFunc: m.HashFunc,
},
want: false,
wantErr: true,
},
{
name: "test_data_block_nil",
args: args{
dataBlock: nil,
proof: m.Proofs[0],
root: m.Root,
hashFunc: m.HashFunc,
},
want: false,
wantErr: true,
},
{
name: "data_block_serialize_err",
args: args{
dataBlock: blocks[0],
proof: m.Proofs[0],
root: m.Root,
hashFunc: m.HashFunc,
},
mock: func() {
patches.ApplyMethod(reflect.TypeOf(&mockDataBlock{}), "Serialize",
func(m *mockDataBlock) ([]byte, error) {
return nil, errors.New("test_data_block_serialize_err")
})
},
want: false,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.mock != nil {
tt.mock()
}
defer patches.Reset()
got, err := Verify(tt.args.dataBlock, tt.args.proof, tt.args.root, tt.args.hashFunc)
if (err != nil) != tt.wantErr {
t.Errorf("Verify() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("Verify() = %v, want %v", got, tt.want)
}
})
}
}

0 comments on commit ca0afd0

Please sign in to comment.