Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
reductionista committed Jan 7, 2025
1 parent 1c74a3b commit e49250f
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 67 deletions.
6 changes: 4 additions & 2 deletions integration-tests/smoke/event_loader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (

contract "github.com/smartcontractkit/chainlink-solana/contracts/generated/log_read_test"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/client"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/config"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/logpoller"

"github.com/smartcontractkit/chainlink-solana/integration-tests/solclient"
Expand All @@ -49,7 +50,8 @@ func TestEventLoader(t *testing.T) {
require.NoError(t, err)

rpcURL, wsURL := setupTestValidator(t, privateKey.PublicKey().String())
rpcClient := rpc.New(rpcURL)
cl, rpcClient, err := client.NewTestClient(rpcURL, config.NewDefault(), 1*time.Second, logger.Nop())
require.NoError(t, err)
wsClient, err := ws.Connect(ctx, wsURL)
require.NoError(t, err)

Expand All @@ -62,7 +64,7 @@ func TestEventLoader(t *testing.T) {
parser := &printParser{t: t}
sender := newLogSender(t, rpcClient, wsClient)
collector := logpoller.NewEncodedLogCollector(
rpcClient,
cl,
parser,
logger.Nop(),
)
Expand Down
15 changes: 11 additions & 4 deletions pkg/solana/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,25 @@ type Client struct {
requestGroup *singleflight.Group
}

func NewClient(endpoint string, cfg config.Config, requestTimeout time.Duration, log logger.Logger) (*Client, error) {
return &Client{
// Return both the client and the underlying rpc client for testing
func NewTestClient(endpoint string, cfg config.Config, requestTimeout time.Duration, log logger.Logger) (*Client, *rpc.Client, error) {
rpcClient := Client{
url: endpoint,
rpc: rpc.New(endpoint),
skipPreflight: cfg.SkipPreflight(),
commitment: cfg.Commitment(),
maxRetries: cfg.MaxRetries(),
txTimeout: cfg.TxTimeout(),
contextDuration: requestTimeout,
log: log,
requestGroup: &singleflight.Group{},
}, nil
}
rpcClient.rpc = rpc.New(endpoint)
return &rpcClient, rpcClient.rpc, nil
}

func NewClient(endpoint string, cfg config.Config, requestTimeout time.Duration, log logger.Logger) (*Client, error) {
rpcClient, _, err := NewTestClient(endpoint, cfg, requestTimeout, log)
return rpcClient, err
}

func (c *Client) latency(name string) func() {
Expand Down
22 changes: 10 additions & 12 deletions pkg/solana/logpoller/filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ import (

"github.com/gagliardetto/solana-go"
"github.com/smartcontractkit/chainlink-common/pkg/logger"

"github.com/smartcontractkit/chainlink-solana/pkg/solana/logpoller/utils"
)

type filters struct {
Expand Down Expand Up @@ -88,8 +86,6 @@ func (fl *filters) RegisterFilter(ctx context.Context, filter Filter) error {
return fmt.Errorf("failed to load filters: %w", err)
}

filter.EventSig = utils.Discriminator("event", filter.EventName)

fl.filtersMutex.Lock()
defer fl.filtersMutex.Unlock()

Expand Down Expand Up @@ -134,17 +130,17 @@ func (fl *filters) RegisterFilter(ctx context.Context, filter Filter) error {
}

programID := filter.Address.ToSolana().String()
if _, ok := fl.knownPrograms[programID]; !ok {
if _, ok = fl.knownPrograms[programID]; !ok {
fl.knownPrograms[programID] = 1
} else {
fl.knownPrograms[programID]++
}

discriminator := base64.StdEncoding.EncodeToString(filter.EventSig[:])[:10]
discriminatorHead := filter.Discriminator()[:10]
if _, ok := fl.knownPrograms[programID]; !ok {
fl.knownDiscriminators[discriminator] = 1
fl.knownDiscriminators[discriminatorHead] = 1
} else {
fl.knownDiscriminators[discriminator]++
fl.knownDiscriminators[discriminatorHead]++
}

return nil
Expand Down Expand Up @@ -220,13 +216,13 @@ func (fl *filters) removeFilterFromIndexes(filter Filter) {
}
}

discriminator := base64.StdEncoding.EncodeToString(filter.EventSig[:])[:10]
if refcount, ok := fl.knownDiscriminators[discriminator]; ok {
discriminatorHead := filter.Discriminator()[:10]
if refcount, ok := fl.knownDiscriminators[discriminatorHead]; ok {
refcount--
if refcount > 0 {
fl.knownDiscriminators[discriminator] = refcount
fl.knownDiscriminators[discriminatorHead] = refcount
} else {
delete(fl.knownDiscriminators, discriminator)
delete(fl.knownDiscriminators, discriminatorHead)
}
}
}
Expand Down Expand Up @@ -345,6 +341,8 @@ func (fl *filters) LoadFilters(ctx context.Context) error {
fl.filtersByAddress = make(map[PublicKey]map[EventSignature]map[int64]struct{})
fl.filtersToBackfill = make(map[int64]struct{})
fl.filtersToDelete = make(map[int64]Filter)
fl.knownPrograms = make(map[string]uint)
fl.knownDiscriminators = make(map[string]uint)

filters, err := fl.orm.SelectFilters(ctx)
if err != nil {
Expand Down
29 changes: 29 additions & 0 deletions pkg/solana/logpoller/filters_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ func TestFilters_LoadFilters(t *testing.T) {
happyPath2,
}, nil).Once()

orm.On("SelectSeqNums", mock.Anything).Return(map[int64]int64{
1: 18,
2: 25,
3: 0,
}, nil)

err := fs.LoadFilters(ctx)
require.EqualError(t, err, "failed to select filters from db: db failed")
err = fs.LoadFilters(ctx)
Expand Down Expand Up @@ -110,6 +116,7 @@ func TestFilters_RegisterFilter(t *testing.T) {
const filterName = "Filter"
dbFilter := Filter{Name: filterName}
orm.On("SelectFilters", mock.Anything).Return([]Filter{dbFilter}, nil).Once()
orm.On("SelectSeqNums", mock.Anything).Return(map[int64]int64{}, nil)
newFilter := dbFilter
tc.ModifyField(&newFilter)
err := fs.RegisterFilter(tests.Context(t), newFilter)
Expand All @@ -122,6 +129,7 @@ func TestFilters_RegisterFilter(t *testing.T) {
fs := newFilters(lggr, orm)
const filterName = "Filter"
orm.On("SelectFilters", mock.Anything).Return(nil, nil).Once()
orm.On("SelectSeqNums", mock.Anything).Return(map[int64]int64{}, nil).Once()
orm.On("InsertFilter", mock.Anything, mock.Anything).Return(int64(0), errors.New("failed to insert")).Once()
filter := Filter{Name: filterName}
err := fs.RegisterFilter(tests.Context(t), filter)
Expand Down Expand Up @@ -149,6 +157,7 @@ func TestFilters_RegisterFilter(t *testing.T) {
fs := newFilters(lggr, orm)
const filterName = "Filter"
orm.On("SelectFilters", mock.Anything).Return(nil, nil).Once()
orm.On("SelectSeqNums", mock.Anything).Return(map[int64]int64{}, nil).Once()
const filterID = int64(10)
orm.On("InsertFilter", mock.Anything, mock.Anything).Return(filterID, nil).Once()
err := fs.RegisterFilter(tests.Context(t), Filter{Name: filterName})
Expand Down Expand Up @@ -180,6 +189,7 @@ func TestFilters_UnregisterFilter(t *testing.T) {
fs := newFilters(lggr, orm)
const filterName = "Filter"
orm.On("SelectFilters", mock.Anything).Return(nil, nil).Once()
orm.On("SelectSeqNums", mock.Anything).Return(map[int64]int64{}, nil).Once()
err := fs.UnregisterFilter(tests.Context(t), filterName)
require.NoError(t, err)
})
Expand All @@ -189,6 +199,7 @@ func TestFilters_UnregisterFilter(t *testing.T) {
const filterName = "Filter"
const id int64 = 10
orm.On("SelectFilters", mock.Anything).Return([]Filter{{ID: id, Name: filterName}}, nil).Once()
orm.On("SelectSeqNums", mock.Anything).Return(map[int64]int64{}, nil).Once()
orm.On("MarkFilterDeleted", mock.Anything, id).Return(errors.New("db query failed")).Once()
err := fs.UnregisterFilter(tests.Context(t), filterName)
require.EqualError(t, err, "failed to mark filter deleted: db query failed")
Expand All @@ -199,6 +210,7 @@ func TestFilters_UnregisterFilter(t *testing.T) {
const filterName = "Filter"
const id int64 = 10
orm.On("SelectFilters", mock.Anything).Return([]Filter{{ID: id, Name: filterName}}, nil).Once()
orm.On("SelectSeqNums", mock.Anything).Return(map[int64]int64{}, nil).Once()
orm.On("MarkFilterDeleted", mock.Anything, id).Return(nil).Once()
err := fs.UnregisterFilter(tests.Context(t), filterName)
require.NoError(t, err)
Expand Down Expand Up @@ -226,6 +238,9 @@ func TestFilters_PruneFilters(t *testing.T) {
Name: "To keep",
},
}, nil).Once()
orm.On("SelectSeqNums", mock.Anything).Return(map[int64]int64{
2: 25,
}, nil).Once()
orm.On("DeleteFilters", mock.Anything, map[int64]Filter{toDelete.ID: toDelete}).Return(nil).Once()
err := fs.PruneFilters(tests.Context(t))
require.NoError(t, err)
Expand All @@ -246,6 +261,10 @@ func TestFilters_PruneFilters(t *testing.T) {
Name: "To keep",
},
}, nil).Once()
orm.EXPECT().SelectSeqNums(mock.Anything).Return(map[int64]int64{
1: 18,
2: 25,
}, nil).Once()
newToDelete := Filter{
ID: 3,
Name: "To delete 2",
Expand Down Expand Up @@ -291,6 +310,12 @@ func TestFilters_MatchingFilters(t *testing.T) {
EventSig: expectedFilter1.EventSig,
}
orm.On("SelectFilters", mock.Anything).Return([]Filter{expectedFilter1, expectedFilter2, sameAddress, sameEventSig}, nil).Once()
orm.On("SelectSeqNums", mock.Anything).Return(map[int64]int64{
1: 18,
2: 25,
3: 14,
4: 0,
}, nil)
filters := newFilters(lggr, orm)
err := filters.LoadFilters(tests.Context(t))
require.NoError(t, err)
Expand Down Expand Up @@ -319,6 +344,10 @@ func TestFilters_GetFiltersToBackfill(t *testing.T) {
Name: "notBackfilled",
}
orm.EXPECT().SelectFilters(mock.Anything).Return([]Filter{backfilledFilter, notBackfilled}, nil).Once()
orm.EXPECT().SelectSeqNums(mock.Anything).Return(map[int64]int64{
1: 18,
2: 25,
}, nil)
filters := newFilters(lggr, orm)
err := filters.LoadFilters(tests.Context(t))
require.NoError(t, err)
Expand Down
2 changes: 1 addition & 1 deletion pkg/solana/logpoller/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ func (j *getTransactionsFromBlockJob) Run(ctx context.Context) error {

if block.BlockTime == nil {
return fmt.Errorf("received block %d from rpc with missing block time", block.BlockHeight)
detail.blockTime = *block.BlockTime
}
detail.blockTime = *block.BlockTime

if len(block.Transactions) != len(blockSigsOnly.Signatures) {
return fmt.Errorf("block %d has %d transactions but %d signatures", block.BlockHeight, len(block.Transactions), len(blockSigsOnly.Signatures))
Expand Down
Loading

0 comments on commit e49250f

Please sign in to comment.