Skip to content

Commit 0b81c56

Browse files
committed
feat: Support serde_json's arbitrary_precision feature
1 parent 43c714f commit 0b81c56

File tree

3 files changed

+161
-7
lines changed

3 files changed

+161
-7
lines changed

Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ documentation = "https://docs.rs/crate/pythonize/"
1313

1414
[dependencies]
1515
serde = { version = "1.0", default-features = false, features = ["std"] }
16+
serde_json = { version = "1.0", optional = true }
1617
pyo3 = { version = "0.27", default-features = false }
1718

1819
[dev-dependencies]
@@ -22,3 +23,6 @@ serde_json = "1.0"
2223
serde_bytes = "0.11"
2324
maplit = "1.0.2"
2425
serde_path_to_error = "0.1.15"
26+
27+
[features]
28+
arbitrary_precision = ["serde_json", "serde_json/arbitrary_precision"]

src/ser.rs

Lines changed: 96 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ use pyo3::types::{
44
PyDict, PyDictMethods, PyList, PyListMethods, PyMapping, PySequence, PyString, PyTuple,
55
PyTupleMethods,
66
};
7+
#[cfg(feature = "arbitrary_precision")]
8+
use pyo3::types::{PyAnyMethods, PyFloat, PyInt};
79
use pyo3::{Bound, BoundObject, IntoPyObject, PyAny, PyResult, Python};
810
use serde::{ser, Serialize};
911

@@ -229,6 +231,21 @@ pub struct PythonStructVariantSerializer<'py, P: PythonizeTypes> {
229231
inner: PythonStructDictSerializer<'py, P>,
230232
}
231233

234+
#[cfg(feature = "arbitrary_precision")]
235+
#[doc(hidden)]
236+
pub enum StructSerializer<'py, P: PythonizeTypes> {
237+
Struct(PythonStructDictSerializer<'py, P>),
238+
Number {
239+
py: Python<'py>,
240+
number_string: Option<String>,
241+
_types: PhantomData<P>,
242+
},
243+
}
244+
245+
#[cfg(not(feature = "arbitrary_precision"))]
246+
#[doc(hidden)]
247+
pub type StructSerializer<'py, P> = PythonStructDictSerializer<'py, P>;
248+
232249
#[doc(hidden)]
233250
pub struct PythonStructDictSerializer<'py, P: PythonizeTypes> {
234251
py: Python<'py>,
@@ -266,7 +283,7 @@ impl<'py, P: PythonizeTypes> ser::Serializer for Pythonizer<'py, P> {
266283
type SerializeTupleStruct = PythonCollectionSerializer<'py, P>;
267284
type SerializeTupleVariant = PythonTupleVariantSerializer<'py, P>;
268285
type SerializeMap = PythonMapSerializer<'py, P>;
269-
type SerializeStruct = PythonStructDictSerializer<'py, P>;
286+
type SerializeStruct = StructSerializer<'py, P>;
270287
type SerializeStructVariant = PythonStructVariantSerializer<'py, P>;
271288

272289
fn serialize_bool(self, v: bool) -> Result<Bound<'py, PyAny>> {
@@ -439,12 +456,34 @@ impl<'py, P: PythonizeTypes> ser::Serializer for Pythonizer<'py, P> {
439456
self,
440457
name: &'static str,
441458
len: usize,
442-
) -> Result<PythonStructDictSerializer<'py, P>> {
443-
Ok(PythonStructDictSerializer {
444-
py: self.py,
445-
builder: P::NamedMap::builder(self.py, len, name)?,
446-
_types: PhantomData,
447-
})
459+
) -> Result<StructSerializer<'py, P>> {
460+
#[cfg(feature = "arbitrary_precision")]
461+
{
462+
// With arbitrary_precision enabled, a serde_json::Number serializes as a "$serde_json::private::Number"
463+
// struct with a "$serde_json::private::Number" field, whose value is the String in Number::n.
464+
if name == "$serde_json::private::Number" && len == 1 {
465+
return Ok(StructSerializer::Number {
466+
py: self.py,
467+
number_string: None,
468+
_types: PhantomData,
469+
});
470+
}
471+
472+
Ok(StructSerializer::Struct(PythonStructDictSerializer {
473+
py: self.py,
474+
builder: P::NamedMap::builder(self.py, len, name)?,
475+
_types: PhantomData,
476+
}))
477+
}
478+
479+
#[cfg(not(feature = "arbitrary_precision"))]
480+
{
481+
Ok(PythonStructDictSerializer {
482+
py: self.py,
483+
builder: P::NamedMap::builder(self.py, len, name)?,
484+
_types: PhantomData,
485+
})
486+
}
448487
}
449488

450489
fn serialize_struct_variant(
@@ -569,6 +608,56 @@ impl<'py, P: PythonizeTypes> ser::SerializeMap for PythonMapSerializer<'py, P> {
569608
}
570609
}
571610

611+
#[cfg(feature = "arbitrary_precision")]
612+
impl<'py, P: PythonizeTypes> ser::SerializeStruct for StructSerializer<'py, P> {
613+
type Ok = Bound<'py, PyAny>;
614+
type Error = PythonizeError;
615+
616+
fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<()>
617+
where
618+
T: ?Sized + Serialize,
619+
{
620+
match self {
621+
StructSerializer::Struct(s) => s.serialize_field(key, value),
622+
StructSerializer::Number { number_string, .. } => {
623+
let serde_json::Value::String(s) = value
624+
.serialize(serde_json::value::Serializer)
625+
.map_err(|e| PythonizeError::msg(format!("Failed to serialize number: {}", e)))?
626+
else {
627+
return Err(PythonizeError::msg("Expected string in serde_json::Number"));
628+
};
629+
630+
*number_string = Some(s);
631+
Ok(())
632+
}
633+
}
634+
}
635+
636+
fn end(self) -> Result<Bound<'py, PyAny>> {
637+
match self {
638+
StructSerializer::Struct(s) => s.end(),
639+
StructSerializer::Number {
640+
py, number_string: Some(s), ..
641+
} => {
642+
if let Ok(i) = s.parse::<i64>() {
643+
return Ok(PyInt::new(py, i).into_any());
644+
}
645+
if let Ok(u) = s.parse::<u64>() {
646+
return Ok(PyInt::new(py, u).into_any());
647+
}
648+
if let Ok(f) = s.parse::<f64>() {
649+
return Ok(PyFloat::new(py, f).into_any());
650+
}
651+
// Fall back to Python's int() constructor, which supports arbitrary precision.
652+
py.get_type::<PyInt>()
653+
.call1((s.as_str(),))
654+
.map_err(|e| PythonizeError::msg(format!("Invalid number: {}", e)))
655+
}
656+
StructSerializer::Number { .. } => Err(PythonizeError::msg("Empty serde_json::Number")),
657+
}
658+
}
659+
}
660+
572661
impl<'py, P: PythonizeTypes> ser::SerializeStruct for PythonStructDictSerializer<'py, P> {
573662
type Ok = Bound<'py, PyAny>;
574663
type Error = PythonizeError;

tests/test_arbitrary_precision.rs

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#![cfg(feature = "arbitrary_precision")]
2+
3+
use pyo3::prelude::*;
4+
use pythonize::pythonize;
5+
6+
#[test]
7+
fn test_serde_json_number_with_arbitrary_precision() {
8+
use serde_json::Value;
9+
10+
Python::attach(|py| {
11+
// Parse JSON with arbitrary_precision feature
12+
let json_str = r#"{"number": 12345678901234567890}"#;
13+
let value: Value = serde_json::from_str(json_str).unwrap();
14+
15+
// Serialize to Python
16+
let result = pythonize(py, &value).unwrap();
17+
18+
// Extract the number field
19+
let number = result.get_item("number").unwrap();
20+
21+
// Should be a Python int, not a dict
22+
assert!(number.is_instance_of::<pyo3::types::PyInt>());
23+
24+
// Should have the correct value
25+
let number_str = number.str().unwrap().to_string();
26+
assert_eq!(number_str, "12345678901234567890");
27+
});
28+
}
29+
30+
#[test]
31+
fn test_serde_json_float_with_arbitrary_precision() {
32+
use serde_json::Value;
33+
34+
Python::attach(|py| {
35+
let json_str = r#"{"number": 3.141592653589793238}"#;
36+
let value: Value = serde_json::from_str(json_str).unwrap();
37+
38+
let result = pythonize(py, &value).unwrap();
39+
let number = result.get_item("number").unwrap();
40+
41+
// Should be a Python float, not a dict
42+
assert!(number.is_instance_of::<pyo3::types::PyFloat>());
43+
});
44+
}
45+
46+
#[test]
47+
fn test_serde_json_simple_number_with_arbitrary_precision() {
48+
use serde_json::Value;
49+
50+
Python::attach(|py| {
51+
let json_str = r#"2"#;
52+
let value: Value = serde_json::from_str(json_str).unwrap();
53+
54+
let result = pythonize(py, &value).unwrap();
55+
56+
// Should be a Python int 2, not {'$serde_json::private::Number': '2'}
57+
assert!(result.is_instance_of::<pyo3::types::PyInt>());
58+
let num: i32 = result.extract().unwrap();
59+
assert_eq!(num, 2);
60+
});
61+
}

0 commit comments

Comments
 (0)