Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
295 changes: 290 additions & 5 deletions src/ser.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use half::{bf16, f16};
use serde::{
ser::{
Error as _, Impossible, SerializeMap, SerializeSeq, SerializeStruct,
Expand All @@ -7,7 +8,53 @@ use serde::{
};
use serde_json::error::Error;

use crate::{array::TryCollect, DestructuredRef, IArray, INumber, IObject, IString, IValue};
use crate::{
array::{ArraySliceRef, TryCollect},
DestructuredRef, IArray, INumber, IObject, IString, IValue,
};

/// Rounds an f64 to the given number of significant decimal digits.
#[inline]
fn round_to_sig_digits(val: f64, sig_digits: u32) -> f64 {
// how many digits are left of the decimal point, minus one
// OR: what power of 10 is this number closest to? (e.g. 100 -> 10^2, 0.001 -> 10^-3)
let order_of_magnitude = val.abs().log10().floor() as i32;
// Multiplier that shifts the desired significant digits into the integer part,
// so that f64::round() can snap to the nearest integer and discard the rest.
// e.g. for val=3.14, order_of_magnitude = 0, sig_digits=2: scale=10, so 3.14*10=31.4 → round→31 → 31/10=3.1
let scale = 10f64.powi(sig_digits as i32 - 1 - order_of_magnitude);
(val * scale).round() / scale
}

/// Finds an f64 value that, when formatted by ryu's f64 algorithm, produces
/// the shortest decimal string that still round-trips through the target
/// half-precision type (f16 or bf16).
///
/// ryu only supports f32/f64, and serde has no `serialize_f16`. Since f16/bf16
/// have far fewer distinct values than f32, there exist shorter representations
/// that uniquely identify the half value. For example, f16(0.3) = 0.300048828125,
/// and "0.3" parsed as f16 gives back the same bits — so "0.3" is valid.
///
/// The approach: try rounding to increasing significant digits until the
/// rounded value round-trips through the type. Then return that f64
/// value, so that `serialize_f64` (via ryu) reproduces it.
fn find_shortest_roundtrip_f64(f64_val: f64, roundtrips: impl Fn(f64) -> bool) -> f64 {
if !f64_val.is_finite() || f64_val.fract() == 0.0 {
return f64_val;
}
// With our usage(F16/BF16), the loop will need only ~4 iterations, since max significant digits needed is ~4
// Example: f16(3.14159) stores 3.140625
// sig_digits=1 → 3.0 → f16(3.0)=3.0 ≠ 3.140625 ❌
// sig_digits=2 → 3.1 → f16(3.1)=3.099.. ≠ 3.140625 ❌
// sig_digits=3 → 3.14 → f16(3.14)=3.140625 ✅ → returns 3.14
for sig_digits in 1..=5u32 {
let rounded = round_to_sig_digits(f64_val, sig_digits);
if roundtrips(rounded) {
return rounded;
}
}
f64_val
}

impl Serialize for IValue {
#[inline]
Expand Down Expand Up @@ -55,11 +102,50 @@ impl Serialize for IArray {
where
S: Serializer,
{
let mut s = serializer.serialize_seq(Some(self.len()))?;
for v in self {
s.serialize_element(&v)?;
match self.as_slice() {
// Serialize typed float arrays with the shortest representation that
// round-trips through the stored precision. Without this, all floats
// would be promoted to f64 via INumber, and ryu's f64 algorithm would
// emit unnecessarily long strings (e.g. "0.3" stored as f32 would
// serialize as "0.30000001192092896" instead of "0.3").
//
// F32: serialize directly as f32 so ryu uses its f32 algorithm.
// F16/BF16: ryu has no f16 mode and serde has no serialize_f16, so we
// find the shortest decimal that round-trips through the half type and
// pass the corresponding f64 value to serialize_f64.
ArraySliceRef::F32(slice) => {
let mut s = serializer.serialize_seq(Some(slice.len()))?;
for &v in slice {
s.serialize_element(&v)?;
}
s.end()
}
ArraySliceRef::F16(slice) => {
let mut s = serializer.serialize_seq(Some(slice.len()))?;
for &v in slice {
let f64_val = f64::from(v);
let shortest = find_shortest_roundtrip_f64(f64_val, |p| f16::from_f64(p) == v);
s.serialize_element(&shortest)?;
}
s.end()
}
ArraySliceRef::BF16(slice) => {
let mut s = serializer.serialize_seq(Some(slice.len()))?;
for &v in slice {
let f64_val = f64::from(v);
let shortest = find_shortest_roundtrip_f64(f64_val, |p| bf16::from_f64(p) == v);
s.serialize_element(&shortest)?;
}
s.end()
}
_ => {
let mut s = serializer.serialize_seq(Some(self.len()))?;
for v in self {
s.serialize_element(&v)?;
}
s.end()
}
}
s.end()
}
}

Expand Down Expand Up @@ -635,3 +721,202 @@ where
{
value.serialize(ValueSerializer)
}

#[cfg(test)]
mod tests {
use crate::array::{ArraySliceRef, FloatType};
use crate::{FPHAConfig, IArray, IValue, IValueDeserSeed};

#[test]
fn test_f32_array_serialization_preserves_short_representation() {
let mut arr = IArray::new();
arr.push_with_fp_type(IValue::from(0.3), FloatType::F32)
.unwrap();
assert!(matches!(arr.as_slice(), ArraySliceRef::F32(_)));

let json = serde_json::to_string(&arr).unwrap();
assert_eq!(
json, "[0.3]",
"F32 array should serialize 0.3 as '0.3', not with extra f64 precision digits"
);
}

#[test]
fn test_f64_array_serialization_preserves_short_representation() {
let mut arr = IArray::new();
arr.push_with_fp_type(IValue::from(0.3), FloatType::F64)
.unwrap();
assert!(matches!(arr.as_slice(), ArraySliceRef::F64(_)));

let json = serde_json::to_string(&arr).unwrap();
assert_eq!(json, "[0.3]");
}

#[test]
fn test_f16_array_serialization_preserves_short_representation() {
let mut arr = IArray::new();
arr.push_with_fp_type(IValue::from(1.5), FloatType::F16)
.unwrap();
assert_eq!(serde_json::to_string(&arr).unwrap(), "[1.5]");

let mut arr2 = IArray::new();
arr2.push_with_fp_type(IValue::from(0.3), FloatType::F16)
.unwrap();
assert_eq!(
serde_json::to_string(&arr2).unwrap(),
"[0.3]",
"F16 array should serialize 0.3 as '0.3', not '0.30004883' or '0.300048828125'"
);
}

#[test]
fn test_bf16_array_serialization_preserves_short_representation() {
let mut arr = IArray::new();
arr.push_with_fp_type(IValue::from(1.5), FloatType::BF16)
.unwrap();
assert_eq!(serde_json::to_string(&arr).unwrap(), "[1.5]");

let mut arr2 = IArray::new();
arr2.push_with_fp_type(IValue::from(0.3), FloatType::BF16)
.unwrap();
assert_eq!(
serde_json::to_string(&arr2).unwrap(),
"[0.3]",
"BF16 array should serialize 0.3 as '0.3'"
);
}

#[test]
fn test_typed_float_array_serialization_roundtrip() {
let input = "[0.3,0.1,0.7,1.0,2.5,100.0]";
let fp_types = [
FloatType::F16,
FloatType::BF16,
FloatType::F32,
FloatType::F64,
];

let jsons: Vec<String> = fp_types
.iter()
.map(|&fp_type| {
let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(fp_type)));
let mut de = serde_json::Deserializer::from_str(input);
let arr = serde::de::DeserializeSeed::deserialize(seed, &mut de)
.unwrap()
.into_array()
.unwrap();
let json_out = serde_json::to_string(&arr).unwrap();
assert_eq!(
json_out, input,
"{fp_type} round-trip should preserve the original JSON string"
);
json_out
})
.collect();

for pair in jsons.windows(2) {
assert_eq!(
pair[0], pair[1],
"all float types should produce identical JSON"
);
}
}

#[test]
fn test_f16_precision_loss_produces_different_but_short_representation() {
// Values with more significant digits than f16 can represent (~3.3 digits).
// The stored f16 value differs from the original, so the serialized string
// must differ too — but it should still be the shortest string that
// round-trips through f16.
let cases: &[(&str, &str)] = &[
("3.14159", "3.14"), // pi truncated: f16 stores 3.140625
("42.42", "42.4"), // f16 stores 42.40625
("12.345", "12.34"), // f16 stores 12.34375
("0.5678", "0.568"), // f16 stores 0.56787109375
];

for &(input, expected_f16) in cases {
let json_input = format!("[{input}]");

let f16_arr: IArray = {
let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F16)));
let mut de = serde_json::Deserializer::from_str(&json_input);
serde::de::DeserializeSeed::deserialize(seed, &mut de)
.unwrap()
.into_array()
.unwrap()
};
let f16_json = serde_json::to_string(&f16_arr).unwrap();
assert_eq!(
f16_json,
format!("[{expected_f16}]"),
"F16 of {input}: should serialize as shortest f16 representation"
);

// Same values through F32 should preserve the original (enough precision)
let f32_arr: IArray = {
let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F32)));
let mut de = serde_json::Deserializer::from_str(&json_input);
serde::de::DeserializeSeed::deserialize(seed, &mut de)
.unwrap()
.into_array()
.unwrap()
};
let f32_json = serde_json::to_string(&f32_arr).unwrap();
assert_eq!(
f32_json, json_input,
"F32 of {input}: should preserve the original representation"
);
}
}

#[test]
fn test_negative_float_array_serialization() {
let input = "[-0.3,-0.1,-1.0,-2.5,-100.0]";
let fp_types = [
FloatType::F16,
FloatType::BF16,
FloatType::F32,
FloatType::F64,
];

for &fp_type in &fp_types {
let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(fp_type)));
let mut de = serde_json::Deserializer::from_str(input);
let arr = serde::de::DeserializeSeed::deserialize(seed, &mut de)
.unwrap()
.into_array()
.unwrap();
let json_out = serde_json::to_string(&arr).unwrap();
assert_eq!(
json_out, input,
"{fp_type} negative round-trip should preserve the original JSON string"
);
}
}

#[test]
fn test_negative_f16_precision_loss_produces_short_representation() {
let cases: &[(&str, &str)] = &[
("-3.14159", "-3.14"),
("-42.42", "-42.4"),
("-0.5678", "-0.568"),
];

for &(input, expected_f16) in cases {
let json_input = format!("[{input}]");
let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F16)));
let mut de = serde_json::Deserializer::from_str(&json_input);
let arr = serde::de::DeserializeSeed::deserialize(seed, &mut de)
.unwrap()
.into_array()
.unwrap();
let json_out = serde_json::to_string(&arr).unwrap();
assert_eq!(
json_out,
format!("[{expected_f16}]"),
"F16 of {input}: negative should serialize as shortest f16 representation"
);
}
}
}
Loading