From 5a8ca2e3e081b141838bde4ebd20315c081a67a3 Mon Sep 17 00:00:00 2001 From: charlesdong1991 Date: Sat, 7 Mar 2026 16:59:53 +0100 Subject: [PATCH 1/7] add array data type support --- bindings/cpp/src/types.rs | 2 + crates/fluss/src/record/arrow.rs | 77 +- crates/fluss/src/row/binary/binary_writer.rs | 10 +- crates/fluss/src/row/binary_array.rs | 736 ++++++++++++++++++ crates/fluss/src/row/column.rs | 194 ++++- .../src/row/compacted/compacted_key_writer.rs | 13 + .../fluss/src/row/compacted/compacted_row.rs | 202 ++++- .../src/row/compacted/compacted_row_reader.rs | 52 +- .../src/row/compacted/compacted_row_writer.rs | 4 + crates/fluss/src/row/datum.rs | 117 ++- .../src/row/encode/compacted_key_encoder.rs | 17 +- crates/fluss/src/row/field_getter.rs | 55 +- crates/fluss/src/row/mod.rs | 14 + 13 files changed, 1448 insertions(+), 45 deletions(-) create mode 100644 crates/fluss/src/row/binary_array.rs diff --git a/bindings/cpp/src/types.rs b/bindings/cpp/src/types.rs index f8efe677..f33034bc 100644 --- a/bindings/cpp/src/types.rs +++ b/bindings/cpp/src/types.rs @@ -351,6 +351,7 @@ pub fn resolve_row_types( Datum::Time(t) => Datum::Time(*t), Datum::TimestampNtz(ts) => Datum::TimestampNtz(*ts), Datum::TimestampLtz(ts) => Datum::TimestampLtz(*ts), + Datum::Array(a) => Datum::Array(a.clone()), }; out.set_field(idx, resolved); } @@ -408,6 +409,7 @@ pub fn compacted_row_to_owned( fcore::metadata::DataType::Binary(dt) => { Datum::Blob(Cow::Owned(row.get_binary(i, dt.length())?.to_vec())) } + fcore::metadata::DataType::Array(_) => Datum::Array(row.get_array(i)?), other => return Err(anyhow!("Unsupported data type for column {i}: {other:?}")), }; diff --git a/crates/fluss/src/record/arrow.rs b/crates/fluss/src/record/arrow.rs index ea27836e..2b499a02 100644 --- a/crates/fluss/src/record/arrow.rs +++ b/crates/fluss/src/record/arrow.rs @@ -25,10 +25,10 @@ use crate::row::{ColumnarRow, InternalRow}; use arrow::array::{ ArrayBuilder, ArrayRef, BinaryBuilder, BooleanBuilder, Date32Builder, Decimal128Builder, FixedSizeBinaryBuilder, Float32Builder, Float64Builder, Int8Builder, Int16Builder, - Int32Builder, Int64Builder, StringBuilder, Time32MillisecondBuilder, Time32SecondBuilder, - Time64MicrosecondBuilder, Time64NanosecondBuilder, TimestampMicrosecondBuilder, - TimestampMillisecondBuilder, TimestampNanosecondBuilder, TimestampSecondBuilder, UInt8Builder, - UInt16Builder, UInt32Builder, UInt64Builder, + Int32Builder, Int64Builder, ListBuilder, StringBuilder, Time32MillisecondBuilder, + Time32SecondBuilder, Time64MicrosecondBuilder, Time64NanosecondBuilder, + TimestampMicrosecondBuilder, TimestampMillisecondBuilder, TimestampNanosecondBuilder, + TimestampSecondBuilder, UInt8Builder, UInt16Builder, UInt32Builder, UInt64Builder, }; use arrow::{ array::RecordBatch, @@ -314,6 +314,10 @@ impl RowAppendRecordBatchBuilder { arrow_schema::DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, _) => { Ok(Box::new(TimestampNanosecondBuilder::new())) } + arrow_schema::DataType::List(field) => { + let inner_builder = Self::create_builder(field.data_type())?; + Ok(Box::new(ListBuilder::new(inner_builder))) + } dt => Err(Error::IllegalArgument { message: format!("Unsupported data type: {dt:?}"), }), @@ -1159,6 +1163,71 @@ pub fn to_arrow_type(fluss_type: &DataType) -> Result { }) } +/// Converts an Arrow data type back to a Fluss `DataType`. +/// Used for reading array elements from Arrow ListArray back into Fluss types. +pub fn from_arrow_type(arrow_type: &ArrowDataType) -> Result { + use crate::metadata::DataTypes; + + Ok(match arrow_type { + ArrowDataType::Boolean => DataTypes::boolean(), + ArrowDataType::Int8 => DataTypes::tinyint(), + ArrowDataType::Int16 => DataTypes::smallint(), + ArrowDataType::Int32 => DataTypes::int(), + ArrowDataType::Int64 => DataTypes::bigint(), + ArrowDataType::Float32 => DataTypes::float(), + ArrowDataType::Float64 => DataTypes::double(), + ArrowDataType::Utf8 => DataTypes::string(), + ArrowDataType::Binary => DataTypes::bytes(), + ArrowDataType::Date32 => DataTypes::date(), + ArrowDataType::FixedSizeBinary(len) => { + if *len < 0 { + return Err(Error::IllegalArgument { + message: format!("FixedSizeBinary length must be >= 0, got {len}"), + }); + } + DataTypes::binary(*len as usize) + } + ArrowDataType::Decimal128(p, s) => { + if *s < 0 { + return Err(Error::IllegalArgument { + message: format!("Decimal scale must be >= 0, got {s}"), + }); + } + DataTypes::decimal(*p as u32, *s as u32) + } + ArrowDataType::Time32(arrow_schema::TimeUnit::Second) => DataTypes::time_with_precision(0), + ArrowDataType::Time32(arrow_schema::TimeUnit::Millisecond) => { + DataTypes::time_with_precision(3) + } + ArrowDataType::Time64(arrow_schema::TimeUnit::Microsecond) => { + DataTypes::time_with_precision(6) + } + ArrowDataType::Time64(arrow_schema::TimeUnit::Nanosecond) => { + DataTypes::time_with_precision(9) + } + ArrowDataType::Timestamp(unit, tz) => { + let precision = match unit { + arrow_schema::TimeUnit::Second => 0, + arrow_schema::TimeUnit::Millisecond => 3, + arrow_schema::TimeUnit::Microsecond => 6, + arrow_schema::TimeUnit::Nanosecond => 9, + }; + + if tz.is_some() { + DataTypes::timestamp_ltz_with_precision(precision) + } else { + DataTypes::timestamp_with_precision(precision) + } + } + ArrowDataType::List(field) => DataTypes::array(from_arrow_type(field.data_type())?), + other => { + return Err(Error::IllegalArgument { + message: format!("Cannot convert Arrow type to Fluss type: {other:?}"), + }); + } + }) +} + #[derive(Clone)] pub struct ReadContext { target_schema: SchemaRef, diff --git a/crates/fluss/src/row/binary/binary_writer.rs b/crates/fluss/src/row/binary/binary_writer.rs index af2765c4..f51a6e80 100644 --- a/crates/fluss/src/row/binary/binary_writer.rs +++ b/crates/fluss/src/row/binary/binary_writer.rs @@ -67,8 +67,7 @@ pub trait BinaryWriter { fn write_timestamp_ltz(&mut self, value: &crate::row::datum::TimestampLtz, precision: u32); - // TODO InternalArray, ArraySerializer - // fn write_array(&mut self, pos: i32, value: i64); + fn write_array(&mut self, value: &[u8]); // TODO Row serializer // fn write_row(&mut self, pos: i32, value: &InternalRow); @@ -136,7 +135,8 @@ pub enum InnerValueWriter { Time(u32), // precision (not used in wire format, but kept for consistency) TimestampNtz(u32), // precision TimestampLtz(u32), // precision - // TODO Array, Row + Array, + // TODO Row } /// Accessor for writing the fields/elements of a binary writer during runtime, the @@ -175,6 +175,7 @@ impl InnerValueWriter { // Validation is done at TimestampLTzType construction time Ok(InnerValueWriter::TimestampLtz(t.precision())) } + DataType::Array(_) => Ok(InnerValueWriter::Array), _ => unimplemented!( "ValueWriter for DataType {:?} is currently not implemented", data_type @@ -237,6 +238,9 @@ impl InnerValueWriter { (InnerValueWriter::TimestampLtz(p), Datum::TimestampLtz(ts)) => { writer.write_timestamp_ltz(ts, *p); } + (InnerValueWriter::Array, Datum::Array(arr)) => { + writer.write_array(arr.as_bytes()); + } _ => { return Err(IllegalArgument { message: format!("{self:?} used to write value {value:?}"), diff --git a/crates/fluss/src/row/binary_array.rs b/crates/fluss/src/row/binary_array.rs new file mode 100644 index 00000000..0975f0a6 --- /dev/null +++ b/crates/fluss/src/row/binary_array.rs @@ -0,0 +1,736 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Binary array format matching Java's `BinaryArray.java` layout. +//! +//! Binary layout: +//! ```text +//! [size(4B)] + [null bits (4-byte word aligned)] + [fixed-length part] + [variable-length part] +//! ``` +//! +//! Java reference: `BinaryArray.java`, `BinaryArrayWriter.java` + +use crate::error::Error::IllegalArgument; +use crate::error::Result; +use crate::metadata::DataType; +use crate::row::Decimal; +use crate::row::datum::{Date, Time, TimestampLtz, TimestampNtz}; +use serde::Serialize; +use std::fmt; +use std::hash::{Hash, Hasher}; + +const MAX_FIX_PART_DATA_SIZE: usize = 7; +const HIGHEST_FIRST_BIT: u64 = 0x80_u64 << 56; +const HIGHEST_SECOND_TO_EIGHTH_BIT: u64 = 0x7F_u64 << 56; + +/// Calculates the header size in bytes: 4 (for element count) + null bits (4-byte word aligned). +/// Matches Java's `BinaryArray.calculateHeaderInBytes(numFields)`. +pub fn calculate_header_in_bytes(num_elements: usize) -> usize { + 4 + num_elements.div_ceil(32) * 4 +} + +/// Calculates the fixed-length part size per element for a given data type. +/// Matches Java's `BinaryArray.calculateFixLengthPartSize(DataType)`. +pub fn calculate_fix_length_part_size(element_type: &DataType) -> usize { + match element_type { + DataType::Boolean(_) | DataType::TinyInt(_) => 1, + DataType::SmallInt(_) => 2, + DataType::Int(_) | DataType::Float(_) | DataType::Date(_) | DataType::Time(_) => 4, + DataType::BigInt(_) + | DataType::Double(_) + | DataType::Char(_) + | DataType::String(_) + | DataType::Binary(_) + | DataType::Bytes(_) + | DataType::Decimal(_) + | DataType::Timestamp(_) + | DataType::TimestampLTz(_) + | DataType::Array(_) + | DataType::Map(_) + | DataType::Row(_) => 8, + } +} + +/// Rounds a byte count up to the nearest 8-byte word boundary. +/// Matches Java's `roundNumberOfBytesToNearestWord`. +fn round_to_nearest_word(num_bytes: usize) -> usize { + (num_bytes + 7) & !7 +} + +/// A Fluss binary array, wire-compatible with Java's `BinaryArray`. +/// +/// Stores elements in a flat byte buffer with a header (element count + null bitmap) +/// followed by fixed-length slots and an optional variable-length section. +#[derive(Clone)] +pub struct FlussArray { + data: Vec, + size: usize, + element_offset: usize, +} + +impl fmt::Debug for FlussArray { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("FlussArray") + .field("size", &self.size) + .field("data_len", &self.data.len()) + .finish() + } +} + +impl fmt::Display for FlussArray { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "FlussArray[size={}]", self.size) + } +} + +impl PartialEq for FlussArray { + fn eq(&self, other: &Self) -> bool { + self.data == other.data + } +} + +impl Eq for FlussArray {} + +impl PartialOrd for FlussArray { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for FlussArray { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.data.cmp(&other.data) + } +} + +impl Hash for FlussArray { + fn hash(&self, state: &mut H) { + self.data.hash(state); + } +} + +impl Serialize for FlussArray { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + serializer.serialize_bytes(&self.data) + } +} + +impl FlussArray { + /// Creates a FlussArray by pointing to existing bytes. + pub fn from_bytes(data: &[u8]) -> Result { + if data.len() < 4 { + return Err(IllegalArgument { + message: format!( + "FlussArray data too short: need at least 4 bytes, got {}", + data.len() + ), + }); + } + let raw_size = i32::from_ne_bytes(data[0..4].try_into().unwrap()); + if raw_size < 0 { + return Err(IllegalArgument { + message: format!("FlussArray size must be non-negative, got {raw_size}"), + }); + } + let size = raw_size as usize; + let element_offset = calculate_header_in_bytes(size); + if element_offset > data.len() { + return Err(IllegalArgument { + message: format!( + "FlussArray header exceeds payload: header={}, payload={}", + element_offset, + data.len() + ), + }); + } + + Ok(FlussArray { + data: data.to_vec(), + size, + element_offset, + }) + } + + /// Returns the number of elements. + pub fn size(&self) -> usize { + self.size + } + + /// Returns the raw bytes of this array (the complete binary representation). + pub fn as_bytes(&self) -> &[u8] { + &self.data + } + + /// Returns true if the element at position `pos` is null. + pub fn is_null_at(&self, pos: usize) -> bool { + let byte_index = pos >> 3; + let bit = pos & 7; + (self.data[4 + byte_index] & (1u8 << bit)) != 0 + } + + fn element_offset(&self, ordinal: usize, element_size: usize) -> usize { + self.element_offset + ordinal * element_size + } + + fn checked_slice(&self, start: usize, len: usize, context: &str) -> Result<&[u8]> { + let end = start.checked_add(len).ok_or_else(|| IllegalArgument { + message: format!("Overflow while reading {context}: start={start}, len={len}"), + })?; + if end > self.data.len() { + return Err(IllegalArgument { + message: format!( + "Out-of-bounds while reading {context}: start={start}, len={len}, payload={}", + self.data.len() + ), + }); + } + Ok(&self.data[start..end]) + } + + fn read_var_len_bytes(&self, pos: usize) -> Result<&[u8]> { + let field_offset = self.element_offset(pos, 8); + let packed = self.get_long(pos) as u64; + let mark = packed & HIGHEST_FIRST_BIT; + + if mark == 0 { + let offset = (packed >> 32) as usize; + let len = (packed & 0xFFFF_FFFF) as usize; + self.checked_slice(offset, len, "variable-length array element") + } else { + let len = ((packed & HIGHEST_SECOND_TO_EIGHTH_BIT) >> 56) as usize; + if len > MAX_FIX_PART_DATA_SIZE { + return Err(IllegalArgument { + message: format!( + "Inline array element length must be <= {MAX_FIX_PART_DATA_SIZE}, got {len}" + ), + }); + } + // Java stores inline bytes in the 8-byte slot itself. + // On little-endian, bytes start at field_offset; on big-endian they start at +1. + let start = if cfg!(target_endian = "little") { + field_offset + } else { + field_offset + 1 + }; + self.checked_slice(start, len, "inline array element") + } + } + + pub fn get_boolean(&self, pos: usize) -> bool { + let offset = self.element_offset(pos, 1); + self.data[offset] != 0 + } + + pub fn get_byte(&self, pos: usize) -> i8 { + let offset = self.element_offset(pos, 1); + self.data[offset] as i8 + } + + pub fn get_short(&self, pos: usize) -> i16 { + let offset = self.element_offset(pos, 2); + i16::from_ne_bytes(self.data[offset..offset + 2].try_into().unwrap()) + } + + pub fn get_int(&self, pos: usize) -> i32 { + let offset = self.element_offset(pos, 4); + i32::from_ne_bytes(self.data[offset..offset + 4].try_into().unwrap()) + } + + pub fn get_long(&self, pos: usize) -> i64 { + let offset = self.element_offset(pos, 8); + i64::from_ne_bytes(self.data[offset..offset + 8].try_into().unwrap()) + } + + pub fn get_float(&self, pos: usize) -> f32 { + let offset = self.element_offset(pos, 4); + f32::from_ne_bytes(self.data[offset..offset + 4].try_into().unwrap()) + } + + pub fn get_double(&self, pos: usize) -> f64 { + let offset = self.element_offset(pos, 8); + f64::from_ne_bytes(self.data[offset..offset + 8].try_into().unwrap()) + } + + /// Reads the offset_and_size packed long for variable-length elements. + fn get_offset_and_size(&self, pos: usize) -> (usize, usize) { + let packed = self.get_long(pos) as u64; + let offset = (packed >> 32) as usize; + let size = (packed & 0xFFFF_FFFF) as usize; + (offset, size) + } + + pub fn get_string(&self, pos: usize) -> Result<&str> { + let bytes = self.read_var_len_bytes(pos)?; + std::str::from_utf8(bytes).map_err(|e| IllegalArgument { + message: format!("Invalid UTF-8 in array element at position {pos}: {e}"), + }) + } + + pub fn get_binary(&self, pos: usize) -> Result<&[u8]> { + self.read_var_len_bytes(pos) + } + + pub fn get_decimal(&self, pos: usize, precision: u32, scale: u32) -> Result { + if Decimal::is_compact_precision(precision) { + let unscaled = self.get_long(pos); + Decimal::from_unscaled_long(unscaled, precision, scale) + } else { + let (offset, size) = self.get_offset_and_size(pos); + let bytes = self.checked_slice(offset, size, "decimal bytes")?; + Decimal::from_unscaled_bytes(bytes, precision, scale) + } + } + + pub fn get_date(&self, pos: usize) -> Date { + Date::new(self.get_int(pos)) + } + + pub fn get_time(&self, pos: usize) -> Time { + Time::new(self.get_int(pos)) + } + + pub fn get_timestamp_ntz(&self, pos: usize, precision: u32) -> Result { + if TimestampNtz::is_compact(precision) { + Ok(TimestampNtz::new(self.get_long(pos))) + } else { + let (offset, _size) = self.get_offset_and_size(pos); + let millis_bytes = self.checked_slice(offset, 8, "timestamp ntz millis")?; + let millis = i64::from_ne_bytes(millis_bytes.try_into().unwrap()); + let nanos = _size as i32; + TimestampNtz::from_millis_nanos(millis, nanos) + } + } + + pub fn get_timestamp_ltz(&self, pos: usize, precision: u32) -> Result { + if TimestampLtz::is_compact(precision) { + Ok(TimestampLtz::new(self.get_long(pos))) + } else { + let (offset, _size) = self.get_offset_and_size(pos); + let millis_bytes = self.checked_slice(offset, 8, "timestamp ltz millis")?; + let millis = i64::from_ne_bytes(millis_bytes.try_into().unwrap()); + let nanos = _size as i32; + TimestampLtz::from_millis_nanos(millis, nanos) + } + } + + pub fn get_array(&self, pos: usize) -> Result { + let bytes = self.read_var_len_bytes(pos)?; + FlussArray::from_bytes(bytes) + } +} + +/// Writer for building a `FlussArray` element by element. +/// Matches Java's `BinaryArrayWriter`. +pub struct FlussArrayWriter { + data: Vec, + null_bits_offset: usize, + element_offset: usize, + element_size: usize, + cursor: usize, + num_elements: usize, +} + +impl FlussArrayWriter { + /// Creates a new writer for an array with `num_elements` elements of the given element type. + pub fn new(num_elements: usize, element_type: &DataType) -> Self { + let element_size = calculate_fix_length_part_size(element_type); + Self::with_element_size(num_elements, element_size) + } + + /// Creates a new writer with an explicit element size (in bytes). + pub fn with_element_size(num_elements: usize, element_size: usize) -> Self { + let header_in_bytes = calculate_header_in_bytes(num_elements); + let fixed_size = round_to_nearest_word(header_in_bytes + element_size * num_elements); + let mut data = vec![0u8; fixed_size]; + + // Write element count at offset 0 (native endian, matches Java Unsafe behavior) + data[0..4].copy_from_slice(&(num_elements as i32).to_ne_bytes()); + + FlussArrayWriter { + data, + null_bits_offset: 4, + element_offset: header_in_bytes, + element_size, + cursor: fixed_size, + num_elements, + } + } + + fn get_element_offset(&self, pos: usize) -> usize { + self.element_offset + self.element_size * pos + } + + /// Sets the null bit for the element at position `pos`. + pub fn set_null_at(&mut self, pos: usize) { + let byte_index = pos >> 3; + let bit = pos & 7; + self.data[self.null_bits_offset + byte_index] |= 1u8 << bit; + } + + pub fn write_boolean(&mut self, pos: usize, value: bool) { + let offset = self.get_element_offset(pos); + self.data[offset] = if value { 1 } else { 0 }; + } + + pub fn write_byte(&mut self, pos: usize, value: i8) { + let offset = self.get_element_offset(pos); + self.data[offset] = value as u8; + } + + pub fn write_short(&mut self, pos: usize, value: i16) { + let offset = self.get_element_offset(pos); + self.data[offset..offset + 2].copy_from_slice(&value.to_ne_bytes()); + } + + pub fn write_int(&mut self, pos: usize, value: i32) { + let offset = self.get_element_offset(pos); + self.data[offset..offset + 4].copy_from_slice(&value.to_ne_bytes()); + } + + pub fn write_long(&mut self, pos: usize, value: i64) { + let offset = self.get_element_offset(pos); + self.data[offset..offset + 8].copy_from_slice(&value.to_ne_bytes()); + } + + pub fn write_float(&mut self, pos: usize, value: f32) { + let offset = self.get_element_offset(pos); + self.data[offset..offset + 4].copy_from_slice(&value.to_ne_bytes()); + } + + pub fn write_double(&mut self, pos: usize, value: f64) { + let offset = self.get_element_offset(pos); + self.data[offset..offset + 8].copy_from_slice(&value.to_ne_bytes()); + } + + /// Writes variable-length bytes to the variable part and stores offset+size in the fixed slot. + fn write_bytes_to_var_len_part(&mut self, pos: usize, bytes: &[u8]) { + let rounded = round_to_nearest_word(bytes.len()); + let var_offset = self.cursor; + self.data.resize(self.data.len() + rounded, 0); + self.data[var_offset..var_offset + bytes.len()].copy_from_slice(bytes); + self.set_offset_and_size(pos, var_offset, bytes.len()); + self.cursor += rounded; + } + + fn set_offset_and_size(&mut self, pos: usize, offset: usize, size: usize) { + let packed = ((offset as i64) << 32) | (size as i64); + self.write_long(pos, packed); + } + + fn write_bytes_to_fix_len_part(&mut self, pos: usize, bytes: &[u8]) { + let len = bytes.len(); + debug_assert!(len <= MAX_FIX_PART_DATA_SIZE); + let first_byte = (len as u64) | 0x80; + let mut seven_bytes = 0_u64; + if cfg!(target_endian = "little") { + for (i, b) in bytes.iter().enumerate() { + seven_bytes |= ((*b as u64) & 0xFF) << (i * 8); + } + } else { + for (i, b) in bytes.iter().enumerate() { + seven_bytes |= ((*b as u64) & 0xFF) << ((6 - i) * 8); + } + } + let packed = ((first_byte << 56) | seven_bytes) as i64; + self.write_long(pos, packed); + } + + pub fn write_string(&mut self, pos: usize, value: &str) { + let bytes = value.as_bytes(); + if bytes.len() <= MAX_FIX_PART_DATA_SIZE { + self.write_bytes_to_fix_len_part(pos, bytes); + } else { + self.write_bytes_to_var_len_part(pos, bytes); + } + } + + pub fn write_binary_bytes(&mut self, pos: usize, value: &[u8]) { + if value.len() <= MAX_FIX_PART_DATA_SIZE { + self.write_bytes_to_fix_len_part(pos, value); + } else { + self.write_bytes_to_var_len_part(pos, value); + } + } + + pub fn write_decimal(&mut self, pos: usize, value: &Decimal, precision: u32) { + if Decimal::is_compact_precision(precision) { + self.write_long( + pos, + value + .to_unscaled_long() + .expect("Decimal should fit in i64 for compact precision"), + ); + } else { + let bytes = value.to_unscaled_bytes(); + self.write_bytes_to_var_len_part(pos, &bytes); + } + } + + pub fn write_date(&mut self, pos: usize, value: Date) { + self.write_int(pos, value.get_inner()); + } + + pub fn write_time(&mut self, pos: usize, value: Time) { + self.write_int(pos, value.get_inner()); + } + + pub fn write_timestamp_ntz(&mut self, pos: usize, value: &TimestampNtz, precision: u32) { + if TimestampNtz::is_compact(precision) { + self.write_long(pos, value.get_millisecond()); + } else { + let millis_bytes = value.get_millisecond().to_ne_bytes(); + let var_offset = self.cursor; + let rounded = round_to_nearest_word(8); + self.data.resize(self.data.len() + rounded, 0); + self.data[var_offset..var_offset + 8].copy_from_slice(&millis_bytes); + self.set_offset_and_size(pos, var_offset, value.get_nano_of_millisecond() as usize); + self.cursor += rounded; + } + } + + pub fn write_timestamp_ltz(&mut self, pos: usize, value: &TimestampLtz, precision: u32) { + if TimestampLtz::is_compact(precision) { + self.write_long(pos, value.get_epoch_millisecond()); + } else { + let millis_bytes = value.get_epoch_millisecond().to_ne_bytes(); + let var_offset = self.cursor; + let rounded = round_to_nearest_word(8); + self.data.resize(self.data.len() + rounded, 0); + self.data[var_offset..var_offset + 8].copy_from_slice(&millis_bytes); + self.set_offset_and_size(pos, var_offset, value.get_nano_of_millisecond() as usize); + self.cursor += rounded; + } + } + + /// Writes a nested FlussArray into this array at position `pos`. + pub fn write_array(&mut self, pos: usize, value: &FlussArray) { + self.write_bytes_to_var_len_part(pos, value.as_bytes()); + } + + /// Finalizes the writer and returns the completed FlussArray. + pub fn complete(self) -> Result { + let mut data = self.data; + data.truncate(self.cursor); + FlussArray::from_bytes(&data) + } + + /// Returns the number of elements this writer was initialized with. + pub fn num_elements(&self) -> usize { + self.num_elements + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::metadata::DataTypes; + + #[test] + fn test_header_calculation() { + assert_eq!(calculate_header_in_bytes(0), 4); + assert_eq!(calculate_header_in_bytes(1), 8); + assert_eq!(calculate_header_in_bytes(31), 8); + assert_eq!(calculate_header_in_bytes(32), 8); + assert_eq!(calculate_header_in_bytes(33), 12); + assert_eq!(calculate_header_in_bytes(64), 12); + assert_eq!(calculate_header_in_bytes(65), 16); + } + + #[test] + fn test_fix_length_part_size() { + assert_eq!(calculate_fix_length_part_size(&DataTypes::boolean()), 1); + assert_eq!(calculate_fix_length_part_size(&DataTypes::tinyint()), 1); + assert_eq!(calculate_fix_length_part_size(&DataTypes::smallint()), 2); + assert_eq!(calculate_fix_length_part_size(&DataTypes::int()), 4); + assert_eq!(calculate_fix_length_part_size(&DataTypes::bigint()), 8); + assert_eq!(calculate_fix_length_part_size(&DataTypes::float()), 4); + assert_eq!(calculate_fix_length_part_size(&DataTypes::double()), 8); + assert_eq!(calculate_fix_length_part_size(&DataTypes::string()), 8); + assert_eq!( + calculate_fix_length_part_size(&DataTypes::array(DataTypes::int())), + 8 + ); + } + + #[test] + fn test_round_trip_int_array() { + let elem_type = DataTypes::int(); + let mut writer = FlussArrayWriter::new(3, &elem_type); + writer.write_int(0, 10); + writer.write_int(1, 20); + writer.write_int(2, 30); + let array = writer.complete().unwrap(); + + assert_eq!(array.size(), 3); + assert!(!array.is_null_at(0)); + assert_eq!(array.get_int(0), 10); + assert_eq!(array.get_int(1), 20); + assert_eq!(array.get_int(2), 30); + } + + #[test] + fn test_round_trip_with_nulls() { + let elem_type = DataTypes::int(); + let mut writer = FlussArrayWriter::new(3, &elem_type); + writer.write_int(0, 1); + writer.set_null_at(1); + writer.write_int(2, 3); + let array = writer.complete().unwrap(); + + assert_eq!(array.size(), 3); + assert!(!array.is_null_at(0)); + assert!(array.is_null_at(1)); + assert!(!array.is_null_at(2)); + assert_eq!(array.get_int(0), 1); + assert_eq!(array.get_int(2), 3); + } + + #[test] + fn test_round_trip_string_array() { + let elem_type = DataTypes::string(); + let mut writer = FlussArrayWriter::new(3, &elem_type); + writer.write_string(0, "hello"); + writer.write_string(1, "world"); + writer.write_string(2, "!"); + let array = writer.complete().unwrap(); + + assert_eq!(array.size(), 3); + assert_eq!(array.get_string(0).unwrap(), "hello"); + assert_eq!(array.get_string(1).unwrap(), "world"); + assert_eq!(array.get_string(2).unwrap(), "!"); + } + + #[test] + fn test_java_inline_short_string_decoding() { + // Manually construct Java-style inline encoded short string ("abc") + // slot payload: [len|0x80 in top byte] + [bytes in low 7 bytes on little-endian] + let mut data = vec![0_u8; 16]; + data[0..4].copy_from_slice(&(1_i32).to_ne_bytes()); + // null bits remain 0 + let first_byte = (3_u64 | 0x80) << 56; + let seven_bytes = (b'a' as u64) | ((b'b' as u64) << 8) | ((b'c' as u64) << 16); + let packed = first_byte | seven_bytes; + data[8..16].copy_from_slice(&packed.to_ne_bytes()); + + let arr = FlussArray::from_bytes(&data).unwrap(); + assert_eq!(arr.size(), 1); + assert_eq!(arr.get_string(0).unwrap(), "abc"); + } + + #[test] + fn test_java_inline_short_binary_decoding() { + let elem_type = DataTypes::bytes(); + let mut writer = FlussArrayWriter::new(1, &elem_type); + writer.write_binary_bytes(0, b"abc"); + let arr = writer.complete().unwrap(); + assert_eq!(arr.get_binary(0).unwrap(), b"abc"); + } + + #[test] + fn test_round_trip_empty_array() { + let elem_type = DataTypes::int(); + let writer = FlussArrayWriter::new(0, &elem_type); + let array = writer.complete().unwrap(); + assert_eq!(array.size(), 0); + } + + #[test] + fn test_round_trip_boolean_array() { + let elem_type = DataTypes::boolean(); + let mut writer = FlussArrayWriter::new(3, &elem_type); + writer.write_boolean(0, true); + writer.write_boolean(1, false); + writer.write_boolean(2, true); + let array = writer.complete().unwrap(); + + assert_eq!(array.size(), 3); + assert!(array.get_boolean(0)); + assert!(!array.get_boolean(1)); + assert!(array.get_boolean(2)); + } + + #[test] + fn test_round_trip_long_array() { + let elem_type = DataTypes::bigint(); + let mut writer = FlussArrayWriter::new(2, &elem_type); + writer.write_long(0, i64::MAX); + writer.write_long(1, i64::MIN); + let array = writer.complete().unwrap(); + + assert_eq!(array.get_long(0), i64::MAX); + assert_eq!(array.get_long(1), i64::MIN); + } + + #[test] + fn test_round_trip_double_array() { + let elem_type = DataTypes::double(); + let mut writer = FlussArrayWriter::new(2, &elem_type); + writer.write_double(0, 1.23); + writer.write_double(1, -4.56); + let array = writer.complete().unwrap(); + + assert_eq!(array.get_double(0), 1.23); + assert_eq!(array.get_double(1), -4.56); + } + + #[test] + fn test_round_trip_nested_array() { + let inner_type = DataTypes::int(); + let outer_type = DataTypes::array(DataTypes::int()); + + // Build inner array [1, 2] + let mut inner_writer = FlussArrayWriter::new(2, &inner_type); + inner_writer.write_int(0, 1); + inner_writer.write_int(1, 2); + let inner_array = inner_writer.complete().unwrap(); + + // Build outer array containing the inner array + let mut outer_writer = FlussArrayWriter::new(1, &outer_type); + outer_writer.write_array(0, &inner_array); + let outer_array = outer_writer.complete().unwrap(); + + assert_eq!(outer_array.size(), 1); + let nested = outer_array.get_array(0).unwrap(); + assert_eq!(nested.size(), 2); + assert_eq!(nested.get_int(0), 1); + assert_eq!(nested.get_int(1), 2); + } + + #[test] + fn test_binary_layout_matches_java() { + // Verify exact byte layout for a simple [1, 2, 3] int array + let elem_type = DataTypes::int(); + let mut writer = FlussArrayWriter::new(3, &elem_type); + writer.write_int(0, 1); + writer.write_int(1, 2); + writer.write_int(2, 3); + let array = writer.complete().unwrap(); + let bytes = array.as_bytes(); + + // size = 3 at offset 0 (4 bytes, native endian) + assert_eq!(i32::from_ne_bytes(bytes[0..4].try_into().unwrap()), 3); + // null bits: 4 bytes starting at offset 4, should be all zeros + assert_eq!(&bytes[4..8], &[0, 0, 0, 0]); + // elements start at offset 8 (header = 4 + 4), each 4 bytes + assert_eq!(i32::from_ne_bytes(bytes[8..12].try_into().unwrap()), 1); + assert_eq!(i32::from_ne_bytes(bytes[12..16].try_into().unwrap()), 2); + assert_eq!(i32::from_ne_bytes(bytes[16..20].try_into().unwrap()), 3); + } +} diff --git a/crates/fluss/src/row/column.rs b/crates/fluss/src/row/column.rs index c07fe97c..be2cc78b 100644 --- a/crates/fluss/src/row/column.rs +++ b/crates/fluss/src/row/column.rs @@ -407,17 +407,115 @@ impl InternalRow for ColumnarRow { })? .value(self.row_id)) } + + fn get_array(&self, pos: usize) -> Result { + use crate::record::from_arrow_type; + use crate::row::binary_array::FlussArrayWriter; + use arrow::array::ListArray; + + let column = self.column(pos)?; + let list_array = + column + .as_any() + .downcast_ref::() + .ok_or_else(|| IllegalArgument { + message: format!("expected List array at position {pos}"), + })?; + + let values = list_array.value(self.row_id); + let num_elements = values.len(); + let element_arrow_type = values.data_type(); + let element_fluss_type = from_arrow_type(element_arrow_type)?; + + let mut writer = FlussArrayWriter::new(num_elements, &element_fluss_type); + let element_row = ColumnarRow::new(std::sync::Arc::new( + arrow::array::RecordBatch::try_from_iter(vec![("v", values)]).map_err(|e| { + IllegalArgument { + message: format!("Failed to create RecordBatch from list values: {e}"), + } + })?, + )); + + for i in 0..num_elements { + let mut row = element_row.clone(); + row.set_row_id(i); + if row.is_null_at(0)? { + writer.set_null_at(i); + } else { + write_arrow_value_to_fluss_array(&row, 0, &element_fluss_type, &mut writer, i)?; + } + } + + writer.complete() + } +} + +fn write_arrow_value_to_fluss_array( + row: &ColumnarRow, + col: usize, + element_type: &crate::metadata::DataType, + writer: &mut crate::row::binary_array::FlussArrayWriter, + pos: usize, +) -> Result<()> { + use crate::metadata::DataType; + + match element_type { + DataType::Boolean(_) => writer.write_boolean(pos, row.get_boolean(col)?), + DataType::TinyInt(_) => writer.write_byte(pos, row.get_byte(col)?), + DataType::SmallInt(_) => writer.write_short(pos, row.get_short(col)?), + DataType::Int(_) => writer.write_int(pos, row.get_int(col)?), + DataType::BigInt(_) => writer.write_long(pos, row.get_long(col)?), + DataType::Float(_) => writer.write_float(pos, row.get_float(col)?), + DataType::Double(_) => writer.write_double(pos, row.get_double(col)?), + DataType::Char(t) => writer.write_string(pos, row.get_char(col, t.length() as usize)?), + DataType::String(_) => writer.write_string(pos, row.get_string(col)?), + DataType::Binary(t) => writer.write_binary_bytes(pos, row.get_binary(col, t.length())?), + DataType::Bytes(_) => writer.write_binary_bytes(pos, row.get_bytes(col)?), + DataType::Decimal(dt) => { + let d = row.get_decimal(col, dt.precision() as usize, dt.scale() as usize)?; + writer.write_decimal(pos, &d, dt.precision()); + } + DataType::Date(_) => writer.write_date(pos, row.get_date(col)?), + DataType::Time(_) => writer.write_time(pos, row.get_time(col)?), + DataType::Timestamp(t) => { + let ts = row.get_timestamp_ntz(col, t.precision())?; + writer.write_timestamp_ntz(pos, &ts, t.precision()); + } + DataType::TimestampLTz(t) => { + let ts = row.get_timestamp_ltz(col, t.precision())?; + writer.write_timestamp_ltz(pos, &ts, t.precision()); + } + DataType::Array(_) => { + let nested = row.get_array(col)?; + writer.write_array(pos, &nested); + } + _ => { + return Err(IllegalArgument { + message: format!( + "Unsupported element type for Arrow → FlussArray conversion: {element_type:?}" + ), + }); + } + } + Ok(()) } #[cfg(test)] mod tests { use super::*; use arrow::array::{ - BinaryArray, BooleanArray, Decimal128Array, Float32Array, Float64Array, Int8Array, - Int16Array, Int32Array, Int64Array, StringArray, + ArrayRef, BinaryArray, BooleanArray, Decimal128Array, Float32Array, Float64Array, + Int8Array, Int16Array, Int32Array, Int32Builder, Int64Array, ListBuilder, StringArray, + UInt32Builder, }; use arrow::datatypes::{DataType, Field, Schema}; + fn single_column_row(array: ArrayRef) -> ColumnarRow { + let batch = + RecordBatch::try_from_iter(vec![("arr", array)]).expect("record batch with one column"); + ColumnarRow::new(Arc::new(batch)) + } + #[test] fn columnar_row_reads_values() { let schema = Arc::new(Schema::new(vec![ @@ -533,4 +631,96 @@ mod tests { .unwrap() ); } + + #[test] + fn columnar_row_get_array_int_roundtrip() { + let mut builder = ListBuilder::new(Int32Builder::new()); + builder.values().append_value(1); + builder.values().append_value(2); + builder.values().append_value(3); + builder.append(true); + let array = Arc::new(builder.finish()) as ArrayRef; + + let row = single_column_row(array); + let arr = row.get_array(0).unwrap(); + assert_eq!(arr.size(), 3); + assert_eq!(arr.get_int(0), 1); + assert_eq!(arr.get_int(1), 2); + assert_eq!(arr.get_int(2), 3); + } + + #[test] + fn columnar_row_get_array_with_nulls() { + let mut builder = ListBuilder::new(Int32Builder::new()); + builder.values().append_value(1); + builder.values().append_null(); + builder.values().append_value(3); + builder.append(true); + let array = Arc::new(builder.finish()) as ArrayRef; + + let row = single_column_row(array); + let arr = row.get_array(0).unwrap(); + assert_eq!(arr.size(), 3); + assert_eq!(arr.get_int(0), 1); + assert!(arr.is_null_at(1)); + assert_eq!(arr.get_int(2), 3); + } + + #[test] + fn columnar_row_get_array_nested_array() { + let mut outer = ListBuilder::new(ListBuilder::new(Int32Builder::new())); + + // first nested array: [1, 2] + outer.values().values().append_value(1); + outer.values().values().append_value(2); + outer.values().append(true); + + // second nested array: [99] + outer.values().values().append_value(99); + outer.values().append(true); + + // one row containing two nested arrays + outer.append(true); + let array = Arc::new(outer.finish()) as ArrayRef; + + let row = single_column_row(array); + let arr = row.get_array(0).unwrap(); + assert_eq!(arr.size(), 2); + + let nested0 = arr.get_array(0).unwrap(); + assert_eq!(nested0.size(), 2); + assert_eq!(nested0.get_int(0), 1); + assert_eq!(nested0.get_int(1), 2); + + let nested1 = arr.get_array(1).unwrap(); + assert_eq!(nested1.size(), 1); + assert_eq!(nested1.get_int(0), 99); + } + + #[test] + fn columnar_row_get_array_non_list_column_returns_error() { + let array = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; + let row = single_column_row(array); + let err = row.get_array(0).unwrap_err(); + assert!( + err.to_string().contains("expected List array"), + "unexpected error: {err}" + ); + } + + #[test] + fn columnar_row_get_array_unsupported_element_type_returns_error() { + let mut builder = ListBuilder::new(UInt32Builder::new()); + builder.values().append_value(7); + builder.append(true); + let array = Arc::new(builder.finish()) as ArrayRef; + + let row = single_column_row(array); + let err = row.get_array(0).unwrap_err(); + assert!( + err.to_string() + .contains("Cannot convert Arrow type to Fluss type"), + "unexpected error: {err}" + ); + } } diff --git a/crates/fluss/src/row/compacted/compacted_key_writer.rs b/crates/fluss/src/row/compacted/compacted_key_writer.rs index 339e3661..47d6853e 100644 --- a/crates/fluss/src/row/compacted/compacted_key_writer.rs +++ b/crates/fluss/src/row/compacted/compacted_key_writer.rs @@ -47,6 +47,17 @@ impl CompactedKeyWriter { } pub fn create_value_writer(field_type: &DataType) -> Result { + // Key columns are scalar-only. We reject Array/Map/Row explicitly + // here, so future complex-type writer support does not + // silently widen key semantics. + if matches!( + field_type, + DataType::Array(_) | DataType::Map(_) | DataType::Row(_) + ) { + return Err(crate::error::Error::IllegalArgument { + message: format!("Cannot use {field_type:?} as a key column type"), + }); + } ValueWriter::create_value_writer(field_type, Some(&BinaryRowFormat::Compacted)) } @@ -101,6 +112,8 @@ impl BinaryWriter for CompactedKeyWriter { fn write_timestamp_ntz(&mut self, value: &crate::row::datum::TimestampNtz, precision: u32); fn write_timestamp_ltz(&mut self, value: &crate::row::datum::TimestampLtz, precision: u32); + + fn write_array(&mut self, value: &[u8]); } } diff --git a/crates/fluss/src/row/compacted/compacted_row.rs b/crates/fluss/src/row/compacted/compacted_row.rs index 918ebdfd..12a05a0c 100644 --- a/crates/fluss/src/row/compacted/compacted_row.rs +++ b/crates/fluss/src/row/compacted/compacted_row.rs @@ -29,7 +29,7 @@ use std::sync::{Arc, OnceLock}; pub struct CompactedRow<'a> { arity: usize, size_in_bytes: usize, - decoded_row: OnceLock>, + decoded_row: OnceLock>>, deserializer: Arc>, reader: CompactedRowReader<'a>, data: &'a [u8], @@ -68,9 +68,16 @@ impl<'a> CompactedRow<'a> { self.size_in_bytes } - fn decoded_row(&self) -> &GenericRow<'_> { - self.decoded_row + fn decoded_row(&self) -> Result<&GenericRow<'_>> { + match self + .decoded_row .get_or_init(|| self.deserializer.deserialize(&self.reader)) + { + Ok(row) => Ok(row), + Err(err) => Err(crate::error::Error::IllegalArgument { + message: format!("Failed to deserialize compacted row: {err}"), + }), + } } pub fn as_bytes(&self) -> &[u8] { @@ -97,67 +104,71 @@ impl<'a> InternalRow for CompactedRow<'a> { } fn get_boolean(&self, pos: usize) -> Result { - self.decoded_row().get_boolean(pos) + self.decoded_row()?.get_boolean(pos) } fn get_byte(&self, pos: usize) -> Result { - self.decoded_row().get_byte(pos) + self.decoded_row()?.get_byte(pos) } fn get_short(&self, pos: usize) -> Result { - self.decoded_row().get_short(pos) + self.decoded_row()?.get_short(pos) } fn get_int(&self, pos: usize) -> Result { - self.decoded_row().get_int(pos) + self.decoded_row()?.get_int(pos) } fn get_long(&self, pos: usize) -> Result { - self.decoded_row().get_long(pos) + self.decoded_row()?.get_long(pos) } fn get_float(&self, pos: usize) -> Result { - self.decoded_row().get_float(pos) + self.decoded_row()?.get_float(pos) } fn get_double(&self, pos: usize) -> Result { - self.decoded_row().get_double(pos) + self.decoded_row()?.get_double(pos) } fn get_char(&self, pos: usize, length: usize) -> Result<&str> { - self.decoded_row().get_char(pos, length) + self.decoded_row()?.get_char(pos, length) } fn get_string(&self, pos: usize) -> Result<&str> { - self.decoded_row().get_string(pos) + self.decoded_row()?.get_string(pos) } fn get_decimal(&self, pos: usize, precision: usize, scale: usize) -> Result { - self.decoded_row().get_decimal(pos, precision, scale) + self.decoded_row()?.get_decimal(pos, precision, scale) } fn get_date(&self, pos: usize) -> Result { - self.decoded_row().get_date(pos) + self.decoded_row()?.get_date(pos) } fn get_time(&self, pos: usize) -> Result