-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathvalidation.go
129 lines (118 loc) · 3.7 KB
/
validation.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
package merkle
import (
"bytes"
"errors"
"fmt"
"golang.org/x/exp/slices"
)
const MaxUint = ^uint(0)
// ValidatePartialTree uses leafIndices, leaves and proof to calculate the merkle root of the tree and then compares it
// to expectedRoot.
func ValidatePartialTree(
leafIndices []uint64,
leaves, proof [][]byte,
expectedRoot []byte,
hash HashFunc,
) (bool, error) {
v, err := newValidator(leafIndices, leaves, proof, hash, false)
if err != nil {
return false, err
}
root, _, err := v.CalcRoot(MaxUint)
return bytes.Equal(root, expectedRoot), err
}
// ValidatePartialTree uses leafIndices, leaves and proof to calculate the merkle root of the tree and then compares it
// to expectedRoot. Additionally, it reconstructs the parked nodes when each proven leaf was originally added to the
// tree and returns a list of snapshots. This method is ~15% slower than ValidatePartialTree.
func ValidatePartialTreeWithParkingSnapshots(
leafIndices []uint64,
leaves, proof [][]byte,
expectedRoot []byte,
hash HashFunc,
) (bool, []ParkingSnapshot, error) {
v, err := newValidator(leafIndices, leaves, proof, hash, true)
if err != nil {
return false, nil, err
}
root, parkingSnapshots, err := v.CalcRoot(MaxUint)
return bytes.Equal(root, expectedRoot), parkingSnapshots, err
}
func newValidator(
leafIndices []uint64,
leaves, proof [][]byte,
hash HashFunc,
storeSnapshots bool,
) (*Validator, error) {
if len(leafIndices) != len(leaves) {
return nil, fmt.Errorf("number of leaves (%d) must equal number of indices (%d)", len(leaves), len(leafIndices))
}
if len(leaves) == 0 {
return nil, errors.New("at least one leaf is required for validation")
}
if !slices.IsSorted(leafIndices) {
return nil, errors.New("leafIndices are not sorted")
}
if len(slices.Compact(leafIndices)) != len(leafIndices) {
return nil, errors.New("leafIndices contain duplicates")
}
proofNodes := &proofIterator{proof}
leafIt := &LeafIterator{leafIndices, leaves}
return &Validator{Leaves: leafIt, ProofNodes: proofNodes, Hash: hash, StoreSnapshots: storeSnapshots}, nil
}
type Validator struct {
Leaves *LeafIterator
ProofNodes *proofIterator
Hash HashFunc
StoreSnapshots bool
}
type ParkingSnapshot [][]byte
func (v *Validator) CalcRoot(stopAtLayer uint) ([]byte, []ParkingSnapshot, error) {
activePos, activeNode, err := v.Leaves.next()
if err != nil {
return nil, nil, err
}
var lChild, rChild, sibling []byte
var parkingSnapshots, subTreeSnapshots []ParkingSnapshot
if v.StoreSnapshots {
parkingSnapshots = []ParkingSnapshot{nil}
}
for {
if activePos.Height == stopAtLayer {
break
}
// The activeNode's sibling should be calculated if it's an ancestor of the next proven leaf. Otherwise, the
// sibling is the next node in the proof.
nextLeafPos, _, err := v.Leaves.peek()
if err == nil && activePos.sibling().isAncestorOf(nextLeafPos) {
sibling, subTreeSnapshots, err = v.CalcRoot(activePos.Height)
if err != nil {
return nil, nil, err
}
} else {
sibling, err = v.ProofNodes.next()
if err == noMoreItems {
break
}
}
if activePos.isRightSibling() {
lChild, rChild = sibling, activeNode
addToAll(parkingSnapshots, lChild)
} else {
lChild, rChild = activeNode, sibling
addToAll(parkingSnapshots, nil)
if len(subTreeSnapshots) > 0 {
parkingSnapshots = append(parkingSnapshots, addToAll(subTreeSnapshots, activeNode)...)
subTreeSnapshots = nil
}
}
activeNode = v.Hash(nil, lChild, rChild)
activePos = activePos.parent()
}
return activeNode, parkingSnapshots, nil
}
func addToAll(snapshots []ParkingSnapshot, node []byte) []ParkingSnapshot {
for i := 0; i < len(snapshots); i++ {
snapshots[i] = append(snapshots[i], node)
}
return snapshots
}