Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ documentation = "https://docs.rs/crate/pythonize/"

[dependencies]
serde = { version = "1.0", default-features = false, features = ["std"] }
serde_json = { version = "1.0", optional = true }
pyo3 = { version = "0.27", default-features = false }

[dev-dependencies]
Expand All @@ -22,3 +23,6 @@ serde_json = "1.0"
serde_bytes = "0.11"
maplit = "1.0.2"
serde_path_to_error = "0.1.15"

[features]
arbitrary_precision = ["serde_json", "serde_json/arbitrary_precision"]
105 changes: 98 additions & 7 deletions src/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use pyo3::types::{
PyDict, PyDictMethods, PyList, PyListMethods, PyMapping, PySequence, PyString, PyTuple,
PyTupleMethods,
};
#[cfg(feature = "arbitrary_precision")]
use pyo3::types::{PyAnyMethods, PyFloat, PyInt};
use pyo3::{Bound, BoundObject, IntoPyObject, PyAny, PyResult, Python};
use serde::{ser, Serialize};

Expand Down Expand Up @@ -229,6 +231,21 @@ pub struct PythonStructVariantSerializer<'py, P: PythonizeTypes> {
inner: PythonStructDictSerializer<'py, P>,
}

#[cfg(feature = "arbitrary_precision")]
#[doc(hidden)]
pub enum StructSerializer<'py, P: PythonizeTypes> {
Struct(PythonStructDictSerializer<'py, P>),
Number {
py: Python<'py>,
number_string: Option<String>,
_types: PhantomData<P>,
},
}

#[cfg(not(feature = "arbitrary_precision"))]
#[doc(hidden)]
pub type StructSerializer<'py, P> = PythonStructDictSerializer<'py, P>;

#[doc(hidden)]
pub struct PythonStructDictSerializer<'py, P: PythonizeTypes> {
py: Python<'py>,
Expand Down Expand Up @@ -266,7 +283,7 @@ impl<'py, P: PythonizeTypes> ser::Serializer for Pythonizer<'py, P> {
type SerializeTupleStruct = PythonCollectionSerializer<'py, P>;
type SerializeTupleVariant = PythonTupleVariantSerializer<'py, P>;
type SerializeMap = PythonMapSerializer<'py, P>;
type SerializeStruct = PythonStructDictSerializer<'py, P>;
type SerializeStruct = StructSerializer<'py, P>;
type SerializeStructVariant = PythonStructVariantSerializer<'py, P>;

fn serialize_bool(self, v: bool) -> Result<Bound<'py, PyAny>> {
Expand Down Expand Up @@ -439,12 +456,34 @@ impl<'py, P: PythonizeTypes> ser::Serializer for Pythonizer<'py, P> {
self,
name: &'static str,
len: usize,
) -> Result<PythonStructDictSerializer<'py, P>> {
Ok(PythonStructDictSerializer {
py: self.py,
builder: P::NamedMap::builder(self.py, len, name)?,
_types: PhantomData,
})
) -> Result<StructSerializer<'py, P>> {
#[cfg(feature = "arbitrary_precision")]
{
// With arbitrary_precision enabled, a serde_json::Number serializes as a "$serde_json::private::Number"
// struct with a "$serde_json::private::Number" field, whose value is the String in Number::n.
if name == "$serde_json::private::Number" && len == 1 {
return Ok(StructSerializer::Number {
py: self.py,
number_string: None,
_types: PhantomData,
});
}

Ok(StructSerializer::Struct(PythonStructDictSerializer {
py: self.py,
builder: P::NamedMap::builder(self.py, len, name)?,
_types: PhantomData,
}))
}

#[cfg(not(feature = "arbitrary_precision"))]
{
Ok(PythonStructDictSerializer {
py: self.py,
builder: P::NamedMap::builder(self.py, len, name)?,
_types: PhantomData,
})
}
}

fn serialize_struct_variant(
Expand Down Expand Up @@ -569,6 +608,58 @@ impl<'py, P: PythonizeTypes> ser::SerializeMap for PythonMapSerializer<'py, P> {
}
}

#[cfg(feature = "arbitrary_precision")]
impl<'py, P: PythonizeTypes> ser::SerializeStruct for StructSerializer<'py, P> {
type Ok = Bound<'py, PyAny>;
type Error = PythonizeError;

fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<()>
where
T: ?Sized + Serialize,
{
match self {
StructSerializer::Struct(s) => s.serialize_field(key, value),
StructSerializer::Number { number_string, .. } => {
let serde_json::Value::String(s) = value
.serialize(serde_json::value::Serializer)
.map_err(|e| PythonizeError::msg(format!("Failed to serialize number: {}", e)))?
else {
return Err(PythonizeError::msg("Expected string in serde_json::Number"));
};

*number_string = Some(s);
Ok(())
}
}
}

fn end(self) -> Result<Bound<'py, PyAny>> {
match self {
StructSerializer::Struct(s) => s.end(),
StructSerializer::Number {
py, number_string: Some(s), ..
} => {
if let Ok(i) = s.parse::<i64>() {
return Ok(PyInt::new(py, i).into_any());
}
if let Ok(u) = s.parse::<u64>() {
return Ok(PyInt::new(py, u).into_any());
}
if s.chars().any(|c| c == '.' || c == 'e' || c == 'E') {
if let Ok(f) = s.parse::<f64>() {
return Ok(PyFloat::new(py, f).into_any());
}
}
// Fall back to Python's int() constructor, which supports arbitrary precision.
py.get_type::<PyInt>()
.call1((s.as_str(),))
.map_err(|e| PythonizeError::msg(format!("Invalid number: {}", e)))
}
StructSerializer::Number { .. } => Err(PythonizeError::msg("Empty serde_json::Number")),
}
}
}

impl<'py, P: PythonizeTypes> ser::SerializeStruct for PythonStructDictSerializer<'py, P> {
type Ok = Bound<'py, PyAny>;
type Error = PythonizeError;
Expand Down
57 changes: 57 additions & 0 deletions tests/test_arbitrary_precision.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#![cfg(feature = "arbitrary_precision")]

use pyo3::prelude::*;
use pythonize::pythonize;
use serde_json::Value;

#[test]
fn test_greater_than_u64_max() {
Python::attach(|py| {
let json_str = r#"18446744073709551616"#;
let value: Value = serde_json::from_str(json_str).unwrap();
let result = pythonize(py, &value).unwrap();
let number_str = result.str().unwrap().to_string();

assert!(result.is_instance_of::<pyo3::types::PyInt>());
assert_eq!(number_str, "18446744073709551616");
});
}

#[test]
fn test_less_than_i64_min() {
Python::attach(|py| {
let json_str = r#"-9223372036854775809"#;
let value: Value = serde_json::from_str(json_str).unwrap();
let result = pythonize(py, &value).unwrap();
let number_str = result.str().unwrap().to_string();

assert!(result.is_instance_of::<pyo3::types::PyInt>());
assert_eq!(number_str, "-9223372036854775809");
});
}

#[test]
fn test_float() {
Python::attach(|py| {
let json_str = r#"3.141592653589793238"#;
let value: Value = serde_json::from_str(json_str).unwrap();
let result = pythonize(py, &value).unwrap();
let num: f32 = result.extract().unwrap();

assert!(result.is_instance_of::<pyo3::types::PyFloat>());
assert_eq!(num, 3.141592653589793238); // not {'$serde_json::private::Number': ...}
});
}

#[test]
fn test_int() {
Python::attach(|py| {
let json_str = r#"2"#;
let value: Value = serde_json::from_str(json_str).unwrap();
let result = pythonize(py, &value).unwrap();
let num: i32 = result.extract().unwrap();

assert!(result.is_instance_of::<pyo3::types::PyInt>());
assert_eq!(num, 2); // not {'$serde_json::private::Number': '2'}
});
}