11import time
2- from typing import Dict , Tuple , List , Optional , Any , Union , Sequence
2+ from typing import Dict , Tuple , List , Optional , Any , Union , Sequence , BinaryIO
33import pandas
44
55try :
@@ -662,7 +662,9 @@ def _check_not_closed(self):
662662 )
663663
664664 def _handle_staging_operation (
665- self , staging_allowed_local_path : Union [None , str , List [str ]]
665+ self ,
666+ staging_allowed_local_path : Union [None , str , List [str ]],
667+ input_stream : Optional [BinaryIO ] = None ,
666668 ):
667669 """Fetch the HTTP request instruction from a staging ingestion command
668670 and call the designated handler.
@@ -671,6 +673,28 @@ def _handle_staging_operation(
671673 is not descended from staging_allowed_local_path.
672674 """
673675
676+ assert self .active_result_set is not None
677+ row = self .active_result_set .fetchone ()
678+ assert row is not None
679+
680+ # May be real headers, or could be json string
681+ headers = (
682+ json .loads (row .headers ) if isinstance (row .headers , str ) else row .headers
683+ )
684+ headers = dict (headers ) if headers else {}
685+
686+ # Handle __input_stream__ token for PUT operations
687+ if (
688+ row .operation == "PUT"
689+ and getattr (row , "localFile" , None ) == "__input_stream__"
690+ ):
691+ return self ._handle_staging_put_stream (
692+ presigned_url = row .presignedUrl ,
693+ stream = input_stream ,
694+ headers = headers ,
695+ )
696+
697+ # For non-streaming operations, validate staging_allowed_local_path
674698 if isinstance (staging_allowed_local_path , type (str ())):
675699 _staging_allowed_local_paths = [staging_allowed_local_path ]
676700 elif isinstance (staging_allowed_local_path , type (list ())):
@@ -685,10 +709,6 @@ def _handle_staging_operation(
685709 os .path .abspath (i ) for i in _staging_allowed_local_paths
686710 ]
687711
688- assert self .active_result_set is not None
689- row = self .active_result_set .fetchone ()
690- assert row is not None
691-
692712 # Must set to None in cases where server response does not include localFile
693713 abs_localFile = None
694714
@@ -711,19 +731,16 @@ def _handle_staging_operation(
711731 session_id_hex = self .connection .get_session_id_hex (),
712732 )
713733
714- # May be real headers, or could be json string
715- headers = (
716- json .loads (row .headers ) if isinstance (row .headers , str ) else row .headers
717- )
718-
719734 handler_args = {
720735 "presigned_url" : row .presignedUrl ,
721736 "local_file" : abs_localFile ,
722- "headers" : dict ( headers ) or {} ,
737+ "headers" : headers ,
723738 }
724739
725740 logger .debug (
726- f"Attempting staging operation indicated by server: { row .operation } - { getattr (row , 'localFile' , '' )} "
741+ "Attempting staging operation indicated by server: %s - %s" ,
742+ row .operation ,
743+ getattr (row , "localFile" , "" ),
727744 )
728745
729746 # TODO: Create a retry loop here to re-attempt if the request times out or fails
@@ -762,6 +779,10 @@ def _handle_staging_put(
762779 HttpMethod .PUT , presigned_url , body = fh .read (), headers = headers
763780 )
764781
782+ self ._handle_staging_http_response (r )
783+
784+ def _handle_staging_http_response (self , r ):
785+
765786 # fmt: off
766787 # HTTP status codes
767788 OK = 200
@@ -784,6 +805,37 @@ def _handle_staging_put(
784805 + "but not yet applied on the server. It's possible this command may fail later."
785806 )
786807
808+ @log_latency (StatementType .SQL )
809+ def _handle_staging_put_stream (
810+ self ,
811+ presigned_url : str ,
812+ stream : BinaryIO ,
813+ headers : dict = {},
814+ ) -> None :
815+ """Handle PUT operation with streaming data.
816+
817+ Args:
818+ presigned_url: The presigned URL for upload
819+ stream: Binary stream to upload
820+ headers: HTTP headers
821+
822+ Raises:
823+ ProgrammingError: If no input stream is provided
824+ OperationalError: If the upload fails
825+ """
826+
827+ if not stream :
828+ raise ProgrammingError (
829+ "No input stream provided for streaming operation" ,
830+ session_id_hex = self .connection .get_session_id_hex (),
831+ )
832+
833+ r = self .connection .http_client .request (
834+ HttpMethod .PUT , presigned_url , body = stream .read (), headers = headers
835+ )
836+
837+ self ._handle_staging_http_response (r )
838+
787839 @log_latency (StatementType .SQL )
788840 def _handle_staging_get (
789841 self , local_file : str , presigned_url : str , headers : Optional [dict ] = None
@@ -840,6 +892,7 @@ def execute(
840892 operation : str ,
841893 parameters : Optional [TParameterCollection ] = None ,
842894 enforce_embedded_schema_correctness = False ,
895+ input_stream : Optional [BinaryIO ] = None ,
843896 ) -> "Cursor" :
844897 """
845898 Execute a query and wait for execution to complete.
@@ -914,7 +967,8 @@ def execute(
914967
915968 if self .active_result_set and self .active_result_set .is_staging_operation :
916969 self ._handle_staging_operation (
917- staging_allowed_local_path = self .connection .staging_allowed_local_path
970+ staging_allowed_local_path = self .connection .staging_allowed_local_path ,
971+ input_stream = input_stream ,
918972 )
919973
920974 return self
0 commit comments