Skip to content

Commit 32e48d9

Browse files
committed
Do some refactoring
1 parent 8e1a2f4 commit 32e48d9

File tree

1 file changed

+65
-95
lines changed

1 file changed

+65
-95
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala

Lines changed: 65 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class SymmetricHashJoinStateManager(
9797
keyWithIndexToValue.getAll(key, numValues).map { keyIdxToValue =>
9898
val joinedRow = generateJoinedRow(keyIdxToValue.value)
9999
if (predicate(joinedRow)) {
100-
keyWithIndexToMatched.put(key, keyIdxToValue.valueIndex, matched = true)
100+
keyWithIndexToMatched.put(key, keyIdxToValue.valueIndex, Some(true))
101101
joinedRow
102102
} else {
103103
null
@@ -110,7 +110,7 @@ class SymmetricHashJoinStateManager(
110110
val numExistingValues = keyToNumValues.get(key)
111111
keyWithIndexToValue.put(key, numExistingValues, value)
112112
keyToNumValues.put(key, numExistingValues + 1)
113-
keyWithIndexToMatched.put(key, numExistingValues, matched)
113+
keyWithIndexToMatched.put(key, numExistingValues, Some(matched))
114114
}
115115

116116
/**
@@ -129,7 +129,7 @@ class SymmetricHashJoinStateManager(
129129
private val allKeyToNumValues = keyToNumValues.iterator
130130

131131
private var currentKeyToNumValue: KeyAndNumValues = null
132-
private var currentValues: Iterator[KeyWithIndexAndValue] = null
132+
private var currentValues: Iterator[keyWithIndexToValue.KeyWithIndexAndValue] = null
133133

134134
private def currentKey = currentKeyToNumValue.key
135135

@@ -266,13 +266,9 @@ class SymmetricHashJoinStateManager(
266266
keyWithIndexToValue.put(currentKey, index, valueAtMaxIndex)
267267
keyWithIndexToValue.remove(currentKey, numValues - 1)
268268

269-
keyWithIndexToMatched.get(currentKey, numValues - 1) match {
270-
case Some(matchedAtMaxIndex) =>
271-
keyWithIndexToMatched.put(currentKey, index, matchedAtMaxIndex)
272-
keyWithIndexToMatched.remove(currentKey, numValues - 1)
273-
274-
case None =>
275-
}
269+
val matchedAtMaxIndex = keyWithIndexToMatched.get(currentKey, numValues - 1)
270+
keyWithIndexToMatched.put(currentKey, index, matchedAtMaxIndex)
271+
keyWithIndexToMatched.remove(currentKey, numValues - 1)
276272
} else {
277273
keyWithIndexToValue.remove(currentKey, 0)
278274
keyWithIndexToMatched.remove(currentKey, 0)
@@ -337,7 +333,7 @@ class SymmetricHashJoinStateManager(
337333
joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i", k.dataType, k.nullable) })
338334
private val keyAttributes = keySchema.toAttributes
339335
private val keyToNumValues = new KeyToNumValuesStore()
340-
private val keyWithIndexToValue = new KeyWithIndexToValueStore()
336+
private val keyWithIndexToValue = new KeyWithIndexToRowValueStore()
341337
private val keyWithIndexToMatched = new KeyWithIndexToMatchedStore()
342338

343339
// Clean up any state store resources if necessary at the end of the task
@@ -387,7 +383,6 @@ class SymmetricHashJoinStateManager(
387383
}
388384
}
389385

390-
391386
/** A wrapper around a [[StateStore]] that stores [key -> number of values]. */
392387
private class KeyToNumValuesStore extends StateStoreHandler(KeyToNumValuesType) {
393388
private val longValueSchema = new StructType().add("value", "long")
@@ -420,37 +415,41 @@ class SymmetricHashJoinStateManager(
420415
}
421416
}
422417

423-
/**
424-
* Helper class for representing data returned by [[KeyWithIndexToValueStore]].
425-
* Designed for object reuse.
426-
*/
427-
private case class KeyWithIndexAndValue(
428-
var key: UnsafeRow = null, var valueIndex: Long = -1, var value: UnsafeRow = null) {
429-
def withNew(newKey: UnsafeRow, newIndex: Long, newValue: UnsafeRow): this.type = {
430-
this.key = newKey
431-
this.valueIndex = newIndex
432-
this.value = newValue
433-
this
418+
private abstract class KeyWithIndexToValueStore[T](
419+
storeType: StateStoreType,
420+
valueSchema: StructType)
421+
extends StateStoreHandler(storeType) {
422+
423+
/**
424+
* Helper class for representing data returned by [[KeyWithIndexToValueStore]].
425+
* Designed for object reuse.
426+
*/
427+
case class KeyWithIndexAndValue(
428+
var key: UnsafeRow = null,
429+
var valueIndex: Long = -1,
430+
var value: T = null.asInstanceOf[T]) {
431+
def withNew(newKey: UnsafeRow, newIndex: Long, newValue: T): this.type = {
432+
this.key = newKey
433+
this.valueIndex = newIndex
434+
this.value = newValue
435+
this
436+
}
434437
}
435-
}
436438

437-
/** A wrapper around a [[StateStore]] that stores [(key, index) -> value]. */
438-
private class KeyWithIndexToValueStore extends StateStoreHandler(KeyWithIndexToValueType) {
439439
private val keyWithIndexExprs = keyAttributes :+ Literal(1L)
440440
private val keyWithIndexSchema = keySchema.add("index", LongType)
441441
private val indexOrdinalInKeyWithIndexRow = keyAttributes.size
442442

443-
// Projection to generate (key + index) row from key row
444443
private val keyWithIndexRowGenerator = UnsafeProjection.create(keyWithIndexExprs, keyAttributes)
445444

446445
// Projection to generate key row from (key + index) row
447446
private val keyRowGenerator = UnsafeProjection.create(
448447
keyAttributes, keyAttributes :+ AttributeReference("index", LongType)())
449448

450-
protected val stateStore = getStateStore(keyWithIndexSchema, inputValueAttributes.toStructType)
449+
protected val stateStore = getStateStore(keyWithIndexSchema, valueSchema)
451450

452-
def get(key: UnsafeRow, valueIndex: Long): UnsafeRow = {
453-
stateStore.get(keyWithIndexRow(key, valueIndex))
451+
def get(key: UnsafeRow, valueIndex: Long): T = {
452+
convertValue(stateStore.get(keyWithIndexRow(key, valueIndex)))
454453
}
455454

456455
/**
@@ -468,7 +467,7 @@ class SymmetricHashJoinStateManager(
468467
} else {
469468
val keyWithIndex = keyWithIndexRow(key, index)
470469
val value = stateStore.get(keyWithIndex)
471-
keyWithIndexAndValue.withNew(key, index, value)
470+
keyWithIndexAndValue.withNew(key, index, convertValue(value))
472471
index += 1
473472
keyWithIndexAndValue
474473
}
@@ -479,9 +478,12 @@ class SymmetricHashJoinStateManager(
479478
}
480479

481480
/** Put new value for key at the given index */
482-
def put(key: UnsafeRow, valueIndex: Long, value: UnsafeRow): Unit = {
481+
def put(key: UnsafeRow, valueIndex: Long, value: T): Unit = {
483482
val keyWithIndex = keyWithIndexRow(key, valueIndex)
484-
stateStore.put(keyWithIndex, value)
483+
val row = convertToValueRow(value)
484+
if (row != null) {
485+
stateStore.put(keyWithIndex, row)
486+
}
485487
}
486488

487489
/**
@@ -504,90 +506,57 @@ class SymmetricHashJoinStateManager(
504506
def iterator: Iterator[KeyWithIndexAndValue] = {
505507
val keyWithIndexAndValue = new KeyWithIndexAndValue()
506508
stateStore.getRange(None, None).map { pair =>
507-
keyWithIndexAndValue.withNew(
508-
keyRowGenerator(pair.key), pair.key.getLong(indexOrdinalInKeyWithIndexRow), pair.value)
509+
keyWithIndexAndValue.withNew(keyRowGenerator(pair.key),
510+
pair.key.getLong(indexOrdinalInKeyWithIndexRow), convertValue(pair.value))
509511
keyWithIndexAndValue
510512
}
511513
}
512514

513515
/** Generated a row using the key and index */
514-
private def keyWithIndexRow(key: UnsafeRow, valueIndex: Long): UnsafeRow = {
516+
protected def keyWithIndexRow(key: UnsafeRow, valueIndex: Long): UnsafeRow = {
515517
val row = keyWithIndexRowGenerator(key)
516518
row.setLong(indexOrdinalInKeyWithIndexRow, valueIndex)
517519
row
518520
}
519-
}
520521

521-
/**
522-
* Helper class for representing data returned by [[KeyWithIndexToMatchedStore]].
523-
* Designed for object reuse.
524-
*/
525-
private case class KeyWithIndexAndMatched(
526-
var key: UnsafeRow = null, var valueIndex: Long = -1, var matched: Boolean = false) {
527-
def withNew(newKey: UnsafeRow, newIndex: Long, newMatched: Boolean): this.type = {
528-
this.key = newKey
529-
this.valueIndex = newIndex
530-
this.matched = newMatched
531-
this
532-
}
522+
protected def convertValue(value: UnsafeRow): T
523+
protected def convertToValueRow(value: T): UnsafeRow
533524
}
534525

535-
// TODO: clean up KeyWithIndexToValueStore and KeyWithIndexToMatchedStore
526+
/** A wrapper around a [[StateStore]] that stores [(key, index) -> value]. */
527+
private class KeyWithIndexToRowValueStore
528+
extends KeyWithIndexToValueStore[UnsafeRow](
529+
KeyWithIndexToRowValueType,
530+
inputValueAttributes.toStructType) {
536531

537-
/** A wrapper around a [[StateStore]] that stores [(key, index) -> matched]. */
538-
private class KeyWithIndexToMatchedStore extends StateStoreHandler(KeyWithIndexToMatchedType) {
539-
private val keyWithIndexExprs = keyAttributes :+ Literal(1L)
540-
private val keyWithIndexSchema = keySchema.add("index", LongType)
541-
private val indexOrdinalInKeyWithIndexRow = keyAttributes.size
532+
override protected def convertValue(value: UnsafeRow): UnsafeRow = value
542533

543-
// Projection to generate (key + index) row from key row
544-
private val keyWithIndexRowGenerator = UnsafeProjection.create(keyWithIndexExprs, keyAttributes)
534+
override protected def convertToValueRow(value: UnsafeRow): UnsafeRow = value
535+
}
545536

546-
// Projection to generate key row from (key + index) row
547-
private val keyRowGenerator = UnsafeProjection.create(
548-
keyAttributes, keyAttributes :+ AttributeReference("index", LongType)())
537+
private class KeyWithIndexToMatchedStore extends KeyWithIndexToValueStore[Option[Boolean]](
538+
KeyWithIndexToMatchedType,
539+
KeyWithIndexToMatchedStore.booleanValueSchema) {
549540

550541
private val booleanValueSchema = new StructType().add("value", "boolean")
551542
private val booleanToUnsafeRow = UnsafeProjection.create(booleanValueSchema)
552543
private val valueRow = booleanToUnsafeRow(new SpecificInternalRow(booleanValueSchema))
553544

554-
protected val stateStore = getStateStore(keyWithIndexSchema, booleanValueSchema)
555-
556-
def get(key: UnsafeRow, valueIndex: Long): Option[Boolean] = {
557-
val row = stateStore.get(keyWithIndexRow(key, valueIndex))
558-
if (row != null) Some(row.getBoolean(0)) else None
545+
override protected def convertValue(value: UnsafeRow): Option[Boolean] = {
546+
if (value != null) Some(value.getBoolean(0)) else None
559547
}
560548

561-
/** Put matched for key at the given index */
562-
def put(key: UnsafeRow, valueIndex: Long, matched: Boolean): Unit = {
563-
val keyWithIndex = keyWithIndexRow(key, valueIndex)
564-
valueRow.setBoolean(0, matched)
565-
stateStore.put(keyWithIndex, valueRow)
566-
}
549+
override protected def convertToValueRow(value: Option[Boolean]): UnsafeRow = value match {
550+
case Some(matched) =>
551+
valueRow.setBoolean(0, matched)
552+
valueRow
567553

568-
/**
569-
* Remove key and value at given index. Note that this will create a hole in
570-
* (key, index) and it is upto the caller to deal with it.
571-
*/
572-
def remove(key: UnsafeRow, valueIndex: Long): Unit = {
573-
stateStore.remove(keyWithIndexRow(key, valueIndex))
574-
}
575-
576-
/** Remove all values (i.e. all the indices) for the given key. */
577-
def removeAllValues(key: UnsafeRow, numValues: Long): Unit = {
578-
var index = 0
579-
while (index < numValues) {
580-
stateStore.remove(keyWithIndexRow(key, index))
581-
index += 1
582-
}
554+
case None => null
583555
}
556+
}
584557

585-
/** Generated a row using the key and index */
586-
private def keyWithIndexRow(key: UnsafeRow, valueIndex: Long): UnsafeRow = {
587-
val row = keyWithIndexRowGenerator(key)
588-
row.setLong(indexOrdinalInKeyWithIndexRow, valueIndex)
589-
row
590-
}
558+
private object KeyWithIndexToMatchedStore {
559+
val booleanValueSchema = new StructType().add("value", "boolean")
591560
}
592561
}
593562

@@ -610,7 +579,8 @@ object SymmetricHashJoinStateManager {
610579
}
611580

612581
def allStateStoreNames(joinSides: JoinSide*): Seq[String] = {
613-
val allStateStoreTypes: Seq[StateStoreType] = Seq(KeyToNumValuesType, KeyWithIndexToValueType)
582+
val allStateStoreTypes: Seq[StateStoreType] = Seq(KeyToNumValuesType,
583+
KeyWithIndexToRowValueType, KeyWithIndexToMatchedType)
614584
for (joinSide <- joinSides; stateStoreType <- allStateStoreTypes) yield {
615585
getStateStoreName(joinSide, stateStoreType)
616586
}
@@ -622,8 +592,8 @@ object SymmetricHashJoinStateManager {
622592
override def toString(): String = "keyToNumValues"
623593
}
624594

625-
private case object KeyWithIndexToValueType extends StateStoreType {
626-
override def toString(): String = "keyWithIndexToValue"
595+
private case object KeyWithIndexToRowValueType extends StateStoreType {
596+
override def toString(): String = "keyWithIndexToRowValue"
627597
}
628598

629599
private case object KeyWithIndexToMatchedType extends StateStoreType {

0 commit comments

Comments
 (0)