@@ -130,14 +130,40 @@ def __init__(
130
130
# Extract warehouse ID from http_path
131
131
self .warehouse_id = self ._extract_warehouse_id (http_path )
132
132
133
+ # Extract retry policy parameters
134
+ retry_policy = kwargs .get ("_retry_policy" , None )
135
+ retry_stop_after_attempts_count = kwargs .get (
136
+ "_retry_stop_after_attempts_count" , 30
137
+ )
138
+ retry_stop_after_attempts_duration = kwargs .get (
139
+ "_retry_stop_after_attempts_duration" , 600
140
+ )
141
+ retry_delay_min = kwargs .get ("_retry_delay_min" , 1 )
142
+ retry_delay_max = kwargs .get ("_retry_delay_max" , 60 )
143
+ retry_delay_default = kwargs .get ("_retry_delay_default" , 5 )
144
+ retry_dangerous_codes = kwargs .get ("_retry_dangerous_codes" , [])
145
+
146
+ # Create retry policy if not provided
147
+ if not retry_policy :
148
+ from databricks .sql .auth .retry import DatabricksRetryPolicy
149
+
150
+ retry_policy = DatabricksRetryPolicy (
151
+ delay_min = retry_delay_min ,
152
+ delay_max = retry_delay_max ,
153
+ stop_after_attempts_count = retry_stop_after_attempts_count ,
154
+ stop_after_attempts_duration = retry_stop_after_attempts_duration ,
155
+ delay_default = retry_delay_default ,
156
+ force_dangerous_codes = retry_dangerous_codes ,
157
+ )
158
+
133
159
# Initialize ThriftHttpClient
134
160
thrift_client = THttpClient (
135
161
auth_provider = auth_provider ,
136
162
uri_or_host = f"https://{ server_hostname } :{ port } " ,
137
163
path = http_path ,
138
164
ssl_options = ssl_options ,
139
165
max_connections = kwargs .get ("max_connections" , 1 ),
140
- retry_policy = kwargs . get ( "_retry_stop_after_attempts_count" , 30 ),
166
+ retry_policy = retry_policy , # Use the configured retry policy
141
167
)
142
168
143
169
# Set custom headers
@@ -229,22 +255,31 @@ def open_session(
229
255
schema = schema ,
230
256
)
231
257
232
- response = self .http_client .post (
233
- path = self .SESSION_PATH , data = request_data .to_dict ()
234
- )
235
-
236
- session_response = CreateSessionResponse .from_dict (response )
237
- session_id = session_response .session_id
238
- if not session_id :
239
- raise ServerOperationError (
240
- "Failed to create session: No session ID returned" ,
241
- {
242
- "operation-id" : None ,
243
- "diagnostic-info" : None ,
244
- },
258
+ try :
259
+ response = self .http_client .post (
260
+ path = self .SESSION_PATH , data = request_data .to_dict ()
245
261
)
246
262
247
- return SessionId .from_sea_session_id (session_id )
263
+ session_response = CreateSessionResponse .from_dict (response )
264
+ session_id = session_response .session_id
265
+ if not session_id :
266
+ raise ServerOperationError (
267
+ "Failed to create session: No session ID returned" ,
268
+ {
269
+ "operation-id" : None ,
270
+ "diagnostic-info" : None ,
271
+ },
272
+ )
273
+
274
+ return SessionId .from_sea_session_id (session_id )
275
+ except Exception as e :
276
+ # Map exceptions to match Thrift behavior
277
+ from databricks .sql .exc import RequestError , OperationalError
278
+
279
+ if isinstance (e , (RequestError , ServerOperationError )):
280
+ raise
281
+ else :
282
+ raise OperationalError (f"Error opening session: { str (e )} " )
248
283
249
284
def close_session (self , session_id : SessionId ) -> None :
250
285
"""
@@ -269,10 +304,25 @@ def close_session(self, session_id: SessionId) -> None:
269
304
session_id = sea_session_id ,
270
305
)
271
306
272
- self .http_client .delete (
273
- path = self .SESSION_PATH_WITH_ID .format (sea_session_id ),
274
- data = request_data .to_dict (),
275
- )
307
+ try :
308
+ self .http_client .delete (
309
+ path = self .SESSION_PATH_WITH_ID .format (sea_session_id ),
310
+ data = request_data .to_dict (),
311
+ )
312
+ except Exception as e :
313
+ # Map exceptions to match Thrift behavior
314
+ from databricks .sql .exc import (
315
+ RequestError ,
316
+ OperationalError ,
317
+ SessionAlreadyClosedError ,
318
+ )
319
+
320
+ if isinstance (e , RequestError ) and "404" in str (e ):
321
+ raise SessionAlreadyClosedError ("Session is already closed" )
322
+ elif isinstance (e , (RequestError , ServerOperationError )):
323
+ raise
324
+ else :
325
+ raise OperationalError (f"Error closing session: { str (e )} " )
276
326
277
327
@staticmethod
278
328
def get_default_session_configuration_value (name : str ) -> Optional [str ]:
@@ -475,48 +525,57 @@ def execute_command(
475
525
result_compression = result_compression ,
476
526
)
477
527
478
- response_data = self .http_client .post (
479
- path = self .STATEMENT_PATH , data = request .to_dict ()
480
- )
481
- response = ExecuteStatementResponse .from_dict (response_data )
482
- statement_id = response .statement_id
483
- if not statement_id :
484
- raise ServerOperationError (
485
- "Failed to execute command: No statement ID returned" ,
486
- {
487
- "operation-id" : None ,
488
- "diagnostic-info" : None ,
489
- },
528
+ try :
529
+ response_data = self .http_client .post (
530
+ path = self .STATEMENT_PATH , data = request .to_dict ()
490
531
)
532
+ response = ExecuteStatementResponse .from_dict (response_data )
533
+ statement_id = response .statement_id
534
+ if not statement_id :
535
+ raise ServerOperationError (
536
+ "Failed to execute command: No statement ID returned" ,
537
+ {
538
+ "operation-id" : None ,
539
+ "diagnostic-info" : None ,
540
+ },
541
+ )
491
542
492
- command_id = CommandId .from_sea_statement_id (statement_id )
543
+ command_id = CommandId .from_sea_statement_id (statement_id )
493
544
494
- # Store the command ID in the cursor
495
- cursor .active_command_id = command_id
545
+ # Store the command ID in the cursor
546
+ cursor .active_command_id = command_id
496
547
497
- # If async operation, return and let the client poll for results
498
- if async_op :
499
- return None
548
+ # If async operation, return and let the client poll for results
549
+ if async_op :
550
+ return None
500
551
501
- # For synchronous operation, wait for the statement to complete
502
- status = response .status
503
- state = status .state
552
+ # For synchronous operation, wait for the statement to complete
553
+ status = response .status
554
+ state = status .state
504
555
505
- # Keep polling until we reach a terminal state
506
- while state in [CommandState .PENDING , CommandState .RUNNING ]:
507
- time .sleep (0.5 ) # add a small delay to avoid excessive API calls
508
- state = self .get_query_state (command_id )
556
+ # Keep polling until we reach a terminal state
557
+ while state in [CommandState .PENDING , CommandState .RUNNING ]:
558
+ time .sleep (0.5 ) # add a small delay to avoid excessive API calls
559
+ state = self .get_query_state (command_id )
509
560
510
- if state != CommandState .SUCCEEDED :
511
- raise ServerOperationError (
512
- f"Statement execution did not succeed: { status .error .message if status .error else 'Unknown error' } " ,
513
- {
514
- "operation-id" : command_id .to_sea_statement_id (),
515
- "diagnostic-info" : None ,
516
- },
517
- )
561
+ if state != CommandState .SUCCEEDED :
562
+ raise ServerOperationError (
563
+ f"Statement execution did not succeed: { status .error .message if status .error else 'Unknown error' } " ,
564
+ {
565
+ "operation-id" : command_id .to_sea_statement_id (),
566
+ "diagnostic-info" : None ,
567
+ },
568
+ )
518
569
519
- return self .get_execution_result (command_id , cursor )
570
+ return self .get_execution_result (command_id , cursor )
571
+ except Exception as e :
572
+ # Map exceptions to match Thrift behavior
573
+ from databricks .sql .exc import RequestError , OperationalError
574
+
575
+ if isinstance (e , (RequestError , ServerOperationError )):
576
+ raise
577
+ else :
578
+ raise OperationalError (f"Error executing command: { str (e )} " )
520
579
521
580
def cancel_command (self , command_id : CommandId ) -> None :
522
581
"""
@@ -535,10 +594,25 @@ def cancel_command(self, command_id: CommandId) -> None:
535
594
sea_statement_id = command_id .to_sea_statement_id ()
536
595
537
596
request = CancelStatementRequest (statement_id = sea_statement_id )
538
- self .http_client .post (
539
- path = self .CANCEL_STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
540
- data = request .to_dict (),
541
- )
597
+ try :
598
+ self .http_client .post (
599
+ path = self .CANCEL_STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
600
+ data = request .to_dict (),
601
+ )
602
+ except Exception as e :
603
+ # Map exceptions to match Thrift behavior
604
+ from databricks .sql .exc import RequestError , OperationalError
605
+
606
+ if isinstance (e , RequestError ) and "404" in str (e ):
607
+ # Operation was already closed, so we can ignore this
608
+ logger .warning (
609
+ f"Attempted to cancel a command that was already closed: { sea_statement_id } "
610
+ )
611
+ return
612
+ elif isinstance (e , (RequestError , ServerOperationError )):
613
+ raise
614
+ else :
615
+ raise OperationalError (f"Error canceling command: { str (e )} " )
542
616
543
617
def close_command (self , command_id : CommandId ) -> None :
544
618
"""
@@ -557,10 +631,25 @@ def close_command(self, command_id: CommandId) -> None:
557
631
sea_statement_id = command_id .to_sea_statement_id ()
558
632
559
633
request = CloseStatementRequest (statement_id = sea_statement_id )
560
- self .http_client .delete (
561
- path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
562
- data = request .to_dict (),
563
- )
634
+ try :
635
+ self .http_client .delete (
636
+ path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
637
+ data = request .to_dict (),
638
+ )
639
+ except Exception as e :
640
+ # Map exceptions to match Thrift behavior
641
+ from databricks .sql .exc import (
642
+ RequestError ,
643
+ OperationalError ,
644
+ CursorAlreadyClosedError ,
645
+ )
646
+
647
+ if isinstance (e , RequestError ) and "404" in str (e ):
648
+ raise CursorAlreadyClosedError ("Cursor is already closed" )
649
+ elif isinstance (e , (RequestError , ServerOperationError )):
650
+ raise
651
+ else :
652
+ raise OperationalError (f"Error closing command: { str (e )} " )
564
653
565
654
def get_query_state (self , command_id : CommandId ) -> CommandState :
566
655
"""
@@ -582,13 +671,28 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
582
671
sea_statement_id = command_id .to_sea_statement_id ()
583
672
584
673
request = GetStatementRequest (statement_id = sea_statement_id )
585
- response_data = self .http_client .get (
586
- path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
587
- )
674
+ try :
675
+ response_data = self .http_client .get (
676
+ path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
677
+ )
588
678
589
- # Parse the response
590
- response = GetStatementResponse .from_dict (response_data )
591
- return response .status .state
679
+ # Parse the response
680
+ response = GetStatementResponse .from_dict (response_data )
681
+ return response .status .state
682
+ except Exception as e :
683
+ # Map exceptions to match Thrift behavior
684
+ from databricks .sql .exc import RequestError , OperationalError
685
+
686
+ if isinstance (e , RequestError ) and "404" in str (e ):
687
+ # If the operation is not found, it was likely already closed
688
+ logger .warning (
689
+ f"Operation not found when checking state: { sea_statement_id } "
690
+ )
691
+ return CommandState .CANCELLED
692
+ elif isinstance (e , (RequestError , ServerOperationError )):
693
+ raise
694
+ else :
695
+ raise OperationalError (f"Error getting query state: { str (e )} " )
592
696
593
697
def get_execution_result (
594
698
self ,
@@ -617,30 +721,39 @@ def get_execution_result(
617
721
# Create the request model
618
722
request = GetStatementRequest (statement_id = sea_statement_id )
619
723
620
- # Get the statement result
621
- response_data = self .http_client .get (
622
- path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
623
- )
724
+ try :
725
+ # Get the statement result
726
+ response_data = self .http_client .get (
727
+ path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
728
+ )
624
729
625
- # Create and return a SeaResultSet
626
- from databricks .sql .result_set import SeaResultSet
627
-
628
- # Convert the response to an ExecuteResponse and extract result data
629
- (
630
- execute_response ,
631
- result_data ,
632
- manifest ,
633
- ) = self ._results_message_to_execute_response (response_data , command_id )
634
-
635
- return SeaResultSet (
636
- connection = cursor .connection ,
637
- execute_response = execute_response ,
638
- sea_client = self ,
639
- buffer_size_bytes = cursor .buffer_size_bytes ,
640
- arraysize = cursor .arraysize ,
641
- result_data = result_data ,
642
- manifest = manifest ,
643
- )
730
+ # Create and return a SeaResultSet
731
+ from databricks .sql .result_set import SeaResultSet
732
+
733
+ # Convert the response to an ExecuteResponse and extract result data
734
+ (
735
+ execute_response ,
736
+ result_data ,
737
+ manifest ,
738
+ ) = self ._results_message_to_execute_response (response_data , command_id )
739
+
740
+ return SeaResultSet (
741
+ connection = cursor .connection ,
742
+ execute_response = execute_response ,
743
+ sea_client = self ,
744
+ buffer_size_bytes = cursor .buffer_size_bytes ,
745
+ arraysize = cursor .arraysize ,
746
+ result_data = result_data ,
747
+ manifest = manifest ,
748
+ )
749
+ except Exception as e :
750
+ # Map exceptions to match Thrift behavior
751
+ from databricks .sql .exc import RequestError , OperationalError
752
+
753
+ if isinstance (e , (RequestError , ServerOperationError )):
754
+ raise
755
+ else :
756
+ raise OperationalError (f"Error getting execution result: { str (e )} " )
644
757
645
758
# == Metadata Operations ==
646
759
0 commit comments