Skip to content

Commit 0b2ef6c

Browse files
integrate simple retries in Sea client
Signed-off-by: varun-edachali-dbx <[email protected]>
1 parent 6bf5995 commit 0b2ef6c

File tree

3 files changed

+408
-93
lines changed

3 files changed

+408
-93
lines changed

src/databricks/sql/backend/sea/backend.py

Lines changed: 203 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,40 @@ def __init__(
130130
# Extract warehouse ID from http_path
131131
self.warehouse_id = self._extract_warehouse_id(http_path)
132132

133+
# Extract retry policy parameters
134+
retry_policy = kwargs.get("_retry_policy", None)
135+
retry_stop_after_attempts_count = kwargs.get(
136+
"_retry_stop_after_attempts_count", 30
137+
)
138+
retry_stop_after_attempts_duration = kwargs.get(
139+
"_retry_stop_after_attempts_duration", 600
140+
)
141+
retry_delay_min = kwargs.get("_retry_delay_min", 1)
142+
retry_delay_max = kwargs.get("_retry_delay_max", 60)
143+
retry_delay_default = kwargs.get("_retry_delay_default", 5)
144+
retry_dangerous_codes = kwargs.get("_retry_dangerous_codes", [])
145+
146+
# Create retry policy if not provided
147+
if not retry_policy:
148+
from databricks.sql.auth.retry import DatabricksRetryPolicy
149+
150+
retry_policy = DatabricksRetryPolicy(
151+
delay_min=retry_delay_min,
152+
delay_max=retry_delay_max,
153+
stop_after_attempts_count=retry_stop_after_attempts_count,
154+
stop_after_attempts_duration=retry_stop_after_attempts_duration,
155+
delay_default=retry_delay_default,
156+
force_dangerous_codes=retry_dangerous_codes,
157+
)
158+
133159
# Initialize ThriftHttpClient
134160
thrift_client = THttpClient(
135161
auth_provider=auth_provider,
136162
uri_or_host=f"https://{server_hostname}:{port}",
137163
path=http_path,
138164
ssl_options=ssl_options,
139165
max_connections=kwargs.get("max_connections", 1),
140-
retry_policy=kwargs.get("_retry_stop_after_attempts_count", 30),
166+
retry_policy=retry_policy, # Use the configured retry policy
141167
)
142168

143169
# Set custom headers
@@ -229,22 +255,31 @@ def open_session(
229255
schema=schema,
230256
)
231257

232-
response = self.http_client.post(
233-
path=self.SESSION_PATH, data=request_data.to_dict()
234-
)
235-
236-
session_response = CreateSessionResponse.from_dict(response)
237-
session_id = session_response.session_id
238-
if not session_id:
239-
raise ServerOperationError(
240-
"Failed to create session: No session ID returned",
241-
{
242-
"operation-id": None,
243-
"diagnostic-info": None,
244-
},
258+
try:
259+
response = self.http_client.post(
260+
path=self.SESSION_PATH, data=request_data.to_dict()
245261
)
246262

247-
return SessionId.from_sea_session_id(session_id)
263+
session_response = CreateSessionResponse.from_dict(response)
264+
session_id = session_response.session_id
265+
if not session_id:
266+
raise ServerOperationError(
267+
"Failed to create session: No session ID returned",
268+
{
269+
"operation-id": None,
270+
"diagnostic-info": None,
271+
},
272+
)
273+
274+
return SessionId.from_sea_session_id(session_id)
275+
except Exception as e:
276+
# Map exceptions to match Thrift behavior
277+
from databricks.sql.exc import RequestError, OperationalError
278+
279+
if isinstance(e, (RequestError, ServerOperationError)):
280+
raise
281+
else:
282+
raise OperationalError(f"Error opening session: {str(e)}")
248283

249284
def close_session(self, session_id: SessionId) -> None:
250285
"""
@@ -269,10 +304,25 @@ def close_session(self, session_id: SessionId) -> None:
269304
session_id=sea_session_id,
270305
)
271306

272-
self.http_client.delete(
273-
path=self.SESSION_PATH_WITH_ID.format(sea_session_id),
274-
data=request_data.to_dict(),
275-
)
307+
try:
308+
self.http_client.delete(
309+
path=self.SESSION_PATH_WITH_ID.format(sea_session_id),
310+
data=request_data.to_dict(),
311+
)
312+
except Exception as e:
313+
# Map exceptions to match Thrift behavior
314+
from databricks.sql.exc import (
315+
RequestError,
316+
OperationalError,
317+
SessionAlreadyClosedError,
318+
)
319+
320+
if isinstance(e, RequestError) and "404" in str(e):
321+
raise SessionAlreadyClosedError("Session is already closed")
322+
elif isinstance(e, (RequestError, ServerOperationError)):
323+
raise
324+
else:
325+
raise OperationalError(f"Error closing session: {str(e)}")
276326

277327
@staticmethod
278328
def get_default_session_configuration_value(name: str) -> Optional[str]:
@@ -475,48 +525,57 @@ def execute_command(
475525
result_compression=result_compression,
476526
)
477527

478-
response_data = self.http_client.post(
479-
path=self.STATEMENT_PATH, data=request.to_dict()
480-
)
481-
response = ExecuteStatementResponse.from_dict(response_data)
482-
statement_id = response.statement_id
483-
if not statement_id:
484-
raise ServerOperationError(
485-
"Failed to execute command: No statement ID returned",
486-
{
487-
"operation-id": None,
488-
"diagnostic-info": None,
489-
},
528+
try:
529+
response_data = self.http_client.post(
530+
path=self.STATEMENT_PATH, data=request.to_dict()
490531
)
532+
response = ExecuteStatementResponse.from_dict(response_data)
533+
statement_id = response.statement_id
534+
if not statement_id:
535+
raise ServerOperationError(
536+
"Failed to execute command: No statement ID returned",
537+
{
538+
"operation-id": None,
539+
"diagnostic-info": None,
540+
},
541+
)
491542

492-
command_id = CommandId.from_sea_statement_id(statement_id)
543+
command_id = CommandId.from_sea_statement_id(statement_id)
493544

494-
# Store the command ID in the cursor
495-
cursor.active_command_id = command_id
545+
# Store the command ID in the cursor
546+
cursor.active_command_id = command_id
496547

497-
# If async operation, return and let the client poll for results
498-
if async_op:
499-
return None
548+
# If async operation, return and let the client poll for results
549+
if async_op:
550+
return None
500551

501-
# For synchronous operation, wait for the statement to complete
502-
status = response.status
503-
state = status.state
552+
# For synchronous operation, wait for the statement to complete
553+
status = response.status
554+
state = status.state
504555

505-
# Keep polling until we reach a terminal state
506-
while state in [CommandState.PENDING, CommandState.RUNNING]:
507-
time.sleep(0.5) # add a small delay to avoid excessive API calls
508-
state = self.get_query_state(command_id)
556+
# Keep polling until we reach a terminal state
557+
while state in [CommandState.PENDING, CommandState.RUNNING]:
558+
time.sleep(0.5) # add a small delay to avoid excessive API calls
559+
state = self.get_query_state(command_id)
509560

510-
if state != CommandState.SUCCEEDED:
511-
raise ServerOperationError(
512-
f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}",
513-
{
514-
"operation-id": command_id.to_sea_statement_id(),
515-
"diagnostic-info": None,
516-
},
517-
)
561+
if state != CommandState.SUCCEEDED:
562+
raise ServerOperationError(
563+
f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}",
564+
{
565+
"operation-id": command_id.to_sea_statement_id(),
566+
"diagnostic-info": None,
567+
},
568+
)
518569

519-
return self.get_execution_result(command_id, cursor)
570+
return self.get_execution_result(command_id, cursor)
571+
except Exception as e:
572+
# Map exceptions to match Thrift behavior
573+
from databricks.sql.exc import RequestError, OperationalError
574+
575+
if isinstance(e, (RequestError, ServerOperationError)):
576+
raise
577+
else:
578+
raise OperationalError(f"Error executing command: {str(e)}")
520579

521580
def cancel_command(self, command_id: CommandId) -> None:
522581
"""
@@ -535,10 +594,25 @@ def cancel_command(self, command_id: CommandId) -> None:
535594
sea_statement_id = command_id.to_sea_statement_id()
536595

537596
request = CancelStatementRequest(statement_id=sea_statement_id)
538-
self.http_client.post(
539-
path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id),
540-
data=request.to_dict(),
541-
)
597+
try:
598+
self.http_client.post(
599+
path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id),
600+
data=request.to_dict(),
601+
)
602+
except Exception as e:
603+
# Map exceptions to match Thrift behavior
604+
from databricks.sql.exc import RequestError, OperationalError
605+
606+
if isinstance(e, RequestError) and "404" in str(e):
607+
# Operation was already closed, so we can ignore this
608+
logger.warning(
609+
f"Attempted to cancel a command that was already closed: {sea_statement_id}"
610+
)
611+
return
612+
elif isinstance(e, (RequestError, ServerOperationError)):
613+
raise
614+
else:
615+
raise OperationalError(f"Error canceling command: {str(e)}")
542616

543617
def close_command(self, command_id: CommandId) -> None:
544618
"""
@@ -557,10 +631,25 @@ def close_command(self, command_id: CommandId) -> None:
557631
sea_statement_id = command_id.to_sea_statement_id()
558632

559633
request = CloseStatementRequest(statement_id=sea_statement_id)
560-
self.http_client.delete(
561-
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
562-
data=request.to_dict(),
563-
)
634+
try:
635+
self.http_client.delete(
636+
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
637+
data=request.to_dict(),
638+
)
639+
except Exception as e:
640+
# Map exceptions to match Thrift behavior
641+
from databricks.sql.exc import (
642+
RequestError,
643+
OperationalError,
644+
CursorAlreadyClosedError,
645+
)
646+
647+
if isinstance(e, RequestError) and "404" in str(e):
648+
raise CursorAlreadyClosedError("Cursor is already closed")
649+
elif isinstance(e, (RequestError, ServerOperationError)):
650+
raise
651+
else:
652+
raise OperationalError(f"Error closing command: {str(e)}")
564653

565654
def get_query_state(self, command_id: CommandId) -> CommandState:
566655
"""
@@ -582,13 +671,28 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
582671
sea_statement_id = command_id.to_sea_statement_id()
583672

584673
request = GetStatementRequest(statement_id=sea_statement_id)
585-
response_data = self.http_client.get(
586-
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
587-
)
674+
try:
675+
response_data = self.http_client.get(
676+
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
677+
)
588678

589-
# Parse the response
590-
response = GetStatementResponse.from_dict(response_data)
591-
return response.status.state
679+
# Parse the response
680+
response = GetStatementResponse.from_dict(response_data)
681+
return response.status.state
682+
except Exception as e:
683+
# Map exceptions to match Thrift behavior
684+
from databricks.sql.exc import RequestError, OperationalError
685+
686+
if isinstance(e, RequestError) and "404" in str(e):
687+
# If the operation is not found, it was likely already closed
688+
logger.warning(
689+
f"Operation not found when checking state: {sea_statement_id}"
690+
)
691+
return CommandState.CANCELLED
692+
elif isinstance(e, (RequestError, ServerOperationError)):
693+
raise
694+
else:
695+
raise OperationalError(f"Error getting query state: {str(e)}")
592696

593697
def get_execution_result(
594698
self,
@@ -617,30 +721,39 @@ def get_execution_result(
617721
# Create the request model
618722
request = GetStatementRequest(statement_id=sea_statement_id)
619723

620-
# Get the statement result
621-
response_data = self.http_client.get(
622-
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
623-
)
724+
try:
725+
# Get the statement result
726+
response_data = self.http_client.get(
727+
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
728+
)
624729

625-
# Create and return a SeaResultSet
626-
from databricks.sql.result_set import SeaResultSet
627-
628-
# Convert the response to an ExecuteResponse and extract result data
629-
(
630-
execute_response,
631-
result_data,
632-
manifest,
633-
) = self._results_message_to_execute_response(response_data, command_id)
634-
635-
return SeaResultSet(
636-
connection=cursor.connection,
637-
execute_response=execute_response,
638-
sea_client=self,
639-
buffer_size_bytes=cursor.buffer_size_bytes,
640-
arraysize=cursor.arraysize,
641-
result_data=result_data,
642-
manifest=manifest,
643-
)
730+
# Create and return a SeaResultSet
731+
from databricks.sql.result_set import SeaResultSet
732+
733+
# Convert the response to an ExecuteResponse and extract result data
734+
(
735+
execute_response,
736+
result_data,
737+
manifest,
738+
) = self._results_message_to_execute_response(response_data, command_id)
739+
740+
return SeaResultSet(
741+
connection=cursor.connection,
742+
execute_response=execute_response,
743+
sea_client=self,
744+
buffer_size_bytes=cursor.buffer_size_bytes,
745+
arraysize=cursor.arraysize,
746+
result_data=result_data,
747+
manifest=manifest,
748+
)
749+
except Exception as e:
750+
# Map exceptions to match Thrift behavior
751+
from databricks.sql.exc import RequestError, OperationalError
752+
753+
if isinstance(e, (RequestError, ServerOperationError)):
754+
raise
755+
else:
756+
raise OperationalError(f"Error getting execution result: {str(e)}")
644757

645758
# == Metadata Operations ==
646759

0 commit comments

Comments
 (0)