Skip to content

Commit 62c12bc

Browse files
nineinchnickJan Waś
authored andcommitted
Fetch and decode query results concurrently
1 parent 79e6ed1 commit 62c12bc

File tree

1 file changed

+143
-78
lines changed

1 file changed

+143
-78
lines changed

trino/trino.go

Lines changed: 143 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,11 @@
4343
//
4444
// The driver should be used via the database/sql package:
4545
//
46-
// import "database/sql"
47-
// import _ "github.com/trinodb/trino-go-client/trino"
48-
//
49-
// dsn := "http://user@localhost:8080?catalog=default&schema=test"
50-
// db, err := sql.Open("trino", dsn)
46+
// import "database/sql"
47+
// import _ "github.com/trinodb/trino-go-client/trino"
5148
//
49+
// dsn := "http://user@localhost:8080?catalog=default&schema=test"
50+
// db, err := sql.Open("trino", dsn)
5251
package trino
5352

5453
import (
@@ -136,6 +135,8 @@ const (
136135
kerberosRealmConfig = "KerberosRealm"
137136
kerberosConfigPathConfig = "KerberosConfigPath"
138137
SSLCertPathConfig = "SSLCertPath"
138+
139+
CHAN_SIZE = 10
139140
)
140141

141142
var (
@@ -372,7 +373,6 @@ var customClientRegistry = struct {
372373
// }
373374
// trino.RegisterCustomClient("foobar", foobarClient)
374375
// db, err := sql.Open("trino", "https://user@localhost:8080?custom_client=foobar")
375-
//
376376
func RegisterCustomClient(key string, client *http.Client) error {
377377
if _, err := strconv.ParseBool(key); err == nil {
378378
return fmt.Errorf("trino: custom client key %q is reserved", key)
@@ -549,11 +549,15 @@ func newErrQueryFailedFromResponse(resp *http.Response) *ErrQueryFailed {
549549
}
550550

551551
type driverStmt struct {
552-
conn *Conn
553-
query string
554-
user string
555-
statsCh chan QueryProgressInfo
556-
doneCh chan struct{}
552+
conn *Conn
553+
query string
554+
user string
555+
nextURIs chan string
556+
httpResponses chan *http.Response
557+
queryResponses chan queryResponse
558+
statsCh chan QueryProgressInfo
559+
errors chan error
560+
doneCh chan struct{}
557561
}
558562

559563
var (
@@ -565,12 +569,26 @@ var (
565569

566570
// Close closes statement just before releasing connection
567571
func (st *driverStmt) Close() error {
568-
if st.doneCh != nil {
569-
close(st.doneCh)
572+
if st.doneCh == nil {
573+
return nil
570574
}
575+
close(st.doneCh)
571576
if st.statsCh != nil {
572577
<-st.statsCh
578+
st.statsCh = nil
573579
}
580+
go func() {
581+
// drain errors chan to allow goroutines to write to it
582+
for range st.errors {
583+
}
584+
}()
585+
for range st.queryResponses {
586+
}
587+
for range st.httpResponses {
588+
}
589+
close(st.nextURIs)
590+
close(st.errors)
591+
st.doneCh = nil
574592
return nil
575593
}
576594

@@ -598,7 +616,7 @@ func (st *driverStmt) ExecContext(ctx context.Context, args []driver.NamedValue)
598616
}
599617
// consume all results, if there are any
600618
for err == nil {
601-
err = rows.fetch(true)
619+
err = rows.fetch()
602620
}
603621

604622
if err != nil && err != io.EOF {
@@ -709,7 +727,7 @@ func (st *driverStmt) QueryContext(ctx context.Context, args []driver.NamedValue
709727
statsCh: st.statsCh,
710728
doneCh: st.doneCh,
711729
}
712-
if err = rows.fetch(false); err != nil {
730+
if err = rows.fetch(); err != nil && err != io.EOF {
713731
return nil, err
714732
}
715733
return rows, nil
@@ -782,9 +800,90 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
782800
return nil, fmt.Errorf("trino: %v", err)
783801
}
784802

803+
st.doneCh = make(chan struct{})
804+
st.nextURIs = make(chan string, CHAN_SIZE)
805+
st.httpResponses = make(chan *http.Response, CHAN_SIZE)
806+
st.queryResponses = make(chan queryResponse, CHAN_SIZE)
807+
st.errors = make(chan error)
808+
st.nextURIs <- sr.NextURI
809+
go func() {
810+
defer close(st.httpResponses)
811+
for {
812+
select {
813+
case nextURI := <-st.nextURIs:
814+
if nextURI == "" {
815+
st.errors <- io.EOF
816+
return
817+
}
818+
hs := make(http.Header)
819+
hs.Add(trinoUserHeader, st.user)
820+
req, err := st.conn.newRequest("GET", nextURI, nil, hs)
821+
if err != nil {
822+
st.errors <- err
823+
return
824+
}
825+
resp, err := st.conn.roundTrip(ctx, req)
826+
if err != nil {
827+
if ctx.Err() == context.Canceled {
828+
st.errors <- context.Canceled
829+
return
830+
}
831+
st.errors <- err
832+
return
833+
}
834+
select {
835+
case st.httpResponses <- resp:
836+
case <-st.doneCh:
837+
return
838+
}
839+
case <-st.doneCh:
840+
return
841+
}
842+
}
843+
}()
844+
go func() {
845+
defer close(st.queryResponses)
846+
for {
847+
select {
848+
case resp := <-st.httpResponses:
849+
if resp == nil {
850+
return
851+
}
852+
var qresp queryResponse
853+
d := json.NewDecoder(resp.Body)
854+
d.UseNumber()
855+
err = d.Decode(&qresp)
856+
if err != nil {
857+
st.errors <- fmt.Errorf("trino: %v", err)
858+
return
859+
}
860+
err = resp.Body.Close()
861+
if err != nil {
862+
st.errors <- err
863+
return
864+
}
865+
err = handleResponseError(resp.StatusCode, qresp.Error)
866+
if err != nil {
867+
st.errors <- err
868+
return
869+
}
870+
select {
871+
case st.nextURIs <- qresp.NextURI:
872+
case <-st.doneCh:
873+
return
874+
}
875+
select {
876+
case st.queryResponses <- qresp:
877+
case <-st.doneCh:
878+
return
879+
}
880+
case <-st.doneCh:
881+
return
882+
}
883+
}
884+
}()
785885
if st.conn.progressUpdater != nil {
786886
st.statsCh = make(chan QueryProgressInfo)
787-
st.doneCh = make(chan struct{})
788887

789888
// progress updater go func
790889
go func() {
@@ -812,7 +911,6 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
812911
st.conn.progressUpdaterPeriod.LastCallbackTime = time.Now()
813912
st.conn.progressUpdaterPeriod.LastQueryState = sr.Stats.State
814913
}
815-
816914
return &sr, handleResponseError(resp.StatusCode, sr.Error)
817915
}
818916

@@ -875,7 +973,7 @@ func (qr *driverRows) Columns() []string {
875973
return []string{}
876974
}
877975
if qr.columns == nil {
878-
if err := qr.fetch(false); err != nil {
976+
if err := qr.fetch(); err != nil && err != io.EOF {
879977
qr.err = err
880978
return []string{}
881979
}
@@ -917,7 +1015,7 @@ func (qr *driverRows) Next(dest []driver.Value) error {
9171015
qr.err = io.EOF
9181016
return qr.err
9191017
}
920-
if err := qr.fetch(true); err != nil {
1018+
if err := qr.fetch(); err != nil {
9211019
qr.err = err
9221020
return err
9231021
}
@@ -947,7 +1045,7 @@ func (qr driverRows) LastInsertId() (int64, error) {
9471045

9481046
// RowsAffected returns the number of rows affected by the query.
9491047
func (qr driverRows) RowsAffected() (int64, error) {
950-
return qr.rowsAffected, qr.err
1048+
return qr.rowsAffected, nil
9511049
}
9521050

9531051
type queryResponse struct {
@@ -1016,71 +1114,31 @@ func handleResponseError(status int, respErr stmtError) error {
10161114
}
10171115
}
10181116

1019-
func (qr *driverRows) fetch(allowEOF bool) error {
1020-
if qr.nextURI == "" {
1021-
if allowEOF {
1022-
return io.EOF
1023-
}
1024-
return nil
1025-
}
1026-
1027-
for qr.nextURI != "" {
1028-
var qresp queryResponse
1029-
err := qr.executeFetchRequest(&qresp)
1030-
if err != nil {
1031-
return err
1032-
}
1033-
1034-
qr.rowindex = 0
1035-
qr.data = qresp.Data
1036-
qr.nextURI = qresp.NextURI
1037-
qr.rowsAffected = qresp.UpdateCount
1038-
qr.scheduleProgressUpdate(qresp.ID, qresp.Stats)
1039-
1040-
if len(qr.data) == 0 {
1041-
if qr.nextURI != "" {
1042-
continue
1043-
}
1044-
if allowEOF {
1045-
qr.err = io.EOF
1046-
return qr.err
1047-
}
1048-
}
1049-
if qr.columns == nil && len(qresp.Columns) > 0 {
1117+
func (qr *driverRows) fetch() error {
1118+
var qresp queryResponse
1119+
var err error
1120+
for {
1121+
select {
1122+
case qresp = <-qr.stmt.queryResponses:
10501123
err = qr.initColumns(&qresp)
10511124
if err != nil {
10521125
return err
10531126
}
1054-
}
1055-
return nil
1056-
}
1057-
return nil
1058-
}
1059-
1060-
func (qr *driverRows) executeFetchRequest(qresp *queryResponse) error {
1061-
hs := make(http.Header)
1062-
hs.Add(trinoUserHeader, qr.stmt.user)
1063-
req, err := qr.stmt.conn.newRequest("GET", qr.nextURI, nil, hs)
1064-
if err != nil {
1065-
return err
1066-
}
1067-
resp, err := qr.stmt.conn.roundTrip(qr.ctx, req)
1068-
if err != nil {
1069-
if qr.ctx.Err() == context.Canceled {
1070-
qr.Close()
1127+
qr.rowindex = 0
1128+
qr.data = qresp.Data
1129+
qr.rowsAffected = qresp.UpdateCount
1130+
qr.scheduleProgressUpdate(qresp.ID, qresp.Stats)
1131+
if len(qresp.Data) != 0 {
1132+
return nil
1133+
}
1134+
case err = <-qr.stmt.errors:
1135+
if err == context.Canceled {
1136+
qr.Close()
1137+
}
1138+
qr.err = err
10711139
return err
10721140
}
1073-
return err
10741141
}
1075-
defer resp.Body.Close()
1076-
1077-
d := json.NewDecoder(resp.Body)
1078-
d.UseNumber()
1079-
err = d.Decode(&qresp)
1080-
if err != nil {
1081-
return fmt.Errorf("trino: %v", err)
1082-
}
1083-
return handleResponseError(resp.StatusCode, qresp.Error)
10841142
}
10851143

10861144
func unmarshalArguments(signature *typeSignature) error {
@@ -1112,6 +1170,9 @@ func unmarshalArguments(signature *typeSignature) error {
11121170
}
11131171

11141172
func (qr *driverRows) initColumns(qresp *queryResponse) error {
1173+
if qr.columns != nil || len(qresp.Columns) == 0 {
1174+
return nil
1175+
}
11151176
var err error
11161177
for i := range qresp.Columns {
11171178
err = unmarshalArguments(&(qresp.Columns[i].TypeSignature))
@@ -1122,6 +1183,10 @@ func (qr *driverRows) initColumns(qresp *queryResponse) error {
11221183
qr.columns = make([]string, len(qresp.Columns))
11231184
qr.coltype = make([]*typeConverter, len(qresp.Columns))
11241185
for i, col := range qresp.Columns {
1186+
err = unmarshalArguments(&(qresp.Columns[i].TypeSignature))
1187+
if err != nil {
1188+
return fmt.Errorf("error decoding column type signature: %w", err)
1189+
}
11251190
qr.columns[i] = col.Name
11261191
qr.coltype[i], err = newTypeConverter(col.Type, col.TypeSignature)
11271192
if err != nil {

0 commit comments

Comments
 (0)