diff --git a/merkle_tree.go b/merkle_tree.go index 4794fd9..88ed79c 100644 --- a/merkle_tree.go +++ b/merkle_tree.go @@ -78,6 +78,8 @@ type Config struct { // SortSiblingPairs is the parameter for OpenZeppelin compatibility. // If set to `true`, the hashing sibling pairs are sorted. SortSiblingPairs bool + // If true, the leaf nodes are NOT hashed before being added to the Merkle Tree. + DoNotHashLeaves bool } // MerkleTree implements the Merkle Tree structure @@ -479,6 +481,10 @@ func (m *MerkleTree) leafGen(blocks []DataBlock) ([][]byte, error) { if err != nil { return nil, err } + if m.DoNotHashLeaves { + leaves[i] = data + continue + } var hash []byte if hash, err = m.HashFunc(data); err != nil { return nil, err @@ -712,8 +718,13 @@ func (m *MerkleTree) GenerateProof(dataBlock DataBlock) (*Proof, error) { return nil, err } var blockHash []byte - if blockHash, err = m.HashFunc(blockByte); err != nil { - return nil, err + + if m.DoNotHashLeaves { + blockHash = blockByte + } else { + if blockHash, err = m.HashFunc(blockByte); err != nil { + return nil, err + } } val, ok := m.leafMap.Load(string(blockHash)) if !ok { diff --git a/merkle_tree_test.go b/merkle_tree_test.go index 05ead5b..62e9e80 100644 --- a/merkle_tree_test.go +++ b/merkle_tree_test.go @@ -210,6 +210,89 @@ func TestMerkleTreeNew_proofGen(t *testing.T) { } } +func Test_doNotHashLeaves(t *testing.T) { + // Generate some random blocks + blocks := genTestDataBlocks(100) + + // Manually hash the blocks + hashedBlocks := make([]DataBlock, 0) + for _, block := range blocks { + ser, _ := block.Serialize() + hashedBlock, _ := testHashFunc(ser) + hashedBlocks = append(hashedBlocks, &mockDataBlock{ + data: hashedBlock, + }) + } + + // Create a tree that does not hash the leaves, using the already + // hashed blocks + mtNoHash, err := New(&Config{ + DoNotHashLeaves: true, + Mode: ModeProofGenAndTreeBuild, + }, hashedBlocks) + if err != nil { + t.Errorf("error creating tree: %v", err) + return + } + + // Create a tree that hashes the leaves, but provide unhashed blocks + mtHash, err := New(&Config{ + Mode: ModeProofGenAndTreeBuild, + }, blocks) + + // Assert that both trees are identical + + if err != nil { + t.Errorf("error creating tree: %v", err) + return + } + + if !bytes.Equal(mtNoHash.Root, mtHash.Root) { + fmt.Println("root1", mtNoHash.Root) + fmt.Println("root2", mtHash.Root) + t.Errorf("merkle root mismatch") + return + } + + if !reflect.DeepEqual(mtNoHash.Leaves, mtHash.Leaves) { + fmt.Println("leaves1", mtNoHash.Leaves) + fmt.Println("leaves2", mtHash.Leaves) + t.Errorf("leaves mismatch") + return + } + if !reflect.DeepEqual(mtNoHash.Proofs, mtHash.Proofs) { + fmt.Println("proof1", mtNoHash.Proofs) + fmt.Println("proof2", mtHash.Proofs) + t.Errorf("proofs mismatch") + return + } + if mtNoHash.Depth != mtHash.Depth { + fmt.Println("depth1", mtNoHash.Depth) + fmt.Println("depth2", mtHash.Depth) + t.Errorf("merkle tree depth mismatch") + return + } + + if len(mtNoHash.Proofs) != len(mtHash.Proofs) { + fmt.Println("len proofs 1", len(mtNoHash.Proofs)) + fmt.Println("len proofs 2", len(mtHash.Proofs)) + t.Errorf("proofs length mismatch") + return + } + + for i := 0; i < len(blocks); i++ { + proofHash, _ := mtHash.GenerateProof(blocks[i]) + proofNoHash, _ := mtNoHash.GenerateProof(hashedBlocks[i]) + + if !reflect.DeepEqual(proofHash, proofNoHash) { + fmt.Println("proof1", proofHash) + fmt.Println("proof2", proofNoHash) + t.Errorf("proof value mismatch") + return + } + } +} + func TestMerkleTreeNew_buildTree(t *testing.T) { type args struct { blocks []DataBlock