4343//
4444// The driver should be used via the database/sql package:
4545//
46- // import "database/sql"
47- // import _ "github.com/trinodb/trino-go-client/trino"
48- //
49- // dsn := "http://user@localhost:8080?catalog=default&schema=test"
50- // db, err := sql.Open("trino", dsn)
46+ // import "database/sql"
47+ // import _ "github.com/trinodb/trino-go-client/trino"
5148//
49+ // dsn := "http://user@localhost:8080?catalog=default&schema=test"
50+ // db, err := sql.Open("trino", dsn)
5251package trino
5352
5453import (
@@ -136,6 +135,8 @@ const (
136135 kerberosRealmConfig = "KerberosRealm"
137136 kerberosConfigPathConfig = "KerberosConfigPath"
138137 SSLCertPathConfig = "SSLCertPath"
138+
139+ CHAN_SIZE = 10
139140)
140141
141142var (
@@ -372,7 +373,6 @@ var customClientRegistry = struct {
372373// }
373374// trino.RegisterCustomClient("foobar", foobarClient)
374375// db, err := sql.Open("trino", "https://user@localhost:8080?custom_client=foobar")
375- //
376376func RegisterCustomClient (key string , client * http.Client ) error {
377377 if _ , err := strconv .ParseBool (key ); err == nil {
378378 return fmt .Errorf ("trino: custom client key %q is reserved" , key )
@@ -549,11 +549,15 @@ func newErrQueryFailedFromResponse(resp *http.Response) *ErrQueryFailed {
549549}
550550
551551type driverStmt struct {
552- conn * Conn
553- query string
554- user string
555- statsCh chan QueryProgressInfo
556- doneCh chan struct {}
552+ conn * Conn
553+ query string
554+ user string
555+ nextURIs chan string
556+ httpResponses chan * http.Response
557+ queryResponses chan queryResponse
558+ statsCh chan QueryProgressInfo
559+ errors chan error
560+ doneCh chan struct {}
557561}
558562
559563var (
@@ -565,12 +569,26 @@ var (
565569
566570// Close closes statement just before releasing connection
567571func (st * driverStmt ) Close () error {
568- if st .doneCh ! = nil {
569- close ( st . doneCh )
572+ if st .doneCh = = nil {
573+ return nil
570574 }
575+ close (st .doneCh )
571576 if st .statsCh != nil {
572577 <- st .statsCh
578+ st .statsCh = nil
573579 }
580+ go func () {
581+ // drain errors chan to allow goroutines to write to it
582+ for range st .errors {
583+ }
584+ }()
585+ for range st .queryResponses {
586+ }
587+ for range st .httpResponses {
588+ }
589+ close (st .nextURIs )
590+ close (st .errors )
591+ st .doneCh = nil
574592 return nil
575593}
576594
@@ -598,7 +616,7 @@ func (st *driverStmt) ExecContext(ctx context.Context, args []driver.NamedValue)
598616 }
599617 // consume all results, if there are any
600618 for err == nil {
601- err = rows .fetch (true )
619+ err = rows .fetch ()
602620 }
603621
604622 if err != nil && err != io .EOF {
@@ -709,7 +727,7 @@ func (st *driverStmt) QueryContext(ctx context.Context, args []driver.NamedValue
709727 statsCh : st .statsCh ,
710728 doneCh : st .doneCh ,
711729 }
712- if err = rows .fetch (false ); err != nil {
730+ if err = rows .fetch (); err != nil && err != io . EOF {
713731 return nil , err
714732 }
715733 return rows , nil
@@ -782,9 +800,90 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
782800 return nil , fmt .Errorf ("trino: %v" , err )
783801 }
784802
803+ st .doneCh = make (chan struct {})
804+ st .nextURIs = make (chan string , CHAN_SIZE )
805+ st .httpResponses = make (chan * http.Response , CHAN_SIZE )
806+ st .queryResponses = make (chan queryResponse , CHAN_SIZE )
807+ st .errors = make (chan error )
808+ st .nextURIs <- sr .NextURI
809+ go func () {
810+ defer close (st .httpResponses )
811+ for {
812+ select {
813+ case nextURI := <- st .nextURIs :
814+ if nextURI == "" {
815+ st .errors <- io .EOF
816+ return
817+ }
818+ hs := make (http.Header )
819+ hs .Add (trinoUserHeader , st .user )
820+ req , err := st .conn .newRequest ("GET" , nextURI , nil , hs )
821+ if err != nil {
822+ st .errors <- err
823+ return
824+ }
825+ resp , err := st .conn .roundTrip (ctx , req )
826+ if err != nil {
827+ if ctx .Err () == context .Canceled {
828+ st .errors <- context .Canceled
829+ return
830+ }
831+ st .errors <- err
832+ return
833+ }
834+ select {
835+ case st .httpResponses <- resp :
836+ case <- st .doneCh :
837+ return
838+ }
839+ case <- st .doneCh :
840+ return
841+ }
842+ }
843+ }()
844+ go func () {
845+ defer close (st .queryResponses )
846+ for {
847+ select {
848+ case resp := <- st .httpResponses :
849+ if resp == nil {
850+ return
851+ }
852+ var qresp queryResponse
853+ d := json .NewDecoder (resp .Body )
854+ d .UseNumber ()
855+ err = d .Decode (& qresp )
856+ if err != nil {
857+ st .errors <- fmt .Errorf ("trino: %v" , err )
858+ return
859+ }
860+ err = resp .Body .Close ()
861+ if err != nil {
862+ st .errors <- err
863+ return
864+ }
865+ err = handleResponseError (resp .StatusCode , qresp .Error )
866+ if err != nil {
867+ st .errors <- err
868+ return
869+ }
870+ select {
871+ case st .nextURIs <- qresp .NextURI :
872+ case <- st .doneCh :
873+ return
874+ }
875+ select {
876+ case st .queryResponses <- qresp :
877+ case <- st .doneCh :
878+ return
879+ }
880+ case <- st .doneCh :
881+ return
882+ }
883+ }
884+ }()
785885 if st .conn .progressUpdater != nil {
786886 st .statsCh = make (chan QueryProgressInfo )
787- st .doneCh = make (chan struct {})
788887
789888 // progress updater go func
790889 go func () {
@@ -812,7 +911,6 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
812911 st .conn .progressUpdaterPeriod .LastCallbackTime = time .Now ()
813912 st .conn .progressUpdaterPeriod .LastQueryState = sr .Stats .State
814913 }
815-
816914 return & sr , handleResponseError (resp .StatusCode , sr .Error )
817915}
818916
@@ -875,7 +973,7 @@ func (qr *driverRows) Columns() []string {
875973 return []string {}
876974 }
877975 if qr .columns == nil {
878- if err := qr .fetch (false ); err != nil {
976+ if err := qr .fetch (); err != nil && err != io . EOF {
879977 qr .err = err
880978 return []string {}
881979 }
@@ -917,7 +1015,7 @@ func (qr *driverRows) Next(dest []driver.Value) error {
9171015 qr .err = io .EOF
9181016 return qr .err
9191017 }
920- if err := qr .fetch (true ); err != nil {
1018+ if err := qr .fetch (); err != nil {
9211019 qr .err = err
9221020 return err
9231021 }
@@ -947,7 +1045,7 @@ func (qr driverRows) LastInsertId() (int64, error) {
9471045
9481046// RowsAffected returns the number of rows affected by the query.
9491047func (qr driverRows ) RowsAffected () (int64 , error ) {
950- return qr .rowsAffected , qr . err
1048+ return qr .rowsAffected , nil
9511049}
9521050
9531051type queryResponse struct {
@@ -1016,71 +1114,31 @@ func handleResponseError(status int, respErr stmtError) error {
10161114 }
10171115}
10181116
1019- func (qr * driverRows ) fetch (allowEOF bool ) error {
1020- if qr .nextURI == "" {
1021- if allowEOF {
1022- return io .EOF
1023- }
1024- return nil
1025- }
1026-
1027- for qr .nextURI != "" {
1028- var qresp queryResponse
1029- err := qr .executeFetchRequest (& qresp )
1030- if err != nil {
1031- return err
1032- }
1033-
1034- qr .rowindex = 0
1035- qr .data = qresp .Data
1036- qr .nextURI = qresp .NextURI
1037- qr .rowsAffected = qresp .UpdateCount
1038- qr .scheduleProgressUpdate (qresp .ID , qresp .Stats )
1039-
1040- if len (qr .data ) == 0 {
1041- if qr .nextURI != "" {
1042- continue
1043- }
1044- if allowEOF {
1045- qr .err = io .EOF
1046- return qr .err
1047- }
1048- }
1049- if qr .columns == nil && len (qresp .Columns ) > 0 {
1117+ func (qr * driverRows ) fetch () error {
1118+ var qresp queryResponse
1119+ var err error
1120+ for {
1121+ select {
1122+ case qresp = <- qr .stmt .queryResponses :
10501123 err = qr .initColumns (& qresp )
10511124 if err != nil {
10521125 return err
10531126 }
1054- }
1055- return nil
1056- }
1057- return nil
1058- }
1059-
1060- func (qr * driverRows ) executeFetchRequest (qresp * queryResponse ) error {
1061- hs := make (http.Header )
1062- hs .Add (trinoUserHeader , qr .stmt .user )
1063- req , err := qr .stmt .conn .newRequest ("GET" , qr .nextURI , nil , hs )
1064- if err != nil {
1065- return err
1066- }
1067- resp , err := qr .stmt .conn .roundTrip (qr .ctx , req )
1068- if err != nil {
1069- if qr .ctx .Err () == context .Canceled {
1070- qr .Close ()
1127+ qr .rowindex = 0
1128+ qr .data = qresp .Data
1129+ qr .rowsAffected = qresp .UpdateCount
1130+ qr .scheduleProgressUpdate (qresp .ID , qresp .Stats )
1131+ if len (qresp .Data ) != 0 {
1132+ return nil
1133+ }
1134+ case err = <- qr .stmt .errors :
1135+ if err == context .Canceled {
1136+ qr .Close ()
1137+ }
1138+ qr .err = err
10711139 return err
10721140 }
1073- return err
10741141 }
1075- defer resp .Body .Close ()
1076-
1077- d := json .NewDecoder (resp .Body )
1078- d .UseNumber ()
1079- err = d .Decode (& qresp )
1080- if err != nil {
1081- return fmt .Errorf ("trino: %v" , err )
1082- }
1083- return handleResponseError (resp .StatusCode , qresp .Error )
10841142}
10851143
10861144func unmarshalArguments (signature * typeSignature ) error {
@@ -1112,6 +1170,9 @@ func unmarshalArguments(signature *typeSignature) error {
11121170}
11131171
11141172func (qr * driverRows ) initColumns (qresp * queryResponse ) error {
1173+ if qr .columns != nil || len (qresp .Columns ) == 0 {
1174+ return nil
1175+ }
11151176 var err error
11161177 for i := range qresp .Columns {
11171178 err = unmarshalArguments (& (qresp .Columns [i ].TypeSignature ))
@@ -1122,6 +1183,10 @@ func (qr *driverRows) initColumns(qresp *queryResponse) error {
11221183 qr .columns = make ([]string , len (qresp .Columns ))
11231184 qr .coltype = make ([]* typeConverter , len (qresp .Columns ))
11241185 for i , col := range qresp .Columns {
1186+ err = unmarshalArguments (& (qresp .Columns [i ].TypeSignature ))
1187+ if err != nil {
1188+ return fmt .Errorf ("error decoding column type signature: %w" , err )
1189+ }
11251190 qr .columns [i ] = col .Name
11261191 qr .coltype [i ], err = newTypeConverter (col .Type , col .TypeSignature )
11271192 if err != nil {
0 commit comments