Skip to content

Commit 8bb0546

Browse files
bersprocketsdongjoon-hyun
authored andcommitted
[SPARK-44805][SQL] getBytes/getShorts/getInts/etc. should work in a column vector that has a dictionary
Change getBytes/getShorts/getInts/getLongs/getFloats/getDoubles in `OnHeapColumnVector` and `OffHeapColumnVector` to use the dictionary, if present. The following query gets incorrect results: ``` drop table if exists t1; create table t1 using parquet as select * from values (named_struct('f1', array(1, 2, 3), 'f2', array(1, 1, 2))) as (value); select cast(value as struct<f1:array<double>,f2:array<int>>) AS value from t1; {"f1":[1.0,2.0,3.0],"f2":[0,0,0]} ``` The result should be: ``` {"f1":[1.0,2.0,3.0],"f2":[1,2,3]} ``` The cast operation copies the second array by calling `ColumnarArray#copy`, which in turn calls `ColumnarArray#toIntArray`, which in turn calls `ColumnVector#getInts` on the underlying column vector (which is either an `OnHeapColumnVector` or an `OffHeapColumnVector`). The implementation of `getInts` in either concrete class assumes there is no dictionary and does not use it if it is present (in fact, it even asserts that there is no dictionary). However, in the above example, the column vector associated with the second array does have a dictionary: ``` java -cp ~/github/parquet-mr/parquet-tools/target/parquet-tools-1.10.1.jar org.apache.parquet.tools.Main meta ./spark-warehouse/t1/part-00000-122fdd53-8166-407b-aec5-08e0c2845c3d-c000.snappy.parquet ... row group 1: RC:1 TS:112 OFFSET:4 ------------------------------------------------------------------------------------------------------------------------------------------------------- value: .f1: ..list: ...element: INT32 SNAPPY DO:0 FPO:4 SZ:47/47/1.00 VC:3 ENC:RLE,PLAIN ST:[min: 1, max: 3, num_nulls: 0] .f2: ..list: ...element: INT32 SNAPPY DO:51 FPO:80 SZ:69/65/0.94 VC:3 ENC:RLE,PLAIN_DICTIONARY ST:[min: 1, max: 2, num_nulls: 0] ``` The same bug also occurs when field f2 is a map. This PR fixes that case as well. No, except for fixing the correctness issue. New tests. No. Closes #42850 from bersprockets/vector_oddity. Authored-by: Bruce Robbins <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]> (cherry picked from commit fac236e) Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent a4d40e8 commit 8bb0546

File tree

5 files changed

+186
-31
lines changed

5 files changed

+186
-31
lines changed

sql/core/src/main/java/org/apache/spark/sql/execution/columnar/ColumnDictionary.java

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
public final class ColumnDictionary implements Dictionary {
2323
private int[] intDictionary;
2424
private long[] longDictionary;
25+
private float[] floatDictionary;
26+
private double[] doubleDictionary;
2527

2628
public ColumnDictionary(int[] dictionary) {
2729
this.intDictionary = dictionary;
@@ -31,6 +33,14 @@ public ColumnDictionary(long[] dictionary) {
3133
this.longDictionary = dictionary;
3234
}
3335

36+
public ColumnDictionary(float[] dictionary) {
37+
this.floatDictionary = dictionary;
38+
}
39+
40+
public ColumnDictionary(double[] dictionary) {
41+
this.doubleDictionary = dictionary;
42+
}
43+
3444
@Override
3545
public int decodeToInt(int id) {
3646
return intDictionary[id];
@@ -42,14 +52,10 @@ public long decodeToLong(int id) {
4252
}
4353

4454
@Override
45-
public float decodeToFloat(int id) {
46-
throw new UnsupportedOperationException("Dictionary encoding does not support float");
47-
}
55+
public float decodeToFloat(int id) { return floatDictionary[id]; }
4856

4957
@Override
50-
public double decodeToDouble(int id) {
51-
throw new UnsupportedOperationException("Dictionary encoding does not support double");
52-
}
58+
public double decodeToDouble(int id) { return doubleDictionary[id]; }
5359

5460
@Override
5561
public byte[] decodeToBinary(int id) {

sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,14 @@ public byte getByte(int rowId) {
210210

211211
@Override
212212
public byte[] getBytes(int rowId, int count) {
213-
assert(dictionary == null);
214213
byte[] array = new byte[count];
215-
Platform.copyMemory(null, data + rowId, array, Platform.BYTE_ARRAY_OFFSET, count);
214+
if (dictionary == null) {
215+
Platform.copyMemory(null, data + rowId, array, Platform.BYTE_ARRAY_OFFSET, count);
216+
} else {
217+
for (int i = 0; i < count; i++) {
218+
array[i] = getByte(rowId + i);
219+
}
220+
}
216221
return array;
217222
}
218223

@@ -266,9 +271,14 @@ public short getShort(int rowId) {
266271

267272
@Override
268273
public short[] getShorts(int rowId, int count) {
269-
assert(dictionary == null);
270274
short[] array = new short[count];
271-
Platform.copyMemory(null, data + rowId * 2L, array, Platform.SHORT_ARRAY_OFFSET, count * 2L);
275+
if (dictionary == null) {
276+
Platform.copyMemory(null, data + rowId * 2L, array, Platform.SHORT_ARRAY_OFFSET, count * 2L);
277+
} else {
278+
for (int i = 0; i < count; i++) {
279+
array[i] = getShort(rowId + i);
280+
}
281+
}
272282
return array;
273283
}
274284

@@ -327,9 +337,14 @@ public int getInt(int rowId) {
327337

328338
@Override
329339
public int[] getInts(int rowId, int count) {
330-
assert(dictionary == null);
331340
int[] array = new int[count];
332-
Platform.copyMemory(null, data + rowId * 4L, array, Platform.INT_ARRAY_OFFSET, count * 4L);
341+
if (dictionary == null) {
342+
Platform.copyMemory(null, data + rowId * 4L, array, Platform.INT_ARRAY_OFFSET, count * 4L);
343+
} else {
344+
for (int i = 0; i < count; i++) {
345+
array[i] = getInt(rowId + i);
346+
}
347+
}
333348
return array;
334349
}
335350

@@ -399,9 +414,14 @@ public long getLong(int rowId) {
399414

400415
@Override
401416
public long[] getLongs(int rowId, int count) {
402-
assert(dictionary == null);
403417
long[] array = new long[count];
404-
Platform.copyMemory(null, data + rowId * 8L, array, Platform.LONG_ARRAY_OFFSET, count * 8L);
418+
if (dictionary == null) {
419+
Platform.copyMemory(null, data + rowId * 8L, array, Platform.LONG_ARRAY_OFFSET, count * 8L);
420+
} else {
421+
for (int i = 0; i < count; i++) {
422+
array[i] = getLong(rowId + i);
423+
}
424+
}
405425
return array;
406426
}
407427

@@ -458,9 +478,14 @@ public float getFloat(int rowId) {
458478

459479
@Override
460480
public float[] getFloats(int rowId, int count) {
461-
assert(dictionary == null);
462481
float[] array = new float[count];
463-
Platform.copyMemory(null, data + rowId * 4L, array, Platform.FLOAT_ARRAY_OFFSET, count * 4L);
482+
if (dictionary == null) {
483+
Platform.copyMemory(null, data + rowId * 4L, array, Platform.FLOAT_ARRAY_OFFSET, count * 4L);
484+
} else {
485+
for (int i = 0; i < count; i++) {
486+
array[i] = getFloat(rowId + i);
487+
}
488+
}
464489
return array;
465490
}
466491

@@ -518,9 +543,15 @@ public double getDouble(int rowId) {
518543

519544
@Override
520545
public double[] getDoubles(int rowId, int count) {
521-
assert(dictionary == null);
522546
double[] array = new double[count];
523-
Platform.copyMemory(null, data + rowId * 8L, array, Platform.DOUBLE_ARRAY_OFFSET, count * 8L);
547+
if (dictionary == null) {
548+
Platform.copyMemory(null, data + rowId * 8L, array, Platform.DOUBLE_ARRAY_OFFSET,
549+
count * 8L);
550+
} else {
551+
for (int i = 0; i < count; i++) {
552+
array[i] = getDouble(rowId + i);
553+
}
554+
}
524555
return array;
525556
}
526557

sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -208,9 +208,14 @@ public byte getByte(int rowId) {
208208

209209
@Override
210210
public byte[] getBytes(int rowId, int count) {
211-
assert(dictionary == null);
212211
byte[] array = new byte[count];
213-
System.arraycopy(byteData, rowId, array, 0, count);
212+
if (dictionary == null) {
213+
System.arraycopy(byteData, rowId, array, 0, count);
214+
} else {
215+
for (int i = 0; i < count; i++) {
216+
array[i] = getByte(rowId + i);
217+
}
218+
}
214219
return array;
215220
}
216221

@@ -263,9 +268,14 @@ public short getShort(int rowId) {
263268

264269
@Override
265270
public short[] getShorts(int rowId, int count) {
266-
assert(dictionary == null);
267271
short[] array = new short[count];
268-
System.arraycopy(shortData, rowId, array, 0, count);
272+
if (dictionary == null) {
273+
System.arraycopy(shortData, rowId, array, 0, count);
274+
} else {
275+
for (int i = 0; i < count; i++) {
276+
array[i] = getShort(rowId + i);
277+
}
278+
}
269279
return array;
270280
}
271281

@@ -319,9 +329,14 @@ public int getInt(int rowId) {
319329

320330
@Override
321331
public int[] getInts(int rowId, int count) {
322-
assert(dictionary == null);
323332
int[] array = new int[count];
324-
System.arraycopy(intData, rowId, array, 0, count);
333+
if (dictionary == null) {
334+
System.arraycopy(intData, rowId, array, 0, count);
335+
} else {
336+
for (int i = 0; i < count; i++) {
337+
array[i] = getInt(rowId + i);
338+
}
339+
}
325340
return array;
326341
}
327342

@@ -385,9 +400,14 @@ public long getLong(int rowId) {
385400

386401
@Override
387402
public long[] getLongs(int rowId, int count) {
388-
assert(dictionary == null);
389403
long[] array = new long[count];
390-
System.arraycopy(longData, rowId, array, 0, count);
404+
if (dictionary == null) {
405+
System.arraycopy(longData, rowId, array, 0, count);
406+
} else {
407+
for (int i = 0; i < count; i++) {
408+
array[i] = getLong(rowId + i);
409+
}
410+
}
391411
return array;
392412
}
393413

@@ -437,9 +457,14 @@ public float getFloat(int rowId) {
437457

438458
@Override
439459
public float[] getFloats(int rowId, int count) {
440-
assert(dictionary == null);
441460
float[] array = new float[count];
442-
System.arraycopy(floatData, rowId, array, 0, count);
461+
if (dictionary == null) {
462+
System.arraycopy(floatData, rowId, array, 0, count);
463+
} else {
464+
for (int i = 0; i < count; i++) {
465+
array[i] = getFloat(rowId + i);
466+
}
467+
}
443468
return array;
444469
}
445470

@@ -491,9 +516,14 @@ public double getDouble(int rowId) {
491516

492517
@Override
493518
public double[] getDoubles(int rowId, int count) {
494-
assert(dictionary == null);
495519
double[] array = new double[count];
496-
System.arraycopy(doubleData, rowId, array, 0, count);
520+
if (dictionary == null) {
521+
System.arraycopy(doubleData, rowId, array, 0, count);
522+
} else {
523+
for (int i = 0; i < count; i++) {
524+
array[i] = getDouble(rowId + i);
525+
}
526+
}
497527
return array;
498528
}
499529

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,6 +1014,16 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS
10141014
checkAnswer(sql("select * from tbl"), expected)
10151015
}
10161016
}
1017+
1018+
test("SPARK-44805: cast of struct with two arrays") {
1019+
withTable("tbl") {
1020+
sql("create table tbl (value struct<f1:array<int>,f2:array<int>>) using parquet")
1021+
sql("insert into tbl values (named_struct('f1', array(1, 2, 3), 'f2', array(1, 1, 2)))")
1022+
val df = sql("select cast(value as struct<f1:array<double>,f2:array<int>>) AS value from tbl")
1023+
val expected = Row(Row(Array(1.0d, 2.0d, 3.0d), Array(1, 1, 2))) :: Nil
1024+
checkAnswer(df, expected)
1025+
}
1026+
}
10171027
}
10181028

10191029
class ParquetV1QuerySuite extends ParquetQuerySuite {

sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import org.scalatest.BeforeAndAfterEach
2121

2222
import org.apache.spark.SparkFunSuite
2323
import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
24-
import org.apache.spark.sql.execution.columnar.ColumnAccessor
24+
import org.apache.spark.sql.execution.columnar.{ColumnAccessor, ColumnDictionary}
2525
import org.apache.spark.sql.execution.columnar.compression.ColumnBuilderHelper
2626
import org.apache.spark.sql.types._
2727
import org.apache.spark.sql.vectorized.ColumnarArray
@@ -383,6 +383,84 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach {
383383
assert(testVector.getStruct(1).get(1, DoubleType) === 5.67)
384384
}
385385

386+
testVectors("SPARK-44805: getInts with dictionary", 3, IntegerType) { testVector =>
387+
val dict = new ColumnDictionary(Array[Int](7, 8, 9))
388+
testVector.setDictionary(dict)
389+
testVector.reserveDictionaryIds(3)
390+
testVector.getDictionaryIds.putInt(0, 0)
391+
testVector.getDictionaryIds.putInt(1, 1)
392+
testVector.getDictionaryIds.putInt(2, 2)
393+
394+
assert(testVector.getInts(0, 3)(0) == 7)
395+
assert(testVector.getInts(0, 3)(1) == 8)
396+
assert(testVector.getInts(0, 3)(2) == 9)
397+
}
398+
399+
testVectors("SPARK-44805: getShorts with dictionary", 3, ShortType) { testVector =>
400+
val dict = new ColumnDictionary(Array[Int](7, 8, 9))
401+
testVector.setDictionary(dict)
402+
testVector.reserveDictionaryIds(3)
403+
testVector.getDictionaryIds.putInt(0, 0)
404+
testVector.getDictionaryIds.putInt(1, 1)
405+
testVector.getDictionaryIds.putInt(2, 2)
406+
407+
assert(testVector.getShorts(0, 3)(0) == 7)
408+
assert(testVector.getShorts(0, 3)(1) == 8)
409+
assert(testVector.getShorts(0, 3)(2) == 9)
410+
}
411+
412+
testVectors("SPARK-44805: getBytes with dictionary", 3, ByteType) { testVector =>
413+
val dict = new ColumnDictionary(Array[Int](7, 8, 9))
414+
testVector.setDictionary(dict)
415+
testVector.reserveDictionaryIds(3)
416+
testVector.getDictionaryIds.putInt(0, 0)
417+
testVector.getDictionaryIds.putInt(1, 1)
418+
testVector.getDictionaryIds.putInt(2, 2)
419+
420+
assert(testVector.getBytes(0, 3)(0) == 7)
421+
assert(testVector.getBytes(0, 3)(1) == 8)
422+
assert(testVector.getBytes(0, 3)(2) == 9)
423+
}
424+
425+
testVectors("SPARK-44805: getLongs with dictionary", 3, LongType) { testVector =>
426+
val dict = new ColumnDictionary(Array[Long](2147483648L, 2147483649L, 2147483650L))
427+
testVector.setDictionary(dict)
428+
testVector.reserveDictionaryIds(3)
429+
testVector.getDictionaryIds.putInt(0, 0)
430+
testVector.getDictionaryIds.putInt(1, 1)
431+
testVector.getDictionaryIds.putInt(2, 2)
432+
433+
assert(testVector.getLongs(0, 3)(0) == 2147483648L)
434+
assert(testVector.getLongs(0, 3)(1) == 2147483649L)
435+
assert(testVector.getLongs(0, 3)(2) == 2147483650L)
436+
}
437+
438+
testVectors("SPARK-44805: getFloats with dictionary", 3, FloatType) { testVector =>
439+
val dict = new ColumnDictionary(Array[Float](0.1f, 0.2f, 0.3f))
440+
testVector.setDictionary(dict)
441+
testVector.reserveDictionaryIds(3)
442+
testVector.getDictionaryIds.putInt(0, 0)
443+
testVector.getDictionaryIds.putInt(1, 1)
444+
testVector.getDictionaryIds.putInt(2, 2)
445+
446+
assert(testVector.getFloats(0, 3)(0) == 0.1f)
447+
assert(testVector.getFloats(0, 3)(1) == 0.2f)
448+
assert(testVector.getFloats(0, 3)(2) == 0.3f)
449+
}
450+
451+
testVectors("SPARK-44805: getDoubles with dictionary", 3, DoubleType) { testVector =>
452+
val dict = new ColumnDictionary(Array[Double](1342.17727d, 1342.17728d, 1342.17729d))
453+
testVector.setDictionary(dict)
454+
testVector.reserveDictionaryIds(3)
455+
testVector.getDictionaryIds.putInt(0, 0)
456+
testVector.getDictionaryIds.putInt(1, 1)
457+
testVector.getDictionaryIds.putInt(2, 2)
458+
459+
assert(testVector.getDoubles(0, 3)(0) == 1342.17727d)
460+
assert(testVector.getDoubles(0, 3)(1) == 1342.17728d)
461+
assert(testVector.getDoubles(0, 3)(2) == 1342.17729d)
462+
}
463+
386464
test("[SPARK-22092] off-heap column vector reallocation corrupts array data") {
387465
withVector(new OffHeapColumnVector(8, arrayType)) { testVector =>
388466
val data = testVector.arrayData()

0 commit comments

Comments
 (0)