Skip to content

Commit 6459a67

Browse files
arayyhuai
authored andcommitted
[SPARK-11690][PYSPARK] Add pivot to python api
This PR adds pivot to the python api of GroupedData with the same syntax as Scala/Java. Author: Andrew Ray <[email protected]> Closes #9653 from aray/sql-pivot-python. (cherry picked from commit a244779) Signed-off-by: Yin Huai <[email protected]>
1 parent 4a1bcb2 commit 6459a67

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

python/pyspark/sql/group.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from pyspark import since
1919
from pyspark.rdd import ignore_unicode_prefix
20-
from pyspark.sql.column import Column, _to_seq
20+
from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal
2121
from pyspark.sql.dataframe import DataFrame
2222
from pyspark.sql.types import *
2323

@@ -167,6 +167,23 @@ def sum(self, *cols):
167167
[Row(sum(age)=7, sum(height)=165)]
168168
"""
169169

170+
@since(1.6)
171+
def pivot(self, pivot_col, *values):
172+
"""Pivots a column of the current DataFrame and preform the specified aggregation.
173+
174+
:param pivot_col: Column to pivot
175+
:param values: Optional list of values of pivotColumn that will be translated to columns in
176+
the output data frame. If values are not provided the method with do an immediate call
177+
to .distinct() on the pivot column.
178+
>>> df4.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings").collect()
179+
[Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000, Java=30000)]
180+
>>> df4.groupBy("year").pivot("course").sum("earnings").collect()
181+
[Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)]
182+
"""
183+
jgd = self._jdf.pivot(_to_java_column(pivot_col),
184+
_to_seq(self.sql_ctx._sc, values, _create_column_from_literal))
185+
return GroupedData(jgd, self.sql_ctx)
186+
170187

171188
def _test():
172189
import doctest
@@ -182,6 +199,11 @@ def _test():
182199
StructField('name', StringType())]))
183200
globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
184201
Row(name='Bob', age=5, height=85)]).toDF()
202+
globs['df4'] = sc.parallelize([Row(course="dotNET", year=2012, earnings=10000),
203+
Row(course="Java", year=2012, earnings=20000),
204+
Row(course="dotNET", year=2012, earnings=5000),
205+
Row(course="dotNET", year=2013, earnings=48000),
206+
Row(course="Java", year=2013, earnings=30000)]).toDF()
185207

186208
(failure_count, test_count) = doctest.testmod(
187209
pyspark.sql.group, globs=globs,

0 commit comments

Comments
 (0)