Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions arrow-pyarrow-testing/tests/pyarrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ fn test_to_pyarrow_byte_view() {
])
.unwrap();

println!("input: {input:?}");
let res = Python::attach(|py| {
let py_input = input.to_pyarrow(py)?;
let records = RecordBatch::from_pyarrow_bound(&py_input)?;
Expand Down Expand Up @@ -120,7 +119,7 @@ value = NotATuple()
assert!(err.is_instance_of::<PyTypeError>(py));
assert_eq!(
err.to_string(),
"TypeError: Expected __arrow_c_array__ to return a tuple of (schema, array) capsules."
"TypeError: Expected __arrow_c_array__ to return a tuple of (arrow_schema, arrow_array) capsules."
);
});
}
Expand Down
203 changes: 103 additions & 100 deletions arrow-pyarrow/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@

use std::convert::{From, TryFrom};
use std::ffi::CStr;
use std::ptr::NonNull;
use std::sync::Arc;

use arrow_array::ffi;
Expand All @@ -77,15 +78,12 @@ use pyo3::ffi::Py_uintptr_t;
use pyo3::import_exception;
use pyo3::prelude::*;
use pyo3::sync::PyOnceLock;
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
use pyo3::types::{PyCapsule, PyDict, PyList, PyType};

import_exception!(pyarrow, ArrowException);
/// Represents an exception raised by PyArrow.
pub type PyArrowException = ArrowException;

const ARROW_ARRAY_STREAM_CAPSULE_NAME: &CStr = c"arrow_array_stream";
const ARROW_SCHEMA_CAPSULE_NAME: &CStr = c"arrow_schema";
const ARROW_ARRAY_CAPSULE_NAME: &CStr = c"arrow_array";

fn to_py_err(err: ArrowError) -> PyErr {
PyArrowException::new_err(err.to_string())
Expand Down Expand Up @@ -131,54 +129,16 @@ fn validate_class(expected: &Bound<PyType>, value: &Bound<PyAny>) -> PyResult<()
Ok(())
}

fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyResult<()> {
let capsule_name = capsule.name()?;
if capsule_name.is_none() {
return Err(PyValueError::new_err(
"Expected schema PyCapsule to have name set.",
));
}

let capsule_name = unsafe { capsule_name.unwrap().as_cstr().to_str()? };
if capsule_name != name {
return Err(PyValueError::new_err(format!(
"Expected name '{name}' in PyCapsule, instead got '{capsule_name}'",
)));
}

Ok(())
}

fn extract_arrow_c_array_capsules<'py>(
value: &Bound<'py, PyAny>,
) -> PyResult<(Bound<'py, PyCapsule>, Bound<'py, PyCapsule>)> {
let tuple = value.call_method0("__arrow_c_array__")?;

if !tuple.is_instance_of::<PyTuple>() {
return Err(PyTypeError::new_err(
"Expected __arrow_c_array__ to return a tuple of (schema, array) capsules.",
));
}

tuple.extract().map_err(|_| {
PyTypeError::new_err(
"Expected __arrow_c_array__ to return a tuple of (schema, array) capsules.",
)
})
}

impl FromPyArrow for DataType {
fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
// Newer versions of PyArrow as well as other libraries with Arrow data implement this
// method, so prefer it over _export_to_c.
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
if value.hasattr("__arrow_c_schema__")? {
let capsule = value.call_method0("__arrow_c_schema__")?.extract()?;
validate_pycapsule(&capsule, "arrow_schema")?;

let schema_ptr = capsule
.pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
.cast::<FFI_ArrowSchema>();
let schema_ptr = extract_capsule_from_method::<FFI_ArrowSchema>(
value,
"__arrow_c_schema__",
)?;
return unsafe { DataType::try_from(schema_ptr.as_ref()) }.map_err(to_py_err);
}

Expand All @@ -203,12 +163,10 @@ impl FromPyArrow for Field {
// method, so prefer it over _export_to_c.
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
if value.hasattr("__arrow_c_schema__")? {
let capsule = value.call_method0("__arrow_c_schema__")?.extract()?;
validate_pycapsule(&capsule, "arrow_schema")?;

let schema_ptr = capsule
.pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
.cast::<FFI_ArrowSchema>();
let schema_ptr = extract_capsule_from_method::<FFI_ArrowSchema>(
value,
"__arrow_c_schema__",
)?;
return unsafe { Field::try_from(schema_ptr.as_ref()) }.map_err(to_py_err);
}

Expand All @@ -233,12 +191,10 @@ impl FromPyArrow for Schema {
// method, so prefer it over _export_to_c.
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
if value.hasattr("__arrow_c_schema__")? {
let capsule = value.call_method0("__arrow_c_schema__")?.extract()?;
validate_pycapsule(&capsule, "arrow_schema")?;

let schema_ptr = capsule
.pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
.cast::<FFI_ArrowSchema>();
let schema_ptr = extract_capsule_from_method::<FFI_ArrowSchema>(
value,
"__arrow_c_schema__",
)?;
return unsafe { Schema::try_from(schema_ptr.as_ref()) }.map_err(to_py_err);
}

Expand All @@ -263,22 +219,12 @@ impl FromPyArrow for ArrayData {
// method, so prefer it over _export_to_c.
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
if value.hasattr("__arrow_c_array__")? {
let (schema_capsule, array_capsule) = extract_arrow_c_array_capsules(value)?;

validate_pycapsule(&schema_capsule, "arrow_schema")?;
validate_pycapsule(&array_capsule, "arrow_array")?;

let schema_ptr = schema_capsule
.pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
.cast::<FFI_ArrowSchema>();
let array = unsafe {
FFI_ArrowArray::from_raw(
array_capsule
.pointer_checked(Some(ARROW_ARRAY_CAPSULE_NAME))?
.cast::<FFI_ArrowArray>()
.as_ptr(),
)
};
let (schema_ptr, array_ptr) =
extract_capsule_pair_from_method::<FFI_ArrowSchema, FFI_ArrowArray>(
value,
"__arrow_c_array__",
)?;
let array = unsafe { FFI_ArrowArray::from_raw(array_ptr.as_ptr()) };
return unsafe { ffi::from_ffi(array, schema_ptr.as_ref()) }.map_err(to_py_err);
}

Expand Down Expand Up @@ -341,23 +287,17 @@ impl FromPyArrow for RecordBatch {
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html

if value.hasattr("__arrow_c_array__")? {
let (schema_capsule, array_capsule) = extract_arrow_c_array_capsules(value)?;

validate_pycapsule(&schema_capsule, "arrow_schema")?;
validate_pycapsule(&array_capsule, "arrow_array")?;

let schema_ptr = schema_capsule
.pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
.cast::<FFI_ArrowSchema>();
let array_ptr = array_capsule
.pointer_checked(Some(ARROW_ARRAY_CAPSULE_NAME))?
.cast::<FFI_ArrowArray>();
let (schema_ptr, array_ptr) =
extract_capsule_pair_from_method::<FFI_ArrowSchema, FFI_ArrowArray>(
value,
"__arrow_c_array__",
)?;
let ffi_array = unsafe { FFI_ArrowArray::from_raw(array_ptr.as_ptr()) };
let mut array_data =
unsafe { ffi::from_ffi(ffi_array, schema_ptr.as_ref()) }.map_err(to_py_err)?;
if !matches!(array_data.data_type(), DataType::Struct(_)) {
return Err(PyTypeError::new_err(
"Expected Struct type from __arrow_c_array.",
format!("Expected Struct type from __arrow_c_array__, found {}.", array_data.data_type()),
));
}
let options = RecordBatchOptions::default().with_row_count(Some(array_data.len()));
Expand Down Expand Up @@ -421,18 +361,11 @@ impl FromPyArrow for ArrowArrayStreamReader {
// method, so prefer it over _export_to_c.
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
if value.hasattr("__arrow_c_stream__")? {
let capsule = value.call_method0("__arrow_c_stream__")?.extract()?;

validate_pycapsule(&capsule, "arrow_array_stream")?;

let stream = unsafe {
FFI_ArrowArrayStream::from_raw(
capsule
.pointer_checked(Some(ARROW_ARRAY_STREAM_CAPSULE_NAME))?
.cast::<FFI_ArrowArrayStream>()
.as_ptr(),
)
};
let stream_ptr = extract_capsule_from_method::<FFI_ArrowArrayStream>(
value,
"__arrow_c_stream__",
)?;
let stream = unsafe { FFI_ArrowArrayStream::from_raw(stream_ptr.as_ptr()) };

let stream_reader = ArrowArrayStreamReader::try_new(stream)
.map_err(|err| PyValueError::new_err(err.to_string()))?;
Expand All @@ -448,8 +381,7 @@ impl FromPyArrow for ArrowArrayStreamReader {
// make the conversion through PyArrow's private API
// this changes the pointer's memory and is thus unsafe.
// In particular, `_export_to_c` can go out of bounds
let args = PyTuple::new(value.py(), [&raw mut stream as Py_uintptr_t])?;
value.call_method1("_export_to_c", args)?;
value.call_method1("_export_to_c", (&raw mut stream as Py_uintptr_t,))?;

ArrowArrayStreamReader::try_new(stream)
.map_err(|err| PyValueError::new_err(err.to_string()))
Expand Down Expand Up @@ -631,3 +563,74 @@ impl<T> From<T> for PyArrowType<T> {
Self(s)
}
}

trait PyCapsuleType {
const NAME: &CStr;
}

impl PyCapsuleType for FFI_ArrowSchema {
const NAME: &CStr = c"arrow_schema";
}

impl PyCapsuleType for FFI_ArrowArray {
const NAME: &CStr = c"arrow_array";
}

impl PyCapsuleType for FFI_ArrowArrayStream {
const NAME: &CStr = c"arrow_array_stream";
}

fn extract_capsule_from_method<T: PyCapsuleType>(
object: &Bound<'_, PyAny>,
method_name: &'static str,
) -> PyResult<NonNull<T>> {
(|| {
Ok(object
.call_method0(method_name)?
.extract::<Bound<'_, PyCapsule>>()?
.pointer_checked(Some(T::NAME))?
.cast::<T>())
})()
.map_err(|e| {
wrapping_type_error(
object.py(),
e,
format!(
"Expected {method_name} to return a {} capsule.",
T::NAME.to_str().unwrap(),
),
)
})
}

fn extract_capsule_pair_from_method<T1: PyCapsuleType, T2: PyCapsuleType>(
object: &Bound<'_, PyAny>,
method_name: &'static str,
) -> PyResult<(NonNull<T1>, NonNull<T2>)> {
(|| {
let (c1, c2) = object
.call_method0(method_name)?
.extract::<(Bound<'_, PyCapsule>, Bound<'_, PyCapsule>)>()?;
Ok((
c1.pointer_checked(Some(T1::NAME))?.cast::<T1>(),
c2.pointer_checked(Some(T2::NAME))?.cast::<T2>(),
))
})()
.map_err(|e| {
wrapping_type_error(
object.py(),
e,
format!(
"Expected {method_name} to return a tuple of ({}, {}) capsules.",
T1::NAME.to_str().unwrap(),
T2::NAME.to_str().unwrap()
),
)
})
}

fn wrapping_type_error(py: Python<'_>, error: PyErr, message: String) -> PyErr {
let e = PyTypeError::new_err(message);
e.set_cause(py, Some(error));
e
}
Loading