diff --git a/cache/cache.go b/cache/cache.go index 82cbd16..73e3b0a 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -3,16 +3,31 @@ package cache import ( "errors" "fmt" - "github.com/spacemeshos/merkle-tree/cache/readwriters" - "math" + "github.com/spacemeshos/merkle-tree/shared" ) -const NodeSize = 32 +const NodeSize = shared.NodeSize + +type ( + HashFunc = shared.HashFunc + LayerWriter = shared.LayerWriter + LayerReader = shared.LayerReader + LayerReadWriter = shared.LayerReadWriter + CacheWriter = shared.CacheWriter + CacheReader = shared.CacheReader + LayerFactory = shared.LayerFactory + CachingPolicy = shared.CachingPolicy +) + +var RootHeightFromWidth = shared.RootHeightFromWidth type Writer struct { *cache } +// A compile time check to ensure that Writer fully implements CacheWriter. +var _ CacheWriter = (*Writer)(nil) + func NewWriter(shouldCacheLayer CachingPolicy, generateLayer LayerFactory) *Writer { return &Writer{ cache: &cache{ @@ -40,14 +55,14 @@ func (c *Writer) GetLayerWriter(layerHeight uint) (LayerWriter, error) { return layerReadWriter, nil } -func (c *Writer) SetHash(hashFunc func(lChild, rChild []byte) []byte) { +func (c *Writer) SetHash(hashFunc HashFunc) { c.hash = hashFunc } // GetReader returns a cache reader that can be passed into GenerateProof. It first flushes the layer writers to support // layer writers that have internal buffers that may not be reflected in the reader until flushed. After flushing, this // method validates the structure of the cache, including that a base layer is cached. -func (c *Writer) GetReader() (*Reader, error) { +func (c *Writer) GetReader() (CacheReader, error) { if err := c.flush(); err != nil { return nil, err } @@ -69,17 +84,32 @@ type Reader struct { *cache } +// A compile time check to ensure that Reader fully implements CacheReader. +var _ CacheReader = (*Reader)(nil) + +func (c *Reader) Layers() map[uint]LayerReadWriter { + return c.layers +} + func (c *Reader) GetLayerReader(layerHeight uint) LayerReader { return c.layers[layerHeight] } -func (c *Reader) GetHashFunc() func(lChild, rChild []byte) []byte { +func (c *Reader) GetHashFunc() HashFunc { return c.hash } +func (c *Reader) GetLayerFactory() LayerFactory { + return c.generateLayer +} + +func (c *Reader) GetCachingPolicy() CachingPolicy { + return c.shouldCacheLayer +} + type cache struct { layers map[uint]LayerReadWriter - hash func(lChild, rChild []byte) []byte + hash HashFunc shouldCacheLayer CachingPolicy generateLayer LayerFactory } @@ -113,35 +143,6 @@ func (c *cache) validateStructure() error { return nil } -type CachingPolicy func(layerHeight uint) (shouldCacheLayer bool) - -type LayerFactory func(layerHeight uint) (LayerReadWriter, error) - -// LayerReadWriter is a combined reader-writer. Note that the Seek() method only belongs to the LayerReader interface -// and does not affect the LayerWriter. -type LayerReadWriter interface { - LayerReader - LayerWriter -} - -var _ LayerReadWriter = &readwriters.FileReadWriter{} -var _ LayerReadWriter = &readwriters.SliceReadWriter{} - -type LayerReader interface { - Seek(index uint64) error - ReadNext() ([]byte, error) - Width() (uint64, error) -} - -type LayerWriter interface { - Append(p []byte) (n int, err error) - Flush() error -} - -func RootHeightFromWidth(width uint64) uint { - return uint(math.Ceil(math.Log2(float64(width)))) -} - //func (c *cache) Print(bottom, top int) { // for i := top; i >= bottom; i-- { // print("| ") diff --git a/cache/cache_test.go b/cache/cache_test.go index 25071f4..5626557 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -10,7 +10,8 @@ var someError = errors.New("some error") type widthReader struct{ width uint64 } -var _ LayerReadWriter = &widthReader{} +// A compile time check to ensure that widthReader fully implements LayerReadWriter. +var _ LayerReadWriter = (*widthReader)(nil) func (r widthReader) Seek(index uint64) error { return nil } func (r widthReader) ReadNext() ([]byte, error) { return nil, someError } diff --git a/cache/group.go b/cache/group.go new file mode 100644 index 0000000..8cf05a7 --- /dev/null +++ b/cache/group.go @@ -0,0 +1,103 @@ +package cache + +import ( + "errors" + "io" +) + +type GroupLayerReadWriter struct { + chunks []LayerReadWriter + activeChunkIndex int + widthPerChunk uint64 + lastChunkWidth uint64 +} + +// A compile time check to ensure that GroupLayerReadWriter fully implements LayerReadWriter. +var _ LayerReadWriter = (*GroupLayerReadWriter)(nil) + +// groupLayers groups a slice of layers into one unified layer. +func groupLayers(layers []LayerReadWriter) (*GroupLayerReadWriter, error) { + if len(layers) < 2 { + return nil, errors.New("number of layers must be at least 2") + } + + widthPerLayer, err := layers[0].Width() + if err != nil { + return nil, err + } + if widthPerLayer == 0 { + return nil, errors.New("0 width layers are not allowed") + } + + // Verify that all layers, except the last one, have the same width. + var lastLayerWidth uint64 + for i := 1; i < len(layers); i++ { + layer := layers[i] + if layer == nil { + return nil, errors.New("nil layers are not allowed") + } + width, err := layers[i].Width() + if err != nil { + return nil, err + } + + if i == len(layers)-1 { + lastLayerWidth = width + } else { + if width != widthPerLayer && i < len(layers)-1 { + return nil, errors.New("layers width mismatch") + } + } + } + + g := &GroupLayerReadWriter{ + chunks: layers, + widthPerChunk: widthPerLayer, + lastChunkWidth: lastLayerWidth, + } + + return g, nil +} + +func (g *GroupLayerReadWriter) Seek(index uint64) error { + // Find the target chunk. + chunkIndex := int(index / g.widthPerChunk) + if chunkIndex >= len(g.chunks) { + return io.EOF + } + + // Reset the previous active chunk position + // and set the new active chunk. + if g.activeChunkIndex != chunkIndex { + err := g.chunks[g.activeChunkIndex].Seek(0) + if err != nil { + return err + } + + g.activeChunkIndex = chunkIndex + } + + indexInChunk := index % g.widthPerChunk + return g.chunks[g.activeChunkIndex].Seek(indexInChunk) +} + +func (g *GroupLayerReadWriter) ReadNext() ([]byte, error) { + val, err := g.chunks[g.activeChunkIndex].ReadNext() + if err != nil { + if err == io.EOF && g.activeChunkIndex < len(g.chunks)-1 { + g.activeChunkIndex++ + return g.ReadNext() + } + return nil, err + } + + return val, nil +} + +func (g *GroupLayerReadWriter) Width() (uint64, error) { + return uint64(len(g.chunks)-1)*g.widthPerChunk + g.lastChunkWidth, nil +} + +func (g *GroupLayerReadWriter) Append(p []byte) (n int, err error) { panic("not implemented") } + +func (g *GroupLayerReadWriter) Flush() error { panic("not implemented") } diff --git a/cache/group_test.go b/cache/group_test.go new file mode 100644 index 0000000..306ed5a --- /dev/null +++ b/cache/group_test.go @@ -0,0 +1,157 @@ +package cache + +import ( + "github.com/spacemeshos/merkle-tree/cache/readwriters" + "github.com/stretchr/testify/require" + "io" + "testing" +) + +func TestGroupLayers(t *testing.T) { + r := require.New(t) + + // Create 9 nodes. + nodes := genNodes(9) + + // Split the nodes into 3 separate layers. + layers := make([]LayerReadWriter, 3) + layers[0] = &readwriters.SliceReadWriter{} + _, _ = layers[0].Append(nodes[0]) + _, _ = layers[0].Append(nodes[1]) + _, _ = layers[0].Append(nodes[2]) + layers[1] = &readwriters.SliceReadWriter{} + _, _ = layers[1].Append(nodes[3]) + _, _ = layers[1].Append(nodes[4]) + _, _ = layers[1].Append(nodes[5]) + layers[2] = &readwriters.SliceReadWriter{} + _, _ = layers[2].Append(nodes[6]) + _, _ = layers[2].Append(nodes[7]) + _, _ = layers[2].Append(nodes[8]) + + // Group the layers. + layer, err := groupLayers(layers) + r.NoError(err) + + width, err := layer.Width() + r.NoError(err) + r.Equal(width, uint64(len(nodes))) + + // Iterate over the layer. + for _, node := range nodes { + val, err := layer.ReadNext() + r.NoError(err) + r.Equal(val, node) + } + + // Iterate over the layer with Seek. + for i, node := range nodes { + err := layer.Seek(uint64(i)) + r.NoError(err) + val, err := layer.ReadNext() + r.NoError(err) + r.Equal(val, node) + } + _, err = layer.ReadNext() + r.Equal(err, io.EOF) + + // Iterate over the layer with Seek in reverse. + for i := len(nodes) - 1; i >= 0; i-- { + err := layer.Seek(uint64(i)) + r.NoError(err) + val, err := layer.ReadNext() + r.NoError(err) + r.Equal(val, nodes[i]) + } + + // Verify that deactivated chunk position is being reset. + // (target chunk 1 position 1) + err = layer.Seek(uint64(3)) + r.NoError(err) + val, err := layer.ReadNext() + r.NoError(err) + r.Equal(val, nodes[3]) + // (target chunk 0 position 2) + err = layer.Seek(uint64(2)) + r.NoError(err) + val, err = layer.ReadNext() + r.NoError(err) + r.Equal(val, nodes[2]) + // (target chunk 1 position 0) + val, err = layer.ReadNext() + r.NoError(err) + r.Equal(val, nodes[3]) +} + +func TestGroupLayersWithShorterLastLayer(t *testing.T) { + r := require.New(t) + + // Create 7 nodes. + nodes := genNodes(7) + + // Split the nodes into 3 separate layers in groups of [3,3,1]. + layers := make([]LayerReadWriter, 3) + layers[0] = &readwriters.SliceReadWriter{} + _, _ = layers[0].Append(nodes[0]) + _, _ = layers[0].Append(nodes[1]) + _, _ = layers[0].Append(nodes[2]) + layers[1] = &readwriters.SliceReadWriter{} + _, _ = layers[1].Append(nodes[3]) + _, _ = layers[1].Append(nodes[4]) + _, _ = layers[1].Append(nodes[5]) + layers[2] = &readwriters.SliceReadWriter{} + _, _ = layers[2].Append(nodes[6]) + + // Group the layers. + layer, err := groupLayers(layers) + r.NoError(err) + + width, err := layer.Width() + r.NoError(err) + r.Equal(width, uint64(len(nodes))) + + // Iterate over the layer. + for _, node := range nodes { + val, err := layer.ReadNext() + r.NoError(err) + r.Equal(val, node) + } + + // Arrive to EOF with ReadNext. + err = layer.Seek(uint64(6)) + r.NoError(err) + val, err := layer.ReadNext() + r.NoError(err) + r.Equal(val, nodes[6]) + val, err = layer.ReadNext() + r.Equal(io.EOF, err) + + // Arrive to EOF with Seek. + err = layer.Seek(uint64(7)) + r.Equal(io.EOF, err) + err = layer.Seek(uint64(666)) + r.Equal(io.EOF, err) +} + +func TestGroupLayersWithShorterMidLayer(t *testing.T) { + r := require.New(t) + + // Create 7 nodes. + nodes := genNodes(7) + + // Split the nodes into 3 separate layers in groups of [3,1,3]. + layers := make([]LayerReadWriter, 3) + layers[0] = &readwriters.SliceReadWriter{} + _, _ = layers[0].Append(nodes[0]) + _, _ = layers[0].Append(nodes[1]) + _, _ = layers[0].Append(nodes[2]) + layers[1] = &readwriters.SliceReadWriter{} + _, _ = layers[1].Append(nodes[3]) + layers[2] = &readwriters.SliceReadWriter{} + _, _ = layers[2].Append(nodes[4]) + _, _ = layers[2].Append(nodes[5]) + _, _ = layers[2].Append(nodes[6]) + + // Group the layers. + _, err := groupLayers(layers) + r.Equal("layers width mismatch", err.Error()) +} diff --git a/cache/merge.go b/cache/merge.go new file mode 100644 index 0000000..6cb1264 --- /dev/null +++ b/cache/merge.go @@ -0,0 +1,120 @@ +package cache + +import ( + "errors" + "github.com/spacemeshos/merkle-tree" + "io" +) + +// Merge merges a slice of caches into one unified cache. +// Layers of all caches per each height are appended and grouped, while +// the hash function, caching policy and layer factory are taken +// from the first cache of the slice. +func Merge(caches []CacheReader) (*Reader, error) { + if len(caches) < 2 { + return nil, errors.New("number of caches must be at least 2") + } + + // Aggregate caches' layers by height. + layerGroups := make(map[uint][]LayerReadWriter) + for _, cache := range caches { + for height, layer := range cache.Layers() { + layerGroups[height] = append(layerGroups[height], layer) + } + } + + // Group layer groups. + layers := make(map[uint]LayerReadWriter) + for height, layerGroup := range layerGroups { + if len(layerGroup) != len(caches) { + return nil, errors.New("number of layers per height mismatch") + } + + group, err := groupLayers(layerGroup) + if err != nil { + return nil, err + } + layers[height] = group + } + + hashFunc := caches[0].GetHashFunc() + layerFactory := caches[0].GetLayerFactory() + cachingPolicy := caches[0].GetCachingPolicy() + + cache := &cache{ + layers: layers, + hash: hashFunc, + shouldCacheLayer: cachingPolicy, + generateLayer: layerFactory, + } + return &Reader{cache}, nil +} + +// BuildTop builds the top layers of a cache, and returns +// its new version in addition to its root. +func BuildTop(cacheReader CacheReader) (*Reader, []byte, error) { + // Find the cache highest layer. + var maxHeight uint + for height := range cacheReader.Layers() { + if height > maxHeight { + maxHeight = height + } + } + + // Create a subtree with adjusted CachingPolicy and LayerFactory. + // Use the cache highest layer as its leaves. + subtreeWriter := NewWriter( + func(layerHeight uint) bool { + return cacheReader.GetCachingPolicy()(maxHeight + layerHeight) + }, + func(layerHeight uint) (LayerReadWriter, error) { + return cacheReader.GetLayerFactory()(maxHeight + layerHeight) + }) + subtree, err := merkle.NewTreeBuilder(). + WithHashFunc(cacheReader.GetHashFunc()). + WithCacheWriter(subtreeWriter). + Build() + if err != nil { + return nil, nil, err + } + + layer := cacheReader.GetLayerReader(maxHeight) + for { + val, err := layer.ReadNext() + if err != nil { + if err == io.EOF { + break + } else { + return nil, nil, err + } + } + + err = subtree.AddLeaf(val) + if err != nil { + return nil, nil, err + } + } + + // Clone the existing cache. + newCache := &cache{ + layers: cacheReader.Layers(), + hash: cacheReader.GetHashFunc(), + shouldCacheLayer: cacheReader.GetCachingPolicy(), + generateLayer: cacheReader.GetLayerFactory(), + } + + // Add the subtree cache layers on top of the existing ones. + for height, layer := range subtreeWriter.layers { + if height == 0 { + continue + } + newCache.layers[height+maxHeight] = layer + } + + err = newCache.validateStructure() + if err != nil { + return nil, nil, err + } + + return &Reader{cache: newCache}, subtree.Root(), nil +} diff --git a/cache/merge_test.go b/cache/merge_test.go new file mode 100644 index 0000000..41c2162 --- /dev/null +++ b/cache/merge_test.go @@ -0,0 +1,289 @@ +package cache + +import ( + "encoding/binary" + "github.com/spacemeshos/merkle-tree" + "github.com/spacemeshos/merkle-tree/cache/readwriters" + "github.com/spacemeshos/sha256-simd" + "github.com/stretchr/testify/require" + "testing" +) + +func TestMerge(t *testing.T) { + r := require.New(t) + + readers := make([]*Reader, 3) + readers[0] = &Reader{&cache{layers: make(map[uint]LayerReadWriter), hash: merkle.GetSha256Parent}} + readers[1] = &Reader{&cache{layers: make(map[uint]LayerReadWriter), hash: merkle.GetSha256Parent}} + readers[2] = &Reader{&cache{layers: make(map[uint]LayerReadWriter), hash: merkle.GetSha256Parent}} + + // Create 9 nodes. + nodes := genNodes(9) + + // Split the nodes into 3 layers. + splitLayer := make([]LayerReadWriter, 3) + splitLayer[0] = &readwriters.SliceReadWriter{} + splitLayer[1] = &readwriters.SliceReadWriter{} + splitLayer[2] = &readwriters.SliceReadWriter{} + _, _ = splitLayer[0].Append(nodes[0]) + _, _ = splitLayer[0].Append(nodes[1]) + _, _ = splitLayer[0].Append(nodes[2]) + _, _ = splitLayer[1].Append(nodes[3]) + _, _ = splitLayer[1].Append(nodes[4]) + _, _ = splitLayer[1].Append(nodes[5]) + _, _ = splitLayer[2].Append(nodes[6]) + _, _ = splitLayer[2].Append(nodes[7]) + _, _ = splitLayer[2].Append(nodes[8]) + + // Assign the split layer into 3 different readers on height 0. + readers[0].cache.layers[0] = splitLayer[0] + readers[1].cache.layers[0] = splitLayer[1] + readers[2].cache.layers[0] = splitLayer[2] + + var caches []CacheReader + for _, reader := range readers { + caches = append(caches, CacheReader(reader)) + } + cache, err := Merge(caches) + r.NoError(err) + + // Verify the split layers group. + layer := cache.GetLayerReader(0) + width, err := layer.Width() + r.NoError(err) + r.Equal(width, uint64(len(nodes))) + + // Iterate over the layer. + for _, node := range nodes { + val, err := layer.ReadNext() + r.NoError(err) + r.Equal(val, node) + } +} + +func TestMergeFailure1(t *testing.T) { + r := require.New(t) + + readers := make([]*Reader, 1) + readers[0] = &Reader{&cache{layers: make(map[uint]LayerReadWriter)}} + + var caches []CacheReader + for _, reader := range readers { + caches = append(caches, CacheReader(reader)) + } + _, err := Merge(caches) + r.Equal("number of caches must be at least 2", err.Error()) +} + +func TestMergeFailure2(t *testing.T) { + r := require.New(t) + + readers := make([]*Reader, 2) + readers[0] = &Reader{&cache{layers: make(map[uint]LayerReadWriter), hash: merkle.GetSha256Parent}} + readers[1] = &Reader{&cache{layers: make(map[uint]LayerReadWriter), hash: merkle.GetSha256Parent}} + + readers[0].cache.layers[0] = &readwriters.SliceReadWriter{} + + var caches []CacheReader + for _, reader := range readers { + caches = append(caches, CacheReader(reader)) + } + _, err := Merge(caches) + r.Equal("number of layers per height mismatch", err.Error()) +} + +func TestMergeAndBuildTopCache(t *testing.T) { + r := require.New(t) + + // Create 4 trees. + cacheReaders := make([]CacheReader, 4) + for i := 0; i < 4; i++ { + cacheWriter := NewWriter(MinHeightPolicy(0), MakeSliceReadWriterFactory()) + tree, err := merkle.NewCachingTree(cacheWriter) + r.NoError(err) + for i := uint64(0); i < 8; i++ { + err := tree.AddLeaf(NewNodeFromUint64(i)) + r.NoError(err) + } + + cacheReader, err := cacheWriter.GetReader() + r.NoError(err) + + assertWidth(r, 8, cacheReader.GetLayerReader(0)) + assertWidth(r, 4, cacheReader.GetLayerReader(1)) + assertWidth(r, 2, cacheReader.GetLayerReader(2)) + assertWidth(r, 1, cacheReader.GetLayerReader(3)) + cacheRoot, err := cacheReader.GetLayerReader(3).ReadNext() + r.NoError(err) + r.Equal(cacheRoot, tree.Root()) + err = cacheReader.GetLayerReader(3).Seek(0) // Reset position. + r.NoError(err) + + cacheReaders[i] = cacheReader + } + + // Merge caches, and verify that the upper subtree is missing. + cacheReader, err := Merge(cacheReaders) + r.NoError(err) + assertWidth(r, 32, cacheReader.GetLayerReader(0)) + assertWidth(r, 16, cacheReader.GetLayerReader(1)) + assertWidth(r, 8, cacheReader.GetLayerReader(2)) + assertWidth(r, 4, cacheReader.GetLayerReader(3)) + r.Nil(cacheReader.GetLayerReader(4)) + r.Nil(cacheReader.GetLayerReader(5)) + + // Create the upper subtree. + cacheReader, root, err := BuildTop(cacheReader) + r.NoError(err) + assertWidth(r, 32, cacheReader.GetLayerReader(0)) + assertWidth(r, 16, cacheReader.GetLayerReader(1)) + assertWidth(r, 8, cacheReader.GetLayerReader(2)) + assertWidth(r, 4, cacheReader.GetLayerReader(3)) + assertWidth(r, 2, cacheReader.GetLayerReader(4)) + assertWidth(r, 1, cacheReader.GetLayerReader(5)) + + // Compare the cache root with the root received from BuildTop. + cacheRoot, err := cacheReader.GetLayerReader(5).ReadNext() + r.NoError(err) + r.Equal(cacheRoot, root) + err = cacheReader.GetLayerReader(5).Seek(0) // Reset position. + r.NoError(err) +} + +func TestMergeAndBuildTop(t *testing.T) { + r := require.New(t) + + // Create 32 nodes. + nodes := genNodes(32) + + // Create a custom (non-default) hash function. + hashFunc := func(lChild, rChild []byte) []byte { + bytes := append(lChild, rChild...) + res := sha256.Sum256(append([]byte("0"), bytes...)) + return res[:] + } + + // Add the nodes as leaves to a baseline tree. + cacheWriter := NewWriter(MinHeightPolicy(0), MakeSliceReadWriterFactory()) + baselineTree, err := merkle.NewTreeBuilder(). + WithHashFunc(hashFunc). + WithCacheWriter(cacheWriter). + Build() + r.NoError(err) + for i := 0; i < len(nodes); i++ { + err := baselineTree.AddLeaf(NewNodeFromUint64(uint64(i))) + r.NoError(err) + } + + // Add the nodes as leaves to 4 separate trees. + cacheWriters := make([]CacheWriter, 4) + cacheReaders := make([]CacheReader, 4) + trees := make([]*merkle.Tree, 4) + for i := 0; i < 4; i++ { + cacheWriter := NewWriter(MinHeightPolicy(0), MakeSliceReadWriterFactory()) + tree, err := merkle.NewTreeBuilder(). + WithHashFunc(hashFunc). + WithCacheWriter(cacheWriter). + Build() + r.NoError(err) + + cacheWriters[i] = cacheWriter + trees[i] = tree + } + for i := 0; i < len(nodes); i++ { + err := trees[i/8].AddLeaf(NewNodeFromUint64(uint64(i))) + r.NoError(err) + } + for i := 0; i < 4; i++ { + reader, err := cacheWriters[i].GetReader() + r.NoError(err) + cacheReaders[i] = reader + } + + // Merge caches. + cacheReader, err := Merge(cacheReaders) + r.NoError(err) + r.NotNil(cacheReader) + + // Create the upper subtree. + cacheReader, mergeRoot, err := BuildTop(cacheReader) + r.NoError(err) + r.NotNil(cacheReader) + + // Verify that the 4 trees merge root is the same as the baseline tree root. + r.Equal(mergeRoot, baselineTree.Root()) +} + +// TODO(moshababo): support unbalanced merge to make this test pass. +//func TestMergeAndBuildTopUnbalanced(t *testing.T) { +// r := require.New(t) +// +// // Create 29 nodes. +// nodes := genNodes(29) +// +// // Add the nodes as leaves to one tree and save its root. +// cacheWriter := NewWriter(MinHeightPolicy(0), MakeSliceReadWriterFactory()) +// tree, err := merkle.NewCachingTree(cacheWriter) +// r.NoError(err) +// for i := 0; i < len(nodes); i++ { +// err := tree.AddLeaf(NewNodeFromUint64(uint64(i))) +// r.NoError(err) +// } +// treeRoot := tree.Root() +// +// // Add the nodes as leaves to 4 separate trees. +// cacheWriters := make([]CacheWriter, 4) +// cacheReaders := make([]CacheReader, 4) +// trees := make([]*merkle.Tree, 4) +// for i := 0; i < 4; i++ { +// cacheWriter := NewWriter(MinHeightPolicy(0), MakeSliceReadWriterFactory()) +// tree, err := merkle.NewCachingTree(cacheWriter) +// r.NoError(err) +// +// cacheWriters[i] = cacheWriter +// trees[i] = tree +// } +// for i := 0; i < len(nodes); i++ { +// err := trees[i/8].AddLeaf(NewNodeFromUint64(uint64(i))) +// r.NoError(err) +// } +// for i := 0; i < 4; i++ { +// reader, err := cacheWriters[i].GetReader() +// r.NoError(err) +// cacheReaders[i] = reader +// } +// +// // Merge caches. +// cacheReader, err := Merge(cacheReaders) +// r.NoError(err) +// r.NotNil(cacheReader) +// +// // Create the upper subtree. +// cacheReader, mergeRoot, err := BuildTop(cacheReader) +// r.NoError(err) +// r.NotNil(cacheReader) +// +// // Verify that the 4 trees merge root is the same as the main tree root. +// r.Equal(mergeRoot, treeRoot) +//} + +func genNodes(num int) [][]byte { + nodes := make([][]byte, num) + for i := 0; i < num; i++ { + nodes[i] = NewNodeFromUint64(uint64(i)) + } + return nodes +} + +func NewNodeFromUint64(i uint64) []byte { + b := make([]byte, NodeSize) + binary.LittleEndian.PutUint64(b, i) + return b +} + +func assertWidth(r *require.Assertions, expectedWidth int, layerReader LayerReader) { + r.NotNil(layerReader) + width, err := layerReader.Width() + r.NoError(err) + r.Equal(uint64(expectedWidth), width) +} diff --git a/cache/readwriters/file.go b/cache/readwriters/file.go index 3f3093c..684130d 100644 --- a/cache/readwriters/file.go +++ b/cache/readwriters/file.go @@ -3,6 +3,7 @@ package readwriters import ( "bufio" "fmt" + "github.com/spacemeshos/merkle-tree/shared" "io" "os" ) @@ -25,6 +26,9 @@ type FileReadWriter struct { b *bufio.ReadWriter } +// A compile time check to ensure that FileReadWriter fully implements LayerReadWriter. +var _ shared.LayerReadWriter = (*FileReadWriter)(nil) + func (rw *FileReadWriter) Seek(index uint64) error { _, err := rw.f.Seek(int64(index*NodeSize), io.SeekStart) if err != nil { diff --git a/cache/readwriters/slice.go b/cache/readwriters/slice.go index bbdf623..95d1bef 100644 --- a/cache/readwriters/slice.go +++ b/cache/readwriters/slice.go @@ -1,16 +1,20 @@ package readwriters import ( + "github.com/spacemeshos/merkle-tree/shared" "io" ) -const NodeSize = 32 +const NodeSize = shared.NodeSize type SliceReadWriter struct { slice [][]byte position uint64 } +// A compile time check to ensure that SliceReadWriter fully implements LayerReadWriter. +var _ shared.LayerReadWriter = (*SliceReadWriter)(nil) + func (s *SliceReadWriter) Width() (uint64, error) { return uint64(len(s.slice)), nil } diff --git a/iterators.go b/iterators.go index 7ceccb1..5ca0d1d 100644 --- a/iterators.go +++ b/iterators.go @@ -7,9 +7,9 @@ import ( var noMoreItems = errors.New("no more items") -type set map[uint64]bool +type Set map[uint64]bool -func (s set) asSortedSlice() []uint64 { +func (s Set) AsSortedSlice() []uint64 { var ret []uint64 for key, value := range s { if value { @@ -20,8 +20,8 @@ func (s set) asSortedSlice() []uint64 { return ret } -func setOf(members ...uint64) set { - ret := make(set) +func SetOf(members ...uint64) Set { + ret := make(Set) for _, member := range members { ret[member] = true } @@ -32,22 +32,22 @@ type positionsIterator struct { s []uint64 } -func newPositionsIterator(positions set) *positionsIterator { - s := positions.asSortedSlice() +func NewPositionsIterator(positions Set) *positionsIterator { + s := positions.AsSortedSlice() return &positionsIterator{s: s} } -func (it *positionsIterator) peek() (pos position, found bool) { +func (it *positionsIterator) peek() (pos Position, found bool) { if len(it.s) == 0 { - return position{}, false + return Position{}, false } index := it.s[0] - return position{index: index}, true + return Position{Index: index}, true } // batchPop returns the indices of all positions up to endIndex. -func (it *positionsIterator) batchPop(endIndex uint64) set { - res := make(set) +func (it *positionsIterator) batchPop(endIndex uint64) Set { + res := make(Set) for len(it.s) > 0 && it.s[0] < endIndex { res[it.s[0]] = true it.s = it.s[1:] @@ -68,27 +68,27 @@ func (it *proofIterator) next() ([]byte, error) { return n, nil } -type leafIterator struct { +type LeafIterator struct { indices []uint64 leaves [][]byte } -// leafIterator.next() returns the leaf index and value -func (it *leafIterator) next() (position, []byte, error) { +// LeafIterator.next() returns the leaf index and value +func (it *LeafIterator) next() (Position, []byte, error) { if len(it.indices) == 0 { - return position{}, nil, noMoreItems + return Position{}, nil, noMoreItems } idx := it.indices[0] leaf := it.leaves[0] it.indices = it.indices[1:] it.leaves = it.leaves[1:] - return position{index: idx}, leaf, nil + return Position{Index: idx}, leaf, nil } -// leafIterator.peek() returns the leaf index but doesn't move the iterator to this leaf as next would do -func (it *leafIterator) peek() (position, []byte, error) { +// LeafIterator.peek() returns the leaf index but doesn't move the iterator to this leaf as next would do +func (it *LeafIterator) peek() (Position, []byte, error) { if len(it.indices) == 0 { - return position{}, nil, noMoreItems + return Position{}, nil, noMoreItems } - return position{index: it.indices[0]}, it.leaves[0], nil + return Position{Index: it.indices[0]}, it.leaves[0], nil } diff --git a/merkle.go b/merkle.go index b49b90a..5573e69 100644 --- a/merkle.go +++ b/merkle.go @@ -2,27 +2,36 @@ package merkle import ( "errors" - "github.com/spacemeshos/merkle-tree/cache" + "github.com/spacemeshos/merkle-tree/shared" "github.com/spacemeshos/sha256-simd" ) -const NodeSize = cache.NodeSize +const NodeSize = shared.NodeSize -type HashFunc func(lChild, rChild []byte) []byte +type ( + HashFunc = shared.HashFunc + LayerWriter = shared.LayerWriter + LayerReader = shared.LayerReader + LayerReadWriter = shared.LayerReadWriter + CacheWriter = shared.CacheWriter + CacheReader = shared.CacheReader +) + +var RootHeightFromWidth = shared.RootHeightFromWidth -var emptyNode node +var EmptyNode node // PaddingValue is used for padding unbalanced trees. This value should not be permitted at the leaf layer to // distinguish padding from actual members of the tree. var PaddingValue = node{ value: make([]byte, NodeSize), // Zero filled. - onProvenPath: false, + OnProvenPath: false, } // node is a node in the merkle tree. type node struct { value []byte - onProvenPath bool // Whether this node is an ancestor of a leaf whose membership in the tree is being proven. + OnProvenPath bool // Whether this node is an ancestor of a leaf whose membership in the tree is being proven. } func (n node) IsEmpty() bool { @@ -34,11 +43,11 @@ type layer struct { height uint parking node // This is where we park a node until its sibling is processed and we can calculate their parent. next *layer - cache cache.LayerWriter + cache LayerWriter } // ensureNextLayerExists creates the next layer if it doesn't exist. -func (l *layer) ensureNextLayerExists(cacheWriter *cache.Writer) error { +func (l *layer) ensureNextLayerExists(cacheWriter shared.CacheWriter) error { if l.next == nil { writer, err := cacheWriter.GetLayerWriter(l.height + 1) if err != nil { @@ -49,7 +58,7 @@ func (l *layer) ensureNextLayerExists(cacheWriter *cache.Writer) error { return nil } -func newLayer(height uint, cache cache.LayerWriter) *layer { +func newLayer(height uint, cache LayerWriter) *layer { return &layer{height: height, cache: cache} } @@ -58,8 +67,8 @@ type sparseBoolStack struct { currentIndex uint64 } -func newSparseBoolStack(trueIndices set) *sparseBoolStack { - sorted := trueIndices.asSortedSlice() +func NewSparseBoolStack(trueIndices Set) *sparseBoolStack { + sorted := trueIndices.AsSortedSlice() return &sparseBoolStack{sortedTrueIndices: sorted} } @@ -85,7 +94,7 @@ type Tree struct { hash HashFunc proof [][]byte leavesToProve *sparseBoolStack - cacheWriter *cache.Writer + cacheWriter CacheWriter minHeight uint } @@ -94,7 +103,7 @@ type Tree struct { func (t *Tree) AddLeaf(value []byte) error { n := node{ value: value, - onProvenPath: t.leavesToProve.Pop(), + OnProvenPath: t.leavesToProve.Pop(), } l := t.baseLayer var parent, lChild, rChild node @@ -122,11 +131,11 @@ func (t *Tree) AddLeaf(value []byte) error { // A given node is required in the proof if and only if its parent is an ancestor // of a leaf whose membership in the tree is being proven, but the given node isn't. - if parent.onProvenPath { - if !lChild.onProvenPath { + if parent.OnProvenPath { + if !lChild.OnProvenPath { t.proof = append(t.proof, lChild.value) } - if !rChild.onProvenPath { + if !rChild.OnProvenPath { t.proof = append(t.proof, rChild.value) } } @@ -187,11 +196,11 @@ func (t *Tree) RootAndProof() ([]byte, [][]byte) { // Consider adding children to the ephemeralProof. `onProvenPath` must be explicitly set -- an empty node has // the default value `false` and would never pass this point. - if parent.onProvenPath { - if !lChild.onProvenPath { + if parent.OnProvenPath { + if !lChild.OnProvenPath { ephemeralProof = append(ephemeralProof, lChild.value) } - if !rChild.onProvenPath { + if !rChild.OnProvenPath { ephemeralProof = append(ephemeralProof, rChild.value) } } @@ -231,7 +240,7 @@ func (t *Tree) calcEphemeralParent(parking, ephemeralNode node) (parent, lChild, lChild, rChild = ephemeralNode, PaddingValue default: // both are empty - return emptyNode, emptyNode, emptyNode + return EmptyNode, EmptyNode, EmptyNode } return t.calcParent(lChild, rChild), lChild, rChild } @@ -240,7 +249,7 @@ func (t *Tree) calcEphemeralParent(parking, ephemeralNode node) (parent, lChild, func (t *Tree) calcParent(lChild, rChild node) node { return node{ value: t.hash(lChild.value, rChild.value), - onProvenPath: lChild.onProvenPath || rChild.onProvenPath, + OnProvenPath: lChild.OnProvenPath || rChild.OnProvenPath, } } diff --git a/merkle_test.go b/merkle_test.go index 2b7e792..5acc522 100644 --- a/merkle_test.go +++ b/merkle_test.go @@ -1,15 +1,40 @@ -package merkle +package merkle_test import ( "encoding/binary" "encoding/hex" "fmt" + "github.com/spacemeshos/merkle-tree" "github.com/spacemeshos/merkle-tree/cache" "github.com/stretchr/testify/require" "testing" "time" ) +var ( + NewTree = merkle.NewTree + NewTreeBuilder = merkle.NewTreeBuilder + NewProvingTree = merkle.NewProvingTree + NewCachingTree = merkle.NewCachingTree + GenerateProof = merkle.GenerateProof + ValidatePartialTree = merkle.ValidatePartialTree + ValidatePartialTreeWithParkingSnapshots = merkle.ValidatePartialTreeWithParkingSnapshots + GetSha256Parent = merkle.GetSha256Parent + GetNode = merkle.GetNode + setOf = merkle.SetOf + newSparseBoolStack = merkle.NewSparseBoolStack + emptyNode = merkle.EmptyNode + NodeSize = merkle.NodeSize +) + +type ( + set = merkle.Set + position = merkle.Position + validator = merkle.Validator + leafIterator = merkle.LeafIterator + CacheReader = cache.CacheReader +) + /* 8-leaf tree (1st 2 bytes of each node): @@ -171,6 +196,7 @@ func TestNewTreeUnbalancedProof(t *testing.T) { } func assertWidth(r *require.Assertions, expectedWidth int, layerReader cache.LayerReader) { + r.NotNil(layerReader) width, err := layerReader.Width() r.NoError(err) r.Equal(uint64(expectedWidth), width) @@ -398,7 +424,7 @@ func TestEmptyNode(t *testing.T) { r := require.New(t) r.True(emptyNode.IsEmpty()) - r.False(emptyNode.onProvenPath) + r.False(emptyNode.OnProvenPath) } func TestTree_GetParkedNodes(t *testing.T) { @@ -477,8 +503,8 @@ func ExampleTree() { // We now have access to a sorted list of proven leaves, the values of those leaves and the Merkle proof for them: fmt.Println(sortedProvenLeafIndices) // 0 4 7 - fmt.Println(nodes(provenLeaves)) // 0000 0400 0700 - fmt.Println(nodes(proof)) // 0100 0094 0500 0600 + fmt.Println(nodes(provenLeaves)) // 0000 0400 0700 + fmt.Println(nodes(proof)) // 0100 0094 0500 0600 // We can validate these values using ValidatePartialTree: valid, err := ValidatePartialTree(sortedProvenLeafIndices, provenLeaves, proof, tree.Root(), GetSha256Parent) @@ -494,4 +520,4 @@ func ExampleTree() { | cb59 .0094. bd50 fa67 | | =0000=.0100. 0200 0300 =0400=.0500..0600.=0700= | ***************************************************/ -} \ No newline at end of file +} diff --git a/position.go b/position.go index b79c589..3e4b9dd 100644 --- a/position.go +++ b/position.go @@ -2,57 +2,57 @@ package merkle import "fmt" -type position struct { - index uint64 - height uint +type Position struct { + Index uint64 + Height uint } -func (p position) String() string { - return fmt.Sprintf("", p.height, p.index) +func (p Position) String() string { + return fmt.Sprintf("", p.Height, p.Index) } -func (p position) sibling() position { - return position{ - index: p.index ^ 1, - height: p.height, +func (p Position) sibling() Position { + return Position{ + Index: p.Index ^ 1, + Height: p.Height, } } -func (p position) isAncestorOf(other position) bool { - if p.height < other.height { +func (p Position) isAncestorOf(other Position) bool { + if p.Height < other.Height { return false } - return p.index == (other.index >> (p.height - other.height)) + return p.Index == (other.Index >> (p.Height - other.Height)) } -func (p position) isRightSibling() bool { - return p.index%2 == 1 +func (p Position) isRightSibling() bool { + return p.Index%2 == 1 } -func (p position) parent() position { - return position{ - index: p.index >> 1, - height: p.height + 1, +func (p Position) parent() Position { + return Position{ + Index: p.Index >> 1, + Height: p.Height + 1, } } -func (p position) leftChild() position { - return position{ - index: p.index << 1, - height: p.height - 1, +func (p Position) leftChild() Position { + return Position{ + Index: p.Index << 1, + Height: p.Height - 1, } } type positionsStack struct { - positions []position + positions []Position } -func (s *positionsStack) Push(v position) { +func (s *positionsStack) Push(v Position) { s.positions = append(s.positions, v) } // Check the top of the stack for equality and pop the element if it's equal. -func (s *positionsStack) PopIfEqual(p position) bool { +func (s *positionsStack) PopIfEqual(p Position) bool { l := len(s.positions) if l == 0 { return false diff --git a/position_test.go b/position_test.go index ded60b4..ac4e40d 100644 --- a/position_test.go +++ b/position_test.go @@ -6,14 +6,14 @@ import ( ) func TestPosition_isAncestorOf(t *testing.T) { - lower := position{ - index: 0, - height: 0, + lower := Position{ + Index: 0, + Height: 0, } - higher := position{ - index: 0, - height: 1, + higher := Position{ + Index: 0, + Height: 1, } isAncestor := lower.isAncestorOf(higher) diff --git a/proving.go b/proving.go index 3cb689b..3606f67 100644 --- a/proving.go +++ b/proving.go @@ -3,7 +3,6 @@ package merkle import ( "errors" "fmt" - "github.com/spacemeshos/merkle-tree/cache" "io" ) @@ -11,16 +10,16 @@ var ErrMissingValueAtBaseLayer = errors.New("reader for base layer must be inclu func GenerateProof( provenLeafIndices map[uint64]bool, - treeCache *cache.Reader, + treeCache CacheReader, ) (sortedProvenLeafIndices []uint64, provenLeaves, proofNodes [][]byte, err error) { - provenLeafIndexIt := newPositionsIterator(provenLeafIndices) + provenLeafIndexIt := NewPositionsIterator(provenLeafIndices) skipPositions := &positionsStack{} width, err := treeCache.GetLayerReader(0).Width() if err != nil { return nil, nil, nil, err } - rootHeight := cache.RootHeightFromWidth(width) + rootHeight := RootHeightFromWidth(width) for { // Process proven leaves: @@ -38,7 +37,7 @@ func GenerateProof( } // Prepare list of leaves to prove in the subtree. - leavesToProve := provenLeafIndexIt.batchPop(subtreeStart.index + width) + leavesToProve := provenLeafIndexIt.batchPop(subtreeStart.Index + width) additionalProof, additionalLeaves, err := calcSubtreeProof(treeCache, leavesToProve, subtreeStart, width) if err != nil { @@ -47,7 +46,7 @@ func GenerateProof( proofNodes = append(proofNodes, additionalProof...) provenLeaves = append(provenLeaves, additionalLeaves...) - for ; currentPos.height < rootHeight; currentPos = currentPos.parent() { // Traverse treeCache: + for ; currentPos.Height < rootHeight; currentPos = currentPos.parent() { // Traverse treeCache: // Check if we're revisiting a node. If we've descended into a subtree and just got back, we shouldn't add // the sibling to the proof and instead move on to the parent. @@ -72,21 +71,21 @@ func GenerateProof( } } - return set(provenLeafIndices).asSortedSlice(), provenLeaves, proofNodes, nil + return Set(provenLeafIndices).AsSortedSlice(), provenLeaves, proofNodes, nil } -func calcSubtreeProof(c *cache.Reader, leavesToProve set, subtreeStart position, width uint64) ( +func calcSubtreeProof(c CacheReader, leavesToProve Set, subtreeStart Position, width uint64) ( additionalProof, additionalLeaves [][]byte, err error) { // By subtracting subtreeStart.index we get the index relative to the subtree. - relativeLeavesToProve := make(set) + relativeLeavesToProve := make(Set) for leafIndex, prove := range leavesToProve { - relativeLeavesToProve[leafIndex-subtreeStart.index] = prove + relativeLeavesToProve[leafIndex-subtreeStart.Index] = prove } // Prepare leaf reader to read subtree leaves. reader := c.GetLayerReader(0) - err = reader.Seek(subtreeStart.index) + err = reader.Seek(subtreeStart.Index) if err != nil { return nil, nil, errors.New("while preparing to traverse subtree: " + err.Error()) } @@ -99,14 +98,14 @@ func calcSubtreeProof(c *cache.Reader, leavesToProve set, subtreeStart position, return additionalProof, additionalLeaves, err } -func traverseSubtree(leafReader cache.LayerReader, width uint64, hash HashFunc, leavesToProve set, +func traverseSubtree(leafReader LayerReader, width uint64, hash HashFunc, leavesToProve Set, externalPadding []byte) (root []byte, proof, provenLeaves [][]byte, err error) { shouldUseExternalPadding := externalPadding != nil t, err := NewTreeBuilder(). WithHashFunc(hash). WithLeavesToProve(leavesToProve). - WithMinHeight(cache.RootHeightFromWidth(width)). // This ensures the correct size tree, even if padding is needed. + WithMinHeight(RootHeightFromWidth(width)). // This ensures the correct size tree, even if padding is needed. Build() if err != nil { return nil, nil, nil, errors.New("while building a tree: " + err.Error()) @@ -135,22 +134,22 @@ func traverseSubtree(leafReader cache.LayerReader, width uint64, hash HashFunc, return root, proof, provenLeaves, nil } -// GetNode reads the node at the requested position from the cache or calculates it if not available. -func GetNode(c *cache.Reader, nodePos position) ([]byte, error) { +// GetNode reads the node at the requested Position from the cache or calculates it if not available. +func GetNode(c CacheReader, nodePos Position) ([]byte, error) { // Get the cache reader for the requested node's layer. - reader := c.GetLayerReader(nodePos.height) + reader := c.GetLayerReader(nodePos.Height) // If the cache wasn't found, we calculate the minimal subtree that will get us the required node. if reader == nil { return calcNode(c, nodePos) } - err := reader.Seek(nodePos.index) + err := reader.Seek(nodePos.Index) if err == io.EOF { return calcNode(c, nodePos) } if err != nil { - return nil, errors.New("while seeking to position " + nodePos.String() + " in cache: " + err.Error()) + return nil, errors.New("while seeking to Position " + nodePos.String() + " in cache: " + err.Error()) } currentVal, err := reader.ReadNext() if err != nil { @@ -159,45 +158,45 @@ func GetNode(c *cache.Reader, nodePos position) ([]byte, error) { return currentVal, nil } -func calcNode(c *cache.Reader, nodePos position) ([]byte, error) { - var subtreeStart position - var reader cache.LayerReader +func calcNode(c CacheReader, nodePos Position) ([]byte, error) { + var subtreeStart Position + var reader LayerReader - if nodePos.height == 0 { + if nodePos.Height == 0 { return nil, ErrMissingValueAtBaseLayer } // Find the next cached layer below the current one. for subtreeStart = nodePos; reader == nil; { subtreeStart = subtreeStart.leftChild() - reader = c.GetLayerReader(subtreeStart.height) + reader = c.GetLayerReader(subtreeStart.Height) } // Prepare the reader for traversing the subtree. - err := reader.Seek(subtreeStart.index) + err := reader.Seek(subtreeStart.Index) if err == io.EOF { return PaddingValue.value, nil } if err != nil { - return nil, errors.New("while seeking to position " + subtreeStart.String() + " in cache: " + err.Error()) + return nil, errors.New("while seeking to Position " + subtreeStart.String() + " in cache: " + err.Error()) } var paddingValue []byte - width := uint64(1) << (nodePos.height - subtreeStart.height) + width := uint64(1) << (nodePos.Height - subtreeStart.Height) readerWidth, err := reader.Width() if err != nil { return nil, fmt.Errorf("while getting reader width: %v", err) } - if readerWidth < subtreeStart.index+width { - paddingPos := position{ - index: readerWidth, - height: subtreeStart.height, + if readerWidth < subtreeStart.Index+width { + paddingPos := Position{ + Index: readerWidth, + Height: subtreeStart.Height, } paddingValue, err = calcNode(c, paddingPos) if err == ErrMissingValueAtBaseLayer { paddingValue = PaddingValue.value } else if err != nil { - return nil, errors.New("while calculating ephemeral node at position " + paddingPos.String() + ": " + err.Error()) + return nil, errors.New("while calculating ephemeral node at Position " + paddingPos.String() + ": " + err.Error()) } } @@ -212,23 +211,23 @@ func calcNode(c *cache.Reader, nodePos position) ([]byte, error) { // subtreeDefinition returns the definition (firstLeaf and root positions, width) for the minimal subtree whose // base layer includes p and where the root is on a cached layer. If no cached layer exists above the base layer, the // subtree will reach the root of the original tree. -func subtreeDefinition(c *cache.Reader, p position) (root, firstLeaf position, width uint64, err error) { +func subtreeDefinition(c CacheReader, p Position) (root, firstLeaf Position, width uint64, err error) { // maxRootHeight represents the max height of the tree, based on the width of base layer. This is used to prevent an // infinite loop. - width, err = c.GetLayerReader(p.height).Width() + width, err = c.GetLayerReader(p.Height).Width() if err != nil { - return position{}, position{}, 0, err + return Position{}, Position{}, 0, err } - maxRootHeight := cache.RootHeightFromWidth(width) - for root = p.parent(); root.height < maxRootHeight; root = root.parent() { - if layer := c.GetLayerReader(root.height); layer != nil { + maxRootHeight := RootHeightFromWidth(width) + for root = p.parent(); root.Height < maxRootHeight; root = root.parent() { + if layer := c.GetLayerReader(root.Height); layer != nil { break } } - subtreeHeight := root.height - p.height - firstLeaf = position{ - index: root.index << subtreeHeight, - height: p.height, + subtreeHeight := root.Height - p.Height + firstLeaf = Position{ + Index: root.Index << subtreeHeight, + Height: p.Height, } return root, firstLeaf, 1 << subtreeHeight, err } diff --git a/proving_test.go b/proving_test.go index 5bb3e1e..eed1b9c 100644 --- a/proving_test.go +++ b/proving_test.go @@ -1,4 +1,4 @@ -package merkle +package merkle_test import ( "encoding/hex" @@ -57,7 +57,7 @@ func TestGenerateProof(t *testing.T) { r.EqualValues(expectedProof, proof) var expectedLeaves nodes - for _, i := range leavesToProve.asSortedSlice() { + for _, i := range leavesToProve.AsSortedSlice() { expectedLeaves = append(expectedLeaves, NewNodeFromUint64(i)) } r.EqualValues(expectedLeaves, leaves) @@ -105,7 +105,7 @@ func BenchmarkGenerateProof(b *testing.B) { r.EqualValues(expectedProof, proof) var expectedLeaves nodes - for _, i := range leavesToProve.asSortedSlice() { + for _, i := range leavesToProve.AsSortedSlice() { expectedLeaves = append(expectedLeaves, NewNodeFromUint64(i)) } r.EqualValues(expectedLeaves, leaves) @@ -153,7 +153,7 @@ func TestGenerateProofWithRoot(t *testing.T) { r.EqualValues(expectedProof, proof) var expectedLeaves nodes - for _, i := range leavesToProve.asSortedSlice() { + for _, i := range leavesToProve.AsSortedSlice() { expectedLeaves = append(expectedLeaves, NewNodeFromUint64(i)) } r.EqualValues(expectedLeaves, leaves) @@ -189,7 +189,7 @@ func TestGenerateProofWithoutCache(t *testing.T) { r.EqualValues(expectedProof, proof) var expectedLeaves nodes - for _, i := range leavesToProve.asSortedSlice() { + for _, i := range leavesToProve.AsSortedSlice() { expectedLeaves = append(expectedLeaves, NewNodeFromUint64(i)) } r.EqualValues(expectedLeaves, leaves) @@ -228,7 +228,7 @@ func TestGenerateProofWithSingleLayerCache(t *testing.T) { r.EqualValues(expectedProof, proof) var expectedLeaves nodes - for _, i := range leavesToProve.asSortedSlice() { + for _, i := range leavesToProve.AsSortedSlice() { expectedLeaves = append(expectedLeaves, NewNodeFromUint64(i)) } r.EqualValues(expectedLeaves, leaves) @@ -267,7 +267,7 @@ func TestGenerateProofWithSingleLayerCache2(t *testing.T) { r.EqualValues(expectedProof, proof) var expectedLeaves nodes - for _, i := range leavesToProve.asSortedSlice() { + for _, i := range leavesToProve.AsSortedSlice() { expectedLeaves = append(expectedLeaves, NewNodeFromUint64(i)) } r.EqualValues(expectedLeaves, leaves) @@ -337,7 +337,7 @@ func TestGenerateProofUnbalanced(t *testing.T) { r.EqualValues(expectedProof, proof) var expectedLeaves nodes - for _, i := range leavesToProve.asSortedSlice() { + for _, i := range leavesToProve.AsSortedSlice() { expectedLeaves = append(expectedLeaves, NewNodeFromUint64(i)) } r.EqualValues(expectedLeaves, leaves) @@ -375,7 +375,7 @@ func TestGenerateProofUnbalanced2(t *testing.T) { r.EqualValues(expectedProof, proof) var expectedLeaves nodes - for _, i := range leavesToProve.asSortedSlice() { + for _, i := range leavesToProve.AsSortedSlice() { expectedLeaves = append(expectedLeaves, NewNodeFromUint64(i)) } r.EqualValues(expectedLeaves, leaves) @@ -413,7 +413,7 @@ func TestGenerateProofUnbalanced3(t *testing.T) { r.EqualValues(expectedProof, proof) var expectedLeaves nodes - for _, i := range leavesToProve.asSortedSlice() { + for _, i := range leavesToProve.AsSortedSlice() { expectedLeaves = append(expectedLeaves, NewNodeFromUint64(i)) } r.EqualValues(expectedLeaves, leaves) @@ -434,7 +434,8 @@ var someError = errors.New("some error") type seekErrorReader struct{} -var _ cache.LayerReadWriter = &seekErrorReader{} +// A compile time check to ensure that seekErrorReader fully implements LayerReadWriter. +var _ cache.LayerReadWriter = (*seekErrorReader)(nil) func (seekErrorReader) Seek(index uint64) error { return someError } func (seekErrorReader) ReadNext() ([]byte, error) { panic("implement me") } @@ -444,7 +445,8 @@ func (seekErrorReader) Flush() error { return nil } type readErrorReader struct{} -var _ cache.LayerReadWriter = &readErrorReader{} +// A compile time check to ensure that readErrorReader fully implements LayerReadWriter. +var _ cache.LayerReadWriter = (*readErrorReader)(nil) func (readErrorReader) Seek(index uint64) error { return nil } func (readErrorReader) ReadNext() ([]byte, error) { return nil, someError } @@ -454,7 +456,8 @@ func (readErrorReader) Flush() error { return nil } type seekEOFReader struct{} -var _ cache.LayerReadWriter = &seekEOFReader{} +// A compile time check to ensure that seekEOFReader fully implements LayerReadWriter. +var _ cache.LayerReadWriter = (*seekEOFReader)(nil) func (seekEOFReader) Seek(index uint64) error { return io.EOF } func (seekEOFReader) ReadNext() ([]byte, error) { panic("implement me") } @@ -464,7 +467,8 @@ func (seekEOFReader) Flush() error { return nil } type widthReader struct{ width uint64 } -var _ cache.LayerReadWriter = &widthReader{} +// A compile time check to ensure that widthReader fully implements LayerReadWriter. +var _ cache.LayerReadWriter = (*widthReader)(nil) func (r widthReader) Seek(index uint64) error { return nil } func (r widthReader) ReadNext() ([]byte, error) { return nil, someError } @@ -485,7 +489,7 @@ func TestGetNode(t *testing.T) { node, err := GetNode(cacheReader, nodePos) r.Error(err) - r.Equal("while seeking to position in cache: some error", err.Error()) + r.Equal("while seeking to Position in cache: some error", err.Error()) r.Nil(node) } @@ -513,11 +517,11 @@ func TestGetNode3(t *testing.T) { cacheReader, err := cacheWriter.GetReader() r.NoError(err) - nodePos := position{height: 1} + nodePos := position{Height: 1} node, err := GetNode(cacheReader, nodePos) r.Error(err) - r.Equal("while seeking to position in cache: some error", err.Error()) + r.Equal("while seeking to Position in cache: some error", err.Error()) r.Nil(node) } @@ -530,11 +534,11 @@ func TestGetNode4(t *testing.T) { cacheReader, err := cacheWriter.GetReader() r.NoError(err) - nodePos := position{height: 2} + nodePos := position{Height: 2} node, err := GetNode(cacheReader, nodePos) r.Error(err) - r.Equal("while calculating ephemeral node at position : while seeking to position in cache: some error", err.Error()) + r.Equal("while calculating ephemeral node at Position : while seeking to Position in cache: some error", err.Error()) r.Nil(node) } @@ -546,7 +550,7 @@ func TestGetNode5(t *testing.T) { cacheReader, err := cacheWriter.GetReader() r.NoError(err) - nodePos := position{height: 1} + nodePos := position{Height: 1} node, err := GetNode(cacheReader, nodePos) r.Error(err) diff --git a/shared/consts.go b/shared/consts.go new file mode 100644 index 0000000..87a51a4 --- /dev/null +++ b/shared/consts.go @@ -0,0 +1,5 @@ +package shared + +const ( + NodeSize = 32 +) diff --git a/shared/types.go b/shared/types.go new file mode 100644 index 0000000..69759da --- /dev/null +++ b/shared/types.go @@ -0,0 +1,40 @@ +package shared + +type HashFunc func(lChild, rChild []byte) []byte + +// LayerReadWriter is a combined reader-writer. Note that the Seek() method only belongs to the LayerReader interface +// and does not affect the LayerWriter. +type LayerReadWriter interface { + LayerReader + LayerWriter +} + +type LayerReader interface { + Seek(index uint64) error + ReadNext() ([]byte, error) + Width() (uint64, error) +} + +type LayerWriter interface { + Append(p []byte) (n int, err error) + Flush() error +} + +type CacheWriter interface { + SetLayer(layerHeight uint, rw LayerReadWriter) + GetLayerWriter(layerHeight uint) (LayerWriter, error) + SetHash(hashFunc HashFunc) + GetReader() (CacheReader, error) +} + +type CacheReader interface { + Layers() map[uint]LayerReadWriter + GetLayerReader(layerHeight uint) LayerReader + GetHashFunc() HashFunc + GetLayerFactory() LayerFactory + GetCachingPolicy() CachingPolicy +} + +type CachingPolicy func(layerHeight uint) (shouldCacheLayer bool) + +type LayerFactory func(layerHeight uint) (LayerReadWriter, error) diff --git a/shared/utils.go b/shared/utils.go new file mode 100644 index 0000000..6912bb9 --- /dev/null +++ b/shared/utils.go @@ -0,0 +1,7 @@ +package shared + +import "math" + +func RootHeightFromWidth(width uint64) uint { + return uint(math.Ceil(math.Log2(float64(width)))) +} diff --git a/treebuilder.go b/treebuilder.go index e10501e..528b21d 100644 --- a/treebuilder.go +++ b/treebuilder.go @@ -1,11 +1,9 @@ package merkle -import "github.com/spacemeshos/merkle-tree/cache" - type TreeBuilder struct { hash HashFunc - leavesToProves set - cacheWriter *cache.Writer + leavesToProves Set + cacheWriter CacheWriter minHeight uint } @@ -18,7 +16,7 @@ func (tb TreeBuilder) Build() (*Tree, error) { tb.hash = GetSha256Parent } if tb.cacheWriter == nil { - tb.cacheWriter = cache.NewWriter(cache.SpecificLayersPolicy(map[uint]bool{}), nil) + tb.cacheWriter = disabledCacheWriter{} } tb.cacheWriter.SetHash(tb.hash) writer, err := tb.cacheWriter.GetLayerWriter(0) @@ -28,7 +26,7 @@ func (tb TreeBuilder) Build() (*Tree, error) { return &Tree{ baseLayer: newLayer(0, writer), hash: tb.hash, - leavesToProve: newSparseBoolStack(tb.leavesToProves), + leavesToProve: NewSparseBoolStack(tb.leavesToProves), cacheWriter: tb.cacheWriter, minHeight: tb.minHeight, }, nil @@ -44,7 +42,7 @@ func (tb TreeBuilder) WithLeavesToProve(leavesToProves map[uint64]bool) TreeBuil return tb } -func (tb TreeBuilder) WithCacheWriter(cacheWriter *cache.Writer) TreeBuilder { +func (tb TreeBuilder) WithCacheWriter(cacheWriter CacheWriter) TreeBuilder { tb.cacheWriter = cacheWriter return tb } @@ -62,6 +60,16 @@ func NewProvingTree(leavesToProves map[uint64]bool) (*Tree, error) { return NewTreeBuilder().WithLeavesToProve(leavesToProves).Build() } -func NewCachingTree(cacheWriter *cache.Writer) (*Tree, error) { +func NewCachingTree(cacheWriter CacheWriter) (*Tree, error) { return NewTreeBuilder().WithCacheWriter(cacheWriter).Build() } + +type disabledCacheWriter struct{} + +// A compile time check to ensure that disabledCacheWriter fully implements CacheWriter. +var _ CacheWriter = (*disabledCacheWriter)(nil) + +func (disabledCacheWriter) SetLayer(layerHeight uint, rw LayerReadWriter) {} +func (disabledCacheWriter) GetLayerWriter(layerHeight uint) (LayerWriter, error) { return nil, nil } +func (disabledCacheWriter) SetHash(hashFunc HashFunc) {} +func (disabledCacheWriter) GetReader() (CacheReader, error) { return nil, nil } diff --git a/validation.go b/validation.go index 6c91b37..b4eab3a 100644 --- a/validation.go +++ b/validation.go @@ -17,7 +17,7 @@ func ValidatePartialTree(leafIndices []uint64, leaves, proof [][]byte, expectedR if err != nil { return false, err } - root, _, err := v.calcRoot(MaxUint) + root, _, err := v.CalcRoot(MaxUint) return bytes.Equal(root, expectedRoot), err } @@ -30,11 +30,11 @@ func ValidatePartialTreeWithParkingSnapshots(leafIndices []uint64, leaves, proof if err != nil { return false, nil, err } - root, parkingSnapshots, err := v.calcRoot(MaxUint) + root, parkingSnapshots, err := v.CalcRoot(MaxUint) return bytes.Equal(root, expectedRoot), parkingSnapshots, err } -func newValidator(leafIndices []uint64, leaves, proof [][]byte, hash HashFunc, storeSnapshots bool) (*validator, error) { +func newValidator(leafIndices []uint64, leaves, proof [][]byte, hash HashFunc, storeSnapshots bool) (*Validator, error) { if len(leafIndices) != len(leaves) { return nil, fmt.Errorf("number of leaves (%d) must equal number of indices (%d)", len(leaves), len(leafIndices)) @@ -45,48 +45,48 @@ func newValidator(leafIndices []uint64, leaves, proof [][]byte, hash HashFunc, s if !sort.SliceIsSorted(leafIndices, func(i, j int) bool { return leafIndices[i] < leafIndices[j] }) { return nil, errors.New("leafIndices are not sorted") } - if len(setOf(leafIndices...)) != len(leafIndices) { + if len(SetOf(leafIndices...)) != len(leafIndices) { return nil, errors.New("leafIndices contain duplicates") } proofNodes := &proofIterator{proof} - leafIt := &leafIterator{leafIndices, leaves} + leafIt := &LeafIterator{leafIndices, leaves} - return &validator{leaves: leafIt, proofNodes: proofNodes, hash: hash, storeSnapshots: storeSnapshots}, nil + return &Validator{Leaves: leafIt, ProofNodes: proofNodes, Hash: hash, StoreSnapshots: storeSnapshots}, nil } -type validator struct { - leaves *leafIterator - proofNodes *proofIterator - hash HashFunc - storeSnapshots bool +type Validator struct { + Leaves *LeafIterator + ProofNodes *proofIterator + Hash HashFunc + StoreSnapshots bool } type ParkingSnapshot [][]byte -func (v *validator) calcRoot(stopAtLayer uint) ([]byte, []ParkingSnapshot, error) { - activePos, activeNode, err := v.leaves.next() +func (v *Validator) CalcRoot(stopAtLayer uint) ([]byte, []ParkingSnapshot, error) { + activePos, activeNode, err := v.Leaves.next() if err != nil { return nil, nil, err } var lChild, rChild, sibling []byte var parkingSnapshots, subTreeSnapshots []ParkingSnapshot - if v.storeSnapshots { + if v.StoreSnapshots { parkingSnapshots = []ParkingSnapshot{nil} } for { - if activePos.height == stopAtLayer { + if activePos.Height == stopAtLayer { break } // The activeNode's sibling should be calculated iff it's an ancestor of the next proven leaf. Otherwise, the // sibling is the next node in the proof. - nextLeafPos, _, err := v.leaves.peek() + nextLeafPos, _, err := v.Leaves.peek() if err == nil && activePos.sibling().isAncestorOf(nextLeafPos) { - sibling, subTreeSnapshots, err = v.calcRoot(activePos.height) + sibling, subTreeSnapshots, err = v.CalcRoot(activePos.Height) if err != nil { return nil, nil, err } } else { - sibling, err = v.proofNodes.next() + sibling, err = v.ProofNodes.next() if err == noMoreItems { break } @@ -102,7 +102,7 @@ func (v *validator) calcRoot(stopAtLayer uint) ([]byte, []ParkingSnapshot, error subTreeSnapshots = nil } } - activeNode = v.hash(lChild, rChild) + activeNode = v.Hash(lChild, rChild) activePos = activePos.parent() } return activeNode, parkingSnapshots, nil diff --git a/validation_test.go b/validation_test.go index ff3b04e..930dd3e 100644 --- a/validation_test.go +++ b/validation_test.go @@ -1,4 +1,4 @@ -package merkle +package merkle_test import ( "fmt" @@ -300,13 +300,13 @@ func TestValidatePartialTreeErrors(t *testing.T) { func TestValidator_calcRoot(t *testing.T) { r := require.New(t) v := validator{ - leaves: &leafIterator{}, - proofNodes: nil, - hash: nil, - storeSnapshots: false, + Leaves: &leafIterator{}, + ProofNodes: nil, + Hash: nil, + StoreSnapshots: false, } - root, _, err := v.calcRoot(0) + root, _, err := v.CalcRoot(0) r.Error(err) r.Equal("no more items", err.Error())