Skip to content

Commit

Permalink
fix: db tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hopeyen committed Dec 11, 2024
1 parent e9299dc commit 1a58cb6
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 33 deletions.
4 changes: 2 additions & 2 deletions core/meterer/meterer.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ func (m *Meterer) ValidatePayment(ctx context.Context, header core.PaymentMetada
return fmt.Errorf("insufficient cumulative payment increment")
}
// the current request must not break the payment magnitude for the next payment if the two requests were delivered out-of-order
if nextPmt != nil && header.CumulativePayment.Add(header.CumulativePayment, m.PaymentCharged(uint(nextPmtnumSymbols))).Cmp(nextPmt) > 0 {
if nextPmt.Cmp(big.NewInt(0)) != 0 && header.CumulativePayment.Add(header.CumulativePayment, m.PaymentCharged(uint(nextPmtnumSymbols))).Cmp(nextPmt) > 0 {
return fmt.Errorf("breaking cumulative payment invariants")
}
// check passed: blob can be safely inserted into the set of payments
Expand All @@ -244,7 +244,7 @@ func (m *Meterer) ValidatePayment(ctx context.Context, header core.PaymentMetada

// PaymentCharged returns the chargeable price for a given data length
func (m *Meterer) PaymentCharged(numSymbols uint) *big.Int {
return new(big.Int).Mul(big.NewInt(int64(m.SymbolsCharged(numSymbols))), big.NewInt(int64(m.ChainPaymentState.GetPricePerSymbol())))
return big.NewInt(int64(m.SymbolsCharged(numSymbols) * m.ChainPaymentState.GetPricePerSymbol()))
}

// SymbolsCharged returns the number of symbols charged for a given data length
Expand Down
52 changes: 26 additions & 26 deletions core/meterer/meterer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,16 +186,16 @@ func TestMetererReservations(t *testing.T) {
paymentChainState.On("GetActiveReservationByAccount", testifymock.Anything, testifymock.Anything).Return(&core.ActiveReservation{}, fmt.Errorf("reservation not found"))

// test invalid quorom ID
header := createPaymentHeader(1, 0, accountID1)
header := createPaymentHeader(1, big.NewInt(0), accountID1)
err := mt.MeterRequest(ctx, *header, 1000, []uint8{0, 1, 2})
assert.ErrorContains(t, err, "quorum number mismatch")

// overwhelming bin overflow for empty bins
header = createPaymentHeader(binIndex-1, 0, accountID2)
header = createPaymentHeader(binIndex-1, big.NewInt(0), accountID2)
err = mt.MeterRequest(ctx, *header, 10, quoromNumbers)
assert.NoError(t, err)
// overwhelming bin overflow for empty bins
header = createPaymentHeader(binIndex-1, 0, accountID2)
header = createPaymentHeader(binIndex-1, big.NewInt(0), accountID2)
err = mt.MeterRequest(ctx, *header, 1000, quoromNumbers)
assert.ErrorContains(t, err, "overflow usage exceeds bin limit")

Expand All @@ -204,21 +204,21 @@ func TestMetererReservations(t *testing.T) {
if err != nil {
t.Fatalf("Failed to generate key: %v", err)
}
header = createPaymentHeader(1, 0, crypto.PubkeyToAddress(unregisteredUser.PublicKey))
header = createPaymentHeader(1, big.NewInt(0), crypto.PubkeyToAddress(unregisteredUser.PublicKey))
assert.NoError(t, err)
err = mt.MeterRequest(ctx, *header, 1000, []uint8{0, 1, 2})
assert.ErrorContains(t, err, "failed to get active reservation by account: reservation not found")

// test invalid bin index
header = createPaymentHeader(binIndex, 0, accountID1)
header = createPaymentHeader(binIndex, big.NewInt(0), accountID1)
err = mt.MeterRequest(ctx, *header, 2000, quoromNumbers)
assert.ErrorContains(t, err, "invalid bin index for reservation")

// test bin usage metering
symbolLength := uint(20)
requiredLength := uint(21) // 21 should be charged for length of 20 since minNumSymbols is 3
for i := 0; i < 9; i++ {
header = createPaymentHeader(binIndex, 0, accountID2)
header = createPaymentHeader(binIndex, big.NewInt(0), accountID2)
err = mt.MeterRequest(ctx, *header, symbolLength, quoromNumbers)
assert.NoError(t, err)
item, err := dynamoClient.GetItem(ctx, reservationTableName, commondynamodb.Key{
Expand All @@ -232,7 +232,7 @@ func TestMetererReservations(t *testing.T) {

}
// first over flow is allowed
header = createPaymentHeader(binIndex, 0, accountID2)
header = createPaymentHeader(binIndex, big.NewInt(0), accountID2)
assert.NoError(t, err)
err = mt.MeterRequest(ctx, *header, 25, quoromNumbers)
assert.NoError(t, err)
Expand All @@ -248,7 +248,7 @@ func TestMetererReservations(t *testing.T) {
assert.Equal(t, strconv.Itoa(int(16)), item["BinUsage"].(*types.AttributeValueMemberN).Value)

// second over flow
header = createPaymentHeader(binIndex, 0, accountID2)
header = createPaymentHeader(binIndex, big.NewInt(0), accountID2)
assert.NoError(t, err)
err = mt.MeterRequest(ctx, *header, 1, quoromNumbers)
assert.ErrorContains(t, err, "bin has already been filled")
Expand All @@ -275,18 +275,18 @@ func TestMetererOnDemand(t *testing.T) {
if err != nil {
t.Fatalf("Failed to generate key: %v", err)
}
header := createPaymentHeader(binIndex, 2, crypto.PubkeyToAddress(unregisteredUser.PublicKey))
header := createPaymentHeader(binIndex, big.NewInt(2), crypto.PubkeyToAddress(unregisteredUser.PublicKey))
assert.NoError(t, err)
err = mt.MeterRequest(ctx, *header, 1000, quorumNumbers)
assert.ErrorContains(t, err, "failed to get on-demand payment by account: payment not found")

// test invalid quorom ID
header = createPaymentHeader(binIndex, 1, accountID1)
header = createPaymentHeader(binIndex, big.NewInt(2), accountID1)
err = mt.MeterRequest(ctx, *header, 1000, []uint8{0, 1, 2})
assert.ErrorContains(t, err, "invalid quorum for On-Demand Request")

// test insufficient cumulative payment
header = createPaymentHeader(binIndex, 1, accountID1)
header = createPaymentHeader(binIndex, big.NewInt(1), accountID1)
err = mt.MeterRequest(ctx, *header, 1000, quorumNumbers)
assert.ErrorContains(t, err, "insufficient cumulative payment increment")
// No rollback after meter request
Expand All @@ -300,7 +300,7 @@ func TestMetererOnDemand(t *testing.T) {
// test duplicated cumulative payments
symbolLength := uint(100)
priceCharged := mt.PaymentCharged(symbolLength)
assert.Equal(t, uint64(102*mt.ChainPaymentState.GetPricePerSymbol()), priceCharged)
assert.Equal(t, big.NewInt(int64(102*mt.ChainPaymentState.GetPricePerSymbol())), priceCharged)
header = createPaymentHeader(binIndex, priceCharged, accountID2)
err = mt.MeterRequest(ctx, *header, symbolLength, quorumNumbers)
assert.NoError(t, err)
Expand All @@ -310,24 +310,24 @@ func TestMetererOnDemand(t *testing.T) {

// test valid payments
for i := 1; i < 9; i++ {
header = createPaymentHeader(binIndex, uint64(priceCharged)*uint64(i+1), accountID2)
header = createPaymentHeader(binIndex, new(big.Int).Mul(priceCharged, big.NewInt(int64(i+1))), accountID2)
err = mt.MeterRequest(ctx, *header, symbolLength, quorumNumbers)
assert.NoError(t, err)
}

// test cumulative payment on-chain constraint
header = createPaymentHeader(binIndex, 2023, accountID2)
header = createPaymentHeader(binIndex, big.NewInt(2023), accountID2)
err = mt.MeterRequest(ctx, *header, 1, quorumNumbers)
assert.ErrorContains(t, err, "invalid on-demand payment: request claims a cumulative payment greater than the on-chain deposit")

// test insufficient increment in cumulative payment
previousCumulativePayment := uint64(priceCharged) * uint64(9)
previousCumulativePayment := priceCharged.Mul(priceCharged, big.NewInt(9))
symbolLength = uint(2)
priceCharged = mt.PaymentCharged(symbolLength)
header = createPaymentHeader(binIndex, previousCumulativePayment+priceCharged-1, accountID2)
header = createPaymentHeader(binIndex, big.NewInt(0).Add(previousCumulativePayment, big.NewInt(0).Sub(priceCharged, big.NewInt(1))), accountID2)
err = mt.MeterRequest(ctx, *header, symbolLength, quorumNumbers)
assert.ErrorContains(t, err, "invalid on-demand payment: insufficient cumulative payment increment")
previousCumulativePayment = previousCumulativePayment + priceCharged
previousCumulativePayment = big.NewInt(0).Add(previousCumulativePayment, priceCharged)

// test cannot insert cumulative payment in out of order
header = createPaymentHeader(binIndex, mt.PaymentCharged(50), accountID2)
Expand All @@ -342,7 +342,7 @@ func TestMetererOnDemand(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, numPrevRecords, len(result))
// test failed global rate limit (previously payment recorded: 2, global limit: 1009)
header = createPaymentHeader(binIndex, previousCumulativePayment+mt.PaymentCharged(1010), accountID1)
header = createPaymentHeader(binIndex, big.NewInt(0).Add(previousCumulativePayment, mt.PaymentCharged(1010)), accountID1)
err = mt.MeterRequest(ctx, *header, 1010, quorumNumbers)
assert.ErrorContains(t, err, "failed global rate limiting")
// Correct rollback
Expand All @@ -360,42 +360,42 @@ func TestMeterer_paymentCharged(t *testing.T) {
symbolLength uint
pricePerSymbol uint32
minNumSymbols uint32
expected uint64
expected *big.Int
}{
{
name: "Data length equal to min chargeable size",
symbolLength: 1024,
pricePerSymbol: 1,
minNumSymbols: 1024,
expected: 1024,
expected: big.NewInt(1024),
},
{
name: "Data length less than min chargeable size",
symbolLength: 512,
pricePerSymbol: 1,
minNumSymbols: 1024,
expected: 1024,
expected: big.NewInt(1024),
},
{
name: "Data length greater than min chargeable size",
symbolLength: 2048,
pricePerSymbol: 1,
minNumSymbols: 1024,
expected: 2048,
expected: big.NewInt(2048),
},
{
name: "Large data length",
symbolLength: 1 << 20, // 1 MB
pricePerSymbol: 1,
minNumSymbols: 1024,
expected: 1 << 20,
expected: big.NewInt(1 << 20),
},
{
name: "Price not evenly divisible by min chargeable size",
symbolLength: 1536,
pricePerSymbol: 1,
minNumSymbols: 1024,
expected: 2048,
expected: big.NewInt(2048),
},
}

Expand Down Expand Up @@ -465,10 +465,10 @@ func TestMeterer_symbolsCharged(t *testing.T) {
}
}

func createPaymentHeader(binIndex uint32, cumulativePayment uint64, accountID gethcommon.Address) *core.PaymentMetadata {
func createPaymentHeader(binIndex uint32, cumulativePayment *big.Int, accountID gethcommon.Address) *core.PaymentMetadata {
return &core.PaymentMetadata{
AccountID: accountID.Hex(),
BinIndex: binIndex,
CumulativePayment: big.NewInt(int64(cumulativePayment)),
CumulativePayment: cumulativePayment,
}
}
12 changes: 7 additions & 5 deletions core/meterer/offchain_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,13 @@ func (s *OffchainStore) GetRelevantOnDemandRecords(ctx context.Context, accountI
if err != nil {
return nil, nil, 0, fmt.Errorf("failed to query smaller payments for account: %w", err)
}
var prevPayment *big.Int
prevPayment := big.NewInt(0)
if len(smallerResult) > 0 {
_, success := prevPayment.SetString(smallerResult[0]["CumulativePayments"].(*types.AttributeValueMemberN).Value, 10)
setPrevPayment, success := prevPayment.SetString(smallerResult[0]["CumulativePayments"].(*types.AttributeValueMemberN).Value, 10)
if !success {
return nil, nil, 0, fmt.Errorf("failed to parse previous payment: %w", err)
}
prevPayment = setPrevPayment
}

// Fetch the smallest entry larger than the given cumulativePayment
Expand All @@ -205,13 +206,14 @@ func (s *OffchainStore) GetRelevantOnDemandRecords(ctx context.Context, accountI
if err != nil {
return nil, nil, 0, fmt.Errorf("failed to query the next payment for account: %w", err)
}
var nextPayment *big.Int
var nextDataLength uint32
nextPayment := big.NewInt(0)
nextDataLength := uint32(0)
if len(largerResult) > 0 {
_, success := nextPayment.SetString(largerResult[0]["CumulativePayments"].(*types.AttributeValueMemberN).Value, 10)
setNextPayment, success := nextPayment.SetString(largerResult[0]["CumulativePayments"].(*types.AttributeValueMemberN).Value, 10)
if !success {
return nil, nil, 0, fmt.Errorf("failed to parse previous payment: %w", err)
}
nextPayment = setNextPayment
dataLength, err := strconv.ParseUint(largerResult[0]["DataLength"].(*types.AttributeValueMemberN).Value, 10, 32)
if err != nil {
return nil, nil, 0, fmt.Errorf("failed to parse blob size: %w", err)
Expand Down

0 comments on commit 1a58cb6

Please sign in to comment.