Skip to content

Commit

Permalink
Move CalculateChunksForRBF() to the mempool changeset
Browse files Browse the repository at this point in the history
  • Loading branch information
sdaftuar committed Nov 13, 2024
1 parent 284a1d3 commit d7dc9fd
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 74 deletions.
8 changes: 2 additions & 6 deletions src/policy/rbf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,10 @@ std::optional<std::string> PaysForRBF(CAmount original_fees,
return std::nullopt;
}

std::optional<std::pair<DiagramCheckError, std::string>> ImprovesFeerateDiagram(CTxMemPool& pool,
const CTxMemPool::setEntries& direct_conflicts,
const CTxMemPool::setEntries& all_conflicts,
CAmount replacement_fees,
int64_t replacement_vsize)
std::optional<std::pair<DiagramCheckError, std::string>> ImprovesFeerateDiagram(CTxMemPool::ChangeSet& changeset)
{
// Require that the replacement strictly improves the mempool's feerate diagram.
const auto chunk_results{pool.CalculateChunksForRBF(replacement_fees, replacement_vsize, all_conflicts, all_conflicts)};
const auto chunk_results{changeset.CalculateChunksForRBF()};

if (!chunk_results.has_value()) {
return std::make_pair(DiagramCheckError::UNCALCULABLE, util::ErrorString(chunk_results).original);
Expand Down
14 changes: 2 additions & 12 deletions src/policy/rbf.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,19 +117,9 @@ std::optional<std::string> PaysForRBF(CAmount original_fees,

/**
* The replacement transaction must improve the feerate diagram of the mempool.
* @param[in] pool The mempool.
* @param[in] direct_conflicts Set of in-mempool txids corresponding to the direct conflicts i.e.
* input double-spends with the proposed transaction
* @param[in] all_conflicts Set of mempool entries corresponding to all transactions to be evicted
* @param[in] replacement_fees Fees of proposed replacement package
* @param[in] replacement_vsize Size of proposed replacement package
* @param[in] changeset The changeset containing proposed additions/removals
* @returns error type and string if mempool diagram doesn't improve, otherwise std::nullopt.
*/
std::optional<std::pair<DiagramCheckError, std::string>> ImprovesFeerateDiagram(CTxMemPool& pool,
const CTxMemPool::setEntries& direct_conflicts,
const CTxMemPool::setEntries& all_conflicts,
CAmount replacement_fees,
int64_t replacement_vsize)
EXCLUSIVE_LOCKS_REQUIRED(pool.cs);
std::optional<std::pair<DiagramCheckError, std::string>> ImprovesFeerateDiagram(CTxMemPool::ChangeSet& changeset);

#endif // BITCOIN_POLICY_RBF_H
26 changes: 21 additions & 5 deletions src/test/fuzz/rbf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,18 @@ FUZZ_TARGET(package_rbf, .init = initialize_package_rbf)
std::vector<CTransaction> mempool_txs;
size_t iter{0};

int32_t replacement_vsize = fuzzed_data_provider.ConsumeIntegralInRange<int32_t>(1, 1000000);

// Keep track of the total vsize of CTxMemPoolEntry's being added to the mempool to avoid overflow
// Add replacement_vsize since this is added to new diagram during RBF check
std::optional<CMutableTransaction> replacement_tx = ConsumeDeserializable<CMutableTransaction>(fuzzed_data_provider, TX_WITH_WITNESS);
if (!replacement_tx) {
return;
}
assert(iter <= g_outpoints.size());
replacement_tx->vin.resize(1);
replacement_tx->vin[0].prevout = g_outpoints[iter++];
CTransaction replacement_tx_final{*replacement_tx};
auto replacement_entry = ConsumeTxMemPoolEntry(fuzzed_data_provider, replacement_tx_final);
int32_t replacement_vsize = replacement_entry.GetTxSize();
int64_t running_vsize_total{replacement_vsize};

LOCK2(cs_main, pool.cs);
Expand Down Expand Up @@ -162,9 +170,17 @@ FUZZ_TARGET(package_rbf, .init = initialize_package_rbf)
pool.CalculateDescendants(txiter, all_conflicts);
}

// Calculate the chunks for a replacement.
CAmount replacement_fees = ConsumeMoney(fuzzed_data_provider);
auto calc_results{pool.CalculateChunksForRBF(replacement_fees, replacement_vsize, direct_conflicts, all_conflicts)};
auto changeset = pool.GetChangeSet();
for (auto& txiter : all_conflicts) {
changeset->StageRemoval(txiter);
}
changeset->StageAddition(replacement_entry.GetSharedTx(), replacement_fees,
replacement_entry.GetTime().count(), replacement_entry.GetHeight(),
replacement_entry.GetSequence(), replacement_entry.GetSpendsCoinbase(),
replacement_entry.GetSigOpCost(), replacement_entry.GetLockPoints());
// Calculate the chunks for a replacement.
auto calc_results{changeset->CalculateChunksForRBF()};

if (calc_results.has_value()) {
// Sanity checks on the chunks.
Expand Down Expand Up @@ -192,7 +208,7 @@ FUZZ_TARGET(package_rbf, .init = initialize_package_rbf)
}

// If internals report error, wrapper should too
auto err_tuple{ImprovesFeerateDiagram(pool, direct_conflicts, all_conflicts, replacement_fees, replacement_vsize)};
auto err_tuple{ImprovesFeerateDiagram(*changeset)};
if (!calc_results.has_value()) {
assert(err_tuple.value().first == DiagramCheckError::UNCALCULABLE);
} else {
Expand Down
129 changes: 103 additions & 26 deletions src/test/rbf_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,44 +369,79 @@ BOOST_FIXTURE_TEST_CASE(improves_feerate, TestChain100Setup)

const auto entry1 = pool.GetIter(tx1->GetHash()).value();
const auto tx1_fee = entry1->GetModifiedFee();
const auto tx1_size = entry1->GetTxSize();
const auto entry2 = pool.GetIter(tx2->GetHash()).value();
const auto tx2_fee = entry2->GetModifiedFee();
const auto tx2_size = entry2->GetTxSize();

// conflicting transactions
const auto tx1_conflict = make_tx(/*inputs=*/ {m_coinbase_txns[0], m_coinbase_txns[2]}, /*output_values=*/ {10 * COIN});
const auto tx3 = make_tx(/*inputs=*/ {tx1_conflict}, /*output_values=*/ {995 * CENT});
auto entry3 = entry.FromTx(tx3);

// Now test ImprovesFeerateDiagram with various levels of "package rbf" feerates

// It doesn't improve itself
const auto res1 = ImprovesFeerateDiagram(pool, {entry1}, {entry1, entry2}, tx1_fee + tx2_fee, tx1_size + tx2_size);
auto changeset = pool.GetChangeSet();
changeset->StageRemoval(entry1);
changeset->StageRemoval(entry2);
changeset->StageAddition(tx1_conflict, tx1_fee, 0, 1, 0, false, 4, LockPoints());
changeset->StageAddition(tx3, tx2_fee, 0, 1, 0, false, 4, LockPoints());
const auto res1 = ImprovesFeerateDiagram(*changeset);
BOOST_CHECK(res1.has_value());
BOOST_CHECK(res1.value().first == DiagramCheckError::FAILURE);
BOOST_CHECK(res1.value().second == "insufficient feerate: does not improve feerate diagram");

// With one more satoshi it does
BOOST_CHECK(ImprovesFeerateDiagram(pool, {entry1}, {entry1, entry2}, tx1_fee + tx2_fee + 1, tx1_size + tx2_size) == std::nullopt);

changeset.reset();
changeset = pool.GetChangeSet();
changeset->StageRemoval(entry1);
changeset->StageRemoval(entry2);
changeset->StageAddition(tx1_conflict, tx1_fee+1, 0, 1, 0, false, 4, LockPoints());
changeset->StageAddition(tx3, tx2_fee, 0, 1, 0, false, 4, LockPoints());
BOOST_CHECK(ImprovesFeerateDiagram(*changeset) == std::nullopt);

changeset.reset();
// With prioritisation of in-mempool conflicts, it affects the results of the comparison using the same args as just above
pool.PrioritiseTransaction(entry1->GetSharedTx()->GetHash(), /*nFeeDelta=*/1);
const auto res2 = ImprovesFeerateDiagram(pool, {entry1}, {entry1, entry2}, tx1_fee + tx2_fee + 1, tx1_size + tx2_size);
changeset = pool.GetChangeSet();
changeset->StageRemoval(entry1);
changeset->StageRemoval(entry2);
changeset->StageAddition(tx1_conflict, tx1_fee+1, 0, 1, 0, false, 4, LockPoints());
changeset->StageAddition(tx3, tx2_fee, 0, 1, 0, false, 4, LockPoints());
const auto res2 = ImprovesFeerateDiagram(*changeset);
BOOST_CHECK(res2.has_value());
BOOST_CHECK(res2.value().first == DiagramCheckError::FAILURE);
BOOST_CHECK(res2.value().second == "insufficient feerate: does not improve feerate diagram");
changeset.reset();

pool.PrioritiseTransaction(entry1->GetSharedTx()->GetHash(), /*nFeeDelta=*/-1);

// With one less vB it does
BOOST_CHECK(ImprovesFeerateDiagram(pool, {entry1}, {entry1, entry2}, tx1_fee + tx2_fee, tx1_size + tx2_size - 1) == std::nullopt);
// With fewer vbytes it does
CMutableTransaction tx4{entry3.GetTx()};
tx4.vin[0].scriptWitness = CScriptWitness(); // Clear out the witness, to reduce size
auto entry4 = entry.FromTx(MakeTransactionRef(tx4));
changeset = pool.GetChangeSet();
changeset->StageRemoval(entry1);
changeset->StageRemoval(entry2);
changeset->StageAddition(tx1_conflict, tx1_fee, 0, 1, 0, false, 4, LockPoints());
changeset->StageAddition(entry4.GetSharedTx(), tx2_fee, 0, 1, 0, false, 4, LockPoints());
BOOST_CHECK(ImprovesFeerateDiagram(*changeset) == std::nullopt);
changeset.reset();

// Adding a grandchild makes the cluster size 3, which is uncalculable
const auto tx3 = make_tx(/*inputs=*/ {tx2}, /*output_values=*/ {995 * CENT});
AddToMempool(pool, entry.Fee(normal_fee).FromTx(tx3));
const auto res3 = ImprovesFeerateDiagram(pool, {entry1}, {entry1, entry2}, tx1_fee + tx2_fee + 1, tx1_size + tx2_size);
const auto tx5 = make_tx(/*inputs=*/ {tx2}, /*output_values=*/ {995 * CENT});
AddToMempool(pool, entry.Fee(normal_fee).FromTx(tx5));
const auto entry5 = pool.GetIter(tx5->GetHash()).value();

changeset = pool.GetChangeSet();
changeset->StageRemoval(entry1);
changeset->StageRemoval(entry2);
changeset->StageRemoval(entry5);
changeset->StageAddition(tx1_conflict, tx1_fee, 0, 1, 0, false, 4, LockPoints());
changeset->StageAddition(entry4.GetSharedTx(), tx2_fee + entry5->GetModifiedFee() + 1, 0, 1, 0, false, 4, LockPoints());
const auto res3 = ImprovesFeerateDiagram(*changeset);
BOOST_CHECK(res3.has_value());
BOOST_CHECK(res3.value().first == DiagramCheckError::UNCALCULABLE);
BOOST_CHECK(res3.value().second == strprintf("%s has both ancestor and descendant, exceeding cluster limit of 2", tx2->GetHash().GetHex()));

BOOST_CHECK(res3.value().second == strprintf("%s has 2 ancestors, max 1 allowed", tx5->GetHash().GetHex()));
}

BOOST_FIXTURE_TEST_CASE(calc_feerate_diagram_rbf, TestChain100Setup)
Expand All @@ -427,19 +462,28 @@ BOOST_FIXTURE_TEST_CASE(calc_feerate_diagram_rbf, TestChain100Setup)
const auto entry_low = pool.GetIter(low_tx->GetHash()).value();
const auto low_size = entry_low->GetTxSize();

const auto replacement_tx = make_tx(/*inputs=*/ {m_coinbase_txns[0]}, /*output_values=*/ {9 * COIN});
auto entry_replacement = entry.FromTx(replacement_tx);

// Replacement of size 1
{
const auto replace_one{pool.CalculateChunksForRBF(/*replacement_fees=*/0, /*replacement_vsize=*/1, {entry_low}, {entry_low})};
auto changeset = pool.GetChangeSet();
changeset->StageRemoval(entry_low);
changeset->StageAddition(replacement_tx, 0, 0, 1, 0, false, 4, LockPoints());
const auto replace_one{changeset->CalculateChunksForRBF()};
BOOST_CHECK(replace_one.has_value());
std::vector<FeeFrac> expected_old_chunks{{low_fee, low_size}};
BOOST_CHECK(replace_one->first == expected_old_chunks);
std::vector<FeeFrac> expected_new_chunks{{0, 1}};
std::vector<FeeFrac> expected_new_chunks{{0, int32_t(entry_replacement.GetTxSize())}};
BOOST_CHECK(replace_one->second == expected_new_chunks);
}

// Non-zero replacement fee/size
{
const auto replace_one_fee{pool.CalculateChunksForRBF(/*replacement_fees=*/high_fee, /*replacement_vsize=*/low_size, {entry_low}, {entry_low})};
auto changeset = pool.GetChangeSet();
changeset->StageRemoval(entry_low);
changeset->StageAddition(replacement_tx, high_fee, 0, 1, 0, false, 4, LockPoints());
const auto replace_one_fee{changeset->CalculateChunksForRBF()};
BOOST_CHECK(replace_one_fee.has_value());
std::vector<FeeFrac> expected_old_diagram{{low_fee, low_size}};
BOOST_CHECK(replace_one_fee->first == expected_old_diagram);
Expand All @@ -454,7 +498,11 @@ BOOST_FIXTURE_TEST_CASE(calc_feerate_diagram_rbf, TestChain100Setup)
const auto high_size = entry_high->GetTxSize();

{
const auto replace_single_chunk{pool.CalculateChunksForRBF(/*replacement_fees=*/high_fee, /*replacement_vsize=*/low_size, {entry_low}, {entry_low, entry_high})};
auto changeset = pool.GetChangeSet();
changeset->StageRemoval(entry_low);
changeset->StageRemoval(entry_high);
changeset->StageAddition(replacement_tx, high_fee, 0, 1, 0, false, 4, LockPoints());
const auto replace_single_chunk{changeset->CalculateChunksForRBF()};
BOOST_CHECK(replace_single_chunk.has_value());
std::vector<FeeFrac> expected_old_chunks{{low_fee + high_fee, low_size + high_size}};
BOOST_CHECK(replace_single_chunk->first == expected_old_chunks);
Expand All @@ -464,7 +512,10 @@ BOOST_FIXTURE_TEST_CASE(calc_feerate_diagram_rbf, TestChain100Setup)

// Conflict with the 2nd tx, resulting in new diagram with three entries
{
const auto replace_cpfp_child{pool.CalculateChunksForRBF(/*replacement_fees=*/high_fee, /*replacement_vsize=*/low_size, {entry_high}, {entry_high})};
auto changeset = pool.GetChangeSet();
changeset->StageRemoval(entry_high);
changeset->StageAddition(replacement_tx, high_fee, 0, 1, 0, false, 4, LockPoints());
const auto replace_cpfp_child{changeset->CalculateChunksForRBF()};
BOOST_CHECK(replace_cpfp_child.has_value());
std::vector<FeeFrac> expected_old_chunks{{low_fee + high_fee, low_size + high_size}};
BOOST_CHECK(replace_cpfp_child->first == expected_old_chunks);
Expand All @@ -476,12 +527,16 @@ BOOST_FIXTURE_TEST_CASE(calc_feerate_diagram_rbf, TestChain100Setup)
const auto normal_tx = make_tx(/*inputs=*/ {high_tx}, /*output_values=*/ {995 * CENT});
AddToMempool(pool, entry.Fee(normal_fee).FromTx(normal_tx));
const auto entry_normal = pool.GetIter(normal_tx->GetHash()).value();
const auto normal_size = entry_normal->GetTxSize();

{
const auto replace_too_large{pool.CalculateChunksForRBF(/*replacement_fees=*/normal_fee, /*replacement_vsize=*/normal_size, {entry_low}, {entry_low, entry_high, entry_normal})};
auto changeset = pool.GetChangeSet();
changeset->StageRemoval(entry_low);
changeset->StageRemoval(entry_high);
changeset->StageRemoval(entry_normal);
changeset->StageAddition(replacement_tx, high_fee, 0, 1, 0, false, 4, LockPoints());
const auto replace_too_large{changeset->CalculateChunksForRBF()};
BOOST_CHECK(!replace_too_large.has_value());
BOOST_CHECK_EQUAL(util::ErrorString(replace_too_large).original, strprintf("%s has 2 descendants, max 1 allowed", low_tx->GetHash().GetHex()));
BOOST_CHECK_EQUAL(util::ErrorString(replace_too_large).original, strprintf("%s has 2 ancestors, max 1 allowed", normal_tx->GetHash().GetHex()));
}

// Make a size 2 cluster that is itself two chunks; evict both txns
Expand All @@ -496,7 +551,11 @@ BOOST_FIXTURE_TEST_CASE(calc_feerate_diagram_rbf, TestChain100Setup)
const auto low_size_2 = entry_low_2->GetTxSize();

{
const auto replace_two_chunks_single_cluster{pool.CalculateChunksForRBF(/*replacement_fees=*/high_fee, /*replacement_vsize=*/low_size, {entry_high_2}, {entry_high_2, entry_low_2})};
auto changeset = pool.GetChangeSet();
changeset->StageRemoval(entry_high_2);
changeset->StageRemoval(entry_low_2);
changeset->StageAddition(replacement_tx, high_fee, 0, 1, 0, false, 4, LockPoints());
const auto replace_two_chunks_single_cluster{changeset->CalculateChunksForRBF()};
BOOST_CHECK(replace_two_chunks_single_cluster.has_value());
std::vector<FeeFrac> expected_old_chunks{{high_fee, high_size_2}, {low_fee, low_size_2}};
BOOST_CHECK(replace_two_chunks_single_cluster->first == expected_old_chunks);
Expand All @@ -518,7 +577,12 @@ BOOST_FIXTURE_TEST_CASE(calc_feerate_diagram_rbf, TestChain100Setup)
const auto conflict_3_entry = pool.GetIter(conflict_3->GetHash()).value();

{
const auto replace_multiple_clusters{pool.CalculateChunksForRBF(/*replacement_fees=*/high_fee, /*replacement_vsize=*/low_size, {conflict_1_entry, conflict_2_entry, conflict_3_entry}, {conflict_1_entry, conflict_2_entry, conflict_3_entry})};
auto changeset = pool.GetChangeSet();
changeset->StageRemoval(conflict_1_entry);
changeset->StageRemoval(conflict_2_entry);
changeset->StageRemoval(conflict_3_entry);
changeset->StageAddition(replacement_tx, high_fee, 0, 1, 0, false, 4, LockPoints());
const auto replace_multiple_clusters{changeset->CalculateChunksForRBF()};
BOOST_CHECK(replace_multiple_clusters.has_value());
BOOST_CHECK(replace_multiple_clusters->first.size() == 3);
BOOST_CHECK(replace_multiple_clusters->second.size() == 1);
Expand All @@ -530,7 +594,13 @@ BOOST_FIXTURE_TEST_CASE(calc_feerate_diagram_rbf, TestChain100Setup)
const auto conflict_1_child_entry = pool.GetIter(conflict_1_child->GetHash()).value();

{
const auto replace_multiple_clusters_2{pool.CalculateChunksForRBF(/*replacement_fees=*/high_fee, /*replacement_vsize=*/low_size, {conflict_1_entry, conflict_2_entry, conflict_3_entry}, {conflict_1_entry, conflict_2_entry, conflict_3_entry, conflict_1_child_entry})};
auto changeset = pool.GetChangeSet();
changeset->StageRemoval(conflict_1_entry);
changeset->StageRemoval(conflict_2_entry);
changeset->StageRemoval(conflict_3_entry);
changeset->StageRemoval(conflict_1_child_entry);
changeset->StageAddition(replacement_tx, high_fee, 0, 1, 0, false, 4, LockPoints());
const auto replace_multiple_clusters_2{changeset->CalculateChunksForRBF()};

BOOST_CHECK(replace_multiple_clusters_2.has_value());
BOOST_CHECK(replace_multiple_clusters_2->first.size() == 4);
Expand All @@ -543,10 +613,17 @@ BOOST_FIXTURE_TEST_CASE(calc_feerate_diagram_rbf, TestChain100Setup)
const auto conflict_1_grand_child_entry = pool.GetIter(conflict_1_child->GetHash()).value();

{
const auto replace_cluster_size_3{pool.CalculateChunksForRBF(/*replacement_fees=*/high_fee, /*replacement_vsize=*/low_size, {conflict_1_entry, conflict_2_entry, conflict_3_entry}, {conflict_1_entry, conflict_2_entry, conflict_3_entry, conflict_1_child_entry, conflict_1_grand_child_entry})};
auto changeset = pool.GetChangeSet();
changeset->StageRemoval(conflict_1_entry);
changeset->StageRemoval(conflict_2_entry);
changeset->StageRemoval(conflict_3_entry);
changeset->StageRemoval(conflict_1_child_entry);
changeset->StageRemoval(conflict_1_grand_child_entry);
changeset->StageAddition(replacement_tx, high_fee, 0, 1, 0, false, 4, LockPoints());
const auto replace_cluster_size_3{changeset->CalculateChunksForRBF()};

BOOST_CHECK(!replace_cluster_size_3.has_value());
BOOST_CHECK_EQUAL(util::ErrorString(replace_cluster_size_3).original, strprintf("%s has 2 descendants, max 1 allowed", conflict_1->GetHash().GetHex()));
BOOST_CHECK_EQUAL(util::ErrorString(replace_cluster_size_3).original, strprintf("%s has both ancestor and descendant, exceeding cluster limit of 2", conflict_1_child->GetHash().GetHex()));
}
}

Expand Down
Loading

0 comments on commit d7dc9fd

Please sign in to comment.