Skip to content

Commit

Permalink
Add simple Bitset types.
Browse files Browse the repository at this point in the history
  • Loading branch information
BenKaufmann committed Nov 19, 2024
1 parent 3915610 commit 7e112cc
Show file tree
Hide file tree
Showing 3 changed files with 253 additions and 3 deletions.
90 changes: 90 additions & 0 deletions potassco/basic_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,96 @@ class DynamicBuffer {
};
inline void swap(DynamicBuffer& lhs, DynamicBuffer& rhs) noexcept { lhs.swap(rhs); }

template <std::unsigned_integral T, typename ElemType = unsigned>
requires requires(ElemType e) {
{ +e } -> std::convertible_to<unsigned>;
}
class Bitset {
public:
using StorageType = T;
//! Maximal number of elements in the set (i.e. maximal number of bits)
static constexpr auto max_count = sizeof(T) * CHAR_BIT;

//! Creates an empty set, i.e. all bits are zero.
constexpr Bitset() noexcept : set_{} {}
//! Creates a set with the given elements, i.e. bits at the given positions are set.
constexpr Bitset(std::initializer_list<ElemType> elems) : set_{} {
for (StorageType zero{}; auto e : elems) { set_ |= Potassco::set_bit(zero, +e); }
}
//! Constructs a bitset with all bits in `r` set.
static constexpr Bitset fromRep(StorageType r) noexcept { return Bitset(r); }
//! Returns whether the set contains the given element.
[[nodiscard]] constexpr bool contains(ElemType e) const { return Potassco::test_bit(set_, +e); }
//! Returns the number of elements in the set, i.e. the number of bits set.
[[nodiscard]] constexpr unsigned count() const noexcept { return Potassco::bit_count(set_); }
//! Adds the given element to the set and returns true if it was not already in the set.
constexpr bool add(ElemType e) { return not contains(e) && Potassco::store_set_bit(set_, +e); }
//! Removes the given element from the set and returns true if it was in the set.
constexpr bool remove(ElemType e) { return contains(e) && Potassco::store_clear_bit(set_, +e) >= 0u; }
//! Removes all elements (bits) >= max.
constexpr void removeMax(ElemType max) { set_ &= Potassco::bit_max<StorageType>(+max); }
//! Removes all elements from the set.
constexpr void clear() noexcept { set_ = {}; }

[[nodiscard]] constexpr StorageType rep() const noexcept { return set_; }

friend constexpr bool operator==(Bitset lhs, Bitset rhs) noexcept = default;
friend constexpr auto operator<=>(Bitset lhs, Bitset rhs) noexcept = default;

private:
constexpr explicit Bitset(StorageType r) : set_(r) {}
StorageType set_;
};
static_assert(Bitset<uint32_t>::max_count == 32);
static_assert(Bitset<uint32_t>{}.rep() == 0u);
static_assert(Bitset<uint32_t>::fromRep(8u).contains(3));
static_assert(Bitset<uint32_t>::fromRep(15u).count() == 4u);
static_assert(Bitset<uint32_t>{1, 2, 3}.rep() == 14u);

class DynamicBitset {
public:
using IndexType = unsigned;
using trivially_relocatable = std::true_type; // NOLINT

//! Creates an empty set.
DynamicBitset() noexcept = default;
//! Reserves space for at least `numBits`.
void reserve(unsigned numBits);
//! Returns whether the set contains the given bit.
[[nodiscard]] bool contains(IndexType bit) const {
auto [w, p] = pos(bit);
auto s = span();
return w < s.size() && s[w].contains(p);
}
//! Returns the number of elements in the set, i.e. the number of bits set.
[[nodiscard]] unsigned count() const noexcept;
//! Returns whether the set is empty.
[[nodiscard]] bool empty() const noexcept;
//! Adds the given bit to the set and returns true if it was not already in the set.
bool add(IndexType bit);
//! Removes the given bit from the set and returns true if it was in the set.
bool remove(IndexType bit);
//! Removes all elements from the set.
void clear() noexcept { buffer_.clear(); }

friend bool operator==(const DynamicBitset& lhs, const DynamicBitset& rhs) noexcept {
return lhs.compare(rhs) == std::strong_ordering::equal;
}
friend auto operator<=>(const DynamicBitset& lhs, const DynamicBitset& rhs) noexcept { return lhs.compare(rhs); }

private:
using SetType = Bitset<uint64_t>;
[[nodiscard]] static constexpr auto pos(IndexType bit) -> std::pair<unsigned, unsigned> {
return {bit / 64u, bit & 63u};
}
[[nodiscard]] auto compare(const DynamicBitset& other) const -> std::strong_ordering;
[[nodiscard]] auto data() const noexcept -> SetType* { return reinterpret_cast<SetType*>(buffer_.data()); }
[[nodiscard]] auto span() const noexcept -> std::span<SetType> {
return {data(), buffer_.size() / sizeof(SetType)};
}
DynamicBuffer buffer_;
};

class RuleBuilder;

//! A trivially relocatable immutable string type with small buffer optimization.
Expand Down
50 changes: 50 additions & 0 deletions src/match_basic_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,56 @@ void DynamicBuffer::append(const void* what, std::size_t n) {
}
}
/////////////////////////////////////////////////////////////////////////////////////////
// DynamicBitset
/////////////////////////////////////////////////////////////////////////////////////////
unsigned DynamicBitset::count() const noexcept {
unsigned c = 0;
for (auto xs : span()) { c += xs.count(); }
return c;
}
bool DynamicBitset::empty() const noexcept { return buffer_.size() == 0; }
void DynamicBitset::reserve(unsigned numBits) {
if (numBits) {
buffer_.reserve((1u + pos(numBits - 1).first) * sizeof(SetType));
}
}
bool DynamicBitset::add(IndexType bit) {
auto [w, p] = pos(bit);
auto s = span();
if (w >= s.size()) {
auto add = (w - s.size()) + 1;
buffer_.reserve((s.size() + add) * sizeof(SetType));
while (add--) { new (buffer_.alloc(sizeof(SetType)).data()) SetType(); }
s = span();
}
return s[w].add(bit);
}
bool DynamicBitset::remove(IndexType bit) {
auto [w, p] = pos(bit);
auto s = span();
auto res = w < s.size() && s[w].remove(p);
if (res && (w + 1) == s.size() && s[w].count() == 0) {
auto pop = 1u;
while (w-- && s[w].count() == 0) { ++pop; }
buffer_.pop(pop * sizeof(SetType));
return true;
}
return res;
}
auto DynamicBitset::compare(const DynamicBitset& other) const -> std::strong_ordering {
auto lhs = span();
auto rhs = other.span();
if (auto x = lhs.size() <=> rhs.size(); x != std::strong_ordering::equal) {
return x;
}
for (auto n = lhs.size(); n--;) {
if (auto x = lhs[n] <=> rhs[n]; x != std::strong_ordering::equal) {
return x;
}
}
return std::strong_ordering::equal;
}
/////////////////////////////////////////////////////////////////////////////////////////
// ConstString
/////////////////////////////////////////////////////////////////////////////////////////
ConstString::ConstString(std::string_view n, CreateMode m) {
Expand Down
116 changes: 113 additions & 3 deletions tests/test_aspif.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ static std::ostream& operator<<(std::ostream& os, const std::pair<Atom_t, Potass
static std::ostream& operator<<(std::ostream& os, const Heuristic& h);
static std::ostream& operator<<(std::ostream& os, const Edge& e);
template <typename T>
std::string stringify(const std::span<T>& s) {
static std::string stringify(const std::span<T>& s) {
std::stringstream str;
str << "[";
const char* sep = "";
Expand All @@ -82,7 +82,8 @@ template <Potassco::ScopedEnum E>
[[maybe_unused]] static std::ostream& operator<<(std::ostream& os, E e) {
return os << Potassco::to_underlying(e);
}
class ReadObserver : public Test::ReadObserver {
namespace {
class ReadObserver final : public Test::ReadObserver {
public:
void rule(Head_t ht, const AtomSpan& head, const LitSpan& body) override {
rules.push_back({ht, {begin(head), end(head)}, Body_t::normal, bound_none, {}});
Expand Down Expand Up @@ -126,6 +127,22 @@ class ReadObserver : public Test::ReadObserver {
TheoryData theory;
};

enum class DummyEnum : uint8_t {
zero = 0,
one = 1,
two = 2,
three = 3,
four = 4,
five = 5,
six = 6,
seven = 7,
eight = 8,
};
POTASSCO_SET_DEFAULT_ENUM_MAX(DummyEnum::eight);
[[maybe_unused]] consteval auto enable_ops(std::type_identity<DummyEnum>) -> CmpOps { return {}; }

} // namespace

static unsigned compareRead(std::stringstream& input, ReadObserver& observer, const Rule* rules,
const std::pair<unsigned, unsigned>& subset) {
for (unsigned i = 0; i != subset.second; ++i) { rule(input, rules[subset.first + i]); }
Expand All @@ -141,6 +158,7 @@ static unsigned compareRead(std::stringstream& input, ReadObserver& observer, co
}
return subset.second;
}

TEST_CASE("Test DynamicBuffer", "[rule]") {
SECTION("starts empty") {
DynamicBuffer r;
Expand Down Expand Up @@ -261,6 +279,7 @@ TEST_CASE("Test DynamicBuffer", "[rule]") {
void* raw = m1.data();
POTASSCO_WARNING_PUSH()
POTASSCO_WARNING_IGNORE_GCC("-Wclass-memaccess")
POTASSCO_WARNING_IGNORE_CLANG("-Wnontrivial-memaccess")
std::memcpy(&m2, &m1, sizeof(DynamicBuffer)); // NOLINT(*-undefined-memory-manipulation)
std::memcpy(&m1, &empty, sizeof(DynamicBuffer)); // NOLINT(*-undefined-memory-manipulation)
POTASSCO_WARNING_POP()
Expand Down Expand Up @@ -490,6 +509,96 @@ TEST_CASE("Test Basic", "[rule]") {
CHECK(store_clear_bit(n, 3u) == 0u);
CHECK(n == 0);
}
SECTION("bitset") {
Bitset<unsigned> bs({1u, 2u, 5u});
CHECK(bs.count() == 3);
CHECK(bs.contains(1));
CHECK(bs.contains(2));
CHECK(bs.contains(5));
CHECK_FALSE(bs.contains(0));
CHECK_FALSE(bs.contains(3));
CHECK_FALSE(bs.contains(4));

bs.removeMax(5);
CHECK_FALSE(bs.contains(5));
CHECK(bs.count() == 2);
bs.add(3);
bs.add(4);
bs.add(5);
CHECK(bs.count() == 5);
bs.removeMax(4);
CHECK_FALSE(bs.contains(5));
CHECK_FALSE(bs.contains(4));
CHECK(bs.contains(3));
CHECK(bs.count() == 3);
bs.remove(3);
CHECK(bs.count() == 2);
CHECK_FALSE(bs.contains(3));

auto copy = bs;
bs.removeMax(0);
CHECK(bs.count() == 0);
CHECK(copy.count() == 2);
copy.clear();
CHECK(copy.count() == 0);

bs.add(31);
bs.add(30);
CHECK(bs.count() == 2);
bs.removeMax(32);
CHECK(bs.count() == 2);
bs.removeMax(31);
CHECK(bs.count() == 1);
}

SECTION("bitset enum") {
using SetType = Bitset<unsigned, DummyEnum>;
static_assert(sizeof(SetType) == sizeof(unsigned));

SetType dummy;
dummy.add(DummyEnum::eight);
CHECK(dummy.count() == 1u);
CHECK(dummy.contains(DummyEnum::eight));
CHECK_FALSE(dummy.contains(DummyEnum::seven));

dummy.add(DummyEnum::five);
dummy.removeMax(DummyEnum::seven);
CHECK(dummy.count() == 1u);
CHECK(dummy.contains(DummyEnum::five));
}

SECTION("dynamic biset") {
DynamicBitset bitset;
CHECK(bitset.count() == 0);
CHECK(bitset == bitset);
CHECK_FALSE(bitset < bitset);
CHECK_FALSE(bitset > bitset);
CHECK(bitset <= bitset);

bitset.add(63);
DynamicBitset other;
CHECK(bitset.count() == 1);
CHECK(other < bitset);
bitset.add(64);
CHECK(bitset.count() == 2);
other.add(64);
CHECK(other < bitset);
other.add(65);
CHECK(other > bitset);
other.add(128);
CHECK(other > bitset);
CHECK(other.count() == 3);
other.remove(65);
CHECK(other > bitset);
other.add(63);
other.remove(128);
CHECK(other == bitset);
other.add(4096);
other.add(100000);
CHECK(other.count() == 4);
other.remove(100000);
CHECK(other.count() == 3);
}
}
TEST_CASE("Test RuleBuilder", "[rule]") {
RuleBuilder rb;
Expand Down Expand Up @@ -543,7 +652,7 @@ TEST_CASE("Test RuleBuilder", "[rule]") {
rb.startSum(2).addGoal(2, 1).addGoal(-3, 1).addGoal(4, 2).addHead(1).end();
REQUIRE_THROWS_AS(rb.setBound(4), std::logic_error);
}
SECTION("weakean to cardinality rule") {
SECTION("weaken to cardinality rule") {
rb.start().addHead(1).startSum(2).addGoal(2, 2).addGoal(-3, 2).addGoal(4, 2).weaken(Body_t::count).end();
REQUIRE(spanEq(rb.head(), std::vector<Atom_t>{1}));
REQUIRE(rb.bodyType() == Body_t::count);
Expand Down Expand Up @@ -699,6 +808,7 @@ TEST_CASE("Test RuleBuilder", "[rule]") {
RuleBuilder mc;
POTASSCO_WARNING_PUSH()
POTASSCO_WARNING_IGNORE_GCC("-Wclass-memaccess")
POTASSCO_WARNING_IGNORE_CLANG("-Wnontrivial-memaccess")
std::memcpy(&mc, &rb, sizeof(RuleBuilder)); // NOLINT(*-undefined-memory-manipulation)
std::memcpy(&rb, &empty, sizeof(RuleBuilder)); // NOLINT(*-undefined-memory-manipulation)
POTASSCO_WARNING_POP()
Expand Down

0 comments on commit 7e112cc

Please sign in to comment.