@@ -547,11 +547,15 @@ func newErrQueryFailedFromResponse(resp *http.Response) *ErrQueryFailed {
547547}
548548
549549type driverStmt struct {
550- conn * Conn
551- query string
552- user string
553- statsCh chan QueryProgressInfo
554- doneCh chan struct {}
550+ conn * Conn
551+ query string
552+ user string
553+ nextURIs chan string
554+ httpResponses chan * http.Response
555+ queryResponses chan queryResponse
556+ statsCh chan QueryProgressInfo
557+ errors chan error
558+ doneCh chan struct {}
555559}
556560
557561var (
@@ -563,12 +567,26 @@ var (
563567
564568// Close closes statement just before releasing connection
565569func (st * driverStmt ) Close () error {
566- if st .doneCh ! = nil {
567- close ( st . doneCh )
570+ if st .doneCh = = nil {
571+ return nil
568572 }
573+ close (st .doneCh )
569574 if st .statsCh != nil {
570575 <- st .statsCh
576+ st .statsCh = nil
577+ }
578+ go func () {
579+ // drain errors chan to allow goroutines to write to it
580+ for range st .errors {
581+ }
582+ }()
583+ for range st .queryResponses {
584+ }
585+ for range st .httpResponses {
571586 }
587+ close (st .nextURIs )
588+ close (st .errors )
589+ st .doneCh = nil
572590 return nil
573591}
574592
@@ -596,7 +614,7 @@ func (st *driverStmt) ExecContext(ctx context.Context, args []driver.NamedValue)
596614 }
597615 // consume all results, if there are any
598616 for err == nil {
599- err = rows .fetch (true )
617+ err = rows .fetch ()
600618 }
601619
602620 if err != nil && err != io .EOF {
@@ -707,7 +725,7 @@ func (st *driverStmt) QueryContext(ctx context.Context, args []driver.NamedValue
707725 statsCh : st .statsCh ,
708726 doneCh : st .doneCh ,
709727 }
710- if err = rows .fetch (false ); err != nil {
728+ if err = rows .fetch (); err != nil && err != io . EOF {
711729 return nil , err
712730 }
713731 return rows , nil
@@ -780,9 +798,89 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
780798 return nil , fmt .Errorf ("trino: %v" , err )
781799 }
782800
801+ st .doneCh = make (chan struct {})
802+ st .nextURIs = make (chan string )
803+ st .httpResponses = make (chan * http.Response )
804+ st .queryResponses = make (chan queryResponse )
805+ st .errors = make (chan error )
806+ go func () {
807+ defer close (st .httpResponses )
808+ for {
809+ select {
810+ case nextURI := <- st .nextURIs :
811+ if nextURI == "" {
812+ return
813+ }
814+ hs := make (http.Header )
815+ hs .Add (trinoUserHeader , st .user )
816+ req , err := st .conn .newRequest ("GET" , nextURI , nil , hs )
817+ if err != nil {
818+ st .errors <- err
819+ return
820+ }
821+ resp , err := st .conn .roundTrip (ctx , req )
822+ if err != nil {
823+ if ctx .Err () == context .Canceled {
824+ st .errors <- context .Canceled
825+ return
826+ }
827+ st .errors <- err
828+ return
829+ }
830+ select {
831+ case st .httpResponses <- resp :
832+ case <- st .doneCh :
833+ return
834+ }
835+ case <- st .doneCh :
836+ return
837+ }
838+ }
839+ }()
840+ go func () {
841+ defer close (st .queryResponses )
842+ for {
843+ select {
844+ case resp := <- st .httpResponses :
845+ if resp == nil {
846+ return
847+ }
848+ var qresp queryResponse
849+ d := json .NewDecoder (resp .Body )
850+ d .UseNumber ()
851+ err = d .Decode (& qresp )
852+ if err != nil {
853+ st .errors <- fmt .Errorf ("trino: %v" , err )
854+ return
855+ }
856+ err = resp .Body .Close ()
857+ if err != nil {
858+ st .errors <- err
859+ return
860+ }
861+ err = handleResponseError (resp .StatusCode , qresp .Error )
862+ if err != nil {
863+ st .errors <- err
864+ return
865+ }
866+ select {
867+ case st .nextURIs <- qresp .NextURI :
868+ case <- st .doneCh :
869+ return
870+ }
871+ select {
872+ case st .queryResponses <- qresp :
873+ case <- st .doneCh :
874+ return
875+ }
876+ case <- st .doneCh :
877+ return
878+ }
879+ }
880+ }()
881+ st .nextURIs <- sr .NextURI
783882 if st .conn .progressUpdater != nil {
784883 st .statsCh = make (chan QueryProgressInfo )
785- st .doneCh = make (chan struct {})
786884
787885 // progress updater go func
788886 go func () {
@@ -810,7 +908,6 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
810908 st .conn .progressUpdaterPeriod .LastCallbackTime = time .Now ()
811909 st .conn .progressUpdaterPeriod .LastQueryState = sr .Stats .State
812910 }
813-
814911 return & sr , handleResponseError (resp .StatusCode , sr .Error )
815912}
816913
@@ -873,7 +970,7 @@ func (qr *driverRows) Columns() []string {
873970 return []string {}
874971 }
875972 if qr .columns == nil {
876- if err := qr .fetch (false ); err != nil {
973+ if err := qr .fetch (); err != nil && err != io . EOF {
877974 qr .err = err
878975 return []string {}
879976 }
@@ -915,7 +1012,7 @@ func (qr *driverRows) Next(dest []driver.Value) error {
9151012 qr .err = io .EOF
9161013 return qr .err
9171014 }
918- if err := qr .fetch (true ); err != nil {
1015+ if err := qr .fetch (); err != nil {
9191016 qr .err = err
9201017 return err
9211018 }
@@ -925,6 +1022,9 @@ func (qr *driverRows) Next(dest []driver.Value) error {
9251022 return qr .err
9261023 }
9271024 for i , v := range qr .coltype {
1025+ if i > len (dest )- 1 {
1026+ break
1027+ }
9281028 vv , err := v .ConvertValue (qr.data [qr.rowindex ][i ])
9291029 if err != nil {
9301030 qr .err = err
@@ -945,7 +1045,7 @@ func (qr driverRows) LastInsertId() (int64, error) {
9451045
9461046// RowsAffected returns the number of rows affected by the query.
9471047func (qr driverRows ) RowsAffected () (int64 , error ) {
948- return qr .rowsAffected , qr . err
1048+ return qr .rowsAffected , nil
9491049}
9501050
9511051type queryResponse struct {
@@ -1014,71 +1114,34 @@ func handleResponseError(status int, respErr stmtError) error {
10141114 }
10151115}
10161116
1017- func (qr * driverRows ) fetch (allowEOF bool ) error {
1018- if qr .nextURI == "" {
1019- if allowEOF {
1020- return io .EOF
1021- }
1022- return nil
1023- }
1024-
1025- for qr .nextURI != "" {
1026- var qresp queryResponse
1027- err := qr .executeFetchRequest (& qresp )
1028- if err != nil {
1029- return err
1030- }
1031-
1032- qr .rowindex = 0
1033- qr .data = qresp .Data
1034- qr .nextURI = qresp .NextURI
1035- qr .rowsAffected = qresp .UpdateCount
1036- qr .scheduleProgressUpdate (qresp .ID , qresp .Stats )
1037-
1038- if len (qr .data ) == 0 {
1039- if qr .nextURI != "" {
1040- continue
1041- }
1042- if allowEOF {
1043- qr .err = io .EOF
1044- return qr .err
1117+ func (qr * driverRows ) fetch () error {
1118+ var qresp queryResponse
1119+ var err error
1120+ for {
1121+ select {
1122+ case qresp = <- qr .stmt .queryResponses :
1123+ if qresp .ID == "" {
1124+ return io .EOF
10451125 }
1046- }
1047- if qr .columns == nil && len (qresp .Columns ) > 0 {
10481126 err = qr .initColumns (& qresp )
10491127 if err != nil {
10501128 return err
10511129 }
1052- }
1053- return nil
1054- }
1055- return nil
1056- }
1057-
1058- func (qr * driverRows ) executeFetchRequest (qresp * queryResponse ) error {
1059- hs := make (http.Header )
1060- hs .Add (trinoUserHeader , qr .stmt .user )
1061- req , err := qr .stmt .conn .newRequest ("GET" , qr .nextURI , nil , hs )
1062- if err != nil {
1063- return err
1064- }
1065- resp , err := qr .stmt .conn .roundTrip (qr .ctx , req )
1066- if err != nil {
1067- if qr .ctx .Err () == context .Canceled {
1068- qr .Close ()
1130+ qr .rowindex = 0
1131+ qr .data = qresp .Data
1132+ qr .rowsAffected = qresp .UpdateCount
1133+ qr .scheduleProgressUpdate (qresp .ID , qresp .Stats )
1134+ if len (qr .data ) != 0 {
1135+ return nil
1136+ }
1137+ case err = <- qr .stmt .errors :
1138+ if err == context .Canceled {
1139+ qr .Close ()
1140+ }
1141+ qr .err = err
10691142 return err
10701143 }
1071- return err
1072- }
1073- defer resp .Body .Close ()
1074-
1075- d := json .NewDecoder (resp .Body )
1076- d .UseNumber ()
1077- err = d .Decode (& qresp )
1078- if err != nil {
1079- return fmt .Errorf ("trino: %v" , err )
10801144 }
1081- return handleResponseError (resp .StatusCode , qresp .Error )
10821145}
10831146
10841147func unmarshalArguments (signature * typeSignature ) error {
@@ -1110,6 +1173,9 @@ func unmarshalArguments(signature *typeSignature) error {
11101173}
11111174
11121175func (qr * driverRows ) initColumns (qresp * queryResponse ) error {
1176+ if qr .columns != nil || len (qresp .Columns ) == 0 {
1177+ return nil
1178+ }
11131179 var err error
11141180 for i := range qresp .Columns {
11151181 err = unmarshalArguments (& (qresp .Columns [i ].TypeSignature ))
@@ -1120,6 +1186,10 @@ func (qr *driverRows) initColumns(qresp *queryResponse) error {
11201186 qr .columns = make ([]string , len (qresp .Columns ))
11211187 qr .coltype = make ([]* typeConverter , len (qresp .Columns ))
11221188 for i , col := range qresp .Columns {
1189+ err = unmarshalArguments (& (qresp .Columns [i ].TypeSignature ))
1190+ if err != nil {
1191+ return fmt .Errorf ("error decoding column type signature: %w" , err )
1192+ }
11231193 qr .columns [i ] = col .Name
11241194 qr .coltype [i ], err = newTypeConverter (col .Type , col .TypeSignature )
11251195 if err != nil {
0 commit comments