Skip to content

Commit

Permalink
fix: atomic and map pointers
Browse files Browse the repository at this point in the history
  • Loading branch information
hopeyen committed Dec 5, 2024
1 parent 8352f35 commit 865b41e
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 31 deletions.
4 changes: 2 additions & 2 deletions core/chainio.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,13 @@ type Reader interface {
GetAllVersionedBlobParams(ctx context.Context) (map[uint8]*BlobVersionParameters, error)

// GetActiveReservations returns active reservations (end timestamp > current timestamp)
GetActiveReservations(ctx context.Context, accountIDs []gethcommon.Address) (*map[gethcommon.Address]ActiveReservation, error)
GetActiveReservations(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*ActiveReservation, error)

// GetActiveReservationByAccount returns active reservation by account ID
GetActiveReservationByAccount(ctx context.Context, accountID gethcommon.Address) (*ActiveReservation, error)

// GetOnDemandPayments returns all on-demand payments
GetOnDemandPayments(ctx context.Context, accountIDs []gethcommon.Address) (*map[gethcommon.Address]OnDemandPayment, error)
GetOnDemandPayments(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*OnDemandPayment, error)

// GetOnDemandPaymentByAccount returns on-demand payment of an account
GetOnDemandPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*OnDemandPayment, error)
Expand Down
20 changes: 10 additions & 10 deletions core/eth/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -659,8 +659,8 @@ func (t *Reader) GetAllVersionedBlobParams(ctx context.Context) (map[uint8]*core
return res, nil
}

func (t *Reader) GetActiveReservations(ctx context.Context, accountIDs []gethcommon.Address) (*map[gethcommon.Address]core.ActiveReservation, error) {
reservationsMap := make(map[gethcommon.Address]core.ActiveReservation)
func (t *Reader) GetActiveReservations(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*core.ActiveReservation, error) {
reservationsMap := make(map[gethcommon.Address]*core.ActiveReservation)
reservations, err := t.bindings.PaymentVault.GetReservations(&bind.CallOpts{
Context: ctx,
}, accountIDs)
Expand All @@ -670,17 +670,17 @@ func (t *Reader) GetActiveReservations(ctx context.Context, accountIDs []gethcom

// since reservations are returned in the same order as the accountIDs, we can directly map them
for i, reservation := range reservations {
reservationsMap[accountIDs[i]] = reservation
reservationsMap[accountIDs[i]] = &reservation
}

// filter out all zero-valued reservations
for accountID, reservation := range reservationsMap {
if isZeroValuedReservation(reservation) {
if isZeroValuedReservation(*reservation) {
delete(reservationsMap, accountID)
}
}

return &reservationsMap, nil
return reservationsMap, nil
}

func (t *Reader) GetActiveReservationByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ActiveReservation, error) {
Expand All @@ -691,13 +691,13 @@ func (t *Reader) GetActiveReservationByAccount(ctx context.Context, accountID ge
return nil, err
}
if isZeroValuedReservation(reservation) {
return nil, errors.New("reservation is zero-valued")
return nil, errors.New("reservation does not exist for given account")
}
return &reservation, nil
}

func (t *Reader) GetOnDemandPayments(ctx context.Context, accountIDs []gethcommon.Address) (*map[gethcommon.Address]core.OnDemandPayment, error) {
paymentsMap := make(map[gethcommon.Address]core.OnDemandPayment)
func (t *Reader) GetOnDemandPayments(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*core.OnDemandPayment, error) {
paymentsMap := make(map[gethcommon.Address]*core.OnDemandPayment)
payments, err := t.bindings.PaymentVault.GetOnDemandAmounts(&bind.CallOpts{
Context: ctx,
}, accountIDs)
Expand All @@ -707,7 +707,7 @@ func (t *Reader) GetOnDemandPayments(ctx context.Context, accountIDs []gethcommo

// since payments are returned in the same order as the accountIDs, we can directly map them
for i, payment := range payments {
paymentsMap[accountIDs[i]] = core.OnDemandPayment{
paymentsMap[accountIDs[i]] = &core.OnDemandPayment{
CumulativePayment: payment,
}
}
Expand All @@ -719,7 +719,7 @@ func (t *Reader) GetOnDemandPayments(ctx context.Context, accountIDs []gethcommo
}
}

return &paymentsMap, nil
return paymentsMap, nil
}

func (t *Reader) GetOnDemandPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.OnDemandPayment, error) {
Expand Down
32 changes: 17 additions & 15 deletions core/meterer/onchain_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"fmt"
"sync"
"sync/atomic"
"unsafe"

"github.com/Layr-Labs/eigenda/core"
"github.com/Layr-Labs/eigenda/core/eth"
Expand All @@ -30,8 +32,8 @@ var _ OnchainPayment = (*OnchainPaymentState)(nil)
type OnchainPaymentState struct {
tx *eth.Reader

ActiveReservations *map[gethcommon.Address]core.ActiveReservation
OnDemandPayments *map[gethcommon.Address]core.OnDemandPayment
ActiveReservations map[gethcommon.Address]*core.ActiveReservation
OnDemandPayments map[gethcommon.Address]*core.OnDemandPayment

ReservationsLock sync.RWMutex
OnDemandLocks sync.RWMutex
Expand All @@ -56,8 +58,8 @@ func NewOnchainPaymentState(ctx context.Context, tx *eth.Reader) (OnchainPayment

return OnchainPaymentState{
tx: tx,
ActiveReservations: &map[gethcommon.Address]core.ActiveReservation{},
OnDemandPayments: &map[gethcommon.Address]core.OnDemandPayment{},
ActiveReservations: make(map[gethcommon.Address]*core.ActiveReservation),
OnDemandPayments: make(map[gethcommon.Address]*core.OnDemandPayment),
PaymentVaultParams: paymentVaultParams,
}, nil
}
Expand Down Expand Up @@ -109,11 +111,11 @@ func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context,
return err
}
// These parameters should be rarely updated, but we refresh them anyway
pcs.PaymentVaultParams = paymentVaultParams
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&pcs.PaymentVaultParams)), unsafe.Pointer(paymentVaultParams))

pcs.ReservationsLock.Lock()
accountIDs := make([]gethcommon.Address, 0, len(*pcs.ActiveReservations))
for accountID := range *pcs.ActiveReservations {
accountIDs := make([]gethcommon.Address, 0, len(pcs.ActiveReservations))
for accountID := range pcs.ActiveReservations {
accountIDs = append(accountIDs, accountID)
}

Expand All @@ -125,8 +127,8 @@ func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context,
pcs.ReservationsLock.Unlock()

pcs.OnDemandLocks.Lock()
accountIDs = make([]gethcommon.Address, 0, len(*pcs.OnDemandPayments))
for accountID := range *pcs.OnDemandPayments {
accountIDs = make([]gethcommon.Address, 0, len(pcs.OnDemandPayments))
for accountID := range pcs.OnDemandPayments {
accountIDs = append(accountIDs, accountID)
}

Expand All @@ -142,8 +144,8 @@ func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context,

// GetActiveReservationByAccount returns a pointer to the active reservation for the given account ID; no writes will be made to the reservation
func (pcs *OnchainPaymentState) GetActiveReservationByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ActiveReservation, error) {
if reservation, ok := (*pcs.ActiveReservations)[accountID]; ok {
return &reservation, nil
if reservation, ok := (pcs.ActiveReservations)[accountID]; ok {
return reservation, nil
}

// pulls the chain state
Expand All @@ -152,7 +154,7 @@ func (pcs *OnchainPaymentState) GetActiveReservationByAccount(ctx context.Contex
return nil, err
}
pcs.ReservationsLock.Lock()
(*pcs.ActiveReservations)[accountID] = *res
(pcs.ActiveReservations)[accountID] = res
pcs.ReservationsLock.Unlock()
return res, nil
}
Expand All @@ -168,8 +170,8 @@ func (pcs *OnchainPaymentState) GetActiveReservationByAccountOnChain(ctx context

// GetOnDemandPaymentByAccount returns a pointer to the on-demand payment for the given account ID; no writes will be made to the payment
func (pcs *OnchainPaymentState) GetOnDemandPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.OnDemandPayment, error) {
if payment, ok := (*pcs.OnDemandPayments)[accountID]; ok {
return &payment, nil
if payment, ok := (pcs.OnDemandPayments)[accountID]; ok {
return payment, nil
}
// pulls the chain state
res, err := pcs.tx.GetOnDemandPaymentByAccount(ctx, accountID)
Expand All @@ -178,7 +180,7 @@ func (pcs *OnchainPaymentState) GetOnDemandPaymentByAccount(ctx context.Context,
}

pcs.OnDemandLocks.Lock()
(*pcs.OnDemandPayments)[accountID] = *res
(pcs.OnDemandPayments)[accountID] = res
pcs.OnDemandLocks.Unlock()
return res, nil
}
Expand Down
8 changes: 4 additions & 4 deletions core/mock/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,10 @@ func (t *MockWriter) PubkeyHashToOperator(ctx context.Context, operatorId core.O
return result.(gethcommon.Address), args.Error(1)
}

func (t *MockWriter) GetActiveReservations(ctx context.Context, accountIDs []gethcommon.Address) (*map[gethcommon.Address]core.ActiveReservation, error) {
func (t *MockWriter) GetActiveReservations(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*core.ActiveReservation, error) {
args := t.Called()
result := args.Get(0)
return result.(*map[gethcommon.Address]core.ActiveReservation), args.Error(1)
return result.(map[gethcommon.Address]*core.ActiveReservation), args.Error(1)
}

func (t *MockWriter) GetActiveReservationByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ActiveReservation, error) {
Expand All @@ -233,10 +233,10 @@ func (t *MockWriter) GetActiveReservationByAccount(ctx context.Context, accountI
return result.(*core.ActiveReservation), args.Error(1)
}

func (t *MockWriter) GetOnDemandPayments(ctx context.Context, accountIDs []gethcommon.Address) (*map[gethcommon.Address]core.OnDemandPayment, error) {
func (t *MockWriter) GetOnDemandPayments(ctx context.Context, accountIDs []gethcommon.Address) (map[gethcommon.Address]*core.OnDemandPayment, error) {
args := t.Called()
result := args.Get(0)
return result.(*map[gethcommon.Address]core.OnDemandPayment), args.Error(1)
return result.(map[gethcommon.Address]*core.OnDemandPayment), args.Error(1)
}

func (t *MockWriter) GetOnDemandPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.OnDemandPayment, error) {
Expand Down

0 comments on commit 865b41e

Please sign in to comment.