@@ -97,7 +97,7 @@ class SymmetricHashJoinStateManager(
9797 keyWithIndexToValue.getAll(key, numValues).map { keyIdxToValue =>
9898 val joinedRow = generateJoinedRow(keyIdxToValue.value)
9999 if (predicate(joinedRow)) {
100- keyWithIndexToMatched.put(key, keyIdxToValue.valueIndex, matched = true )
100+ keyWithIndexToMatched.put(key, keyIdxToValue.valueIndex, Some ( true ) )
101101 joinedRow
102102 } else {
103103 null
@@ -110,7 +110,7 @@ class SymmetricHashJoinStateManager(
110110 val numExistingValues = keyToNumValues.get(key)
111111 keyWithIndexToValue.put(key, numExistingValues, value)
112112 keyToNumValues.put(key, numExistingValues + 1 )
113- keyWithIndexToMatched.put(key, numExistingValues, matched)
113+ keyWithIndexToMatched.put(key, numExistingValues, Some ( matched) )
114114 }
115115
116116 /**
@@ -129,7 +129,7 @@ class SymmetricHashJoinStateManager(
129129 private val allKeyToNumValues = keyToNumValues.iterator
130130
131131 private var currentKeyToNumValue : KeyAndNumValues = null
132- private var currentValues : Iterator [KeyWithIndexAndValue ] = null
132+ private var currentValues : Iterator [keyWithIndexToValue. KeyWithIndexAndValue ] = null
133133
134134 private def currentKey = currentKeyToNumValue.key
135135
@@ -266,13 +266,9 @@ class SymmetricHashJoinStateManager(
266266 keyWithIndexToValue.put(currentKey, index, valueAtMaxIndex)
267267 keyWithIndexToValue.remove(currentKey, numValues - 1 )
268268
269- keyWithIndexToMatched.get(currentKey, numValues - 1 ) match {
270- case Some (matchedAtMaxIndex) =>
271- keyWithIndexToMatched.put(currentKey, index, matchedAtMaxIndex)
272- keyWithIndexToMatched.remove(currentKey, numValues - 1 )
273-
274- case None =>
275- }
269+ val matchedAtMaxIndex = keyWithIndexToMatched.get(currentKey, numValues - 1 )
270+ keyWithIndexToMatched.put(currentKey, index, matchedAtMaxIndex)
271+ keyWithIndexToMatched.remove(currentKey, numValues - 1 )
276272 } else {
277273 keyWithIndexToValue.remove(currentKey, 0 )
278274 keyWithIndexToMatched.remove(currentKey, 0 )
@@ -337,7 +333,7 @@ class SymmetricHashJoinStateManager(
337333 joinKeys.zipWithIndex.map { case (k, i) => StructField (s " field $i" , k.dataType, k.nullable) })
338334 private val keyAttributes = keySchema.toAttributes
339335 private val keyToNumValues = new KeyToNumValuesStore ()
340- private val keyWithIndexToValue = new KeyWithIndexToValueStore ()
336+ private val keyWithIndexToValue = new KeyWithIndexToRowValueStore ()
341337 private val keyWithIndexToMatched = new KeyWithIndexToMatchedStore ()
342338
343339 // Clean up any state store resources if necessary at the end of the task
@@ -387,7 +383,6 @@ class SymmetricHashJoinStateManager(
387383 }
388384 }
389385
390-
391386 /** A wrapper around a [[StateStore ]] that stores [key -> number of values]. */
392387 private class KeyToNumValuesStore extends StateStoreHandler (KeyToNumValuesType ) {
393388 private val longValueSchema = new StructType ().add(" value" , " long" )
@@ -420,37 +415,41 @@ class SymmetricHashJoinStateManager(
420415 }
421416 }
422417
423- /**
424- * Helper class for representing data returned by [[KeyWithIndexToValueStore ]].
425- * Designed for object reuse.
426- */
427- private case class KeyWithIndexAndValue (
428- var key : UnsafeRow = null , var valueIndex : Long = - 1 , var value : UnsafeRow = null ) {
429- def withNew (newKey : UnsafeRow , newIndex : Long , newValue : UnsafeRow ): this .type = {
430- this .key = newKey
431- this .valueIndex = newIndex
432- this .value = newValue
433- this
418+ private abstract class KeyWithIndexToValueStore [T ](
419+ storeType : StateStoreType ,
420+ valueSchema : StructType )
421+ extends StateStoreHandler (storeType) {
422+
423+ /**
424+ * Helper class for representing data returned by [[KeyWithIndexToValueStore ]].
425+ * Designed for object reuse.
426+ */
427+ case class KeyWithIndexAndValue (
428+ var key : UnsafeRow = null ,
429+ var valueIndex : Long = - 1 ,
430+ var value : T = null .asInstanceOf [T ]) {
431+ def withNew (newKey : UnsafeRow , newIndex : Long , newValue : T ): this .type = {
432+ this .key = newKey
433+ this .valueIndex = newIndex
434+ this .value = newValue
435+ this
436+ }
434437 }
435- }
436438
437- /** A wrapper around a [[StateStore ]] that stores [(key, index) -> value]. */
438- private class KeyWithIndexToValueStore extends StateStoreHandler (KeyWithIndexToValueType ) {
439439 private val keyWithIndexExprs = keyAttributes :+ Literal (1L )
440440 private val keyWithIndexSchema = keySchema.add(" index" , LongType )
441441 private val indexOrdinalInKeyWithIndexRow = keyAttributes.size
442442
443- // Projection to generate (key + index) row from key row
444443 private val keyWithIndexRowGenerator = UnsafeProjection .create(keyWithIndexExprs, keyAttributes)
445444
446445 // Projection to generate key row from (key + index) row
447446 private val keyRowGenerator = UnsafeProjection .create(
448447 keyAttributes, keyAttributes :+ AttributeReference (" index" , LongType )())
449448
450- protected val stateStore = getStateStore(keyWithIndexSchema, inputValueAttributes.toStructType )
449+ protected val stateStore = getStateStore(keyWithIndexSchema, valueSchema )
451450
452- def get (key : UnsafeRow , valueIndex : Long ): UnsafeRow = {
453- stateStore.get(keyWithIndexRow(key, valueIndex))
451+ def get (key : UnsafeRow , valueIndex : Long ): T = {
452+ convertValue( stateStore.get(keyWithIndexRow(key, valueIndex) ))
454453 }
455454
456455 /**
@@ -468,7 +467,7 @@ class SymmetricHashJoinStateManager(
468467 } else {
469468 val keyWithIndex = keyWithIndexRow(key, index)
470469 val value = stateStore.get(keyWithIndex)
471- keyWithIndexAndValue.withNew(key, index, value)
470+ keyWithIndexAndValue.withNew(key, index, convertValue( value) )
472471 index += 1
473472 keyWithIndexAndValue
474473 }
@@ -479,9 +478,12 @@ class SymmetricHashJoinStateManager(
479478 }
480479
481480 /** Put new value for key at the given index */
482- def put (key : UnsafeRow , valueIndex : Long , value : UnsafeRow ): Unit = {
481+ def put (key : UnsafeRow , valueIndex : Long , value : T ): Unit = {
483482 val keyWithIndex = keyWithIndexRow(key, valueIndex)
484- stateStore.put(keyWithIndex, value)
483+ val row = convertToValueRow(value)
484+ if (row != null ) {
485+ stateStore.put(keyWithIndex, row)
486+ }
485487 }
486488
487489 /**
@@ -504,90 +506,57 @@ class SymmetricHashJoinStateManager(
504506 def iterator : Iterator [KeyWithIndexAndValue ] = {
505507 val keyWithIndexAndValue = new KeyWithIndexAndValue ()
506508 stateStore.getRange(None , None ).map { pair =>
507- keyWithIndexAndValue.withNew(
508- keyRowGenerator( pair.key), pair.key. getLong(indexOrdinalInKeyWithIndexRow), pair.value)
509+ keyWithIndexAndValue.withNew(keyRowGenerator(pair.key),
510+ pair.key. getLong(indexOrdinalInKeyWithIndexRow), convertValue( pair.value) )
509511 keyWithIndexAndValue
510512 }
511513 }
512514
513515 /** Generated a row using the key and index */
514- private def keyWithIndexRow (key : UnsafeRow , valueIndex : Long ): UnsafeRow = {
516+ protected def keyWithIndexRow (key : UnsafeRow , valueIndex : Long ): UnsafeRow = {
515517 val row = keyWithIndexRowGenerator(key)
516518 row.setLong(indexOrdinalInKeyWithIndexRow, valueIndex)
517519 row
518520 }
519- }
520521
521- /**
522- * Helper class for representing data returned by [[KeyWithIndexToMatchedStore ]].
523- * Designed for object reuse.
524- */
525- private case class KeyWithIndexAndMatched (
526- var key : UnsafeRow = null , var valueIndex : Long = - 1 , var matched : Boolean = false ) {
527- def withNew (newKey : UnsafeRow , newIndex : Long , newMatched : Boolean ): this .type = {
528- this .key = newKey
529- this .valueIndex = newIndex
530- this .matched = newMatched
531- this
532- }
522+ protected def convertValue (value : UnsafeRow ): T
523+ protected def convertToValueRow (value : T ): UnsafeRow
533524 }
534525
535- // TODO: clean up KeyWithIndexToValueStore and KeyWithIndexToMatchedStore
526+ /** A wrapper around a [[StateStore ]] that stores [(key, index) -> value]. */
527+ private class KeyWithIndexToRowValueStore
528+ extends KeyWithIndexToValueStore [UnsafeRow ](
529+ KeyWithIndexToRowValueType ,
530+ inputValueAttributes.toStructType) {
536531
537- /** A wrapper around a [[StateStore ]] that stores [(key, index) -> matched]. */
538- private class KeyWithIndexToMatchedStore extends StateStoreHandler (KeyWithIndexToMatchedType ) {
539- private val keyWithIndexExprs = keyAttributes :+ Literal (1L )
540- private val keyWithIndexSchema = keySchema.add(" index" , LongType )
541- private val indexOrdinalInKeyWithIndexRow = keyAttributes.size
532+ override protected def convertValue (value : UnsafeRow ): UnsafeRow = value
542533
543- // Projection to generate (key + index) row from key row
544- private val keyWithIndexRowGenerator = UnsafeProjection .create(keyWithIndexExprs, keyAttributes)
534+ override protected def convertToValueRow ( value : UnsafeRow ) : UnsafeRow = value
535+ }
545536
546- // Projection to generate key row from (key + index) row
547- private val keyRowGenerator = UnsafeProjection .create(
548- keyAttributes, keyAttributes :+ AttributeReference ( " index " , LongType )())
537+ private class KeyWithIndexToMatchedStore extends KeyWithIndexToValueStore [ Option [ Boolean ]](
538+ KeyWithIndexToMatchedType ,
539+ KeyWithIndexToMatchedStore .booleanValueSchema) {
549540
550541 private val booleanValueSchema = new StructType ().add(" value" , " boolean" )
551542 private val booleanToUnsafeRow = UnsafeProjection .create(booleanValueSchema)
552543 private val valueRow = booleanToUnsafeRow(new SpecificInternalRow (booleanValueSchema))
553544
554- protected val stateStore = getStateStore(keyWithIndexSchema, booleanValueSchema)
555-
556- def get (key : UnsafeRow , valueIndex : Long ): Option [Boolean ] = {
557- val row = stateStore.get(keyWithIndexRow(key, valueIndex))
558- if (row != null ) Some (row.getBoolean(0 )) else None
545+ override protected def convertValue (value : UnsafeRow ): Option [Boolean ] = {
546+ if (value != null ) Some (value.getBoolean(0 )) else None
559547 }
560548
561- /** Put matched for key at the given index */
562- def put (key : UnsafeRow , valueIndex : Long , matched : Boolean ): Unit = {
563- val keyWithIndex = keyWithIndexRow(key, valueIndex)
564- valueRow.setBoolean(0 , matched)
565- stateStore.put(keyWithIndex, valueRow)
566- }
549+ override protected def convertToValueRow (value : Option [Boolean ]): UnsafeRow = value match {
550+ case Some (matched) =>
551+ valueRow.setBoolean(0 , matched)
552+ valueRow
567553
568- /**
569- * Remove key and value at given index. Note that this will create a hole in
570- * (key, index) and it is upto the caller to deal with it.
571- */
572- def remove (key : UnsafeRow , valueIndex : Long ): Unit = {
573- stateStore.remove(keyWithIndexRow(key, valueIndex))
574- }
575-
576- /** Remove all values (i.e. all the indices) for the given key. */
577- def removeAllValues (key : UnsafeRow , numValues : Long ): Unit = {
578- var index = 0
579- while (index < numValues) {
580- stateStore.remove(keyWithIndexRow(key, index))
581- index += 1
582- }
554+ case None => null
583555 }
556+ }
584557
585- /** Generated a row using the key and index */
586- private def keyWithIndexRow (key : UnsafeRow , valueIndex : Long ): UnsafeRow = {
587- val row = keyWithIndexRowGenerator(key)
588- row.setLong(indexOrdinalInKeyWithIndexRow, valueIndex)
589- row
590- }
558+ private object KeyWithIndexToMatchedStore {
559+ val booleanValueSchema = new StructType ().add(" value" , " boolean" )
591560 }
592561}
593562
@@ -610,7 +579,8 @@ object SymmetricHashJoinStateManager {
610579 }
611580
612581 def allStateStoreNames (joinSides : JoinSide * ): Seq [String ] = {
613- val allStateStoreTypes : Seq [StateStoreType ] = Seq (KeyToNumValuesType , KeyWithIndexToValueType )
582+ val allStateStoreTypes : Seq [StateStoreType ] = Seq (KeyToNumValuesType ,
583+ KeyWithIndexToRowValueType , KeyWithIndexToMatchedType )
614584 for (joinSide <- joinSides; stateStoreType <- allStateStoreTypes) yield {
615585 getStateStoreName(joinSide, stateStoreType)
616586 }
@@ -622,8 +592,8 @@ object SymmetricHashJoinStateManager {
622592 override def toString (): String = " keyToNumValues"
623593 }
624594
625- private case object KeyWithIndexToValueType extends StateStoreType {
626- override def toString (): String = " keyWithIndexToValue "
595+ private case object KeyWithIndexToRowValueType extends StateStoreType {
596+ override def toString (): String = " keyWithIndexToRowValue "
627597 }
628598
629599 private case object KeyWithIndexToMatchedType extends StateStoreType {
0 commit comments