Skip to content

Commit

Permalink
Add a base class for the interface of a value in a map
Browse files Browse the repository at this point in the history
Summary:
All our map containers expect a `ValueInterface` template parameter that must implement specific members.
This diff adds a base class for this interface, which allows us to check statically if it properly implements all members.

Reviewed By: arnaudvenet

Differential Revision: D50013140

fbshipit-source-id: 8621b664db2a4b4d1b93ac0eb1a6a59d683cb204
  • Loading branch information
arthaud authored and facebook-github-bot committed Oct 9, 2023
1 parent c9d491e commit eb70b21
Show file tree
Hide file tree
Showing 15 changed files with 260 additions and 117 deletions.
91 changes: 91 additions & 0 deletions sparta/include/sparta/AbstractMapValue.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <type_traits>

#include <sparta/AbstractDomain.h>
#include <sparta/TypeTraits.h>

namespace sparta {

namespace {
SPARTA_HAS_STATIC_MEMBER_FUNCTION(leq, has_static_leq_member_function)
}

/*
* This describes the `ValueInterface` structure required for abstract maps.
*/
template <typename Derived>
class AbstractMapValue {
public:
/*
* Check that `Derived` implements the `AbstractMapValue` interface, using
* static assertions. This must be called from a method that is instantiated,
* for instance the destructor of the map.
*/
constexpr static void check_interface() {
static_assert(std::is_base_of<AbstractMapValue<Derived>, Derived>::value,
"Derived doesn't inherit from AbstractMapValue");
static_assert(std::is_final<Derived>::value, "Derived is not final");

// Derived::type
using type = typename Derived::type;

/*
* Returns the default value.
*/
// static type default_value();
static_assert(std::is_same<decltype(Derived::default_value()), type>::value,
"Derived::default_value() does not exist");

/*
* Tests whether a value is the default value.
*/
// static bool is_default_value(const type& x);
static_assert(std::is_same<decltype(Derived::is_default_value(
std::declval<const type>())),
bool>::value,
"Derived::is_default_value(const type&) does not exist");

/*
* The equality predicate for values.
*/
// static bool equals(const type& x, const type& y);
static_assert(
std::is_same<decltype(Derived::equals(std::declval<const type>(),
std::declval<const type>())),
bool>::value,
"Derived::equals(const type&, const type&) does not exist");

if constexpr (has_static_leq_member_function<Derived>::value) {
/*
* A partial order relation over values. In order to use the lifted
* partial order relation over maps PatriciaTreeMap::leq(), this method
* must be implemented. Additionally, value::type must be an
* implementation of an AbstractDomain.
*/
// static bool leq(const type& x, const type& y);
static_assert(
std::is_same<decltype(Derived::leq(std::declval<const type>(),
std::declval<const type>())),
bool>::value,
"Derived::leq(const type&, const type&) does not exist");
}

/*
* Whether the default value is top, bottom, or an arbitrary value.
*/
// constexpr static AbstractValueKind default_value_kind;
static_assert(std::is_same<decltype(Derived::default_value_kind),
const AbstractValueKind>::value,
"Derived::default_value_kind does not exist");
}
};

} // namespace sparta
38 changes: 6 additions & 32 deletions sparta/include/sparta/Analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,43 +15,17 @@
#include <vector>

#include <sparta/AbstractDomain.h>
#include <sparta/TypeTraits.h>

namespace sparta {

// To achieve a high level of metaprogramming, we use a little C++ trick to
// check for existence of a member function. This is implemented with SFINAE.
//
// More details:
// https://en.wikibooks.org/wiki/More_C%2B%2B_Idioms/Member_Detector
// https://stackoverflow.com/questions/257288/templated-check-for-the-existence-of-a-class-member-function

#define HAS_MEM_FUNC(func, name) \
template <typename T, typename Sign> \
struct name { \
typedef char yes[1]; \
typedef char no[2]; \
template <typename U, U> \
struct type_check; \
template <typename _1> \
static yes& chk(type_check<Sign, &_1::func>*); \
template <typename> \
static no& chk(...); \
static bool const value = sizeof(chk<T>(0)) == sizeof(yes); \
}

template <bool C, typename T = void>
struct enable_if {
typedef T type;
};

template <typename T>
struct enable_if<false, T> {};

HAS_MEM_FUNC(analyze_edge, has_analyze_edge);
namespace {
SPARTA_HAS_MEMBER_FUNCTION_WITH_SIGNATURE(analyze_edge, has_analyze_edge);
}

// The compiler is going to optionally enable either of the following
template <typename Callsite, typename Edge, typename Domain>
typename enable_if<
typename std::enable_if<
has_analyze_edge<Callsite,
Domain (Callsite::*)(const Edge&, const Domain&)>::value,
Domain>::type
Expand All @@ -63,7 +37,7 @@ optionally_analyze_edge_if_exist(Callsite*,
}

template <typename Callsite, typename Edge, typename Domain>
typename enable_if<
typename std::enable_if<
!has_analyze_edge<Callsite,
Domain (Callsite::*)(const Edge&, const Domain&)>::value,
Domain>::type
Expand Down
64 changes: 37 additions & 27 deletions sparta/include/sparta/FlatMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <boost/container/flat_map.hpp>

#include <sparta/AbstractMapValue.h>
#include <sparta/PatriciaTreeCore.h>

namespace sparta {
Expand All @@ -26,21 +27,21 @@ namespace sparta {
* `PatriciaTreeMap`.
*/
template <typename Key,
typename ValueType,
typename Value = pt_core::SimpleValue<ValueType>,
typename Value,
typename ValueInterface = pt_core::SimpleValue<Value>,
typename KeyCompare = std::less<Key>,
typename KeyEqual = std::equal_to<Key>,
typename AllocatorOrContainer =
boost::container::new_allocator<std::pair<Key, ValueType>>>
boost::container::new_allocator<std::pair<Key, Value>>>
class FlatMap final {
private:
using BoostFlatMap = boost::container::
flat_map<Key, ValueType, KeyCompare, AllocatorOrContainer>;
using BoostFlatMap =
boost::container::flat_map<Key, Value, KeyCompare, AllocatorOrContainer>;

public:
// C++ container concept member types
using key_type = Key;
using mapped_type = typename Value::type;
using mapped_type = typename ValueInterface::type;
using value_type = typename BoostFlatMap::value_type;
using iterator = typename BoostFlatMap::const_iterator;
using const_iterator = iterator;
Expand All @@ -49,8 +50,14 @@ class FlatMap final {
using const_reference = typename BoostFlatMap::const_reference;
using const_pointer = typename BoostFlatMap::const_pointer;

static_assert(std::is_same_v<ValueType, mapped_type>,
"ValueType must be equal to Value::type");
~FlatMap() {
static_assert(std::is_same_v<Value, mapped_type>,
"Value must be equal to ValueInterface::type");
static_assert(std::is_base_of<AbstractMapValue<ValueInterface>,
ValueInterface>::value,
"ValueInterface doesn't inherit from AbstractMapValue");
ValueInterface::check_interface();
}

private:
struct ComparePairWithKey {
Expand All @@ -62,20 +69,20 @@ class FlatMap final {
struct PairEqual {
bool operator()(const value_type& left, const value_type& right) const {
return KeyEqual()(left.first, right.first) &&
Value::equals(left.second, right.second);
ValueInterface::equals(left.second, right.second);
}
};

void erase_default_values() {
this->filter([](const Key&, const mapped_type& value) {
return !Value::is_default_value(value);
return !ValueInterface::is_default_value(value);
});
}

public:
explicit FlatMap() = default;

explicit FlatMap(std::initializer_list<std::pair<Key, ValueType>> l) {
explicit FlatMap(std::initializer_list<std::pair<Key, Value>> l) {
for (const auto& p : l) {
insert_or_assign(p.first, p.second);
}
Expand All @@ -94,7 +101,7 @@ class FlatMap final {
const mapped_type& at(const Key& key) const {
auto it = m_map.find(key);
if (it == m_map.end()) {
static const ValueType default_value = Value::default_value();
static const Value default_value = ValueInterface::default_value();
return default_value;
} else {
return it->second;
Expand All @@ -108,7 +115,7 @@ class FlatMap final {

template <typename V>
FlatMap& insert_or_assign(const Key& key, V&& value) {
if (Value::is_default_value(value)) {
if (ValueInterface::is_default_value(value)) {
remove(key);
} else {
m_map.insert_or_assign(key, std::forward<V>(value));
Expand All @@ -118,9 +125,10 @@ class FlatMap final {

template <typename Operation> // void(mapped_type*)
FlatMap& update(Operation&& operation, const Key& key) {
auto [it, inserted] = m_map.try_emplace(key, Value::default_value());
auto [it, inserted] =
m_map.try_emplace(key, ValueInterface::default_value());
operation(&it->second);
if (Value::is_default_value(it->second)) {
if (ValueInterface::is_default_value(it->second)) {
m_map.erase(it);
}
return *this;
Expand Down Expand Up @@ -149,7 +157,7 @@ class FlatMap final {
if (it == end || !KeyEqual()(it->first, other_it->first)) {
return false;
}
if (!Value::leq(it->second, other_it->second)) {
if (!ValueInterface::leq(it->second, other_it->second)) {
return false;
} else {
++it;
Expand Down Expand Up @@ -182,7 +190,7 @@ class FlatMap final {
if (other_it == other_end || !KeyEqual()(it->first, other_it->first)) {
return false;
}
if (!Value::leq(it->second, other_it->second)) {
if (!ValueInterface::leq(it->second, other_it->second)) {
return false;
} else {
++it;
Expand All @@ -195,20 +203,22 @@ class FlatMap final {

public:
bool leq(const FlatMap& other) const {
static_assert(std::is_base_of_v<AbstractDomain<ValueType>, ValueType>,
static_assert(std::is_base_of_v<AbstractDomain<Value>, Value>,
"leq can only be used when Value implements AbstractDomain");

// Assumes Value::default_value() is either Top or Bottom.
if constexpr (Value::default_value_kind == AbstractValueKind::Top) {
// Assumes ValueInterface::default_value() is either Top or Bottom.
if constexpr (ValueInterface::default_value_kind ==
AbstractValueKind::Top) {
return this->leq_when_default_is_top(other);
} else if constexpr (Value::default_value_kind ==
} else if constexpr (ValueInterface::default_value_kind ==
AbstractValueKind::Bottom) {
return this->leq_when_default_is_bottom(other);
} else {
static_assert(
Value::default_value_kind == AbstractValueKind::Top ||
Value::default_value_kind == AbstractValueKind::Bottom,
"leq can only be used when Value::default_value() is top or bottom");
ValueInterface::default_value_kind == AbstractValueKind::Top ||
ValueInterface::default_value_kind == AbstractValueKind::Bottom,
"leq can only be used when ValueInterface::default_value() is top or "
"bottom");
}
}

Expand All @@ -230,7 +240,7 @@ class FlatMap final {
bool has_default_value = false;
for (auto& p : m_map) {
f(&p.second);
if (Value::is_default_value(p.second)) {
if (ValueInterface::is_default_value(p.second)) {
has_default_value = true;
}
}
Expand All @@ -246,7 +256,7 @@ class FlatMap final {
}
}

template <typename Predicate> // bool(const Key&, const ValueType&)
template <typename Predicate> // bool(const Key&, const mapped_type&)
FlatMap& filter(Predicate&& predicate) {
switch (m_map.size()) {
case 0:
Expand Down Expand Up @@ -323,7 +333,7 @@ class FlatMap final {
++it;
++other_it;
} else {
it->second = Value::default_value();
it->second = ValueInterface::default_value();
++it;
}
}
Expand Down
Loading

0 comments on commit eb70b21

Please sign in to comment.