From 3b445a0585ab0c1e45bf66ea2c5829d91b7b1664 Mon Sep 17 00:00:00 2001 From: Tommy TIAN Date: Thu, 23 Nov 2023 20:49:12 +0800 Subject: [PATCH] Refactor (#30) Refactor: 1. Removed goroutine pool (gool). Using goroutine pool causes overhead and increases the complexity of the code base. 2. Changed the proof generation method. Now proofGen parallelized algorithm doesn't require two buffers. 3. Added fuzzing test. 4. More unit tests were added to achieve 100% coverage. 5. Refactored many methods and functions. 6. Extracted methods and functions to separate files to increase readability. Signed-off-by: txaty --- .github/workflows/ci.yml | 85 +- .gitignore | 3 + Makefile | 5 +- data_block.go | 35 + errors.go | 44 + go.mod | 2 +- go.sum | 4 +- leaf.go | 89 ++ merkle_tree.go | 596 +------- merkle_tree_test.go | 1281 ++--------------- .../{mock_datablock.go => mock_data_block.go} | 0 proof.go | 75 + proof_gen.go | 189 +++ proof_gen_and_tree_build.go | 94 ++ proof_gen_and_tree_build_test.go | 264 ++++ proof_gen_test.go | 452 ++++++ proof_test.go | 126 ++ tree_build.go | 117 ++ tree_build_test.go | 320 ++++ verify.go | 82 ++ verify_test.go | 339 +++++ 21 files changed, 2460 insertions(+), 1742 deletions(-) create mode 100644 data_block.go create mode 100644 errors.go create mode 100644 leaf.go rename mock/{mock_datablock.go => mock_data_block.go} (100%) create mode 100644 proof.go create mode 100644 proof_gen.go create mode 100644 proof_gen_and_tree_build.go create mode 100644 proof_gen_and_tree_build_test.go create mode 100644 proof_gen_test.go create mode 100644 proof_test.go create mode 100644 tree_build.go create mode 100644 tree_build_test.go create mode 100644 verify.go create mode 100644 verify_test.go diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3bed93d..c825cb1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,4 +1,4 @@ -name: Go Build +name: Go Build, Test, and Analyze Pipeline on: push: @@ -6,9 +6,13 @@ on: pull_request: branches: [ "main" ] +env: + GO_VERSION: 1.21 + jobs: build: + name: Build runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -16,22 +20,95 @@ jobs: - name: Set up Go uses: actions/setup-go@v4 with: - go-version-file: './go.mod' + go-version: ${{ env.GO_VERSION }} + + - name: Cache Go modules + uses: actions/cache@v3 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- - name: Build run: make build - - name: Test - run: make test_race + test: + name: Unit Test + needs: build + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: ${{ env.GO_VERSION }} + + - name: Unit Test with Mocking + run: make test_with_mock + + lint: + name: Linting with golangci-lint + needs: test + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: ${{ env.GO_VERSION }} - name: Run golangci-lint uses: golangci/golangci-lint-action@v3 + coverage: + name: Coverage and Codecov Upload + needs: lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: ${{ env.GO_VERSION }} + - name: Run coverage run: make test_ci_coverage - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 + if: success() + + analysis: + name: Codacy Analysis + needs: coverage + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: ${{ env.GO_VERSION }} - name: Codacy Analysis CLI uses: codacy/codacy-analysis-cli-action@master + if: success() + + fuzz: + name: Fuzzing Test + needs: build + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: ${{ env.GO_VERSION }} + + - name: Run Fuzzing Tests + run: make test_fuzz \ No newline at end of file diff --git a/.gitignore b/.gitignore index a51fec3..7b94b91 100644 --- a/.gitignore +++ b/.gitignore @@ -20,6 +20,9 @@ coverage.* # Dependency directories (remove the comment below to include it) vendor/ +# Go fuzzing data +testdata/ + # IDEs .idea/ .vscode/ diff --git a/Makefile b/Makefile index 4586f92..03a62d7 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: test test_race test_with_mock test_ci_coverage format bench report_bench cpu_report mem_report build +.PHONY: test test_race test_with_mock test_fuzz test_ci_coverage format bench report_bench cpu_report mem_report build COVER_OUT := coverage.out COVER_HTML := coverage.html @@ -10,6 +10,9 @@ test_with_mock: COVER_OPTS = -race -gcflags=all=-l -covermode atomic test test_race test_with_mock: go test -v $(COVER_OPTS) -coverprofile=$(COVER_OUT) && go tool cover -html=$(COVER_OUT) -o $(COVER_HTML) && go tool cover -func=$(COVER_OUT) -o $(COVER_OUT) +test_fuzz: + go test -v -race -fuzz=FuzzMerkleTreeNew -fuzztime=30m -run ^FuzzMerkleTreeNew$ + test_ci_coverage: go test -race -gcflags=all=-l -coverprofile=coverage.txt -covermode=atomic diff --git a/data_block.go b/data_block.go new file mode 100644 index 0000000..fc5211e --- /dev/null +++ b/data_block.go @@ -0,0 +1,35 @@ +// MIT License +// +// Copyright (c) 2023 Tommy TIAN +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// Package merkletree implements a high-performance Merkle Tree in Go. +// It supports parallel execution for enhanced performance and +// offers compatibility with OpenZeppelin through sorted sibling pairs. +package merkletree + +// DataBlock is the interface for input data blocks used to generate the Merkle Tree. +// Implementations of DataBlock should provide a serialization method +// that converts the data block into a byte slice for hashing purposes. +type DataBlock interface { + // Serialize converts the data block into a byte slice. + // It returns the serialized byte slice and an error, if any occurs during the serialization process. + Serialize() ([]byte, error) +} diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..83022f8 --- /dev/null +++ b/errors.go @@ -0,0 +1,44 @@ +// MIT License +// +// Copyright (c) 2023 Tommy TIAN +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// Package merkletree implements a high-performance Merkle Tree in Go. +// It supports parallel execution for enhanced performance and +// offers compatibility with OpenZeppelin through sorted sibling pairs. +package merkletree + +import "errors" + +var ( + // ErrInvalidNumOfDataBlocks is the error for an invalid number of data blocks. + ErrInvalidNumOfDataBlocks = errors.New("the number of data blocks must be greater than 1") + // ErrInvalidConfigMode is the error for an invalid configuration mode. + ErrInvalidConfigMode = errors.New("invalid configuration mode") + // ErrProofIsNil is the error for a nil proof. + ErrProofIsNil = errors.New("proof is nil") + // ErrDataBlockIsNil is the error for a nil data block. + ErrDataBlockIsNil = errors.New("data block is nil") + // ErrProofInvalidModeTreeNotBuilt is the error for an invalid mode in Proof() function. + // Proof() function requires a built tree to generate the proof. + ErrProofInvalidModeTreeNotBuilt = errors.New("merkle tree is not in built, could not generate proof by this method") + // ErrProofInvalidDataBlock is the error for an invalid data block in Proof() function. + ErrProofInvalidDataBlock = errors.New("data block is not a member of the merkle tree") +) diff --git a/go.mod b/go.mod index 6ef2cbc..fde397a 100644 --- a/go.mod +++ b/go.mod @@ -4,5 +4,5 @@ go 1.21 require ( github.com/agiledragon/gomonkey/v2 v2.11.0 - github.com/txaty/gool v0.1.5 + golang.org/x/sync v0.5.0 ) diff --git a/go.sum b/go.sum index 2798d88..7da9f11 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,3 @@ -github.com/agiledragon/gomonkey/v2 v2.10.1 h1:FPJJNykD1957cZlGhr9X0zjr291/lbazoZ/dmc4mS4c= -github.com/agiledragon/gomonkey/v2 v2.10.1/go.mod h1:ap1AmDzcVOAz1YpeJ3TCzIgstoaWLA6jbbgxfB4w2iY= github.com/agiledragon/gomonkey/v2 v2.11.0 h1:5oxSgA+tC1xuGsrIorR+sYiziYltmJyEZ9qA25b6l5U= github.com/agiledragon/gomonkey/v2 v2.11.0/go.mod h1:ap1AmDzcVOAz1YpeJ3TCzIgstoaWLA6jbbgxfB4w2iY= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= @@ -10,6 +8,8 @@ github.com/txaty/gool v0.1.5 h1:yjxie86J1kBBAAsP/xa2K4j1HJoB90RvjDyzuMjlK8k= github.com/txaty/gool v0.1.5/go.mod h1:zhUnrAMYUZXRYBq6dTofbCUn8OgA3OOKCFMeqGV2mu0= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= +golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= diff --git a/leaf.go b/leaf.go new file mode 100644 index 0000000..f7c15d6 --- /dev/null +++ b/leaf.go @@ -0,0 +1,89 @@ +// MIT License +// +// Copyright (c) 2023 Tommy TIAN +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// Package merkletree implements a high-performance Merkle Tree in Go. +// It supports parallel execution for enhanced performance and +// offers compatibility with OpenZeppelin through sorted sibling pairs. +package merkletree + +import "golang.org/x/sync/errgroup" + +// computeLeafNodes compute the leaf nodes from the data blocks. +func (m *MerkleTree) computeLeafNodes(blocks []DataBlock) ([][]byte, error) { + var ( + leaves = make([][]byte, m.NumLeaves) + hashFunc = m.HashFunc + disableLeafHashing = m.DisableLeafHashing + err error + ) + for i := 0; i < m.NumLeaves; i++ { + if leaves[i], err = dataBlockToLeaf(blocks[i], hashFunc, disableLeafHashing); err != nil { + return nil, err + } + } + return leaves, nil +} + +// computeLeafNodesParallel compute the leaf nodes from the data blocks in parallel. +func (m *MerkleTree) computeLeafNodesParallel(blocks []DataBlock) ([][]byte, error) { + var ( + lenLeaves = len(blocks) + leaves = make([][]byte, lenLeaves) + numRoutines = m.NumRoutines + hashFunc = m.HashFunc + disableLeafHashing = m.DisableLeafHashing + eg = new(errgroup.Group) + ) + numRoutines = min(numRoutines, lenLeaves) + for startIdx := 0; startIdx < numRoutines; startIdx++ { + startIdx := startIdx + eg.Go(func() error { + var err error + for i := startIdx; i < lenLeaves; i += numRoutines { + if leaves[i], err = dataBlockToLeaf(blocks[i], hashFunc, disableLeafHashing); err != nil { + return err + } + } + return nil + }) + } + if err := eg.Wait(); err != nil { + return nil, err + } + return leaves, nil +} + +// dataBlockToLeaf generates the leaf from the data block. +// If the leaf hashing is disabled, the data block is returned as the leaf. +func dataBlockToLeaf(block DataBlock, hashFunc TypeHashFunc, disableLeafHashing bool) ([]byte, error) { + blockBytes, err := block.Serialize() + if err != nil { + return nil, err + } + if disableLeafHashing { + // copy the value so that the original byte slice is not modified + leaf := make([]byte, len(blockBytes)) + copy(leaf, blockBytes) + return leaf, nil + } + return hashFunc(blockBytes) +} diff --git a/merkle_tree.go b/merkle_tree.go index d9ba4b2..fb08c57 100644 --- a/merkle_tree.go +++ b/merkle_tree.go @@ -27,12 +27,9 @@ package merkletree import ( "bytes" - "errors" "math/bits" "runtime" "sync" - - "github.com/txaty/gool" ) const ( @@ -44,41 +41,6 @@ const ( ModeProofGenAndTreeBuild ) -var ( - // ErrInvalidNumOfDataBlocks is the error for an invalid number of data blocks. - ErrInvalidNumOfDataBlocks = errors.New("the number of data blocks must be greater than 1") - // ErrInvalidConfigMode is the error for an invalid configuration mode. - ErrInvalidConfigMode = errors.New("invalid configuration mode") - // ErrProofIsNil is the error for a nil proof. - ErrProofIsNil = errors.New("proof is nil") - // ErrDataBlockIsNil is the error for a nil data block. - ErrDataBlockIsNil = errors.New("data block is nil") - // ErrProofInvalidModeTreeNotBuilt is the error for an invalid mode in Proof() function. - // Proof() function requires a built tree to generate the proof. - ErrProofInvalidModeTreeNotBuilt = errors.New("merkle tree is not in built, could not generate proof by this method") - // ErrProofInvalidDataBlock is the error for an invalid data block in Proof() function. - ErrProofInvalidDataBlock = errors.New("data block is not a member of the merkle tree") -) - -// DataBlock is the interface for input data blocks used to generate the Merkle Tree. -// Implementations of DataBlock should provide a serialization method -// that converts the data block into a byte slice for hashing purposes. -type DataBlock interface { - // Serialize converts the data block into a byte slice. - // It returns the serialized byte slice and an error, if any occurs during the serialization process. - Serialize() ([]byte, error) -} - -// workerArgs is used as the arguments for the worker functions when performing parallel computations. -// Each worker function has its own dedicated argument struct embedded within workerArgs, -// which eliminates the need for interface conversion overhead and provides clear separation of concerns. -type workerArgs struct { - generateProofs *workerArgsGenerateProofs - updateProofs *workerArgsUpdateProofs - generateLeaves *workerArgsGenerateLeaves - computeTreeNodes *workerArgsComputeTreeNodes -} - // TypeConfigMode is the type in the Merkle Tree configuration indicating what operations are performed. type TypeConfigMode int @@ -108,14 +70,12 @@ type Config struct { // MerkleTree implements the Merkle Tree data structure. type MerkleTree struct { - Config + *Config // leafMap maps the data (converted to string) of each leaf node to its index in the Tree slice. // It is only available when the configuration mode is set to ModeTreeBuild or ModeProofGenAndTreeBuild. leafMap map[string]int // leafMapMu is a mutex that protects concurrent access to the leafMap. leafMapMu sync.Mutex - // wp is the worker pool used for parallel computation in the tree building process. - wp *gool.Pool[workerArgs, error] // concatHashFunc is the function for concatenating two hashes. // If SortSiblingPairs in Config is true, then the sibling pairs are first sorted and then concatenated, // supporting the OpenZeppelin Merkle Tree protocol. @@ -139,12 +99,6 @@ type MerkleTree struct { NumLeaves int } -// Proof represents a Merkle Tree proof. -type Proof struct { - Siblings [][]byte // Sibling nodes to the Merkle Tree path of the data block. - Path uint32 // Path variable indicating whether the neighbor is on the left or right. -} - // New generates a new Merkle Tree with the specified configuration and data blocks. func New(config *Config, blocks []DataBlock) (m *MerkleTree, err error) { // Check if there are enough data blocks to build the tree. @@ -159,7 +113,7 @@ func New(config *Config, blocks []DataBlock) (m *MerkleTree, err error) { // Create a MerkleTree with the provided configuration. m = &MerkleTree{ - Config: *config, + Config: config, NumLeaves: len(blocks), Depth: bits.Len(uint(len(blocks) - 1)), } @@ -189,16 +143,12 @@ func New(config *Config, blocks []DataBlock) (m *MerkleTree, err error) { if m.NumRoutines <= 0 { m.NumRoutines = runtime.NumCPU() } - // Initialize a wait group for parallel computation and generate leaves. - // Task channel capacity is passed as 0, so use the default value: 2 * numWorkers. - m.wp = gool.NewPool[workerArgs, error](m.NumRoutines, 0) - defer m.wp.Close() - if m.Leaves, err = m.generateLeavesInParallel(blocks); err != nil { + if m.Leaves, err = m.computeLeafNodesParallel(blocks); err != nil { return nil, err } } else { // Generate leaves without parallelization. - if m.Leaves, err = m.generateLeaves(blocks); err != nil { + if m.Leaves, err = m.computeLeafNodes(blocks); err != nil { return nil, err } } @@ -211,6 +161,10 @@ func New(config *Config, blocks []DataBlock) (m *MerkleTree, err error) { // Generate proofs in ModeProofGen. if m.Mode == ModeProofGen { + if m.RunInParallel { + err = m.generateProofsParallel() + return + } err = m.generateProofs() return } @@ -219,25 +173,21 @@ func New(config *Config, blocks []DataBlock) (m *MerkleTree, err error) { // Build the tree in ModeTreeBuild. if m.Mode == ModeTreeBuild { - err = m.buildTree() + if m.RunInParallel { + err = m.treeBuildParallel() + return + } + err = m.treeBuild() return } // Build the tree and generate proofs in ModeProofGenAndTreeBuild. if m.Mode == ModeProofGenAndTreeBuild { - if err = m.buildTree(); err != nil { - return - } - m.initProofs() if m.RunInParallel { - for i := 0; i < len(m.nodes); i++ { - m.updateProofsInParallel(m.nodes[i], len(m.nodes[i]), i) - } + err = m.proofGenAndTreeBuildParallel() return } - for i := 0; i < len(m.nodes); i++ { - m.updateProofs(m.nodes[i], len(m.nodes[i]), i) - } + err = m.proofGenAndTreeBuild() return } @@ -245,7 +195,7 @@ func New(config *Config, blocks []DataBlock) (m *MerkleTree, err error) { return nil, ErrInvalidConfigMode } -// concatSortHash concatenates two byte slices, b1 and b2, in a sorted order. +// concatHash concatenates two byte slices, b1 and b2. func concatHash(b1 []byte, b2 []byte) []byte { result := make([]byte, len(b1)+len(b2)) copy(result, b1) @@ -263,517 +213,3 @@ func concatSortHash(b1 []byte, b2 []byte) []byte { } return concatHash(b2, b1) } - -// initProofs initializes the MerkleTree's Proofs with the appropriate size and depth. -func (m *MerkleTree) initProofs() { - m.Proofs = make([]*Proof, m.NumLeaves) - for i := 0; i < m.NumLeaves; i++ { - m.Proofs[i] = new(Proof) - m.Proofs[i].Siblings = make([][]byte, 0, m.Depth) - } -} - -// generateProofs constructs the Merkle Tree and generates the Merkle proofs for each leaf. -// It returns an error if there is an issue during the generation process. -func (m *MerkleTree) generateProofs() error { - m.initProofs() - buffer := make([][]byte, m.NumLeaves) - copy(buffer, m.Leaves) - var bufferLength int - buffer, bufferLength = m.fixOddLength(buffer, m.NumLeaves) - - if m.RunInParallel { - return m.generateProofsInParallel(buffer, bufferLength) - } - - m.updateProofs(buffer, m.NumLeaves, 0) - var err error - for step := 1; step < m.Depth; step++ { - for idx := 0; idx < bufferLength; idx += 2 { - buffer[idx>>1], err = m.HashFunc(m.concatHashFunc(buffer[idx], buffer[idx+1])) - if err != nil { - return err - } - } - bufferLength >>= 1 - buffer, bufferLength = m.fixOddLength(buffer, bufferLength) - m.updateProofs(buffer, bufferLength, step) - } - - m.Root, err = m.HashFunc(m.concatHashFunc(buffer[0], buffer[1])) - return err -} - -// workerArgsGenerateProofs contains the parameters required for workerGenerateProofs. -type workerArgsGenerateProofs struct { - hashFunc TypeHashFunc - concatHashFunc typeConcatHashFunc - buffer [][]byte - tempBuffer [][]byte - startIdx int - bufferLength int - numRoutines int -} - -// workerGenerateProofs is the worker function that generates Merkle proofs in parallel. -// It processes a portion of the buffer based on the provided worker arguments. -func workerGenerateProofs(args workerArgs) error { - chosenArgs := args.generateProofs - var ( - hashFunc = chosenArgs.hashFunc - concatFunc = chosenArgs.concatHashFunc - buffer = chosenArgs.buffer - tempBuffer = chosenArgs.tempBuffer - startIdx = chosenArgs.startIdx - bufferLength = chosenArgs.bufferLength - numRoutines = chosenArgs.numRoutines - ) - for i := startIdx; i < bufferLength; i += numRoutines << 1 { - newHash, err := hashFunc(concatFunc(buffer[i], buffer[i+1])) - if err != nil { - return err - } - tempBuffer[i>>1] = newHash - } - return nil -} - -// generateProofsInParallel generates proofs concurrently for the MerkleTree. -func (m *MerkleTree) generateProofsInParallel(buffer [][]byte, bufferLength int) (err error) { - tempBuffer := make([][]byte, bufferLength>>1) - m.updateProofsInParallel(buffer, m.NumLeaves, 0) - numRoutines := m.NumRoutines - for step := 1; step < m.Depth; step++ { - // Limit the number of workers to the previous level length. - if numRoutines > bufferLength { - numRoutines = bufferLength - } - - // Create the list of arguments for the worker pool. - argList := make([]workerArgs, numRoutines) - for i := 0; i < numRoutines; i++ { - argList[i] = workerArgs{ - generateProofs: &workerArgsGenerateProofs{ - hashFunc: m.HashFunc, - concatHashFunc: m.concatHashFunc, - buffer: buffer, - tempBuffer: tempBuffer, - startIdx: i << 1, - bufferLength: bufferLength, - numRoutines: numRoutines, - }, - } - } - - // Execute proof generation concurrently using the worker pool. - errList := m.wp.Map(workerGenerateProofs, argList) - for _, err = range errList { - if err != nil { - return - } - } - - // Swap the buffers for the next iteration. - buffer, tempBuffer = tempBuffer, buffer - bufferLength >>= 1 - - // Fix the buffer if it has an odd number of elements. - buffer, bufferLength = m.fixOddLength(buffer, bufferLength) - - // Update the proofs with the new buffer. - m.updateProofsInParallel(buffer, bufferLength, step) - } - - // Compute the root hash of the Merkle tree. - m.Root, err = m.HashFunc(m.concatHashFunc(buffer[0], buffer[1])) - return -} - -// fixOddLength adjusts the buffer for odd-length slices by appending a node. -func (m *MerkleTree) fixOddLength(buffer [][]byte, bufferLength int) ([][]byte, int) { - // If the buffer length is even, no adjustment is needed. - if bufferLength&1 == 0 { - return buffer, bufferLength - } - - // Determine the node to append. - appendNode := buffer[bufferLength-1] - bufferLength++ - - // Append the node to the buffer, either by extending the buffer or updating an existing entry. - if len(buffer) < bufferLength { - buffer = append(buffer, appendNode) - } else { - buffer[bufferLength-1] = appendNode - } - - return buffer, bufferLength -} - -func (m *MerkleTree) updateProofs(buffer [][]byte, bufferLength, step int) { - batch := 1 << step - for i := 0; i < bufferLength; i += 2 { - m.updateProofPairs(buffer, i, batch, step) - } -} - -// workerArgsUpdateProofs contains arguments for the workerUpdateProofs function. -type workerArgsUpdateProofs struct { - tree *MerkleTree - buffer [][]byte - startIdx int - batch int - step int - bufferLength int - numRoutines int -} - -// workerUpdateProofs is the worker function that updates Merkle proofs in parallel. -func workerUpdateProofs(args workerArgs) error { - chosenArgs := args.updateProofs - var ( - tree = chosenArgs.tree - buffer = chosenArgs.buffer - startIdx = chosenArgs.startIdx - batch = chosenArgs.batch - step = chosenArgs.step - bufferLength = chosenArgs.bufferLength - numRoutines = chosenArgs.numRoutines - ) - for i := startIdx; i < bufferLength; i += numRoutines << 1 { - tree.updateProofPairs(buffer, i, batch, step) - } - // return the nil error to be compatible with the worker type - return nil -} - -// updateProofsInParallel updates proofs concurrently for the Merkle Tree. -func (m *MerkleTree) updateProofsInParallel(buffer [][]byte, bufferLength, step int) { - batch := 1 << step - numRoutines := m.NumRoutines - if numRoutines > bufferLength { - numRoutines = bufferLength - } - argList := make([]workerArgs, numRoutines) - for i := 0; i < numRoutines; i++ { - argList[i] = workerArgs{ - updateProofs: &workerArgsUpdateProofs{ - tree: m, - buffer: buffer, - startIdx: i << 1, - batch: batch, - step: step, - bufferLength: bufferLength, - numRoutines: numRoutines, - }, - } - } - m.wp.Map(workerUpdateProofs, argList) -} - -// updateProofPairs updates the proofs in the Merkle Tree in pairs. -func (m *MerkleTree) updateProofPairs(buffer [][]byte, idx, batch, step int) { - start := idx * batch - end := min(start+batch, len(m.Proofs)) - for i := start; i < end; i++ { - m.Proofs[i].Path += 1 << step - m.Proofs[i].Siblings = append(m.Proofs[i].Siblings, buffer[idx+1]) - } - start += batch - end = min(start+batch, len(m.Proofs)) - for i := start; i < end; i++ { - m.Proofs[i].Siblings = append(m.Proofs[i].Siblings, buffer[idx]) - } -} - -// generateLeaves generates the leaves slice from the data blocks. -func (m *MerkleTree) generateLeaves(blocks []DataBlock) ([][]byte, error) { - var ( - leaves = make([][]byte, m.NumLeaves) - err error - ) - for i := 0; i < m.NumLeaves; i++ { - if leaves[i], err = dataBlockToLeaf(blocks[i], &m.Config); err != nil { - return nil, err - } - } - return leaves, nil -} - -// dataBlockToLeaf generates the leaf from the data block. -// If the leaf hashing is disabled, the data block is returned as the leaf. -func dataBlockToLeaf(block DataBlock, config *Config) ([]byte, error) { - blockBytes, err := block.Serialize() - if err != nil { - return nil, err - } - if config.DisableLeafHashing { - // copy the value so that the original byte slice is not modified - leaf := make([]byte, len(blockBytes)) - copy(leaf, blockBytes) - return leaf, nil - } - return config.HashFunc(blockBytes) -} - -// workerArgsGenerateLeaves contains arguments for the workerGenerateLeaves function. -type workerArgsGenerateLeaves struct { - config *Config - dataBlocks []DataBlock - leaves [][]byte - startIdx int - lenLeaves int - numRoutines int -} - -// workerGenerateLeaves is the worker function that generates Merkle leaves in parallel. -func workerGenerateLeaves(args workerArgs) error { - chosenArgs := args.generateLeaves - var ( - config = chosenArgs.config - blocks = chosenArgs.dataBlocks - leaves = chosenArgs.leaves - start = chosenArgs.startIdx - lenLeaves = chosenArgs.lenLeaves - numRoutines = chosenArgs.numRoutines - ) - var err error - for i := start; i < lenLeaves; i += numRoutines { - if leaves[i], err = dataBlockToLeaf(blocks[i], config); err != nil { - return err - } - } - return nil -} - -// generateLeavesInParallel generates the leaves slice from the data blocks in parallel. -func (m *MerkleTree) generateLeavesInParallel(blocks []DataBlock) ([][]byte, error) { - var ( - lenLeaves = len(blocks) - leaves = make([][]byte, lenLeaves) - numRoutines = m.NumRoutines - ) - if numRoutines > lenLeaves { - numRoutines = lenLeaves - } - argList := make([]workerArgs, numRoutines) - for i := 0; i < numRoutines; i++ { - argList[i] = workerArgs{ - generateLeaves: &workerArgsGenerateLeaves{ - config: &m.Config, - dataBlocks: blocks, - leaves: leaves, - startIdx: i, - lenLeaves: lenLeaves, - numRoutines: numRoutines, - }, - } - } - errList := m.wp.Map(workerGenerateLeaves, argList) - for _, err := range errList { - if err != nil { - return nil, err - } - } - return leaves, nil -} - -// buildTree builds the Merkle Tree. -func (m *MerkleTree) buildTree() (err error) { - finishMap := make(chan struct{}) - go func() { - m.leafMapMu.Lock() - defer m.leafMapMu.Unlock() - for i := 0; i < m.NumLeaves; i++ { - m.leafMap[string(m.Leaves[i])] = i - } - finishMap <- struct{}{} // empty channel to serve as a wait group for map generation - }() - m.nodes = make([][][]byte, m.Depth) - m.nodes[0] = make([][]byte, m.NumLeaves) - copy(m.nodes[0], m.Leaves) - var bufferLength int - m.nodes[0], bufferLength = m.fixOddLength(m.nodes[0], m.NumLeaves) - if m.RunInParallel { - if err := m.computeTreeNodesInParallel(bufferLength); err != nil { - return err - } - } - for i := 0; i < m.Depth-1; i++ { - m.nodes[i+1] = make([][]byte, bufferLength>>1) - for j := 0; j < bufferLength; j += 2 { - if m.nodes[i+1][j>>1], err = m.HashFunc( - m.concatHashFunc(m.nodes[i][j], m.nodes[i][j+1]), - ); err != nil { - return - } - } - m.nodes[i+1], bufferLength = m.fixOddLength(m.nodes[i+1], len(m.nodes[i+1])) - } - if m.Root, err = m.HashFunc(m.concatHashFunc( - m.nodes[m.Depth-1][0], m.nodes[m.Depth-1][1], - )); err != nil { - return - } - <-finishMap - return -} - -// workerArgsComputeTreeNodes contains arguments for the workerComputeTreeNodes function. -type workerArgsComputeTreeNodes struct { - tree *MerkleTree - startIdx int - bufferLength int - numRoutines int - depth int -} - -// workerBuildTree is the worker function that builds the Merkle tree in parallel. -func workerBuildTree(args workerArgs) error { - chosenArgs := args.computeTreeNodes - var ( - tree = chosenArgs.tree - start = chosenArgs.startIdx - bufferLength = chosenArgs.bufferLength - numRoutines = chosenArgs.numRoutines - depth = chosenArgs.depth - ) - for i := start; i < bufferLength; i += numRoutines << 1 { - newHash, err := tree.HashFunc(tree.concatHashFunc( - tree.nodes[depth][i], tree.nodes[depth][i+1], - )) - if err != nil { - return err - } - tree.nodes[depth+1][i>>1] = newHash - } - return nil -} - -// computeTreeNodesInParallel computes the tree nodes in parallel. -func (m *MerkleTree) computeTreeNodesInParallel(bufferLength int) error { - for i := 0; i < m.Depth-1; i++ { - m.nodes[i+1] = make([][]byte, bufferLength>>1) - numRoutines := m.NumRoutines - if numRoutines > bufferLength { - numRoutines = bufferLength - } - argList := make([]workerArgs, numRoutines) - for j := 0; j < numRoutines; j++ { - argList[j] = workerArgs{ - computeTreeNodes: &workerArgsComputeTreeNodes{ - tree: m, - startIdx: j << 1, - bufferLength: bufferLength, - numRoutines: m.NumRoutines, - depth: i, - }, - } - } - errList := m.wp.Map(workerBuildTree, argList) - for _, err := range errList { - if err != nil { - return err - } - } - m.nodes[i+1], bufferLength = m.fixOddLength(m.nodes[i+1], len(m.nodes[i+1])) - } - return nil -} - -// Verify checks if the data block is valid using the Merkle Tree proof and the cached Merkle root hash. -func (m *MerkleTree) Verify(dataBlock DataBlock, proof *Proof) (bool, error) { - return Verify(dataBlock, proof, m.Root, &m.Config) -} - -// Verify checks if the data block is valid using the Merkle Tree proof and the provided Merkle root hash. -// It returns true if the data block is valid, false otherwise. An error is returned in case of any issues -// during the verification process. -func Verify(dataBlock DataBlock, proof *Proof, root []byte, config *Config) (bool, error) { - // Validate input parameters. - if dataBlock == nil { - return false, ErrDataBlockIsNil - } - if proof == nil { - return false, ErrProofIsNil - } - if config == nil { - config = new(Config) - } - if config.HashFunc == nil { - config.HashFunc = DefaultHashFunc - } - - // Determine the concatenation function based on the configuration. - concatFunc := concatHash - if config.SortSiblingPairs { - concatFunc = concatSortHash - } - - // Convert the data block to a leaf. - leaf, err := dataBlockToLeaf(dataBlock, config) - if err != nil { - return false, err - } - - // Traverse the Merkle proof and compute the resulting hash. - // Copy the slice so that the original leaf won't be modified. - result := make([]byte, len(leaf)) - copy(result, leaf) - path := proof.Path - for _, sib := range proof.Siblings { - if path&1 == 1 { - result, err = config.HashFunc(concatFunc(result, sib)) - } else { - result, err = config.HashFunc(concatFunc(sib, result)) - } - if err != nil { - return false, err - } - path >>= 1 - } - return bytes.Equal(result, root), nil -} - -// Proof generates the Merkle proof for a data block using the previously generated Merkle Tree structure. -// This method is only available when the configuration mode is ModeTreeBuild or ModeProofGenAndTreeBuild. -// In ModeProofGen, proofs for all the data blocks are already generated, and the Merkle Tree structure -// is not cached. -func (m *MerkleTree) Proof(dataBlock DataBlock) (*Proof, error) { - if m.Mode != ModeTreeBuild && m.Mode != ModeProofGenAndTreeBuild { - return nil, ErrProofInvalidModeTreeNotBuilt - } - - // Convert the data block to a leaf. - leaf, err := dataBlockToLeaf(dataBlock, &m.Config) - if err != nil { - return nil, err - } - - // Retrieve the index of the leaf in the Merkle Tree. - m.leafMapMu.Lock() - idx, ok := m.leafMap[string(leaf)] - m.leafMapMu.Unlock() - if !ok { - return nil, ErrProofInvalidDataBlock - } - - // Compute the path and siblings for the proof. - var ( - path uint32 - siblings = make([][]byte, m.Depth) - ) - for i := 0; i < m.Depth; i++ { - if idx&1 == 1 { - siblings[i] = m.nodes[i][idx-1] - } else { - path += 1 << i - siblings[i] = m.nodes[i][idx+1] - } - idx >>= 1 - } - return &Proof{ - Path: path, - Siblings: siblings, - }, nil -} diff --git a/merkle_tree_test.go b/merkle_tree_test.go index 677bddb..cf12f8e 100644 --- a/merkle_tree_test.go +++ b/merkle_tree_test.go @@ -24,27 +24,21 @@ package merkletree import ( "bytes" - "crypto/rand" - "crypto/sha256" - "errors" - "fmt" - "reflect" + crand "crypto/rand" + "math/rand" "testing" - "github.com/agiledragon/gomonkey/v2" - "github.com/txaty/go-merkletree/mock" ) -const benchSize = 10000 - -func generatedTestDataBlocks(num int) []DataBlock { +func mockDataBlocks(num int) []DataBlock { blocks := make([]DataBlock, num) for i := 0; i < num; i++ { + byteLen := rand.Intn(1 << 15) block := &mock.DataBlock{ - Data: make([]byte, 100), + Data: make([]byte, byteLen), } - if _, err := rand.Read(block.Data); err != nil { + if _, err := crand.Read(block.Data); err != nil { panic(err) } blocks[i] = block @@ -52,1219 +46,198 @@ func generatedTestDataBlocks(num int) []DataBlock { return blocks } -func TestMerkleTreeNew_modeProofGen(t *testing.T) { - dummyDataBlocks := []DataBlock{ - &mock.DataBlock{ - Data: []byte("dummy_data_0"), - }, - &mock.DataBlock{ - Data: []byte("dummy_data_1"), - }, - &mock.DataBlock{ - Data: []byte("dummy_data_2"), - }, - } - dummyHashList := make([][]byte, 3) - var err error - for i := 0; i < 3; i++ { - dataByte, err := dummyDataBlocks[i].Serialize() - if err != nil { - t.Fatal(err) +func mockDataBlocksFixedSize(num int) []DataBlock { + blocks := make([]DataBlock, num) + for i := 0; i < num; i++ { + block := &mock.DataBlock{ + Data: make([]byte, 128), } - dummyHashList[i], err = DefaultHashFunc(dataByte) - if err != nil { - t.Fatal(err) + if _, err := crand.Read(block.Data); err != nil { + panic(err) } + blocks[i] = block } - twoDummyRoot, err := DefaultHashFunc( - append(dummyHashList[0], dummyHashList[1]...), - ) - if err != nil { - t.Fatal(err) - } - leftHash, err := DefaultHashFunc( - append(dummyHashList[0], dummyHashList[1]...), - ) - if err != nil { - t.Fatal(err) - } - rightHash, err := DefaultHashFunc( - append(dummyHashList[2], dummyHashList[2]...), - ) - if err != nil { - t.Fatal(err) - } - threeDummyRoot, err := DefaultHashFunc(append(leftHash, rightHash...)) - if err != nil { - t.Fatal(err) - } - type args struct { - blocks []DataBlock - config *Config - } - tests := []struct { - name string - args args - wantErr bool - wantRoot []byte - }{ - { - name: "test_0", - args: args{ - blocks: generatedTestDataBlocks(0), - }, - wantErr: true, - }, - { - name: "test_1", - args: args{ - blocks: generatedTestDataBlocks(1), - }, - wantErr: true, - }, - { - name: "test_2", - args: args{ - blocks: []DataBlock{dummyDataBlocks[0], dummyDataBlocks[1]}, - }, - wantErr: false, - wantRoot: twoDummyRoot, - }, - { - name: "test_3", - args: args{ - blocks: dummyDataBlocks, - }, - wantErr: false, - wantRoot: threeDummyRoot, - }, - { - name: "test_8", - args: args{ - blocks: generatedTestDataBlocks(8), - }, - wantErr: false, - }, - { - name: "test_5", - args: args{ - blocks: generatedTestDataBlocks(5), - }, - wantErr: false, - }, - { - name: "test_1000", - args: args{ - blocks: generatedTestDataBlocks(1000), - }, - wantErr: false, - }, - { - name: "test_100_parallel_4", - args: args{ - blocks: generatedTestDataBlocks(100), - config: &Config{ - RunInParallel: true, - NumRoutines: 4, - }, - }, - wantErr: false, - }, - { - name: "test_10_parallel_32", - args: args{ - blocks: generatedTestDataBlocks(100), - config: &Config{ - RunInParallel: true, - NumRoutines: 32, - }, - }, - wantErr: false, - }, - { - name: "test_100_parallel_no_specify_num_of_routines", - args: args{ - blocks: generatedTestDataBlocks(100), - config: &Config{ - RunInParallel: true, - }, - }, - wantErr: false, - }, - { - name: "test_8_sorted", - args: args{ - blocks: generatedTestDataBlocks(8), - config: &Config{ - SortSiblingPairs: true, - }, - }, - wantErr: false, - }, - { - name: "test_hash_func_error", - args: args{ - blocks: generatedTestDataBlocks(100), - config: &Config{ - HashFunc: func([]byte) ([]byte, error) { - return nil, fmt.Errorf("hash func error") - }, - }, - }, - wantErr: true, - }, - { - name: "test_hash_func_error_parallel", - args: args{ - blocks: generatedTestDataBlocks(100), - config: &Config{ - HashFunc: func([]byte) ([]byte, error) { - return nil, fmt.Errorf("hash func error") - }, - RunInParallel: true, - }, - }, - wantErr: true, - }, - { - name: "test_100_disable_leaf_hashing", - args: args{ - blocks: generatedTestDataBlocks(100), - config: &Config{ - DisableLeafHashing: true, - }, - }, - wantErr: false, - }, - { - name: "test_100_disable_leaf_hashing_parallel_4", - args: args{ - blocks: generatedTestDataBlocks(100), - config: &Config{ - DisableLeafHashing: true, - RunInParallel: true, - NumRoutines: 4, - }, - }, - wantErr: false, - }, - { - name: "invalid_mode", - args: args{ - blocks: generatedTestDataBlocks(100), - config: &Config{ - Mode: 5, - }, - }, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mt, err := New(tt.args.config, tt.args.blocks) - if (err != nil) != tt.wantErr { - t.Errorf("Build() error = %v, wantErr %v", err, tt.wantErr) - return - } - if tt.wantErr { - return - } - if tt.wantRoot == nil { - for idx, block := range tt.args.blocks { - if ok, _ := mt.Verify(block, mt.Proofs[idx]); !ok { - t.Errorf("proof verification failed") - return - } - } - } else { - if !bytes.Equal(mt.Root, tt.wantRoot) { - t.Errorf("root mismatch, got %x, want %x", mt.Root, tt.wantRoot) - return - } - } - }) - } + return blocks } -func TestMerkleTreeNew_modeTreeBuild(t *testing.T) { - type args struct { - blocks []DataBlock - config *Config - } - tests := []struct { - name string - args args - checkingConfig *Config - wantErr bool - }{ - { - name: "test_build_tree_2", - args: args{ - blocks: generatedTestDataBlocks(2), - config: &Config{ - Mode: ModeTreeBuild, - }, - }, - wantErr: false, - }, - { - name: "test_build_tree_5", - args: args{ - blocks: generatedTestDataBlocks(5), - config: &Config{ - Mode: ModeTreeBuild, - }, - }, - wantErr: false, - }, - { - name: "test_build_tree_8", - args: args{ - blocks: generatedTestDataBlocks(8), - config: &Config{ - Mode: ModeTreeBuild, - }, - }, - wantErr: false, - }, - { - name: "test_build_tree_1000", - args: args{ - blocks: generatedTestDataBlocks(1000), - config: &Config{ - Mode: ModeTreeBuild, - }, - }, - wantErr: false, - }, - { - name: "test_hash_func_error", - args: args{ - blocks: generatedTestDataBlocks(100), - config: &Config{ - HashFunc: func([]byte) ([]byte, error) { - return nil, fmt.Errorf("hash func error") - }, - Mode: ModeTreeBuild, - }, - }, - wantErr: true, - }, - { - name: "test_disable_leaf_hashing", - args: args{ - blocks: generatedTestDataBlocks(100), - config: &Config{ - DisableLeafHashing: true, - Mode: ModeTreeBuild, - }, - }, - checkingConfig: &Config{ - DisableLeafHashing: true, - }, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - m, err := New(tt.args.config, tt.args.blocks) - if (err != nil) != tt.wantErr { - t.Errorf("Build() error = %v, wantErr %v", err, tt.wantErr) - return - } - m1, err := New(tt.checkingConfig, tt.args.blocks) - if err != nil { - t.Errorf("test setup error %v", err) - return - } - if !tt.wantErr && !bytes.Equal(m.Root, m1.Root) && !tt.wantErr { - fmt.Println("m", m.Root) - fmt.Println("m1", m1.Root) - t.Errorf("tree generated is wrong") - return - } - }) - } -} +func FuzzMerkleTreeNew(f *testing.F) { + f.Add(10, 0, 4, false, false, false) + f.Add(128, 1, 3, true, true, true) + f.Add(256, 2, 2, false, false, true) + f.Add(512, 0, 1, true, true, false) + f.Fuzz(func(t *testing.T, + numBlocks, modeInt, numRoutines int, + runInParallel, sortSiblingPairs, disableLeafHashing bool, + ) { + // 0 < numBlocks < 262144 + if numBlocks < 0 { + numBlocks = -numBlocks + } + numBlocks %= 262144 + numBlocks++ + dataBlocks := mockDataBlocks(numBlocks) -func TestMerkleTreeNew_modeTreeBuildRunInParallel(t *testing.T) { - type args struct { - blocks []DataBlock - config *Config - } - tests := []struct { - name string - args args - wantErr bool - }{ - { - name: "test_build_tree_parallel_2", - args: args{ - blocks: generatedTestDataBlocks(2), - config: &Config{ - RunInParallel: true, - NumRoutines: 4, - Mode: ModeTreeBuild, - }, - }, - wantErr: false, - }, - { - name: "test_build_tree_parallel_4", - args: args{ - blocks: generatedTestDataBlocks(4), - config: &Config{ - RunInParallel: true, - NumRoutines: 4, - Mode: ModeTreeBuild, - }, - }, - wantErr: false, - }, - { - name: "test_build_tree_parallel_5", - args: args{ - blocks: generatedTestDataBlocks(5), - config: &Config{ - RunInParallel: true, - NumRoutines: 4, - Mode: ModeTreeBuild, - }, - }, - wantErr: false, - }, - { - name: "test_build_tree_parallel_8", - args: args{ - blocks: generatedTestDataBlocks(8), - config: &Config{ - RunInParallel: true, - NumRoutines: 4, - Mode: ModeTreeBuild, - }, - }, - wantErr: false, - }, - { - name: "test_build_tree_parallel_8_32", - args: args{ - blocks: generatedTestDataBlocks(8), - config: &Config{ - RunInParallel: true, - NumRoutines: 32, - Mode: ModeTreeBuild, - }, - }, - wantErr: false, - }, - { - name: "test_hash_func_error_parallel", - args: args{ - blocks: generatedTestDataBlocks(100), - config: &Config{ - HashFunc: func([]byte) ([]byte, error) { - return nil, fmt.Errorf("hash func error") - }, - RunInParallel: true, - Mode: ModeTreeBuild, - }, - }, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - m, err := New(tt.args.config, tt.args.blocks) - if (err != nil) != tt.wantErr { - t.Errorf("Build() error = %v, wantErr %v", err, tt.wantErr) - return - } - m1, err := New(nil, tt.args.blocks) - if err != nil { - t.Errorf("test setup error %v", err) - return - } - if !tt.wantErr && !bytes.Equal(m.Root, m1.Root) && !tt.wantErr { - fmt.Println("m", m.Root) - fmt.Println("m1", m1.Root) - t.Errorf("tree generated is wrong") - return - } - }) - } -} + // 0 <= modeInt < 3 + if modeInt < 0 { + modeInt = -modeInt + } + modeInt %= 3 + mode := TypeConfigMode(modeInt) -func TestMerkleTreeNew_modeProofGenAndTreeBuild(t *testing.T) { - type args struct { - blocks []DataBlock - config *Config - } - tests := []struct { - name string - args args - wantErr bool - }{ - { - name: "test_build_tree_proof_2", - args: args{ - blocks: generatedTestDataBlocks(2), - config: &Config{ - Mode: ModeProofGenAndTreeBuild, - }, - }, - wantErr: false, - }, - { - name: "test_build_tree_proof_4", - args: args{ - blocks: generatedTestDataBlocks(4), - config: &Config{ - Mode: ModeProofGenAndTreeBuild, - }, - }, - wantErr: false, - }, - { - name: "test_build_tree_proof_5", - args: args{ - blocks: generatedTestDataBlocks(5), - config: &Config{ - Mode: ModeProofGenAndTreeBuild, - }, - }, - wantErr: false, - }, - { - name: "test_build_tree_proof_8", - args: args{ - blocks: generatedTestDataBlocks(8), - config: &Config{ - Mode: ModeProofGenAndTreeBuild, - }, - }, - wantErr: false, - }, - { - name: "test_build_tree_proof_9", - args: args{ - blocks: generatedTestDataBlocks(9), - config: &Config{ - Mode: ModeProofGenAndTreeBuild, - }, - }, - wantErr: false, - }, - { - name: "test_hash_func_error", - args: args{ - blocks: generatedTestDataBlocks(100), - config: &Config{ - HashFunc: func([]byte) ([]byte, error) { - return nil, fmt.Errorf("hash func error") - }, - Mode: ModeProofGenAndTreeBuild, - }, - }, - wantErr: true, - }, - { - name: "test_tree_build_hash_func_error", - args: args{ - blocks: generatedTestDataBlocks(100), - config: &Config{ - HashFunc: func(block []byte) ([]byte, error) { - if len(block) == 64 { - return nil, fmt.Errorf("hash func error") - } - sha256Func := sha256.New() - sha256Func.Write(block) - return sha256Func.Sum(nil), nil - }, - Mode: ModeProofGenAndTreeBuild, - }, - }, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - m, err := New(tt.args.config, tt.args.blocks) - if (err != nil) != tt.wantErr { - t.Errorf("Build() error = %v, wantErr %v", err, tt.wantErr) - return - } - if tt.wantErr { - return - } - m1, err := New(nil, tt.args.blocks) - if err != nil { - t.Errorf("test setup error %v", err) - return - } - for i := 0; i < len(tt.args.blocks); i++ { - if !reflect.DeepEqual(m.Proofs[i], m1.Proofs[i]) { - t.Errorf("proofs generated are wrong for block %d", i) - return - } - } - }) - } -} + // 0 <= numRoutines < 1024 + if numRoutines < 0 { + numRoutines = -numRoutines + } + numRoutines %= 1024 + + config := &Config{ + NumRoutines: numRoutines, + Mode: mode, + RunInParallel: runInParallel, + SortSiblingPairs: sortSiblingPairs, + DisableLeafHashing: disableLeafHashing, + } + mt, err := New(config, dataBlocks) + if err != nil { + return + } + if mt == nil { + return + } -func TestMerkleTreeNew_modeProofGenAndTreeBuildRunInParallel(t *testing.T) { - type args struct { - blocks []DataBlock - config *Config - } - tests := []struct { - name string - args args - wantErr bool - }{ - { - name: "test_build_tree_proof_parallel_2", - args: args{ - blocks: generatedTestDataBlocks(2), - config: &Config{ - RunInParallel: true, - NumRoutines: 4, - Mode: ModeProofGenAndTreeBuild, - }, - }, - wantErr: false, - }, - { - name: "test_build_tree_proof_parallel_4", - args: args{ - blocks: generatedTestDataBlocks(4), - config: &Config{ - RunInParallel: true, - NumRoutines: 4, - Mode: ModeProofGenAndTreeBuild, - }, - }, - wantErr: false, - }, - { - name: "test_build_tree_proof_parallel_5", - args: args{ - blocks: generatedTestDataBlocks(5), - config: &Config{ - RunInParallel: true, - NumRoutines: 4, - Mode: ModeProofGenAndTreeBuild, - }, - }, - wantErr: false, - }, - { - name: "test_build_tree_proof_parallel_8", - args: args{ - blocks: generatedTestDataBlocks(8), - config: &Config{ - RunInParallel: true, - NumRoutines: 4, - Mode: ModeProofGenAndTreeBuild, - }, - }, - wantErr: false, - }, - { - name: "test_hash_func_error", - args: args{ - blocks: generatedTestDataBlocks(100), - config: &Config{ - HashFunc: func([]byte) ([]byte, error) { - return nil, fmt.Errorf("hash func error") - }, - Mode: ModeProofGenAndTreeBuild, - RunInParallel: true, - }, - }, - wantErr: true, - }, - { - name: "test_tree_build_hash_func_error", - args: args{ - blocks: generatedTestDataBlocks(100), - config: &Config{ - HashFunc: func(block []byte) ([]byte, error) { - if len(block) == 64 { - return nil, fmt.Errorf("hash func error") - } - sha256Func := sha256.New() - sha256Func.Write(block) - return sha256Func.Sum(nil), nil - }, - Mode: ModeProofGenAndTreeBuild, - RunInParallel: true, - }, - }, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - m, err := New(tt.args.config, tt.args.blocks) - if (err != nil) != tt.wantErr { - t.Errorf("Build() error = %v, wantErr %v", err, tt.wantErr) - return - } - if tt.wantErr { - return - } - m1, err := New(nil, tt.args.blocks) - if err != nil { - t.Errorf("test setup error %v", err) - return - } - for i := 0; i < len(tt.args.blocks); i++ { - if !reflect.DeepEqual(m.Proofs[i], m1.Proofs[i]) { - t.Errorf("proofs generated are wrong for block %d", i) + if mode == ModeProofGen || mode == ModeProofGenAndTreeBuild { + for idx, block := range dataBlocks { + ok, err := mt.Verify(block, mt.Proofs[idx]) + if err != nil { + t.Errorf("proof verification error, idx %d, err %v", idx, err) return } - } - }) - } -} - -func setupTestVerify(size int) (*MerkleTree, []DataBlock) { - blocks := generatedTestDataBlocks(size) - m, err := New(nil, blocks) - if err != nil { - panic(err) - } - return m, blocks -} - -func setupTestVerifyRunInParallel(size int) (*MerkleTree, []DataBlock) { - blocks := generatedTestDataBlocks(size) - m, err := New(&Config{ - RunInParallel: true, - NumRoutines: 1, - }, blocks) - if err != nil { - panic(err) - } - return m, blocks -} - -func TestMerkleTree_Verify(t *testing.T) { - tests := []struct { - name string - setupFunc func(int) (*MerkleTree, []DataBlock) - blockSize int - want bool - wantErr bool - }{ - { - name: "test_2", - setupFunc: setupTestVerify, - blockSize: 2, - want: true, - wantErr: false, - }, - { - name: "test_3", - setupFunc: setupTestVerify, - blockSize: 3, - want: true, - wantErr: false, - }, - { - name: "test_4", - setupFunc: setupTestVerify, - blockSize: 4, - want: true, - wantErr: false, - }, - { - name: "test_5", - setupFunc: setupTestVerify, - blockSize: 5, - want: true, - wantErr: false, - }, - { - name: "test_6", - setupFunc: setupTestVerify, - blockSize: 6, - want: true, - wantErr: false, - }, - { - name: "test_8", - setupFunc: setupTestVerify, - blockSize: 8, - want: true, - wantErr: false, - }, - { - name: "test_9", - setupFunc: setupTestVerify, - blockSize: 9, - want: true, - wantErr: false, - }, - { - name: "test_1001", - setupFunc: setupTestVerify, - blockSize: 1001, - want: true, - wantErr: false, - }, - { - name: "test_2_parallel", - setupFunc: setupTestVerifyRunInParallel, - blockSize: 2, - want: true, - wantErr: false, - }, - { - name: "test_4_parallel", - setupFunc: setupTestVerifyRunInParallel, - blockSize: 4, - want: true, - wantErr: false, - }, - { - name: "test_64_parallel", - setupFunc: setupTestVerifyRunInParallel, - blockSize: 64, - want: true, - wantErr: false, - }, - { - name: "test_1001_parallel", - setupFunc: setupTestVerifyRunInParallel, - blockSize: 1001, - want: true, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - m, blocks := tt.setupFunc(tt.blockSize) - for i := 0; i < tt.blockSize; i++ { - got, err := m.Verify(blocks[i], m.Proofs[i]) - if (err != nil) != tt.wantErr { - t.Errorf("Verify() error = %v, wantErr %v", err, tt.wantErr) + if !ok { + t.Errorf("proof verification failed, idx %d", idx) return } - if got != tt.want { - t.Errorf("Verify() got = %v, want %v", got, tt.want) - } } - }) - } -} - -func TestMerkleTree_Proof(t *testing.T) { - patches := gomonkey.NewPatches() - defer patches.Reset() - tests := []struct { - name string - config *Config - mock func() - blocks []DataBlock - proofBlocks []DataBlock - wantErr bool - }{ - { - name: "test_2", - config: &Config{Mode: ModeTreeBuild}, - blocks: generatedTestDataBlocks(2), - }, - { - name: "test_4", - config: &Config{Mode: ModeTreeBuild}, - blocks: generatedTestDataBlocks(4), - }, - { - name: "test_5", - config: &Config{Mode: ModeTreeBuild}, - blocks: generatedTestDataBlocks(5), - }, - { - name: "test_wrong_mode", - config: &Config{Mode: ModeProofGen}, - blocks: generatedTestDataBlocks(5), - wantErr: true, - }, - { - name: "test_wrong_blocks", - config: &Config{Mode: ModeTreeBuild}, - blocks: generatedTestDataBlocks(5), - proofBlocks: []DataBlock{ - &mock.DataBlock{ - Data: []byte("test_wrong_blocks"), - }, - }, - wantErr: true, - }, - { - name: "test_data_block_serialize_error", - config: &Config{Mode: ModeTreeBuild}, - mock: func() { - patches.ApplyMethod(reflect.TypeOf(&mock.DataBlock{}), "Serialize", - func(*mock.DataBlock) ([]byte, error) { - return nil, errors.New("data block serialize error") - }) - }, - blocks: generatedTestDataBlocks(5), - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - m1, err := New(nil, tt.blocks) + } else { + compareTree, err := New(&Config{ + SortSiblingPairs: sortSiblingPairs, + DisableLeafHashing: disableLeafHashing, + }, dataBlocks) if err != nil { - t.Errorf("m1 New() error = %v", err) return } - m2, err := New(tt.config, tt.blocks) - if err != nil { - t.Errorf("m2 New() error = %v", err) + if !bytes.Equal(mt.Root, compareTree.Root) { + t.Errorf("tree generated is wrong") return } - if tt.proofBlocks == nil { - tt.proofBlocks = tt.blocks - } - if tt.mock != nil { - tt.mock() - } - defer patches.Reset() - for idx, block := range tt.proofBlocks { - got, err := m2.Proof(block) - if (err != nil) != tt.wantErr { - t.Errorf("Proof() error = %v, wantErr %v", err, tt.wantErr) + for idx, block := range dataBlocks { + proof, err := mt.Proof(block) + if err != nil { return } - if tt.wantErr { + ok, err := mt.Verify(block, proof) + if err != nil { + t.Errorf("proof verification error, idx %d, err %v", idx, err) return } - if !reflect.DeepEqual(got, m1.Proofs[idx]) && !tt.wantErr { - t.Errorf("Proof() %d got = %v, want %v", idx, got, m1.Proofs[idx]) + if !ok { + t.Errorf("proof verification failed, idx %d", idx) + return + } + ok, err = compareTree.Verify(block, proof) + if err != nil { + t.Errorf("proof verification error, idx %d, err %v", idx, err) + return + } + if !ok { + t.Errorf("proof verification failed, idx %d", idx) return } } - }) - } + } + }) } -func mockHashFunc(data []byte) ([]byte, error) { - sha256Func := sha256.New() - sha256Func.Write(data) - return sha256Func.Sum(nil), nil -} +const benchSize = 65536 -func TestMerkleTree_generateProofs(t *testing.T) { - patches := gomonkey.NewPatches() - defer patches.Reset() - type args struct { - config *Config - blocks []DataBlock - } - tests := []struct { - name string - args args - mock func() - wantErr bool - }{ - { - name: "test_hash_func_err", - args: args{ - config: &Config{ - HashFunc: mockHashFunc, - }, - blocks: generatedTestDataBlocks(5), - }, - mock: func() { - patches.ApplyFunc(mockHashFunc, - func([]byte) ([]byte, error) { - return nil, errors.New("test_hash_func_err") - }) - }, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - m, err := New(tt.args.config, tt.args.blocks) - if err != nil { - t.Errorf("New() error = %v", err) - return - } - if tt.mock != nil { - tt.mock() - } - defer patches.Reset() - if err := m.generateProofs(); (err != nil) != tt.wantErr { - t.Errorf("generateProofs() error = %v, wantErr %v", err, tt.wantErr) - } - }) +func BenchmarkMerkleTreeNew_modeProofGen(b *testing.B) { + testCases := mockDataBlocksFixedSize(benchSize) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := New(nil, testCases) + if err != nil { + b.Errorf("New() proof gen error = %v", err) + } } } -func TestVerify(t *testing.T) { - blocks := generatedTestDataBlocks(5) - m, err := New(nil, blocks) - if err != nil { - t.Errorf("New() error = %v", err) - return - } - patches := gomonkey.NewPatches() - defer patches.Reset() - type args struct { - dataBlock DataBlock - proof *Proof - root []byte - config *Config - } - tests := []struct { - name string - args args - mock func() - want bool - wantErr bool - }{ - { - name: "test_ok", - args: args{ - dataBlock: blocks[0], - proof: m.Proofs[0], - root: m.Root, - config: &Config{ - HashFunc: m.HashFunc, - }, - }, - want: true, - }, - { - name: "test_config_nil", - args: args{ - dataBlock: blocks[0], - proof: m.Proofs[0], - root: m.Root, - }, - want: true, - }, - { - name: "test_wrong_root", - args: args{ - dataBlock: blocks[0], - proof: m.Proofs[0], - root: []byte("test_wrong_root"), - config: &Config{ - HashFunc: m.HashFunc, - }, - }, - want: false, - }, - { - name: "test_wrong_hash_func", - args: args{ - dataBlock: blocks[0], - proof: m.Proofs[0], - root: m.Root, - config: &Config{ - HashFunc: func([]byte) ([]byte, error) { return []byte("test_wrong_hash_hash"), nil }, - }, - }, - want: false, - }, - { - name: "test_proof_nil", - args: args{ - dataBlock: blocks[0], - proof: nil, - root: m.Root, - config: &Config{ - HashFunc: m.HashFunc, - }, - }, - want: false, - wantErr: true, - }, - { - name: "test_data_block_nil", - args: args{ - dataBlock: nil, - proof: m.Proofs[0], - root: m.Root, - config: &Config{ - HashFunc: m.HashFunc, - }, - }, - want: false, - wantErr: true, - }, - { - name: "test_hash_func_nil", - args: args{ - dataBlock: blocks[0], - proof: m.Proofs[0], - root: m.Root, - config: &Config{ - HashFunc: nil, - }, - }, - want: true, - wantErr: false, - }, - { - name: "test_hash_func_err", - args: args{ - dataBlock: blocks[0], - proof: m.Proofs[0], - root: m.Root, - config: &Config{ - HashFunc: func([]byte) ([]byte, error) { - return nil, errors.New("test_hash_func_err") - }, - }, - }, - want: false, - wantErr: true, - }, - { - name: "data_block_serialize_err", - args: args{ - dataBlock: blocks[0], - proof: m.Proofs[0], - root: m.Root, - config: &Config{ - HashFunc: m.HashFunc, - }, - }, - mock: func() { - patches.ApplyMethod(reflect.TypeOf(&mock.DataBlock{}), "Serialize", - func(m *mock.DataBlock) ([]byte, error) { - return nil, errors.New("test_data_block_serialize_err") - }) - }, - want: false, - wantErr: true, - }, +func BenchmarkMerkleTreeNew_modeProofGenParallel(b *testing.B) { + config := &Config{ + RunInParallel: true, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.mock != nil { - tt.mock() - } - defer patches.Reset() - got, err := Verify(tt.args.dataBlock, tt.args.proof, tt.args.root, tt.args.config) - if (err != nil) != tt.wantErr { - t.Errorf("Verify() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("Verify() = %v, want %v", got, tt.want) - } - }) + testCases := mockDataBlocksFixedSize(benchSize) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := New(config, testCases) + if err != nil { + b.Errorf("New() proof gen parallel error = %v", err) + } } } -func Test_workerGenerateProofs(t *testing.T) { - patches := gomonkey.NewPatches() - defer patches.Reset() - type args struct { - arg workerArgs - } - mt, err := New(nil, generatedTestDataBlocks(5)) - if err != nil { - t.Errorf("New() error = %v", err) - return - } - mt.HashFunc = func([]byte) ([]byte, error) { - return nil, errors.New("test_hash_func_err") - } - tests := []struct { - mock func() - name string - args args - wantErr bool - }{ - { - name: "test_hash_func_err", - args: args{ - arg: workerArgs{ - generateProofs: &workerArgsGenerateProofs{ - hashFunc: mt.HashFunc, - concatHashFunc: mt.concatHashFunc, - buffer: [][]byte{[]byte("test_buf1"), []byte("test_buf1")}, - tempBuffer: [][]byte{[]byte("test_buf2")}, - bufferLength: 2, - numRoutines: 2, - }, - }, - }, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.mock != nil { - tt.mock() - } - defer patches.Reset() - if err := workerGenerateProofs(tt.args.arg); (err != nil) != tt.wantErr { - t.Errorf("workerGenerateProofs() error = %v, wantErr %v", err, tt.wantErr) - } - }) +func BenchmarkMerkleTreeNew_modeTreeBuild(b *testing.B) { + testCases := mockDataBlocksFixedSize(benchSize) + config := &Config{ + Mode: ModeTreeBuild, } -} - -func BenchmarkMerkleTreeNew(b *testing.B) { - testCases := generatedTestDataBlocks(benchSize) b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := New(nil, testCases) + _, err := New(config, testCases) if err != nil { - b.Errorf("Build() error = %v", err) + b.Errorf("New() tree build error = %v", err) } } } -func BenchmarkMerkleTreeNew_modeRunInParallel(b *testing.B) { +func BenchmarkMerkleTreeNew_modeTreeBuildParallel(b *testing.B) { config := &Config{ + Mode: ModeTreeBuild, RunInParallel: true, } - testCases := generatedTestDataBlocks(benchSize) + testCases := mockDataBlocksFixedSize(benchSize) b.ResetTimer() for i := 0; i < b.N; i++ { _, err := New(config, testCases) if err != nil { - b.Errorf("Build() error = %v", err) + b.Errorf("New() tree build parallel error = %v", err) } } } -func BenchmarkMerkleTreeNew_modeTreeBuild(b *testing.B) { - testCases := generatedTestDataBlocks(benchSize) +func BenchmarkMerkleTreeNew_modeProofGenAndTreeBuild(b *testing.B) { config := &Config{ - Mode: ModeTreeBuild, + Mode: ModeProofGenAndTreeBuild, } + testCases := mockDataBlocksFixedSize(benchSize) b.ResetTimer() for i := 0; i < b.N; i++ { _, err := New(config, testCases) if err != nil { - b.Errorf("Build() error = %v", err) + b.Errorf("New() proof gen and tree build error = %v", err) } } } -func BenchmarkMerkleTreeNew_modeTreeBuildRunInParallel(b *testing.B) { +func BenchmarkMerkleTreeNew_modeProofGenAndTreeBuildParallel(b *testing.B) { config := &Config{ - Mode: ModeTreeBuild, + Mode: ModeProofGenAndTreeBuild, RunInParallel: true, } - testCases := generatedTestDataBlocks(benchSize) + testCases := mockDataBlocksFixedSize(benchSize) b.ResetTimer() for i := 0; i < b.N; i++ { _, err := New(config, testCases) if err != nil { - b.Errorf("Build() error = %v", err) + b.Errorf("New() proof gen and tree build parallel error = %v", err) } } } diff --git a/mock/mock_datablock.go b/mock/mock_data_block.go similarity index 100% rename from mock/mock_datablock.go rename to mock/mock_data_block.go diff --git a/proof.go b/proof.go new file mode 100644 index 0000000..7426904 --- /dev/null +++ b/proof.go @@ -0,0 +1,75 @@ +// MIT License +// +// Copyright (c) 2023 Tommy TIAN +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// Package merkletree implements a high-performance Merkle Tree in Go. +// It supports parallel execution for enhanced performance and +// offers compatibility with OpenZeppelin through sorted sibling pairs. +package merkletree + +// Proof represents a Merkle Tree proof. +type Proof struct { + Siblings [][]byte // Sibling nodes to the Merkle Tree path of the data block. + Path uint32 // Path variable indicating whether the neighbor is on the left or right. +} + +// Proof generates the Merkle proof for a data block using the previously generated Merkle Tree structure. +// This method is only available when the configuration mode is ModeTreeBuild or ModeProofGenAndTreeBuild. +// In ModeProofGen, proofs for all the data blocks are already generated, and the Merkle Tree structure +// is not cached. +func (m *MerkleTree) Proof(dataBlock DataBlock) (*Proof, error) { + if m.Mode != ModeTreeBuild && m.Mode != ModeProofGenAndTreeBuild { + return nil, ErrProofInvalidModeTreeNotBuilt + } + + // Convert the data block to a leaf. + leaf, err := dataBlockToLeaf(dataBlock, m.HashFunc, m.DisableLeafHashing) + if err != nil { + return nil, err + } + + // Retrieve the index of the leaf in the Merkle Tree. + m.leafMapMu.Lock() + idx, ok := m.leafMap[string(leaf)] + m.leafMapMu.Unlock() + if !ok { + return nil, ErrProofInvalidDataBlock + } + + // Compute the path and siblings for the proof. + var ( + path uint32 + siblings = make([][]byte, m.Depth) + ) + for i := 0; i < m.Depth; i++ { + if idx&1 == 1 { + siblings[i] = m.nodes[i][idx-1] + } else { + path += 1 << i + siblings[i] = m.nodes[i][idx+1] + } + idx >>= 1 + } + return &Proof{ + Path: path, + Siblings: siblings, + }, nil +} diff --git a/proof_gen.go b/proof_gen.go new file mode 100644 index 0000000..df057aa --- /dev/null +++ b/proof_gen.go @@ -0,0 +1,189 @@ +// MIT License +// +// Copyright (c) 2023 Tommy TIAN +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// Package merkletree implements a high-performance Merkle Tree in Go. +// It supports parallel execution for enhanced performance and +// offers compatibility with OpenZeppelin through sorted sibling pairs. +package merkletree + +import ( + "sync" + + "golang.org/x/sync/errgroup" +) + +// generateProofs constructs the Merkle Tree and generates the Merkle proofs for each leaf. +// It returns an error if there is an issue during the generation process. +func (m *MerkleTree) generateProofs() (err error) { + m.initProofs() + buffer, bufferSize := initBuffer(m.Leaves) + for step := 0; step < m.Depth; step++ { + bufferSize = fixOddNumOfNodes(buffer, bufferSize, step) + updateProofs(m.Proofs, buffer, bufferSize, step) + for idx := 0; idx < bufferSize; idx += 2 { + leftIdx := idx << step + rightIdx := min(leftIdx+(1<>= 1 + } + m.Root = buffer[0] + return +} + +// generateProofsParallel generates proofs concurrently for the MerkleTree. +func (m *MerkleTree) generateProofsParallel() (err error) { + m.initProofs() + buffer, bufferSize := initBuffer(m.Leaves) + numRoutines := m.NumRoutines + for step := 0; step < m.Depth; step++ { + // Limit the number of workers to the previous level length. + numRoutines = min(numRoutines, bufferSize) + bufferSize = fixOddNumOfNodes(buffer, bufferSize, step) + updateProofsParallel(m.Proofs, buffer, bufferSize, step, m.NumLeaves) + var ( + eg = new(errgroup.Group) + hashFunc = m.HashFunc + concatHashFunc = m.concatHashFunc + ) + for startIdx := 0; startIdx < numRoutines; startIdx++ { + startIdx := startIdx << 1 + eg.Go(func() error { + return workerProofGen( + hashFunc, concatHashFunc, + buffer, bufferSize, numRoutines, startIdx, step, + ) + }) + } + if err = eg.Wait(); err != nil { + return + } + bufferSize >>= 1 + } + m.Root = buffer[0] + return +} + +func workerProofGen( + hashFunc TypeHashFunc, concatHashFunc typeConcatHashFunc, + buffer [][]byte, bufferSize, numRoutine, startIdx, step int, +) error { + var err error + for i := startIdx; i < bufferSize; i += numRoutine << 1 { + leftIdx := i << step + rightIdx := min(leftIdx+(1<>1) + for j := 0; j < numNodes; j += 2 { + if m.nodes[i+1][j>>1], err = m.HashFunc( + m.concatHashFunc(m.nodes[i][j], m.nodes[i][j+1]), + ); err != nil { + return + } + } + } + if m.Root, err = m.HashFunc(m.concatHashFunc( + m.nodes[m.Depth-1][0], m.nodes[m.Depth-1][1], + )); err != nil { + return + } + <-finishMap + return +} + +// treeBuildParallel builds the Merkle Tree and stores all the nodes in parallel. +func (m *MerkleTree) treeBuildParallel() (err error) { + finishMap := make(chan struct{}) + go m.workerBuildLeafMap(finishMap) + m.initNodes() + for i := 0; i < m.Depth-1; i++ { + m.nodes[i] = appendNodeIfOdd(m.nodes[i]) + numNodes := len(m.nodes[i]) + m.nodes[i+1] = make([][]byte, numNodes>>1) + numRoutines := min(m.NumRoutines, numNodes) + eg := new(errgroup.Group) + for startIdx := 0; startIdx < numRoutines; startIdx++ { + startIdx := startIdx + eg.Go(func() error { + for j := startIdx << 1; j < numNodes; j += numRoutines << 1 { + newHash, err := m.HashFunc(m.concatHashFunc( + m.nodes[i][j], m.nodes[i][j+1], + )) + if err != nil { + return err + } + m.nodes[i+1][j>>1] = newHash + } + return nil + }) + } + if err = eg.Wait(); err != nil { + return + } + } + if m.Root, err = m.HashFunc(m.concatHashFunc( + m.nodes[m.Depth-1][0], m.nodes[m.Depth-1][1], + )); err != nil { + return + } + <-finishMap + return +} + +func (m *MerkleTree) workerBuildLeafMap(finishChan chan struct{}) { + m.leafMapMu.Lock() + defer m.leafMapMu.Unlock() + for i := 0; i < m.NumLeaves; i++ { + m.leafMap[string(m.Leaves[i])] = i + } + finishChan <- struct{}{} // empty channel to serve as a wait group for map generation +} + +func (m *MerkleTree) initNodes() { + m.nodes = make([][][]byte, m.Depth) + m.nodes[0] = make([][]byte, m.NumLeaves) + copy(m.nodes[0], m.Leaves) +} + +func appendNodeIfOdd(buffer [][]byte) [][]byte { + if len(buffer)&1 == 0 { + return buffer + } + appendNode := buffer[len(buffer)-1] + buffer = append(buffer, appendNode) + return buffer +} diff --git a/tree_build_test.go b/tree_build_test.go new file mode 100644 index 0000000..f076754 --- /dev/null +++ b/tree_build_test.go @@ -0,0 +1,320 @@ +// MIT License +// +// # Copyright (c) 2023 Tommy TIAN +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package merkletree + +import ( + "bytes" + "crypto/sha256" + "fmt" + "sync/atomic" + "testing" +) + +func TestMerkleTreeNew_modeTreeBuild(t *testing.T) { + var hashFuncCounter int + type args struct { + blocks []DataBlock + config *Config + } + tests := []struct { + name string + args args + checkingConfig *Config + wantErr bool + }{ + { + name: "test_build_tree_2", + args: args{ + blocks: mockDataBlocks(2), + config: &Config{ + Mode: ModeTreeBuild, + }, + }, + wantErr: false, + }, + { + name: "test_build_tree_3", + args: args{ + blocks: mockDataBlocks(3), + config: &Config{ + Mode: ModeTreeBuild, + }, + }, + wantErr: false, + }, + { + name: "test_build_tree_5", + args: args{ + blocks: mockDataBlocks(5), + config: &Config{ + Mode: ModeTreeBuild, + }, + }, + wantErr: false, + }, + { + name: "test_build_tree_8", + args: args{ + blocks: mockDataBlocks(8), + config: &Config{ + Mode: ModeTreeBuild, + }, + }, + wantErr: false, + }, + { + name: "test_build_tree_16", + args: args{ + blocks: mockDataBlocks(16), + config: &Config{ + Mode: ModeTreeBuild, + }, + }, + wantErr: false, + }, + { + name: "test_build_tree_32", + args: args{ + blocks: mockDataBlocks(32), + config: &Config{ + Mode: ModeTreeBuild, + }, + }, + wantErr: false, + }, + { + name: "test_build_tree_36", + args: args{ + blocks: mockDataBlocks(36), + config: &Config{ + Mode: ModeTreeBuild, + }, + }, + wantErr: false, + }, + { + name: "test_build_tree_1000", + args: args{ + blocks: mockDataBlocks(1000), + config: &Config{ + Mode: ModeTreeBuild, + }, + }, + wantErr: false, + }, + { + name: "test_hash_func_error", + args: args{ + blocks: mockDataBlocks(100), + config: &Config{ + HashFunc: func([]byte) ([]byte, error) { + return nil, fmt.Errorf("hash func error") + }, + Mode: ModeTreeBuild, + }, + }, + wantErr: true, + }, + { + name: "test_hash_func_error_when_computing_root", + args: args{ + blocks: mockDataBlocks(4), + config: &Config{ + HashFunc: func(block []byte) ([]byte, error) { + if hashFuncCounter == 6 { + return nil, fmt.Errorf("hash func error") + } + hashFuncCounter++ + sha256Func := sha256.New() + sha256Func.Write(block) + return sha256Func.Sum(nil), nil + }, + Mode: ModeTreeBuild, + }, + }, + wantErr: true, + }, + { + name: "test_disable_leaf_hashing", + args: args{ + blocks: mockDataBlocks(100), + config: &Config{ + DisableLeafHashing: true, + Mode: ModeTreeBuild, + }, + }, + checkingConfig: &Config{ + DisableLeafHashing: true, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m, err := New(tt.args.config, tt.args.blocks) + if (err != nil) != tt.wantErr { + t.Errorf("Build() error = %v, wantErr %v", err, tt.wantErr) + return + } + m1, err := New(tt.checkingConfig, tt.args.blocks) + if err != nil { + t.Errorf("test setup error %v", err) + return + } + if !tt.wantErr && !bytes.Equal(m.Root, m1.Root) && !tt.wantErr { + fmt.Println("m", m.Root) + fmt.Println("m1", m1.Root) + t.Errorf("tree generated is wrong") + return + } + }) + } +} + +func TestMerkleTreeNew_modeTreeBuildParallel(t *testing.T) { + var hashFuncCounter atomic.Uint32 + type args struct { + blocks []DataBlock + config *Config + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "test_build_tree_parallel_2", + args: args{ + blocks: mockDataBlocks(2), + config: &Config{ + RunInParallel: true, + NumRoutines: 4, + Mode: ModeTreeBuild, + }, + }, + wantErr: false, + }, + { + name: "test_build_tree_parallel_4", + args: args{ + blocks: mockDataBlocks(4), + config: &Config{ + RunInParallel: true, + NumRoutines: 4, + Mode: ModeTreeBuild, + }, + }, + wantErr: false, + }, + { + name: "test_build_tree_parallel_5", + args: args{ + blocks: mockDataBlocks(5), + config: &Config{ + RunInParallel: true, + NumRoutines: 4, + Mode: ModeTreeBuild, + }, + }, + wantErr: false, + }, + { + name: "test_build_tree_parallel_8", + args: args{ + blocks: mockDataBlocks(8), + config: &Config{ + RunInParallel: true, + NumRoutines: 4, + Mode: ModeTreeBuild, + }, + }, + wantErr: false, + }, + { + name: "test_build_tree_parallel_8_32", + args: args{ + blocks: mockDataBlocks(8), + config: &Config{ + RunInParallel: true, + NumRoutines: 32, + Mode: ModeTreeBuild, + }, + }, + wantErr: false, + }, + { + name: "test_hash_func_error_parallel", + args: args{ + blocks: mockDataBlocks(100), + config: &Config{ + HashFunc: func([]byte) ([]byte, error) { + return nil, fmt.Errorf("hash func error") + }, + RunInParallel: true, + Mode: ModeTreeBuild, + }, + }, + wantErr: true, + }, + { + name: "test_hash_func_error_when_computing_root_parallel", + args: args{ + blocks: mockDataBlocks(4), + config: &Config{ + HashFunc: func(block []byte) ([]byte, error) { + if hashFuncCounter.Load() == 6 { + return nil, fmt.Errorf("hash func error") + } + hashFuncCounter.Add(1) + sha256Func := sha256.New() + sha256Func.Write(block) + return sha256Func.Sum(nil), nil + }, + Mode: ModeTreeBuild, + RunInParallel: true, + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m, err := New(tt.args.config, tt.args.blocks) + if (err != nil) != tt.wantErr { + t.Errorf("Build() error = %v, wantErr %v", err, tt.wantErr) + return + } + m1, err := New(nil, tt.args.blocks) + if err != nil { + t.Errorf("test setup error %v", err) + return + } + if !tt.wantErr && !bytes.Equal(m.Root, m1.Root) && !tt.wantErr { + fmt.Println("m", m.Root) + fmt.Println("m1", m1.Root) + t.Errorf("tree generated is wrong") + return + } + }) + } +} diff --git a/verify.go b/verify.go new file mode 100644 index 0000000..321a2d2 --- /dev/null +++ b/verify.go @@ -0,0 +1,82 @@ +// MIT License +// +// Copyright (c) 2023 Tommy TIAN +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// Package merkletree implements a high-performance Merkle Tree in Go. +// It supports parallel execution for enhanced performance and +// offers compatibility with OpenZeppelin through sorted sibling pairs. +package merkletree + +import "bytes" + +// Verify checks if the data block is valid using the Merkle Tree proof and the cached Merkle root hash. +func (m *MerkleTree) Verify(dataBlock DataBlock, proof *Proof) (bool, error) { + return Verify(dataBlock, proof, m.Root, m.Config) +} + +// Verify checks if the data block is valid using the Merkle Tree proof and the provided Merkle root hash. +// It returns true if the data block is valid, false otherwise. An error is returned in case of any issues +// during the verification process. +func Verify(dataBlock DataBlock, proof *Proof, root []byte, config *Config) (bool, error) { + // Validate input parameters. + if dataBlock == nil { + return false, ErrDataBlockIsNil + } + if proof == nil { + return false, ErrProofIsNil + } + if config == nil { + config = new(Config) + } + if config.HashFunc == nil { + config.HashFunc = DefaultHashFunc + } + + // Determine the concatenation function based on the configuration. + concatFunc := concatHash + if config.SortSiblingPairs { + concatFunc = concatSortHash + } + + // Convert the data block to a leaf. + leaf, err := dataBlockToLeaf(dataBlock, config.HashFunc, config.DisableLeafHashing) + if err != nil { + return false, err + } + + // Traverse the Merkle proof and compute the resulting hash. + // Copy the slice so that the original leaf won't be modified. + result := make([]byte, len(leaf)) + copy(result, leaf) + path := proof.Path + for _, sib := range proof.Siblings { + if path&1 == 1 { + result, err = config.HashFunc(concatFunc(result, sib)) + } else { + result, err = config.HashFunc(concatFunc(sib, result)) + } + if err != nil { + return false, err + } + path >>= 1 + } + return bytes.Equal(result, root), nil +} diff --git a/verify_test.go b/verify_test.go new file mode 100644 index 0000000..fc45ea4 --- /dev/null +++ b/verify_test.go @@ -0,0 +1,339 @@ +// MIT License +// +// # Copyright (c) 2023 Tommy TIAN +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package merkletree + +import ( + "errors" + "reflect" + "testing" + + "github.com/agiledragon/gomonkey/v2" + + "github.com/txaty/go-merkletree/mock" +) + +func setupTestVerify(size int) (*MerkleTree, []DataBlock) { + blocks := mockDataBlocks(size) + m, err := New(nil, blocks) + if err != nil { + panic(err) + } + return m, blocks +} + +func setupTestVerifyParallel(size int) (*MerkleTree, []DataBlock) { + blocks := mockDataBlocks(size) + m, err := New(&Config{ + RunInParallel: true, + NumRoutines: 1, + }, blocks) + if err != nil { + panic(err) + } + return m, blocks +} + +func TestMerkleTreeVerify(t *testing.T) { + tests := []struct { + name string + setupFunc func(int) (*MerkleTree, []DataBlock) + blockSize int + want bool + wantErr bool + }{ + { + name: "test_2", + setupFunc: setupTestVerify, + blockSize: 2, + want: true, + wantErr: false, + }, + { + name: "test_3", + setupFunc: setupTestVerify, + blockSize: 3, + want: true, + wantErr: false, + }, + { + name: "test_4", + setupFunc: setupTestVerify, + blockSize: 4, + want: true, + wantErr: false, + }, + { + name: "test_5", + setupFunc: setupTestVerify, + blockSize: 5, + want: true, + wantErr: false, + }, + { + name: "test_6", + setupFunc: setupTestVerify, + blockSize: 6, + want: true, + wantErr: false, + }, + { + name: "test_8", + setupFunc: setupTestVerify, + blockSize: 8, + want: true, + wantErr: false, + }, + { + name: "test_9", + setupFunc: setupTestVerify, + blockSize: 9, + want: true, + wantErr: false, + }, + { + name: "test_1001", + setupFunc: setupTestVerify, + blockSize: 1001, + want: true, + wantErr: false, + }, + { + name: "test_2_parallel", + setupFunc: setupTestVerifyParallel, + blockSize: 2, + want: true, + wantErr: false, + }, + { + name: "test_4_parallel", + setupFunc: setupTestVerifyParallel, + blockSize: 4, + want: true, + wantErr: false, + }, + { + name: "test_64_parallel", + setupFunc: setupTestVerifyParallel, + blockSize: 64, + want: true, + wantErr: false, + }, + { + name: "test_1001_parallel", + setupFunc: setupTestVerifyParallel, + blockSize: 1001, + want: true, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m, blocks := tt.setupFunc(tt.blockSize) + for i := 0; i < tt.blockSize; i++ { + got, err := m.Verify(blocks[i], m.Proofs[i]) + if (err != nil) != tt.wantErr { + t.Errorf("Verify() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("Verify() got = %v, want %v", got, tt.want) + } + } + }) + } +} + +func TestVerify(t *testing.T) { + blocks := mockDataBlocks(5) + m, err := New(nil, blocks) + if err != nil { + t.Errorf("New() error = %v", err) + return + } + patches := gomonkey.NewPatches() + defer patches.Reset() + type args struct { + dataBlock DataBlock + proof *Proof + root []byte + config *Config + } + tests := []struct { + name string + args args + mock func() + want bool + wantErr bool + }{ + { + name: "test_ok", + args: args{ + dataBlock: blocks[0], + proof: m.Proofs[0], + root: m.Root, + config: &Config{ + HashFunc: m.HashFunc, + }, + }, + want: true, + }, + { + name: "test_config_nil", + args: args{ + dataBlock: blocks[0], + proof: m.Proofs[0], + root: m.Root, + }, + want: true, + }, + { + name: "test_wrong_root", + args: args{ + dataBlock: blocks[0], + proof: m.Proofs[0], + root: []byte("test_wrong_root"), + config: &Config{ + HashFunc: m.HashFunc, + }, + }, + want: false, + }, + { + name: "test_wrong_hash_func", + args: args{ + dataBlock: blocks[0], + proof: m.Proofs[0], + root: m.Root, + config: &Config{ + HashFunc: func([]byte) ([]byte, error) { return []byte("test_wrong_hash_hash"), nil }, + }, + }, + want: false, + }, + { + name: "test_proof_nil", + args: args{ + dataBlock: blocks[0], + proof: nil, + root: m.Root, + config: &Config{ + HashFunc: m.HashFunc, + }, + }, + want: false, + wantErr: true, + }, + { + name: "test_data_block_nil", + args: args{ + dataBlock: nil, + proof: m.Proofs[0], + root: m.Root, + config: &Config{ + HashFunc: m.HashFunc, + }, + }, + want: false, + wantErr: true, + }, + { + name: "test_hash_func_nil", + args: args{ + dataBlock: blocks[0], + proof: m.Proofs[0], + root: m.Root, + config: &Config{ + HashFunc: nil, + }, + }, + want: true, + wantErr: false, + }, + { + name: "test_hash_func_err", + args: args{ + dataBlock: blocks[0], + proof: m.Proofs[0], + root: m.Root, + config: &Config{ + HashFunc: func([]byte) ([]byte, error) { + return nil, errors.New("test_hash_func_err") + }, + }, + }, + want: false, + wantErr: true, + }, + { + name: "test_disable_leaf_hashing_and_hash_func_err", + args: args{ + dataBlock: blocks[0], + proof: m.Proofs[0], + root: m.Root, + config: &Config{ + DisableLeafHashing: true, + HashFunc: func([]byte) ([]byte, error) { + return nil, errors.New("test_hash_func_err") + }, + }, + }, + want: false, + wantErr: true, + }, + { + name: "data_block_serialize_err", + args: args{ + dataBlock: blocks[0], + proof: m.Proofs[0], + root: m.Root, + config: &Config{ + HashFunc: m.HashFunc, + }, + }, + mock: func() { + patches.ApplyMethod(reflect.TypeOf(&mock.DataBlock{}), "Serialize", + func(m *mock.DataBlock) ([]byte, error) { + return nil, errors.New("test_data_block_serialize_err") + }) + }, + want: false, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.mock != nil { + tt.mock() + } + defer patches.Reset() + got, err := Verify(tt.args.dataBlock, tt.args.proof, tt.args.root, tt.args.config) + if (err != nil) != tt.wantErr { + t.Errorf("Verify() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("Verify() = %v, want %v", got, tt.want) + } + }) + } +}