Skip to content

Commit e044705

Browse files
Davies Liudavies
authored andcommitted
[SPARK-9116] [SQL] [PYSPARK] support Python only UDT in __main__
Also we could create a Python UDT without having a Scala one, it's important for Python users. cc mengxr JoshRosen Author: Davies Liu <[email protected]> Closes #7453 from davies/class_in_main and squashes the following commits: 4dfd5e1 [Davies Liu] add tests for Python and Scala UDT 793d9b2 [Davies Liu] Merge branch 'master' of github.com:apache/spark into class_in_main dc65f19 [Davies Liu] address comment a9a3c40 [Davies Liu] Merge branch 'master' of github.com:apache/spark into class_in_main a86e1fc [Davies Liu] fix serialization ad528ba [Davies Liu] Merge branch 'master' of github.com:apache/spark into class_in_main 63f52ef [Davies Liu] fix pylint check 655b8a9 [Davies Liu] Merge branch 'master' of github.com:apache/spark into class_in_main 316a394 [Davies Liu] support Python UDT with UTF 0bcb3ef [Davies Liu] fix bug in mllib de986d6 [Davies Liu] fix test 83d65ac [Davies Liu] fix bug in StructType 55bb86e [Davies Liu] support Python UDT in __main__ (without Scala one)
1 parent f5dd113 commit e044705

File tree

9 files changed

+286
-93
lines changed

9 files changed

+286
-93
lines changed

pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ enable=
8484
# If you would like to improve the code quality of pyspark, remove any of these disabled errors
8585
# run ./dev/lint-python and see if the errors raised by pylint can be fixed.
8686

87-
disable=invalid-name,missing-docstring,protected-access,unused-argument,no-member,unused-wildcard-import,redefined-builtin,too-many-arguments,unused-variable,too-few-public-methods,bad-continuation,duplicate-code,redefined-outer-name,too-many-ancestors,import-error,superfluous-parens,unused-import,line-too-long,no-name-in-module,unnecessary-lambda,import-self,no-self-use,unidiomatic-typecheck,fixme,too-many-locals,cyclic-import,too-many-branches,bare-except,wildcard-import,dangerous-default-value,broad-except,too-many-public-methods,deprecated-lambda,anomalous-backslash-in-string,too-many-lines,reimported,too-many-statements,bad-whitespace,unpacking-non-sequence,too-many-instance-attributes,abstract-method,old-style-class,global-statement,attribute-defined-outside-init,arguments-differ,undefined-all-variable,no-init,useless-else-on-loop,super-init-not-called,notimplemented-raised,too-many-return-statements,pointless-string-statement,global-variable-undefined,bad-classmethod-argument,too-many-format-args,parse-error,no-self-argument,pointless-statement,undefined-variable
87+
disable=invalid-name,missing-docstring,protected-access,unused-argument,no-member,unused-wildcard-import,redefined-builtin,too-many-arguments,unused-variable,too-few-public-methods,bad-continuation,duplicate-code,redefined-outer-name,too-many-ancestors,import-error,superfluous-parens,unused-import,line-too-long,no-name-in-module,unnecessary-lambda,import-self,no-self-use,unidiomatic-typecheck,fixme,too-many-locals,cyclic-import,too-many-branches,bare-except,wildcard-import,dangerous-default-value,broad-except,too-many-public-methods,deprecated-lambda,anomalous-backslash-in-string,too-many-lines,reimported,too-many-statements,bad-whitespace,unpacking-non-sequence,too-many-instance-attributes,abstract-method,old-style-class,global-statement,attribute-defined-outside-init,arguments-differ,undefined-all-variable,no-init,useless-else-on-loop,super-init-not-called,notimplemented-raised,too-many-return-statements,pointless-string-statement,global-variable-undefined,bad-classmethod-argument,too-many-format-args,parse-error,no-self-argument,pointless-statement,undefined-variable,undefined-loop-variable
8888

8989

9090
[REPORTS]

python/pyspark/cloudpickle.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,26 @@ def save_global(self, obj, name=None, pack=struct.pack):
350350
if new_override:
351351
d['__new__'] = obj.__new__
352352

353-
self.save_reduce(typ, (obj.__name__, obj.__bases__, d), obj=obj)
353+
self.save(_load_class)
354+
self.save_reduce(typ, (obj.__name__, obj.__bases__, {"__doc__": obj.__doc__}), obj=obj)
355+
d.pop('__doc__', None)
356+
# handle property and staticmethod
357+
dd = {}
358+
for k, v in d.items():
359+
if isinstance(v, property):
360+
k = ('property', k)
361+
v = (v.fget, v.fset, v.fdel, v.__doc__)
362+
elif isinstance(v, staticmethod) and hasattr(v, '__func__'):
363+
k = ('staticmethod', k)
364+
v = v.__func__
365+
elif isinstance(v, classmethod) and hasattr(v, '__func__'):
366+
k = ('classmethod', k)
367+
v = v.__func__
368+
dd[k] = v
369+
self.save(dd)
370+
self.write(pickle.TUPLE2)
371+
self.write(pickle.REDUCE)
372+
354373
else:
355374
raise pickle.PicklingError("Can't pickle %r" % obj)
356375

@@ -708,6 +727,23 @@ def _make_skel_func(code, closures, base_globals = None):
708727
None, None, closure)
709728

710729

730+
def _load_class(cls, d):
731+
"""
732+
Loads additional properties into class `cls`.
733+
"""
734+
for k, v in d.items():
735+
if isinstance(k, tuple):
736+
typ, k = k
737+
if typ == 'property':
738+
v = property(*v)
739+
elif typ == 'staticmethod':
740+
v = staticmethod(v)
741+
elif typ == 'classmethod':
742+
v = classmethod(v)
743+
setattr(cls, k, v)
744+
return cls
745+
746+
711747
"""Constructors for 3rd party libraries
712748
Note: These can never be renamed due to client compatibility issues"""
713749

python/pyspark/shuffle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ def _open_file(self):
606606
if not os.path.exists(d):
607607
os.makedirs(d)
608608
p = os.path.join(d, str(id(self)))
609-
self._file = open(p, "wb+", 65536)
609+
self._file = open(p, "w+b", 65536)
610610
self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024)
611611
os.unlink(p)
612612

python/pyspark/sql/context.py

Lines changed: 67 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,66 @@ def applySchema(self, rdd, schema):
277277

278278
return self.createDataFrame(rdd, schema)
279279

280+
def _createFromRDD(self, rdd, schema, samplingRatio):
281+
"""
282+
Create an RDD for DataFrame from an existing RDD, returns the RDD and schema.
283+
"""
284+
if schema is None or isinstance(schema, (list, tuple)):
285+
struct = self._inferSchema(rdd, samplingRatio)
286+
converter = _create_converter(struct)
287+
rdd = rdd.map(converter)
288+
if isinstance(schema, (list, tuple)):
289+
for i, name in enumerate(schema):
290+
struct.fields[i].name = name
291+
struct.names[i] = name
292+
schema = struct
293+
294+
elif isinstance(schema, StructType):
295+
# take the first few rows to verify schema
296+
rows = rdd.take(10)
297+
for row in rows:
298+
_verify_type(row, schema)
299+
300+
else:
301+
raise TypeError("schema should be StructType or list or None, but got: %s" % schema)
302+
303+
# convert python objects to sql data
304+
rdd = rdd.map(schema.toInternal)
305+
return rdd, schema
306+
307+
def _createFromLocal(self, data, schema):
308+
"""
309+
Create an RDD for DataFrame from an list or pandas.DataFrame, returns
310+
the RDD and schema.
311+
"""
312+
if has_pandas and isinstance(data, pandas.DataFrame):
313+
if schema is None:
314+
schema = [str(x) for x in data.columns]
315+
data = [r.tolist() for r in data.to_records(index=False)]
316+
317+
# make sure data could consumed multiple times
318+
if not isinstance(data, list):
319+
data = list(data)
320+
321+
if schema is None or isinstance(schema, (list, tuple)):
322+
struct = self._inferSchemaFromList(data)
323+
if isinstance(schema, (list, tuple)):
324+
for i, name in enumerate(schema):
325+
struct.fields[i].name = name
326+
struct.names[i] = name
327+
schema = struct
328+
329+
elif isinstance(schema, StructType):
330+
for row in data:
331+
_verify_type(row, schema)
332+
333+
else:
334+
raise TypeError("schema should be StructType or list or None, but got: %s" % schema)
335+
336+
# convert python objects to sql data
337+
data = [schema.toInternal(row) for row in data]
338+
return self._sc.parallelize(data), schema
339+
280340
@since(1.3)
281341
@ignore_unicode_prefix
282342
def createDataFrame(self, data, schema=None, samplingRatio=None):
@@ -340,49 +400,15 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
340400
if isinstance(data, DataFrame):
341401
raise TypeError("data is already a DataFrame")
342402

343-
if has_pandas and isinstance(data, pandas.DataFrame):
344-
if schema is None:
345-
schema = [str(x) for x in data.columns]
346-
data = [r.tolist() for r in data.to_records(index=False)]
347-
348-
if not isinstance(data, RDD):
349-
if not isinstance(data, list):
350-
data = list(data)
351-
try:
352-
# data could be list, tuple, generator ...
353-
rdd = self._sc.parallelize(data)
354-
except Exception:
355-
raise TypeError("cannot create an RDD from type: %s" % type(data))
403+
if isinstance(data, RDD):
404+
rdd, schema = self._createFromRDD(data, schema, samplingRatio)
356405
else:
357-
rdd = data
358-
359-
if schema is None or isinstance(schema, (list, tuple)):
360-
if isinstance(data, RDD):
361-
struct = self._inferSchema(rdd, samplingRatio)
362-
else:
363-
struct = self._inferSchemaFromList(data)
364-
if isinstance(schema, (list, tuple)):
365-
for i, name in enumerate(schema):
366-
struct.fields[i].name = name
367-
schema = struct
368-
converter = _create_converter(schema)
369-
rdd = rdd.map(converter)
370-
371-
elif isinstance(schema, StructType):
372-
# take the first few rows to verify schema
373-
rows = rdd.take(10)
374-
for row in rows:
375-
_verify_type(row, schema)
376-
377-
else:
378-
raise TypeError("schema should be StructType or list or None")
379-
380-
# convert python objects to sql data
381-
rdd = rdd.map(schema.toInternal)
382-
406+
rdd, schema = self._createFromLocal(data, schema)
383407
jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
384-
df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
385-
return DataFrame(df, self)
408+
jdf = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
409+
df = DataFrame(jdf, self)
410+
df._schema = schema
411+
return df
386412

387413
@since(1.3)
388414
def registerDataFrameAsTable(self, df, tableName):

python/pyspark/sql/tests.py

Lines changed: 103 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def sqlType(self):
7575

7676
@classmethod
7777
def module(cls):
78-
return 'pyspark.tests'
78+
return 'pyspark.sql.tests'
7979

8080
@classmethod
8181
def scalaUDT(cls):
@@ -106,10 +106,45 @@ def __str__(self):
106106
return "(%s,%s)" % (self.x, self.y)
107107

108108
def __eq__(self, other):
109-
return isinstance(other, ExamplePoint) and \
109+
return isinstance(other, self.__class__) and \
110110
other.x == self.x and other.y == self.y
111111

112112

113+
class PythonOnlyUDT(UserDefinedType):
114+
"""
115+
User-defined type (UDT) for ExamplePoint.
116+
"""
117+
118+
@classmethod
119+
def sqlType(self):
120+
return ArrayType(DoubleType(), False)
121+
122+
@classmethod
123+
def module(cls):
124+
return '__main__'
125+
126+
def serialize(self, obj):
127+
return [obj.x, obj.y]
128+
129+
def deserialize(self, datum):
130+
return PythonOnlyPoint(datum[0], datum[1])
131+
132+
@staticmethod
133+
def foo():
134+
pass
135+
136+
@property
137+
def props(self):
138+
return {}
139+
140+
141+
class PythonOnlyPoint(ExamplePoint):
142+
"""
143+
An example class to demonstrate UDT in only Python
144+
"""
145+
__UDT__ = PythonOnlyUDT()
146+
147+
113148
class DataTypeTests(unittest.TestCase):
114149
# regression test for SPARK-6055
115150
def test_data_type_eq(self):
@@ -395,47 +430,106 @@ def test_convert_row_to_dict(self):
395430
self.assertEqual(1, row.asDict()["l"][0].a)
396431
self.assertEqual(1.0, row.asDict()['d']['key'].c)
397432

433+
def test_udt(self):
434+
from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _verify_type
435+
from pyspark.sql.tests import ExamplePointUDT, ExamplePoint
436+
437+
def check_datatype(datatype):
438+
pickled = pickle.loads(pickle.dumps(datatype))
439+
assert datatype == pickled
440+
scala_datatype = self.sqlCtx._ssql_ctx.parseDataType(datatype.json())
441+
python_datatype = _parse_datatype_json_string(scala_datatype.json())
442+
assert datatype == python_datatype
443+
444+
check_datatype(ExamplePointUDT())
445+
structtype_with_udt = StructType([StructField("label", DoubleType(), False),
446+
StructField("point", ExamplePointUDT(), False)])
447+
check_datatype(structtype_with_udt)
448+
p = ExamplePoint(1.0, 2.0)
449+
self.assertEqual(_infer_type(p), ExamplePointUDT())
450+
_verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT())
451+
self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], ExamplePointUDT()))
452+
453+
check_datatype(PythonOnlyUDT())
454+
structtype_with_udt = StructType([StructField("label", DoubleType(), False),
455+
StructField("point", PythonOnlyUDT(), False)])
456+
check_datatype(structtype_with_udt)
457+
p = PythonOnlyPoint(1.0, 2.0)
458+
self.assertEqual(_infer_type(p), PythonOnlyUDT())
459+
_verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT())
460+
self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT()))
461+
398462
def test_infer_schema_with_udt(self):
399463
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
400464
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
401-
df = self.sc.parallelize([row]).toDF()
465+
df = self.sqlCtx.createDataFrame([row])
402466
schema = df.schema
403467
field = [f for f in schema.fields if f.name == "point"][0]
404468
self.assertEqual(type(field.dataType), ExamplePointUDT)
405469
df.registerTempTable("labeled_point")
406470
point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
407471
self.assertEqual(point, ExamplePoint(1.0, 2.0))
408472

473+
row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
474+
df = self.sqlCtx.createDataFrame([row])
475+
schema = df.schema
476+
field = [f for f in schema.fields if f.name == "point"][0]
477+
self.assertEqual(type(field.dataType), PythonOnlyUDT)
478+
df.registerTempTable("labeled_point")
479+
point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
480+
self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
481+
409482
def test_apply_schema_with_udt(self):
410483
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
411484
row = (1.0, ExamplePoint(1.0, 2.0))
412-
rdd = self.sc.parallelize([row])
413485
schema = StructType([StructField("label", DoubleType(), False),
414486
StructField("point", ExamplePointUDT(), False)])
415-
df = rdd.toDF(schema)
487+
df = self.sqlCtx.createDataFrame([row], schema)
416488
point = df.head().point
417489
self.assertEquals(point, ExamplePoint(1.0, 2.0))
418490

491+
row = (1.0, PythonOnlyPoint(1.0, 2.0))
492+
schema = StructType([StructField("label", DoubleType(), False),
493+
StructField("point", PythonOnlyUDT(), False)])
494+
df = self.sqlCtx.createDataFrame([row], schema)
495+
point = df.head().point
496+
self.assertEquals(point, PythonOnlyPoint(1.0, 2.0))
497+
419498
def test_udf_with_udt(self):
420499
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
421500
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
422-
df = self.sc.parallelize([row]).toDF()
501+
df = self.sqlCtx.createDataFrame([row])
423502
self.assertEqual(1.0, df.map(lambda r: r.point.x).first())
424503
udf = UserDefinedFunction(lambda p: p.y, DoubleType())
425504
self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
426505
udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
427506
self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
428507

508+
row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
509+
df = self.sqlCtx.createDataFrame([row])
510+
self.assertEqual(1.0, df.map(lambda r: r.point.x).first())
511+
udf = UserDefinedFunction(lambda p: p.y, DoubleType())
512+
self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
513+
udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT())
514+
self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
515+
429516
def test_parquet_with_udt(self):
430-
from pyspark.sql.tests import ExamplePoint
517+
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
431518
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
432-
df0 = self.sc.parallelize([row]).toDF()
519+
df0 = self.sqlCtx.createDataFrame([row])
433520
output_dir = os.path.join(self.tempdir.name, "labeled_point")
434-
df0.saveAsParquetFile(output_dir)
521+
df0.write.parquet(output_dir)
435522
df1 = self.sqlCtx.parquetFile(output_dir)
436523
point = df1.head().point
437524
self.assertEquals(point, ExamplePoint(1.0, 2.0))
438525

526+
row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
527+
df0 = self.sqlCtx.createDataFrame([row])
528+
df0.write.parquet(output_dir, mode='overwrite')
529+
df1 = self.sqlCtx.parquetFile(output_dir)
530+
point = df1.head().point
531+
self.assertEquals(point, PythonOnlyPoint(1.0, 2.0))
532+
439533
def test_column_operators(self):
440534
ci = self.df.key
441535
cs = self.df.value

0 commit comments

Comments
 (0)