diff --git a/arrow-pyarrow-testing/tests/pyarrow.rs b/arrow-pyarrow-testing/tests/pyarrow.rs index 6f3606478c72..09a4f9756822 100644 --- a/arrow-pyarrow-testing/tests/pyarrow.rs +++ b/arrow-pyarrow-testing/tests/pyarrow.rs @@ -84,7 +84,6 @@ fn test_to_pyarrow_byte_view() { ]) .unwrap(); - println!("input: {input:?}"); let res = Python::attach(|py| { let py_input = input.to_pyarrow(py)?; let records = RecordBatch::from_pyarrow_bound(&py_input)?; @@ -120,7 +119,7 @@ value = NotATuple() assert!(err.is_instance_of::(py)); assert_eq!( err.to_string(), - "TypeError: Expected __arrow_c_array__ to return a tuple of (schema, array) capsules." + "TypeError: Expected __arrow_c_array__ to return a tuple of (arrow_schema, arrow_array) capsules." ); }); } diff --git a/arrow-pyarrow/src/lib.rs b/arrow-pyarrow/src/lib.rs index d8f584e396d3..52487bba99f0 100644 --- a/arrow-pyarrow/src/lib.rs +++ b/arrow-pyarrow/src/lib.rs @@ -61,6 +61,7 @@ use std::convert::{From, TryFrom}; use std::ffi::CStr; +use std::ptr::NonNull; use std::sync::Arc; use arrow_array::ffi; @@ -77,15 +78,12 @@ use pyo3::ffi::Py_uintptr_t; use pyo3::import_exception; use pyo3::prelude::*; use pyo3::sync::PyOnceLock; -use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType}; +use pyo3::types::{PyCapsule, PyDict, PyList, PyType}; import_exception!(pyarrow, ArrowException); /// Represents an exception raised by PyArrow. pub type PyArrowException = ArrowException; -const ARROW_ARRAY_STREAM_CAPSULE_NAME: &CStr = c"arrow_array_stream"; -const ARROW_SCHEMA_CAPSULE_NAME: &CStr = c"arrow_schema"; -const ARROW_ARRAY_CAPSULE_NAME: &CStr = c"arrow_array"; fn to_py_err(err: ArrowError) -> PyErr { PyArrowException::new_err(err.to_string()) @@ -131,54 +129,16 @@ fn validate_class(expected: &Bound, value: &Bound) -> PyResult<() Ok(()) } -fn validate_pycapsule(capsule: &Bound, name: &str) -> PyResult<()> { - let capsule_name = capsule.name()?; - if capsule_name.is_none() { - return Err(PyValueError::new_err( - "Expected schema PyCapsule to have name set.", - )); - } - - let capsule_name = unsafe { capsule_name.unwrap().as_cstr().to_str()? }; - if capsule_name != name { - return Err(PyValueError::new_err(format!( - "Expected name '{name}' in PyCapsule, instead got '{capsule_name}'", - ))); - } - - Ok(()) -} - -fn extract_arrow_c_array_capsules<'py>( - value: &Bound<'py, PyAny>, -) -> PyResult<(Bound<'py, PyCapsule>, Bound<'py, PyCapsule>)> { - let tuple = value.call_method0("__arrow_c_array__")?; - - if !tuple.is_instance_of::() { - return Err(PyTypeError::new_err( - "Expected __arrow_c_array__ to return a tuple of (schema, array) capsules.", - )); - } - - tuple.extract().map_err(|_| { - PyTypeError::new_err( - "Expected __arrow_c_array__ to return a tuple of (schema, array) capsules.", - ) - }) -} - impl FromPyArrow for DataType { fn from_pyarrow_bound(value: &Bound) -> PyResult { // Newer versions of PyArrow as well as other libraries with Arrow data implement this // method, so prefer it over _export_to_c. // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html if value.hasattr("__arrow_c_schema__")? { - let capsule = value.call_method0("__arrow_c_schema__")?.extract()?; - validate_pycapsule(&capsule, "arrow_schema")?; - - let schema_ptr = capsule - .pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))? - .cast::(); + let schema_ptr = extract_capsule_from_method::( + value, + "__arrow_c_schema__", + )?; return unsafe { DataType::try_from(schema_ptr.as_ref()) }.map_err(to_py_err); } @@ -203,12 +163,10 @@ impl FromPyArrow for Field { // method, so prefer it over _export_to_c. // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html if value.hasattr("__arrow_c_schema__")? { - let capsule = value.call_method0("__arrow_c_schema__")?.extract()?; - validate_pycapsule(&capsule, "arrow_schema")?; - - let schema_ptr = capsule - .pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))? - .cast::(); + let schema_ptr = extract_capsule_from_method::( + value, + "__arrow_c_schema__", + )?; return unsafe { Field::try_from(schema_ptr.as_ref()) }.map_err(to_py_err); } @@ -233,12 +191,10 @@ impl FromPyArrow for Schema { // method, so prefer it over _export_to_c. // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html if value.hasattr("__arrow_c_schema__")? { - let capsule = value.call_method0("__arrow_c_schema__")?.extract()?; - validate_pycapsule(&capsule, "arrow_schema")?; - - let schema_ptr = capsule - .pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))? - .cast::(); + let schema_ptr = extract_capsule_from_method::( + value, + "__arrow_c_schema__", + )?; return unsafe { Schema::try_from(schema_ptr.as_ref()) }.map_err(to_py_err); } @@ -263,22 +219,12 @@ impl FromPyArrow for ArrayData { // method, so prefer it over _export_to_c. // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html if value.hasattr("__arrow_c_array__")? { - let (schema_capsule, array_capsule) = extract_arrow_c_array_capsules(value)?; - - validate_pycapsule(&schema_capsule, "arrow_schema")?; - validate_pycapsule(&array_capsule, "arrow_array")?; - - let schema_ptr = schema_capsule - .pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))? - .cast::(); - let array = unsafe { - FFI_ArrowArray::from_raw( - array_capsule - .pointer_checked(Some(ARROW_ARRAY_CAPSULE_NAME))? - .cast::() - .as_ptr(), - ) - }; + let (schema_ptr, array_ptr) = + extract_capsule_pair_from_method::( + value, + "__arrow_c_array__", + )?; + let array = unsafe { FFI_ArrowArray::from_raw(array_ptr.as_ptr()) }; return unsafe { ffi::from_ffi(array, schema_ptr.as_ref()) }.map_err(to_py_err); } @@ -341,23 +287,17 @@ impl FromPyArrow for RecordBatch { // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html if value.hasattr("__arrow_c_array__")? { - let (schema_capsule, array_capsule) = extract_arrow_c_array_capsules(value)?; - - validate_pycapsule(&schema_capsule, "arrow_schema")?; - validate_pycapsule(&array_capsule, "arrow_array")?; - - let schema_ptr = schema_capsule - .pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))? - .cast::(); - let array_ptr = array_capsule - .pointer_checked(Some(ARROW_ARRAY_CAPSULE_NAME))? - .cast::(); + let (schema_ptr, array_ptr) = + extract_capsule_pair_from_method::( + value, + "__arrow_c_array__", + )?; let ffi_array = unsafe { FFI_ArrowArray::from_raw(array_ptr.as_ptr()) }; let mut array_data = unsafe { ffi::from_ffi(ffi_array, schema_ptr.as_ref()) }.map_err(to_py_err)?; if !matches!(array_data.data_type(), DataType::Struct(_)) { return Err(PyTypeError::new_err( - "Expected Struct type from __arrow_c_array.", + format!("Expected Struct type from __arrow_c_array__, found {}.", array_data.data_type()), )); } let options = RecordBatchOptions::default().with_row_count(Some(array_data.len())); @@ -421,18 +361,11 @@ impl FromPyArrow for ArrowArrayStreamReader { // method, so prefer it over _export_to_c. // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html if value.hasattr("__arrow_c_stream__")? { - let capsule = value.call_method0("__arrow_c_stream__")?.extract()?; - - validate_pycapsule(&capsule, "arrow_array_stream")?; - - let stream = unsafe { - FFI_ArrowArrayStream::from_raw( - capsule - .pointer_checked(Some(ARROW_ARRAY_STREAM_CAPSULE_NAME))? - .cast::() - .as_ptr(), - ) - }; + let stream_ptr = extract_capsule_from_method::( + value, + "__arrow_c_stream__", + )?; + let stream = unsafe { FFI_ArrowArrayStream::from_raw(stream_ptr.as_ptr()) }; let stream_reader = ArrowArrayStreamReader::try_new(stream) .map_err(|err| PyValueError::new_err(err.to_string()))?; @@ -448,8 +381,7 @@ impl FromPyArrow for ArrowArrayStreamReader { // make the conversion through PyArrow's private API // this changes the pointer's memory and is thus unsafe. // In particular, `_export_to_c` can go out of bounds - let args = PyTuple::new(value.py(), [&raw mut stream as Py_uintptr_t])?; - value.call_method1("_export_to_c", args)?; + value.call_method1("_export_to_c", (&raw mut stream as Py_uintptr_t,))?; ArrowArrayStreamReader::try_new(stream) .map_err(|err| PyValueError::new_err(err.to_string())) @@ -631,3 +563,74 @@ impl From for PyArrowType { Self(s) } } + +trait PyCapsuleType { + const NAME: &CStr; +} + +impl PyCapsuleType for FFI_ArrowSchema { + const NAME: &CStr = c"arrow_schema"; +} + +impl PyCapsuleType for FFI_ArrowArray { + const NAME: &CStr = c"arrow_array"; +} + +impl PyCapsuleType for FFI_ArrowArrayStream { + const NAME: &CStr = c"arrow_array_stream"; +} + +fn extract_capsule_from_method( + object: &Bound<'_, PyAny>, + method_name: &'static str, +) -> PyResult> { + (|| { + Ok(object + .call_method0(method_name)? + .extract::>()? + .pointer_checked(Some(T::NAME))? + .cast::()) + })() + .map_err(|e| { + wrapping_type_error( + object.py(), + e, + format!( + "Expected {method_name} to return a {} capsule.", + T::NAME.to_str().unwrap(), + ), + ) + }) +} + +fn extract_capsule_pair_from_method( + object: &Bound<'_, PyAny>, + method_name: &'static str, +) -> PyResult<(NonNull, NonNull)> { + (|| { + let (c1, c2) = object + .call_method0(method_name)? + .extract::<(Bound<'_, PyCapsule>, Bound<'_, PyCapsule>)>()?; + Ok(( + c1.pointer_checked(Some(T1::NAME))?.cast::(), + c2.pointer_checked(Some(T2::NAME))?.cast::(), + )) + })() + .map_err(|e| { + wrapping_type_error( + object.py(), + e, + format!( + "Expected {method_name} to return a tuple of ({}, {}) capsules.", + T1::NAME.to_str().unwrap(), + T2::NAME.to_str().unwrap() + ), + ) + }) +} + +fn wrapping_type_error(py: Python<'_>, error: PyErr, message: String) -> PyErr { + let e = PyTypeError::new_err(message); + e.set_cause(py, Some(error)); + e +}