diff --git a/src_cpp/include/node_connection.h b/src_cpp/include/node_connection.h index b5f4bfd..a47be64 100644 --- a/src_cpp/include/node_connection.h +++ b/src_cpp/include/node_connection.h @@ -70,10 +70,11 @@ namespace main { class ConnectionExecuteAsyncWorker : public Napi::AsyncWorker { public: ConnectionExecuteAsyncWorker(Napi::Function& callback, std::shared_ptr& connection, + std::shared_ptr& database, std::shared_ptr preparedStatement, NodeQueryResult* nodeQueryResult, std::unordered_map> params, Napi::Value progressCallback) - : Napi::AsyncWorker(callback), connection(connection), + : Napi::AsyncWorker(callback), connection(connection), database(database), preparedStatement(std::move(preparedStatement)), nodeQueryResult(nodeQueryResult), params(std::move(params)) { if (progressCallback.IsFunction()) { @@ -100,12 +101,11 @@ class ConnectionExecuteAsyncWorker : public Napi::AsyncWorker { auto result = connection ->executeWithParamsWithID(preparedStatement.get(), std::move(params), queryID); - auto* resultRaw = result.get(); - nodeQueryResult->AdoptQueryResult(std::move(result)); - if (!resultRaw->isSuccess()) { - SetError(resultRaw->getErrorMessage()); + if (!result->isSuccess()) { + SetError(result->getErrorMessage()); return; } + nodeQueryResult->AdoptQueryResult(std::move(result), database); } catch (const std::exception& exc) { SetError(std::string(exc.what())); } @@ -122,6 +122,7 @@ class ConnectionExecuteAsyncWorker : public Napi::AsyncWorker { private: std::shared_ptr connection; + std::shared_ptr database; std::shared_ptr preparedStatement; NodeQueryResult* nodeQueryResult; std::unordered_map> params; @@ -131,9 +132,10 @@ class ConnectionExecuteAsyncWorker : public Napi::AsyncWorker { class ConnectionQueryAsyncWorker : public Napi::AsyncWorker { public: ConnectionQueryAsyncWorker(Napi::Function& callback, std::shared_ptr& connection, + std::shared_ptr& database, std::string statement, NodeQueryResult* nodeQueryResult, Napi::Value progressCallback) - : Napi::AsyncWorker(callback), connection(connection), statement(std::move(statement)), - nodeQueryResult(nodeQueryResult) { + : Napi::AsyncWorker(callback), connection(connection), database(database), + statement(std::move(statement)), nodeQueryResult(nodeQueryResult) { if (progressCallback.IsFunction()) { this->progressCallback = Napi::ThreadSafeFunction::New(Env(), progressCallback.As(), "ProgressCallback", 0, 1); @@ -156,11 +158,11 @@ class ConnectionQueryAsyncWorker : public Napi::AsyncWorker { } try { auto result = connection->queryWithID(statement, queryID); - auto* resultRaw = result.get(); - nodeQueryResult->AdoptQueryResult(std::move(result)); - if (!resultRaw->isSuccess()) { - SetError(resultRaw->getErrorMessage()); + if (!result->isSuccess()) { + SetError(result->getErrorMessage()); + return; } + nodeQueryResult->AdoptQueryResult(std::move(result), database); } catch (const std::exception& exc) { SetError(std::string(exc.what())); } @@ -177,6 +179,7 @@ class ConnectionQueryAsyncWorker : public Napi::AsyncWorker { private: std::shared_ptr connection; + std::shared_ptr database; std::string statement; NodeQueryResult* nodeQueryResult; std::optional progressCallback; diff --git a/src_cpp/include/node_query_result.h b/src_cpp/include/node_query_result.h index ae71b27..07d1345 100644 --- a/src_cpp/include/node_query_result.h +++ b/src_cpp/include/node_query_result.h @@ -20,9 +20,10 @@ class NodeQueryResult : public Napi::ObjectWrap { public: static Napi::Object Init(Napi::Env env, Napi::Object exports); - static Napi::Object NewInstance(Napi::Env env, std::unique_ptr queryResult); + static Napi::Object NewInstance(Napi::Env env, std::unique_ptr queryResult, + std::shared_ptr db); explicit NodeQueryResult(const Napi::CallbackInfo& info); - void AdoptQueryResult(std::unique_ptr queryResult); + void AdoptQueryResult(std::unique_ptr queryResult, std::shared_ptr db); std::unique_ptr DetachNextQueryResult(); ~NodeQueryResult() override; @@ -52,6 +53,7 @@ class NodeQueryResult : public Napi::ObjectWrap { private: static Napi::FunctionReference constructor; std::unique_ptr ownedQueryResult = nullptr; + std::shared_ptr database = nullptr; std::unique_ptr> columnNames = nullptr; std::atomic activeAsyncUses = 0; }; @@ -202,7 +204,8 @@ class NodeQueryResultGetNextQueryResultAsyncWorker : public Napi::AsyncWorker { Callback().Call({env.Null(), env.Undefined()}); return; } - Callback().Call({env.Null(), NodeQueryResult::NewInstance(env, std::move(nextOwnedResult))}); + Callback().Call({env.Null(), NodeQueryResult::NewInstance(env, std::move(nextOwnedResult), + currQueryResult->database)}); } void OnError(Napi::Error const& error) override { diff --git a/src_cpp/node_connection.cpp b/src_cpp/node_connection.cpp index a0090f5..c093fbf 100644 --- a/src_cpp/node_connection.cpp +++ b/src_cpp/node_connection.cpp @@ -57,8 +57,6 @@ void NodeConnection::InitCppConnection() { this->connection = std::make_shared(database.get()); ProgressBar::Get(*connection->getClientContext()) ->setDisplay(std::make_shared()); - // After the connection is initialized, we do not need to hold a reference to the database. - database.reset(); } void NodeConnection::SetMaxNumThreadForExec(const Napi::CallbackInfo& info) { @@ -87,6 +85,7 @@ void NodeConnection::Close(const Napi::CallbackInfo& info) { Napi::Env env = info.Env(); Napi::HandleScope scope(env); this->connection.reset(); + this->database.reset(); } Napi::Value NodeConnection::ExecuteAsync(const Napi::CallbackInfo& info) { @@ -98,7 +97,7 @@ Napi::Value NodeConnection::ExecuteAsync(const Napi::CallbackInfo& info) { auto callback = info[3].As(); try { auto params = Util::TransformParametersForExec(info[2].As()); - auto asyncWorker = new ConnectionExecuteAsyncWorker(callback, connection, + auto asyncWorker = new ConnectionExecuteAsyncWorker(callback, connection, database, nodePreparedStatement->preparedStatement, nodeQueryResult, std::move(params), info[4]); asyncWorker->Queue(); } catch (const std::exception& exc) { @@ -114,11 +113,10 @@ Napi::Value NodeConnection::QuerySync(const Napi::CallbackInfo& info) { auto nodeQueryResult = Napi::ObjectWrap::Unwrap(info[1].As()); try { auto result = connection->query(statement); - auto* resultRaw = result.get(); - nodeQueryResult->AdoptQueryResult(std::move(result)); - if (!resultRaw->isSuccess()) { - Napi::Error::New(env, resultRaw->getErrorMessage()).ThrowAsJavaScriptException(); + if (!result->isSuccess()) { + Napi::Error::New(env, result->getErrorMessage()).ThrowAsJavaScriptException(); } + nodeQueryResult->AdoptQueryResult(std::move(result), database); } catch (const std::exception& exc) { Napi::Error::New(env, std::string(exc.what())).ThrowAsJavaScriptException(); } @@ -135,11 +133,10 @@ Napi::Value NodeConnection::ExecuteSync(const Napi::CallbackInfo& info) { auto params = Util::TransformParametersForExec(info[2].As()); auto result = connection->executeWithParams(nodePreparedStatement->preparedStatement.get(), std::move(params)); - auto* resultRaw = result.get(); - nodeQueryResult->AdoptQueryResult(std::move(result)); - if (!resultRaw->isSuccess()) { - Napi::Error::New(env, resultRaw->getErrorMessage()).ThrowAsJavaScriptException(); + if (!result->isSuccess()) { + Napi::Error::New(env, result->getErrorMessage()).ThrowAsJavaScriptException(); } + nodeQueryResult->AdoptQueryResult(std::move(result), database); } catch (const std::exception& exc) { Napi::Error::New(env, std::string(exc.what())).ThrowAsJavaScriptException(); } @@ -153,7 +150,7 @@ Napi::Value NodeConnection::QueryAsync(const Napi::CallbackInfo& info) { auto nodeQueryResult = Napi::ObjectWrap::Unwrap(info[1].As()); auto callback = info[2].As(); auto asyncWorker = - new ConnectionQueryAsyncWorker(callback, connection, statement, nodeQueryResult, info[3]); + new ConnectionQueryAsyncWorker(callback, connection, database, statement, nodeQueryResult, info[3]); asyncWorker->Queue(); return info.Env().Undefined(); } diff --git a/src_cpp/node_query_result.cpp b/src_cpp/node_query_result.cpp index 258abe7..2d8f854 100644 --- a/src_cpp/node_query_result.cpp +++ b/src_cpp/node_query_result.cpp @@ -37,10 +37,10 @@ Napi::Object NodeQueryResult::Init(Napi::Env env, Napi::Object exports) { } Napi::Object NodeQueryResult::NewInstance( - Napi::Env /*env*/, std::unique_ptr queryResult) { + Napi::Env /*env*/, std::unique_ptr queryResult, std::shared_ptr db) { auto obj = constructor.New({}); auto* nodeQueryResult = Napi::ObjectWrap::Unwrap(obj); - nodeQueryResult->AdoptQueryResult(std::move(queryResult)); + nodeQueryResult->AdoptQueryResult(std::move(queryResult), std::move(db)); return obj; } @@ -51,10 +51,12 @@ NodeQueryResult::~NodeQueryResult() { this->Close(); } -void NodeQueryResult::AdoptQueryResult(std::unique_ptr queryResult) { +void NodeQueryResult::AdoptQueryResult( + std::unique_ptr queryResult, std::shared_ptr db) { ThrowIfAsyncOperationInFlight("replace"); columnNames.reset(); ownedQueryResult = std::move(queryResult); + database = std::move(db); } std::unique_ptr NodeQueryResult::DetachNextQueryResult() { @@ -140,7 +142,7 @@ Napi::Value NodeQueryResult::GetNextQueryResultSync(const Napi::CallbackInfo& in .ThrowAsJavaScriptException(); return env.Undefined(); } - return NewInstance(env, std::move(nextOwnedResult)); + return NewInstance(env, std::move(nextOwnedResult), database); } catch (const std::exception& exc) { Napi::Error::New(env, std::string(exc.what())).ThrowAsJavaScriptException(); } @@ -286,4 +288,5 @@ void NodeQueryResult::Close(const Napi::CallbackInfo& info) { void NodeQueryResult::Close() { columnNames.reset(); ownedQueryResult.reset(); + database.reset(); } diff --git a/test/test_database.js b/test/test_database.js index 16825b1..916ba84 100644 --- a/test/test_database.js +++ b/test/test_database.js @@ -505,11 +505,35 @@ describe("Database close", function () { assert.equal(res.getNumTuples(), 1); const tuple = await res.getNext(); assert.deepEqual(tuple, { "+(1,1)": 2 }); + // Close in reverse order: db first, then conn, then result. None should crash. testDb.closeSync(); assert.isTrue(testDb._isClosed); - assert.throws(() => conn.querySync("RETURN 1+1"), Error, "Runtime exception: The current operation is not allowed because the parent database is closed."); conn.closeSync(); assert.isTrue(conn._isClosed); - assert.throws(() => res.resetIterator(), Error, "Runtime exception: The current operation is not allowed because the parent database is closed."); + res.close(); + }); + + it("should not crash when discarded query results are GC'd after database is closed", async function () { + // Regression test for a double-free bug: NodeQueryResult holds a + // MaterializedQueryResult whose FactorizedTable destructor accesses + // database-owned memory. If the Database is destroyed before the GC + // finalizer for NodeQueryResult runs, that destructor crashes. The fix + // is for NodeQueryResult to hold a shared_ptr so the Database + // cannot be freed until every result that references it is gone. + // + // The key pattern being tested is: query results are *not stored*, making + // them immediately eligible for GC. conn.closeSync() and db.closeSync() + // are then called before GC has had a chance to collect them. When the GC + // finalizer eventually runs (possibly later in this mocha process), it must + // not crash. + const testDb = new lbug.Database(); + const conn = new lbug.Connection(testDb); + await conn.query("CREATE NODE TABLE T(id STRING PRIMARY KEY)"); + await conn.query(`CREATE (:T {id: 'test-${Date.now()}'})`); + await conn.query("MATCH (t:T) RETURN t.id"); + conn.closeSync(); + testDb.closeSync(); + assert.isTrue(conn._isClosed); + assert.isTrue(testDb._isClosed); }); });