diff --git a/configure.ac b/configure.ac index 118a677..c7adec0 100644 --- a/configure.ac +++ b/configure.ac @@ -13,6 +13,9 @@ AC_PROG_INSTALL AC_CONFIG_MACRO_DIR([m4]) +# Sanitizers +AC_ARG_ENABLE([asan], AS_HELP_STRING([--enable-asan], [Enable address sanitizer])) + # Macros for SSE availability check. AC_DEFUN([MARISA_ENABLE_SSE2], [AC_EGREP_CPP([yes], [ @@ -241,6 +244,10 @@ elif test "x${enable_sse2}" != "xno"; then CXXFLAGS="$CXXFLAGS -DMARISA_USE_SSE2 -msse2" fi +AS_IF([test "x$enable_asan" = "xyes"], [ + CXXFLAGS="$CXXFLAGS -fsanitize=address" +]) + AC_CONFIG_FILES([Makefile marisa.pc include/Makefile diff --git a/include/marisa/agent.h b/include/marisa/agent.h index b549d36..1d3f913 100644 --- a/include/marisa/agent.h +++ b/include/marisa/agent.h @@ -22,6 +22,14 @@ class Agent { Agent(); ~Agent(); + Agent(const Agent &other); + Agent &operator=(const Agent &other); + +#if __cplusplus >= 201103L + Agent(Agent &&) noexcept = default; + Agent &operator=(Agent &&) noexcept = default; +#endif + const Query &query() const { return query_; } @@ -37,6 +45,9 @@ class Agent { void set_query(const char *str); void set_query(const char *ptr, std::size_t length); void set_query(std::size_t key_id); + void set_query(const Query &query) { + query_ = query; + } const grimoire::trie::State &state() const { return *state_; @@ -76,10 +87,6 @@ class Agent { Query query_; Key key_; scoped_ptr state_; - - // Disallows copy and assignment. - Agent(const Agent &); - Agent &operator=(const Agent &); }; } // namespace marisa diff --git a/include/marisa/scoped-array.h b/include/marisa/scoped-array.h index 34cefa4..b7d92be 100644 --- a/include/marisa/scoped-array.h +++ b/include/marisa/scoped-array.h @@ -11,6 +11,9 @@ class scoped_array { scoped_array() : array_(NULL) {} explicit scoped_array(T *array) : array_(array) {} + scoped_array(scoped_array &&) noexcept = default; + scoped_array &operator=(scoped_array &&) noexcept = default; + ~scoped_array() { delete [] array_; } diff --git a/lib/marisa/agent.cc b/lib/marisa/agent.cc index 7fa7cb1..5808f6c 100644 --- a/lib/marisa/agent.cc +++ b/lib/marisa/agent.cc @@ -9,6 +9,22 @@ Agent::Agent() : query_(), key_(), state_() {} Agent::~Agent() {} +Agent::Agent(const Agent &other) + : query_(other.query_), + key_(other.key_), + state_(other.has_state() ? new (std::nothrow) grimoire::trie::State(other.state()) : NULL) {} + +Agent &Agent::operator=(const Agent &other) { + query_ = other.query_; + key_ = other.key_; + if (other.has_state()) { + state_.reset(new (std::nothrow) grimoire::trie::State(other.state())); + } else { + state_.clear(); + } + return *this; +} + void Agent::set_query(const char *str) { MARISA_THROW_IF(str == NULL, MARISA_NULL_ERROR); if (state_.get() != NULL) { diff --git a/lib/marisa/grimoire/trie/state.h b/lib/marisa/grimoire/trie/state.h index df605a6..07bda52 100644 --- a/lib/marisa/grimoire/trie/state.h +++ b/lib/marisa/grimoire/trie/state.h @@ -24,6 +24,14 @@ class State { : key_buf_(), history_(), node_id_(0), query_pos_(0), history_pos_(0), status_code_(MARISA_READY_TO_ALL) {} + State(const State &) = default; + State &operator=(const State &) = default; + +#if __cplusplus >= 201103L + State(State &&) noexcept = default; + State &operator=(State &&) noexcept = default; +#endif + void set_node_id(std::size_t node_id) { MARISA_DEBUG_IF(node_id > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); node_id_ = (UInt32)node_id; @@ -104,10 +112,6 @@ class State { UInt32 query_pos_; UInt32 history_pos_; StatusCode status_code_; - - // Disallows copy and assignment. - State(const State &); - State &operator=(const State &); }; } // namespace trie diff --git a/lib/marisa/grimoire/vector/vector.h b/lib/marisa/grimoire/vector/vector.h index 2bfccdb..6751f81 100644 --- a/lib/marisa/grimoire/vector/vector.h +++ b/lib/marisa/grimoire/vector/vector.h @@ -23,6 +23,38 @@ class Vector { } } + Vector(const Vector &other) + : buf_(), objs_(NULL), const_objs_(NULL), + size_(0), capacity_(0), fixed_(other.fixed_) { + if (other.buf_.get() == NULL) { + objs_ = other.objs_; + const_objs_ = other.const_objs_; + size_ = other.size_; + capacity_ = other.capacity_; + } else { + copy(other.const_objs_, other.size_, other.capacity_); + } + } + + Vector &operator=(const Vector &other) { + clear(); + fixed_ = other.fixed_; + if (other.buf_.get() == NULL) { + objs_ = other.objs_; + const_objs_ = other.const_objs_; + size_ = other.size_; + capacity_ = other.capacity_; + } else { + copy(other.const_objs_, other.size_, other.capacity_); + } + return *this; + } + +#if __cplusplus >= 201103L + Vector(Vector &&) noexcept = default; + Vector &operator=(Vector &&) noexcept = default; +#endif + void map(Mapper &mapper) { Vector temp; temp.map_(mapper); @@ -225,14 +257,17 @@ class Vector { // realloc() assumes that T's placement new does not throw an exception. void realloc(std::size_t new_capacity) { MARISA_DEBUG_IF(new_capacity > max_size(), MARISA_SIZE_ERROR); + copy(objs_, size_, new_capacity); + } - scoped_array new_buf( - new (std::nothrow) char[sizeof(T) * new_capacity]); + // copy() assumes that T's placement new does not throw an exception. + void copy(const T *src, std::size_t src_size, std::size_t capacity) { + scoped_array new_buf(new (std::nothrow) char[sizeof(T) * capacity]); MARISA_DEBUG_IF(new_buf.get() == NULL, MARISA_MEMORY_ERROR); T *new_objs = reinterpret_cast(new_buf.get()); - for (std::size_t i = 0; i < size_; ++i) { - new (&new_objs[i]) T(objs_[i]); + for (std::size_t i = 0; i < src_size; ++i) { + new (&new_objs[i]) T(src[i]); } for (std::size_t i = 0; i < size_; ++i) { objs_[i].~T(); @@ -241,12 +276,9 @@ class Vector { buf_.swap(new_buf); objs_ = new_objs; const_objs_ = new_objs; - capacity_ = new_capacity; + size_ = src_size; + capacity_ = capacity; } - - // Disallows copy and assignment. - Vector(const Vector &); - Vector &operator=(const Vector &); }; } // namespace vector diff --git a/tests/marisa-test.cc b/tests/marisa-test.cc index 36e4258..4ab4e7b 100644 --- a/tests/marisa-test.cc +++ b/tests/marisa-test.cc @@ -1,7 +1,10 @@ +#include #include #include #include #include +#include +#include #include @@ -258,6 +261,41 @@ void TestPredictiveSearch(const marisa::Trie &trie, } } +void TestPredictiveSearchAgentCopy(const marisa::Trie &trie, + const marisa::Keyset &keyset) { + marisa::Agent agent; + for (std::size_t i = 0; i < keyset.size(); ++i) { + agent.set_query(keyset[i].ptr(), keyset[i].length()); + ASSERT(trie.predictive_search(agent)); + ASSERT(agent.key().id() == keyset[i].id()); + + std::vector agent_copies; + std::vector ids; + while (trie.predictive_search(agent)) { + ASSERT(agent.key().id() > keyset[i].id()); + ids.push_back(agent.key().id()); + + // Tests copy constructor. + agent_copies.push_back(agent); + } + + for (std::size_t i = 0; i < agent_copies.size(); ++i) { + marisa::Agent agent_copy; + + // Tests copy assignment. + agent_copy = agent_copies[i]; + + ASSERT(agent_copy.key().id() == ids[i]); + if (i + 1 < agent_copies.size()) { + ASSERT(trie.predictive_search(agent_copy)); + ASSERT(agent_copy.key().id() == ids[i + 1]); + } else { + ASSERT(!trie.predictive_search(agent_copy)); + } + } + } +} + void TestTrie(int num_tries, marisa::TailMode tail_mode, marisa::NodeOrder node_order, marisa::Keyset &keyset) { for (std::size_t i = 0; i < keyset.size(); ++i) { @@ -276,6 +314,7 @@ void TestTrie(int num_tries, marisa::TailMode tail_mode, TestLookup(trie, keyset); TestCommonPrefixSearch(trie, keyset); TestPredictiveSearch(trie, keyset); + TestPredictiveSearchAgentCopy(trie, keyset); trie.save("marisa-test.dat");