Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trees cache merging support #8

Merged
merged 4 commits into from
May 15, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,12 @@ type (
LayerReadWriter = shared.LayerReadWriter
CacheWriter = shared.CacheWriter
CacheReader = shared.CacheReader
LayerFactory = shared.LayerFactory
CachingPolicy = shared.CachingPolicy
)

var RootHeightFromWidth = shared.RootHeightFromWidth

type CachingPolicy func(layerHeight uint) (shouldCacheLayer bool)

type LayerFactory func(layerHeight uint) (LayerReadWriter, error)

type Writer struct {
*cache
}
Expand Down Expand Up @@ -101,6 +99,14 @@ func (c *Reader) GetHashFunc() HashFunc {
return c.hash
}

func (c *Reader) GetLayerFactory() LayerFactory {
return c.generateLayer
}

func (c *Reader) GetCachingPolicy() CachingPolicy {
return c.shouldCacheLayer
}

type cache struct {
layers map[uint]LayerReadWriter
hash HashFunc
Expand Down
31 changes: 25 additions & 6 deletions cache/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,16 @@ func Merge(caches []CacheReader) (*Reader, error) {
layers[height] = group
}

cache := &cache{layers: layers}
hashFunc := caches[0].GetHashFunc()
layerFactory := caches[0].GetLayerFactory()
cachingPolicy := caches[0].GetCachingPolicy()

cache := &cache{
layers: layers,
hash: hashFunc,
shouldCacheLayer: cachingPolicy,
generateLayer: layerFactory,
}
return &Reader{cache}, nil
}

Expand All @@ -46,8 +55,13 @@ func BuildTop(cacheReader CacheReader) (*Reader, []byte, error) {
}
}

// Create a new subtree with the cache highest layer as its leaves.
subtreeWriter := NewWriter(MinHeightPolicy(0), MakeSliceReadWriterFactory())
// Create an adjusted caching policy for the new subtree.
newCachingPolicy := func(layerHeight uint) bool {
return cacheReader.GetCachingPolicy()(maxHeight + layerHeight)
}

// Create a subtree with the cache highest layer as its leaves.
subtreeWriter := NewWriter(newCachingPolicy, cacheReader.GetLayerFactory())
moshababo marked this conversation as resolved.
Show resolved Hide resolved
subtree, err := merkle.NewTreeBuilder().
WithHashFunc(cacheReader.GetHashFunc()).
WithCacheWriter(subtreeWriter).
Expand All @@ -73,10 +87,15 @@ func BuildTop(cacheReader CacheReader) (*Reader, []byte, error) {
}
}

// Create a new cache with the existing layers.
newCache := &cache{layers: cacheReader.Layers()}
// Clone the existing cache.
newCache := &cache{
layers: cacheReader.Layers(),
hash: cacheReader.GetHashFunc(),
shouldCacheLayer: cacheReader.GetCachingPolicy(),
generateLayer: cacheReader.GetLayerFactory(),
}

// Add the subtree layers on top of the existing ones.
// Add the subtree cache layers on top of the existing ones.
for height, layer := range subtreeWriter.layers {
if height == 0 {
continue
Expand Down
140 changes: 77 additions & 63 deletions cache/merge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/binary"
"github.com/spacemeshos/merkle-tree"
"github.com/spacemeshos/merkle-tree/cache/readwriters"
"github.com/spacemeshos/sha256-simd"
"github.com/stretchr/testify/require"
"testing"
)
Expand All @@ -12,9 +13,9 @@ func TestMerge(t *testing.T) {
r := require.New(t)

readers := make([]*Reader, 3)
readers[0] = &Reader{&cache{layers: make(map[uint]LayerReadWriter)}}
readers[1] = &Reader{&cache{layers: make(map[uint]LayerReadWriter)}}
readers[2] = &Reader{&cache{layers: make(map[uint]LayerReadWriter)}}
readers[0] = &Reader{&cache{layers: make(map[uint]LayerReadWriter), hash: merkle.GetSha256Parent}}
readers[1] = &Reader{&cache{layers: make(map[uint]LayerReadWriter), hash: merkle.GetSha256Parent}}
readers[2] = &Reader{&cache{layers: make(map[uint]LayerReadWriter), hash: merkle.GetSha256Parent}}

// Create 9 nodes.
nodes := genNodes(9)
Expand Down Expand Up @@ -78,8 +79,8 @@ func TestMergeFailure2(t *testing.T) {
r := require.New(t)

readers := make([]*Reader, 2)
readers[0] = &Reader{&cache{layers: make(map[uint]LayerReadWriter)}}
readers[1] = &Reader{&cache{layers: make(map[uint]LayerReadWriter)}}
readers[0] = &Reader{&cache{layers: make(map[uint]LayerReadWriter), hash: merkle.GetSha256Parent}}
readers[1] = &Reader{&cache{layers: make(map[uint]LayerReadWriter), hash: merkle.GetSha256Parent}}

readers[0].cache.layers[0] = &readwriters.SliceReadWriter{}

Expand Down Expand Up @@ -155,76 +156,35 @@ func TestMergeAndBuildTop(t *testing.T) {
// Create 32 nodes.
nodes := genNodes(32)

// Add the nodes as leaves to one tree and save its root.
cacheWriter := NewWriter(MinHeightPolicy(0), MakeSliceReadWriterFactory())
tree, err := merkle.NewCachingTree(cacheWriter)
r.NoError(err)
for i := 0; i < len(nodes); i++ {
err := tree.AddLeaf(NewNodeFromUint64(uint64(i)))
r.NoError(err)
// Create a custom (non-default) hash function.
hashFunc := func(lChild, rChild []byte) []byte {
bytes := append(lChild, rChild...)
res := sha256.Sum256(append([]byte("0"), bytes...))
return res[:]
}
treeRoot := tree.Root()

// Add the nodes as leaves to 4 separate trees.
cacheWriters := make([]CacheWriter, 4)
cacheReaders := make([]CacheReader, 4)
trees := make([]*merkle.Tree, 4)
for i := 0; i < 4; i++ {
cacheWriter := NewWriter(MinHeightPolicy(0), MakeSliceReadWriterFactory())
tree, err := merkle.NewCachingTree(cacheWriter)
r.NoError(err)

cacheWriters[i] = cacheWriter
trees[i] = tree
}
for i := 0; i < len(nodes); i++ {
err := trees[i/8].AddLeaf(NewNodeFromUint64(uint64(i)))
r.NoError(err)
}
for i := 0; i < 4; i++ {
reader, err := cacheWriters[i].GetReader()
r.NoError(err)
cacheReaders[i] = reader
}

// Merge caches.
cacheReader, err := Merge(cacheReaders)
r.NoError(err)
r.NotNil(cacheReader)

// Create the upper subtree.
cacheReader, mergeRoot, err := BuildTop(cacheReader)
r.NoError(err)
r.NotNil(cacheReader)

// Verify that the 4 trees merge root is the same as the main tree root.
r.Equal(mergeRoot, treeRoot)
}

// -- FAILING --
func TestMergeAndBuildTopUnbalanced(t *testing.T) {
r := require.New(t)

// Create 29 nodes.
nodes := genNodes(29)

// Add the nodes as leaves to one tree and save its root.
// Add the nodes as leaves to a baseline tree.
cacheWriter := NewWriter(MinHeightPolicy(0), MakeSliceReadWriterFactory())
tree, err := merkle.NewCachingTree(cacheWriter)
baselineTree, err := merkle.NewTreeBuilder().
WithHashFunc(hashFunc).
WithCacheWriter(cacheWriter).
Build()
r.NoError(err)
for i := 0; i < len(nodes); i++ {
err := tree.AddLeaf(NewNodeFromUint64(uint64(i)))
err := baselineTree.AddLeaf(NewNodeFromUint64(uint64(i)))
r.NoError(err)
}
treeRoot := tree.Root()

// Add the nodes as leaves to 4 separate trees.
cacheWriters := make([]CacheWriter, 4)
cacheReaders := make([]CacheReader, 4)
trees := make([]*merkle.Tree, 4)
for i := 0; i < 4; i++ {
cacheWriter := NewWriter(MinHeightPolicy(0), MakeSliceReadWriterFactory())
tree, err := merkle.NewCachingTree(cacheWriter)
tree, err := merkle.NewTreeBuilder().
WithHashFunc(hashFunc).
WithCacheWriter(cacheWriter).
Build()
r.NoError(err)

cacheWriters[i] = cacheWriter
Expand All @@ -250,10 +210,63 @@ func TestMergeAndBuildTopUnbalanced(t *testing.T) {
r.NoError(err)
r.NotNil(cacheReader)

// Verify that the 4 trees merge root is the same as the main tree root.
r.Equal(mergeRoot, treeRoot)
// Verify that the 4 trees merge root is the same as the baseline tree root.
r.Equal(mergeRoot, baselineTree.Root())
}

// TODO(moshababo): support unbalanced merge to make this test pass.
//func TestMergeAndBuildTopUnbalanced(t *testing.T) {
// r := require.New(t)
//
// // Create 29 nodes.
// nodes := genNodes(29)
//
// // Add the nodes as leaves to one tree and save its root.
// cacheWriter := NewWriter(MinHeightPolicy(0), MakeSliceReadWriterFactory())
// tree, err := merkle.NewCachingTree(cacheWriter)
// r.NoError(err)
// for i := 0; i < len(nodes); i++ {
// err := tree.AddLeaf(NewNodeFromUint64(uint64(i)))
// r.NoError(err)
// }
// treeRoot := tree.Root()
//
// // Add the nodes as leaves to 4 separate trees.
// cacheWriters := make([]CacheWriter, 4)
// cacheReaders := make([]CacheReader, 4)
// trees := make([]*merkle.Tree, 4)
// for i := 0; i < 4; i++ {
// cacheWriter := NewWriter(MinHeightPolicy(0), MakeSliceReadWriterFactory())
// tree, err := merkle.NewCachingTree(cacheWriter)
// r.NoError(err)
//
// cacheWriters[i] = cacheWriter
// trees[i] = tree
// }
// for i := 0; i < len(nodes); i++ {
// err := trees[i/8].AddLeaf(NewNodeFromUint64(uint64(i)))
// r.NoError(err)
// }
// for i := 0; i < 4; i++ {
// reader, err := cacheWriters[i].GetReader()
// r.NoError(err)
// cacheReaders[i] = reader
// }
//
// // Merge caches.
// cacheReader, err := Merge(cacheReaders)
// r.NoError(err)
// r.NotNil(cacheReader)
//
// // Create the upper subtree.
// cacheReader, mergeRoot, err := BuildTop(cacheReader)
// r.NoError(err)
// r.NotNil(cacheReader)
//
// // Verify that the 4 trees merge root is the same as the main tree root.
// r.Equal(mergeRoot, treeRoot)
//}

func genNodes(num int) [][]byte {
nodes := make([][]byte, num)
for i := 0; i < num; i++ {
Expand All @@ -269,6 +282,7 @@ func NewNodeFromUint64(i uint64) []byte {
}

func assertWidth(r *require.Assertions, expectedWidth int, layerReader LayerReader) {
r.NotNil(layerReader)
width, err := layerReader.Width()
r.NoError(err)
r.Equal(uint64(expectedWidth), width)
Expand Down
1 change: 1 addition & 0 deletions merkle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ func TestNewTreeUnbalancedProof(t *testing.T) {
}

func assertWidth(r *require.Assertions, expectedWidth int, layerReader cache.LayerReader) {
r.NotNil(layerReader)
width, err := layerReader.Width()
r.NoError(err)
r.Equal(uint64(expectedWidth), width)
Expand Down
6 changes: 6 additions & 0 deletions shared/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,10 @@ type CacheReader interface {
Layers() map[uint]LayerReadWriter
GetLayerReader(layerHeight uint) LayerReader
GetHashFunc() HashFunc
GetLayerFactory() LayerFactory
GetCachingPolicy() CachingPolicy
}

type CachingPolicy func(layerHeight uint) (shouldCacheLayer bool)

type LayerFactory func(layerHeight uint) (LayerReadWriter, error)