From e461bd60cbe5106313a7533a81d1b33ec7b3c81c Mon Sep 17 00:00:00 2001 From: hopeyen Date: Thu, 19 Dec 2024 00:21:19 -0800 Subject: [PATCH] fix: payment metadata check and symbol charged --- api/clients/v2/accountant.go | 18 +++++++++--------- api/clients/v2/accountant_test.go | 10 +++++----- api/clients/v2/disperser_client.go | 2 +- core/meterer/meterer.go | 2 +- disperser/apiserver/disperse_blob_v2.go | 2 +- 5 files changed, 17 insertions(+), 17 deletions(-) diff --git a/api/clients/v2/accountant.go b/api/clients/v2/accountant.go index 4d1f80500d..fed44412a4 100644 --- a/api/clients/v2/accountant.go +++ b/api/clients/v2/accountant.go @@ -67,14 +67,15 @@ func NewAccountant(accountID string, reservation *core.ReservedPayment, onDemand // then on-demand if the reservation is not available. The returned values are // reservation period for reservation payments and cumulative payment for on-demand payments, // and both fields are used to create the payment header and signature -func (a *Accountant) BlobPaymentInfo(ctx context.Context, numSymbols uint64, quorumNumbers []uint8) (uint32, *big.Int, error) { +func (a *Accountant) BlobPaymentInfo(ctx context.Context, numSymbols uint32, quorumNumbers []uint8) (uint32, *big.Int, error) { now := time.Now().Unix() currentReservationPeriod := meterer.GetReservationPeriod(uint64(now), a.reservationWindow) + symbolUsage := uint64(a.SymbolsCharged(numSymbols)) a.usageLock.Lock() defer a.usageLock.Unlock() relativeBinRecord := a.GetRelativeBinRecord(currentReservationPeriod) - relativeBinRecord.Usage += numSymbols + relativeBinRecord.Usage += symbolUsage // first attempt to use the active reservation binLimit := a.reservation.SymbolsPerSecond * uint64(a.reservationWindow) @@ -87,7 +88,7 @@ func (a *Accountant) BlobPaymentInfo(ctx context.Context, numSymbols uint64, quo overflowBinRecord := a.GetRelativeBinRecord(currentReservationPeriod + 2) // Allow one overflow when the overflow bin is empty, the current usage and new length are both less than the limit - if overflowBinRecord.Usage == 0 && relativeBinRecord.Usage-numSymbols < binLimit && numSymbols <= binLimit { + if overflowBinRecord.Usage == 0 && relativeBinRecord.Usage-symbolUsage < binLimit && symbolUsage <= binLimit { overflowBinRecord.Usage += relativeBinRecord.Usage - binLimit if err := QuorumCheck(quorumNumbers, a.reservation.QuorumNumbers); err != nil { return 0, big.NewInt(0), err @@ -97,8 +98,7 @@ func (a *Accountant) BlobPaymentInfo(ctx context.Context, numSymbols uint64, quo // reservation not available, attempt on-demand //todo: rollback later if disperser respond with some type of rejection? - relativeBinRecord.Usage -= numSymbols - incrementRequired := big.NewInt(int64(a.PaymentCharged(uint(numSymbols)))) + incrementRequired := big.NewInt(int64(a.PaymentCharged(numSymbols))) a.cumulativePayment.Add(a.cumulativePayment, incrementRequired) if a.cumulativePayment.Cmp(a.onDemand.CumulativePayment) <= 0 { if err := QuorumCheck(quorumNumbers, requiredQuorums); err != nil { @@ -110,7 +110,7 @@ func (a *Accountant) BlobPaymentInfo(ctx context.Context, numSymbols uint64, quo } // AccountBlob accountant provides and records payment information -func (a *Accountant) AccountBlob(ctx context.Context, numSymbols uint64, quorums []uint8, salt uint32) (*core.PaymentMetadata, error) { +func (a *Accountant) AccountBlob(ctx context.Context, numSymbols uint32, quorums []uint8, salt uint32) (*core.PaymentMetadata, error) { reservationPeriod, cumulativePayment, err := a.BlobPaymentInfo(ctx, numSymbols, quorums) if err != nil { return nil, err @@ -128,14 +128,14 @@ func (a *Accountant) AccountBlob(ctx context.Context, numSymbols uint64, quorums // TODO: PaymentCharged and SymbolsCharged copied from meterer, should be refactored // PaymentCharged returns the chargeable price for a given data length -func (a *Accountant) PaymentCharged(numSymbols uint) uint64 { +func (a *Accountant) PaymentCharged(numSymbols uint32) uint64 { return uint64(a.SymbolsCharged(numSymbols)) * uint64(a.pricePerSymbol) } // SymbolsCharged returns the number of symbols charged for a given data length // being at least MinNumSymbols or the nearest rounded-up multiple of MinNumSymbols. -func (a *Accountant) SymbolsCharged(numSymbols uint) uint32 { - if numSymbols <= uint(a.minNumSymbols) { +func (a *Accountant) SymbolsCharged(numSymbols uint32) uint32 { + if numSymbols <= a.minNumSymbols { return a.minNumSymbols } // Round up to the nearest multiple of MinNumSymbols diff --git a/api/clients/v2/accountant_test.go b/api/clients/v2/accountant_test.go index d28dd9f16b..ac8672ff5f 100644 --- a/api/clients/v2/accountant_test.go +++ b/api/clients/v2/accountant_test.go @@ -68,7 +68,7 @@ func TestAccountBlob_Reservation(t *testing.T) { accountant := NewAccountant(accountId, reservation, onDemand, reservationWindow, pricePerSymbol, minNumSymbols, numBins) ctx := context.Background() - symbolLength := uint64(500) + symbolLength := uint32(500) quorums := []uint8{0, 1} header, err := accountant.AccountBlob(ctx, symbolLength, quorums, salt) @@ -78,7 +78,7 @@ func TestAccountBlob_Reservation(t *testing.T) { assert.Equal(t, big.NewInt(0), header.CumulativePayment) assert.Equal(t, isRotation([]uint64{500, 0, 0}, mapRecordUsage(accountant.binRecords)), true) - symbolLength = uint64(700) + symbolLength = uint32(700) header, err = accountant.AccountBlob(ctx, symbolLength, quorums, salt) @@ -116,13 +116,13 @@ func TestAccountBlob_OnDemand(t *testing.T) { accountant := NewAccountant(accountId, reservation, onDemand, reservationWindow, pricePerSymbol, minNumSymbols, numBins) ctx := context.Background() - numSymbols := uint64(1500) + numSymbols := uint32(1500) quorums := []uint8{0, 1} header, err := accountant.AccountBlob(ctx, numSymbols, quorums, salt) assert.NoError(t, err) - expectedPayment := big.NewInt(int64(numSymbols * uint64(pricePerSymbol))) + expectedPayment := big.NewInt(int64(numSymbols * pricePerSymbol)) assert.Equal(t, uint32(0), header.ReservationPeriod) assert.Equal(t, expectedPayment, header.CumulativePayment) assert.Equal(t, isRotation([]uint64{0, 0, 0}, mapRecordUsage(accountant.binRecords)), true) @@ -144,7 +144,7 @@ func TestAccountBlob_InsufficientOnDemand(t *testing.T) { accountant := NewAccountant(accountId, reservation, onDemand, reservationWindow, pricePerSymbol, minNumSymbols, numBins) ctx := context.Background() - numSymbols := uint64(2000) + numSymbols := uint32(2000) quorums := []uint8{0, 1} _, err = accountant.AccountBlob(ctx, numSymbols, quorums, salt) diff --git a/api/clients/v2/disperser_client.go b/api/clients/v2/disperser_client.go index 22ab443d7d..6590fb8475 100644 --- a/api/clients/v2/disperser_client.go +++ b/api/clients/v2/disperser_client.go @@ -145,7 +145,7 @@ func (c *disperserClient) DisperseBlob( // } // symbolLength := encoding.GetBlobLengthPowerOf2(uint(len(data))) - // payment, err := c.accountant.AccountBlob(ctx, uint64(symbolLength), quorums, salt) + // payment, err := c.accountant.AccountBlob(ctx, uint32(symbolLength), quorums, salt) // if err != nil { // return nil, [32]byte{}, fmt.Errorf("error accounting blob: %w", err) // } diff --git a/core/meterer/meterer.go b/core/meterer/meterer.go index f7c13c401d..89335f96bf 100644 --- a/core/meterer/meterer.go +++ b/core/meterer/meterer.go @@ -159,7 +159,7 @@ func (m *Meterer) IncrementBinUsage(ctx context.Context, header core.PaymentMeta usageLimit := m.GetReservationBinLimit(reservation) if newUsage <= usageLimit { return nil - } else if newUsage-uint64(numSymbols) >= usageLimit { + } else if newUsage-uint64(symbolsCharged) >= usageLimit { // metered usage before updating the size already exceeded the limit return fmt.Errorf("bin has already been filled") } diff --git a/disperser/apiserver/disperse_blob_v2.go b/disperser/apiserver/disperse_blob_v2.go index c53c1bc66e..c117b0ea4e 100644 --- a/disperser/apiserver/disperse_blob_v2.go +++ b/disperser/apiserver/disperse_blob_v2.go @@ -145,7 +145,7 @@ func (s *DispersalServerV2) validateDispersalRequest(ctx context.Context, req *p // } // TODO(ian-shim): enable this check when we have payment metadata + authentication in disperser client - // if len(blobHeader.PaymentMetadata.AccountID) == 0 || blobHeader.PaymentMetadata.ReservationPeriod == 0 || blobHeader.PaymentMetadata.CumulativePayment == nil { + // if len(blobHeader.PaymentMetadata.AccountID) == 0 || (blobHeader.PaymentMetadata.ReservationPeriod == 0 && blobHeader.PaymentMetadata.CumulativePayment == nil) { // return api.NewErrorInvalidArg("invalid payment metadata") // }