From a3e448d21bf06f7627dbf6c893e2b15dc8720474 Mon Sep 17 00:00:00 2001 From: Oscar Franco Date: Mon, 16 Dec 2024 17:04:27 +0100 Subject: [PATCH] Make hooks work and also react-queries --- cpp/DBHostObject.cpp | 197 ++++++++++++++++++++++--------------------- cpp/DBHostObject.h | 3 + cpp/bindings.cpp | 11 +-- cpp/bridge.cpp | 118 +++++--------------------- cpp/bridge.h | 15 ++-- example/src/App.tsx | 4 +- 6 files changed, 136 insertions(+), 212 deletions(-) diff --git a/cpp/DBHostObject.cpp b/cpp/DBHostObject.cpp index 991ed89..9837374 100644 --- a/cpp/DBHostObject.cpp +++ b/cpp/DBHostObject.cpp @@ -49,70 +49,81 @@ void DBHostObject::flush_pending_reactive_queries( [this, resolve]() { resolve->asObject(rt).asFunction(rt).call(rt, {}); }); } -void DBHostObject::auto_register_update_hook() { - if (update_hook_callback == nullptr && reactive_queries.empty() && - is_update_hook_registered) { - opsqlite_deregister_update_hook(db_name); - is_update_hook_registered = false; - return; +void DBHostObject::on_commit() { + invoker->invokeAsync([this] { + commit_hook_callback->asObject(rt).asFunction(rt).call(rt); + }); +} + + void DBHostObject::on_rollback() { + invoker->invokeAsync([this] { + rollback_hook_callback->asObject(rt).asFunction(rt).call(rt); + }); } - if (is_update_hook_registered) { - return; +void DBHostObject::on_update(std::string table, std::string operation, + int rowid) { + if (update_hook_callback != nullptr) { + invoker->invokeAsync([this, callback = update_hook_callback, table, + operation = std::move(operation), rowid] { + auto res = jsi::Object(rt); + res.setProperty(rt, "table", jsi::String::createFromUtf8(rt, table)); + res.setProperty(rt, "operation", + jsi::String::createFromUtf8(rt, operation)); + res.setProperty(rt, "rowId", jsi::Value(rowid)); + + callback->asObject(rt).asFunction(rt).call(rt, res); + }); } - auto hook = [this](std::string name, std::string table_name, - std::string operation, int rowid) { - if (update_hook_callback != nullptr) { - invoker->invokeAsync([this, callback = update_hook_callback, table_name, - operation = std::move(operation), rowid] { - auto res = jsi::Object(rt); - res.setProperty(rt, "table", - jsi::String::createFromUtf8(rt, table_name)); - res.setProperty(rt, "operation", - jsi::String::createFromUtf8(rt, operation)); - res.setProperty(rt, "rowId", jsi::Value(rowid)); - - callback->asObject(rt).asFunction(rt).call(rt, res); - }); + for (const auto &query_ptr : reactive_queries) { + auto query = query_ptr.get(); + if (query->discriminators.empty()) { + continue; } - for (const auto &query_ptr : reactive_queries) { - auto query = query_ptr.get(); - if (query->discriminators.empty()) { + bool shouldFire = false; + + for (const auto &discriminator : query->discriminators) { + // Tables don't match then skip + if (discriminator.table != table) { continue; } - bool shouldFire = false; - - for (const auto &discriminator : query->discriminators) { - // Tables don't match then skip - if (discriminator.table != table_name) { - continue; - } + // If no ids are specified, then we should fire + if (discriminator.ids.size() == 0) { + shouldFire = true; + break; + } - // If no ids are specified, then we should fire - if (discriminator.ids.size() == 0) { + // If ids are specified, then we should check if the rowId matches + for (const auto &discrimator_id : discriminator.ids) { + if (rowid == discrimator_id) { shouldFire = true; break; } - - // If ids are specified, then we should check if the rowId matches - for (const auto &discrimator_id : discriminator.ids) { - if (rowid == discrimator_id) { - shouldFire = true; - break; - } - } } + } - if (shouldFire) { - pending_reactive_queries.insert(query_ptr); - } + if (shouldFire) { + pending_reactive_queries.insert(query_ptr); } - }; + } +} - opsqlite_register_update_hook(db_name, std::move(hook)); +void DBHostObject::auto_register_update_hook() { + if (update_hook_callback == nullptr && reactive_queries.empty() && + is_update_hook_registered) { + opsqlite_deregister_update_hook(db); + is_update_hook_registered = false; + return; + } + + if (is_update_hook_registered) { + return; + } + + opsqlite_register_update_hook(db, this); is_update_hook_registered = true; } #endif @@ -165,8 +176,9 @@ DBHostObject::DBHostObject(jsi::Runtime &rt, std::string &base_path, std::string &crsqlite_path, std::string &sqlite_vec_path, std::string &encryption_key) - : base_path(base_path), invoker(std::move(invoker)), db_name(db_name), rt(rt) { - _thread_pool = std::make_shared(); + : base_path(base_path), invoker(std::move(invoker)), db_name(db_name), + rt(rt) { + _thread_pool = std::make_shared(); #ifdef OP_SQLITE_USE_SQLCIPHER BridgeResult result = opsqlite_open(db_name, path, crsqlite_path, @@ -290,8 +302,7 @@ void DBHostObject::create_jsi_functions() { auto resolve = std::make_shared(rt, args[0]); auto reject = std::make_shared(rt, args[1]); - auto task = [this, &rt, query, params, resolve, - reject]() { + auto task = [this, &rt, query, params, resolve, reject]() { try { std::vector> results; @@ -312,13 +323,13 @@ void DBHostObject::create_jsi_functions() { resolve->asObject(rt).asFunction(rt).call(rt, std::move(jsiResult)); }); } catch (std::runtime_error &e) { - auto what = e.what(); - invoker->invokeAsync([&rt, what = std::string(what), reject] { - auto errorCtr = rt.global().getPropertyAsFunction(rt, "Error"); - auto error = errorCtr.callAsConstructor( - rt, jsi::String::createFromAscii(rt, what)); - reject->asObject(rt).asFunction(rt).call(rt, error); - }); + auto what = e.what(); + invoker->invokeAsync([&rt, what = std::string(what), reject] { + auto errorCtr = rt.global().getPropertyAsFunction(rt, "Error"); + auto error = errorCtr.callAsConstructor( + rt, jsi::String::createFromAscii(rt, what)); + reject->asObject(rt).asFunction(rt).call(rt, error); + }); } catch (std::exception &exc) { auto what = exc.what(); invoker->invokeAsync([&rt, what = std::move(what), reject] { @@ -459,13 +470,13 @@ void DBHostObject::create_jsi_functions() { std::move(jsiResult)); }); } catch (std::runtime_error &e) { - auto what = e.what(); - invoker->invokeAsync([&rt, what = std::string(what), reject] { - auto errorCtr = rt.global().getPropertyAsFunction(rt, "Error"); - auto error = errorCtr.callAsConstructor( - rt, jsi::String::createFromAscii(rt, what)); - reject->asObject(rt).asFunction(rt).call(rt, error); - }); + auto what = e.what(); + invoker->invokeAsync([&rt, what = std::string(what), reject] { + auto errorCtr = rt.global().getPropertyAsFunction(rt, "Error"); + auto error = errorCtr.callAsConstructor( + rt, jsi::String::createFromAscii(rt, what)); + reject->asObject(rt).asFunction(rt).call(rt, error); + }); } catch (std::exception &exc) { auto what = exc.what(); invoker->invokeAsync([&rt, what = std::move(what), reject] { @@ -535,13 +546,13 @@ void DBHostObject::create_jsi_functions() { resolve->asObject(rt).asFunction(rt).call(rt, std::move(res)); }); } catch (std::runtime_error &e) { - auto what = e.what(); - invoker->invokeAsync([&rt, what = std::string(what), reject] { - auto errorCtr = rt.global().getPropertyAsFunction(rt, "Error"); - auto error = errorCtr.callAsConstructor( - rt, jsi::String::createFromAscii(rt, what)); - reject->asObject(rt).asFunction(rt).call(rt, error); - }); + auto what = e.what(); + invoker->invokeAsync([&rt, what = std::string(what), reject] { + auto errorCtr = rt.global().getPropertyAsFunction(rt, "Error"); + auto error = errorCtr.callAsConstructor( + rt, jsi::String::createFromAscii(rt, what)); + reject->asObject(rt).asFunction(rt).call(rt, error); + }); } catch (std::exception &exc) { auto what = exc.what(); invoker->invokeAsync([&rt, what = std::move(what), reject] { @@ -597,13 +608,13 @@ void DBHostObject::create_jsi_functions() { resolve->asObject(rt).asFunction(rt).call(rt, std::move(res)); }); } catch (std::runtime_error &e) { - auto what = e.what(); - invoker->invokeAsync([&rt, what = std::string(what), reject] { - auto errorCtr = rt.global().getPropertyAsFunction(rt, "Error"); - auto error = errorCtr.callAsConstructor( - rt, jsi::String::createFromAscii(rt, what)); - reject->asObject(rt).asFunction(rt).call(rt, error); - }); + auto what = e.what(); + invoker->invokeAsync([&rt, what = std::string(what), reject] { + auto errorCtr = rt.global().getPropertyAsFunction(rt, "Error"); + auto error = errorCtr.callAsConstructor( + rt, jsi::String::createFromAscii(rt, what)); + reject->asObject(rt).asFunction(rt).call(rt, error); + }); } catch (std::exception &exc) { auto what = exc.what(); invoker->invokeAsync([&rt, what = std::string(what), reject] { @@ -629,6 +640,7 @@ void DBHostObject::create_jsi_functions() { } else { update_hook_callback = callback; } + auto_register_update_hook(); return {}; }); @@ -641,17 +653,11 @@ void DBHostObject::create_jsi_functions() { auto callback = std::make_shared(rt, args[0]); if (callback->isUndefined() || callback->isNull()) { - opsqlite_deregister_commit_hook(db_name); + opsqlite_deregister_commit_hook(db); return {}; } commit_hook_callback = callback; - - auto hook = [&rt, this, callback](std::string dbName) { - invoker->invokeAsync( - [&rt, callback] { callback->asObject(rt).asFunction(rt).call(rt); }); - }; - - opsqlite_register_commit_hook(db_name, std::move(hook)); + opsqlite_register_commit_hook(db, this); return {}; }); @@ -659,23 +665,17 @@ void DBHostObject::create_jsi_functions() { auto rollback_hook = HOSTFN("rollbackHook") { if (sizeof(args) < 1) { throw std::runtime_error("[op-sqlite][rollbackHook] callback needed"); - return {}; } auto callback = std::make_shared(rt, args[0]); if (callback->isUndefined() || callback->isNull()) { - opsqlite_deregister_rollback_hook(db_name); + opsqlite_deregister_rollback_hook(db); return {}; } rollback_hook_callback = callback; - auto hook = [&rt, this, callback](std::string db_name) { - invoker->invokeAsync( - [&rt, callback] { callback->asObject(rt).asFunction(rt).call(rt); }); - }; - - opsqlite_register_rollback_hook(db_name, std::move(hook)); + opsqlite_register_rollback_hook(db, this); return {}; }); @@ -872,12 +872,15 @@ void DBHostObject::set(jsi::Runtime &rt, const jsi::PropNameID &name, void DBHostObject::invalidate() { invalidated = true; +// opsqlite_deregister_commit_hook(db); +// opsqlite_deregister_update_hook(db); +// opsqlite_deregister_rollback_hook(db); + _thread_pool->restartPool(); opsqlite_close(db); } DBHostObject::~DBHostObject() { - invalidated = true; - opsqlite_close(db); + invalidate(); } } // namespace opsqlite diff --git a/cpp/DBHostObject.h b/cpp/DBHostObject.h index 9940beb..8c6526d 100644 --- a/cpp/DBHostObject.h +++ b/cpp/DBHostObject.h @@ -56,6 +56,9 @@ class JSI_EXPORT DBHostObject : public jsi::HostObject { jsi::Value get(jsi::Runtime &rt, const jsi::PropNameID &propNameID); void set(jsi::Runtime &rt, const jsi::PropNameID &name, const jsi::Value &value); + void on_update(std::string table, std::string operation, int rowid); + void on_commit(); + void on_rollback(); void invalidate(); ~DBHostObject(); diff --git a/cpp/bindings.cpp b/cpp/bindings.cpp index 85696ea..b111c23 100644 --- a/cpp/bindings.cpp +++ b/cpp/bindings.cpp @@ -25,25 +25,18 @@ std::string _sqlite_vec_path; std::vector> dbs; // React native will try to clean the module on JS context invalidation -// (CodePush/Hot Reload) The clearState function is called and we use this flag -// to prevent any ongoing operations from continuing work and can return early -bool invalidated = false; - +// (CodePush/Hot Reload) The clearState function is called void clearState() { for (const auto &db : dbs) { db->invalidate(); } - invalidated = true; - - // We then join all the threads before the context gets invalidated -// thread_pool->restartPool(); + dbs.clear(); } void install(jsi::Runtime &rt, const std::shared_ptr &invoker, const char *base_path, const char *crsqlite_path, const char *sqlite_vec_path) { - invalidated = false; _base_path = std::string(base_path); _crsqlite_path = std::string(crsqlite_path); _sqlite_vec_path = std::string(sqlite_vec_path); diff --git a/cpp/bridge.cpp b/cpp/bridge.cpp index a11bfd4..c657d82 100644 --- a/cpp/bridge.cpp +++ b/cpp/bridge.cpp @@ -3,6 +3,7 @@ // so that threading operations are safe and contained within DBHostObject #include "bridge.h" +#include "DBHostObject.h" #include "DumbHostObject.h" #include "SmartHostObject.h" #include "logs.h" @@ -740,120 +741,47 @@ std::string operation_to_string(int operation_type) { } } -void update_callback(void *dbName, int operation_type, +void update_callback(void *db_host_object_ptr, int operation_type, [[maybe_unused]] char const *database, char const *table, sqlite3_int64 row_id) { - // std::string &strDbName = *(static_cast(dbName)); - // auto callback = updateCallbackMap[strDbName]; - // callback(strDbName, std::string(table), - // operation_to_string(operation_type), - // static_cast(row_id)); + auto db_host_object = reinterpret_cast(db_host_object_ptr); + db_host_object->on_update(std::string(table), + operation_to_string(operation_type), row_id); } -BridgeResult opsqlite_register_update_hook(std::string const &dbName, - UpdateCallback const &callback) { - // check_db_open(dbName); - // - // sqlite3 *db = dbMap[dbName]; - // updateCallbackMap[dbName] = callback; - // const std::string *key = nullptr; - // - // // TODO find a more elegant way to retrieve a reference to the key - // for (auto const &element : dbMap) { - // if (element.first == dbName) { - // key = &element.first; - // } - // } - // - // sqlite3_update_hook(db, &update_callback, (void *)key); - - return {}; +void opsqlite_register_update_hook(sqlite3 *db, void *db_host_object) { + sqlite3_update_hook(db, &update_callback, (void *)db_host_object); } -BridgeResult opsqlite_deregister_update_hook(std::string const &dbName) { - // check_db_open(dbName); - // - // sqlite3 *db = dbMap[dbName]; - // updateCallbackMap.erase(dbName); - // - // sqlite3_update_hook(db, nullptr, nullptr); - - return {}; +void opsqlite_deregister_update_hook(sqlite3 *db) { + sqlite3_update_hook(db, nullptr, nullptr); } -int commit_callback(void *dbName) { - // std::string &strDbName = *(static_cast(dbName)); - // auto callback = commitCallbackMap[strDbName]; - // callback(strDbName); - // You need to return 0 to allow commits to continue +int commit_callback(void *db_host_object_ptr) { + auto db_host_object = reinterpret_cast(db_host_object_ptr); + db_host_object->on_commit(); return 0; } -BridgeResult opsqlite_register_commit_hook(std::string const &dbName, - CommitCallback const &callback) { - // check_db_open(dbName); - // - // sqlite3 *db = dbMap[dbName]; - // commitCallbackMap[dbName] = callback; - // const std::string *key = nullptr; - // - // // TODO find a more elegant way to retrieve a reference to the key - // for (auto const &element : dbMap) { - // if (element.first == dbName) { - // key = &element.first; - // } - // } - // - // sqlite3_commit_hook(db, &commit_callback, (void *)key); - - return {}; +void opsqlite_register_commit_hook(sqlite3 *db, void *db_host_object_ptr) { + sqlite3_commit_hook(db, &commit_callback, db_host_object_ptr); } -BridgeResult opsqlite_deregister_commit_hook(std::string const &dbName) { - // check_db_open(dbName); - // - // sqlite3 *db = dbMap[dbName]; - // commitCallbackMap.erase(dbName); - // sqlite3_commit_hook(db, nullptr, nullptr); - - return {}; +void opsqlite_deregister_commit_hook(sqlite3 *db) { + sqlite3_commit_hook(db, nullptr, nullptr); } -void rollback_callback(void *dbName) { - // std::string &strDbName = *(static_cast(dbName)); - // auto callback = rollbackCallbackMap[strDbName]; - // callback(strDbName); +void rollback_callback(void *db_host_object_ptr) { + auto db_host_object = reinterpret_cast(db_host_object_ptr); + db_host_object->on_rollback(); } -BridgeResult opsqlite_register_rollback_hook(std::string const &dbName, - RollbackCallback const &callback) { - // check_db_open(dbName); - // - // sqlite3 *db = dbMap[dbName]; - // rollbackCallbackMap[dbName] = callback; - // const std::string *key = nullptr; - // - // // TODO find a more elegant way to retrieve a reference to the key - // for (auto const &element : dbMap) { - // if (element.first == dbName) { - // key = &element.first; - // } - // } - // - // sqlite3_rollback_hook(db, &rollback_callback, (void *)key); - - return {}; +void opsqlite_register_rollback_hook(sqlite3 *db, void *db_host_object_ptr) { + sqlite3_rollback_hook(db, &rollback_callback, db_host_object_ptr); } -BridgeResult opsqlite_deregister_rollback_hook(std::string const &dbName) { - // check_db_open(dbName); - // - // sqlite3 *db = dbMap[dbName]; - // rollbackCallbackMap.erase(dbName); - // - // sqlite3_rollback_hook(db, nullptr, nullptr); - - return {}; +void opsqlite_deregister_rollback_hook(sqlite3 *db) { + sqlite3_rollback_hook(db, nullptr, nullptr); } void opsqlite_load_extension(sqlite3 *db, std::string &path, diff --git a/cpp/bridge.h b/cpp/bridge.h index e073b64..9881b8e 100644 --- a/cpp/bridge.h +++ b/cpp/bridge.h @@ -61,15 +61,12 @@ BridgeResult opsqlite_execute_raw(sqlite3 *db, std::string const &query, const std::vector *params, std::vector> *results); -BridgeResult opsqlite_register_update_hook(std::string const &dbName, - const UpdateCallback &callback); -BridgeResult opsqlite_deregister_update_hook(std::string const &dbName); -BridgeResult opsqlite_register_commit_hook(std::string const &dbName, - const CommitCallback &callback); -BridgeResult opsqlite_deregister_commit_hook(std::string const &dbName); -BridgeResult opsqlite_register_rollback_hook(std::string const &dbName, - const RollbackCallback &callback); -BridgeResult opsqlite_deregister_rollback_hook(std::string const &dbName); +void opsqlite_register_update_hook(sqlite3 *db, void *db_host_object_ptr); +void opsqlite_deregister_update_hook(sqlite3 *db); +void opsqlite_register_commit_hook(sqlite3 *db, void *db_host_object_ptr); +void opsqlite_deregister_commit_hook(sqlite3 *db); +void opsqlite_register_rollback_hook(sqlite3 *db, void *db_host_object_ptr); +void opsqlite_deregister_rollback_hook(sqlite3 *db); sqlite3_stmt *opsqlite_prepare_statement(sqlite3 *db, std::string const &query); diff --git a/example/src/App.tsx b/example/src/App.tsx index c78c1b8..8f11507 100644 --- a/example/src/App.tsx +++ b/example/src/App.tsx @@ -41,10 +41,10 @@ export default function App() { queriesTests, dbSetupTests, blobTests, - // registerHooksTests, + registerHooksTests, preparedStatementsTests, constantsTests, - // reactiveTests, + reactiveTests, tokenizerTests, ) .then(results => {