Skip to content

Commit

Permalink
Fixed lua reference leak involving db:ping()
Browse files Browse the repository at this point in the history
Added integration test for waiting on a failed database
Added reference created and freed debug statistics
  • Loading branch information
FredyH committed Jan 24, 2022
1 parent 9f10833 commit 18046a0
Show file tree
Hide file tree
Showing 12 changed files with 86 additions and 51 deletions.
11 changes: 11 additions & 0 deletions IntegrationTest/lua/mysqloo/tests/query_tests.lua
Original file line number Diff line number Diff line change
Expand Up @@ -250,4 +250,15 @@ TestFramework:RegisterTest("[Query] abort query correctly", function(test)
end
test:shouldBeEqual(qu2:abort(), true)
test:shouldBeEqual(qu:abort(), false)
end)

TestFramework:RegisterTest("[Query] not crash if waiting on query of a failed database", function(test)
local db = mysqloo.connect("127.0.0.1", "root", "test", "test", 33406)
db:connect()
local qu = db:query("SELECT 1")
qu:start()
function qu:onError()
test:Complete()
end
qu:wait()
end)
8 changes: 5 additions & 3 deletions src/BlockingQueue.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ template<typename T>
class BlockingQueue {
public:
void put(T elem) {
std::lock_guard<std::recursive_mutex> lock(mutex);
backingQueue.push_back(elem);
{
std::lock_guard<std::recursive_mutex> lock(mutex);
backingQueue.push_back(elem);
}
waitObj.notify_all();
}

Expand Down Expand Up @@ -49,7 +51,7 @@ class BlockingQueue {

T take() {
std::unique_lock<std::recursive_mutex> lock(mutex);
while (size() == 0) waitObj.wait(lock);
waitObj.wait(lock, [this] { return this->size() > 0; });
auto front = backingQueue.front();
backingQueue.pop_front();
return front;
Expand Down
23 changes: 19 additions & 4 deletions src/lua/GMModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ static int versionCheckConVar = 0;
GMOD_MODULE_CLOSE() {
// Free the version check ConVar object reference
if (versionCheckConVar != 0) {
LUA->ReferenceFree(versionCheckConVar);
LuaReferenceFree(LUA, versionCheckConVar);
versionCheckConVar = 0;
}
mysql_thread_end();
Expand Down Expand Up @@ -55,13 +55,13 @@ static void printMessage(GarrysMod::Lua::ILuaBase *LUA, const char *str, int r,
LUA->PushNumber(g);
LUA->PushNumber(b);
LUA->Call(3, 1);
int ref = LUA->ReferenceCreate();
int ref = LuaReferenceCreate(LUA);
LUA->GetField(-1, "MsgC");
LUA->ReferencePush(ref);
LUA->PushString(str);
LUA->Call(2, 0);
LUA->Pop();
LUA->ReferenceFree(ref);
LuaReferenceFree(LUA, ref);
}

static int printOutdatedVersion(lua_State *state) {
Expand Down Expand Up @@ -134,6 +134,16 @@ LUA_FUNCTION(deallocationCount) {
return 1;
}

LUA_FUNCTION(referenceCreatedCount) {
LUA->PushNumber((double) LuaObject::referenceCreatedCount);
return 1;
}

LUA_FUNCTION(referenceFreedCount) {
LUA->PushNumber((double) LuaObject::referenceFreedCount);
return 1;
}

LUA_FUNCTION(mysqlooThink) {
LUA->PushSpecial(GarrysMod::Lua::SPECIAL_GLOB);
LUA->GetField(-1, "mysqloo");
Expand Down Expand Up @@ -214,6 +224,11 @@ GMOD_MODULE_OPEN() {
LUA->SetField(-2, "allocationCount");
LUA->PushCFunction(deallocationCount);
LUA->SetField(-2, "deallocationCount");
LUA->PushCFunction(referenceFreedCount);
LUA->SetField(-2, "referenceFreedCount");
LUA->PushCFunction(referenceCreatedCount);
LUA->SetField(-2, "referenceCreatedCount");


LuaDatabase::createWeakTable(LUA);

Expand All @@ -230,7 +245,7 @@ GMOD_MODULE_OPEN() {
LUA->PushNumber(0); // Min value
LUA->PushNumber(1); // Max value
LUA->Call(6, 1); // Call with 6 arguments and 1 result
versionCheckConVar = LUA->ReferenceCreate(); // Store the created ConVar object as a global variable
versionCheckConVar = LuaReferenceCreate(LUA); // Store the created ConVar object as a global variable
LUA->Pop(); // Pop the global table

runInTimer(LUA, 5, doVersionCheck);
Expand Down
14 changes: 7 additions & 7 deletions src/lua/LuaDatabase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ MYSQLOO_LUA_FUNCTION(query) {
auto query = Query::create(database->m_database, std::string(queryStr, outLen));

LUA->Push(1);
int databaseRef = LUA->ReferenceCreate();
int databaseRef = LuaReferenceCreate(LUA);

auto luaQuery = new LuaQuery(query, databaseRef);

Expand All @@ -76,7 +76,7 @@ MYSQLOO_LUA_FUNCTION(prepare) {
auto query = PreparedQuery::create(database->m_database, std::string(queryStr, outLen));

LUA->Push(1);
int databaseRef = LUA->ReferenceCreate();
int databaseRef = LuaReferenceCreate(LUA);

auto luaQuery = new LuaPreparedQuery(query, databaseRef);

Expand All @@ -89,7 +89,7 @@ MYSQLOO_LUA_FUNCTION(createTransaction) {
auto transaction = Transaction::create(database->m_database);

LUA->Push(1);
int databaseRef = LUA->ReferenceCreate();
int databaseRef = LuaReferenceCreate(LUA);

auto luaTransaction = new LuaTransaction(transaction, databaseRef);

Expand All @@ -101,7 +101,7 @@ MYSQLOO_LUA_FUNCTION(connect) {
auto database = LuaObject::getLuaObject<LuaDatabase>(LUA);
if (database->m_tableReference == 0) {
LUA->Push(1);
database->m_tableReference = LUA->ReferenceCreate();
database->m_tableReference = LuaReferenceCreate(LUA);
}
database->m_database->connect();
return 0;
Expand Down Expand Up @@ -324,7 +324,7 @@ void LuaDatabase::think(ILuaBase *LUA) {
LUA->Pop(); //Callback function
}

LUA->ReferenceFree(this->m_tableReference);
LuaReferenceFree(LUA, this->m_tableReference);
this->m_tableReference = 0;
}

Expand Down Expand Up @@ -390,7 +390,7 @@ void LuaDatabase::runAllThinkHooks(ILuaBase *LUA) {
while (LUA->Next(-2) != 0) {
//The key is the table of the database
LUA->Push(-2); //The key, i.e. the database table
databaseReferences.push_back(LUA->ReferenceCreate());
databaseReferences.push_back(LuaReferenceCreate(LUA));

LUA->Pop(); //The value, keep key on stack for next()
}
Expand All @@ -399,7 +399,7 @@ void LuaDatabase::runAllThinkHooks(ILuaBase *LUA) {
//Call think function of each alive database instance
for (auto &ref: databaseReferences) {
LUA->ReferencePush(ref);
LUA->ReferenceFree(ref); //We can immediately free this, the variable on the stack keeps it alive.
LuaReferenceFree(LUA, ref); //We can immediately free this, the variable on the stack keeps it alive.
auto database = LuaObject::getLuaObject<LuaDatabase>(LUA, -1);
database->think(LUA);
LUA->Pop(); //database
Expand Down
14 changes: 7 additions & 7 deletions src/lua/LuaIQuery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ void LuaIQuery::addMetaTableFunctions(ILuaBase *LUA) {

void LuaIQuery::referenceCallbacks(ILuaBase *LUA, int stackPosition, IQueryData &data) {
LUA->Push(stackPosition);
data.m_tableReference = LUA->ReferenceCreate();
data.m_tableReference = LuaReferenceCreate(LUA);

if (data.m_successReference == 0) {
data.m_successReference = getFunctionReference(LUA, stackPosition, "onSuccess");
Expand All @@ -131,19 +131,19 @@ void LuaIQuery::referenceCallbacks(ILuaBase *LUA, int stackPosition, IQueryData
void LuaIQuery::finishQueryData(GarrysMod::Lua::ILuaBase *LUA, const std::shared_ptr<IQuery> &query, const std::shared_ptr<IQueryData> &data) {
query->finishQueryData(data);
if (data->m_tableReference) {
LUA->ReferenceFree(data->m_tableReference);
LuaReferenceFree(LUA, data->m_tableReference);
}
if (data->m_onDataReference) {
LUA->ReferenceFree(data->m_onDataReference);
LuaReferenceFree(LUA, data->m_onDataReference);
}
if (data->m_errorReference) {
LUA->ReferenceFree(data->m_errorReference);
LuaReferenceFree(LUA, data->m_errorReference);
}
if (data->m_abortReference) {
LUA->ReferenceFree(data->m_abortReference);
LuaReferenceFree(LUA, data->m_abortReference);
}
if (data->m_successReference) {
LUA->ReferenceFree(data->m_successReference);
LuaReferenceFree(LUA, data->m_successReference);
}
data->m_onDataReference = 0;
data->m_errorReference = 0;
Expand Down Expand Up @@ -176,7 +176,7 @@ void LuaIQuery::runCallback(ILuaBase *LUA, const std::shared_ptr<IQuery> &iQuery

void LuaIQuery::onDestroyedByLua(ILuaBase *LUA) {
if (m_databaseReference != 0) {
LUA->ReferenceFree(m_databaseReference);
LuaReferenceFree(LUA, m_databaseReference);
m_databaseReference = 0;
}
}
15 changes: 14 additions & 1 deletion src/lua/LuaObject.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,22 @@ int LuaObject::getFunctionReference(ILuaBase *LUA, int stackPosition, const char
LUA->GetField(stackPosition, fieldName);
int reference = 0;
if (LUA->IsType(-1, GarrysMod::Lua::Type::Function)) {
reference = LUA->ReferenceCreate();
reference = LuaReferenceCreate(LUA);
} else {
LUA->Pop();
}
return reference;
}

uint64_t LuaObject::referenceCreatedCount = 0;
uint64_t LuaObject::referenceFreedCount = 0;

int LuaReferenceCreate(GarrysMod::Lua::ILuaBase *LUA) {
LuaObject::referenceCreatedCount++;
return LUA->ReferenceCreate();
}

void LuaReferenceFree(GarrysMod::Lua::ILuaBase *LUA, int ref) {
LuaObject::referenceFreedCount++;
LUA->ReferenceFree(ref);
}
8 changes: 8 additions & 0 deletions src/lua/LuaObject.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include <atomic>
#include "GarrysMod/Lua/Interface.h"
#include "../mysql/MySQLOOException.h"
#include "GarrysMod/Lua/LuaBase.h"


#include <iostream>

Expand Down Expand Up @@ -68,10 +70,16 @@ class LuaObject {
static int getFunctionReference(ILuaBase *LUA, int stackPosition, const char *fieldName);
static std::atomic_long allocationCount;
static std::atomic_long deallocationCount;
static uint64_t referenceCreatedCount;
static uint64_t referenceFreedCount;
protected:
std::string m_className;
};

int LuaReferenceCreate(GarrysMod::Lua::ILuaBase *LUA);

void LuaReferenceFree(GarrysMod::Lua::ILuaBase *LUA, int ref);


#define MYSQLOO_LUA_FUNCTION(FUNC) \
static int FUNC##__Imp( GarrysMod::Lua::ILuaBase* LUA ); \
Expand Down
4 changes: 2 additions & 2 deletions src/lua/LuaQuery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ int LuaQuery::createDataReference(GarrysMod::Lua::ILuaBase *LUA, Query &query, Q
LUA->Pop(2); //data + row
}
}
query.m_dataReference = LUA->ReferenceCreate();
query.m_dataReference = LuaReferenceCreate(LUA);
return query.m_dataReference;
}

Expand Down Expand Up @@ -197,7 +197,7 @@ void LuaQuery::onDestroyedByLua(ILuaBase *LUA) {

void LuaQuery::freeDataReference(ILuaBase *LUA, Query &query) {
if (query.m_dataReference != 0) {
LUA->ReferenceFree(query.m_dataReference);
LuaReferenceFree(LUA, query.m_dataReference);
query.m_dataReference = 0;
}
}
16 changes: 12 additions & 4 deletions src/mysql/Database.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,13 @@ bool Database::ping() {
auto queryData = query->buildQueryData();
query->start(queryData);
query->wait(true);
//Ping queries do not have a lua correspondence, so they need to be removed from finished queries
//(they are essentially just a hack)
this->finishedQueries.removeIf(
[queryData](std::pair<std::shared_ptr<IQuery>, std::shared_ptr<IQueryData>> const &p) {
return p.second == queryData;
});
query->finishQueryData(queryData);
return query->pingSuccess;
}

Expand Down Expand Up @@ -312,6 +319,7 @@ void Database::failWaitingQuery(const std::shared_ptr<IQuery> &query, const std:
data->setError(std::move(reason));
data->setResultStatus(QUERY_ERROR);
data->setStatus(QUERY_COMPLETE);
data->setFinished(true);
finishedQueries.put(std::make_pair(query, data));
}

Expand All @@ -331,12 +339,12 @@ void Database::abortWaitingQuery() {
}
failWaitingQuery(query, data, "The database of the query you were waiting on was disconnected.");
this->m_waitingQuery = std::make_pair(nullptr, nullptr);
query->notify();
this->m_queryWaitWakeupVariable.notify_all();
}

//Called from the main thread when calling query:wait()
//There can always only be at most one waiting query per database (since waiting blocks the main thread here!)
void Database::waitForQuery(const std::shared_ptr<IQuery>& query, const std::shared_ptr<IQueryData>& data) {
void Database::waitForQuery(const std::shared_ptr<IQuery> &query, const std::shared_ptr<IQueryData> &data) {
{
std::unique_lock<std::mutex> lock(this->m_queryWaitMutex);
if (!this->m_canWait) {
Expand All @@ -347,8 +355,8 @@ void Database::waitForQuery(const std::shared_ptr<IQuery>& query, const std::sha
return; //No need to wait
}
this->m_waitingQuery = std::make_pair(query, data);
this->m_queryWaitWakeupVariable.wait(lock, [data] { return data->isFinished(); });
}
query->waitForNotify(data);
}

/* Thread that connects to the database, on success it continues to handle queries in the run method.
Expand Down Expand Up @@ -437,7 +445,7 @@ void Database::run() {
this->m_waitingQuery = std::make_pair(nullptr, nullptr);
}
}
curQuery->notify();
this->m_queryWaitWakeupVariable.notify_all();
//So that statements get eventually freed even if the queue is constantly full
freeUnusedStatements();
}
Expand Down
1 change: 1 addition & 0 deletions src/mysql/Database.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ class Database : public std::enable_shared_from_this<Database> {
std::atomic<bool> m_connectionDone{false};
std::atomic<bool> cachePreparedStatements{true};
std::condition_variable m_queryWakeupVariable{};
std::condition_variable m_queryWaitWakeupVariable{};
std::string database;
std::string host;
std::string username;
Expand Down
17 changes: 0 additions & 17 deletions src/mysql/IQuery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,21 +173,4 @@ void IQuery::finishQueryData(const std::shared_ptr<IQueryData> &data) {
runningQueryData.erase(std::remove(runningQueryData.begin(), runningQueryData.end(), data),
runningQueryData.end());
}
}

/*
* Waits for the query to be notified of the completion of the query data.
* This should not be called directly, but only from the database.
*/
void IQuery::waitForNotify(const std::shared_ptr<IQueryData> &data) {
std::unique_lock<std::mutex> lck(m_waitMutex);
while (!data->isFinished()) m_waitWakeupVariable.wait(lck);
}

/*
* Notifies a waiting query and wakes it up.
*/
void IQuery::notify() {
std::unique_lock<std::mutex> queryMutex(m_waitMutex);
m_waitWakeupVariable.notify_all();
}
6 changes: 0 additions & 6 deletions src/mysql/IQuery.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,9 @@ class IQuery : public std::enable_shared_from_this<IQuery> {

//fields
std::shared_ptr<Database> m_database{};
std::condition_variable m_waitWakeupVariable;
std::mutex m_waitMutex;
int m_options = 0;
std::deque<std::shared_ptr<IQueryData>> runningQueryData;
bool hasBeenStarted = false;
private:
//Wakes up any waiting thread
void notify();
void waitForNotify(const std::shared_ptr<IQueryData> &data);
};

class IQueryData {
Expand Down

0 comments on commit 18046a0

Please sign in to comment.