Skip to content

Commit d292896

Browse files
committed
Use only distinct predictions
- Apply predictions after flatMap in ml.FPGrowthModel.transform - Add tests
1 parent 05887fc commit d292896

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,10 +245,10 @@ class FPGrowthModel private[ml] (
245245
rule._2.filter(item => !itemset.contains(item))
246246
} else {
247247
Seq.empty
248-
})
248+
}).distinct
249249
} else {
250250
Seq.empty
251-
}.distinct }, dt)
251+
}}, dt)
252252
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
253253
}
254254

mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,20 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
103103
FPGrowthSuite.allParamSettings, checkModelData)
104104
}
105105

106+
test("FPGrowth prediction should not contain duplicates") {
107+
// This should generate rule 1 -> 3, 2 -> 3
108+
val dataset = spark.createDataFrame(Seq(
109+
Array("1", "3"),
110+
Array("2", "3")
111+
).map(Tuple1(_))).toDF("features")
112+
val model = new FPGrowth().fit(dataset)
113+
114+
val prediction = model.transform(
115+
spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("features")
116+
).first().getAs[Seq[String]]("prediction")
117+
118+
assert(prediction === Seq("3"))
119+
}
106120
}
107121

108122
object FPGrowthSuite {

0 commit comments

Comments
 (0)