Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 144 additions & 24 deletions bindings/python/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use pyo3::types::{
PyDeltaAccess, PyDict, PyList, PySequence, PySlice, PyTime, PyTimeAccess, PyTuple, PyType,
PyTzInfo,
};
use pyo3::{Bound, IntoPyObjectExt, Py, PyAny, PyRef, PyRefMut, PyResult, Python};
use pyo3_async_runtimes::tokio::future_into_py;
use std::collections::HashMap;
use std::sync::Arc;
Expand Down Expand Up @@ -1863,6 +1864,13 @@ enum ScannerKind {
Batch(fcore::client::RecordBatchLogScanner),
}

/// The internal state of the scanner, protected by a Tokio Mutex for async cross-thread sharing
struct ScannerState {
kind: ScannerKind,
/// A buffer to hold records polled from the network before yielding them one-by-one to Python
pending_records: std::collections::VecDeque<Py<ScanRecord>>,
}

impl ScannerKind {
fn as_record(&self) -> PyResult<&fcore::client::LogScanner> {
match self {
Expand Down Expand Up @@ -1901,7 +1909,7 @@ macro_rules! with_scanner {
/// - Batch-based scanning via `poll_arrow()` / `poll_record_batch()` - returns Arrow batches
#[pyclass]
pub struct LogScanner {
scanner: ScannerKind,
state: Arc<tokio::sync::Mutex<ScannerState>>,
admin: fcore::client::FlussAdmin,
table_info: fcore::metadata::TableInfo,
/// The projected Arrow schema to use for empty table creation
Expand All @@ -1922,7 +1930,8 @@ impl LogScanner {
fn subscribe(&self, py: Python, bucket_id: i32, start_offset: i64) -> PyResult<()> {
py.detach(|| {
TOKIO_RUNTIME.block_on(async {
with_scanner!(&self.scanner, subscribe(bucket_id, start_offset))
let state = self.state.lock().await;
with_scanner!(&state.kind, subscribe(bucket_id, start_offset))
.map_err(|e| FlussError::from_core_error(&e))
})
})
Expand All @@ -1935,7 +1944,8 @@ impl LogScanner {
fn subscribe_buckets(&self, py: Python, bucket_offsets: HashMap<i32, i64>) -> PyResult<()> {
py.detach(|| {
TOKIO_RUNTIME.block_on(async {
with_scanner!(&self.scanner, subscribe_buckets(&bucket_offsets))
let state = self.state.lock().await;
with_scanner!(&state.kind, subscribe_buckets(&bucket_offsets))
.map_err(|e| FlussError::from_core_error(&e))
})
})
Expand All @@ -1956,8 +1966,9 @@ impl LogScanner {
) -> PyResult<()> {
py.detach(|| {
TOKIO_RUNTIME.block_on(async {
let state = self.state.lock().await;
with_scanner!(
&self.scanner,
&state.kind,
subscribe_partition(partition_id, bucket_id, start_offset)
)
.map_err(|e| FlussError::from_core_error(&e))
Expand All @@ -1976,8 +1987,9 @@ impl LogScanner {
) -> PyResult<()> {
py.detach(|| {
TOKIO_RUNTIME.block_on(async {
let state = self.state.lock().await;
with_scanner!(
&self.scanner,
&state.kind,
subscribe_partition_buckets(&partition_bucket_offsets)
)
.map_err(|e| FlussError::from_core_error(&e))
Expand All @@ -1992,7 +2004,8 @@ impl LogScanner {
fn unsubscribe(&self, py: Python, bucket_id: i32) -> PyResult<()> {
py.detach(|| {
TOKIO_RUNTIME.block_on(async {
with_scanner!(&self.scanner, unsubscribe(bucket_id))
let state = self.state.lock().await;
with_scanner!(&state.kind, unsubscribe(bucket_id))
.map_err(|e| FlussError::from_core_error(&e))
})
})
Expand All @@ -2006,11 +2019,9 @@ impl LogScanner {
fn unsubscribe_partition(&self, py: Python, partition_id: i64, bucket_id: i32) -> PyResult<()> {
py.detach(|| {
TOKIO_RUNTIME.block_on(async {
with_scanner!(
&self.scanner,
unsubscribe_partition(partition_id, bucket_id)
)
.map_err(|e| FlussError::from_core_error(&e))
let state = self.state.lock().await;
with_scanner!(&state.kind, unsubscribe_partition(partition_id, bucket_id))
.map_err(|e| FlussError::from_core_error(&e))
})
})
}
Expand All @@ -2030,7 +2041,10 @@ impl LogScanner {
/// - Returns an empty ScanRecords if no records are available
/// - When timeout expires, returns an empty ScanRecords (NOT an error)
fn poll(&self, py: Python, timeout_ms: i64) -> PyResult<ScanRecords> {
let scanner = self.scanner.as_record()?;
let scanner_ref =
unsafe { &*(&self.state as *const std::sync::Arc<tokio::sync::Mutex<ScannerState>>) };
let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await });
let scanner = lock.kind.as_record()?;

if timeout_ms < 0 {
return Err(FlussError::new_err(format!(
Expand Down Expand Up @@ -2079,7 +2093,10 @@ impl LogScanner {
/// - Returns an empty list if no batches are available
/// - When timeout expires, returns an empty list (NOT an error)
fn poll_record_batch(&self, py: Python, timeout_ms: i64) -> PyResult<Vec<RecordBatch>> {
let scanner = self.scanner.as_batch()?;
let scanner_ref =
unsafe { &*(&self.state as *const std::sync::Arc<tokio::sync::Mutex<ScannerState>>) };
let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await });
let scanner = lock.kind.as_batch()?;

if timeout_ms < 0 {
return Err(FlussError::new_err(format!(
Expand Down Expand Up @@ -2114,7 +2131,10 @@ impl LogScanner {
/// - Returns an empty table (with correct schema) if no records are available
/// - When timeout expires, returns an empty table (NOT an error)
fn poll_arrow(&self, py: Python, timeout_ms: i64) -> PyResult<Py<PyAny>> {
let scanner = self.scanner.as_batch()?;
let scanner_ref =
unsafe { &*(&self.state as *const std::sync::Arc<tokio::sync::Mutex<ScannerState>>) };
let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await });
let scanner = lock.kind.as_batch()?;

if timeout_ms < 0 {
return Err(FlussError::new_err(format!(
Expand Down Expand Up @@ -2167,13 +2187,20 @@ impl LogScanner {
/// Returns:
/// PyArrow Table containing all data from subscribed buckets
fn to_arrow(&self, py: Python) -> PyResult<Py<PyAny>> {
let scanner = self.scanner.as_batch()?;
let subscribed = scanner.get_subscribed_buckets();
if subscribed.is_empty() {
return Err(FlussError::new_err(
"No buckets subscribed. Call subscribe(), subscribe_buckets(), subscribe_partition(), or subscribe_partition_buckets() first.",
));
}
let subscribed = {
let scanner_ref = unsafe {
&*(&self.state as *const std::sync::Arc<tokio::sync::Mutex<ScannerState>>)
};
let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await });
let scanner = lock.kind.as_batch()?;
let subs = scanner.get_subscribed_buckets();
if subs.is_empty() {
return Err(FlussError::new_err(
"No buckets subscribed. Call subscribe(), subscribe_buckets(), subscribe_partition(), or subscribe_partition_buckets() first.",
));
}
subs.clone()
};

// 2. Query latest offsets for all subscribed buckets
let stopping_offsets = self.query_latest_offsets(py, &subscribed)?;
Expand All @@ -2199,6 +2226,90 @@ impl LogScanner {
Ok(df)
}

fn __aiter__<'py>(slf: PyRef<'py, Self>) -> PyResult<Bound<'py, PyAny>> {
let py = slf.py();
let code = pyo3::ffi::c_str!(
r#"
async def _adapter(obj):
while True:
try:
yield await obj.__anext__()
except StopAsyncIteration:
break
"#
);
let globals = pyo3::types::PyDict::new(py);
py.run(code, Some(&globals), None)?;
let adapter = globals.get_item("_adapter")?.unwrap();
// Return adapt(self)
adapter.call1((slf.into_bound_py_any(py)?,))
}

fn __anext__<'py>(slf: PyRefMut<'py, Self>) -> PyResult<Option<Bound<'py, PyAny>>> {
let state_arc = slf.state.clone();
let projected_row_type = slf.projected_row_type.clone();
let py = slf.py();

let future = future_into_py(py, async move {
let mut state = state_arc.lock().await;

// 1. If we already have buffered records, pop and return immediately
if let Some(record) = state.pending_records.pop_front() {
return Ok(record.into_any());
}

// 2. Buffer is empty, we must poll the network for the next batch
// The underlying kind must be a Record-based scanner.
let scanner = match state.kind.as_record() {
Ok(s) => s,
Err(_) => {
return Err(pyo3::exceptions::PyStopAsyncIteration::new_err(
"Stream Ended",
));
}
};

// Poll with a reasonable internal timeout before unblocking the event loop
let timeout = core::time::Duration::from_millis(5000);

let mut current_records = scanner
.poll(timeout)
.await
.map_err(|e| FlussError::from_core_error(&e))?;

// If it's a real timeout with zero records, loop or throw StopAsyncIteration?
// Since it's a streaming log, we can yield None or block. Blocking requires a loop in the future.
while current_records.is_empty() {
current_records = scanner
.poll(timeout)
.await
.map_err(|e| FlussError::from_core_error(&e))?;
}

// Now we have records.
Python::attach(|py| {
for (_, records) in current_records.into_records_by_buckets() {
for core_record in records {
let scan_record =
ScanRecord::from_core(py, &core_record, &projected_row_type)?;
state.pending_records.push_back(Py::new(py, scan_record)?);
}
}

// Pop the very first one to return right now
if let Some(record) = state.pending_records.pop_front() {
Ok(record.into_any())
} else {
Err(pyo3::exceptions::PyStopAsyncIteration::new_err(
"Stream Ended",
))
}
})
})?;

Ok(Some(future))
}

fn __repr__(&self) -> String {
format!("LogScanner(table={})", self.table_info.table_path)
}
Expand All @@ -2213,7 +2324,10 @@ impl LogScanner {
projected_row_type: fcore::metadata::RowType,
) -> Self {
Self {
scanner,
state: std::sync::Arc::new(tokio::sync::Mutex::new(ScannerState {
kind: scanner,
pending_records: std::collections::VecDeque::new(),
})),
admin,
table_info,
projected_schema,
Expand Down Expand Up @@ -2264,7 +2378,10 @@ impl LogScanner {
py: Python,
subscribed: &[(fcore::metadata::TableBucket, i64)],
) -> PyResult<HashMap<fcore::metadata::TableBucket, i64>> {
let scanner = self.scanner.as_batch()?;
let scanner_ref =
unsafe { &*(&self.state as *const std::sync::Arc<tokio::sync::Mutex<ScannerState>>) };
let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await });
let scanner = lock.kind.as_batch()?;
let is_partitioned = scanner.is_partitioned();
let table_path = &self.table_info.table_path;

Expand Down Expand Up @@ -2367,7 +2484,10 @@ impl LogScanner {
py: Python,
mut stopping_offsets: HashMap<fcore::metadata::TableBucket, i64>,
) -> PyResult<Py<PyAny>> {
let scanner = self.scanner.as_batch()?;
let scanner_ref =
unsafe { &*(&self.state as *const std::sync::Arc<tokio::sync::Mutex<ScannerState>>) };
let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await });
let scanner = lock.kind.as_batch()?;
let mut all_batches = Vec::new();

while !stopping_offsets.is_empty() {
Expand Down
49 changes: 49 additions & 0 deletions bindings/python/test/test_log_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,55 @@ async def test_scan_records_indexing_and_slicing(connection, admin):
await admin.drop_table(table_path, ignore_if_not_exists=False)


async def test_async_iterator(connection, admin):
"""Test the Python asynchronous iterator loop (`async for`) on LogScanner."""
table_path = fluss.TablePath("fluss", "py_test_async_iterator")
await admin.drop_table(table_path, ignore_if_not_exists=True)

schema = fluss.Schema(
pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())])
)
await admin.create_table(table_path, fluss.TableDescriptor(schema))

table = await connection.get_table(table_path)
writer = table.new_append().create_writer()

# Write 5 records
writer.write_arrow_batch(
pa.RecordBatch.from_arrays(
[pa.array(list(range(1, 6)), type=pa.int32()),
pa.array([f"async{i}" for i in range(1, 6)])],
schema=pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]),
)
)
await writer.flush()

scanner = await table.new_scan().create_log_scanner()
num_buckets = (await admin.get_table_info(table_path)).num_buckets
scanner.subscribe_buckets({i: fluss.EARLIEST_OFFSET for i in range(num_buckets)})

collected = []

# Here is the magical Issue #424 async iterator logic at work:
async def consume_scanner():
async for record in scanner:
collected.append(record)
if len(collected) == 5:
break

# We must race the consumption against a timeout so the test doesn't hang if the iterator is broken
await asyncio.wait_for(consume_scanner(), timeout=10.0)

assert len(collected) == 5, f"Expected 5 records, got {len(collected)}"

collected.sort(key=lambda r: r.row["id"])
for i, record in enumerate(collected):
assert record.row["id"] == i + 1
assert record.row["val"] == f"async{i + 1}"

await admin.drop_table(table_path, ignore_if_not_exists=False)


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
Expand Down