Skip to content

Commit

Permalink
refactor: remove start/count in bls-worker funcs due to spanification (
Browse files Browse the repository at this point in the history
…#5599)

## Issue being fixed or feature implemented
Follow-up changes for this PR:
#5586


## What was done?
Span has already "pointer + start + length", extra start/count variables
in function signatures are just duplicates.


## How Has This Been Tested?
Run unit/functional tests

## Breaking Changes
N/A


## Checklist:
- [x] I have performed a self-review of my own code
- [ ] I have commented my code, particularly in hard-to-understand areas
- [ ] I have added or updated relevant unit/integration/functional/e2e
tests
- [ ] I have made corresponding changes to the documentation
- [x] I have assigned this pull request to a milestone
  • Loading branch information
knst authored Oct 5, 2023
1 parent c814dca commit 1c66ac3
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 99 deletions.
2 changes: 1 addition & 1 deletion src/bench/bls_dkg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class DKG
ReceiveVvecs();

bench.minEpochIterations(epoch_iters).run([&] {
quorumVvec = blsWorker.BuildQuorumVerificationVector(receivedVvecs, 0, 0, false);
quorumVvec = blsWorker.BuildQuorumVerificationVector(receivedVvecs, false);
});
}

Expand Down
123 changes: 45 additions & 78 deletions src/bls/bls_worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,14 @@
#include <utility>

template <typename T>
bool VerifyVectorHelper(Span<T> vec, size_t start, size_t count)
bool VerifyVectorHelper(Span<T> vec)
{
if (start == 0 && count == 0) {
count = vec.size();
}
std::set<uint256> set;
for (size_t i = start; i < start + count; i++) {
if (!vec[i].IsValid())
for (auto item : vec) {
if (!item.IsValid())
return false;
// check duplicates
if (!set.emplace(vec[i].GetHash()).second) {
if (!set.emplace(item.GetHash()).second) {
return false;
}
}
Expand Down Expand Up @@ -147,18 +144,16 @@ struct Aggregator : public std::enable_shared_from_this<Aggregator<T>> {

// TP can either be a pointer or a reference
template <typename TP>
Aggregator(Span<TP> _inputVec,
size_t start, size_t count,
bool _parallel,
Aggregator(Span<TP> _inputSpan, bool _parallel,
ctpl::thread_pool& _workerPool,
DoneCallback _doneCallback) :
inputVec(std::make_shared<std::vector<const T*>>(count)),
inputVec(std::make_shared<std::vector<const T*>>(_inputSpan.size())),
parallel(_parallel),
workerPool(_workerPool),
doneCallback(std::move(_doneCallback))
{
for (size_t i = 0; i < count; i++) {
(*inputVec)[i] = pointer(_inputVec[start + i]);
for (size_t i = 0; i < _inputSpan.size(); i++) {
(*inputVec)[i] = pointer(_inputSpan[i]);
}
}

Expand Down Expand Up @@ -341,8 +336,6 @@ struct VectorAggregator : public std::enable_shared_from_this<VectorAggregator<T
DoneCallback doneCallback;

VectorVectorType vecs;
size_t start;
size_t count;
bool parallel;
ctpl::thread_pool& workerPool;

Expand All @@ -352,13 +345,10 @@ struct VectorAggregator : public std::enable_shared_from_this<VectorAggregator<T
size_t vecSize;

VectorAggregator(VectorVectorType _vecs,
size_t _start, size_t _count,
bool _parallel, ctpl::thread_pool& _workerPool,
DoneCallback _doneCallback) :
doneCallback(std::move(_doneCallback)),
vecs(_vecs),
start(_start),
count(_count),
parallel(_parallel),
workerPool(_workerPool)
{
Expand All @@ -370,13 +360,13 @@ struct VectorAggregator : public std::enable_shared_from_this<VectorAggregator<T
void Start()
{
for (size_t i = 0; i < vecSize; i++) {
std::vector<const T*> tmp(count);
for (size_t j = 0; j < count; j++) {
tmp[j] = &(*vecs[start + j])[i];
std::vector<const T*> tmp(vecs.size());
for (size_t j = 0; j < vecs.size(); j++) {
tmp[j] = &(*vecs[j])[i];
}

auto self(this->shared_from_this());
auto aggregator = std::make_shared<AggregatorType>(Span{tmp}, 0, count, parallel, workerPool, [self, i](const T& agg) {self->CheckDone(agg, i);});
auto aggregator = std::make_shared<AggregatorType>(Span{tmp}, parallel, workerPool, [self, i](const T& agg) {self->CheckDone(agg, i);});
aggregator->Start();
}
}
Expand Down Expand Up @@ -492,8 +482,8 @@ struct ContributionVerifier : public std::enable_shared_from_this<ContributionVe

// aggregate vvecs and skShares of batch in parallel
auto self(this->shared_from_this());
auto vvecAgg = std::make_shared<VectorAggregator<CBLSPublicKey>>(vvecs, batchState.start, batchState.count, parallel, workerPool, [this, self, batchIdx] (const BLSVerificationVectorPtr& vvec) {HandleAggVvecDone(batchIdx, vvec);});
auto skShareAgg = std::make_shared<Aggregator<CBLSSecretKey>>(Span{skShares}, batchState.start, batchState.count, parallel, workerPool, [this, self, batchIdx] (const CBLSSecretKey& skShare) {HandleAggSkShareDone(batchIdx, skShare);});
auto vvecAgg = std::make_shared<VectorAggregator<CBLSPublicKey>>(vvecs.subspan(batchState.start, batchState.count), parallel, workerPool, [this, self, batchIdx] (const BLSVerificationVectorPtr& vvec) {HandleAggVvecDone(batchIdx, vvec);});
auto skShareAgg = std::make_shared<Aggregator<CBLSSecretKey>>(Span{skShares}.subspan(batchState.start, batchState.count), parallel, workerPool, [this, self, batchIdx] (const CBLSSecretKey& skShare) {HandleAggSkShareDone(batchIdx, skShare);});

vvecAgg->Start();
skShareAgg->Start();
Expand Down Expand Up @@ -594,109 +584,92 @@ struct ContributionVerifier : public std::enable_shared_from_this<ContributionVe
}
};

void CBLSWorker::AsyncBuildQuorumVerificationVector(Span<BLSVerificationVectorPtr> vvecs,
size_t start, size_t count, bool parallel,
void CBLSWorker::AsyncBuildQuorumVerificationVector(Span<BLSVerificationVectorPtr> vvecs, bool parallel,
std::function<void(const BLSVerificationVectorPtr&)> doneCallback)
{
if (start == 0 && count == 0) {
count = vvecs.size();
}
if (vvecs.empty() || count == 0 || start > vvecs.size() || start + count > vvecs.size()) {
if (vvecs.empty()) {
doneCallback(nullptr);
return;
}
if (!VerifyVerificationVectors(vvecs, start, count)) {
if (!VerifyVerificationVectors(vvecs)) {
doneCallback(nullptr);
return;
}

auto agg = std::make_shared<VectorAggregator<CBLSPublicKey>>(vvecs, start, count, parallel, workerPool, std::move(doneCallback));
auto agg = std::make_shared<VectorAggregator<CBLSPublicKey>>(vvecs, parallel, workerPool, std::move(doneCallback));
agg->Start();
}

std::future<BLSVerificationVectorPtr> CBLSWorker::AsyncBuildQuorumVerificationVector(Span<BLSVerificationVectorPtr> vvecs,
size_t start, size_t count, bool parallel)
std::future<BLSVerificationVectorPtr> CBLSWorker::AsyncBuildQuorumVerificationVector(Span<BLSVerificationVectorPtr> vvecs, bool parallel)
{
auto p = BuildFutureDoneCallback<BLSVerificationVectorPtr>();
AsyncBuildQuorumVerificationVector(vvecs, start, count, parallel, std::move(p.first));
AsyncBuildQuorumVerificationVector(vvecs, parallel, std::move(p.first));
return std::move(p.second);
}

BLSVerificationVectorPtr CBLSWorker::BuildQuorumVerificationVector(Span<BLSVerificationVectorPtr> vvecs,
size_t start, size_t count, bool parallel)
BLSVerificationVectorPtr CBLSWorker::BuildQuorumVerificationVector(Span<BLSVerificationVectorPtr> vvecs, bool parallel)
{
return AsyncBuildQuorumVerificationVector(vvecs, start, count, parallel).get();
return AsyncBuildQuorumVerificationVector(vvecs, parallel).get();
}

template <typename T>
void AsyncAggregateHelper(ctpl::thread_pool& workerPool,
Span<T> vec, size_t start, size_t count, bool parallel,
void AsyncAggregateHelper(ctpl::thread_pool& workerPool, Span<T> vec, bool parallel,
std::function<void(const T&)> doneCallback)
{
if (start == 0 && count == 0) {
count = vec.size();
}
if (vec.empty() || count == 0 || start > vec.size() || start + count > vec.size()) {
if (vec.empty()) {
doneCallback(T());
return;
}
if (!VerifyVectorHelper(vec, start, count)) {
if (!VerifyVectorHelper(vec)) {
doneCallback(T());
return;
}

auto agg = std::make_shared<Aggregator<T>>(vec, start, count, parallel, workerPool, std::move(doneCallback));
auto agg = std::make_shared<Aggregator<T>>(vec, parallel, workerPool, std::move(doneCallback));
agg->Start();
}

void CBLSWorker::AsyncAggregateSecretKeys(Span<CBLSSecretKey> secKeys,
size_t start, size_t count, bool parallel,
void CBLSWorker::AsyncAggregateSecretKeys(Span<CBLSSecretKey> secKeys, bool parallel,
std::function<void(const CBLSSecretKey&)> doneCallback)
{
AsyncAggregateHelper(workerPool, secKeys, start, count, parallel, std::move(doneCallback));
AsyncAggregateHelper(workerPool, secKeys, parallel, std::move(doneCallback));
}

std::future<CBLSSecretKey> CBLSWorker::AsyncAggregateSecretKeys(Span<CBLSSecretKey> secKeys,
size_t start, size_t count, bool parallel)
std::future<CBLSSecretKey> CBLSWorker::AsyncAggregateSecretKeys(Span<CBLSSecretKey> secKeys, bool parallel)
{
auto p = BuildFutureDoneCallback<CBLSSecretKey>();
AsyncAggregateSecretKeys(secKeys, start, count, parallel, std::move(p.first));
AsyncAggregateSecretKeys(secKeys, parallel, std::move(p.first));
return std::move(p.second);
}

CBLSSecretKey CBLSWorker::AggregateSecretKeys(Span<CBLSSecretKey> secKeys,
size_t start, size_t count, bool parallel)
CBLSSecretKey CBLSWorker::AggregateSecretKeys(Span<CBLSSecretKey> secKeys, bool parallel)
{
return AsyncAggregateSecretKeys(secKeys, start, count, parallel).get();
return AsyncAggregateSecretKeys(secKeys, parallel).get();
}

void CBLSWorker::AsyncAggregatePublicKeys(Span<CBLSPublicKey> pubKeys,
size_t start, size_t count, bool parallel,
void CBLSWorker::AsyncAggregatePublicKeys(Span<CBLSPublicKey> pubKeys, bool parallel,
std::function<void(const CBLSPublicKey&)> doneCallback)
{
AsyncAggregateHelper(workerPool, pubKeys, start, count, parallel, std::move(doneCallback));
AsyncAggregateHelper(workerPool, pubKeys, parallel, std::move(doneCallback));
}

std::future<CBLSPublicKey> CBLSWorker::AsyncAggregatePublicKeys(Span<CBLSPublicKey> pubKeys,
size_t start, size_t count, bool parallel)
std::future<CBLSPublicKey> CBLSWorker::AsyncAggregatePublicKeys(Span<CBLSPublicKey> pubKeys, bool parallel)
{
auto p = BuildFutureDoneCallback<CBLSPublicKey>();
AsyncAggregatePublicKeys(pubKeys, start, count, parallel, std::move(p.first));
AsyncAggregatePublicKeys(pubKeys, parallel, std::move(p.first));
return std::move(p.second);
}

void CBLSWorker::AsyncAggregateSigs(Span<CBLSSignature> sigs,
size_t start, size_t count, bool parallel,
void CBLSWorker::AsyncAggregateSigs(Span<CBLSSignature> sigs, bool parallel,
std::function<void(const CBLSSignature&)> doneCallback)
{
AsyncAggregateHelper(workerPool, sigs, start, count, parallel, std::move(doneCallback));
AsyncAggregateHelper(workerPool, sigs, parallel, std::move(doneCallback));
}

std::future<CBLSSignature> CBLSWorker::AsyncAggregateSigs(Span<CBLSSignature> sigs,
size_t start, size_t count, bool parallel)
std::future<CBLSSignature> CBLSWorker::AsyncAggregateSigs(Span<CBLSSignature> sigs, bool parallel)
{
auto p = BuildFutureDoneCallback<CBLSSignature>();
AsyncAggregateSigs(sigs, start, count, parallel, std::move(p.first));
AsyncAggregateSigs(sigs, parallel, std::move(p.first));
return std::move(p.second);
}

Expand Down Expand Up @@ -757,25 +730,19 @@ std::future<bool> CBLSWorker::AsyncVerifyContributionShare(const CBLSId& forId,
return workerPool.push(f);
}

bool CBLSWorker::VerifyVerificationVector(Span<CBLSPublicKey> vvec, size_t start, size_t count)
bool CBLSWorker::VerifyVerificationVector(Span<CBLSPublicKey> vvec)
{
return VerifyVectorHelper(vvec, start, count);
return VerifyVectorHelper(vvec);
}

bool CBLSWorker::VerifyVerificationVectors(Span<BLSVerificationVectorPtr> vvecs,
size_t start, size_t count)
bool CBLSWorker::VerifyVerificationVectors(Span<BLSVerificationVectorPtr> vvecs)
{
if (start == 0 && count == 0) {
count = vvecs.size();
}

std::set<uint256> set;
for (size_t i = 0; i < count; i++) {
const auto& vvec = vvecs[start + i];
for (const auto& vvec : vvecs) {
if (vvec == nullptr) {
return false;
}
if (vvec->size() != vvecs[start]->size()) {
if (vvec->size() != vvecs[0]->size()) {
return false;
}
for (size_t j = 0; j < vvec->size(); j++) {
Expand Down
32 changes: 12 additions & 20 deletions src/bls/bls_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,37 +69,29 @@ class CBLSWorker
// [ a1+a2+a3+a4, b1+b2+b3+b4, c1+c2+c3+c4, d1+d2+d3+d4]
// Multiple things can be parallelized here. For example, all 4 entries in the result vector can be calculated in parallel
// Also, each individual vector can be split into multiple batches and aggregating the batches can also be parallelized.
void AsyncBuildQuorumVerificationVector(Span<BLSVerificationVectorPtr> vvecs,
size_t start, size_t count, bool parallel,
void AsyncBuildQuorumVerificationVector(Span<BLSVerificationVectorPtr> vvecs, bool parallel,
std::function<void(const BLSVerificationVectorPtr&)> doneCallback);
std::future<BLSVerificationVectorPtr> AsyncBuildQuorumVerificationVector(Span<BLSVerificationVectorPtr> vvecs,
size_t start, size_t count, bool parallel);
BLSVerificationVectorPtr BuildQuorumVerificationVector(Span<BLSVerificationVectorPtr> vvecs,
size_t start = 0, size_t count = 0, bool parallel = true);
std::future<BLSVerificationVectorPtr> AsyncBuildQuorumVerificationVector(Span<BLSVerificationVectorPtr> vvecs, bool parallel);
BLSVerificationVectorPtr BuildQuorumVerificationVector(Span<BLSVerificationVectorPtr> vvecs, bool parallel = true);

// The following functions are all used to aggregate single vectors
// Inputs are in the following form:
// [a, b, c, d],
// The result is simply a+b+c+d
// Aggregation is parallelized by splitting up the input vector into multiple batches and then aggregating the individual batch results
void AsyncAggregateSecretKeys(Span<CBLSSecretKey>,
size_t start, size_t count, bool parallel,
bool parallel,
std::function<void(const CBLSSecretKey&)> doneCallback);
std::future<CBLSSecretKey> AsyncAggregateSecretKeys(Span<CBLSSecretKey> secKeys,
size_t start, size_t count, bool parallel);
CBLSSecretKey AggregateSecretKeys(Span<CBLSSecretKey> secKeys, size_t start = 0, size_t count = 0, bool parallel = true);
std::future<CBLSSecretKey> AsyncAggregateSecretKeys(Span<CBLSSecretKey> secKeys, bool parallel);
CBLSSecretKey AggregateSecretKeys(Span<CBLSSecretKey> secKeys, bool parallel = true);

void AsyncAggregatePublicKeys(Span<CBLSPublicKey> pubKeys,
size_t start, size_t count, bool parallel,
void AsyncAggregatePublicKeys(Span<CBLSPublicKey> pubKeys, bool parallel,
std::function<void(const CBLSPublicKey&)> doneCallback);
std::future<CBLSPublicKey> AsyncAggregatePublicKeys(Span<CBLSPublicKey> pubKeys,
size_t start, size_t count, bool parallel);
std::future<CBLSPublicKey> AsyncAggregatePublicKeys(Span<CBLSPublicKey> pubKeys, bool parallel);

void AsyncAggregateSigs(Span<CBLSSignature> sigs,
size_t start, size_t count, bool parallel,
void AsyncAggregateSigs(Span<CBLSSignature> sigs, bool parallel,
std::function<void(const CBLSSignature&)> doneCallback);
std::future<CBLSSignature> AsyncAggregateSigs(Span<CBLSSignature> sigs,
size_t start, size_t count, bool parallel);
std::future<CBLSSignature> AsyncAggregateSigs(Span<CBLSSignature> sigs, bool parallel);

// Calculate public key share from public key vector and id. Not parallelized
static CBLSPublicKey BuildPubKeyShare(const BLSVerificationVectorPtr& vvec, const CBLSId& id);
Expand All @@ -120,8 +112,8 @@ class CBLSWorker
std::future<bool> AsyncVerifyContributionShare(const CBLSId& forId, const BLSVerificationVectorPtr& vvec, const CBLSSecretKey& skContribution);

// Simple verification of vectors. Checks x.IsValid() for every entry and checks for duplicate entries
static bool VerifyVerificationVector(Span<CBLSPublicKey> vvec, size_t start = 0, size_t count = 0);
static bool VerifyVerificationVectors(Span<BLSVerificationVectorPtr> vvecs, size_t start = 0, size_t count = 0);
static bool VerifyVerificationVector(Span<CBLSPublicKey> vvec);
static bool VerifyVerificationVectors(Span<BLSVerificationVectorPtr> vvecs);

// Internally batched signature signing and verification
void AsyncSign(const CBLSSecretKey& secKey, const uint256& msgHash, const SignDoneCallback& doneCallback);
Expand Down

0 comments on commit 1c66ac3

Please sign in to comment.