Skip to content

Commit f3cdf24

Browse files
Adding and modifying tests
Signed-off-by: Mohit Singla <[email protected]>
1 parent 8cdf57d commit f3cdf24

File tree

3 files changed

+66
-23
lines changed

3 files changed

+66
-23
lines changed

tests/e2e/common/large_queries_mixin.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,13 @@ def test_query_with_large_wide_result_set(self):
4949
# This is used by PyHive tests to determine the buffer size
5050
self.arraysize = 1000
5151
with self.cursor() as cursor:
52-
uuids = ", ".join(["uuid() uuid{}".format(i) for i in range(cols)])
53-
cursor.execute("SELECT id, {uuids} FROM RANGE({rows})".format(uuids=uuids, rows=rows))
54-
for row_id, row in enumerate(self.fetch_rows(cursor, rows, fetchmany_size)):
55-
self.assertEqual(row[0], row_id) # Verify no rows are dropped in the middle.
56-
self.assertEqual(len(row[1]), 36)
52+
for lz4_compression in [False, True]:
53+
cursor.setLZ4Compression(lz4_compression)
54+
uuids = ", ".join(["uuid() uuid{}".format(i) for i in range(cols)])
55+
cursor.execute("SELECT id, {uuids} FROM RANGE({rows})".format(uuids=uuids, rows=rows))
56+
for row_id, row in enumerate(self.fetch_rows(cursor, rows, fetchmany_size)):
57+
self.assertEqual(row[0], row_id) # Verify no rows are dropped in the middle.
58+
self.assertEqual(len(row[1]), 36)
5759

5860
def test_query_with_large_narrow_result_set(self):
5961
resultSize = 300 * 1000 * 1000 # 300 MB
@@ -65,9 +67,11 @@ def test_query_with_large_narrow_result_set(self):
6567
# This is used by PyHive tests to determine the buffer size
6668
self.arraysize = 10000000
6769
with self.cursor() as cursor:
68-
cursor.execute("SELECT * FROM RANGE({rows})".format(rows=rows))
69-
for row_id, row in enumerate(self.fetch_rows(cursor, rows, fetchmany_size)):
70-
self.assertEqual(row[0], row_id)
70+
for lz4_compression in [False, True]:
71+
cursor.setLZ4Compression(lz4_compression)
72+
cursor.execute("SELECT * FROM RANGE({rows})".format(rows=rows))
73+
for row_id, row in enumerate(self.fetch_rows(cursor, rows, fetchmany_size)):
74+
self.assertEqual(row[0], row_id)
7175

7276
def test_long_running_query(self):
7377
""" Incrementally increase query size until it takes at least 5 minutes,
@@ -80,21 +84,23 @@ def test_long_running_query(self):
8084
scale0 = 10000
8185
scale_factor = 1
8286
with self.cursor() as cursor:
83-
while duration < min_duration:
84-
self.assertLess(scale_factor, 512, msg="Detected infinite loop")
85-
start = time.time()
87+
for lz4_compression in [False, True]:
88+
cursor.setLZ4Compression(lz4_compression)
89+
while duration < min_duration:
90+
self.assertLess(scale_factor, 512, msg="Detected infinite loop")
91+
start = time.time()
8692

87-
cursor.execute("""SELECT count(*)
88-
FROM RANGE({scale}) x
89-
JOIN RANGE({scale0}) y
90-
ON from_unixtime(x.id * y.id, "yyyy-MM-dd") LIKE "%not%a%date%"
91-
""".format(scale=scale_factor * scale0, scale0=scale0))
93+
cursor.execute("""SELECT count(*)
94+
FROM RANGE({scale}) x
95+
JOIN RANGE({scale0}) y
96+
ON from_unixtime(x.id * y.id, "yyyy-MM-dd") LIKE "%not%a%date%"
97+
""".format(scale=scale_factor * scale0, scale0=scale0))
9298

93-
n, = cursor.fetchone()
94-
self.assertEqual(n, 0)
99+
n, = cursor.fetchone()
100+
self.assertEqual(n, 0)
95101

96-
duration = time.time() - start
97-
current_fraction = duration / min_duration
98-
print('Took {} s with scale factor={}'.format(duration, scale_factor))
99-
# Extrapolate linearly to reach 5 min and add 50% padding to push over the limit
100-
scale_factor = math.ceil(1.5 * scale_factor / current_fraction)
102+
duration = time.time() - start
103+
current_fraction = duration / min_duration
104+
print('Took {} s with scale factor={}'.format(duration, scale_factor))
105+
# Extrapolate linearly to reach 5 min and add 50% padding to push over the limit
106+
scale_factor = math.ceil(1.5 * scale_factor / current_fraction)

tests/e2e/driver_tests.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,20 @@ def test_timezone_with_timestamp(self):
510510
self.assertEqual(arrow_result_table.field(0).type, ts_type)
511511
self.assertEqual(arrow_result_value, expected.timestamp() * 1000000)
512512

513+
@skipUnless(pysql_supports_arrow(), 'arrow test needs arrow support')
514+
def test_can_flip_compression(self):
515+
with self.cursor() as cursor:
516+
cursor.execute("SELECT array(1,2,3,4)")
517+
cursor.fetchall()
518+
lz4_compressed = cursor.active_result_set.lz4_compressed
519+
#The endpoint should support compression
520+
self.assertEqual(lz4_compressed, True)
521+
cursor.setLZ4Compression(False)
522+
cursor.execute("SELECT array(1,2,3,4)")
523+
cursor.fetchall()
524+
lz4_compressed = cursor.active_result_set.lz4_compressed
525+
self.assertEqual(lz4_compressed, False)
526+
513527
def _should_have_native_complex_types(self):
514528
return pysql_has_version(">=", 2) and is_thrift_v5_plus(self.arguments)
515529

tests/unit/test_thrift_backend.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,29 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self):
309309
thrift_backend._handle_execute_response(t_execute_resp, Mock())
310310
self.assertIn("some information about the error", str(cm.exception))
311311

312+
def test_handle_execute_response_sets_compression_in_direct_results(self):
313+
for resp_type in self.execute_response_types:
314+
lz4Compressed=Mock()
315+
resultSet=MagicMock()
316+
resultSet.results.startRowOffset = 0
317+
t_execute_resp = resp_type(
318+
status=Mock(),
319+
operationHandle=Mock(),
320+
directResults=ttypes.TSparkDirectResults(
321+
operationStatus= Mock(),
322+
resultSetMetadata=ttypes.TGetResultSetMetadataResp(
323+
status=self.okay_status,
324+
resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET,
325+
schema=MagicMock(),
326+
arrowSchema=MagicMock(),
327+
lz4Compressed=lz4Compressed),
328+
resultSet=resultSet,
329+
closeOperation=None))
330+
thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider())
331+
332+
execute_response = thrift_backend._handle_execute_response(t_execute_resp, Mock())
333+
self.assertEqual(execute_response.lz4_compressed, lz4Compressed)
334+
312335
@patch("databricks.sql.thrift_backend.TCLIService.Client")
313336
def test_handle_execute_response_checks_operation_state_in_polls(self, tcli_service_class):
314337
tcli_service_instance = tcli_service_class.return_value

0 commit comments

Comments
 (0)