Skip to content

Commit

Permalink
Make Agent copyable and moveable
Browse files Browse the repository at this point in the history
This allows implementing C++ iterator interface on top of the Agent more
efficiently.

C++ iterators must be copyable, and the only way to copy one previously was
to repeat the query until the index.
  • Loading branch information
glebm committed Jun 24, 2021
1 parent 006020c commit 6758d51
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 17 deletions.
7 changes: 7 additions & 0 deletions configure.ac
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ AC_PROG_INSTALL

AC_CONFIG_MACRO_DIR([m4])

# Sanitizers
AC_ARG_ENABLE([asan], AS_HELP_STRING([--enable-asan], [Enable address sanitizer]))

# Macros for SSE availability check.
AC_DEFUN([MARISA_ENABLE_SSE2],
[AC_EGREP_CPP([yes], [
Expand Down Expand Up @@ -241,6 +244,10 @@ elif test "x${enable_sse2}" != "xno"; then
CXXFLAGS="$CXXFLAGS -DMARISA_USE_SSE2 -msse2"
fi

AS_IF([test "x$enable_asan" = "xyes"], [
CXXFLAGS="$CXXFLAGS -fsanitize=address"
])

AC_CONFIG_FILES([Makefile
marisa.pc
include/Makefile
Expand Down
15 changes: 11 additions & 4 deletions include/marisa/agent.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ class Agent {
Agent();
~Agent();

Agent(const Agent &other);
Agent &operator=(const Agent &other);

#if __cplusplus >= 201103L
Agent(Agent &&) noexcept = default;
Agent &operator=(Agent &&) noexcept = default;
#endif

const Query &query() const {
return query_;
}
Expand All @@ -37,6 +45,9 @@ class Agent {
void set_query(const char *str);
void set_query(const char *ptr, std::size_t length);
void set_query(std::size_t key_id);
void set_query(const Query &query) {
query_ = query;
}

const grimoire::trie::State &state() const {
return *state_;
Expand Down Expand Up @@ -76,10 +87,6 @@ class Agent {
Query query_;
Key key_;
scoped_ptr<grimoire::trie::State> state_;

// Disallows copy and assignment.
Agent(const Agent &);
Agent &operator=(const Agent &);
};

} // namespace marisa
Expand Down
3 changes: 3 additions & 0 deletions include/marisa/scoped-array.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ class scoped_array {
scoped_array() : array_(NULL) {}
explicit scoped_array(T *array) : array_(array) {}

scoped_array(scoped_array &&) noexcept = default;
scoped_array &operator=(scoped_array &&) noexcept = default;

~scoped_array() {
delete [] array_;
}
Expand Down
16 changes: 16 additions & 0 deletions lib/marisa/agent.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,22 @@ Agent::Agent() : query_(), key_(), state_() {}

Agent::~Agent() {}

Agent::Agent(const Agent &other)
: query_(other.query_),
key_(other.key_),
state_(other.has_state() ? new (std::nothrow) grimoire::trie::State(other.state()) : NULL) {}

Agent &Agent::operator=(const Agent &other) {
query_ = other.query_;
key_ = other.key_;
if (other.has_state()) {
state_.reset(new (std::nothrow) grimoire::trie::State(other.state()));
} else {
state_.clear();
}
return *this;
}

void Agent::set_query(const char *str) {
MARISA_THROW_IF(str == NULL, MARISA_NULL_ERROR);
if (state_.get() != NULL) {
Expand Down
12 changes: 8 additions & 4 deletions lib/marisa/grimoire/trie/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ class State {
: key_buf_(), history_(), node_id_(0), query_pos_(0),
history_pos_(0), status_code_(MARISA_READY_TO_ALL) {}

State(const State &) = default;
State &operator=(const State &) = default;

#if __cplusplus >= 201103L
State(State &&) noexcept = default;
State &operator=(State &&) noexcept = default;
#endif

void set_node_id(std::size_t node_id) {
MARISA_DEBUG_IF(node_id > MARISA_UINT32_MAX, MARISA_SIZE_ERROR);
node_id_ = (UInt32)node_id;
Expand Down Expand Up @@ -104,10 +112,6 @@ class State {
UInt32 query_pos_;
UInt32 history_pos_;
StatusCode status_code_;

// Disallows copy and assignment.
State(const State &);
State &operator=(const State &);
};

} // namespace trie
Expand Down
50 changes: 41 additions & 9 deletions lib/marisa/grimoire/vector/vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,38 @@ class Vector {
}
}

Vector(const Vector<T> &other)
: buf_(), objs_(NULL), const_objs_(NULL),
size_(0), capacity_(0), fixed_(other.fixed_) {
if (other.buf_.get() == NULL) {
objs_ = other.objs_;
const_objs_ = other.const_objs_;
size_ = other.size_;
capacity_ = other.capacity_;
} else {
copy(other.const_objs_, other.size_, other.capacity_);
}
}

Vector &operator=(const Vector<T> &other) {
clear();
fixed_ = other.fixed_;
if (other.buf_.get() == NULL) {
objs_ = other.objs_;
const_objs_ = other.const_objs_;
size_ = other.size_;
capacity_ = other.capacity_;
} else {
copy(other.const_objs_, other.size_, other.capacity_);
}
return *this;
}

#if __cplusplus >= 201103L
Vector(Vector<T> &&) noexcept = default;
Vector &operator=(Vector<T> &&) noexcept = default;
#endif

void map(Mapper &mapper) {
Vector temp;
temp.map_(mapper);
Expand Down Expand Up @@ -225,14 +257,17 @@ class Vector {
// realloc() assumes that T's placement new does not throw an exception.
void realloc(std::size_t new_capacity) {
MARISA_DEBUG_IF(new_capacity > max_size(), MARISA_SIZE_ERROR);
copy(objs_, size_, new_capacity);
}

scoped_array<char> new_buf(
new (std::nothrow) char[sizeof(T) * new_capacity]);
// copy() assumes that T's placement new does not throw an exception.
void copy(const T *src, std::size_t src_size, std::size_t capacity) {
scoped_array<char> new_buf(new (std::nothrow) char[sizeof(T) * capacity]);
MARISA_DEBUG_IF(new_buf.get() == NULL, MARISA_MEMORY_ERROR);
T *new_objs = reinterpret_cast<T *>(new_buf.get());

for (std::size_t i = 0; i < size_; ++i) {
new (&new_objs[i]) T(objs_[i]);
for (std::size_t i = 0; i < src_size; ++i) {
new (&new_objs[i]) T(src[i]);
}
for (std::size_t i = 0; i < size_; ++i) {
objs_[i].~T();
Expand All @@ -241,12 +276,9 @@ class Vector {
buf_.swap(new_buf);
objs_ = new_objs;
const_objs_ = new_objs;
capacity_ = new_capacity;
size_ = src_size;
capacity_ = capacity;
}

// Disallows copy and assignment.
Vector(const Vector &);
Vector &operator=(const Vector &);
};

} // namespace vector
Expand Down
39 changes: 39 additions & 0 deletions tests/marisa-test.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#include <cstddef>
#include <cstdlib>
#include <cstring>
#include <ctime>
#include <sstream>
#include <utility>
#include <vector>

#include <marisa.h>

Expand Down Expand Up @@ -258,6 +261,41 @@ void TestPredictiveSearch(const marisa::Trie &trie,
}
}

void TestPredictiveSearchAgentCopy(const marisa::Trie &trie,
const marisa::Keyset &keyset) {
marisa::Agent agent;
for (std::size_t i = 0; i < keyset.size(); ++i) {
agent.set_query(keyset[i].ptr(), keyset[i].length());
ASSERT(trie.predictive_search(agent));
ASSERT(agent.key().id() == keyset[i].id());

std::vector<marisa::Agent> agent_copies;
std::vector<std::size_t> ids;
while (trie.predictive_search(agent)) {
ASSERT(agent.key().id() > keyset[i].id());
ids.push_back(agent.key().id());

// Tests copy constructor.
agent_copies.push_back(agent);
}

for (std::size_t i = 0; i < agent_copies.size(); ++i) {
marisa::Agent agent_copy;

// Tests copy assignment.
agent_copy = agent_copies[i];

ASSERT(agent_copy.key().id() == ids[i]);
if (i + 1 < agent_copies.size()) {
ASSERT(trie.predictive_search(agent_copy));
ASSERT(agent_copy.key().id() == ids[i + 1]);
} else {
ASSERT(!trie.predictive_search(agent_copy));
}
}
}
}

void TestTrie(int num_tries, marisa::TailMode tail_mode,
marisa::NodeOrder node_order, marisa::Keyset &keyset) {
for (std::size_t i = 0; i < keyset.size(); ++i) {
Expand All @@ -276,6 +314,7 @@ void TestTrie(int num_tries, marisa::TailMode tail_mode,
TestLookup(trie, keyset);
TestCommonPrefixSearch(trie, keyset);
TestPredictiveSearch(trie, keyset);
TestPredictiveSearchAgentCopy(trie, keyset);

trie.save("marisa-test.dat");

Expand Down

0 comments on commit 6758d51

Please sign in to comment.