From f4c4bb0c8d4b2857adf8ed422535930b48136083 Mon Sep 17 00:00:00 2001 From: hopeyen <60078528+hopeyen@users.noreply.github.com> Date: Tue, 17 Dec 2024 00:19:33 +0000 Subject: [PATCH] refactor: refreshOnchainPaymentState arg --- core/meterer/onchain_state.go | 45 +++++++++++++++++------------------ 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/core/meterer/onchain_state.go b/core/meterer/onchain_state.go index 48b15c43aa..951d60b974 100644 --- a/core/meterer/onchain_state.go +++ b/core/meterer/onchain_state.go @@ -14,7 +14,7 @@ import ( // OnchainPaymentState is an interface for getting information about the current chain state for payments. type OnchainPayment interface { - RefreshOnchainPaymentState(ctx context.Context, tx *eth.Reader) error + RefreshOnchainPaymentState(ctx context.Context) error GetReservedPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ReservedPayment, error) GetOnDemandPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.OnDemandPayment, error) GetOnDemandQuorumNumbers(ctx context.Context) ([]uint8, error) @@ -49,49 +49,45 @@ type PaymentVaultParams struct { } func NewOnchainPaymentState(ctx context.Context, tx *eth.Reader) (*OnchainPaymentState, error) { - paymentVaultParams, err := GetPaymentVaultParams(ctx, tx) - if err != nil { - return nil, err - } - state := OnchainPaymentState{ tx: tx, ReservedPayments: make(map[gethcommon.Address]*core.ReservedPayment), OnDemandPayments: make(map[gethcommon.Address]*core.OnDemandPayment), PaymentVaultParams: atomic.Pointer[PaymentVaultParams]{}, } - state.PaymentVaultParams.Store(paymentVaultParams) - - return &state, nil -} -func GetPaymentVaultParams(ctx context.Context, tx *eth.Reader) (*PaymentVaultParams, error) { - blockNumber, err := tx.GetCurrentBlockNumber(ctx) + paymentVaultParams, err := state.GetPaymentVaultParams(ctx) if err != nil { return nil, err } - quorumNumbers, err := tx.GetRequiredQuorumNumbers(ctx, blockNumber) + state.PaymentVaultParams.Store(paymentVaultParams) + + return &state, nil +} + +func (pcs *OnchainPaymentState) GetPaymentVaultParams(ctx context.Context) (*PaymentVaultParams, error) { + quorumNumbers, err := pcs.GetOnDemandQuorumNumbers(ctx) if err != nil { return nil, err } - globalSymbolsPerSecond, err := tx.GetGlobalSymbolsPerSecond(ctx) + globalSymbolsPerSecond, err := pcs.tx.GetGlobalSymbolsPerSecond(ctx) if err != nil { return nil, err } - minNumSymbols, err := tx.GetMinNumSymbols(ctx) + minNumSymbols, err := pcs.tx.GetMinNumSymbols(ctx) if err != nil { return nil, err } - pricePerSymbol, err := tx.GetPricePerSymbol(ctx) + pricePerSymbol, err := pcs.tx.GetPricePerSymbol(ctx) if err != nil { return nil, err } - reservationWindow, err := tx.GetReservationWindow(ctx) + reservationWindow, err := pcs.tx.GetReservationWindow(ctx) if err != nil { return nil, err } @@ -106,8 +102,8 @@ func GetPaymentVaultParams(ctx context.Context, tx *eth.Reader) (*PaymentVaultPa } // RefreshOnchainPaymentState returns the current onchain payment state -func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context, tx *eth.Reader) error { - paymentVaultParams, err := GetPaymentVaultParams(ctx, tx) +func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context) error { + paymentVaultParams, err := pcs.GetPaymentVaultParams(ctx) if err != nil { return err } @@ -120,7 +116,7 @@ func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context, accountIDs = append(accountIDs, accountID) } - reservedPayments, err := tx.GetReservedPayments(ctx, accountIDs) + reservedPayments, err := pcs.tx.GetReservedPayments(ctx, accountIDs) if err != nil { return err } @@ -133,7 +129,7 @@ func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context, accountIDs = append(accountIDs, accountID) } - onDemandPayments, err := tx.GetOnDemandPayments(ctx, accountIDs) + onDemandPayments, err := pcs.tx.GetOnDemandPayments(ctx, accountIDs) if err != nil { return err } @@ -146,10 +142,11 @@ func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context, // GetReservedPaymentByAccount returns a pointer to the active reservation for the given account ID; no writes will be made to the reservation func (pcs *OnchainPaymentState) GetReservedPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ReservedPayment, error) { pcs.ReservationsLock.RLock() - defer pcs.ReservationsLock.RUnlock() if reservation, ok := (pcs.ReservedPayments)[accountID]; ok { + pcs.ReservationsLock.RUnlock() return reservation, nil } + pcs.ReservationsLock.RUnlock() // pulls the chain state res, err := pcs.tx.GetReservedPaymentByAccount(ctx, accountID) @@ -166,10 +163,12 @@ func (pcs *OnchainPaymentState) GetReservedPaymentByAccount(ctx context.Context, // GetOnDemandPaymentByAccount returns a pointer to the on-demand payment for the given account ID; no writes will be made to the payment func (pcs *OnchainPaymentState) GetOnDemandPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.OnDemandPayment, error) { pcs.OnDemandLocks.RLock() - defer pcs.OnDemandLocks.RUnlock() if payment, ok := (pcs.OnDemandPayments)[accountID]; ok { + pcs.OnDemandLocks.RUnlock() return payment, nil } + pcs.OnDemandLocks.RUnlock() + // pulls the chain state res, err := pcs.tx.GetOnDemandPaymentByAccount(ctx, accountID) if err != nil {