diff --git a/arrow-pyarrow-testing/tests/pyarrow.rs b/arrow-pyarrow-testing/tests/pyarrow.rs index 4ca661b104d2..6f3606478c72 100644 --- a/arrow-pyarrow-testing/tests/pyarrow.rs +++ b/arrow-pyarrow-testing/tests/pyarrow.rs @@ -42,7 +42,10 @@ use arrow_array::{ Array, ArrayRef, BinaryViewArray, Int32Array, RecordBatch, StringArray, StringViewArray, }; use arrow_pyarrow::{FromPyArrow, ToPyArrow}; +use pyo3::exceptions::PyTypeError; +use pyo3::types::{PyAnyMethods, PyModule}; use pyo3::Python; +use std::ffi::CString; use std::sync::Arc; #[test] @@ -94,6 +97,34 @@ fn test_to_pyarrow_byte_view() { } } +#[test] +fn test_from_pyarrow_non_tuple() { + Python::initialize(); + + Python::attach(|py| { + let code = CString::new( + r#" +class NotATuple: + def __arrow_c_array__(self): + return 1 + +value = NotATuple() +"#, + ) + .unwrap(); + + let module = PyModule::from_code(py, code.as_c_str(), c"test.py", c"test_module").unwrap(); + let value = module.getattr("value").unwrap(); + + let err = RecordBatch::from_pyarrow_bound(&value).unwrap_err(); + assert!(err.is_instance_of::(py)); + assert_eq!( + err.to_string(), + "TypeError: Expected __arrow_c_array__ to return a tuple of (schema, array) capsules." + ); + }); +} + fn binary_view_column(num_variadic_buffers: usize) -> BinaryViewArray { let long_scalar = b"but soft what light through yonder window breaks".as_slice(); let mut builder = BinaryViewBuilder::new().with_fixed_block_size(long_scalar.len() as u32); diff --git a/arrow-pyarrow/src/lib.rs b/arrow-pyarrow/src/lib.rs index 95f1d38fddf3..d8f584e396d3 100644 --- a/arrow-pyarrow/src/lib.rs +++ b/arrow-pyarrow/src/lib.rs @@ -149,6 +149,24 @@ fn validate_pycapsule(capsule: &Bound, name: &str) -> PyResult<()> { Ok(()) } +fn extract_arrow_c_array_capsules<'py>( + value: &Bound<'py, PyAny>, +) -> PyResult<(Bound<'py, PyCapsule>, Bound<'py, PyCapsule>)> { + let tuple = value.call_method0("__arrow_c_array__")?; + + if !tuple.is_instance_of::() { + return Err(PyTypeError::new_err( + "Expected __arrow_c_array__ to return a tuple of (schema, array) capsules.", + )); + } + + tuple.extract().map_err(|_| { + PyTypeError::new_err( + "Expected __arrow_c_array__ to return a tuple of (schema, array) capsules.", + ) + }) +} + impl FromPyArrow for DataType { fn from_pyarrow_bound(value: &Bound) -> PyResult { // Newer versions of PyArrow as well as other libraries with Arrow data implement this @@ -245,8 +263,7 @@ impl FromPyArrow for ArrayData { // method, so prefer it over _export_to_c. // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html if value.hasattr("__arrow_c_array__")? { - let (schema_capsule, array_capsule) = - value.call_method0("__arrow_c_array__")?.extract()?; + let (schema_capsule, array_capsule) = extract_arrow_c_array_capsules(value)?; validate_pycapsule(&schema_capsule, "arrow_schema")?; validate_pycapsule(&array_capsule, "arrow_array")?; @@ -324,8 +341,7 @@ impl FromPyArrow for RecordBatch { // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html if value.hasattr("__arrow_c_array__")? { - let (schema_capsule, array_capsule) = - value.call_method0("__arrow_c_array__")?.extract()?; + let (schema_capsule, array_capsule) = extract_arrow_c_array_capsules(value)?; validate_pycapsule(&schema_capsule, "arrow_schema")?; validate_pycapsule(&array_capsule, "arrow_array")?;