Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.parquet.filter2.predicate.UserDefinedPredicate;
import org.apache.parquet.hadoop.metadata.ColumnChunkMetaData;
import org.apache.parquet.hadoop.metadata.ColumnPath;
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -39,8 +40,8 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.IntFunction;

import static org.apache.parquet.Preconditions.checkArgument;
import static org.apache.parquet.Preconditions.checkNotNull;


Expand Down Expand Up @@ -86,26 +87,36 @@ private <T extends Comparable<T>> Set<T> expandDictionary(ColumnChunkMetaData me

Dictionary dict = page.getEncoding().initDictionary(col, page);

Set dictSet = new HashSet<T>();

for (int i=0; i<=dict.getMaxId(); i++) {
switch(meta.getType()) {
case BINARY: dictSet.add(dict.decodeToBinary(i));
break;
case INT32: dictSet.add(dict.decodeToInt(i));
break;
case INT64: dictSet.add(dict.decodeToLong(i));
break;
case FLOAT: dictSet.add(dict.decodeToFloat(i));
break;
case DOUBLE: dictSet.add(dict.decodeToDouble(i));
break;
default:
LOG.warn("Unknown dictionary type{}", meta.getType());
}
IntFunction<Object> dictValueProvider;
PrimitiveTypeName type = meta.getPrimitiveType().getPrimitiveTypeName();
switch (type) {
case FIXED_LEN_BYTE_ARRAY: // Same as BINARY
case BINARY:
dictValueProvider = dict::decodeToBinary;
break;
case INT32:
dictValueProvider = dict::decodeToInt;
break;
case INT64:
dictValueProvider = dict::decodeToLong;
break;
case FLOAT:
dictValueProvider = dict::decodeToFloat;
break;
case DOUBLE:
dictValueProvider = dict::decodeToDouble;
break;
default:
LOG.warn("Unsupported dictionary type: {}", type);
return null;
}

return (Set<T>) dictSet;
Set<T> dictSet = new HashSet<>();
for (int i = 0; i <= dict.getMaxId(); i++) {
dictSet.add((T) dictValueProvider.apply(i));
}

return dictSet;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.parquet.column.Encoding;
import org.apache.parquet.column.EncodingStats;
import org.apache.parquet.column.ParquetProperties.WriterVersion;
import org.apache.parquet.column.page.DictionaryPageReadStore;
import org.apache.parquet.example.data.Group;
import org.apache.parquet.example.data.simple.SimpleGroupFactory;
Expand All @@ -51,16 +53,21 @@
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;

import java.io.IOException;
import java.io.Serializable;
import java.math.BigInteger;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.UUID;

import static org.apache.parquet.column.ParquetProperties.WriterVersion.PARQUET_1_0;
import static org.apache.parquet.column.ParquetProperties.WriterVersion.PARQUET_2_0;
import static org.apache.parquet.filter2.dictionarylevel.DictionaryFilter.canDrop;
import static org.apache.parquet.filter2.predicate.FilterApi.*;
import static org.apache.parquet.hadoop.metadata.CompressionCodecName.GZIP;
Expand All @@ -70,21 +77,25 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verifyZeroInteractions;

@RunWith(Parameterized.class)
public class DictionaryFilterTest {

private static final int nElements = 1000;
private static final Configuration conf = new Configuration();
private static Path file = new Path("target/test/TestDictionaryFilter/testParquetFile");
private static final Path FILE_V1 = new Path("target/test/TestDictionaryFilter/testParquetFileV1.parquet");
private static final Path FILE_V2 = new Path("target/test/TestDictionaryFilter/testParquetFileV2.parquet");
private static final MessageType schema = parseMessageType(
"message test { "
+ "required binary binary_field; "
+ "required binary single_value_field; "
+ "required fixed_len_byte_array(17) fixed_field (DECIMAL(40,4)); "
+ "required int32 int32_field; "
+ "required int64 int64_field; "
+ "required double double_field; "
+ "required float float_field; "
+ "required int32 plain_int32_field; "
+ "required binary fallback_binary_field; "
+ "required int96 int96_field; "
+ "} ");

private static final String ALPHABET = "abcdefghijklmnopqrstuvwxyz";
Expand All @@ -96,6 +107,46 @@ public class DictionaryFilterTest {
-100L, 302L, 3333333L, 7654321L, 1234567L, -2000L, -77775L, 0L,
75L, 22223L, 77L, 22221L, -444443L, 205L, 12L, 44444L, 889L, 66665L,
-777889L, -7L, 52L, 33L, -257L, 1111L, 775L, 26L};
private static final Binary[] DECIMAL_VALUES = new Binary[] {
toBinary("-9999999999999999999999999999999999999999", 17),
toBinary("-9999999999999999999999999999999999999998", 17),
toBinary(BigInteger.valueOf(Long.MIN_VALUE).subtract(BigInteger.ONE), 17),
toBinary(BigInteger.valueOf(Long.MIN_VALUE), 17),
toBinary(BigInteger.valueOf(Long.MIN_VALUE).add(BigInteger.ONE), 17),
toBinary("-1", 17),
toBinary("0", 17),
toBinary(BigInteger.valueOf(Long.MAX_VALUE).subtract(BigInteger.ONE), 17),
toBinary(BigInteger.valueOf(Long.MAX_VALUE), 17),
toBinary(BigInteger.valueOf(Long.MAX_VALUE).add(BigInteger.ONE), 17),
toBinary("999999999999999999999999999999999999999", 17),
toBinary("9999999999999999999999999999999999999998", 17),
toBinary("9999999999999999999999999999999999999999", 17)
};
private static final Binary[] INT96_VALUES = new Binary[] {
toBinary("-9999999999999999999999999999", 12),
toBinary("-9999999999999999999999999998", 12),
toBinary("-1234567890", 12),
toBinary("-1", 12),
toBinary("-0", 12),
toBinary("1", 12),
toBinary("1234567890", 12),
toBinary("-9999999999999999999999999998", 12),
toBinary("9999999999999999999999999999", 12)
};

private static Binary toBinary(String decimalWithoutScale, int byteCount) {
return toBinary(new BigInteger(decimalWithoutScale), byteCount);
}

private static Binary toBinary(BigInteger decimalWithoutScale, int byteCount) {
byte[] src = decimalWithoutScale.toByteArray();
if (src.length > byteCount) {
throw new IllegalArgumentException("Too large decimal value for byte count " + byteCount);
}
byte[] dest = new byte[byteCount];
System.arraycopy(src, 0, dest, dest.length - src.length, src.length);
return Binary.fromConstantByteArray(dest);
}

private static void writeData(SimpleGroupFactory f, ParquetWriter<Group> writer) throws IOException {
for (int i = 0; i < nElements; i++) {
Expand All @@ -104,13 +155,15 @@ private static void writeData(SimpleGroupFactory f, ParquetWriter<Group> writer)
Group group = f.newGroup()
.append("binary_field", ALPHABET.substring(index, index+1))
.append("single_value_field", "sharp")
.append("fixed_field", DECIMAL_VALUES[i % DECIMAL_VALUES.length])
.append("int32_field", intValues[i % intValues.length])
.append("int64_field", longValues[i % longValues.length])
.append("double_field", toDouble(intValues[i % intValues.length]))
.append("float_field", toFloat(intValues[i % intValues.length]))
.append("plain_int32_field", i)
.append("fallback_binary_field", i < (nElements / 2) ?
ALPHABET.substring(index, index+1) : UUID.randomUUID().toString());
ALPHABET.substring(index, index+1) : UUID.randomUUID().toString())
.append("int96_field", INT96_VALUES[i % INT96_VALUES.length]);

writer.write(group);
}
Expand All @@ -120,11 +173,15 @@ private static void writeData(SimpleGroupFactory f, ParquetWriter<Group> writer)
@BeforeClass
public static void prepareFile() throws IOException {
cleanup();
prepareFile(PARQUET_1_0, FILE_V1);
prepareFile(PARQUET_2_0, FILE_V2);
}

private static void prepareFile(WriterVersion version, Path file) throws IOException {
GroupWriteSupport.setSchema(schema, conf);
SimpleGroupFactory f = new SimpleGroupFactory(schema);
ParquetWriter<Group> writer = ExampleParquetWriter.builder(file)
.withWriterVersion(PARQUET_1_0)
.withWriterVersion(version)
.withCompressionCodec(GZIP)
.withRowGroupSize(1024*1024)
.withPageSize(1024)
Expand All @@ -137,16 +194,39 @@ public static void prepareFile() throws IOException {

@AfterClass
public static void cleanup() throws IOException {
deleteFile(FILE_V1);
deleteFile(FILE_V2);
}

private static void deleteFile(Path file) throws IOException {
FileSystem fs = file.getFileSystem(conf);
if (fs.exists(file)) {
fs.delete(file, true);
}
}

@Parameters
public static Object[] params() {
return new Object[] {PARQUET_1_0, PARQUET_2_0};
}

List<ColumnChunkMetaData> ccmd;
ParquetFileReader reader;
DictionaryPageReadStore dictionaries;
private Path file;
private WriterVersion version;

public DictionaryFilterTest(WriterVersion version) {
this.version = version;
switch (version) {
case PARQUET_1_0:
file = FILE_V1;
break;
case PARQUET_2_0:
file = FILE_V2;
break;
}
}

@Before
public void setUp() throws Exception {
Expand All @@ -162,25 +242,34 @@ public void tearDown() throws Exception {
}

@Test
@SuppressWarnings("deprecation")
public void testDictionaryEncodedColumns() throws Exception {
switch (version) {
case PARQUET_1_0:
testDictionaryEncodedColumnsV1();
break;
case PARQUET_2_0:
testDictionaryEncodedColumnsV2();
break;
}
}

@SuppressWarnings("deprecation")
private void testDictionaryEncodedColumnsV1() throws Exception {
Set<String> dictionaryEncodedColumns = new HashSet<String>(Arrays.asList(
"binary_field", "single_value_field", "int32_field", "int64_field",
"double_field", "float_field"));
"double_field", "float_field", "int96_field"));
for (ColumnChunkMetaData column : ccmd) {
String name = column.getPath().toDotString();
if (dictionaryEncodedColumns.contains(name)) {
assertTrue("Column should be dictionary encoded: " + name,
column.getEncodings().contains(Encoding.PLAIN_DICTIONARY));
assertFalse("Column should not have plain data pages" + name,
column.getEncodings().contains(Encoding.PLAIN));

} else {
assertTrue("Column should have plain encoding: " + name,
column.getEncodings().contains(Encoding.PLAIN));

if (name.startsWith("fallback")) {
assertTrue("Column should be have some dictionary encoding: " + name,
assertTrue("Column should have some dictionary encoding: " + name,
column.getEncodings().contains(Encoding.PLAIN_DICTIONARY));
} else {
assertFalse("Column should have no dictionary encoding: " + name,
Expand All @@ -190,6 +279,32 @@ public void testDictionaryEncodedColumns() throws Exception {
}
}

private void testDictionaryEncodedColumnsV2() throws Exception {
Set<String> dictionaryEncodedColumns = new HashSet<String>(Arrays.asList(
"binary_field", "single_value_field", "fixed_field", "int32_field",
"int64_field", "double_field", "float_field", "int96_field"));
for (ColumnChunkMetaData column : ccmd) {
EncodingStats encStats = column.getEncodingStats();
String name = column.getPath().toDotString();
if (dictionaryEncodedColumns.contains(name)) {
assertTrue("Column should have dictionary pages: " + name, encStats.hasDictionaryPages());
assertTrue("Column should have dictionary encoded pages: " + name, encStats.hasDictionaryEncodedPages());
assertFalse("Column should not have non-dictionary encoded pages: " + name,
encStats.hasNonDictionaryEncodedPages());
} else {
assertTrue("Column should have non-dictionary encoded pages: " + name,
encStats.hasNonDictionaryEncodedPages());
if (name.startsWith("fallback")) {
assertTrue("Column should have dictionary pages: " + name, encStats.hasDictionaryPages());
assertTrue("Column should have dictionary encoded pages: " + name, encStats.hasDictionaryEncodedPages());
} else {
assertFalse("Column should not have dictionary pages: " + name, encStats.hasDictionaryPages());
assertFalse("Column should not have dictionary encoded pages: " + name, encStats.hasDictionaryEncodedPages());
}
}
}
}

@Test
public void testEqBinary() throws Exception {
BinaryColumn b = binaryColumn("binary_field");
Expand All @@ -205,6 +320,38 @@ public void testEqBinary() throws Exception {
canDrop(eq(b, null), ccmd, dictionaries));
}

@Test
public void testEqFixed() throws Exception {
BinaryColumn b = binaryColumn("fixed_field");

// Only V2 supports dictionary encoding for FIXED_LEN_BYTE_ARRAY values
if (version == PARQUET_2_0) {
assertTrue("Should drop block for -2",
canDrop(eq(b, toBinary("-2", 17)), ccmd, dictionaries));
}

assertFalse("Should not drop block for -1",
canDrop(eq(b, toBinary("-1", 17)), ccmd, dictionaries));

assertFalse("Should not drop block for null",
canDrop(eq(b, null), ccmd, dictionaries));
}

@Test
public void testEqInt96() throws Exception {
BinaryColumn b = binaryColumn("int96_field");

// INT96 ordering is undefined => no filtering shall be done
assertFalse("Should not drop block for -2",
canDrop(eq(b, toBinary("-2", 12)), ccmd, dictionaries));

assertFalse("Should not drop block for -1",
canDrop(eq(b, toBinary("-1", 12)), ccmd, dictionaries));

assertFalse("Should not drop block for null",
canDrop(eq(b, null), ccmd, dictionaries));
}

@Test
public void testNotEqBinary() throws Exception {
BinaryColumn sharp = binaryColumn("single_value_field");
Expand Down Expand Up @@ -243,6 +390,20 @@ public void testLtInt() throws Exception {
canDrop(lt(i32, Integer.MAX_VALUE), ccmd, dictionaries));
}

@Test
public void testLtFixed() throws Exception {
BinaryColumn fixed = binaryColumn("fixed_field");

// Only V2 supports dictionary encoding for FIXED_LEN_BYTE_ARRAY values
if (version == PARQUET_2_0) {
assertTrue("Should drop: < lowest value",
canDrop(lt(fixed, DECIMAL_VALUES[0]), ccmd, dictionaries));
}

assertFalse("Should not drop: < 2nd lowest value",
canDrop(lt(fixed, DECIMAL_VALUES[1]), ccmd, dictionaries));
}

@Test
public void testLtEqLong() throws Exception {
LongColumn i64 = longColumn("int64_field");
Expand Down