Skip to content

Commit 2f6cc7a

Browse files
authored
Optional coerce (#367)
* Remove extra line * added optional configuration to have optional coercion * fix circular dependency between conf and utils * add gcc installation for dev build * fix parsing bug for coerce value * fix parsing bug for coerce value 2
1 parent f94809a commit 2f6cc7a

File tree

14 files changed

+233
-69
lines changed

14 files changed

+233
-69
lines changed

Dockerfile.jupyter

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@ FROM jupyter/base-notebook:d0b2d159cc6c
22

33
ARG dev_mode=false
44

5+
USER root
6+
7+
# This is needed because requests-kerberos fails to install on debian due to missing linux headers
8+
RUN conda install requests-kerberos -y
9+
510
USER $NB_USER
611

712
# Install sparkmagic - if DEV_MODE is set, use the one in the host directory.

sparkmagic/example_config.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
},
5858

5959
"use_auto_viz": true,
60+
"coerce_dataframe": true,
6061
"max_results_sql": 2500,
6162
"pyspark_dataframe_encoding": "utf-8",
6263

sparkmagic/sparkmagic/kernels/kernelmagics.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
from hdijupyterutils.utils import generate_uuid
1313

1414
import sparkmagic.utils.configuration as conf
15+
from sparkmagic.utils.configuration import get_livy_kind
1516
from sparkmagic.utils import constants
16-
from sparkmagic.utils.utils import get_livy_kind, parse_argstring_or_throw
17+
from sparkmagic.utils.utils import parse_argstring_or_throw, get_coerce_value
1718
from sparkmagic.utils.sparkevents import SparkEvents
1819
from sparkmagic.utils.constants import LANGS_SUPPORTED
1920
from sparkmagic.livyclientlib.command import Command
@@ -211,12 +212,17 @@ def configure(self, line, cell="", local_ns=None):
211212
@argument("-n", "--maxrows", type=int, default=None, help="Maximum number of rows that will be pulled back "
212213
"from the dataframe on the server for storing")
213214
@argument("-r", "--samplefraction", type=float, default=None, help="Sample fraction for sampling from dataframe")
215+
@argument("-c", "--coerce", type=str, default=None, help="Whether to automatically coerce the types (default, pass True if being explicit) "
216+
"of the dataframe or not (pass False)")
214217
@wrap_unexpected_exceptions
215218
@handle_expected_exceptions
216219
def spark(self, line, cell="", local_ns=None):
217220
if self._do_not_call_start_session(u""):
218221
args = parse_argstring_or_throw(self.spark, line)
219-
self.execute_spark(cell, args.output, args.samplemethod, args.maxrows, args.samplefraction, None)
222+
223+
coerce = get_coerce_value(args.coerce)
224+
225+
self.execute_spark(cell, args.output, args.samplemethod, args.maxrows, args.samplefraction, None, coerce)
220226
else:
221227
return
222228

@@ -230,13 +236,18 @@ def spark(self, line, cell="", local_ns=None):
230236
@argument("-n", "--maxrows", type=int, default=None, help="Maximum number of rows that will be pulled back "
231237
"from the server for SQL queries")
232238
@argument("-r", "--samplefraction", type=float, default=None, help="Sample fraction for sampling from SQL queries")
239+
@argument("-c", "--coerce", type=str, default=None, help="Whether to automatically coerce the types (default, pass True if being explicit) "
240+
"of the dataframe or not (pass False)")
233241
@wrap_unexpected_exceptions
234242
@handle_expected_exceptions
235243
def sql(self, line, cell="", local_ns=None):
236244
if self._do_not_call_start_session(""):
237245
args = parse_argstring_or_throw(self.sql, line)
246+
247+
coerce = get_coerce_value(args.coerce)
248+
238249
return self.execute_sqlquery(cell, args.samplemethod, args.maxrows, args.samplefraction,
239-
None, args.output, args.quiet)
250+
None, args.output, args.quiet, coerce)
240251
else:
241252
return
242253

sparkmagic/sparkmagic/livyclientlib/exceptions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class BadUserConfigurationException(LivyClientLibException):
4646
class BadUserDataException(LivyClientLibException):
4747
"""An exception that is thrown when data provided by the user is invalid
4848
in some way."""
49+
4950

5051
class SqlContextNotFoundException(LivyClientLibException):
5152
"""Exception that is thrown when the SQL context is not found."""

sparkmagic/sparkmagic/livyclientlib/sparkstorecommand.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from sparkmagic.utils.utils import coerce_pandas_df_to_numeric_datetime, records_to_dataframe
1+
from sparkmagic.utils.utils import records_to_dataframe
22
import sparkmagic.utils.configuration as conf
33
import sparkmagic.utils.constants as constants
44
from sparkmagic.utils.sparkevents import SparkEvents
@@ -8,7 +8,7 @@
88
import ast
99

1010
class SparkStoreCommand(Command):
11-
def __init__(self, output_var, samplemethod=None, maxrows=None, samplefraction=None, spark_events=None):
11+
def __init__(self, output_var, samplemethod=None, maxrows=None, samplefraction=None, spark_events=None, coerce=None):
1212
super(SparkStoreCommand, self).__init__("", spark_events)
1313

1414
if samplemethod is None:
@@ -32,6 +32,7 @@ def __init__(self, output_var, samplemethod=None, maxrows=None, samplefraction=N
3232
if spark_events is None:
3333
spark_events = SparkEvents()
3434
self._spark_events = spark_events
35+
self._coerce = coerce
3536

3637

3738
def execute(self, session):
@@ -40,7 +41,7 @@ def execute(self, session):
4041
(success, records_text) = command.execute(session)
4142
if not success:
4243
raise BadUserDataException(records_text)
43-
result = records_to_dataframe(records_text, session.kind)
44+
result = records_to_dataframe(records_text, session.kind, self._coerce)
4445
except Exception as e:
4546
raise
4647
else:
@@ -114,7 +115,8 @@ def __eq__(self, other):
114115
self.samplemethod == other.samplemethod and \
115116
self.maxrows == other.maxrows and \
116117
self.samplefraction == other.samplefraction and \
117-
self.output_var == other.output_var
118+
self.output_var == other.output_var and \
119+
self._coerce == other._coerce
118120

119121
def __ne__(self, other):
120122
return not (self == other)

sparkmagic/sparkmagic/livyclientlib/sqlquery.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
class SQLQuery(ObjectWithGuid):
12-
def __init__(self, query, samplemethod=None, maxrows=None, samplefraction=None, spark_events=None):
12+
def __init__(self, query, samplemethod=None, maxrows=None, samplefraction=None, spark_events=None, coerce=None):
1313
super(SQLQuery, self).__init__()
1414

1515
if samplemethod is None:
@@ -33,6 +33,7 @@ def __init__(self, query, samplemethod=None, maxrows=None, samplefraction=None,
3333
if spark_events is None:
3434
spark_events = SparkEvents()
3535
self._spark_events = spark_events
36+
self._coerce = coerce
3637

3738
def to_command(self, kind, sql_context_variable_name):
3839
if kind == constants.SESSION_KIND_PYSPARK:
@@ -56,7 +57,7 @@ def execute(self, session):
5657
(success, records_text) = command.execute(session)
5758
if not success:
5859
raise BadUserDataException(records_text)
59-
result = records_to_dataframe(records_text, session.kind)
60+
result = records_to_dataframe(records_text, session.kind, self._coerce)
6061
except Exception as e:
6162
self._spark_events.emit_sql_execution_end_event(session.guid, session.kind, session.id, self.guid,
6263
command_guid, False, e.__class__.__name__, str(e))
@@ -117,7 +118,8 @@ def __eq__(self, other):
117118
return self.query == other.query and \
118119
self.samplemethod == other.samplemethod and \
119120
self.maxrows == other.maxrows and \
120-
self.samplefraction == other.samplefraction
121+
self.samplefraction == other.samplefraction and \
122+
self._coerce == other._coerce
121123

122124
def __ne__(self, other):
123125
return not (self == other)

sparkmagic/sparkmagic/magics/remotesparkmagics.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from hdijupyterutils.ipywidgetfactory import IpyWidgetFactory
1313

1414
import sparkmagic.utils.configuration as conf
15-
from sparkmagic.utils.utils import parse_argstring_or_throw
15+
from sparkmagic.utils.utils import parse_argstring_or_throw, get_coerce_value
1616
from sparkmagic.utils.constants import CONTEXT_NAME_SPARK, CONTEXT_NAME_SQL, LANG_PYTHON, LANG_R, LANG_SCALA
1717
from sparkmagic.controllerwidget.magicscontrollerwidget import MagicsControllerWidget
1818
from sparkmagic.livyclientlib.command import Command
@@ -60,6 +60,8 @@ def manage_spark(self, line, local_ns=None):
6060
@argument("command", type=str, default=[""], nargs="*", help="Commands to execute.")
6161
@argument("-k", "--skip", type=bool, default=False, nargs="?", const=True, help="Skip adding session if it already exists")
6262
@argument("-i", "--id", type=int, default=None, help="Session ID")
63+
@argument("-e", "--coerce", type=str, default=None, help="Whether to automatically coerce the types (default, pass True if being explicit) "
64+
"of the dataframe or not (pass False)")
6365
@needs_local_scope
6466
@line_cell_magic
6567
@handle_expected_exceptions
@@ -160,12 +162,13 @@ def spark(self, line, cell="", local_ns=None):
160162
self.ipython_display.write(self.spark_controller.get_logs(args.session))
161163
# run
162164
elif len(subcommand) == 0:
165+
coerce = get_coerce_value(args.coerce)
163166
if args.context == CONTEXT_NAME_SPARK:
164167
return self.execute_spark(cell, args.output, args.samplemethod,
165-
args.maxrows, args.samplefraction, args.session)
168+
args.maxrows, args.samplefraction, args.session, coerce)
166169
elif args.context == CONTEXT_NAME_SQL:
167170
return self.execute_sqlquery(cell, args.samplemethod, args.maxrows, args.samplefraction,
168-
args.session, args.output, args.quiet)
171+
args.session, args.output, args.quiet, coerce)
169172
else:
170173
self.ipython_display.send_error("Context '{}' not found".format(args.context))
171174
# error

sparkmagic/sparkmagic/magics/sparkmagicsbase.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,24 +37,24 @@ def __init__(self, shell, data=None, spark_events=None):
3737
spark_events = SparkEvents()
3838
spark_events.emit_library_loaded_event()
3939

40-
def execute_spark(self, cell, output_var, samplemethod, maxrows, samplefraction, session_name):
40+
def execute_spark(self, cell, output_var, samplemethod, maxrows, samplefraction, session_name, coerce):
4141
(success, out) = self.spark_controller.run_command(Command(cell), session_name)
4242
if not success:
4343
self.ipython_display.send_error(out)
4444
else:
4545
self.ipython_display.write(out)
4646
if output_var is not None:
47-
spark_store_command = self._spark_store_command(output_var, samplemethod, maxrows, samplefraction)
47+
spark_store_command = self._spark_store_command(output_var, samplemethod, maxrows, samplefraction, coerce)
4848
df = self.spark_controller.run_command(spark_store_command, session_name)
4949
self.shell.user_ns[output_var] = df
5050

5151
@staticmethod
52-
def _spark_store_command(output_var, samplemethod, maxrows, samplefraction):
53-
return SparkStoreCommand(output_var, samplemethod, maxrows, samplefraction)
52+
def _spark_store_command(output_var, samplemethod, maxrows, samplefraction, coerce):
53+
return SparkStoreCommand(output_var, samplemethod, maxrows, samplefraction, coerce=coerce)
5454

5555
def execute_sqlquery(self, cell, samplemethod, maxrows, samplefraction,
56-
session, output_var, quiet):
57-
sqlquery = self._sqlquery(cell, samplemethod, maxrows, samplefraction)
56+
session, output_var, quiet, coerce):
57+
sqlquery = self._sqlquery(cell, samplemethod, maxrows, samplefraction, coerce)
5858
df = self.spark_controller.run_sqlquery(sqlquery, session)
5959
if output_var is not None:
6060
self.shell.user_ns[output_var] = df
@@ -64,8 +64,8 @@ def execute_sqlquery(self, cell, samplemethod, maxrows, samplefraction,
6464
return df
6565

6666
@staticmethod
67-
def _sqlquery(cell, samplemethod, maxrows, samplefraction):
68-
return SQLQuery(cell, samplemethod, maxrows, samplefraction)
67+
def _sqlquery(cell, samplemethod, maxrows, samplefraction, coerce):
68+
return SQLQuery(cell, samplemethod, maxrows, samplefraction, coerce=coerce)
6969

7070
def _print_endpoint_info(self, info_sessions, current_session_id):
7171
if info_sessions:

sparkmagic/sparkmagic/tests/test_kernel_magics.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ def test_spark_unexpected_exception_in_storing():
515515
spark_controller.run_command = MagicMock(side_effect=side_effect)
516516

517517
magic.spark(line, cell)
518-
assert spark_controller.run_command.call_count == 2
518+
assert_equals(spark_controller.run_command.call_count, 2)
519519
spark_controller.run_command.assert_any_call(Command(cell), None)
520520
ipython_display.send_error.assert_called_with(constants.INTERNAL_ERROR_MSG
521521
.format(side_effect[1]))
@@ -535,15 +535,24 @@ def test_spark_expected_exception_in_storing():
535535
.format(side_effect[1]))
536536

537537

538-
539538
@with_setup(_setup, _teardown)
540539
def test_spark_sample_options():
541-
line = "-o var_name -m sample -n 142 -r 0.3"
540+
line = "-o var_name -m sample -n 142 -r 0.3 -c True"
541+
cell = ""
542+
magic.execute_spark = MagicMock()
543+
ret = magic.spark(line, cell)
544+
545+
magic.execute_spark.assert_called_once_with(cell, "var_name", "sample", 142, 0.3, None, True)
546+
547+
548+
@with_setup(_setup, _teardown)
549+
def test_spark_false_coerce():
550+
line = "-o var_name -m sample -n 142 -r 0.3 -c False"
542551
cell = ""
543552
magic.execute_spark = MagicMock()
544553
ret = magic.spark(line, cell)
545554

546-
magic.execute_spark.assert_called_once_with(cell, "var_name", "sample", 142, 0.3, None)
555+
magic.execute_spark.assert_called_once_with(cell, "var_name", "sample", 142, 0.3, None, False)
547556

548557

549558
@with_setup(_setup, _teardown)
@@ -556,7 +565,7 @@ def test_sql_without_output():
556565

557566
spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False,
558567
{"kind": constants.SESSION_KIND_PYSPARK})
559-
magic.execute_sqlquery.assert_called_once_with(cell, None, None, None, None, None, False)
568+
magic.execute_sqlquery.assert_called_once_with(cell, None, None, None, None, None, False, None)
560569

561570

562571
@with_setup(_setup, _teardown)
@@ -569,7 +578,7 @@ def test_sql_with_output():
569578

570579
spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False,
571580
{"kind": constants.SESSION_KIND_PYSPARK})
572-
magic.execute_sqlquery.assert_called_once_with(cell, None, None, None, None, "my_var", False)
581+
magic.execute_sqlquery.assert_called_once_with(cell, None, None, None, None, "my_var", False, None)
573582

574583

575584
@with_setup(_setup, _teardown)
@@ -579,7 +588,7 @@ def test_sql_exception():
579588
magic.execute_sqlquery = MagicMock(side_effect=ValueError('HAHAHAHAH'))
580589

581590
magic.sql(line, cell)
582-
magic.execute_sqlquery.assert_called_once_with(cell, None, None, None, None, "my_var", False)
591+
magic.execute_sqlquery.assert_called_once_with(cell, None, None, None, None, "my_var", False, None)
583592
ipython_display.send_error.assert_called_once_with(constants.INTERNAL_ERROR_MSG
584593
.format(magic.execute_sqlquery.side_effect))
585594

@@ -591,7 +600,7 @@ def test_sql_expected_exception():
591600
magic.execute_sqlquery = MagicMock(side_effect=HttpClientException('HAHAHAHAH'))
592601

593602
magic.sql(line, cell)
594-
magic.execute_sqlquery.assert_called_once_with(cell, None, None, None, None, "my_var", False)
603+
magic.execute_sqlquery.assert_called_once_with(cell, None, None, None, None, "my_var", False, None)
595604
ipython_display.send_error.assert_called_once_with(constants.EXPECTED_ERROR_MSG
596605
.format(magic.execute_sqlquery.side_effect))
597606

@@ -619,20 +628,33 @@ def test_sql_quiet():
619628

620629
spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False,
621630
{"kind": constants.SESSION_KIND_PYSPARK})
622-
magic.execute_sqlquery.assert_called_once_with(cell, None, None, None, None, "Output", True)
631+
magic.execute_sqlquery.assert_called_once_with(cell, None, None, None, None, "Output", True, None)
623632

624633

625634
@with_setup(_setup, _teardown)
626635
def test_sql_sample_options():
627-
line = "-q -m sample -n 142 -r 0.3"
636+
line = "-q -m sample -n 142 -r 0.3 -c True"
637+
cell = ""
638+
magic.execute_sqlquery = MagicMock()
639+
640+
ret = magic.sql(line, cell)
641+
642+
spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False,
643+
{"kind": constants.SESSION_KIND_PYSPARK})
644+
magic.execute_sqlquery.assert_called_once_with(cell, "sample", 142, 0.3, None, None, True, True)
645+
646+
647+
@with_setup(_setup, _teardown)
648+
def test_sql_false_coerce():
649+
line = "-q -m sample -n 142 -r 0.3 -c False"
628650
cell = ""
629651
magic.execute_sqlquery = MagicMock()
630652

631653
ret = magic.sql(line, cell)
632654

633655
spark_controller.add_session.assert_called_once_with(magic.session_name, magic.endpoint, False,
634656
{"kind": constants.SESSION_KIND_PYSPARK})
635-
magic.execute_sqlquery.assert_called_once_with(cell, "sample", 142, 0.3, None, None, True)
657+
magic.execute_sqlquery.assert_called_once_with(cell, "sample", 142, 0.3, None, None, True, False)
636658

637659

638660
@with_setup(_setup, _teardown)

0 commit comments

Comments
 (0)