1919
2020package org .apache .parquet .variant ;
2121
22-
2322import static org .assertj .core .api .AssertionsForClassTypes .assertThat ;
2423
2524import com .fasterxml .jackson .databind .JsonNode ;
3332import java .nio .file .Paths ;
3433import java .util .List ;
3534import java .util .stream .Stream ;
35+ import org .apache .avro .generic .GenericData ;
36+ import org .apache .avro .generic .GenericRecord ;
3637import org .apache .parquet .Preconditions ;
38+ import org .apache .parquet .avro .AvroParquetReader ;
3739import org .apache .parquet .io .InputFile ;
3840import org .apache .parquet .io .LocalInputFile ;
3941import org .assertj .core .api .Assertions ;
4042import org .junit .jupiter .params .ParameterizedTest ;
4143import org .junit .jupiter .params .provider .Arguments ;
4244import org .junit .jupiter .params .provider .MethodSource ;
43- import org .apache .avro .generic .GenericRecord ;
44- import org .apache .parquet .avro .AvroParquetReader ;
45-
4645
4746public class TestVariantReadsFromFile {
48- // Set this location to generated variant test cases
4947 private static final String CASE_LOCATION = null ;
5048
5149 private static Stream <JsonNode > cases () throws IOException {
@@ -62,50 +60,42 @@ private static Stream<JsonNode> cases() throws IOException {
6260 }
6361
6462 private static Stream <Arguments > errorCases () throws IOException {
65- return cases ()
66- .filter (caseNode -> caseNode .has ("error_message" ) || !caseNode .has ("parquet_file" ))
67- .map (
68- caseNode -> {
69- int caseNumber = JsonUtil .getInt ("case_number" , caseNode );
70- String testName = JsonUtil .getStringOrNull ("test" , caseNode );
71- String parquetFile = JsonUtil .getStringOrNull ("parquet_file" , caseNode );
72- String errorMessage = JsonUtil .getStringOrNull ("error_message" , caseNode );
73- return Arguments .of (caseNumber , testName , parquetFile , errorMessage );
74- });
63+ return cases ().filter (caseNode -> caseNode .has ("error_message" ) || !caseNode .has ("parquet_file" ))
64+ .map (caseNode -> {
65+ int caseNumber = JsonUtil .getInt ("case_number" , caseNode );
66+ String testName = JsonUtil .getStringOrNull ("test" , caseNode );
67+ String parquetFile = JsonUtil .getStringOrNull ("parquet_file" , caseNode );
68+ String errorMessage = JsonUtil .getStringOrNull ("error_message" , caseNode );
69+ return Arguments .of (caseNumber , testName , parquetFile , errorMessage );
70+ });
7571 }
7672
7773 private static Stream <Arguments > singleVariantCases () throws IOException {
78- return cases ()
79- .filter (caseNode -> caseNode .has ("variant_file" ) || !caseNode .has ("parquet_file" ))
80- .map (
81- caseNode -> {
82- int caseNumber = JsonUtil .getInt ("case_number" , caseNode );
83- String testName = JsonUtil .getStringOrNull ("test" , caseNode );
84- String variant = JsonUtil .getStringOrNull ("variant" , caseNode );
85- String parquetFile = JsonUtil .getStringOrNull ("parquet_file" , caseNode );
86- String variantFile = JsonUtil .getStringOrNull ("variant_file" , caseNode );
87- return Arguments .of (caseNumber , testName , variant , parquetFile , variantFile );
88- });
74+ return cases ().filter (caseNode -> caseNode .has ("variant_file" ) || !caseNode .has ("parquet_file" ))
75+ .map (caseNode -> {
76+ int caseNumber = JsonUtil .getInt ("case_number" , caseNode );
77+ String testName = JsonUtil .getStringOrNull ("test" , caseNode );
78+ String variant = JsonUtil .getStringOrNull ("variant" , caseNode );
79+ String parquetFile = JsonUtil .getStringOrNull ("parquet_file" , caseNode );
80+ String variantFile = JsonUtil .getStringOrNull ("variant_file" , caseNode );
81+ return Arguments .of (caseNumber , testName , variant , parquetFile , variantFile );
82+ });
8983 }
9084
9185 private static Stream <Arguments > multiVariantCases () throws IOException {
92- return cases ()
93- .filter (caseNode -> caseNode .has ("variant_files" ) || !caseNode .has ("parquet_file" ))
94- .map (
95- caseNode -> {
96- int caseNumber = JsonUtil .getInt ("case_number" , caseNode );
97- String testName = JsonUtil .getStringOrNull ("test" , caseNode );
98- String parquetFile = JsonUtil .getStringOrNull ("parquet_file" , caseNode );
99- List <String > variantFiles =
100- caseNode .has ("variant_files" )
101- ? Lists .newArrayList (
102- Iterables .transform (
103- caseNode .get ("variant_files" ),
104- node -> node == null || node .isNull () ? null : node .asText ()))
105- : null ;
106- String variants = JsonUtil .getStringOrNull ("variants" , caseNode );
107- return Arguments .of (caseNumber , testName , variants , parquetFile , variantFiles );
108- });
86+ return cases ().filter (caseNode -> caseNode .has ("variant_files" ) || !caseNode .has ("parquet_file" ))
87+ .map (caseNode -> {
88+ int caseNumber = JsonUtil .getInt ("case_number" , caseNode );
89+ String testName = JsonUtil .getStringOrNull ("test" , caseNode );
90+ String parquetFile = JsonUtil .getStringOrNull ("parquet_file" , caseNode );
91+ List <String > variantFiles = caseNode .has ("variant_files" )
92+ ? Lists .newArrayList (Iterables .transform (
93+ caseNode .get ("variant_files" ),
94+ node -> node == null || node .isNull () ? null : node .asText ()))
95+ : null ;
96+ String variants = JsonUtil .getStringOrNull ("variants" , caseNode );
97+ return Arguments .of (caseNumber , testName , variants , parquetFile , variantFiles );
98+ });
10999 }
110100
111101 @ ParameterizedTest
@@ -115,9 +105,8 @@ public void testError(int caseNumber, String testName, String parquetFile, Strin
115105 return ;
116106 }
117107
118- Assertions .assertThatThrownBy (() -> readParquet (parquetFile ))
119- .as ("Test case %s: %s" , caseNumber , testName );
120- //.hasMessage(errorMessage);
108+ Assertions .assertThatThrownBy (() -> readParquet (parquetFile )).as ("Test case %s: %s" , caseNumber , testName );
109+ // .hasMessage(errorMessage);
121110 }
122111
123112 @ ParameterizedTest
@@ -132,19 +121,15 @@ public void testSingleVariant(
132121 Variant expected = readVariant (variantFile );
133122
134123 GenericRecord record = readParquetRecord (parquetFile );
135- Assertions .assertThat (record .get ("var" )).isInstanceOf (Variant .class );
136- Variant actual = (Variant ) record .get ("var" );
124+ Assertions .assertThat (record .get ("var" )).isInstanceOf (GenericData . Record .class );
125+ GenericData . Record actual = (GenericData . Record ) record .get ("var" );
137126 assertEqual (expected , actual );
138127 }
139128
140129 @ ParameterizedTest
141130 @ MethodSource ("multiVariantCases" )
142131 public void testMultiVariant (
143- int caseNumber ,
144- String testName ,
145- String variants ,
146- String parquetFile ,
147- List <String > variantFiles )
132+ int caseNumber , String testName , String variants , String parquetFile , List <String > variantFiles )
148133 throws IOException {
149134 if (parquetFile == null ) {
150135 return ;
@@ -158,8 +143,8 @@ public void testMultiVariant(
158143
159144 if (variantFile != null ) {
160145 Variant expected = readVariant (variantFile );
161- Assertions .assertThat (record .get ("var" )).isInstanceOf (Variant .class );
162- Variant actual = (Variant ) record .get ("var" );
146+ Assertions .assertThat (record .get ("var" )).isInstanceOf (GenericData . Record .class );
147+ GenericData . Record actual = (GenericData . Record ) record .get ("var" );
163148 assertEqual (expected , actual );
164149 } else {
165150 Assertions .assertThat (record .get ("var" )).isNull ();
@@ -187,7 +172,8 @@ private Variant readVariant(String variantFile) throws IOException {
187172 int dictSize = VariantUtil .readUnsigned (variantBuffer , 1 , offsetSize );
188173 int offsetListOffset = 1 + offsetSize ;
189174 int dataOffset = offsetListOffset + ((1 + dictSize ) * offsetSize );
190- int endOffset = dataOffset + VariantUtil .readUnsigned (variantBuffer , offsetListOffset + (offsetSize * dictSize ), offsetSize );
175+ int endOffset = dataOffset
176+ + VariantUtil .readUnsigned (variantBuffer , offsetListOffset + (offsetSize * dictSize ), offsetSize );
191177
192178 return new Variant (VariantUtil .slice (variantBuffer , endOffset ), variantBuffer );
193179 }
@@ -200,7 +186,8 @@ private GenericRecord readParquetRecord(String parquetFile) throws IOException {
200186 private List <GenericRecord > readParquet (String parquetFile ) throws IOException {
201187 org .apache .parquet .io .InputFile inputFile = new LocalInputFile (Paths .get (CASE_LOCATION + "/" + parquetFile ));
202188 List <GenericRecord > records = Lists .newArrayList ();
203- try (org .apache .parquet .hadoop .ParquetReader <GenericRecord > reader = AvroParquetReader .<GenericRecord >builder (inputFile ).build ()) {
189+ try (org .apache .parquet .hadoop .ParquetReader <GenericRecord > reader =
190+ AvroParquetReader .<GenericRecord >builder (inputFile ).build ()) {
204191 GenericRecord record ;
205192 while ((record = reader .read ()) != null ) {
206193 records .add (record );
@@ -209,9 +196,18 @@ private List<GenericRecord> readParquet(String parquetFile) throws IOException {
209196 return records ;
210197 }
211198
199+ private static void assertEqual (Variant expected , GenericData .Record actualRecord ) {
200+ assertThat (actualRecord ).isNotNull ();
201+ assertThat (expected ).isNotNull ();
202+ Variant actual = new Variant ((ByteBuffer ) actualRecord .get ("value" ), (ByteBuffer ) actualRecord .get ("metadata" ));
203+
204+ assertEqual (expected , actual );
205+ }
206+
212207 private static void assertEqual (Variant expected , Variant actual ) {
213208 assertThat (actual ).isNotNull ();
214209 assertThat (expected ).isNotNull ();
210+
215211 assertThat (actual .getType ()).isEqualTo (expected .getType ());
216212
217213 switch (expected .getType ()) {
@@ -266,13 +262,13 @@ private static void assertEqual(Variant expected, Variant actual) {
266262 Variant .ObjectField actualField = actual .getFieldAtIndex (i );
267263
268264 assertThat (actualField .key ).isEqualTo (expectedField .key );
269- assertEqual (actualField .value , actualField .value );
265+ assertEqual (expectedField .value , actualField .value );
270266 }
271267 break ;
272268 case ARRAY :
273269 assertThat (actual .numArrayElements ()).isEqualTo (expected .numArrayElements ());
274270 for (int i = 0 ; i < expected .numArrayElements (); ++i ) {
275- assertEqual (expected .getElementAtIndex (i ),actual .getElementAtIndex (i ));
271+ assertEqual (expected .getElementAtIndex (i ), actual .getElementAtIndex (i ));
276272 }
277273 break ;
278274 default :
0 commit comments