diff --git a/internal/contractStore/contractStore.go b/internal/contractStore/contractStore.go index 6f403609..623c2f99 100644 --- a/internal/contractStore/contractStore.go +++ b/internal/contractStore/contractStore.go @@ -13,14 +13,12 @@ type ContractStore interface { FindOrCreateProxyContract(blockNumber uint64, contractAddress string, proxyContractAddress string) (*ProxyContract, bool, error) GetContractWithProxyContract(address string, atBlockNumber uint64) (*ContractsTree, error) SetContractCheckedForProxy(address string) (*Contract, error) - GetProxyContract(address string) (*ProxyContract, error) SetContractAbi(address string, abi string, verified bool) (*Contract, error) SetContractMatchingContractAddress(address string, matchingContractAddress string) (*Contract, error) } // Tables type Contract struct { - Id uint64 ContractAddress string ContractAbi string MatchingContractAddress string diff --git a/internal/contractStore/pgContractStore/pgContractStore.go b/internal/contractStore/pgContractStore/pgContractStore.go index 089388bc..7d67127d 100644 --- a/internal/contractStore/pgContractStore/pgContractStore.go +++ b/internal/contractStore/pgContractStore/pgContractStore.go @@ -55,7 +55,7 @@ func (p *PgContractStore) FindOrCreateContract( } // found contract - if contract.ContractAddress == address && contract.Id != 0 { + if contract.ContractAddress == address { found = true return contract, nil } @@ -204,21 +204,6 @@ func (p *PgContractStore) SetContractCheckedForProxy(address string) (*contractS return contract, nil } -func (p *PgContractStore) GetProxyContract(address string) (*contractStore.ProxyContract, error) { - var proxyContract *contractStore.ProxyContract - - result := p.Db.First(&proxyContract, "contract_address = ?", strings.ToLower(address)) - if result.Error != nil { - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - p.Logger.Sugar().Debugf("Proxy contract not found in store '%s'", address) - return nil, nil - } - return nil, result.Error - } - - return proxyContract, nil -} - func (p *PgContractStore) SetContractAbi(address string, abi string, verified bool) (*contractStore.Contract, error) { contract := &contractStore.Contract{} diff --git a/internal/contractStore/sqliteContractStore/sqliteContractStore.go b/internal/contractStore/sqliteContractStore/sqliteContractStore.go index 9ffc59fe..9d864358 100644 --- a/internal/contractStore/sqliteContractStore/sqliteContractStore.go +++ b/internal/contractStore/sqliteContractStore/sqliteContractStore.go @@ -1,6 +1,7 @@ package sqliteContractStore import ( + "database/sql" "errors" "fmt" "github.com/Layr-Labs/sidecar/internal/contractStore" @@ -56,7 +57,7 @@ func (s *SqliteContractStore) FindOrCreateContract( } // found contract - if contract.ContractAddress == address && contract.Id != 0 { + if contract.ContractAddress == address { found = true return contract, nil } @@ -78,7 +79,7 @@ func (s *SqliteContractStore) FindOrCreateContract( return upsertedContract, found, err } -func (s *SqliteContractStore) indVerifiedContractWithMatchingBytecodeHash(bytecodeHash string, address string) (*contractStore.Contract, error) { +func (s *SqliteContractStore) FindVerifiedContractWithMatchingBytecodeHash(bytecodeHash string, address string) (*contractStore.Contract, error) { query := ` select * @@ -88,7 +89,7 @@ func (s *SqliteContractStore) indVerifiedContractWithMatchingBytecodeHash(byteco and verified = true and matching_contract_address = '' and contract_address != ? - order by id asc + order by rowid asc limit 1` var contract *contractStore.Contract @@ -158,19 +159,19 @@ func (s *SqliteContractStore) GetContractWithProxyContract(address string, atBlo select * from proxy_contracts - where contract_address = ? and block_number <= ? + where contract_address = @contractAddress and block_number <= @blockNumber order by block_number desc limit 1 ) as pc on (1=1) left join contracts as pcc on (pcc.contract_address = pc.proxy_contract_address) left join contracts as pcclike on (pcc.matching_contract_address = pcclike.contract_address) left join contracts as clike on (c.matching_contract_address = clike.contract_address) - where c.contract_address = ?` - + where + c.contract_address = @contractAddress + ` contractTree := &contractStore.ContractsTree{} result := s.Db.Raw(query, - address, - atBlockNumber, - address, + sql.Named("contractAddress", address), + sql.Named("blockNumber", atBlockNumber), ).Scan(&contractTree) if result.Error != nil { @@ -205,21 +206,6 @@ func (s *SqliteContractStore) SetContractCheckedForProxy(address string) (*contr return contract, nil } -func (s *SqliteContractStore) GetProxyContract(address string) (*contractStore.ProxyContract, error) { - var proxyContract *contractStore.ProxyContract - - result := s.Db.First(&proxyContract, "contract_address = ?", strings.ToLower(address)) - if result.Error != nil { - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - s.Logger.Sugar().Debugf("Proxy contract not found in store '%s'", address) - return nil, nil - } - return nil, result.Error - } - - return proxyContract, nil -} - func (s *SqliteContractStore) SetContractAbi(address string, abi string, verified bool) (*contractStore.Contract, error) { contract := &contractStore.Contract{} diff --git a/internal/contractStore/sqliteContractStore/sqliteContractStore_test.go b/internal/contractStore/sqliteContractStore/sqliteContractStore_test.go new file mode 100644 index 00000000..c8a23c67 --- /dev/null +++ b/internal/contractStore/sqliteContractStore/sqliteContractStore_test.go @@ -0,0 +1,193 @@ +package sqliteContractStore + +import ( + "github.com/Layr-Labs/sidecar/internal/config" + "github.com/Layr-Labs/sidecar/internal/contractStore" + "github.com/Layr-Labs/sidecar/internal/logger" + "github.com/Layr-Labs/sidecar/internal/sqlite/migrations" + "github.com/Layr-Labs/sidecar/internal/tests" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + "gorm.io/gorm" + "testing" +) + +func setup() ( + *config.Config, + *gorm.DB, + *zap.Logger, + error, +) { + cfg := tests.GetConfig() + l, _ := logger.NewLogger(&logger.LoggerConfig{Debug: cfg.Debug}) + + db, err := tests.GetSqliteDatabaseConnection() + if err != nil { + panic(err) + } + sqliteMigrator := migrations.NewSqliteMigrator(db, l) + if err := sqliteMigrator.MigrateAll(); err != nil { + l.Sugar().Fatalw("Failed to migrate", "error", err) + } + + return cfg, db, l, err +} + +func teardown() { + +} + +func Test_SqliteContractStore(t *testing.T) { + _, db, l, err := setup() + + if err != nil { + t.Fatal(err) + } + + cs := NewSqliteContractStore(db, l) + + createdContracts := make([]*contractStore.Contract, 0) + createdProxyContracts := make([]*contractStore.ProxyContract, 0) + + t.Run("Create contract", func(t *testing.T) { + contract := &contractStore.Contract{ + ContractAddress: "0x123", + ContractAbi: "[]", + Verified: true, + BytecodeHash: "0x123", + MatchingContractAddress: "", + } + + createdContract, found, err := cs.FindOrCreateContract(contract.ContractAddress, contract.ContractAbi, contract.Verified, contract.BytecodeHash, contract.MatchingContractAddress) + assert.Nil(t, err) + assert.False(t, found) + assert.Equal(t, contract.ContractAddress, createdContract.ContractAddress) + assert.Equal(t, contract.ContractAbi, createdContract.ContractAbi) + assert.Equal(t, contract.Verified, createdContract.Verified) + assert.Equal(t, contract.BytecodeHash, createdContract.BytecodeHash) + assert.Equal(t, contract.MatchingContractAddress, createdContract.MatchingContractAddress) + + createdContracts = append(createdContracts, createdContract) + }) + t.Run("Find contract rather than create", func(t *testing.T) { + contract := &contractStore.Contract{ + ContractAddress: "0x123", + ContractAbi: "[]", + Verified: true, + BytecodeHash: "0x123", + MatchingContractAddress: "", + } + + createdContract, found, err := cs.FindOrCreateContract(contract.ContractAddress, contract.ContractAbi, contract.Verified, contract.BytecodeHash, contract.MatchingContractAddress) + assert.Nil(t, err) + assert.True(t, found) + assert.Equal(t, contract.ContractAddress, createdContract.ContractAddress) + assert.Equal(t, contract.ContractAbi, createdContract.ContractAbi) + assert.Equal(t, contract.Verified, createdContract.Verified) + assert.Equal(t, contract.BytecodeHash, createdContract.BytecodeHash) + assert.Equal(t, contract.MatchingContractAddress, createdContract.MatchingContractAddress) + }) + t.Run("Create proxy contract", func(t *testing.T) { + proxyContract := &contractStore.ProxyContract{ + BlockNumber: 1, + ContractAddress: createdContracts[0].ContractAddress, + ProxyContractAddress: "0x456", + } + + proxy, found, err := cs.FindOrCreateProxyContract(uint64(proxyContract.BlockNumber), proxyContract.ContractAddress, proxyContract.ProxyContractAddress) + assert.Nil(t, err) + assert.False(t, found) + assert.Equal(t, proxyContract.BlockNumber, proxy.BlockNumber) + assert.Equal(t, proxyContract.ContractAddress, proxy.ContractAddress) + assert.Equal(t, proxyContract.ProxyContractAddress, proxy.ProxyContractAddress) + + newProxyContract := &contractStore.Contract{ + ContractAddress: proxyContract.ProxyContractAddress, + ContractAbi: "[]", + Verified: true, + BytecodeHash: "0x456", + MatchingContractAddress: "", + } + createdProxy, _, err := cs.FindOrCreateContract(newProxyContract.ContractAddress, newProxyContract.ContractAbi, newProxyContract.Verified, newProxyContract.BytecodeHash, newProxyContract.MatchingContractAddress) + assert.Nil(t, err) + createdContracts = append(createdContracts, createdProxy) + + createdProxyContracts = append(createdProxyContracts, proxy) + }) + t.Run("Find proxy contract rather than create", func(t *testing.T) { + proxyContract := &contractStore.ProxyContract{ + BlockNumber: 1, + ContractAddress: createdContracts[0].ContractAddress, + ProxyContractAddress: "0x456", + } + + proxy, found, err := cs.FindOrCreateProxyContract(uint64(proxyContract.BlockNumber), proxyContract.ContractAddress, proxyContract.ProxyContractAddress) + assert.Nil(t, err) + assert.True(t, found) + assert.Equal(t, proxyContract.BlockNumber, proxy.BlockNumber) + assert.Equal(t, proxyContract.ContractAddress, proxy.ContractAddress) + assert.Equal(t, proxyContract.ProxyContractAddress, proxy.ProxyContractAddress) + }) + t.Run("Get contract from address", func(t *testing.T) { + address := createdContracts[0].ContractAddress + + contract, err := cs.GetContractForAddress(address) + assert.Nil(t, err) + assert.Equal(t, address, contract.ContractAddress) + assert.Equal(t, createdContracts[0].ContractAbi, contract.ContractAbi) + assert.Equal(t, createdContracts[0].Verified, contract.Verified) + assert.Equal(t, createdContracts[0].BytecodeHash, contract.BytecodeHash) + assert.Equal(t, createdContracts[0].MatchingContractAddress, contract.MatchingContractAddress) + }) + t.Run("Find verified contract with matching bytecode hash", func(t *testing.T) { + bytecodeHash := createdContracts[0].BytecodeHash + address := createdContracts[0].ContractAddress + + contract, err := cs.FindVerifiedContractWithMatchingBytecodeHash(bytecodeHash, address) + assert.Nil(t, err) + assert.Nil(t, contract) + }) + t.Run("Get contract with proxy contract", func(t *testing.T) { + address := createdContracts[0].ContractAddress + + contracts := make([]contractStore.Contract, 0) + db.Raw(`select * from contracts`, address).Scan(&contracts) + + contractsTree, err := cs.GetContractWithProxyContract(address, 1) + assert.Nil(t, err) + assert.Equal(t, createdContracts[0].ContractAddress, contractsTree.BaseAddress) + assert.Equal(t, createdContracts[0].ContractAbi, contractsTree.BaseAbi) + assert.Equal(t, createdContracts[1].ContractAddress, contractsTree.BaseProxyAddress) + assert.Equal(t, createdContracts[1].ContractAbi, contractsTree.BaseProxyAbi) + assert.Equal(t, "", contractsTree.BaseLikeAddress) + assert.Equal(t, "", contractsTree.BaseLikeAbi) + }) + t.Run("Set contract checked for proxy", func(t *testing.T) { + address := createdContracts[0].ContractAddress + + contract, err := cs.SetContractCheckedForProxy(address) + assert.Nil(t, err) + assert.Equal(t, address, contract.ContractAddress) + assert.True(t, contract.CheckedForProxy) + }) + t.Run("Set contract ABI", func(t *testing.T) { + address := createdContracts[0].ContractAddress + abi := `[{ "type": "function", "name": "balanceOf", "inputs": [{ "name": "owner", "type": "address" }], "outputs": [{ "name": "balance", "type": "uint256" }] }]` + verified := true + + contract, err := cs.SetContractAbi(address, abi, verified) + assert.Nil(t, err) + assert.Equal(t, address, contract.ContractAddress) + assert.Equal(t, abi, contract.ContractAbi) + assert.Equal(t, verified, contract.Verified) + }) + t.Run("Set contract matching contract address", func(t *testing.T) { + address := createdContracts[0].ContractAddress + matchingContractAddress := "0x789" + + contract, err := cs.SetContractMatchingContractAddress(address, matchingContractAddress) + assert.Nil(t, err) + assert.Equal(t, address, contract.ContractAddress) + assert.Equal(t, matchingContractAddress, contract.MatchingContractAddress) + }) +} diff --git a/internal/sqlite/migrations/202409061249_bootstrapDb/up.go b/internal/sqlite/migrations/202409061249_bootstrapDb/up.go index 8e289481..f3f0ac7c 100644 --- a/internal/sqlite/migrations/202409061249_bootstrapDb/up.go +++ b/internal/sqlite/migrations/202409061249_bootstrapDb/up.go @@ -58,13 +58,13 @@ func (m *SqliteMigration) Up(grm *gorm.DB) error { verified INTEGER DEFAULT false, matching_contract_address TEXT DEFAULT NULL, checked_for_proxy INTEGER DEFAULT 0 NOT NULL, - checked_for_abi INTEGER NOT NULL, + checked_for_abi INTEGER NOT NULL DEFAULT 0, UNIQUE(contract_address) )`, `CREATE TABLE IF NOT EXISTS proxy_contracts ( block_number INTEGER NOT NULL, contract_address TEXT NOT NULL PRIMARY KEY REFERENCES contracts(contract_address) ON DELETE CASCADE, - proxy_contract_address TEXT NOT NULL REFERENCES contracts(contract_address) ON DELETE CASCADE, + proxy_contract_address TEXT NOT NULL, created_at DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL, updated_at DATETIME, deleted_at DATETIME diff --git a/internal/sqlite/sqlite.go b/internal/sqlite/sqlite.go index f6f98f2b..94a52c41 100644 --- a/internal/sqlite/sqlite.go +++ b/internal/sqlite/sqlite.go @@ -4,7 +4,6 @@ import ( "fmt" "gorm.io/driver/sqlite" "gorm.io/gorm" - "gorm.io/gorm/logger" ) func NewSqlite(path string) gorm.Dialector { @@ -14,7 +13,7 @@ func NewSqlite(path string) gorm.Dialector { func NewGormSqliteFromSqlite(sqlite gorm.Dialector) (*gorm.DB, error) { db, err := gorm.Open(sqlite, &gorm.Config{ - Logger: logger.Default.LogMode(logger.Silent), + //Logger: logger.Default.LogMode(logger.Silent), }) if err != nil { return nil, err