@@ -22,18 +22,16 @@ import java.sql.{Date, Timestamp}
2222import java .text .SimpleDateFormat
2323import java .util .Locale
2424
25- import scala .collection .JavaConverters ._
26-
2725import com .google .common .io .Files
2826import org .apache .arrow .memory .RootAllocator
29- import org .apache .arrow .vector .{NullableIntVector , VectorLoader , VectorSchemaRoot }
27+ import org .apache .arrow .vector .{VectorLoader , VectorSchemaRoot }
3028import org .apache .arrow .vector .file .json .JsonFileReader
3129import org .apache .arrow .vector .util .Validator
3230import org .scalatest .BeforeAndAfterAll
3331
3432import org .apache .spark .{SparkException , TaskContext }
3533import org .apache .spark .sql .{DataFrame , Row }
36- import org .apache .spark .sql .execution . vectorized .{ ArrowColumnVector , ColumnarBatch , ColumnVector }
34+ import org .apache .spark .sql .catalyst . InternalRow
3735import org .apache .spark .sql .test .SharedSQLContext
3836import org .apache .spark .sql .types .{BinaryType , IntegerType , StructField , StructType }
3937import org .apache .spark .util .Utils
@@ -1633,39 +1631,29 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
16331631 }
16341632
16351633 test(" roundtrip payloads" ) {
1636- val allocator = ArrowUtils .rootAllocator.newChildAllocator(" int" , 0 , Long .MaxValue )
1637- val vector = ArrowUtils .toArrowField(" int" , IntegerType , nullable = true )
1638- .createVector(allocator).asInstanceOf [NullableIntVector ]
1639- vector.allocateNew()
1640- val mutator = vector.getMutator()
1641-
1642- (0 until 10 ).foreach { i =>
1643- mutator.setSafe(i, i)
1644- }
1645- mutator.setNull(10 )
1646- mutator.setValueCount(11 )
1634+ val inputRows = (0 until 9 ).map { i =>
1635+ InternalRow (i)
1636+ } :+ InternalRow (null )
16471637
1648- val schema = StructType (Seq (StructField (" int" , IntegerType )))
1649-
1650- val batch = new ColumnarBatch (schema, Array [ColumnVector ](new ArrowColumnVector (vector)), 11 )
1651- batch.setNumRows(11 )
1638+ val schema = StructType (Seq (StructField (" int" , IntegerType , nullable = true )))
16521639
16531640 val ctx = TaskContext .empty()
1654- val payloadIter = ArrowConverters .toPayloadIterator(batch.rowIterator().asScala , schema, 0 , ctx)
1655- val rowIter = ArrowConverters .fromPayloadIterator(payloadIter, ctx)
1641+ val payloadIter = ArrowConverters .toPayloadIterator(inputRows.toIterator , schema, 0 , ctx)
1642+ val outputRowIter = ArrowConverters .fromPayloadIterator(payloadIter, ctx)
16561643
1657- assert(schema.equals(rowIter .schema))
1644+ assert(schema.equals(outputRowIter .schema))
16581645
1659- rowIter.zipWithIndex.foreach { case (row, i) =>
1660- if (i == 10 ) {
1661- assert(row.isNullAt(0 ))
1662- } else {
1646+ var count = 0
1647+ outputRowIter.zipWithIndex.foreach { case (row, i) =>
1648+ if (i != 9 ) {
16631649 assert(row.getInt(0 ) == i)
1650+ } else {
1651+ assert(row.isNullAt(0 ))
16641652 }
1653+ count += 1
16651654 }
16661655
1667- vector.close()
1668- allocator.close()
1656+ assert(count == inputRows.length)
16691657 }
16701658
16711659 /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */
0 commit comments