diff --git a/sparta/include/sparta/AbstractMap.h b/sparta/include/sparta/AbstractMap.h new file mode 100644 index 00000000000..7d9902c8633 --- /dev/null +++ b/sparta/include/sparta/AbstractMap.h @@ -0,0 +1,271 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace sparta { + +enum class AbstractMapMutability { + // The abstract map is immutable. + // Unary operators will have signature `Domain(const Domain&)` + // Binary operators will have signature `Domain(const Domain&, const Domain&)` + Immutable, + + // The abstract map is mutable. + // Unary operators will have signature `void(Domain*)` + // Binary operators will have signature `void(Domain*, const Domain&)` + Mutable, +}; + +/* + * This describes the API for a generic map container. + */ +template +class AbstractMap { + public: + ~AbstractMap() { + // The destructor is the only method that is guaranteed to be created when a + // class template is instantiated. This is a good place to perform all the + // sanity checks on the template parameters. + static_assert(std::is_base_of, Derived>::value, + "Derived doesn't inherit from AbstractMap"); + static_assert(std::is_final::value, "Derived is not final"); + + using Key = typename Derived::key_type; + using Value = typename Derived::mapped_type; + using KeyValuePair = typename Derived::value_type; + using Iterator = typename Derived::iterator; + + // Derived(); + static_assert(std::is_default_constructible::value, + "Derived is not default constructible"); + + // Derived(const Derived&); + static_assert(std::is_copy_constructible::value, + "Derived is not copy constructible"); + + // Derived& operator=(const Derived&); + static_assert(std::is_copy_assignable::value, + "Derived is not copy assignable"); + + // constexpr static AbstractMapMutability mutability; + static_assert(std::is_same::value, + "Derived::mutability does not exist"); + + // bool empty() const; + static_assert(std::is_same().empty()), + bool>::value, + "Derived::empty() does not exist"); + + // std::size_t size() const; + static_assert(std::is_same().size()), + std::size_t>::value, + "Derived::size() does not exist"); + + // std::size_t max_size() const; + static_assert( + std::is_same().max_size()), + std::size_t>::value, + "Derived::max_size() does not exist"); + + // iterator begin() const; + static_assert(std::is_same().begin()), + Iterator>::value, + "Derived::begin() does not exist"); + + // iterator end() const; + static_assert(std::is_same().end()), + Iterator>::value, + "Derived::begin() does not exist"); + + // const Value& at(const Key&) const; + static_assert(std::is_same().at( + std::declval())), + const Value&>::value, + "Derived::at(const Key&) does not exist"); + + // Derived& insert_or_assign(const Key& key, Value value); + static_assert( + std::is_same().insert_or_assign( + std::declval(), std::declval())), + Derived&>::value, + "Derived::insert_or_assign(const Key&, Value) does not exist"); + + // Derived& remove(const Key& key); + static_assert(std::is_same().remove( + std::declval())), + Derived&>::value, + "Derived::remove(const Key&) does not exist"); + + // void clear(); + static_assert(std::is_void_v().clear())>, + "Derived::clear() does not exist"); + + // void visit(Visitor&& visitor) const; + static_assert(std::is_same().visit( + std::declval())), + void>::value, + "Derived::visit(Visitor&&) does not exist"); + + // Derived& filter(Predicate&& predicate); + static_assert( + std::is_same().filter( + std::declval())), + Derived&>::value, + "Derived::filter(Predicate&&) does not exist"); + + /* + * The partial order relation. + */ + // bool leq(const Derived& other) const; + static_assert(std::is_same().leq( + std::declval())), + bool>::value, + "Derived::leq(const Derived&) does not exist"); + + /* + * a.equals(b) is semantically equivalent to a.leq(b) && b.leq(a). + */ + // bool equals(const Derived& other) const; + static_assert(std::is_same().equals( + std::declval())), + bool>::value, + "Derived::equals(const Derived&) does not exist"); + + if constexpr (Derived::mutability == AbstractMapMutability::Immutable) { + // Derived& update(Operation&& operation, const Key& key); + static_assert(std::is_same().update( + std::declval(), + std::declval())), + Derived&>::value, + "Derived::update(Operation&&, const Key&) does not exist"); + + // bool transform(MappingFunction&& f); + static_assert(std::is_same().transform( + std::declval())), + bool>::value, + "Derived::transform(MappingFunction&&) does not exist"); + + // Derived& union_with(CombiningFunction&& combine, const Derived& other); + static_assert( + std::is_same().union_with( + std::declval(), + std::declval())), + Derived&>::value, + "Derived::union_with(CombiningFunction&&, const Derived&) does not " + "exist"); + + // Derived& intersection_with(CombiningFunction&& combine, const Derived& + // other); + static_assert( + std::is_same().intersection_with( + std::declval(), + std::declval())), + Derived&>::value, + "Derived::intersection_with(CombiningFunction&&, const Derived&) " + "does " + "not exist"); + + // Derived& difference_with(CombiningFunction&& combine, const Derived& + // other); + static_assert( + std::is_same().difference_with( + std::declval(), + std::declval())), + Derived&>::value, + "Derived::difference_with(CombiningFunction&&, const Derived&) does " + "not exist"); + } else if constexpr (Derived::mutability == + AbstractMapMutability::Mutable) { + // Derived& update(Operation&& operation, const Key& key); + static_assert(std::is_same().update( + std::declval(), + std::declval())), + Derived&>::value, + "Derived::update(Operation&&, const Key&) does not exist"); + + // void transform(MappingFunction&& f); + static_assert(std::is_same().transform( + std::declval())), + void>::value, + "Derived::transform(MappingFunction&&) does not exist"); + + // Derived& union_with(CombiningFunction&& combine, const Derived& other); + static_assert( + std::is_same().union_with( + std::declval(), + std::declval())), + Derived&>::value, + "Derived::union_with(CombiningFunction&&, const Derived&) does not " + "exist"); + + // Derived& intersection_with(CombiningFunction&& combine, const Derived& + // other); + static_assert( + std::is_same().intersection_with( + std::declval(), + std::declval())), + Derived&>::value, + "Derived::intersection_with(CombiningFunction&&, const Derived&) " + "does not exist"); + + // Derived& difference_with(CombiningFunction&& combine, const Derived& + // other); + static_assert( + std::is_same().difference_with( + std::declval(), + std::declval())), + Derived&>::value, + "Derived::difference_with(CombiningFunction&&, const Derived&) does " + "not exist"); + } + } + + /* + * Many C++ libraries default to using operator== to check for equality, + * so we define it here as an alias of equals(). + */ + friend bool operator==(const Derived& self, const Derived& other) { + return self.equals(other); + } + + friend bool operator!=(const Derived& self, const Derived& other) { + return !self.equals(other); + } + + template + Derived get_union_with(CombiningFunction&& combine, + const Derived& other) const { + // Here and below: the static_cast is required in order to instruct + // the compiler to use the copy constructor of the derived class. + Derived result(static_cast(*this)); + result.union_with(std::forward(combine), other); + return result; + } + + template + Derived get_intersection_with(CombiningFunction&& combine, + const Derived& other) const { + Derived result(static_cast(*this)); + result.intersection_with(std::forward(combine), other); + return result; + } + + template + Derived get_difference_with(CombiningFunction&& combine, + const Derived& other) const { + Derived result(static_cast(*this)); + result.difference_with(std::forward(combine), other); + return result; + } +}; + +} // namespace sparta diff --git a/sparta/include/sparta/FlatMap.h b/sparta/include/sparta/FlatMap.h index ecef133d18d..75454c2451b 100644 --- a/sparta/include/sparta/FlatMap.h +++ b/sparta/include/sparta/FlatMap.h @@ -14,6 +14,7 @@ #include +#include #include #include @@ -33,7 +34,12 @@ template , typename AllocatorOrContainer = boost::container::new_allocator>> -class FlatMap final { +class FlatMap final : public AbstractMap> { private: using BoostFlatMap = boost::container::flat_map; @@ -50,6 +56,9 @@ class FlatMap final { using const_reference = typename BoostFlatMap::const_reference; using const_pointer = typename BoostFlatMap::const_pointer; + constexpr static AbstractMapMutability mutability = + AbstractMapMutability::Mutable; + ~FlatMap() { static_assert(std::is_same_v, "Value must be equal to ValueInterface::type"); @@ -227,14 +236,6 @@ class FlatMap final { other.m_map.end(), PairEqual()); } - friend bool operator==(const FlatMap& m1, const FlatMap& m2) { - return m1.equals(m2); - } - - friend bool operator!=(const FlatMap& m1, const FlatMap& m2) { - return !m1.equals(m2); - } - template // void(mapped_type*) void transform(MappingFunction&& f) { bool has_default_value = false; @@ -293,7 +294,7 @@ class FlatMap final { // Requires CombiningFunction to coerce to // std::function template - void union_with(CombiningFunction&& combine, const FlatMap& other) { + FlatMap& union_with(CombiningFunction&& combine, const FlatMap& other) { auto it = m_map.begin(), end = m_map.end(); auto other_it = other.m_map.begin(), other_end = other.m_map.end(); while (other_it != other_end) { @@ -313,12 +314,14 @@ class FlatMap final { ++other_it; } erase_default_values(); + return *this; } // Requires CombiningFunction to coerce to // std::function template - void intersection_with(CombiningFunction&& combine, const FlatMap& other) { + FlatMap& intersection_with(CombiningFunction&& combine, + const FlatMap& other) { auto it = m_map.begin(), end = m_map.end(); auto other_it = other.m_map.begin(), other_end = other.m_map.end(); while (it != end) { @@ -338,19 +341,20 @@ class FlatMap final { } } erase_default_values(); + return *this; } // Requires CombiningFunction to coerce to // std::function // Requires `combine(bottom, ...)` to be a no-op. template - void difference_with(CombiningFunction&& combine, const FlatMap& other) { + FlatMap& difference_with(CombiningFunction&& combine, const FlatMap& other) { auto it = m_map.begin(), end = m_map.end(); auto other_it = other.m_map.begin(), other_end = other.m_map.end(); while (other_it != other_end) { it = std::lower_bound(it, end, other_it->first, ComparePairWithKey()); if (it == end) { - return; + break; } if (KeyEqual()(it->first, other_it->first)) { combine(&it->second, other_it->second); @@ -359,6 +363,7 @@ class FlatMap final { ++other_it; } erase_default_values(); + return *this; } void clear() { m_map.clear(); } diff --git a/sparta/include/sparta/HashMap.h b/sparta/include/sparta/HashMap.h index 92263a92125..a7ef63fb3d9 100644 --- a/sparta/include/sparta/HashMap.h +++ b/sparta/include/sparta/HashMap.h @@ -14,6 +14,7 @@ #include #include +#include #include #include @@ -31,7 +32,9 @@ template , typename KeyHash = std::hash, typename KeyEqual = std::equal_to> -class HashMap final { +class HashMap final + : public AbstractMap< + HashMap> { public: using StdUnorderedMap = std::unordered_map; @@ -46,6 +49,9 @@ class HashMap final { using const_reference = typename StdUnorderedMap::const_reference; using const_pointer = typename StdUnorderedMap::const_pointer; + constexpr static AbstractMapMutability mutability = + AbstractMapMutability::Mutable; + ~HashMap() { static_assert(std::is_same_v, "Value must be equal to ValueInterface::type"); @@ -213,16 +219,8 @@ class HashMap final { return true; } - friend bool operator==(const HashMap& m1, const HashMap& m2) { - return m1.equals(m2); - } - - friend bool operator!=(const HashMap& m1, const HashMap& m2) { - return !m1.equals(m2); - } - template // void(mapped_type*) - HashMap& transform(MappingFunction&& f) { + void transform(MappingFunction&& f) { auto it = m_map.begin(), end = m_map.end(); while (it != end) { f(&it->second); @@ -232,7 +230,6 @@ class HashMap final { ++it; } } - return *this; } template // void(const value_type&) @@ -258,7 +255,7 @@ class HashMap final { // Requires CombiningFunction to coerce to // std::function template - void union_with(CombiningFunction&& combine, const HashMap& other) { + HashMap& union_with(CombiningFunction&& combine, const HashMap& other) { for (const auto& other_binding : other.m_map) { auto binding = m_map.find(other_binding.first); if (binding == m_map.end()) { @@ -270,12 +267,14 @@ class HashMap final { } } } + return *this; } // Requires CombiningFunction to coerce to // std::function template - void intersection_with(CombiningFunction&& combine, const HashMap& other) { + HashMap& intersection_with(CombiningFunction&& combine, + const HashMap& other) { auto it = m_map.begin(), end = m_map.end(); while (it != end) { auto other_binding = other.m_map.find(it->first); @@ -290,13 +289,14 @@ class HashMap final { } } } + return *this; } // Requires CombiningFunction to coerce to // std::function // Requires `combine(bottom, ...)` to be a no-op. template - void difference_with(CombiningFunction&& combine, const HashMap& other) { + HashMap& difference_with(CombiningFunction&& combine, const HashMap& other) { for (const auto& other_binding : other.m_map) { auto binding = m_map.find(other_binding.first); if (binding != m_map.end()) { @@ -306,6 +306,7 @@ class HashMap final { } } } + return *this; } void clear() { m_map.clear(); } diff --git a/sparta/include/sparta/PatriciaTreeHashMap.h b/sparta/include/sparta/PatriciaTreeHashMap.h index 745af805dac..09060999e9a 100644 --- a/sparta/include/sparta/PatriciaTreeHashMap.h +++ b/sparta/include/sparta/PatriciaTreeHashMap.h @@ -15,6 +15,7 @@ #include +#include #include #include #include @@ -38,7 +39,13 @@ template , typename KeyCompare = std::less, typename KeyEqual = std::equal_to> -class PatriciaTreeHashMap final { +class PatriciaTreeHashMap final + : public AbstractMap> { private: using SmallVector = boost::container::small_vector, 1>; using FlatMapT = @@ -89,6 +96,9 @@ class PatriciaTreeHashMap final { using const_reference = typename FlattenIteratorT::reference; using const_pointer = typename FlattenIteratorT::pointer; + constexpr static AbstractMapMutability mutability = + AbstractMapMutability::Mutable; + ~PatriciaTreeHashMap() { static_assert(std::is_same_v, "Value must be equal to ValueInterface::type"); @@ -128,16 +138,6 @@ class PatriciaTreeHashMap final { return m_tree.equals(other.m_tree); } - friend bool operator==(const PatriciaTreeHashMap& m1, - const PatriciaTreeHashMap& m2) { - return m1.equals(m2); - } - - friend bool operator!=(const PatriciaTreeHashMap& m1, - const PatriciaTreeHashMap& m2) { - return !m1.equals(m2); - } - /* See `PatriciaTreeMap::reference_equals` */ bool reference_equals(const PatriciaTreeHashMap& other) const { return m_tree.reference_equals(other.m_tree); @@ -169,8 +169,8 @@ class PatriciaTreeHashMap final { } template // void(mapped_type*) - bool transform(MappingFunction&& f) { - return m_tree.transform( + void transform(MappingFunction&& f) { + m_tree.transform( [f = std::forward(f)](FlatMapT flat_map) -> FlatMapT { flat_map.transform(f); return flat_map; @@ -249,32 +249,6 @@ class PatriciaTreeHashMap final { return *this; } - template - PatriciaTreeHashMap get_union_with(const CombiningFunction& combine, - const PatriciaTreeHashMap& other) const { - auto result = *this; - result.union_with(combine, other); - return result; - } - - template - PatriciaTreeHashMap get_intersection_with( - const CombiningFunction& combine, - const PatriciaTreeHashMap& other) const { - auto result = *this; - result.intersection_with(combine, other); - return result; - } - - template - PatriciaTreeHashMap get_difference_with( - const CombiningFunction& combine, - const PatriciaTreeHashMap& other) const { - auto result = *this; - result.difference_with(combine, other); - return result; - } - void clear() { m_tree.clear(); } friend std::ostream& operator<<(std::ostream& o, diff --git a/sparta/include/sparta/PatriciaTreeHashMapAbstractEnvironment.h b/sparta/include/sparta/PatriciaTreeHashMapAbstractEnvironment.h index 625a14a7267..44b5c8726a9 100644 --- a/sparta/include/sparta/PatriciaTreeHashMapAbstractEnvironment.h +++ b/sparta/include/sparta/PatriciaTreeHashMapAbstractEnvironment.h @@ -103,13 +103,13 @@ class PatriciaTreeHashMapAbstractEnvironment final } template // void(Domain*) - bool transform(Operation&& f) { + void transform(Operation&& f) { if (this->is_bottom()) { - return false; + return; } - bool res = this->get_value()->transform(std::forward(f)); + this->get_value()->transform(std::forward(f)); this->normalize(); - return res; + return; } template // void(const std::pair&) @@ -280,8 +280,8 @@ class MapValue final : public AbstractValue> { } template // void(Domain*) - bool transform(Operation&& f) { - return m_map.transform(std::forward(f)); + void transform(Operation&& f) { + m_map.transform(std::forward(f)); } template // void(const std::pair&) diff --git a/sparta/include/sparta/PatriciaTreeHashMapAbstractPartition.h b/sparta/include/sparta/PatriciaTreeHashMapAbstractPartition.h index b13d07489b2..2bdffd2f0d0 100644 --- a/sparta/include/sparta/PatriciaTreeHashMapAbstractPartition.h +++ b/sparta/include/sparta/PatriciaTreeHashMapAbstractPartition.h @@ -114,11 +114,11 @@ class PatriciaTreeHashMapAbstractPartition final } template // void(Domain*) - bool transform(Operation&& f) { + void transform(Operation&& f) { if (is_top()) { - return false; + return; } - return m_map.transform(std::forward(f)); + m_map.transform(std::forward(f)); } template // void(const std::pair&) diff --git a/sparta/include/sparta/PatriciaTreeMap.h b/sparta/include/sparta/PatriciaTreeMap.h index 692b02fe3bf..8119f45ac70 100644 --- a/sparta/include/sparta/PatriciaTreeMap.h +++ b/sparta/include/sparta/PatriciaTreeMap.h @@ -16,6 +16,7 @@ #include +#include #include #include #include @@ -72,7 +73,8 @@ namespace sparta { template > -class PatriciaTreeMap final { +class PatriciaTreeMap final + : public AbstractMap> { using Core = pt_core::PatriciaTreeCore; using Codec = typename Core::Codec; @@ -90,6 +92,9 @@ class PatriciaTreeMap final { using IntegerType = typename Codec::IntegerType; + constexpr static AbstractMapMutability mutability = + AbstractMapMutability::Immutable; + ~PatriciaTreeMap() { static_assert(std::is_same_v, "Value must be equal to ValueInterface::type"); @@ -119,14 +124,6 @@ class PatriciaTreeMap final { return m_core.equals(other.m_core); } - friend bool operator==(const PatriciaTreeMap& m1, const PatriciaTreeMap& m2) { - return m1.equals(m2); - } - - friend bool operator!=(const PatriciaTreeMap& m1, const PatriciaTreeMap& m2) { - return !m1.equals(m2); - } - /* * This faster equality predicate can be used to check whether a sequence of * in-place modifications leaves a Patricia-tree map unchanged. For comparing @@ -217,30 +214,6 @@ class PatriciaTreeMap final { return *this; } - template - PatriciaTreeMap get_union_with(CombiningFunction&& combine, - const PatriciaTreeMap& other) const { - auto result = *this; - result.union_with(std::forward(combine), other); - return result; - } - - template - PatriciaTreeMap get_intersection_with(CombiningFunction&& combine, - const PatriciaTreeMap& other) const { - auto result = *this; - result.intersection_with(std::forward(combine), other); - return result; - } - - template - PatriciaTreeMap get_difference_with(CombiningFunction&& combine, - const PatriciaTreeMap& other) const { - auto result = *this; - result.difference_with(std::forward(combine), other); - return result; - } - void clear() { m_core.clear(); } friend std::ostream& operator<<(std::ostream& o, const PatriciaTreeMap& s) { diff --git a/sparta/test/PatriciaTreeHashMapAbstractEnvironmentTest.cpp b/sparta/test/PatriciaTreeHashMapAbstractEnvironmentTest.cpp index 2e2a788750a..4df4672a082 100644 --- a/sparta/test/PatriciaTreeHashMapAbstractEnvironmentTest.cpp +++ b/sparta/test/PatriciaTreeHashMapAbstractEnvironmentTest.cpp @@ -244,11 +244,10 @@ TEST_F(PatriciaTreeHashMapAbstractEnvironmentTest, whiteBox) { TEST_F(PatriciaTreeHashMapAbstractEnvironmentTest, transform) { Environment e1({{1, Domain({"a", "b"})}}); - bool any_changes = e1.transform([](Domain*) {}); - EXPECT_FALSE(any_changes); + e1.transform([](Domain*) {}); + EXPECT_EQ(e1.size(), 1); - any_changes = e1.transform([](Domain* d) { d->set_to_top(); }); - EXPECT_TRUE(any_changes); + e1.transform([](Domain* d) { d->set_to_top(); }); EXPECT_TRUE(e1.is_top()); } diff --git a/sparta/test/PatriciaTreeHashMapAbstractPartitionTest.cpp b/sparta/test/PatriciaTreeHashMapAbstractPartitionTest.cpp index 6715652cb2c..0fecd08d34c 100644 --- a/sparta/test/PatriciaTreeHashMapAbstractPartitionTest.cpp +++ b/sparta/test/PatriciaTreeHashMapAbstractPartitionTest.cpp @@ -247,10 +247,9 @@ TEST(PatriciaTreeHashMapAbstractPartitionTest, destructiveOperations) { TEST(PatriciaTreeHashMapAbstractPartitionTest, transform) { Partition p1({{1, Domain({"a", "b"})}}); - bool any_changes = p1.transform([](Domain*) {}); - EXPECT_FALSE(any_changes); + p1.transform([](Domain*) {}); + EXPECT_EQ(p1.size(), 1); - any_changes = p1.transform([](Domain* d) { d->set_to_bottom(); }); - EXPECT_TRUE(any_changes); + p1.transform([](Domain* d) { d->set_to_bottom(); }); EXPECT_TRUE(p1.is_bottom()); } diff --git a/sparta/test/PatriciaTreeHashMapTest.cpp b/sparta/test/PatriciaTreeHashMapTest.cpp index d1238ac7dff..2eb6b460988 100644 --- a/sparta/test/PatriciaTreeHashMapTest.cpp +++ b/sparta/test/PatriciaTreeHashMapTest.cpp @@ -65,15 +65,13 @@ TEST(PatriciaTreeHashMapTest, map) { constexpr uint32_t default_value = 0; pth_map m1 = create_pth_map({{0, 1}, {1, 2}, {2, 4}}); - bool any_changes = m1.transform([](uint32_t*) {}); - EXPECT_FALSE(any_changes); + m1.transform([](uint32_t*) {}); EXPECT_EQ(3, m1.size()); EXPECT_EQ(m1.at(0), 1); EXPECT_EQ(m1.at(1), 2); EXPECT_EQ(m1.at(2), 4); - any_changes = m1.transform([](uint32_t* value) { --(*value); }); - EXPECT_TRUE(any_changes); + m1.transform([](uint32_t* value) { --(*value); }); EXPECT_EQ(2, m1.size()); EXPECT_EQ(m1.at(0), default_value); EXPECT_EQ(m1.at(1), 1);