Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -588,10 +588,18 @@ public Optional<LogicalTypeVisitorResult> visit(
int bitWidth = intLogicalType.getBitWidth();

if (bitWidth == 8 || bitWidth == 16 || bitWidth == 32) {
// Iceberg has no unsigned integer type. Reading UINT32 into a 32-bit signed value would
// silently produce negative results for inputs above Integer.MAX_VALUE. UINT8 and UINT16
// both fit losslessly in a signed int32 and are allowed, matching the policy in
// BaseParquetReaders for the non-vectorized path.
Preconditions.checkArgument(
intLogicalType.isSigned() || bitWidth < 32, "Cannot read UINT32 as an int value");
((IntVector) vector).allocateNew(batchSize);
return Optional.of(
new LogicalTypeVisitorResult(vector, ReadType.INT, (int) IntVector.TYPE_WIDTH));
} else if (bitWidth == 64) {
Preconditions.checkArgument(
intLogicalType.isSigned(), "Cannot read UINT64 as a long value");
((BigIntVector) vector).allocateNew(batchSize);
return Optional.of(
new LogicalTypeVisitorResult(vector, ReadType.LONG, (int) BigIntVector.TYPE_WIDTH));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import static org.apache.iceberg.Files.localInput;
import static org.apache.parquet.schema.Types.primitive;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import java.io.File;
import java.io.IOException;
Expand All @@ -41,6 +42,7 @@
import java.util.concurrent.TimeUnit;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.BitVector;
import org.apache.arrow.vector.DateDayVector;
Expand Down Expand Up @@ -101,6 +103,9 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

/**
* Test cases for {@link ArrowReader}.
Expand Down Expand Up @@ -383,6 +388,133 @@ public void testTimestampMillisAreReadCorrectly() throws Exception {
assertThat(totalRowsRead).as("Should read all rows").isEqualTo(millisValues.size());
}

private static Stream<Arguments> rejectedUnsignedIntegerCases() {
return Stream.of(
Arguments.of(
32,
PrimitiveType.PrimitiveTypeName.INT32,
new Schema(Types.NestedField.optional(1, "col", Types.IntegerType.get())),
"Cannot read UINT32 as an int value"),
Arguments.of(
64,
PrimitiveType.PrimitiveTypeName.INT64,
new Schema(Types.NestedField.optional(1, "col", Types.LongType.get())),
"Cannot read UINT64 as a long value"));
}

@ParameterizedTest
@MethodSource("rejectedUnsignedIntegerCases")
public void testUnsignedIntegerColumnThrowsException(
int unsignedBitWidth,
PrimitiveType.PrimitiveTypeName physicalType,
Schema schema,
String expectedMessage)
throws Exception {
tables = new HadoopTables();
Table table = tables.create(schema, tempDir.toURI() + "/uint" + unsignedBitWidth);

MessageType parquetSchema =
new MessageType(
"test",
primitive(physicalType, Type.Repetition.OPTIONAL)
.as(LogicalTypeAnnotation.intType(unsignedBitWidth, false))
.id(1)
.named("col"));

File testFile =
new File(tempDir, "unsigned-int" + unsignedBitWidth + "-" + System.nanoTime() + ".parquet");
try (ParquetWriter<Group> writer =
ExampleParquetWriter.builder(new Path(testFile.toURI())).withType(parquetSchema).build()) {
SimpleGroupFactory factory = new SimpleGroupFactory(parquetSchema);
Group group = factory.newGroup();
if (unsignedBitWidth == 64) {
group.add("col", 100L);
} else {
group.add("col", 100);
}
writer.write(group);
}

DataFile dataFile =
DataFiles.builder(PartitionSpec.unpartitioned())
.withPath(testFile.getAbsolutePath())
.withFileSizeInBytes(testFile.length())
.withFormat(FileFormat.PARQUET)
.withRecordCount(1)
.build();
table.newAppend().appendFile(dataFile).commit();

assertThatThrownBy(
() -> {
try (VectorizedTableScanIterable vectorizedReader =
new VectorizedTableScanIterable(table.newScan(), 1024, false)) {
for (ColumnarBatch batch : vectorizedReader) {
batch.createVectorSchemaRootFromVectors().close();
}
}
})
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining(expectedMessage);
}

@Test
public void testUnsignedSmallIntegerColumnRoundtrips() throws Exception {
tables = new HadoopTables();

for (int[] spec : new int[][] {{8, 250}, {16, 50000}}) {
int unsignedBitWidth = spec[0];
int value = spec[1];

Schema schema = new Schema(Types.NestedField.optional(1, "col", Types.IntegerType.get()));
Table table = tables.create(schema, tempDir.toURI() + "/uint" + unsignedBitWidth);

MessageType parquetSchema =
new MessageType(
"test",
primitive(PrimitiveType.PrimitiveTypeName.INT32, Type.Repetition.OPTIONAL)
.as(LogicalTypeAnnotation.intType(unsignedBitWidth, false))
.id(1)
.named("col"));

File testFile =
new File(
tempDir, "unsigned-int" + unsignedBitWidth + "-" + System.nanoTime() + ".parquet");
try (ParquetWriter<Group> writer =
ExampleParquetWriter.builder(new Path(testFile.toURI()))
.withType(parquetSchema)
.build()) {
SimpleGroupFactory factory = new SimpleGroupFactory(parquetSchema);
Group group = factory.newGroup();
group.add("col", value);
writer.write(group);
}

DataFile dataFile =
DataFiles.builder(PartitionSpec.unpartitioned())
.withPath(testFile.getAbsolutePath())
.withFileSizeInBytes(testFile.length())
.withFormat(FileFormat.PARQUET)
.withRecordCount(1)
.build();
table.newAppend().appendFile(dataFile).commit();

int totalRows = 0;
try (VectorizedTableScanIterable vectorizedReader =
new VectorizedTableScanIterable(table.newScan(), 1024, false)) {
for (ColumnarBatch batch : vectorizedReader) {
VectorSchemaRoot root = batch.createVectorSchemaRootFromVectors();
assertThat(((IntVector) root.getVector("col")).get(0))
.as("UINT%d value should round-trip through int", unsignedBitWidth)
.isEqualTo(value);
totalRows += root.getRowCount();
root.close();
}
}

assertThat(totalRows).isEqualTo(1);
}
}

/**
* Run the following verifications:
*
Expand Down
Loading