Skip to content

Commit

Permalink
refactor: dispatcher logic for clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
luan committed Oct 13, 2023
1 parent 4d10010 commit d17bf26
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 85 deletions.
108 changes: 61 additions & 47 deletions src/game/scheduling/dispatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "utils/tools.hpp"

constexpr static auto ASYNC_TIME_OUT = std::chrono::seconds(15);
constexpr static auto SLEEP_TIME_MS = 15;
static std::mutex dummyMutex; // This is only used for signaling the condition variable and not as an actual lock.

Dispatcher &Dispatcher::getInstance() {
return inject<Dispatcher>();
Expand All @@ -25,51 +25,60 @@ void Dispatcher::init() {
updateClock();

threadPool.addLoad([this] {
std::unique_lock asyncLock(mutex);
std::unique_lock asyncLock(dummyMutex);

while (!threadPool.getIoContext().stopped()) {
updateClock();

// Execute all asynchronous events separately by context
for (uint_fast8_t i = 0; i < static_cast<uint8_t>(AsyncEventContext::Last); ++i) {
executeAsyncEvents(i, asyncLock);
for (uint_fast8_t i = 0; i < static_cast<uint8_t>(TaskGroup::Last); ++i) {
executeEvents(i, asyncLock);
}

// Merge all events that were created by async events
mergeEvents();

executeEvents();
executeScheduledEvents();

// Merge all events that were created by events and scheduled events
mergeEvents();

auto waitDuration = timeUntilNextScheduledTask();
cv.wait_for(asyncLock, waitDuration);
if (!hasPendingTasks) {
auto waitDuration = timeUntilNextScheduledTask();
signalSchedule.wait_for(asyncLock, waitDuration);
}
}
});
}

void Dispatcher::addEvent(std::function<void(void)> &&f, std::string_view context, uint32_t expiresAfterMs) {
auto &thread = threads[getThreadId()];
std::scoped_lock lock(thread->mutex);
thread->tasks.emplace_back(expiresAfterMs, std::move(f), context);
cv.notify_one();
bool notify = !hasPendingTasks;
thread->tasks[static_cast<uint8_t>(TaskGroup::Serial)].emplace_back(expiresAfterMs, std::move(f), context);
if (notify && !hasPendingTasks) {
hasPendingTasks = true;
signalSchedule.notify_one();
}
}

void Dispatcher::addEvent_async(std::function<void(void)> &&f, AsyncEventContext context) {
void Dispatcher::addEvent_async(std::function<void(void)> &&f, TaskGroup group) {
auto &thread = threads[getThreadId()];
std::scoped_lock lock(thread->mutex);
thread->asyncTasks[static_cast<uint8_t>(context)].emplace_back(0, std::move(f), "Dispatcher::addEvent_async");
cv.notify_one();
bool notify = !hasPendingTasks;
thread->tasks[static_cast<uint8_t>(group)].emplace_back(0, std::move(f), "Dispatcher::addEvent_async");
if (notify && !hasPendingTasks) {
hasPendingTasks = true;
signalSchedule.notify_one();
}
}

uint64_t Dispatcher::scheduleEvent(const std::shared_ptr<Task> &task) {
auto &thread = threads[getThreadId()];
std::scoped_lock lock(thread->mutex);
thread->scheduledTasks.emplace_back(task);
cv.notify_one();
return scheduledTasksRef.emplace(task->generateId(), task).first->first;
bool notify = !hasPendingTasks;
signalSchedule.notify_one();
auto eventId = scheduledTasksRef.emplace(task->generateId(), task).first->first;
if (notify && !hasPendingTasks) {
hasPendingTasks = true;
signalSchedule.notify_one();
}
return eventId;
}

uint64_t Dispatcher::scheduleEvent(uint32_t delay, std::function<void(void)> &&f, std::string_view context, bool cycle) {
Expand All @@ -87,41 +96,45 @@ void Dispatcher::stopEvent(uint64_t eventId) {
scheduledTasksRef.erase(it);
}

void Dispatcher::executeEvents() {
for (const auto &task : eventTasks) {
void Dispatcher::executeSerialEvents(std::vector<Task> &tasks) {
for (const auto &task : tasks) {
if (task.execute()) {
++dispatcherCycle;
}
}
eventTasks.clear();
tasks.clear();
}

void Dispatcher::executeAsyncEvents(const uint8_t contextId, std::unique_lock<std::mutex> &asyncLock) {
auto &asyncTasks = asyncEventTasks[contextId];
if (asyncTasks.empty()) {
return;
}

void Dispatcher::executeParallelEvents(std::vector<Task> &tasks, const uint8_t groupId, std::unique_lock<std::mutex> &asyncLock) {
std::atomic_uint_fast64_t executedTasks = 0;

// Execute Async Task
for (const auto &task : asyncTasks) {
threadPool.addLoad([this, &task, &executedTasks, totalTaskSize = asyncTasks.size()] {
for (const auto &task : tasks) {
threadPool.addLoad([this, &task, &executedTasks, totalTaskSize = tasks.size()] {
task.execute();

if (executedTasks.fetch_add(1) == totalTaskSize) {
asyncTasks_cv.notify_one();
signalAsync.notify_one();
}
});
}

// Wait for all the tasks in the current context to be executed.
if (asyncTasks_cv.wait_for(asyncLock, ASYNC_TIME_OUT) == std::cv_status::timeout) {
g_logger().warn("A timeout occurred when executing the async dispatch in the context({}).", contextId);
if (signalAsync.wait_for(asyncLock, ASYNC_TIME_OUT) == std::cv_status::timeout) {
g_logger().warn("A timeout occurred when executing the async dispatch in the context({}).", groupId);
}
tasks.clear();
}

// Clear all async tasks
asyncTasks.clear();
void Dispatcher::executeEvents(const uint8_t groupId, std::unique_lock<std::mutex> &asyncLock) {
auto &tasks = m_tasks[groupId];
if (tasks.empty()) {
return;
}

if (groupId == static_cast<uint8_t>(TaskGroup::Serial)) {
executeSerialEvents(tasks);
} else {
executeParallelEvents(tasks, groupId, asyncLock);
}
}

void Dispatcher::executeScheduledEvents() {
Expand All @@ -147,15 +160,9 @@ void Dispatcher::mergeEvents() {
for (auto &thread : threads) {
std::scoped_lock lock(thread->mutex);
if (!thread->tasks.empty()) {
eventTasks.insert(eventTasks.end(), make_move_iterator(thread->tasks.begin()), make_move_iterator(thread->tasks.end()));
thread->tasks.clear();
}

for (uint_fast8_t i = 0; i < static_cast<uint8_t>(AsyncEventContext::Last); ++i) {
auto &context = thread->asyncTasks[i];
if (!context.empty()) {
asyncEventTasks[i].insert(asyncEventTasks[i].end(), make_move_iterator(context.begin()), make_move_iterator(context.end()));
context.clear();
for (uint_fast8_t i = 0; i < static_cast<uint8_t>(TaskGroup::Last); ++i) {
m_tasks[i].insert(m_tasks[i].end(), make_move_iterator(thread->tasks[i].begin()), make_move_iterator(thread->tasks[i].end()));
thread->tasks[i].clear();
}
}

Expand All @@ -167,6 +174,13 @@ void Dispatcher::mergeEvents() {
thread->scheduledTasks.clear();
}
}
hasPendingTasks = false;
for (uint_fast8_t i = 0; i < static_cast<uint8_t>(TaskGroup::Last); ++i) {
if (!m_tasks[i].empty()) {
hasPendingTasks = true;
break;
}
}
}

std::chrono::nanoseconds Dispatcher::timeUntilNextScheduledTask() {
Expand Down
32 changes: 17 additions & 15 deletions src/game/scheduling/dispatcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
static constexpr uint16_t DISPATCHER_TASK_EXPIRATION = 2000;
static constexpr uint16_t SCHEDULER_MINTICKS = 50;

enum class AsyncEventContext : uint8_t {
First,
enum class TaskGroup : uint8_t {
Serial,
GenericParallel,
Last
};

Expand All @@ -43,11 +44,11 @@ class Dispatcher {

void init();
void shutdown() {
asyncTasks_cv.notify_all();
signalAsync.notify_all();
}

void addEvent(std::function<void(void)> &&f, std::string_view context, uint32_t expiresAfterMs = 0);
void addEvent_async(std::function<void(void)> &&f, AsyncEventContext context = AsyncEventContext::First);
void addEvent_async(std::function<void(void)> &&f, TaskGroup group = TaskGroup::Serial);

uint64_t scheduleEvent(const std::shared_ptr<Task> &task);
uint64_t scheduleEvent(uint32_t delay, std::function<void(void)> &&f, std::string_view context) {
Expand Down Expand Up @@ -84,36 +85,37 @@ class Dispatcher {
uint64_t scheduleEvent(uint32_t delay, std::function<void(void)> &&f, std::string_view context, bool cycle);

inline void mergeEvents();
inline void executeEvents();
inline void executeAsyncEvents(const uint8_t contextId, std::unique_lock<std::mutex> &asyncLock);
inline void executeEvents(const uint8_t groupId, std::unique_lock<std::mutex> &asyncLock);
inline void executeScheduledEvents();

inline void executeSerialEvents(std::vector<Task> &tasks);
inline void executeParallelEvents(std::vector<Task> &tasks, const uint8_t groupId, std::unique_lock<std::mutex> &asyncLock);
inline std::chrono::nanoseconds timeUntilNextScheduledTask();

uint_fast64_t dispatcherCycle = 0;

ThreadPool &threadPool;
std::mutex mutex;
std::condition_variable asyncTasks_cv;
std::condition_variable cv;
bool hasPendingTasks = false;
std::condition_variable signalAsync;
std::condition_variable signalSchedule;
std::atomic_bool hasPendingTasks = false;

// Thread Events
struct ThreadTask {
ThreadTask() {
tasks.reserve(2000);
for (auto &task : tasks) {
task.reserve(2000);
}
scheduledTasks.reserve(2000);
}

std::vector<Task> tasks;
std::array<std::vector<Task>, static_cast<uint8_t>(AsyncEventContext::Last)> asyncTasks;
std::array<std::vector<Task>, static_cast<uint8_t>(TaskGroup::Last)> tasks;
std::vector<std::shared_ptr<Task>> scheduledTasks;
std::mutex mutex;
};
std::vector<std::unique_ptr<ThreadTask>> threads;

// Main Events
std::vector<Task> eventTasks;
std::array<std::vector<Task>, static_cast<uint8_t>(AsyncEventContext::Last)> asyncEventTasks;
std::array<std::vector<Task>, static_cast<uint8_t>(TaskGroup::Last)> m_tasks;
std::priority_queue<std::shared_ptr<Task>, std::deque<std::shared_ptr<Task>>, Task::Compare> scheduledTasks;
phmap::parallel_flat_hash_map_m<uint64_t, std::shared_ptr<Task>> scheduledTasksRef;
};
Expand Down
48 changes: 25 additions & 23 deletions src/game/scheduling/task.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,29 +94,31 @@ class Task {
static std::atomic_uint_fast64_t LAST_EVENT_ID;

bool hasTraceableContext() const {
const static auto tasksContext = phmap::flat_hash_set<std::string>({ "Creature::checkCreatureWalk",
"Decay::checkDecay",
"Dispatcher::addEvent_async",
"Game::checkCreatureAttack",
"Game::checkCreatures",
"Game::checkImbuements",
"Game::checkLight",
"Game::createFiendishMonsters",
"Game::createInfluencedMonsters",
"Game::updateCreatureWalk",
"Game::updateForgeableMonsters",
"GlobalEvents::think",
"LuaEnvironment::executeTimerEvent",
"Modules::executeOnRecvbyte",
"OutputMessagePool::sendAll",
"ProtocolGame::addGameTask",
"ProtocolGame::parsePacketFromDispatcher",
"Raids::checkRaids",
"SpawnMonster::checkSpawnMonster",
"SpawnMonster::scheduleSpawn",
"SpawnNpc::checkSpawnNpc",
"Webhook::run",
"sendRecvMessageCallback" });
const static auto tasksContext = phmap::flat_hash_set<std::string>({
"Creature::checkCreatureWalk",
"Decay::checkDecay",
"Dispatcher::addEvent_async",
"Game::checkCreatureAttack",
"Game::checkCreatures",
"Game::checkImbuements",
"Game::checkLight",
"Game::createFiendishMonsters",
"Game::createInfluencedMonsters",
"Game::updateCreatureWalk",
"Game::updateForgeableMonsters",
"GlobalEvents::think",
"LuaEnvironment::executeTimerEvent",
"Modules::executeOnRecvbyte",
"OutputMessagePool::sendAll",
"ProtocolGame::addGameTask",
"ProtocolGame::parsePacketFromDispatcher",
"Raids::checkRaids",
"SpawnMonster::checkSpawnMonster",
"SpawnMonster::scheduleSpawn",
"SpawnNpc::checkSpawnNpc",
"Webhook::run",
"sendRecvMessageCallback",
});

return tasksContext.contains(context);
}
Expand Down

0 comments on commit d17bf26

Please sign in to comment.