Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src_cpp/include/py_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class PyConnection {
const py::dict& params);

std::unique_ptr<PyQueryResult> query(const std::string& statement);
std::unique_ptr<PyQueryResult> queryAsArrow(const std::string& statement,
int64_t chunkSize);

void setMaxNumThreadForExec(uint64_t numThreads);

Expand Down
1 change: 1 addition & 0 deletions src_cpp/include/py_query_result.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class PyQueryResult {
py::object getAsDF();

lbug::pyarrow::Table getAsArrow(std::int64_t chunkSize, bool fallbackExtensionTypes);
py::dict getCSR();

py::list getColumnDataTypes();

Expand Down
10 changes: 10 additions & 0 deletions src_cpp/py_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ void PyConnection::initialize(py::handle& m) {
.def("execute", &PyConnection::execute, py::arg("prepared_statement"),
py::arg("parameters") = py::dict())
.def("query", &PyConnection::query, py::arg("statement"))
.def("query_as_arrow", &PyConnection::queryAsArrow, py::arg("statement"),
py::arg("chunk_size"))
.def("set_max_threads_for_exec", &PyConnection::setMaxNumThreadForExec,
py::arg("num_threads"))
.def("prepare", &PyConnection::prepare, py::arg("query"),
Expand Down Expand Up @@ -175,6 +177,14 @@ std::unique_ptr<PyQueryResult> PyConnection::query(const std::string& statement)
return checkAndWrapQueryResult(queryResult);
}

std::unique_ptr<PyQueryResult> PyConnection::queryAsArrow(const std::string& statement,
int64_t chunkSize) {
py::gil_scoped_release release;
auto queryResult = conn->queryAsArrow(statement, chunkSize);
py::gil_scoped_acquire acquire;
return checkAndWrapQueryResult(queryResult);
}

void PyConnection::setMaxNumThreadForExec(uint64_t numThreads) {
conn->setMaxNumThreadForExec(numThreads);
}
Expand Down
55 changes: 55 additions & 0 deletions src_cpp/py_query_result.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
#include "common/arrow/arrow_row_batch.h"
#include "common/constants.h"
#include "common/exception/not_implemented.h"
#include "common/exception/runtime.h"
#include "common/types/uuid.h"
#include "common/types/value/nested.h"
#include "common/types/value/node.h"
#include "common/types/value/rel.h"
#include "datetime.h" // python lib
#include "include/py_query_result_converter.h"
#include "main/query_result/arrow_query_result.h"

using namespace lbug::common;
using lbug::importCache;
Expand All @@ -30,6 +32,7 @@ void PyQueryResult::initialize(py::handle& m) {
.def("close", &PyQueryResult::close)
.def("getAsDF", &PyQueryResult::getAsDF)
.def("getAsArrow", &PyQueryResult::getAsArrow)
.def("getCSR", &PyQueryResult::getCSR)
.def("getColumnNames", &PyQueryResult::getColumnNames)
.def("getColumnDataTypes", &PyQueryResult::getColumnDataTypes)
.def("resetIterator", &PyQueryResult::resetIterator)
Expand Down Expand Up @@ -85,6 +88,30 @@ void PyQueryResult::close() {
}
}

namespace {

py::array_t<int64_t> copyToNumpyArray(const std::vector<int64_t>& values) {
auto result = py::array_t<int64_t>(values.size());
auto* data = static_cast<int64_t*>(result.request().ptr);
std::copy(values.begin(), values.end(), data);
return result;
}

py::dict buildCSRResult(std::vector<int64_t> indptr, std::vector<int64_t> indices,
std::vector<int64_t> edgeIDs, bool includeEdgeIDs) {
py::dict result;
result["indptr"] = copyToNumpyArray(indptr);
result["indices"] = copyToNumpyArray(indices);
if (includeEdgeIDs) {
result["edge_ids"] = copyToNumpyArray(edgeIDs);
} else {
result["edge_ids"] = py::none();
}
return result;
}

} // namespace

static py::object converTimestampToPyObject(timestamp_t& timestamp) {
int32_t year = 0, month = 0, day = 0, hour = 0, min = 0, sec = 0, micros = 0;
date_t date;
Expand Down Expand Up @@ -320,6 +347,23 @@ py::object PyQueryResult::getArrowChunks(const std::vector<LogicalType>& types,

lbug::pyarrow::Table PyQueryResult::getAsArrow(std::int64_t chunkSize,
bool fallbackExtensionTypes) {
if (queryResult->getType() == QueryResultType::ARROW) {
auto types = queryResult->getColumnDataTypes();
auto names = queryResult->getColumnNames();
py::list batches;
auto batchImportFunc = importCache->pyarrow.lib.RecordBatch._import_from_c();
while (queryResult->hasNextArrowChunk()) {
auto data = queryResult->getNextArrowChunk(chunkSize);
auto schema = ArrowConverter::toArrowSchema(types, names, fallbackExtensionTypes);
batches.append(
batchImportFunc((std::uint64_t)data.get(), (std::uint64_t)schema.get()));
}
auto schema = ArrowConverter::toArrowSchema(types, names, fallbackExtensionTypes);
auto fromBatchesFunc = importCache->pyarrow.lib.Table.from_batches();
auto schemaImportFunc = importCache->pyarrow.lib.Schema._import_from_c();
auto schemaObj = schemaImportFunc((std::uint64_t)schema.get());
return py::cast<lbug::pyarrow::Table>(fromBatchesFunc(batches, schemaObj));
}
auto types = queryResult->getColumnDataTypes();
auto names = queryResult->getColumnNames();
py::list batches = getArrowChunks(types, names, chunkSize, fallbackExtensionTypes);
Expand All @@ -330,6 +374,17 @@ lbug::pyarrow::Table PyQueryResult::getAsArrow(std::int64_t chunkSize,
return py::cast<lbug::pyarrow::Table>(fromBatchesFunc(batches, schemaObj));
}

py::dict PyQueryResult::getCSR() {
if (auto* arrowQueryResult = dynamic_cast<lbug::main::ArrowQueryResult*>(queryResult);
arrowQueryResult != nullptr && arrowQueryResult->hasCSRMetadata()) {
const auto& metadata = arrowQueryResult->getCSRMetadata();
return buildCSRResult(metadata.indptr, metadata.indices, metadata.edgeIDs,
metadata.hasEdgeIDs);
}
throw RuntimeException(
"CSR export is only supported for Arrow query results with native CSR metadata.");
}

py::list PyQueryResult::getColumnDataTypes() {
auto columnDataTypes = queryResult->getColumnDataTypes();
py::tuple result(columnDataTypes.size());
Expand Down
4 changes: 3 additions & 1 deletion src_py/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from .connection import Connection # noqa: E402
from .database import Database # noqa: E402
from .prepared_statement import PreparedStatement # noqa: E402
from .query_result import QueryResult # noqa: E402
from .query_result import ArrowQueryResult, CSRResult, QueryResult # noqa: E402
from .types import Type # noqa: E402

_VERSION_INFO: tuple[str, int] | None = None
Expand All @@ -80,7 +80,9 @@ def __getattr__(name: str) -> str | int:

__all__ = [
"AsyncConnection",
"ArrowQueryResult",
"Connection",
"CSRResult",
"Database",
"PreparedStatement",
"QueryResult",
Expand Down
5 changes: 5 additions & 0 deletions src_py/_lbug_capi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,6 +1229,11 @@ def getAsArrow(self, *_args: Any, **_kwargs: Any) -> Any:
"Arrow export is not yet implemented in C-API backend"
)

def getCSR(self, *_args: Any, **_kwargs: Any) -> Any:
raise NotImplementedError(
"CSR export is not yet implemented in C-API backend"
)

def getAsDF(self) -> Any:
raise NotImplementedError(
"DataFrame export is not yet implemented in C-API backend"
Expand Down
23 changes: 22 additions & 1 deletion src_py/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from ._backend import get_capi_module, get_pybind_module
from .prepared_statement import PreparedStatement
from .query_result import QueryResult
from .query_result import ArrowQueryResult, QueryResult

if TYPE_CHECKING:
import sys
Expand Down Expand Up @@ -369,6 +369,27 @@ def execute(
all_query_results.append(next_query_result)
return all_query_results

def query_as_arrow(self, query: str, chunk_size: int) -> ArrowQueryResult:
"""
Execute a query with the native Arrow collector path.

This is the efficient path for CSR-aware Arrow export.
"""
self.init_connection()
if not self._using_pybind_backend():
msg = "query_as_arrow requires the pybind backend"
raise NotImplementedError(msg)
query_result_internal = self._get_pybind_connection().query_as_arrow(
query, chunk_size
)
if not query_result_internal.isSuccess():
raise RuntimeError(query_result_internal.getErrorMessage())
current_query_result = ArrowQueryResult(
self, query_result_internal, native_chunk_size=chunk_size
)
self._register_query_result(current_query_result)
return current_query_result

def _prepare(
self,
query: str,
Expand Down
54 changes: 54 additions & 0 deletions src_py/query_result.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

from .constants import DST, ID, LABEL, NODES, RELS, SRC
Expand Down Expand Up @@ -525,6 +526,59 @@ def rows_as_dict(self, state=True) -> Self:
return self


class ArrowQueryResult(QueryResult):
"""QueryResult backed by the native Arrow collector path."""

def __init__(
self, connection: Any, query_result: Any, native_chunk_size: int
) -> None:
super().__init__(connection, query_result)
self._native_chunk_size = native_chunk_size

def get_as_arrow(
self, chunk_size: int | None = None, *, fallbackExtensionTypes: bool = False
) -> pa.Table:
"""
Get the query result as a PyArrow Table.

Arrow-native results preserve the execution-time chunking chosen by
`Connection.query_as_arrow(...)`. Requesting `None`, `0`, or `-1`
reuses that native chunk size instead of rechunking the result.
"""
if chunk_size is None or chunk_size <= 0:
chunk_size = self._native_chunk_size
return super().get_as_arrow(
chunk_size, fallbackExtensionTypes=fallbackExtensionTypes
)

def csr(self) -> CSRResult:
"""
Get native CSR arrays from an Arrow query result.

This is available only for Arrow results with CSR metadata, typically
from `Connection.query_as_arrow(...)` on relationship-shaped projections.
"""
self.check_for_query_result_close()

import pyarrow as pa

csr = self._query_result.getCSR()
return CSRResult(
indptr=pa.array(csr["indptr"]),
indices=pa.array(csr["indices"]),
edge_ids=(
None if csr["edge_ids"] is None else pa.array(csr["edge_ids"])
),
)


@dataclass(frozen=True)
class CSRResult:
indptr: pa.Array
indices: pa.Array
edge_ids: pa.Array | None = None


def _row_to_dict(columns: list[str], row: list[Any]) -> dict[str, Any]:
if len(columns) != len(row):
msg = "Number of columns in output row does not match number of columns"
Expand Down
64 changes: 64 additions & 0 deletions test/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,3 +772,67 @@ def test_to_arrow1(conn: lb.Connection) -> None:
-1
) # what is a chunk size of -1 even supposed to mean?
assert arrow_tbl == []


def test_query_as_arrow_csr_with_rel_ids(conn_db_readonly: ConnDB) -> None:
conn, _ = conn_db_readonly
query = """
MATCH (a:person)-[b:knows]->(c:person)
RETURN a.rowid, b.rowid, c.rowid
"""
rows = conn.execute(query).get_all()
csr = conn.query_as_arrow(query, 8).csr()

assert csr.edge_ids is not None

reconstructed = []
indptr = csr.indptr.to_pylist()
indices = csr.indices.to_pylist()
edge_ids = csr.edge_ids.to_pylist()
for src_rowid in range(len(indptr) - 1):
for idx in range(indptr[src_rowid], indptr[src_rowid + 1]):
reconstructed.append([src_rowid, edge_ids[idx], indices[idx]])

assert reconstructed == rows


def test_query_as_arrow_csr_with_extra_columns(conn_db_readonly: ConnDB) -> None:
conn, _ = conn_db_readonly
query = """
MATCH (a:person)-[b:knows]->(c:person)
RETURN a.rowid, b.rowid, c.rowid, b.date, c.fName
"""
result = conn.query_as_arrow(query, 8)
csr = result.csr()
arrow_tbl = result.get_as_arrow(0)

assert csr.edge_ids is not None
assert arrow_tbl.column_names == ["a.rowid", "b.rowid", "c.rowid", "b.date", "c.fName"]
assert len(csr.indptr) >= 2


def test_query_as_arrow_csr_without_rel_ids(conn_db_readonly: ConnDB) -> None:
conn, _ = conn_db_readonly
query = """
MATCH (a:person)-[:knows]->(c:person)
RETURN a.rowid, c.rowid
"""
rows = conn.execute(query).get_all()
csr = conn.query_as_arrow(query, 8).csr()

assert csr.edge_ids is None

reconstructed = []
indptr = csr.indptr.to_pylist()
indices = csr.indices.to_pylist()
for src_rowid in range(len(indptr) - 1):
for idx in range(indptr[src_rowid], indptr[src_rowid + 1]):
reconstructed.append([src_rowid, indices[idx]])

assert reconstructed == rows


def test_query_as_arrow_csr_rejects_non_csr_shape(conn_db_readonly: ConnDB) -> None:
conn, _ = conn_db_readonly
with pytest.raises(RuntimeError, match="CSR export is only supported"):
conn.query_as_arrow("MATCH (a:person) RETURN a.fName", 8).csr()
Loading