diff --git a/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/VectorizedArrowReader.java b/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/VectorizedArrowReader.java index 2cc7cde4541a..e9ebed2826f4 100644 --- a/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/VectorizedArrowReader.java +++ b/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/VectorizedArrowReader.java @@ -588,10 +588,18 @@ public Optional visit( int bitWidth = intLogicalType.getBitWidth(); if (bitWidth == 8 || bitWidth == 16 || bitWidth == 32) { + // Iceberg has no unsigned integer type. Reading UINT32 into a 32-bit signed value would + // silently produce negative results for inputs above Integer.MAX_VALUE. UINT8 and UINT16 + // both fit losslessly in a signed int32 and are allowed, matching the policy in + // BaseParquetReaders for the non-vectorized path. + Preconditions.checkArgument( + intLogicalType.isSigned() || bitWidth < 32, "Cannot read UINT32 as an int value"); ((IntVector) vector).allocateNew(batchSize); return Optional.of( new LogicalTypeVisitorResult(vector, ReadType.INT, (int) IntVector.TYPE_WIDTH)); } else if (bitWidth == 64) { + Preconditions.checkArgument( + intLogicalType.isSigned(), "Cannot read UINT64 as a long value"); ((BigIntVector) vector).allocateNew(batchSize); return Optional.of( new LogicalTypeVisitorResult(vector, ReadType.LONG, (int) BigIntVector.TYPE_WIDTH)); diff --git a/arrow/src/test/java/org/apache/iceberg/arrow/vectorized/TestArrowReader.java b/arrow/src/test/java/org/apache/iceberg/arrow/vectorized/TestArrowReader.java index 34e83de15207..e5412317ea33 100644 --- a/arrow/src/test/java/org/apache/iceberg/arrow/vectorized/TestArrowReader.java +++ b/arrow/src/test/java/org/apache/iceberg/arrow/vectorized/TestArrowReader.java @@ -21,6 +21,7 @@ import static org.apache.iceberg.Files.localInput; import static org.apache.parquet.schema.Types.primitive; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import java.io.File; import java.io.IOException; @@ -41,6 +42,7 @@ import java.util.concurrent.TimeUnit; import java.util.function.BiFunction; import java.util.stream.Collectors; +import java.util.stream.Stream; import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.BitVector; import org.apache.arrow.vector.DateDayVector; @@ -101,6 +103,9 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; /** * Test cases for {@link ArrowReader}. @@ -383,6 +388,133 @@ public void testTimestampMillisAreReadCorrectly() throws Exception { assertThat(totalRowsRead).as("Should read all rows").isEqualTo(millisValues.size()); } + private static Stream rejectedUnsignedIntegerCases() { + return Stream.of( + Arguments.of( + 32, + PrimitiveType.PrimitiveTypeName.INT32, + new Schema(Types.NestedField.optional(1, "col", Types.IntegerType.get())), + "Cannot read UINT32 as an int value"), + Arguments.of( + 64, + PrimitiveType.PrimitiveTypeName.INT64, + new Schema(Types.NestedField.optional(1, "col", Types.LongType.get())), + "Cannot read UINT64 as a long value")); + } + + @ParameterizedTest + @MethodSource("rejectedUnsignedIntegerCases") + public void testUnsignedIntegerColumnThrowsException( + int unsignedBitWidth, + PrimitiveType.PrimitiveTypeName physicalType, + Schema schema, + String expectedMessage) + throws Exception { + tables = new HadoopTables(); + Table table = tables.create(schema, tempDir.toURI() + "/uint" + unsignedBitWidth); + + MessageType parquetSchema = + new MessageType( + "test", + primitive(physicalType, Type.Repetition.OPTIONAL) + .as(LogicalTypeAnnotation.intType(unsignedBitWidth, false)) + .id(1) + .named("col")); + + File testFile = + new File(tempDir, "unsigned-int" + unsignedBitWidth + "-" + System.nanoTime() + ".parquet"); + try (ParquetWriter writer = + ExampleParquetWriter.builder(new Path(testFile.toURI())).withType(parquetSchema).build()) { + SimpleGroupFactory factory = new SimpleGroupFactory(parquetSchema); + Group group = factory.newGroup(); + if (unsignedBitWidth == 64) { + group.add("col", 100L); + } else { + group.add("col", 100); + } + writer.write(group); + } + + DataFile dataFile = + DataFiles.builder(PartitionSpec.unpartitioned()) + .withPath(testFile.getAbsolutePath()) + .withFileSizeInBytes(testFile.length()) + .withFormat(FileFormat.PARQUET) + .withRecordCount(1) + .build(); + table.newAppend().appendFile(dataFile).commit(); + + assertThatThrownBy( + () -> { + try (VectorizedTableScanIterable vectorizedReader = + new VectorizedTableScanIterable(table.newScan(), 1024, false)) { + for (ColumnarBatch batch : vectorizedReader) { + batch.createVectorSchemaRootFromVectors().close(); + } + } + }) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining(expectedMessage); + } + + @Test + public void testUnsignedSmallIntegerColumnRoundtrips() throws Exception { + tables = new HadoopTables(); + + for (int[] spec : new int[][] {{8, 250}, {16, 50000}}) { + int unsignedBitWidth = spec[0]; + int value = spec[1]; + + Schema schema = new Schema(Types.NestedField.optional(1, "col", Types.IntegerType.get())); + Table table = tables.create(schema, tempDir.toURI() + "/uint" + unsignedBitWidth); + + MessageType parquetSchema = + new MessageType( + "test", + primitive(PrimitiveType.PrimitiveTypeName.INT32, Type.Repetition.OPTIONAL) + .as(LogicalTypeAnnotation.intType(unsignedBitWidth, false)) + .id(1) + .named("col")); + + File testFile = + new File( + tempDir, "unsigned-int" + unsignedBitWidth + "-" + System.nanoTime() + ".parquet"); + try (ParquetWriter writer = + ExampleParquetWriter.builder(new Path(testFile.toURI())) + .withType(parquetSchema) + .build()) { + SimpleGroupFactory factory = new SimpleGroupFactory(parquetSchema); + Group group = factory.newGroup(); + group.add("col", value); + writer.write(group); + } + + DataFile dataFile = + DataFiles.builder(PartitionSpec.unpartitioned()) + .withPath(testFile.getAbsolutePath()) + .withFileSizeInBytes(testFile.length()) + .withFormat(FileFormat.PARQUET) + .withRecordCount(1) + .build(); + table.newAppend().appendFile(dataFile).commit(); + + int totalRows = 0; + try (VectorizedTableScanIterable vectorizedReader = + new VectorizedTableScanIterable(table.newScan(), 1024, false)) { + for (ColumnarBatch batch : vectorizedReader) { + VectorSchemaRoot root = batch.createVectorSchemaRootFromVectors(); + assertThat(((IntVector) root.getVector("col")).get(0)) + .as("UINT%d value should round-trip through int", unsignedBitWidth) + .isEqualTo(value); + totalRows += root.getRowCount(); + root.close(); + } + } + + assertThat(totalRows).isEqualTo(1); + } + } + /** * Run the following verifications: *