Skip to content

Commit

Permalink
harmonize insert/emplace on InsertOnlyConcurrent* collections
Browse files Browse the repository at this point in the history
Summary:
Besides helping with the next diff, this removes the need for look-ups immediately after inserting an element.

This is a behavior-preserving change.

Reviewed By: agampe

Differential Revision: D51519942

fbshipit-source-id: 08e64180aa037d54f3f1bcc40bd98f65cfff72ec
  • Loading branch information
Nikolai Tillmann authored and facebook-github-bot committed Nov 22, 2023
1 parent a30308c commit 94e97e5
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 53 deletions.
89 changes: 58 additions & 31 deletions libredex/ConcurrentContainers.h
Original file line number Diff line number Diff line change
Expand Up @@ -1562,24 +1562,45 @@ class InsertOnlyConcurrentMap final
const Value& at_unsafe(const Key& key) const { return at(key); }

/*
* The Boolean return value denotes whether the insertion took place.
* This operation is always thread-safe.
*
* Note that while the STL containers' insert() methods return both an
* iterator and a boolean success value, we only return the boolean value
* here as any operations on a returned iterator are not guaranteed to be
* thread-safe.
* Returns a pair consisting of a pointer on the inserted element (or the
* element that prevented the insertion) and a boolean denoting whether the
* insertion took place. This operation is always thread-safe.
*/
bool insert(const KeyValuePair& entry) {
std::pair<const Value*, bool> insert(const KeyValuePair& entry) {
size_t slot = Hash()(entry.first) % n_slots;
auto& map = this->get_container(slot);
return map.try_insert(entry).success;
auto insertion_result = map.try_insert(entry);
return std::make_pair(&insertion_result.stored_value_ptr->second,
insertion_result.success);
}

bool insert(KeyValuePair&& entry) {
/*
* Returns a pair consisting of a pointer on the inserted element (or the
* element that prevented the insertion) and a boolean denoting whether the
* insertion took place. This operation is always thread-safe.
*/
std::pair<const Value*, bool> insert(KeyValuePair&& entry) {
size_t slot = Hash()(entry.first) % n_slots;
auto& map = this->get_container(slot);
return map.try_insert(std::forward<KeyValuePair>(entry)).success;
auto insertion_result = map.try_insert(std::move(entry));
return std::make_pair(&insertion_result.stored_value_ptr->second,
insertion_result.success);
}

std::pair<Value*, bool> insert_unsafe(const KeyValuePair& entry) {
size_t slot = Hash()(entry.first) % n_slots;
auto& map = this->get_container(slot);
auto insertion_result = map.try_insert(entry);
return std::make_pair(&insertion_result.stored_value_ptr->second,
insertion_result.success);
}

std::pair<Value*, bool> insert_unsafe(KeyValuePair entry) {
size_t slot = Hash()(entry.first) % n_slots;
auto& map = this->get_container(slot);
auto insertion_result = map.try_insert(std::move(entry));
return std::make_pair(&insertion_result.stored_value_ptr->second,
insertion_result.success);
}

/*
Expand All @@ -1601,28 +1622,12 @@ class InsertOnlyConcurrentMap final
}
}

void insert_or_assign_unsafe(const KeyValuePair& entry) {
size_t slot = Hash()(entry.first) % n_slots;
auto& map = this->get_container(slot);
auto insertion_result = map.try_emplace(entry);
if (insertion_result.success) {
return;
}
auto* constructed_value = insertion_result.incidentally_constructed_value();
if (constructed_value) {
insertion_result.stored_value_ptr->second =
std::move(constructed_value->second);
} else {
insertion_result.stored_value_ptr->second = entry.second;
}
}

void insert_or_assign_unsafe(KeyValuePair&& entry) {
std::pair<Value*, bool> insert_or_assign_unsafe(KeyValuePair&& entry) {
size_t slot = Hash()(entry.first) % n_slots;
auto& map = this->get_container(slot);
auto insertion_result = map.try_emplace(std::forward<KeyValuePair>(entry));
if (insertion_result.success) {
return;
return std::make_pair(&insertion_result.stored_value_ptr->second, true);
}
auto* constructed_value = insertion_result.incidentally_constructed_value();
if (constructed_value) {
Expand All @@ -1632,17 +1637,20 @@ class InsertOnlyConcurrentMap final
insertion_result.stored_value_ptr->second =
std::forward<Value>(entry.second);
}
return std::make_pair(&insertion_result.stored_value_ptr->second, false);
}

/*
* This operation is always thread-safe.
*/
template <typename... Args>
bool emplace(Args&&... args) {
std::pair<const Value*, bool> emplace(Args&&... args) {
KeyValuePair entry(std::forward<Args>(args)...);
size_t slot = Hash()(entry.first) % n_slots;
auto& map = this->get_container(slot);
return map.try_insert(std::move(entry)).success;
auto insertion_result = map.try_insert(std::move(entry));
return std::make_pair(&insertion_result.stored_value_ptr->second,
insertion_result.success);
}

template <typename... Args>
Expand Down Expand Up @@ -2080,13 +2088,32 @@ class InsertOnlyConcurrentSet final
return {insertion_result.stored_value_ptr, insertion_result.success};
}

/*
* Returns a pair consisting of a pointer on the inserted element (or the
* element that prevented the insertion) and a boolean denoting whether the
* insertion took place. This operation is always thread-safe.
*/
std::pair<const Key*, bool> insert(Key&& key) {
size_t slot = Hash()(key) % n_slots;
auto& set = this->get_container(slot);
auto insertion_result = set.try_insert(std::forward<Key>(key));
return {insertion_result.stored_value_ptr, insertion_result.success};
}

std::pair<Key*, bool> insert_unsafe(const Key& key) {
size_t slot = Hash()(key) % n_slots;
auto& set = this->get_container(slot);
auto insertion_result = set.try_insert(key);
return {insertion_result.stored_value_ptr, insertion_result.success};
}

std::pair<Key*, bool> insert_unsafe(Key&& key) {
size_t slot = Hash()(key) % n_slots;
auto& set = this->get_container(slot);
auto insertion_result = set.try_insert(std::forward<Key>(key));
return {insertion_result.stored_value_ptr, insertion_result.success};
}

/*
* Return a pointer on the element, or `nullptr` if the element is not in the
* set. This operation is always thread-safe.
Expand Down
12 changes: 6 additions & 6 deletions libredex/KeepReason.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ namespace {

// Lint will complain about this, but it is better than having to
// forward-declare all of concurrent containers.
std::unique_ptr<InsertOnlyConcurrentMap<keep_reason::Reason*,
keep_reason::Reason*,
std::unique_ptr<InsertOnlyConcurrentSet<keep_reason::Reason*,
keep_reason::ReasonPtrHash,
keep_reason::ReasonPtrEqual>>
s_keep_reasons{nullptr};
Expand All @@ -66,17 +65,18 @@ bool Reason::s_record_keep_reasons = false;
void Reason::set_record_keep_reasons(bool v) {
s_record_keep_reasons = v;
if (v && s_keep_reasons == nullptr) {
s_keep_reasons = std::make_unique<InsertOnlyConcurrentMap<
keep_reason::Reason*, keep_reason::Reason*, keep_reason::ReasonPtrHash,
s_keep_reasons = std::make_unique<InsertOnlyConcurrentSet<
keep_reason::Reason*, keep_reason::ReasonPtrHash,
keep_reason::ReasonPtrEqual>>();
}
}

Reason* Reason::try_insert(std::unique_ptr<Reason> to_insert) {
if (s_keep_reasons->emplace(to_insert.get(), to_insert.get())) {
auto [reason_ptr, emplaced] = s_keep_reasons->insert(to_insert.get());
if (emplaced) {
return to_insert.release();
}
return s_keep_reasons->at(to_insert.get());
return const_cast<Reason*>(*reason_ptr);
}

void Reason::release_keep_reasons() { s_keep_reasons.reset(); }
Expand Down
31 changes: 16 additions & 15 deletions libredex/MethodOverrideGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class GraphBuilder {
}

private:
ClassSignatureMap analyze_non_interface(const DexClass* cls) {
const ClassSignatureMap& analyze_non_interface(const DexClass* cls) {
always_assert(!is_interface(cls));
auto* res = m_class_signature_maps.get(cls);
if (res) {
Expand Down Expand Up @@ -155,7 +155,9 @@ class GraphBuilder {
&class_signatures.unimplemented);
}

if (m_class_signature_maps.emplace(cls, class_signatures)) {
auto [map_ptr, emplaced] =
m_class_signature_maps.emplace(cls, class_signatures);
if (emplaced) {
// Mark all overriding methods as reachable via their parent method ref.
for (auto* method : cls->get_vmethods()) {
const auto& overridden_set =
Expand All @@ -181,12 +183,12 @@ class GraphBuilder {
}
}
}
return class_signatures;
}
return m_class_signature_maps.at(cls);

return *map_ptr;
}

SignatureMap analyze_interface(const DexClass* cls) {
const SignatureMap& analyze_interface(const DexClass* cls) {
always_assert(is_interface(cls));
auto* res = m_interface_signature_maps.get(cls);
if (res) {
Expand All @@ -199,7 +201,9 @@ class GraphBuilder {
update_signature_map(method, MethodSet{method}, &interface_signatures);
}

if (m_interface_signature_maps.emplace(cls, interface_signatures)) {
auto [map_ptr, emplaced] =
m_interface_signature_maps.emplace(cls, interface_signatures);
if (emplaced) {
for (auto* method : cls->get_vmethods()) {
const auto& overridden_set =
inherited_interface_signatures.at(method->get_name())
Expand All @@ -226,13 +230,12 @@ class GraphBuilder {
method, /* overriding_is_interface */ true);
}
}

return interface_signatures;
}
return m_interface_signature_maps.at(cls);

return *map_ptr;
}

SignatureMap unify_super_interface_signatures(const DexClass* cls) {
const SignatureMap& unify_super_interface_signatures(const DexClass* cls) {
auto* type_list = cls->get_interfaces();
auto* res = m_unified_interfaces_signature_maps.get(type_list);
if (res) {
Expand All @@ -248,11 +251,9 @@ class GraphBuilder {
}
}

if (m_unified_interfaces_signature_maps.emplace(
type_list, super_interface_signatures)) {
return super_interface_signatures;
}
return m_unified_interfaces_signature_maps.at(type_list);
auto [map_ptr, _] = m_unified_interfaces_signature_maps.emplace(
type_list, super_interface_signatures);
return *map_ptr;
}

std::unique_ptr<Graph> m_graph;
Expand Down
3 changes: 2 additions & 1 deletion libredex/Purity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,8 @@ get_wto_successors(
}
}
}
auto emplaced = concurrent_cache->emplace(m, std::move(successors));
auto [_, emplaced] =
concurrent_cache->emplace(m, std::move(successors));
always_assert(emplaced);
},
wto_nodes);
Expand Down

0 comments on commit 94e97e5

Please sign in to comment.