Skip to content

Commit

Permalink
Make Agent copyable and moveable
Browse files Browse the repository at this point in the history
This allows implementing C++ iterator interface on top of the Agent more
efficiently.

C++ iterators must be copyable, and the only way to copy one previously was
to repeat the query until the index.
  • Loading branch information
glebm committed Jun 24, 2021
1 parent 006020c commit 8f330fb
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 30 deletions.
7 changes: 7 additions & 0 deletions configure.ac
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ AC_PROG_INSTALL

AC_CONFIG_MACRO_DIR([m4])

# Sanitizers
AC_ARG_ENABLE([asan], AS_HELP_STRING([--enable-asan], [Enable address sanitizer]))

# Macros for SSE availability check.
AC_DEFUN([MARISA_ENABLE_SSE2],
[AC_EGREP_CPP([yes], [
Expand Down Expand Up @@ -241,6 +244,10 @@ elif test "x${enable_sse2}" != "xno"; then
CXXFLAGS="$CXXFLAGS -DMARISA_USE_SSE2 -msse2"
fi

AS_IF([test "x$enable_asan" = "xyes"], [
CXXFLAGS="$CXXFLAGS -fsanitize=address"
])

AC_CONFIG_FILES([Makefile
marisa.pc
include/Makefile
Expand Down
21 changes: 16 additions & 5 deletions include/marisa/agent.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ class Agent {
Agent();
~Agent();

Agent(const Agent &other);
Agent &operator=(const Agent &other);

#if __cplusplus >= 201103L
Agent(Agent &&other) noexcept;
Agent &operator=(Agent &&other) noexcept;
#endif

const Query &query() const {
return query_;
}
Expand All @@ -37,6 +45,9 @@ class Agent {
void set_query(const char *str);
void set_query(const char *ptr, std::size_t length);
void set_query(std::size_t key_id);
void set_query(const Query &query) {
query_ = query;
}

const grimoire::trie::State &state() const {
return *state_;
Expand Down Expand Up @@ -65,7 +76,7 @@ class Agent {
}

bool has_state() const {
return state_.get() != NULL;
return state_ != NULL;
}
void init_state();

Expand All @@ -75,11 +86,11 @@ class Agent {
private:
Query query_;
Key key_;
scoped_ptr<grimoire::trie::State> state_;

// Disallows copy and assignment.
Agent(const Agent &);
Agent &operator=(const Agent &);
// Cannot be `scoped_ptr` because `State` is forward-declared.
grimoire::trie::State *state_;

void clear_state();
};

} // namespace marisa
Expand Down
10 changes: 10 additions & 0 deletions include/marisa/scoped-array.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@ class scoped_array {
scoped_array() : array_(NULL) {}
explicit scoped_array(T *array) : array_(array) {}

#if __cplusplus >= 201103L
scoped_array(scoped_array &&other) noexcept : array_(other.array_) {
other.array_ = NULL;
}
scoped_array<T> &operator=(scoped_array<T> &&other) noexcept {
other.array_ = NULL;
return *this;
}
#endif

~scoped_array() {
delete [] array_;
}
Expand Down
10 changes: 10 additions & 0 deletions include/marisa/scoped-ptr.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@ class scoped_ptr {
delete ptr_;
}

#if __cplusplus >= 201103L
scoped_ptr(scoped_ptr &&other) noexcept : ptr_(other.ptr_) {
other.ptr_ = NULL;
}
scoped_ptr<T> &operator=(scoped_ptr<T> &&other) noexcept {
other.ptr_ = NULL;
return *this;
}
#endif

void reset(T *ptr = NULL) {
MARISA_DEBUG_IF((ptr != NULL) && (ptr == ptr_), MARISA_RESET_ERROR);
scoped_ptr(ptr).swap(*this);
Expand Down
66 changes: 54 additions & 12 deletions lib/marisa/agent.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,47 +5,89 @@

namespace marisa {

Agent::Agent() : query_(), key_(), state_() {}
Agent::Agent() : query_(), key_(), state_(NULL) {}

Agent::~Agent() {}
Agent::~Agent() {
delete state_;
}

Agent::Agent(const Agent &other)
: query_(other.query_),
key_(other.key_),
state_(other.has_state() ? new (std::nothrow) grimoire::trie::State(other.state()) : NULL) {}

Agent &Agent::operator=(const Agent &other) {
query_ = other.query_;
key_ = other.key_;
delete state_;
if (other.has_state()) {
state_ = new (std::nothrow) grimoire::trie::State(other.state());
} else {
state_ = NULL;
}
return *this;
}

#if __cplusplus >= 201103L
Agent::Agent(Agent &&other) noexcept
: query_(other.query_), key_(other.key_), state_(other.state_) {
other.state_ = NULL;
}

Agent &Agent::operator=(Agent &&other) noexcept {
query_ = other.query_;
key_ = other.key_;
delete state_;
state_ = other.state_;
other.state_ = NULL;
return *this;
}
#endif

void Agent::set_query(const char *str) {
MARISA_THROW_IF(str == NULL, MARISA_NULL_ERROR);
if (state_.get() != NULL) {
state_->reset();
if (state_ != NULL) {
clear_state();
}
query_.set_str(str);
}

void Agent::set_query(const char *ptr, std::size_t length) {
MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_NULL_ERROR);
if (state_.get() != NULL) {
state_->reset();
if (state_ != NULL) {
clear_state();
}
query_.set_str(ptr, length);
}

void Agent::set_query(std::size_t key_id) {
if (state_.get() != NULL) {
state_->reset();
if (state_ != NULL) {
clear_state();
}
query_.set_id(key_id);
}

void Agent::init_state() {
MARISA_THROW_IF(state_.get() != NULL, MARISA_STATE_ERROR);
state_.reset(new (std::nothrow) grimoire::State);
MARISA_THROW_IF(state_.get() == NULL, MARISA_MEMORY_ERROR);
MARISA_THROW_IF(state_ != NULL, MARISA_STATE_ERROR);
delete state_;
state_ = new (std::nothrow) grimoire::State;
MARISA_THROW_IF(state_ == NULL, MARISA_MEMORY_ERROR);
}

void Agent::clear() {
Agent().swap(*this);
}


void Agent::clear_state() {
delete state_;
state_ = nullptr;
}

void Agent::swap(Agent &rhs) {
query_.swap(rhs.query_);
key_.swap(rhs.key_);
state_.swap(rhs.state_);
marisa::swap(state_, rhs.state_);
}

} // namespace marisa
12 changes: 8 additions & 4 deletions lib/marisa/grimoire/trie/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ class State {
: key_buf_(), history_(), node_id_(0), query_pos_(0),
history_pos_(0), status_code_(MARISA_READY_TO_ALL) {}

State(const State &) = default;
State &operator=(const State &) = default;

#if __cplusplus >= 201103L
State(State &&) noexcept = default;
State &operator=(State &&) noexcept = default;
#endif

void set_node_id(std::size_t node_id) {
MARISA_DEBUG_IF(node_id > MARISA_UINT32_MAX, MARISA_SIZE_ERROR);
node_id_ = (UInt32)node_id;
Expand Down Expand Up @@ -104,10 +112,6 @@ class State {
UInt32 query_pos_;
UInt32 history_pos_;
StatusCode status_code_;

// Disallows copy and assignment.
State(const State &);
State &operator=(const State &);
};

} // namespace trie
Expand Down
50 changes: 41 additions & 9 deletions lib/marisa/grimoire/vector/vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,38 @@ class Vector {
}
}

Vector(const Vector<T> &other)
: buf_(), objs_(NULL), const_objs_(NULL),
size_(0), capacity_(0), fixed_(other.fixed_) {
if (other.buf_.get() == NULL) {
objs_ = other.objs_;
const_objs_ = other.const_objs_;
size_ = other.size_;
capacity_ = other.capacity_;
} else {
copy(other.const_objs_, other.size_, other.capacity_);
}
}

Vector &operator=(const Vector<T> &other) {
clear();
fixed_ = other.fixed_;
if (other.buf_.get() == NULL) {
objs_ = other.objs_;
const_objs_ = other.const_objs_;
size_ = other.size_;
capacity_ = other.capacity_;
} else {
copy(other.const_objs_, other.size_, other.capacity_);
}
return *this;
}

#if __cplusplus >= 201103L
Vector(Vector &&) noexcept = default;
Vector &operator=(Vector<T> &&) noexcept = default;
#endif

void map(Mapper &mapper) {
Vector temp;
temp.map_(mapper);
Expand Down Expand Up @@ -225,14 +257,17 @@ class Vector {
// realloc() assumes that T's placement new does not throw an exception.
void realloc(std::size_t new_capacity) {
MARISA_DEBUG_IF(new_capacity > max_size(), MARISA_SIZE_ERROR);
copy(objs_, size_, new_capacity);
}

scoped_array<char> new_buf(
new (std::nothrow) char[sizeof(T) * new_capacity]);
// copy() assumes that T's placement new does not throw an exception.
void copy(const T *src, std::size_t src_size, std::size_t capacity) {
scoped_array<char> new_buf(new (std::nothrow) char[sizeof(T) * capacity]);
MARISA_DEBUG_IF(new_buf.get() == NULL, MARISA_MEMORY_ERROR);
T *new_objs = reinterpret_cast<T *>(new_buf.get());

for (std::size_t i = 0; i < size_; ++i) {
new (&new_objs[i]) T(objs_[i]);
for (std::size_t i = 0; i < src_size; ++i) {
new (&new_objs[i]) T(src[i]);
}
for (std::size_t i = 0; i < size_; ++i) {
objs_[i].~T();
Expand All @@ -241,12 +276,9 @@ class Vector {
buf_.swap(new_buf);
objs_ = new_objs;
const_objs_ = new_objs;
capacity_ = new_capacity;
size_ = src_size;
capacity_ = capacity;
}

// Disallows copy and assignment.
Vector(const Vector &);
Vector &operator=(const Vector &);
};

} // namespace vector
Expand Down
39 changes: 39 additions & 0 deletions tests/marisa-test.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#include <cstddef>
#include <cstdlib>
#include <cstring>
#include <ctime>
#include <sstream>
#include <utility>
#include <vector>

#include <marisa.h>

Expand Down Expand Up @@ -258,6 +261,41 @@ void TestPredictiveSearch(const marisa::Trie &trie,
}
}

void TestPredictiveSearchAgentCopy(const marisa::Trie &trie,
const marisa::Keyset &keyset) {
marisa::Agent agent;
for (std::size_t i = 0; i < keyset.size(); ++i) {
agent.set_query(keyset[i].ptr(), keyset[i].length());
ASSERT(trie.predictive_search(agent));
ASSERT(agent.key().id() == keyset[i].id());

std::vector<marisa::Agent> agent_copies;
std::vector<std::size_t> ids;
while (trie.predictive_search(agent)) {
ASSERT(agent.key().id() > keyset[i].id());
ids.push_back(agent.key().id());

// Tests copy constructor.
agent_copies.push_back(agent);
}

for (std::size_t i = 0; i < agent_copies.size(); ++i) {
marisa::Agent agent_copy;

// Tests copy assignment.
agent_copy = agent_copies[i];

ASSERT(agent_copy.key().id() == ids[i]);
if (i + 1 < agent_copies.size()) {
ASSERT(trie.predictive_search(agent_copy));
ASSERT(agent_copy.key().id() == ids[i + 1]);
} else {
ASSERT(!trie.predictive_search(agent_copy));
}
}
}
}

void TestTrie(int num_tries, marisa::TailMode tail_mode,
marisa::NodeOrder node_order, marisa::Keyset &keyset) {
for (std::size_t i = 0; i < keyset.size(); ++i) {
Expand All @@ -276,6 +314,7 @@ void TestTrie(int num_tries, marisa::TailMode tail_mode,
TestLookup(trie, keyset);
TestCommonPrefixSearch(trie, keyset);
TestPredictiveSearch(trie, keyset);
TestPredictiveSearchAgentCopy(trie, keyset);

trie.save("marisa-test.dat");

Expand Down

0 comments on commit 8f330fb

Please sign in to comment.