diff --git a/parquet-hadoop/src/main/java/org/apache/parquet/filter2/dictionarylevel/DictionaryFilter.java b/parquet-hadoop/src/main/java/org/apache/parquet/filter2/dictionarylevel/DictionaryFilter.java index eaba2c1cb8..ecd104327f 100644 --- a/parquet-hadoop/src/main/java/org/apache/parquet/filter2/dictionarylevel/DictionaryFilter.java +++ b/parquet-hadoop/src/main/java/org/apache/parquet/filter2/dictionarylevel/DictionaryFilter.java @@ -29,6 +29,7 @@ import org.apache.parquet.filter2.predicate.UserDefinedPredicate; import org.apache.parquet.hadoop.metadata.ColumnChunkMetaData; import org.apache.parquet.hadoop.metadata.ColumnPath; +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -39,8 +40,8 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.IntFunction; -import static org.apache.parquet.Preconditions.checkArgument; import static org.apache.parquet.Preconditions.checkNotNull; @@ -86,26 +87,36 @@ private > Set expandDictionary(ColumnChunkMetaData me Dictionary dict = page.getEncoding().initDictionary(col, page); - Set dictSet = new HashSet(); - - for (int i=0; i<=dict.getMaxId(); i++) { - switch(meta.getType()) { - case BINARY: dictSet.add(dict.decodeToBinary(i)); - break; - case INT32: dictSet.add(dict.decodeToInt(i)); - break; - case INT64: dictSet.add(dict.decodeToLong(i)); - break; - case FLOAT: dictSet.add(dict.decodeToFloat(i)); - break; - case DOUBLE: dictSet.add(dict.decodeToDouble(i)); - break; - default: - LOG.warn("Unknown dictionary type{}", meta.getType()); - } + IntFunction dictValueProvider; + PrimitiveTypeName type = meta.getPrimitiveType().getPrimitiveTypeName(); + switch (type) { + case FIXED_LEN_BYTE_ARRAY: // Same as BINARY + case BINARY: + dictValueProvider = dict::decodeToBinary; + break; + case INT32: + dictValueProvider = dict::decodeToInt; + break; + case INT64: + dictValueProvider = dict::decodeToLong; + break; + case FLOAT: + dictValueProvider = dict::decodeToFloat; + break; + case DOUBLE: + dictValueProvider = dict::decodeToDouble; + break; + default: + LOG.warn("Unsupported dictionary type: {}", type); + return null; } - return (Set) dictSet; + Set dictSet = new HashSet<>(); + for (int i = 0; i <= dict.getMaxId(); i++) { + dictSet.add((T) dictValueProvider.apply(i)); + } + + return dictSet; } @Override diff --git a/parquet-hadoop/src/test/java/org/apache/parquet/filter2/dictionarylevel/DictionaryFilterTest.java b/parquet-hadoop/src/test/java/org/apache/parquet/filter2/dictionarylevel/DictionaryFilterTest.java index 3883d8727f..39db6d4be8 100644 --- a/parquet-hadoop/src/test/java/org/apache/parquet/filter2/dictionarylevel/DictionaryFilterTest.java +++ b/parquet-hadoop/src/test/java/org/apache/parquet/filter2/dictionarylevel/DictionaryFilterTest.java @@ -26,6 +26,8 @@ import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.parquet.column.Encoding; +import org.apache.parquet.column.EncodingStats; +import org.apache.parquet.column.ParquetProperties.WriterVersion; import org.apache.parquet.column.page.DictionaryPageReadStore; import org.apache.parquet.example.data.Group; import org.apache.parquet.example.data.simple.SimpleGroupFactory; @@ -51,9 +53,13 @@ import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; import java.io.IOException; import java.io.Serializable; +import java.math.BigInteger; import java.util.Arrays; import java.util.HashSet; import java.util.List; @@ -61,6 +67,7 @@ import java.util.UUID; import static org.apache.parquet.column.ParquetProperties.WriterVersion.PARQUET_1_0; +import static org.apache.parquet.column.ParquetProperties.WriterVersion.PARQUET_2_0; import static org.apache.parquet.filter2.dictionarylevel.DictionaryFilter.canDrop; import static org.apache.parquet.filter2.predicate.FilterApi.*; import static org.apache.parquet.hadoop.metadata.CompressionCodecName.GZIP; @@ -70,21 +77,25 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verifyZeroInteractions; +@RunWith(Parameterized.class) public class DictionaryFilterTest { private static final int nElements = 1000; private static final Configuration conf = new Configuration(); - private static Path file = new Path("target/test/TestDictionaryFilter/testParquetFile"); + private static final Path FILE_V1 = new Path("target/test/TestDictionaryFilter/testParquetFileV1.parquet"); + private static final Path FILE_V2 = new Path("target/test/TestDictionaryFilter/testParquetFileV2.parquet"); private static final MessageType schema = parseMessageType( "message test { " + "required binary binary_field; " + "required binary single_value_field; " + + "required fixed_len_byte_array(17) fixed_field (DECIMAL(40,4)); " + "required int32 int32_field; " + "required int64 int64_field; " + "required double double_field; " + "required float float_field; " + "required int32 plain_int32_field; " + "required binary fallback_binary_field; " + + "required int96 int96_field; " + "} "); private static final String ALPHABET = "abcdefghijklmnopqrstuvwxyz"; @@ -96,6 +107,46 @@ public class DictionaryFilterTest { -100L, 302L, 3333333L, 7654321L, 1234567L, -2000L, -77775L, 0L, 75L, 22223L, 77L, 22221L, -444443L, 205L, 12L, 44444L, 889L, 66665L, -777889L, -7L, 52L, 33L, -257L, 1111L, 775L, 26L}; + private static final Binary[] DECIMAL_VALUES = new Binary[] { + toBinary("-9999999999999999999999999999999999999999", 17), + toBinary("-9999999999999999999999999999999999999998", 17), + toBinary(BigInteger.valueOf(Long.MIN_VALUE).subtract(BigInteger.ONE), 17), + toBinary(BigInteger.valueOf(Long.MIN_VALUE), 17), + toBinary(BigInteger.valueOf(Long.MIN_VALUE).add(BigInteger.ONE), 17), + toBinary("-1", 17), + toBinary("0", 17), + toBinary(BigInteger.valueOf(Long.MAX_VALUE).subtract(BigInteger.ONE), 17), + toBinary(BigInteger.valueOf(Long.MAX_VALUE), 17), + toBinary(BigInteger.valueOf(Long.MAX_VALUE).add(BigInteger.ONE), 17), + toBinary("999999999999999999999999999999999999999", 17), + toBinary("9999999999999999999999999999999999999998", 17), + toBinary("9999999999999999999999999999999999999999", 17) + }; + private static final Binary[] INT96_VALUES = new Binary[] { + toBinary("-9999999999999999999999999999", 12), + toBinary("-9999999999999999999999999998", 12), + toBinary("-1234567890", 12), + toBinary("-1", 12), + toBinary("-0", 12), + toBinary("1", 12), + toBinary("1234567890", 12), + toBinary("-9999999999999999999999999998", 12), + toBinary("9999999999999999999999999999", 12) + }; + + private static Binary toBinary(String decimalWithoutScale, int byteCount) { + return toBinary(new BigInteger(decimalWithoutScale), byteCount); + } + + private static Binary toBinary(BigInteger decimalWithoutScale, int byteCount) { + byte[] src = decimalWithoutScale.toByteArray(); + if (src.length > byteCount) { + throw new IllegalArgumentException("Too large decimal value for byte count " + byteCount); + } + byte[] dest = new byte[byteCount]; + System.arraycopy(src, 0, dest, dest.length - src.length, src.length); + return Binary.fromConstantByteArray(dest); + } private static void writeData(SimpleGroupFactory f, ParquetWriter writer) throws IOException { for (int i = 0; i < nElements; i++) { @@ -104,13 +155,15 @@ private static void writeData(SimpleGroupFactory f, ParquetWriter writer) Group group = f.newGroup() .append("binary_field", ALPHABET.substring(index, index+1)) .append("single_value_field", "sharp") + .append("fixed_field", DECIMAL_VALUES[i % DECIMAL_VALUES.length]) .append("int32_field", intValues[i % intValues.length]) .append("int64_field", longValues[i % longValues.length]) .append("double_field", toDouble(intValues[i % intValues.length])) .append("float_field", toFloat(intValues[i % intValues.length])) .append("plain_int32_field", i) .append("fallback_binary_field", i < (nElements / 2) ? - ALPHABET.substring(index, index+1) : UUID.randomUUID().toString()); + ALPHABET.substring(index, index+1) : UUID.randomUUID().toString()) + .append("int96_field", INT96_VALUES[i % INT96_VALUES.length]); writer.write(group); } @@ -120,11 +173,15 @@ private static void writeData(SimpleGroupFactory f, ParquetWriter writer) @BeforeClass public static void prepareFile() throws IOException { cleanup(); + prepareFile(PARQUET_1_0, FILE_V1); + prepareFile(PARQUET_2_0, FILE_V2); + } + private static void prepareFile(WriterVersion version, Path file) throws IOException { GroupWriteSupport.setSchema(schema, conf); SimpleGroupFactory f = new SimpleGroupFactory(schema); ParquetWriter writer = ExampleParquetWriter.builder(file) - .withWriterVersion(PARQUET_1_0) + .withWriterVersion(version) .withCompressionCodec(GZIP) .withRowGroupSize(1024*1024) .withPageSize(1024) @@ -137,16 +194,39 @@ public static void prepareFile() throws IOException { @AfterClass public static void cleanup() throws IOException { + deleteFile(FILE_V1); + deleteFile(FILE_V2); + } + + private static void deleteFile(Path file) throws IOException { FileSystem fs = file.getFileSystem(conf); if (fs.exists(file)) { fs.delete(file, true); } } + @Parameters + public static Object[] params() { + return new Object[] {PARQUET_1_0, PARQUET_2_0}; + } List ccmd; ParquetFileReader reader; DictionaryPageReadStore dictionaries; + private Path file; + private WriterVersion version; + + public DictionaryFilterTest(WriterVersion version) { + this.version = version; + switch (version) { + case PARQUET_1_0: + file = FILE_V1; + break; + case PARQUET_2_0: + file = FILE_V2; + break; + } + } @Before public void setUp() throws Exception { @@ -162,11 +242,22 @@ public void tearDown() throws Exception { } @Test - @SuppressWarnings("deprecation") public void testDictionaryEncodedColumns() throws Exception { + switch (version) { + case PARQUET_1_0: + testDictionaryEncodedColumnsV1(); + break; + case PARQUET_2_0: + testDictionaryEncodedColumnsV2(); + break; + } + } + + @SuppressWarnings("deprecation") + private void testDictionaryEncodedColumnsV1() throws Exception { Set dictionaryEncodedColumns = new HashSet(Arrays.asList( "binary_field", "single_value_field", "int32_field", "int64_field", - "double_field", "float_field")); + "double_field", "float_field", "int96_field")); for (ColumnChunkMetaData column : ccmd) { String name = column.getPath().toDotString(); if (dictionaryEncodedColumns.contains(name)) { @@ -174,13 +265,11 @@ public void testDictionaryEncodedColumns() throws Exception { column.getEncodings().contains(Encoding.PLAIN_DICTIONARY)); assertFalse("Column should not have plain data pages" + name, column.getEncodings().contains(Encoding.PLAIN)); - } else { assertTrue("Column should have plain encoding: " + name, column.getEncodings().contains(Encoding.PLAIN)); - if (name.startsWith("fallback")) { - assertTrue("Column should be have some dictionary encoding: " + name, + assertTrue("Column should have some dictionary encoding: " + name, column.getEncodings().contains(Encoding.PLAIN_DICTIONARY)); } else { assertFalse("Column should have no dictionary encoding: " + name, @@ -190,6 +279,32 @@ public void testDictionaryEncodedColumns() throws Exception { } } + private void testDictionaryEncodedColumnsV2() throws Exception { + Set dictionaryEncodedColumns = new HashSet(Arrays.asList( + "binary_field", "single_value_field", "fixed_field", "int32_field", + "int64_field", "double_field", "float_field", "int96_field")); + for (ColumnChunkMetaData column : ccmd) { + EncodingStats encStats = column.getEncodingStats(); + String name = column.getPath().toDotString(); + if (dictionaryEncodedColumns.contains(name)) { + assertTrue("Column should have dictionary pages: " + name, encStats.hasDictionaryPages()); + assertTrue("Column should have dictionary encoded pages: " + name, encStats.hasDictionaryEncodedPages()); + assertFalse("Column should not have non-dictionary encoded pages: " + name, + encStats.hasNonDictionaryEncodedPages()); + } else { + assertTrue("Column should have non-dictionary encoded pages: " + name, + encStats.hasNonDictionaryEncodedPages()); + if (name.startsWith("fallback")) { + assertTrue("Column should have dictionary pages: " + name, encStats.hasDictionaryPages()); + assertTrue("Column should have dictionary encoded pages: " + name, encStats.hasDictionaryEncodedPages()); + } else { + assertFalse("Column should not have dictionary pages: " + name, encStats.hasDictionaryPages()); + assertFalse("Column should not have dictionary encoded pages: " + name, encStats.hasDictionaryEncodedPages()); + } + } + } + } + @Test public void testEqBinary() throws Exception { BinaryColumn b = binaryColumn("binary_field"); @@ -205,6 +320,38 @@ public void testEqBinary() throws Exception { canDrop(eq(b, null), ccmd, dictionaries)); } + @Test + public void testEqFixed() throws Exception { + BinaryColumn b = binaryColumn("fixed_field"); + + // Only V2 supports dictionary encoding for FIXED_LEN_BYTE_ARRAY values + if (version == PARQUET_2_0) { + assertTrue("Should drop block for -2", + canDrop(eq(b, toBinary("-2", 17)), ccmd, dictionaries)); + } + + assertFalse("Should not drop block for -1", + canDrop(eq(b, toBinary("-1", 17)), ccmd, dictionaries)); + + assertFalse("Should not drop block for null", + canDrop(eq(b, null), ccmd, dictionaries)); + } + + @Test + public void testEqInt96() throws Exception { + BinaryColumn b = binaryColumn("int96_field"); + + // INT96 ordering is undefined => no filtering shall be done + assertFalse("Should not drop block for -2", + canDrop(eq(b, toBinary("-2", 12)), ccmd, dictionaries)); + + assertFalse("Should not drop block for -1", + canDrop(eq(b, toBinary("-1", 12)), ccmd, dictionaries)); + + assertFalse("Should not drop block for null", + canDrop(eq(b, null), ccmd, dictionaries)); + } + @Test public void testNotEqBinary() throws Exception { BinaryColumn sharp = binaryColumn("single_value_field"); @@ -243,6 +390,20 @@ public void testLtInt() throws Exception { canDrop(lt(i32, Integer.MAX_VALUE), ccmd, dictionaries)); } + @Test + public void testLtFixed() throws Exception { + BinaryColumn fixed = binaryColumn("fixed_field"); + + // Only V2 supports dictionary encoding for FIXED_LEN_BYTE_ARRAY values + if (version == PARQUET_2_0) { + assertTrue("Should drop: < lowest value", + canDrop(lt(fixed, DECIMAL_VALUES[0]), ccmd, dictionaries)); + } + + assertFalse("Should not drop: < 2nd lowest value", + canDrop(lt(fixed, DECIMAL_VALUES[1]), ccmd, dictionaries)); + } + @Test public void testLtEqLong() throws Exception { LongColumn i64 = longColumn("int64_field");