From aaf62c9498643ce09d0f8512aeccdf5d041ab2e0 Mon Sep 17 00:00:00 2001 From: Gleb Mazovetskiy Date: Thu, 24 Jun 2021 08:52:54 +0100 Subject: [PATCH 1/2] Make `Agent` copyable and moveable This allows implementing C++ iterator interface on top of the Agent more efficiently. C++ iterators must be copyable, and the only way to copy one previously was to repeat the query until the index. --- configure.ac | 7 +++ include/marisa/agent.h | 21 +++++++-- include/marisa/scoped-array.h | 12 +++++ include/marisa/scoped-ptr.h | 12 +++++ lib/marisa/agent.cc | 66 +++++++++++++++++++++----- lib/marisa/grimoire/trie/state.h | 12 +++-- lib/marisa/grimoire/vector/vector.h | 50 ++++++++++++++++---- tests/marisa-test.cc | 72 +++++++++++++++++++++++++++++ 8 files changed, 222 insertions(+), 30 deletions(-) diff --git a/configure.ac b/configure.ac index 118a677..c7adec0 100644 --- a/configure.ac +++ b/configure.ac @@ -13,6 +13,9 @@ AC_PROG_INSTALL AC_CONFIG_MACRO_DIR([m4]) +# Sanitizers +AC_ARG_ENABLE([asan], AS_HELP_STRING([--enable-asan], [Enable address sanitizer])) + # Macros for SSE availability check. AC_DEFUN([MARISA_ENABLE_SSE2], [AC_EGREP_CPP([yes], [ @@ -241,6 +244,10 @@ elif test "x${enable_sse2}" != "xno"; then CXXFLAGS="$CXXFLAGS -DMARISA_USE_SSE2 -msse2" fi +AS_IF([test "x$enable_asan" = "xyes"], [ + CXXFLAGS="$CXXFLAGS -fsanitize=address" +]) + AC_CONFIG_FILES([Makefile marisa.pc include/Makefile diff --git a/include/marisa/agent.h b/include/marisa/agent.h index b549d36..117aed8 100644 --- a/include/marisa/agent.h +++ b/include/marisa/agent.h @@ -22,6 +22,14 @@ class Agent { Agent(); ~Agent(); + Agent(const Agent &other); + Agent &operator=(const Agent &other); + +#if __cplusplus >= 201103L + Agent(Agent &&other) noexcept; + Agent &operator=(Agent &&other) noexcept; +#endif + const Query &query() const { return query_; } @@ -37,6 +45,9 @@ class Agent { void set_query(const char *str); void set_query(const char *ptr, std::size_t length); void set_query(std::size_t key_id); + void set_query(const Query &query) { + query_ = query; + } const grimoire::trie::State &state() const { return *state_; @@ -65,7 +76,7 @@ class Agent { } bool has_state() const { - return state_.get() != NULL; + return state_ != NULL; } void init_state(); @@ -75,11 +86,11 @@ class Agent { private: Query query_; Key key_; - scoped_ptr state_; - // Disallows copy and assignment. - Agent(const Agent &); - Agent &operator=(const Agent &); + // Cannot be `scoped_ptr` because `State` is forward-declared. + grimoire::trie::State *state_; + + void clear_state(); }; } // namespace marisa diff --git a/include/marisa/scoped-array.h b/include/marisa/scoped-array.h index 34cefa4..7d58c98 100644 --- a/include/marisa/scoped-array.h +++ b/include/marisa/scoped-array.h @@ -11,6 +11,18 @@ class scoped_array { scoped_array() : array_(NULL) {} explicit scoped_array(T *array) : array_(array) {} +#if __cplusplus >= 201103L + scoped_array(scoped_array &&other) noexcept : array_(other.array_) { + other.array_ = NULL; + } + scoped_array &operator=(scoped_array &&other) noexcept { + delete [] array_; + array_ = other.array_; + other.array_ = NULL; + return *this; + } +#endif + ~scoped_array() { delete [] array_; } diff --git a/include/marisa/scoped-ptr.h b/include/marisa/scoped-ptr.h index abf48d8..2330439 100644 --- a/include/marisa/scoped-ptr.h +++ b/include/marisa/scoped-ptr.h @@ -15,6 +15,18 @@ class scoped_ptr { delete ptr_; } +#if __cplusplus >= 201103L + scoped_ptr(scoped_ptr &&other) noexcept : ptr_(other.ptr_) { + other.ptr_ = NULL; + } + scoped_ptr &operator=(scoped_ptr &&other) noexcept { + delete ptr_; + ptr_ = other.ptr_; + other.ptr_ = NULL; + return *this; + } +#endif + void reset(T *ptr = NULL) { MARISA_DEBUG_IF((ptr != NULL) && (ptr == ptr_), MARISA_RESET_ERROR); scoped_ptr(ptr).swap(*this); diff --git a/lib/marisa/agent.cc b/lib/marisa/agent.cc index 7fa7cb1..51baa5b 100644 --- a/lib/marisa/agent.cc +++ b/lib/marisa/agent.cc @@ -5,47 +5,89 @@ namespace marisa { -Agent::Agent() : query_(), key_(), state_() {} +Agent::Agent() : query_(), key_(), state_(NULL) {} -Agent::~Agent() {} +Agent::~Agent() { + delete state_; +} + +Agent::Agent(const Agent &other) + : query_(other.query_), + key_(other.key_), + state_(other.has_state() ? new (std::nothrow) grimoire::trie::State(other.state()) : NULL) {} + +Agent &Agent::operator=(const Agent &other) { + query_ = other.query_; + key_ = other.key_; + delete state_; + if (other.has_state()) { + state_ = new (std::nothrow) grimoire::trie::State(other.state()); + } else { + state_ = NULL; + } + return *this; +} + +#if __cplusplus >= 201103L +Agent::Agent(Agent &&other) noexcept + : query_(other.query_), key_(other.key_), state_(other.state_) { + other.state_ = NULL; +} + +Agent &Agent::operator=(Agent &&other) noexcept { + query_ = other.query_; + key_ = other.key_; + delete state_; + state_ = other.state_; + other.state_ = NULL; + return *this; +} +#endif void Agent::set_query(const char *str) { MARISA_THROW_IF(str == NULL, MARISA_NULL_ERROR); - if (state_.get() != NULL) { - state_->reset(); + if (state_ != NULL) { + clear_state(); } query_.set_str(str); } void Agent::set_query(const char *ptr, std::size_t length) { MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_NULL_ERROR); - if (state_.get() != NULL) { - state_->reset(); + if (state_ != NULL) { + clear_state(); } query_.set_str(ptr, length); } void Agent::set_query(std::size_t key_id) { - if (state_.get() != NULL) { - state_->reset(); + if (state_ != NULL) { + clear_state(); } query_.set_id(key_id); } void Agent::init_state() { - MARISA_THROW_IF(state_.get() != NULL, MARISA_STATE_ERROR); - state_.reset(new (std::nothrow) grimoire::State); - MARISA_THROW_IF(state_.get() == NULL, MARISA_MEMORY_ERROR); + MARISA_THROW_IF(state_ != NULL, MARISA_STATE_ERROR); + delete state_; + state_ = new (std::nothrow) grimoire::State; + MARISA_THROW_IF(state_ == NULL, MARISA_MEMORY_ERROR); } void Agent::clear() { Agent().swap(*this); } + +void Agent::clear_state() { + delete state_; + state_ = nullptr; +} + void Agent::swap(Agent &rhs) { query_.swap(rhs.query_); key_.swap(rhs.key_); - state_.swap(rhs.state_); + marisa::swap(state_, rhs.state_); } } // namespace marisa diff --git a/lib/marisa/grimoire/trie/state.h b/lib/marisa/grimoire/trie/state.h index df605a6..07bda52 100644 --- a/lib/marisa/grimoire/trie/state.h +++ b/lib/marisa/grimoire/trie/state.h @@ -24,6 +24,14 @@ class State { : key_buf_(), history_(), node_id_(0), query_pos_(0), history_pos_(0), status_code_(MARISA_READY_TO_ALL) {} + State(const State &) = default; + State &operator=(const State &) = default; + +#if __cplusplus >= 201103L + State(State &&) noexcept = default; + State &operator=(State &&) noexcept = default; +#endif + void set_node_id(std::size_t node_id) { MARISA_DEBUG_IF(node_id > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); node_id_ = (UInt32)node_id; @@ -104,10 +112,6 @@ class State { UInt32 query_pos_; UInt32 history_pos_; StatusCode status_code_; - - // Disallows copy and assignment. - State(const State &); - State &operator=(const State &); }; } // namespace trie diff --git a/lib/marisa/grimoire/vector/vector.h b/lib/marisa/grimoire/vector/vector.h index 2bfccdb..209bc70 100644 --- a/lib/marisa/grimoire/vector/vector.h +++ b/lib/marisa/grimoire/vector/vector.h @@ -23,6 +23,38 @@ class Vector { } } + Vector(const Vector &other) + : buf_(), objs_(NULL), const_objs_(NULL), + size_(0), capacity_(0), fixed_(other.fixed_) { + if (other.buf_.get() == NULL) { + objs_ = other.objs_; + const_objs_ = other.const_objs_; + size_ = other.size_; + capacity_ = other.capacity_; + } else { + copy(other.const_objs_, other.size_, other.capacity_); + } + } + + Vector &operator=(const Vector &other) { + clear(); + fixed_ = other.fixed_; + if (other.buf_.get() == NULL) { + objs_ = other.objs_; + const_objs_ = other.const_objs_; + size_ = other.size_; + capacity_ = other.capacity_; + } else { + copy(other.const_objs_, other.size_, other.capacity_); + } + return *this; + } + +#if __cplusplus >= 201103L + Vector(Vector &&) noexcept = default; + Vector &operator=(Vector &&) noexcept = default; +#endif + void map(Mapper &mapper) { Vector temp; temp.map_(mapper); @@ -225,14 +257,17 @@ class Vector { // realloc() assumes that T's placement new does not throw an exception. void realloc(std::size_t new_capacity) { MARISA_DEBUG_IF(new_capacity > max_size(), MARISA_SIZE_ERROR); + copy(objs_, size_, new_capacity); + } - scoped_array new_buf( - new (std::nothrow) char[sizeof(T) * new_capacity]); + // copy() assumes that T's placement new does not throw an exception. + void copy(const T *src, std::size_t src_size, std::size_t capacity) { + scoped_array new_buf(new (std::nothrow) char[sizeof(T) * capacity]); MARISA_DEBUG_IF(new_buf.get() == NULL, MARISA_MEMORY_ERROR); T *new_objs = reinterpret_cast(new_buf.get()); - for (std::size_t i = 0; i < size_; ++i) { - new (&new_objs[i]) T(objs_[i]); + for (std::size_t i = 0; i < src_size; ++i) { + new (&new_objs[i]) T(src[i]); } for (std::size_t i = 0; i < size_; ++i) { objs_[i].~T(); @@ -241,12 +276,9 @@ class Vector { buf_.swap(new_buf); objs_ = new_objs; const_objs_ = new_objs; - capacity_ = new_capacity; + size_ = src_size; + capacity_ = capacity; } - - // Disallows copy and assignment. - Vector(const Vector &); - Vector &operator=(const Vector &); }; } // namespace vector diff --git a/tests/marisa-test.cc b/tests/marisa-test.cc index 36e4258..29e8377 100644 --- a/tests/marisa-test.cc +++ b/tests/marisa-test.cc @@ -1,7 +1,10 @@ +#include #include #include #include #include +#include +#include #include @@ -258,6 +261,71 @@ void TestPredictiveSearch(const marisa::Trie &trie, } } +void TestPredictiveSearchAgentCopy(const marisa::Trie &trie, + const marisa::Keyset &keyset) { + marisa::Agent agent; + for (std::size_t i = 0; i < keyset.size(); ++i) { + agent.set_query(keyset[i].ptr(), keyset[i].length()); + ASSERT(trie.predictive_search(agent)); + ASSERT(agent.key().id() == keyset[i].id()); + + std::vector agent_copies; + std::vector ids; + while (trie.predictive_search(agent)) { + ASSERT(agent.key().id() > keyset[i].id()); + ids.push_back(agent.key().id()); + + // Tests copy constructor. + agent_copies.push_back(agent); + } + + for (std::size_t i = 0; i < agent_copies.size(); ++i) { + marisa::Agent agent_copy; + + // Tests copy assignment. + agent_copy = agent_copies[i]; + + ASSERT(agent_copy.key().id() == ids[i]); + if (i + 1 < agent_copies.size()) { + ASSERT(trie.predictive_search(agent_copy)); + ASSERT(agent_copy.key().id() == ids[i + 1]); + } else { + ASSERT(!trie.predictive_search(agent_copy)); + } + } + } +} + +#if __cplusplus >= 201103L +void TestPredictiveSearchAgentMove(const marisa::Trie &trie, + const marisa::Keyset &keyset) { + marisa::Agent agents[2]; + std::size_t current_agent = 0; + + const auto move_agent = [&]() { + const std::size_t other_agent = (current_agent + 1) % 2; + agents[other_agent] = std::move(agents[current_agent]); + agents[current_agent] = {}; + current_agent = other_agent; + }; + + for (std::size_t i = 0; i < keyset.size(); ++i) { + agents[current_agent].set_query(keyset[i].ptr(), keyset[i].length()); + move_agent(); + + ASSERT(trie.predictive_search(agents[current_agent])); + move_agent(); + + ASSERT(agents[current_agent].key().id() == keyset[i].id()); + + while (trie.predictive_search(agents[current_agent])) { + move_agent(); + ASSERT(agents[current_agent].key().id() > keyset[i].id()); + } + } +} +#endif // __cplusplus >= 201103L + void TestTrie(int num_tries, marisa::TailMode tail_mode, marisa::NodeOrder node_order, marisa::Keyset &keyset) { for (std::size_t i = 0; i < keyset.size(); ++i) { @@ -276,6 +344,10 @@ void TestTrie(int num_tries, marisa::TailMode tail_mode, TestLookup(trie, keyset); TestCommonPrefixSearch(trie, keyset); TestPredictiveSearch(trie, keyset); + TestPredictiveSearchAgentCopy(trie, keyset); +#if __cplusplus >= 201103L + TestPredictiveSearchAgentMove(trie, keyset); +#endif trie.save("marisa-test.dat"); From af8fc6a2c20b0a1011163427e691707aaa3234b2 Mon Sep 17 00:00:00 2001 From: Gleb Mazovetskiy Date: Fri, 25 Jun 2021 08:36:23 +0100 Subject: [PATCH 2/2] Make `Trie` moveable Adds move constructor and move assignment to `Trie`, allowing modern C++ code to use it without a pointer indirection. --- include/marisa/scoped-ptr.h | 5 +++ include/marisa/trie.h | 8 +++- lib/marisa/trie.cc | 74 +++++++++++++++++++++++-------------- 3 files changed, 59 insertions(+), 28 deletions(-) diff --git a/include/marisa/scoped-ptr.h b/include/marisa/scoped-ptr.h index 2330439..c485f08 100644 --- a/include/marisa/scoped-ptr.h +++ b/include/marisa/scoped-ptr.h @@ -43,6 +43,11 @@ class scoped_ptr { T *get() const { return ptr_; } + T *release() { + T *ptr = ptr_; + ptr_ = NULL; + return ptr; + } void clear() { scoped_ptr().swap(*this); diff --git a/include/marisa/trie.h b/include/marisa/trie.h index 30f3c68..f39808c 100644 --- a/include/marisa/trie.h +++ b/include/marisa/trie.h @@ -20,6 +20,11 @@ class Trie { Trie(); ~Trie(); +#if __cplusplus >= 201103L + Trie(Trie &&other) noexcept; + Trie &operator=(Trie &&other) noexcept; +#endif + void build(Keyset &keyset, int config_flags = 0); void mmap(const char *filename); @@ -52,7 +57,8 @@ class Trie { void swap(Trie &rhs); private: - scoped_ptr trie_; + // Cannot be `scoped_ptr` because `LoudsTrie` is forward-declared. + grimoire::trie::LoudsTrie *trie_; // Disallows copy and assignment. Trie(const Trie &); diff --git a/lib/marisa/trie.cc b/lib/marisa/trie.cc index 6805001..5268031 100644 --- a/lib/marisa/trie.cc +++ b/lib/marisa/trie.cc @@ -5,16 +5,30 @@ namespace marisa { -Trie::Trie() : trie_() {} +Trie::Trie() : trie_(NULL) {} -Trie::~Trie() {} +Trie::~Trie() { delete trie_; } + +#if __cplusplus >= 201103L +Trie::Trie(Trie &&other) noexcept : trie_(other.trie_) { + other.trie_ = NULL; +} + +Trie &Trie::operator=(Trie &&other) noexcept { + delete trie_; + trie_ = other.trie_; + other.trie_ = NULL; + return *this; +} +#endif void Trie::build(Keyset &keyset, int config_flags) { scoped_ptr temp(new (std::nothrow) grimoire::LoudsTrie); MARISA_THROW_IF(temp.get() == NULL, MARISA_MEMORY_ERROR); temp->build(keyset, config_flags); - trie_.swap(temp); + delete trie_; + trie_ = temp.release(); } void Trie::mmap(const char *filename) { @@ -26,7 +40,8 @@ void Trie::mmap(const char *filename) { grimoire::Mapper mapper; mapper.open(filename); temp->map(mapper); - trie_.swap(temp); + delete trie_; + trie_ = temp.release(); } void Trie::map(const void *ptr, std::size_t size) { @@ -38,7 +53,8 @@ void Trie::map(const void *ptr, std::size_t size) { grimoire::Mapper mapper; mapper.open(ptr, size); temp->map(mapper); - trie_.swap(temp); + delete trie_; + trie_ = temp.release(); } void Trie::load(const char *filename) { @@ -50,7 +66,8 @@ void Trie::load(const char *filename) { grimoire::Reader reader; reader.open(filename); temp->read(reader); - trie_.swap(temp); + delete trie_; + trie_ = temp.release(); } void Trie::read(int fd) { @@ -62,11 +79,12 @@ void Trie::read(int fd) { grimoire::Reader reader; reader.open(fd); temp->read(reader); - trie_.swap(temp); + delete trie_; + trie_ = temp.release(); } void Trie::save(const char *filename) const { - MARISA_THROW_IF(trie_.get() == NULL, MARISA_STATE_ERROR); + MARISA_THROW_IF(trie_ == NULL, MARISA_STATE_ERROR); MARISA_THROW_IF(filename == NULL, MARISA_NULL_ERROR); grimoire::Writer writer; @@ -75,7 +93,7 @@ void Trie::save(const char *filename) const { } void Trie::write(int fd) const { - MARISA_THROW_IF(trie_.get() == NULL, MARISA_STATE_ERROR); + MARISA_THROW_IF(trie_ == NULL, MARISA_STATE_ERROR); MARISA_THROW_IF(fd == -1, MARISA_CODE_ERROR); grimoire::Writer writer; @@ -84,7 +102,7 @@ void Trie::write(int fd) const { } bool Trie::lookup(Agent &agent) const { - MARISA_THROW_IF(trie_.get() == NULL, MARISA_STATE_ERROR); + MARISA_THROW_IF(trie_ == NULL, MARISA_STATE_ERROR); if (!agent.has_state()) { agent.init_state(); } @@ -92,7 +110,7 @@ bool Trie::lookup(Agent &agent) const { } void Trie::reverse_lookup(Agent &agent) const { - MARISA_THROW_IF(trie_.get() == NULL, MARISA_STATE_ERROR); + MARISA_THROW_IF(trie_ == NULL, MARISA_STATE_ERROR); if (!agent.has_state()) { agent.init_state(); } @@ -100,7 +118,7 @@ void Trie::reverse_lookup(Agent &agent) const { } bool Trie::common_prefix_search(Agent &agent) const { - MARISA_THROW_IF(trie_.get() == NULL, MARISA_STATE_ERROR); + MARISA_THROW_IF(trie_ == NULL, MARISA_STATE_ERROR); if (!agent.has_state()) { agent.init_state(); } @@ -108,7 +126,7 @@ bool Trie::common_prefix_search(Agent &agent) const { } bool Trie::predictive_search(Agent &agent) const { - MARISA_THROW_IF(trie_.get() == NULL, MARISA_STATE_ERROR); + MARISA_THROW_IF(trie_ == NULL, MARISA_STATE_ERROR); if (!agent.has_state()) { agent.init_state(); } @@ -116,47 +134,47 @@ bool Trie::predictive_search(Agent &agent) const { } std::size_t Trie::num_tries() const { - MARISA_THROW_IF(trie_.get() == NULL, MARISA_STATE_ERROR); + MARISA_THROW_IF(trie_ == NULL, MARISA_STATE_ERROR); return trie_->num_tries(); } std::size_t Trie::num_keys() const { - MARISA_THROW_IF(trie_.get() == NULL, MARISA_STATE_ERROR); + MARISA_THROW_IF(trie_ == NULL, MARISA_STATE_ERROR); return trie_->num_keys(); } std::size_t Trie::num_nodes() const { - MARISA_THROW_IF(trie_.get() == NULL, MARISA_STATE_ERROR); + MARISA_THROW_IF(trie_ == NULL, MARISA_STATE_ERROR); return trie_->num_nodes(); } TailMode Trie::tail_mode() const { - MARISA_THROW_IF(trie_.get() == NULL, MARISA_STATE_ERROR); + MARISA_THROW_IF(trie_ == NULL, MARISA_STATE_ERROR); return trie_->tail_mode(); } NodeOrder Trie::node_order() const { - MARISA_THROW_IF(trie_.get() == NULL, MARISA_STATE_ERROR); + MARISA_THROW_IF(trie_ == NULL, MARISA_STATE_ERROR); return trie_->node_order(); } bool Trie::empty() const { - MARISA_THROW_IF(trie_.get() == NULL, MARISA_STATE_ERROR); + MARISA_THROW_IF(trie_ == NULL, MARISA_STATE_ERROR); return trie_->empty(); } std::size_t Trie::size() const { - MARISA_THROW_IF(trie_.get() == NULL, MARISA_STATE_ERROR); + MARISA_THROW_IF(trie_ == NULL, MARISA_STATE_ERROR); return trie_->size(); } std::size_t Trie::total_size() const { - MARISA_THROW_IF(trie_.get() == NULL, MARISA_STATE_ERROR); + MARISA_THROW_IF(trie_ == NULL, MARISA_STATE_ERROR); return trie_->total_size(); } std::size_t Trie::io_size() const { - MARISA_THROW_IF(trie_.get() == NULL, MARISA_STATE_ERROR); + MARISA_THROW_IF(trie_ == NULL, MARISA_STATE_ERROR); return trie_->io_size(); } @@ -165,7 +183,7 @@ void Trie::clear() { } void Trie::swap(Trie &rhs) { - trie_.swap(rhs.trie_); + marisa::swap(trie_, rhs.trie_); } } // namespace marisa @@ -186,11 +204,12 @@ class TrieIO { grimoire::Reader reader; reader.open(file); temp->read(reader); - trie->trie_.swap(temp); + delete trie->trie_; + trie->trie_ = temp.release(); } static void fwrite(std::FILE *file, const Trie &trie) { MARISA_THROW_IF(file == NULL, MARISA_NULL_ERROR); - MARISA_THROW_IF(trie.trie_.get() == NULL, MARISA_STATE_ERROR); + MARISA_THROW_IF(trie.trie_ == NULL, MARISA_STATE_ERROR); grimoire::Writer writer; writer.open(file); trie.trie_->write(writer); @@ -206,11 +225,12 @@ class TrieIO { grimoire::Reader reader; reader.open(stream); temp->read(reader); - trie->trie_.swap(temp); + delete trie->trie_; + trie->trie_ = temp.release(); return stream; } static std::ostream &write(std::ostream &stream, const Trie &trie) { - MARISA_THROW_IF(trie.trie_.get() == NULL, MARISA_STATE_ERROR); + MARISA_THROW_IF(trie.trie_ == NULL, MARISA_STATE_ERROR); grimoire::Writer writer; writer.open(stream); trie.trie_->write(writer);