Skip to content

Commit

Permalink
Add a base class for abstract maps
Browse files Browse the repository at this point in the history
Summary:
We now have multiple implementations of "map" containers.
We want to create a generic adapter that gives a partition abstract domain and environment abstract domain for any map.
First, let's create a base class for abstract maps that enforces a common API.

Note that map APIs can have slight differences if the map is a mutable structure (hash map, flat map) or an immutable structure (patricia tree).
This works by enforcing the map to declare it's mutability and using `if constexpr` to check for the different interfaces.

The alternative would be to enforce a common interface, but that is a costly refactoring.

Reviewed By: arnaudvenet

Differential Revision: D50019979

fbshipit-source-id: 0a22e23a85111f2e3bfd7152b5c605ecbf49fe09
  • Loading branch information
arthaud authored and facebook-github-bot committed Oct 9, 2023
1 parent eb70b21 commit 95dcf5b
Show file tree
Hide file tree
Showing 10 changed files with 340 additions and 120 deletions.
271 changes: 271 additions & 0 deletions sparta/include/sparta/AbstractMap.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <type_traits>
#include <utility>

namespace sparta {

enum class AbstractMapMutability {
// The abstract map is immutable.
// Unary operators will have signature `Domain(const Domain&)`
// Binary operators will have signature `Domain(const Domain&, const Domain&)`
Immutable,

// The abstract map is mutable.
// Unary operators will have signature `void(Domain*)`
// Binary operators will have signature `void(Domain*, const Domain&)`
Mutable,
};

/*
* This describes the API for a generic map container.
*/
template <typename Derived>
class AbstractMap {
public:
~AbstractMap() {
// The destructor is the only method that is guaranteed to be created when a
// class template is instantiated. This is a good place to perform all the
// sanity checks on the template parameters.
static_assert(std::is_base_of<AbstractMap<Derived>, Derived>::value,
"Derived doesn't inherit from AbstractMap");
static_assert(std::is_final<Derived>::value, "Derived is not final");

using Key = typename Derived::key_type;
using Value = typename Derived::mapped_type;
using KeyValuePair = typename Derived::value_type;
using Iterator = typename Derived::iterator;

// Derived();
static_assert(std::is_default_constructible<Derived>::value,
"Derived is not default constructible");

// Derived(const Derived&);
static_assert(std::is_copy_constructible<Derived>::value,
"Derived is not copy constructible");

// Derived& operator=(const Derived&);
static_assert(std::is_copy_assignable<Derived>::value,
"Derived is not copy assignable");

// constexpr static AbstractMapMutability mutability;
static_assert(std::is_same<decltype(Derived::mutability),
const AbstractMapMutability>::value,
"Derived::mutability does not exist");

// bool empty() const;
static_assert(std::is_same<decltype(std::declval<const Derived>().empty()),
bool>::value,
"Derived::empty() does not exist");

// std::size_t size() const;
static_assert(std::is_same<decltype(std::declval<const Derived>().size()),
std::size_t>::value,
"Derived::size() does not exist");

// std::size_t max_size() const;
static_assert(
std::is_same<decltype(std::declval<const Derived>().max_size()),
std::size_t>::value,
"Derived::max_size() does not exist");

// iterator begin() const;
static_assert(std::is_same<decltype(std::declval<const Derived>().begin()),
Iterator>::value,
"Derived::begin() does not exist");

// iterator end() const;
static_assert(std::is_same<decltype(std::declval<const Derived>().end()),
Iterator>::value,
"Derived::begin() does not exist");

// const Value& at(const Key&) const;
static_assert(std::is_same<decltype(std::declval<const Derived>().at(
std::declval<const Key>())),
const Value&>::value,
"Derived::at(const Key&) does not exist");

// Derived& insert_or_assign(const Key& key, Value value);
static_assert(
std::is_same<decltype(std::declval<Derived>().insert_or_assign(
std::declval<const Key>(), std::declval<Value>())),
Derived&>::value,
"Derived::insert_or_assign(const Key&, Value) does not exist");

// Derived& remove(const Key& key);
static_assert(std::is_same<decltype(std::declval<Derived>().remove(
std::declval<const Key>())),
Derived&>::value,
"Derived::remove(const Key&) does not exist");

// void clear();
static_assert(std::is_void_v<decltype(std::declval<Derived>().clear())>,
"Derived::clear() does not exist");

// void visit(Visitor&& visitor) const;
static_assert(std::is_same<decltype(std::declval<const Derived>().visit(
std::declval<void(const KeyValuePair&)>())),
void>::value,
"Derived::visit(Visitor&&) does not exist");

// Derived& filter(Predicate&& predicate);
static_assert(
std::is_same<decltype(std::declval<Derived>().filter(
std::declval<bool(const Key&, const Value&)>())),
Derived&>::value,
"Derived::filter(Predicate&&) does not exist");

/*
* The partial order relation.
*/
// bool leq(const Derived& other) const;
static_assert(std::is_same<decltype(std::declval<const Derived>().leq(
std::declval<const Derived>())),
bool>::value,
"Derived::leq(const Derived&) does not exist");

/*
* a.equals(b) is semantically equivalent to a.leq(b) && b.leq(a).
*/
// bool equals(const Derived& other) const;
static_assert(std::is_same<decltype(std::declval<const Derived>().equals(
std::declval<const Derived>())),
bool>::value,
"Derived::equals(const Derived&) does not exist");

if constexpr (Derived::mutability == AbstractMapMutability::Immutable) {
// Derived& update(Operation&& operation, const Key& key);
static_assert(std::is_same<decltype(std::declval<Derived>().update(
std::declval<Value(const Value&)>(),
std::declval<const Key>())),
Derived&>::value,
"Derived::update(Operation&&, const Key&) does not exist");

// bool transform(MappingFunction&& f);
static_assert(std::is_same<decltype(std::declval<Derived>().transform(
std::declval<Value(const Value&)>())),
bool>::value,
"Derived::transform(MappingFunction&&) does not exist");

// Derived& union_with(CombiningFunction&& combine, const Derived& other);
static_assert(
std::is_same<decltype(std::declval<Derived>().union_with(
std::declval<Value(const Value&, const Value&)>(),
std::declval<const Derived>())),
Derived&>::value,
"Derived::union_with(CombiningFunction&&, const Derived&) does not "
"exist");

// Derived& intersection_with(CombiningFunction&& combine, const Derived&
// other);
static_assert(
std::is_same<decltype(std::declval<Derived>().intersection_with(
std::declval<Value(const Value&, const Value&)>(),
std::declval<const Derived>())),
Derived&>::value,
"Derived::intersection_with(CombiningFunction&&, const Derived&) "
"does "
"not exist");

// Derived& difference_with(CombiningFunction&& combine, const Derived&
// other);
static_assert(
std::is_same<decltype(std::declval<Derived>().difference_with(
std::declval<Value(const Value&, const Value&)>(),
std::declval<const Derived>())),
Derived&>::value,
"Derived::difference_with(CombiningFunction&&, const Derived&) does "
"not exist");
} else if constexpr (Derived::mutability ==
AbstractMapMutability::Mutable) {
// Derived& update(Operation&& operation, const Key& key);
static_assert(std::is_same<decltype(std::declval<Derived>().update(
std::declval<void(Value*)>(),
std::declval<const Key>())),
Derived&>::value,
"Derived::update(Operation&&, const Key&) does not exist");

// void transform(MappingFunction&& f);
static_assert(std::is_same<decltype(std::declval<Derived>().transform(
std::declval<void(Value*)>())),
void>::value,
"Derived::transform(MappingFunction&&) does not exist");

// Derived& union_with(CombiningFunction&& combine, const Derived& other);
static_assert(
std::is_same<decltype(std::declval<Derived>().union_with(
std::declval<void(Value*, const Value&)>(),
std::declval<const Derived>())),
Derived&>::value,
"Derived::union_with(CombiningFunction&&, const Derived&) does not "
"exist");

// Derived& intersection_with(CombiningFunction&& combine, const Derived&
// other);
static_assert(
std::is_same<decltype(std::declval<Derived>().intersection_with(
std::declval<void(Value*, const Value&)>(),
std::declval<const Derived>())),
Derived&>::value,
"Derived::intersection_with(CombiningFunction&&, const Derived&) "
"does not exist");

// Derived& difference_with(CombiningFunction&& combine, const Derived&
// other);
static_assert(
std::is_same<decltype(std::declval<Derived>().difference_with(
std::declval<void(Value*, const Value&)>(),
std::declval<const Derived>())),
Derived&>::value,
"Derived::difference_with(CombiningFunction&&, const Derived&) does "
"not exist");
}
}

/*
* Many C++ libraries default to using operator== to check for equality,
* so we define it here as an alias of equals().
*/
friend bool operator==(const Derived& self, const Derived& other) {
return self.equals(other);
}

friend bool operator!=(const Derived& self, const Derived& other) {
return !self.equals(other);
}

template <typename CombiningFunction>
Derived get_union_with(CombiningFunction&& combine,
const Derived& other) const {
// Here and below: the static_cast is required in order to instruct
// the compiler to use the copy constructor of the derived class.
Derived result(static_cast<const Derived&>(*this));
result.union_with(std::forward<CombiningFunction>(combine), other);
return result;
}

template <typename CombiningFunction>
Derived get_intersection_with(CombiningFunction&& combine,
const Derived& other) const {
Derived result(static_cast<const Derived&>(*this));
result.intersection_with(std::forward<CombiningFunction>(combine), other);
return result;
}

template <typename CombiningFunction>
Derived get_difference_with(CombiningFunction&& combine,
const Derived& other) const {
Derived result(static_cast<const Derived&>(*this));
result.difference_with(std::forward<CombiningFunction>(combine), other);
return result;
}
};

} // namespace sparta
31 changes: 18 additions & 13 deletions sparta/include/sparta/FlatMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <boost/container/flat_map.hpp>

#include <sparta/AbstractMap.h>
#include <sparta/AbstractMapValue.h>
#include <sparta/PatriciaTreeCore.h>

Expand All @@ -33,7 +34,12 @@ template <typename Key,
typename KeyEqual = std::equal_to<Key>,
typename AllocatorOrContainer =
boost::container::new_allocator<std::pair<Key, Value>>>
class FlatMap final {
class FlatMap final : public AbstractMap<FlatMap<Key,
Value,
ValueInterface,
KeyCompare,
KeyEqual,
AllocatorOrContainer>> {
private:
using BoostFlatMap =
boost::container::flat_map<Key, Value, KeyCompare, AllocatorOrContainer>;
Expand All @@ -50,6 +56,9 @@ class FlatMap final {
using const_reference = typename BoostFlatMap::const_reference;
using const_pointer = typename BoostFlatMap::const_pointer;

constexpr static AbstractMapMutability mutability =
AbstractMapMutability::Mutable;

~FlatMap() {
static_assert(std::is_same_v<Value, mapped_type>,
"Value must be equal to ValueInterface::type");
Expand Down Expand Up @@ -227,14 +236,6 @@ class FlatMap final {
other.m_map.end(), PairEqual());
}

friend bool operator==(const FlatMap& m1, const FlatMap& m2) {
return m1.equals(m2);
}

friend bool operator!=(const FlatMap& m1, const FlatMap& m2) {
return !m1.equals(m2);
}

template <typename MappingFunction> // void(mapped_type*)
void transform(MappingFunction&& f) {
bool has_default_value = false;
Expand Down Expand Up @@ -293,7 +294,7 @@ class FlatMap final {
// Requires CombiningFunction to coerce to
// std::function<void(mapped_type*, const mapped_type&)>
template <typename CombiningFunction>
void union_with(CombiningFunction&& combine, const FlatMap& other) {
FlatMap& union_with(CombiningFunction&& combine, const FlatMap& other) {
auto it = m_map.begin(), end = m_map.end();
auto other_it = other.m_map.begin(), other_end = other.m_map.end();
while (other_it != other_end) {
Expand All @@ -313,12 +314,14 @@ class FlatMap final {
++other_it;
}
erase_default_values();
return *this;
}

// Requires CombiningFunction to coerce to
// std::function<void(mapped_type*, const mapped_type&)>
template <typename CombiningFunction>
void intersection_with(CombiningFunction&& combine, const FlatMap& other) {
FlatMap& intersection_with(CombiningFunction&& combine,
const FlatMap& other) {
auto it = m_map.begin(), end = m_map.end();
auto other_it = other.m_map.begin(), other_end = other.m_map.end();
while (it != end) {
Expand All @@ -338,19 +341,20 @@ class FlatMap final {
}
}
erase_default_values();
return *this;
}

// Requires CombiningFunction to coerce to
// std::function<void(mapped_type*, const mapped_type&)>
// Requires `combine(bottom, ...)` to be a no-op.
template <typename CombiningFunction>
void difference_with(CombiningFunction&& combine, const FlatMap& other) {
FlatMap& difference_with(CombiningFunction&& combine, const FlatMap& other) {
auto it = m_map.begin(), end = m_map.end();
auto other_it = other.m_map.begin(), other_end = other.m_map.end();
while (other_it != other_end) {
it = std::lower_bound(it, end, other_it->first, ComparePairWithKey());
if (it == end) {
return;
break;
}
if (KeyEqual()(it->first, other_it->first)) {
combine(&it->second, other_it->second);
Expand All @@ -359,6 +363,7 @@ class FlatMap final {
++other_it;
}
erase_default_values();
return *this;
}

void clear() { m_map.clear(); }
Expand Down
Loading

0 comments on commit 95dcf5b

Please sign in to comment.