3939import copy
4040import functools
4141import os
42+ import queue
4243import random
4344import re
4445import threading
4546import urllib .parse
47+ from concurrent .futures import ThreadPoolExecutor
4648import warnings
4749from datetime import date , datetime , time , timedelta , timezone , tzinfo
4850from decimal import Decimal
4951from time import sleep
50- from typing import Any , Dict , Generic , List , Optional , Tuple , TypeVar , Union
52+ from typing import Any , Callable , Dict , Generic , List , Optional , Tuple , TypeVar , Union
5153
5254import pytz
5355import requests
@@ -684,6 +686,27 @@ def _verify_extra_credential(self, header):
684686 raise ValueError (f"only ASCII characters are allowed in extra credential '{ key } '" )
685687
686688
689+ class ResultDownloader ():
690+ def __init__ (self ):
691+ self .queue : queue .Queue = queue .Queue ()
692+ self .executor : Optional [ThreadPoolExecutor ] = None
693+
694+ def submit (self , fetch_func : Callable [[], List [Any ]]):
695+ assert self .executor is not None
696+ self .executor .submit (self .download_task , fetch_func )
697+
698+ def download_task (self , fetch_func ):
699+ self .queue .put (fetch_func ())
700+
701+ def __enter__ (self ):
702+ self .executor = ThreadPoolExecutor (max_workers = 1 )
703+ return self
704+
705+ def __exit__ (self , exc_type , exc_value , exc_traceback ):
706+ self .executor .shutdown ()
707+ self .executor = None
708+
709+
687710class TrinoResult (object ):
688711 """
689712 Represent the result of a Trino query as an iterator on rows.
@@ -711,16 +734,21 @@ def rownumber(self) -> int:
711734 return self ._rownumber
712735
713736 def __iter__ (self ):
714- # A query only transitions to a FINISHED state when the results are fully consumed :
715- # The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi.
716- while not self . _query . finished or self . _rows is not None :
717- next_rows = self . _query . fetch () if not self ._query .finished else None
718- for row in self ._rows :
719- self ._rownumber += 1
720- logger . debug ( "row %s" , row )
721- yield row
737+ with ResultDownloader () as result_downloader :
738+ # A query only transitions to a FINISHED state when the results are fully consumed:
739+ # The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi.
740+ result_downloader . submit ( self ._query .fetch )
741+ while not self . _query . finished or self ._rows is not None :
742+ next_rows = result_downloader . queue . get () if not self ._query . finished else None
743+ if not self . _query . finished :
744+ result_downloader . submit ( self . _query . fetch )
722745
723- self ._rows = next_rows
746+ for row in self ._rows :
747+ self ._rownumber += 1
748+ logger .debug ("row %s" , row )
749+ yield row
750+
751+ self ._rows = next_rows
724752
725753
726754class TrinoQuery (object ):
@@ -753,7 +781,7 @@ def columns(self):
753781 while not self ._columns and not self .finished and not self .cancelled :
754782 # Columns are not returned immediately after query is submitted.
755783 # Continue fetching data until columns information is available and push fetched rows into buffer.
756- self ._result .rows += self .fetch ()
784+ self ._result .rows += self .map_rows ( self . fetch () )
757785 return self ._columns
758786
759787 @property
@@ -802,7 +830,7 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
802830
803831 # Execute should block until at least one row is received or query is finished or cancelled
804832 while not self .finished and not self .cancelled and len (self ._result .rows ) == 0 :
805- self ._result .rows += self .fetch ()
833+ self ._result .rows += self .map_rows ( self . fetch () )
806834 return self ._result
807835
808836 def _update_state (self , status ):
@@ -822,11 +850,12 @@ def fetch(self) -> List[List[Any]]:
822850 logger .debug (status )
823851 if status .next_uri is None :
824852 self ._finished = True
853+ return status .rows
825854
855+ def map_rows (self , rows : List [List [Any ]]) -> List [List [Any ]]:
826856 if not self ._row_mapper :
827857 return []
828-
829- return self ._row_mapper .map (status .rows )
858+ return self ._row_mapper .map (rows )
830859
831860 def cancel (self ) -> None :
832861 """Cancel the current query"""
0 commit comments