Skip to content

Commit acbf9e4

Browse files
committed
address comments
1 parent 0c3eb12 commit acbf9e4

File tree

3 files changed

+71
-31
lines changed

3 files changed

+71
-31
lines changed

mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ import org.apache.spark.annotation.{Experimental, Since}
2121
import org.apache.spark.mllib.fpm.{PrefixSpan => mllibPrefixSpan}
2222
import org.apache.spark.sql.{DataFrame, Dataset, Row}
2323
import org.apache.spark.sql.functions.col
24-
import org.apache.spark.sql.types.{LongType, StructField, StructType}
25-
import org.apache.spark.storage.StorageLevel
24+
import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType}
2625

2726
/**
2827
* :: Experimental ::
@@ -44,26 +43,37 @@ object PrefixSpan {
4443
*
4544
* @param dataset A dataset or a dataframe containing a sequence column which is
4645
* {{{Seq[Seq[_]]}}} type
47-
* @param sequenceCol the name of the sequence column in dataset
46+
* @param sequenceCol the name of the sequence column in dataset, rows with nulls in this column
47+
* are ignored
4848
* @param minSupport the minimal support level of the sequential pattern, any pattern that
4949
* appears more than (minSupport * size-of-the-dataset) times will be output
50-
* (default: `0.1`).
51-
* @param maxPatternLength the maximal length of the sequential pattern, any pattern that appears
52-
* less than maxPatternLength will be output (default: `10`).
50+
* (recommended value: `0.1`).
51+
* @param maxPatternLength the maximal length of the sequential pattern
52+
* (recommended value: `10`).
5353
* @param maxLocalProjDBSize The maximum number of items (including delimiters used in the
5454
* internal storage format) allowed in a projected database before
5555
* local processing. If a projected database exceeds this size, another
56-
* iteration of distributed prefix growth is run (default: `32000000`).
57-
* @return A dataframe that contains columns of sequence and corresponding frequency.
56+
* iteration of distributed prefix growth is run
57+
* (recommended value: `32000000`).
58+
* @return A `DataFrame` that contains columns of sequence and corresponding frequency.
59+
* The schema of it will be:
60+
* - `sequence: Seq[Seq[T]]` (T is the item type)
61+
* - `frequency: Long`
5862
*/
5963
@Since("2.4.0")
60-
def findFrequentSequentPatterns(
64+
def findFrequentSequentialPatterns(
6165
dataset: Dataset[_],
6266
sequenceCol: String,
63-
minSupport: Double = 0.1,
64-
maxPatternLength: Int = 10,
65-
maxLocalProjDBSize: Long = 32000000L): DataFrame = {
66-
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
67+
minSupport: Double,
68+
maxPatternLength: Int,
69+
maxLocalProjDBSize: Long): DataFrame = {
70+
71+
val inputType = dataset.schema(sequenceCol).dataType
72+
require(inputType.isInstanceOf[ArrayType] &&
73+
inputType.asInstanceOf[ArrayType].elementType.isInstanceOf[ArrayType],
74+
s"The input column must be ArrayType and the array element type must also be ArrayType, " +
75+
s"but got $inputType.")
76+
6777

6878
val data = dataset.select(sequenceCol)
6979
val sequences = data.where(col(sequenceCol).isNotNull).rdd
@@ -73,18 +83,13 @@ object PrefixSpan {
7383
.setMinSupport(minSupport)
7484
.setMaxPatternLength(maxPatternLength)
7585
.setMaxLocalProjDBSize(maxLocalProjDBSize)
76-
if (handlePersistence) {
77-
sequences.persist(StorageLevel.MEMORY_AND_DISK)
78-
}
86+
7987
val rows = mllibPrefixSpan.run(sequences).freqSequences.map(f => Row(f.sequence, f.freq))
8088
val schema = StructType(Seq(
8189
StructField("sequence", dataset.schema(sequenceCol).dataType, nullable = false),
82-
StructField("freq", LongType, nullable = false)))
90+
StructField("frequency", LongType, nullable = false)))
8391
val freqSequences = dataset.sparkSession.createDataFrame(rows, schema)
8492

85-
if (handlePersistence) {
86-
sequences.unpersist()
87-
}
8893
freqSequences
8994
}
9095

mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ import org.apache.spark.storage.StorageLevel
4949
*
5050
* @param minSupport the minimal support level of the sequential pattern, any pattern that appears
5151
* more than (minSupport * size-of-the-dataset) times will be output
52-
* @param maxPatternLength the maximal length of the sequential pattern, any pattern that appears
53-
* less than maxPatternLength will be output
52+
* @param maxPatternLength the maximal length of the sequential pattern
5453
* @param maxLocalProjDBSize The maximum number of items (including delimiters used in the internal
5554
* storage format) allowed in a projected database before local
5655
* processing. If a projected database exceeds this size, another

mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,13 @@ class PrefixSpanSuite extends MLTest {
2525

2626
override def beforeAll(): Unit = {
2727
super.beforeAll()
28-
smallDataset = Seq(Seq(Seq(1, 2), Seq(1, 2, 3))).toDF("sequence")
2928
}
3029

31-
@transient var smallDataset: DataFrame = _
32-
3330
test("PrefixSpan projections with multiple partial starts") {
34-
val result = PrefixSpan.findFrequentSequentPatterns(smallDataset, "sequence",
35-
minSupport = 1.0, maxPatternLength = 2).as[(Seq[Seq[Int]], Long)].collect()
31+
val smallDataset = Seq(Seq(Seq(1, 2), Seq(1, 2, 3))).toDF("sequence")
32+
val result = PrefixSpan.findFrequentSequentialPatterns(smallDataset, "sequence",
33+
minSupport = 1.0, maxPatternLength = 2, maxLocalProjDBSize = 32000000)
34+
.as[(Seq[Seq[Int]], Long)].collect()
3635
val expected = Array(
3736
(Seq(Seq(1)), 1L),
3837
(Seq(Seq(1, 2)), 1L),
@@ -49,6 +48,32 @@ class PrefixSpanSuite extends MLTest {
4948
compareResults[Int](expected, result)
5049
}
5150

51+
/*
52+
To verify expected results for `smallTestData`, create file "prefixSpanSeqs2" with content
53+
(format = (transactionID, idxInTransaction, numItemsinItemset, itemset)):
54+
1 1 2 1 2
55+
1 2 1 3
56+
2 1 1 1
57+
2 2 2 3 2
58+
2 3 2 1 2
59+
3 1 2 1 2
60+
3 2 1 5
61+
4 1 1 6
62+
In R, run:
63+
library("arulesSequences")
64+
prefixSpanSeqs = read_baskets("prefixSpanSeqs", info = c("sequenceID","eventID","SIZE"))
65+
freqItemSeq = cspade(prefixSpanSeqs,
66+
parameter = 0.5, maxlen = 5 ))
67+
resSeq = as(freqItemSeq, "data.frame")
68+
resSeq
69+
70+
sequence support
71+
1 <{1}> 0.75
72+
2 <{2}> 0.75
73+
3 <{3}> 0.50
74+
4 <{1},{3}> 0.50
75+
5 <{1,2}> 0.75
76+
*/
5277
val smallTestData = Seq(
5378
Seq(Seq(1, 2), Seq(3)),
5479
Seq(Seq(1), Seq(3, 2), Seq(1, 2)),
@@ -65,8 +90,18 @@ class PrefixSpanSuite extends MLTest {
6590

6691
test("PrefixSpan Integer type, variable-size itemsets") {
6792
val df = smallTestData.toDF("sequence")
68-
val result = PrefixSpan.findFrequentSequentPatterns(df, "sequence",
69-
minSupport = 0.5, maxPatternLength = 5).as[(Seq[Seq[Int]], Long)].collect()
93+
val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence",
94+
minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000)
95+
.as[(Seq[Seq[Int]], Long)].collect()
96+
97+
compareResults[Int](smallTestDataExpectedResult, result)
98+
}
99+
100+
test("PrefixSpan input row with nulls") {
101+
val df = (smallTestData :+ null).toDF("sequence")
102+
val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence",
103+
minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000)
104+
.as[(Seq[Seq[Int]], Long)].collect()
70105

71106
compareResults[Int](smallTestDataExpectedResult, result)
72107
}
@@ -76,8 +111,9 @@ class PrefixSpanSuite extends MLTest {
76111
val df = smallTestData
77112
.map(seq => seq.map(itemSet => itemSet.map(intToString)))
78113
.toDF("sequence")
79-
val result = PrefixSpan.findFrequentSequentPatterns(df, "sequence",
80-
minSupport = 0.5, maxPatternLength = 5).as[(Seq[Seq[String]], Long)].collect()
114+
val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence",
115+
minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000)
116+
.as[(Seq[Seq[String]], Long)].collect()
81117

82118
val expected = smallTestDataExpectedResult.map { case (seq, freq) =>
83119
(seq.map(itemSet => itemSet.map(intToString)), freq)

0 commit comments

Comments
 (0)