Skip to content

Commit 33a5e0d

Browse files
committed
Move tests to PythonUDFSuite.
1 parent 34531b4 commit 33a5e0d

File tree

2 files changed

+67
-46
lines changed

2 files changed

+67
-46
lines changed

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2189,50 +2189,4 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
21892189
|*(1) Range (0, 10, step=1, splits=2)""".stripMargin))
21902190
}
21912191
}
2192-
2193-
test("SPARK-28445: PythonUDF in grouping key and aggregate expressions") {
2194-
import IntegratedUDFTestUtils._
2195-
2196-
val scalaTestUDF = TestScalaUDF(name = "scalaUDF")
2197-
val pythonTestUDF = TestPythonUDF(name = "pyUDF")
2198-
assume(shouldTestPythonUDFs)
2199-
2200-
withTempView("testData") {
2201-
sql(
2202-
"""CREATE OR REPLACE TEMPORARY VIEW testData AS
2203-
|SELECT * FROM VALUES
2204-
|(1, 1), (1, 2), (2, 1), (2, 2), (3, 1), (3, 2), (null, 1), (3, null), (null, null)
2205-
|AS testData(a, b)""".stripMargin)
2206-
2207-
val base = spark.table("testData")
2208-
2209-
val df = base.groupBy(scalaTestUDF(base("a") + 1))
2210-
.agg(scalaTestUDF(base("a") + 1), scalaTestUDF(count(base("b"))))
2211-
val df2 = base.groupBy(pythonTestUDF(base("a") + 1))
2212-
.agg(pythonTestUDF(base("a") + 1), pythonTestUDF(count(base("b"))))
2213-
checkAnswer(df, df2)
2214-
2215-
val df3 = base.groupBy(scalaTestUDF(base("a") + 1))
2216-
.agg(scalaTestUDF(base("a") + 1) + 1, scalaTestUDF(count(base("b"))))
2217-
val df4 = base.groupBy(pythonTestUDF(base("a") + 1))
2218-
.agg(pythonTestUDF(base("a") + 1) + 1, pythonTestUDF(count(base("b"))))
2219-
checkAnswer(df3, df4)
2220-
2221-
// PythonUDF in aggregate expression has grouping key in its arguments.
2222-
val df5 = base.groupBy(scalaTestUDF(base("a") + 1))
2223-
.agg(scalaTestUDF(scalaTestUDF(base("a") + 1)), scalaTestUDF(count(base("b"))))
2224-
val df6 = base.groupBy(pythonTestUDF(base("a") + 1))
2225-
.agg(pythonTestUDF(pythonTestUDF(base("a") + 1)), pythonTestUDF(count(base("b"))))
2226-
checkAnswer(df5, df6)
2227-
2228-
// PythonUDF over grouping key is argument to aggregate function.
2229-
val df7 = base.groupBy(scalaTestUDF(base("a") + 1))
2230-
.agg(scalaTestUDF(scalaTestUDF(base("a") + 1)),
2231-
scalaTestUDF(count(scalaTestUDF(base("a") + 1))))
2232-
val df8 = base.groupBy(pythonTestUDF(base("a") + 1))
2233-
.agg(pythonTestUDF(pythonTestUDF(base("a") + 1)),
2234-
pythonTestUDF(count(pythonTestUDF(base("a") + 1))))
2235-
checkAnswer(df7, df8)
2236-
}
2237-
}
22382192
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.python
19+
20+
import org.apache.spark.sql.{IntegratedUDFTestUtils, QueryTest}
21+
import org.apache.spark.sql.functions.count
22+
import org.apache.spark.sql.test.SharedSQLContext
23+
24+
class PythonUDFSuite extends QueryTest with SharedSQLContext {
25+
import testImplicits._
26+
27+
test("SPARK-28445: PythonUDF in grouping key and aggregate expressions") {
28+
import IntegratedUDFTestUtils._
29+
30+
val scalaTestUDF = TestScalaUDF(name = "scalaUDF")
31+
val pythonTestUDF = TestPythonUDF(name = "pyUDF")
32+
assume(shouldTestPythonUDFs)
33+
34+
val base = Seq(
35+
(Some(1), Some(1)), (Some(1), Some(2)), (Some(2), Some(1)),
36+
(Some(2), Some(2)), (Some(3), Some(1)), (Some(3), Some(2)),
37+
(None, Some(1)), (Some(3), None), (None, None)).toDF("a", "b")
38+
39+
val df = base.groupBy(scalaTestUDF(base("a") + 1))
40+
.agg(scalaTestUDF(base("a") + 1), scalaTestUDF(count(base("b"))))
41+
val df2 = base.groupBy(pythonTestUDF(base("a") + 1))
42+
.agg(pythonTestUDF(base("a") + 1), pythonTestUDF(count(base("b"))))
43+
checkAnswer(df, df2)
44+
45+
val df3 = base.groupBy(scalaTestUDF(base("a") + 1))
46+
.agg(scalaTestUDF(base("a") + 1) + 1, scalaTestUDF(count(base("b"))))
47+
val df4 = base.groupBy(pythonTestUDF(base("a") + 1))
48+
.agg(pythonTestUDF(base("a") + 1) + 1, pythonTestUDF(count(base("b"))))
49+
checkAnswer(df3, df4)
50+
51+
// PythonUDF in aggregate expression has grouping key in its arguments.
52+
val df5 = base.groupBy(scalaTestUDF(base("a") + 1))
53+
.agg(scalaTestUDF(scalaTestUDF(base("a") + 1)), scalaTestUDF(count(base("b"))))
54+
val df6 = base.groupBy(pythonTestUDF(base("a") + 1))
55+
.agg(pythonTestUDF(pythonTestUDF(base("a") + 1)), pythonTestUDF(count(base("b"))))
56+
checkAnswer(df5, df6)
57+
58+
// PythonUDF over grouping key is argument to aggregate function.
59+
val df7 = base.groupBy(scalaTestUDF(base("a") + 1))
60+
.agg(scalaTestUDF(scalaTestUDF(base("a") + 1)),
61+
scalaTestUDF(count(scalaTestUDF(base("a") + 1))))
62+
val df8 = base.groupBy(pythonTestUDF(base("a") + 1))
63+
.agg(pythonTestUDF(pythonTestUDF(base("a") + 1)),
64+
pythonTestUDF(count(pythonTestUDF(base("a") + 1))))
65+
checkAnswer(df7, df8)
66+
}
67+
}

0 commit comments

Comments
 (0)