Skip to content

Commit

Permalink
Fix types to avoid overflows and too large chainID error (#697) (#698)
Browse files Browse the repository at this point in the history
* Fix types

* fix e2e + unit test types

* fix

* DB types

* Naming refactor
  • Loading branch information
ARR552 authored Nov 21, 2024
1 parent 002f436 commit 57ae55e
Show file tree
Hide file tree
Showing 49 changed files with 735 additions and 655 deletions.
38 changes: 19 additions & 19 deletions bridgectrl/bridgectrl.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@ const (

// BridgeController struct
type BridgeController struct {
exitTrees []*MerkleTree
rollupsTree *MerkleTree
networkIDs map[uint]uint8
exitTrees []*MerkleTree
rollupsTree *MerkleTree
merkleTreeIDs map[uint32]uint8
}

// NewBridgeController creates new BridgeController.
func NewBridgeController(ctx context.Context, cfg Config, networks []uint, mtStore interface{}) (*BridgeController, error) {
func NewBridgeController(ctx context.Context, cfg Config, networkIDs []uint32, mtStore interface{}) (*BridgeController, error) {
var (
networkIDs = make(map[uint]uint8)
exitTrees []*MerkleTree
merkleTreeIDs = make(map[uint32]uint8)
exitTrees []*MerkleTree
)

for i, network := range networks {
networkIDs[network] = uint8(i)
mt, err := NewMerkleTree(ctx, mtStore.(merkleTreeStore), cfg.Height, network)
for i, networkID := range networkIDs {
merkleTreeIDs[networkID] = uint8(i)
mt, err := NewMerkleTree(ctx, mtStore.(merkleTreeStore), cfg.Height, networkID)
if err != nil {
return nil, err
}
Expand All @@ -44,14 +44,14 @@ func NewBridgeController(ctx context.Context, cfg Config, networks []uint, mtSto
}

return &BridgeController{
exitTrees: exitTrees,
rollupsTree: rollupsTree,
networkIDs: networkIDs,
exitTrees: exitTrees,
rollupsTree: rollupsTree,
merkleTreeIDs: merkleTreeIDs,
}, nil
}

func (bt *BridgeController) GetNetworkID(networkID uint) (uint8, error) {
tID, found := bt.networkIDs[networkID]
func (bt *BridgeController) GetMerkleTreeID(networkID uint32) (uint8, error) {
tID, found := bt.merkleTreeIDs[networkID]
if !found {
return 0, gerror.ErrNetworkNotRegister
}
Expand All @@ -61,16 +61,16 @@ func (bt *BridgeController) GetNetworkID(networkID uint) (uint8, error) {
// AddDeposit adds deposit information to the bridge tree.
func (bt *BridgeController) AddDeposit(ctx context.Context, deposit *etherman.Deposit, depositID uint64, dbTx pgx.Tx) error {
leaf := hashDeposit(deposit)
tID, err := bt.GetNetworkID(deposit.NetworkID)
tID, err := bt.GetMerkleTreeID(deposit.NetworkID)
if err != nil {
return err
}
return bt.exitTrees[tID].addLeaf(ctx, depositID, leaf, deposit.DepositCount, dbTx)
}

// ReorgMT reorg the specific merkle tree.
func (bt *BridgeController) ReorgMT(ctx context.Context, depositCount uint, networkID uint, dbTx pgx.Tx) error {
tID, err := bt.GetNetworkID(networkID)
func (bt *BridgeController) ReorgMT(ctx context.Context, depositCount uint32, networkID uint32, dbTx pgx.Tx) error {
tID, err := bt.GetMerkleTreeID(networkID)
if err != nil {
return err
}
Expand All @@ -79,8 +79,8 @@ func (bt *BridgeController) ReorgMT(ctx context.Context, depositCount uint, netw

// GetExitRoot returns the dedicated merkle tree's root.
// only use for the test purpose
func (bt *BridgeController) GetExitRoot(ctx context.Context, networkID int, dbTx pgx.Tx) ([]byte, error) {
return bt.exitTrees[networkID].getRoot(ctx, dbTx)
func (bt *BridgeController) GetExitRoot(ctx context.Context, tID uint8, dbTx pgx.Tx) ([]byte, error) {
return bt.exitTrees[tID].getRoot(ctx, dbTx)
}

func (bt *BridgeController) AddRollupExitLeaf(ctx context.Context, rollupLeaf etherman.RollupExitLeaf, dbTx pgx.Tx) error {
Expand Down
10 changes: 5 additions & 5 deletions bridgectrl/bridgectrl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func TestBridgeTree(t *testing.T) {
store, err := pgstorage.NewPostgresStorage(dbCfg)
require.NoError(t, err)
ctx := context.Background()
bt, err := NewBridgeController(ctx, cfg, []uint{0, 1000}, store)
bt, err := NewBridgeController(ctx, cfg, []uint32{0, 1000}, store)
require.NoError(t, err)

t.Run("Test adding deposit for the bridge tree", func(t *testing.T) {
Expand All @@ -71,7 +71,7 @@ func TestBridgeTree(t *testing.T) {
DestinationNetwork: testVector.DestinationNetwork,
DestinationAddress: common.HexToAddress(testVector.DestinationAddress),
BlockID: blockID,
DepositCount: uint(i),
DepositCount: uint32(i),
Metadata: common.FromHex(testVector.Metadata),
}
leafHash := hashDeposit(deposit)
Expand All @@ -82,10 +82,10 @@ func TestBridgeTree(t *testing.T) {
require.NoError(t, err)

// test reorg
orgRoot, err := bt.exitTrees[0].store.GetRoot(ctx, uint(i), 0, nil)
orgRoot, err := bt.exitTrees[0].store.GetRoot(ctx, uint32(i), 0, nil)
require.NoError(t, err)
require.NoError(t, store.Reset(ctx, uint64(i), deposit.NetworkID, nil))
err = bt.ReorgMT(ctx, uint(i), testVectors[i].OriginalNetwork, nil)
err = bt.ReorgMT(ctx, uint32(i), testVectors[i].OriginalNetwork, nil)
require.NoError(t, err)
blockID, err = store.AddBlock(context.TODO(), block, nil)
require.NoError(t, err)
Expand All @@ -94,7 +94,7 @@ func TestBridgeTree(t *testing.T) {
require.NoError(t, err)
err = bt.AddDeposit(ctx, deposit, depositID, nil)
require.NoError(t, err)
newRoot, err := bt.exitTrees[0].store.GetRoot(ctx, uint(i), 0, nil)
newRoot, err := bt.exitTrees[0].store.GetRoot(ctx, uint32(i), 0, nil)
require.NoError(t, err)
assert.Equal(t, orgRoot, newRoot)

Expand Down
6 changes: 3 additions & 3 deletions bridgectrl/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ import (
type merkleTreeStore interface {
Get(ctx context.Context, key []byte, dbTx pgx.Tx) ([][]byte, error)
BulkSet(ctx context.Context, rows [][]interface{}, dbTx pgx.Tx) error
GetRoot(ctx context.Context, depositCount uint, network uint, dbTx pgx.Tx) ([]byte, error)
SetRoot(ctx context.Context, root []byte, depositID uint64, network uint, dbTx pgx.Tx) error
GetLastDepositCount(ctx context.Context, network uint, dbTx pgx.Tx) (uint, error)
GetRoot(ctx context.Context, depositCount uint32, network uint32, dbTx pgx.Tx) ([]byte, error)
SetRoot(ctx context.Context, root []byte, depositID uint64, network uint32, dbTx pgx.Tx) error
GetLastDepositCount(ctx context.Context, networkID uint32, dbTx pgx.Tx) (uint32, error)
AddRollupExitLeaves(ctx context.Context, rows [][]interface{}, dbTx pgx.Tx) error
GetRollupExitLeavesByRoot(ctx context.Context, root common.Hash, dbTx pgx.Tx) ([]etherman.RollupExitLeaf, error)
GetLatestRollupExitLeaves(ctx context.Context, dbTx pgx.Tx) ([]etherman.RollupExitLeaf, error)
Expand Down
36 changes: 18 additions & 18 deletions bridgectrl/merkletree.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ var zeroHashes [][KeyLen]byte
// MerkleTree struct
type MerkleTree struct {
// store is the database storage to store all node data
store merkleTreeStore
network uint
store merkleTreeStore
networkID uint32
// height is the depth of the merkle tree
height uint8
// count is the number of deposit
count uint
count uint32
// siblings is the array of sibling of the last leaf added
siblings [][KeyLen]byte
}
Expand All @@ -38,8 +38,8 @@ func init() {
}

// NewMerkleTree creates new MerkleTree.
func NewMerkleTree(ctx context.Context, store merkleTreeStore, height uint8, network uint) (*MerkleTree, error) {
depositCnt, err := store.GetLastDepositCount(ctx, network, nil)
func NewMerkleTree(ctx context.Context, store merkleTreeStore, height uint8, networkID uint32) (*MerkleTree, error) {
depositCnt, err := store.GetLastDepositCount(ctx, networkID, nil)
if err != nil {
if err != gerror.ErrStorageNotFound {
return nil, err
Expand All @@ -50,10 +50,10 @@ func NewMerkleTree(ctx context.Context, store merkleTreeStore, height uint8, net
}

mt := &MerkleTree{
store: store,
network: network,
height: height,
count: depositCnt,
store: store,
networkID: networkID,
height: height,
count: depositCnt,
}
mt.siblings, err = mt.initSiblings(ctx, nil)

Expand Down Expand Up @@ -110,7 +110,7 @@ func (mt *MerkleTree) initSiblings(ctx context.Context, dbTx pgx.Tx) ([][KeyLen]
return siblings, nil
}

func (mt *MerkleTree) addLeaf(ctx context.Context, depositID uint64, leaf [KeyLen]byte, index uint, dbTx pgx.Tx) error {
func (mt *MerkleTree) addLeaf(ctx context.Context, depositID uint64, leaf [KeyLen]byte, index uint32, dbTx pgx.Tx) error {
if index != mt.count {
return fmt.Errorf("mismatched deposit count: %d, expected: %d", index, mt.count)
}
Expand Down Expand Up @@ -141,7 +141,7 @@ func (mt *MerkleTree) addLeaf(ctx context.Context, depositID uint64, leaf [KeyLe
}
}

err := mt.store.SetRoot(ctx, cur[:], depositID, mt.network, dbTx)
err := mt.store.SetRoot(ctx, cur[:], depositID, mt.networkID, dbTx)
if err != nil {
return err
}
Expand All @@ -157,7 +157,7 @@ func (mt *MerkleTree) addLeaf(ctx context.Context, depositID uint64, leaf [KeyLe
return nil
}

func (mt *MerkleTree) resetLeaf(ctx context.Context, depositCount uint, dbTx pgx.Tx) error {
func (mt *MerkleTree) resetLeaf(ctx context.Context, depositCount uint32, dbTx pgx.Tx) error {
var err error
mt.count = depositCount
mt.siblings, err = mt.initSiblings(ctx, dbTx)
Expand All @@ -169,7 +169,7 @@ func (mt *MerkleTree) getRoot(ctx context.Context, dbTx pgx.Tx) ([]byte, error)
if mt.count == 0 {
return zeroHashes[mt.height][:], nil
}
return mt.store.GetRoot(ctx, mt.count-1, mt.network, dbTx)
return mt.store.GetRoot(ctx, mt.count-1, mt.networkID, dbTx)
}

func buildIntermediate(leaves [][KeyLen]byte) ([][][]byte, [][32]byte) {
Expand All @@ -191,7 +191,7 @@ func (mt *MerkleTree) updateLeaf(ctx context.Context, depositID uint64, leaves [
nodes [][][][]byte
ns [][][]byte
)
initLeavesCount := uint(len(leaves))
initLeavesCount := uint32(len(leaves))
if len(leaves) == 0 {
leaves = append(leaves, zeroHashes[0])
}
Expand All @@ -207,7 +207,7 @@ func (mt *MerkleTree) updateLeaf(ctx context.Context, depositID uint64, leaves [
return fmt.Errorf("error: more than one root detected: %+v", nodes)
}
log.Debug("Root calculated: ", common.Bytes2Hex(ns[0][0]))
err := mt.store.SetRoot(ctx, ns[0][0], depositID, mt.network, dbTx)
err := mt.store.SetRoot(ctx, ns[0][0], depositID, mt.networkID, dbTx)
if err != nil {
return err
}
Expand Down Expand Up @@ -329,7 +329,7 @@ func (mt MerkleTree) addRollupExitLeaf(ctx context.Context, rollupLeaf etherman.
for i := len(storedRollupLeaves); i < int(rollupLeaf.RollupId); i++ {
storedRollupLeaves = append(storedRollupLeaves, etherman.RollupExitLeaf{
BlockID: rollupLeaf.BlockID,
RollupId: uint(i + 1),
RollupId: uint32(i + 1),
})
}
if storedRollupLeaves[rollupLeaf.RollupId-1].RollupId == rollupLeaf.RollupId {
Expand All @@ -352,7 +352,7 @@ func (mt MerkleTree) addRollupExitLeaf(ctx context.Context, rollupLeaf etherman.
return nil
}

func ComputeSiblings(rollupIndex uint, leaves [][KeyLen]byte, height uint8) ([][KeyLen]byte, common.Hash, error) {
func ComputeSiblings(rollupIndex uint32, leaves [][KeyLen]byte, height uint8) ([][KeyLen]byte, common.Hash, error) {
var ns [][][]byte
if len(leaves) == 0 {
leaves = append(leaves, zeroHashes[0])
Expand Down Expand Up @@ -382,7 +382,7 @@ func ComputeSiblings(rollupIndex uint, leaves [][KeyLen]byte, height uint8) ([][
}
// Find the index of the leave in the next level of the tree.
// Divide the index by 2 to find the position in the upper level
index = uint(float64(index) / 2) //nolint:gomnd
index = uint32(float64(index) / 2) //nolint:gomnd
ns = nsi
leaves = hashes
}
Expand Down
16 changes: 8 additions & 8 deletions bridgectrl/merkletree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func TestLeafHash(t *testing.T) {
DestinationNetwork: testVector.DestinationNetwork,
DestinationAddress: common.HexToAddress(testVector.DestinationAddress),
BlockNumber: 0,
DepositCount: uint(ti + 1),
DepositCount: uint32(ti + 1),
Metadata: common.FromHex(testVector.Metadata),
}
leafHash := hashDeposit(deposit)
Expand Down Expand Up @@ -111,7 +111,7 @@ func TestMTAddLeaf(t *testing.T) {
DestinationNetwork: testVector.NewLeaf.DestinationNetwork,
DestinationAddress: common.HexToAddress(testVector.NewLeaf.DestinationAddress),
BlockNumber: 0,
DepositCount: uint(i),
DepositCount: uint32(i),
Metadata: common.FromHex(testVector.NewLeaf.Metadata),
}
depositID, err := store.AddDeposit(ctx, deposit, nil)
Expand All @@ -123,15 +123,15 @@ func TestMTAddLeaf(t *testing.T) {
leafValue, err := formatBytes32String(leaf[2:])
require.NoError(t, err)

err = mt.addLeaf(ctx, depositIDs[i], leafValue, uint(i), nil)
err = mt.addLeaf(ctx, depositIDs[i], leafValue, uint32(i), nil)
require.NoError(t, err)
}
curRoot, err := mt.getRoot(ctx, nil)
require.NoError(t, err)
assert.Equal(t, hex.EncodeToString(curRoot), testVector.CurrentRoot[2:])

leafHash := hashDeposit(deposit)
err = mt.addLeaf(ctx, depositIDs[len(depositIDs)-1], leafHash, uint(len(testVector.ExistingLeaves)), nil)
err = mt.addLeaf(ctx, depositIDs[len(depositIDs)-1], leafHash, uint32(len(testVector.ExistingLeaves)), nil)
require.NoError(t, err)
newRoot, err := mt.getRoot(ctx, nil)
require.NoError(t, err)
Expand Down Expand Up @@ -179,7 +179,7 @@ func TestMTGetProof(t *testing.T) {
DestinationNetwork: leaf.DestinationNetwork,
DestinationAddress: common.HexToAddress(leaf.DestinationAddress),
BlockID: blockID,
DepositCount: uint(li),
DepositCount: uint32(li),
Metadata: common.FromHex(leaf.Metadata),
}
depositID, err := store.AddDeposit(ctx, deposit, nil)
Expand All @@ -188,7 +188,7 @@ func TestMTGetProof(t *testing.T) {
if li == int(testVector.Index) {
cur = leafHash
}
err = mt.addLeaf(ctx, depositID, leafHash, uint(li), nil)
err = mt.addLeaf(ctx, depositID, leafHash, uint32(li), nil)
require.NoError(t, err)
}
root, err := mt.getRoot(ctx, nil)
Expand Down Expand Up @@ -239,7 +239,7 @@ func TestUpdateMT(t *testing.T) {
DestinationNetwork: testVector.NewLeaf.DestinationNetwork,
DestinationAddress: common.HexToAddress(testVector.NewLeaf.DestinationAddress),
BlockNumber: 0,
DepositCount: uint(i),
DepositCount: uint32(i),
Metadata: common.FromHex(testVector.NewLeaf.Metadata),
}
_, err := store.AddDeposit(ctx, deposit, nil)
Expand Down Expand Up @@ -348,7 +348,7 @@ func TestBuildMTRootAndStore(t *testing.T) {
require.Equal(t, len(leaves), len(result))
require.Equal(t, leaves[i][:], result[i].Leaf.Bytes())
require.Equal(t, newRoot, result[i].Root)
require.Equal(t, uint(i+1), result[i].RollupId)
require.Equal(t, uint32(i+1), result[i].RollupId)
}
}
}
Expand Down
Loading

0 comments on commit 57ae55e

Please sign in to comment.