Skip to content

Commit 019254e

Browse files
nineinchnickJan Waś
authored andcommitted
Fetch and decode query results concurrently
1 parent d60ab69 commit 019254e

File tree

2 files changed

+142
-71
lines changed

2 files changed

+142
-71
lines changed

trino/trino.go

Lines changed: 141 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -547,11 +547,15 @@ func newErrQueryFailedFromResponse(resp *http.Response) *ErrQueryFailed {
547547
}
548548

549549
type driverStmt struct {
550-
conn *Conn
551-
query string
552-
user string
553-
statsCh chan QueryProgressInfo
554-
doneCh chan struct{}
550+
conn *Conn
551+
query string
552+
user string
553+
nextURIs chan string
554+
httpResponses chan *http.Response
555+
queryResponses chan queryResponse
556+
statsCh chan QueryProgressInfo
557+
errors chan error
558+
doneCh chan struct{}
555559
}
556560

557561
var (
@@ -563,12 +567,26 @@ var (
563567

564568
// Close closes statement just before releasing connection
565569
func (st *driverStmt) Close() error {
566-
if st.doneCh != nil {
567-
close(st.doneCh)
570+
if st.doneCh == nil {
571+
return nil
568572
}
573+
close(st.doneCh)
569574
if st.statsCh != nil {
570575
<-st.statsCh
576+
st.statsCh = nil
577+
}
578+
go func() {
579+
// drain errors chan to allow goroutines to write to it
580+
for range st.errors {
581+
}
582+
}()
583+
for range st.queryResponses {
584+
}
585+
for range st.httpResponses {
571586
}
587+
close(st.nextURIs)
588+
close(st.errors)
589+
st.doneCh = nil
572590
return nil
573591
}
574592

@@ -596,7 +614,7 @@ func (st *driverStmt) ExecContext(ctx context.Context, args []driver.NamedValue)
596614
}
597615
// consume all results, if there are any
598616
for err == nil {
599-
err = rows.fetch(true)
617+
err = rows.fetch()
600618
}
601619

602620
if err != nil && err != io.EOF {
@@ -707,7 +725,7 @@ func (st *driverStmt) QueryContext(ctx context.Context, args []driver.NamedValue
707725
statsCh: st.statsCh,
708726
doneCh: st.doneCh,
709727
}
710-
if err = rows.fetch(false); err != nil {
728+
if err = rows.fetch(); err != nil && err != io.EOF {
711729
return nil, err
712730
}
713731
return rows, nil
@@ -780,9 +798,89 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
780798
return nil, fmt.Errorf("trino: %v", err)
781799
}
782800

801+
st.doneCh = make(chan struct{})
802+
st.nextURIs = make(chan string)
803+
st.httpResponses = make(chan *http.Response)
804+
st.queryResponses = make(chan queryResponse)
805+
st.errors = make(chan error)
806+
go func() {
807+
defer close(st.httpResponses)
808+
for {
809+
select {
810+
case nextURI := <-st.nextURIs:
811+
if nextURI == "" {
812+
return
813+
}
814+
hs := make(http.Header)
815+
hs.Add(trinoUserHeader, st.user)
816+
req, err := st.conn.newRequest("GET", nextURI, nil, hs)
817+
if err != nil {
818+
st.errors <- err
819+
return
820+
}
821+
resp, err := st.conn.roundTrip(ctx, req)
822+
if err != nil {
823+
if ctx.Err() == context.Canceled {
824+
st.errors <- context.Canceled
825+
return
826+
}
827+
st.errors <- err
828+
return
829+
}
830+
select {
831+
case st.httpResponses <- resp:
832+
case <-st.doneCh:
833+
return
834+
}
835+
case <-st.doneCh:
836+
return
837+
}
838+
}
839+
}()
840+
go func() {
841+
defer close(st.queryResponses)
842+
for {
843+
select {
844+
case resp := <-st.httpResponses:
845+
if resp == nil {
846+
return
847+
}
848+
var qresp queryResponse
849+
d := json.NewDecoder(resp.Body)
850+
d.UseNumber()
851+
err = d.Decode(&qresp)
852+
if err != nil {
853+
st.errors <- fmt.Errorf("trino: %v", err)
854+
return
855+
}
856+
err = resp.Body.Close()
857+
if err != nil {
858+
st.errors <- err
859+
return
860+
}
861+
err = handleResponseError(resp.StatusCode, qresp.Error)
862+
if err != nil {
863+
st.errors <- err
864+
return
865+
}
866+
select {
867+
case st.nextURIs <- qresp.NextURI:
868+
case <-st.doneCh:
869+
return
870+
}
871+
select {
872+
case st.queryResponses <- qresp:
873+
case <-st.doneCh:
874+
return
875+
}
876+
case <-st.doneCh:
877+
return
878+
}
879+
}
880+
}()
881+
st.nextURIs <- sr.NextURI
783882
if st.conn.progressUpdater != nil {
784883
st.statsCh = make(chan QueryProgressInfo)
785-
st.doneCh = make(chan struct{})
786884

787885
// progress updater go func
788886
go func() {
@@ -810,7 +908,6 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
810908
st.conn.progressUpdaterPeriod.LastCallbackTime = time.Now()
811909
st.conn.progressUpdaterPeriod.LastQueryState = sr.Stats.State
812910
}
813-
814911
return &sr, handleResponseError(resp.StatusCode, sr.Error)
815912
}
816913

@@ -873,7 +970,7 @@ func (qr *driverRows) Columns() []string {
873970
return []string{}
874971
}
875972
if qr.columns == nil {
876-
if err := qr.fetch(false); err != nil {
973+
if err := qr.fetch(); err != nil && err != io.EOF {
877974
qr.err = err
878975
return []string{}
879976
}
@@ -915,7 +1012,7 @@ func (qr *driverRows) Next(dest []driver.Value) error {
9151012
qr.err = io.EOF
9161013
return qr.err
9171014
}
918-
if err := qr.fetch(true); err != nil {
1015+
if err := qr.fetch(); err != nil {
9191016
qr.err = err
9201017
return err
9211018
}
@@ -925,6 +1022,9 @@ func (qr *driverRows) Next(dest []driver.Value) error {
9251022
return qr.err
9261023
}
9271024
for i, v := range qr.coltype {
1025+
if i > len(dest)-1 {
1026+
break
1027+
}
9281028
vv, err := v.ConvertValue(qr.data[qr.rowindex][i])
9291029
if err != nil {
9301030
qr.err = err
@@ -945,7 +1045,7 @@ func (qr driverRows) LastInsertId() (int64, error) {
9451045

9461046
// RowsAffected returns the number of rows affected by the query.
9471047
func (qr driverRows) RowsAffected() (int64, error) {
948-
return qr.rowsAffected, qr.err
1048+
return qr.rowsAffected, nil
9491049
}
9501050

9511051
type queryResponse struct {
@@ -1014,71 +1114,34 @@ func handleResponseError(status int, respErr stmtError) error {
10141114
}
10151115
}
10161116

1017-
func (qr *driverRows) fetch(allowEOF bool) error {
1018-
if qr.nextURI == "" {
1019-
if allowEOF {
1020-
return io.EOF
1021-
}
1022-
return nil
1023-
}
1024-
1025-
for qr.nextURI != "" {
1026-
var qresp queryResponse
1027-
err := qr.executeFetchRequest(&qresp)
1028-
if err != nil {
1029-
return err
1030-
}
1031-
1032-
qr.rowindex = 0
1033-
qr.data = qresp.Data
1034-
qr.nextURI = qresp.NextURI
1035-
qr.rowsAffected = qresp.UpdateCount
1036-
qr.scheduleProgressUpdate(qresp.ID, qresp.Stats)
1037-
1038-
if len(qr.data) == 0 {
1039-
if qr.nextURI != "" {
1040-
continue
1041-
}
1042-
if allowEOF {
1043-
qr.err = io.EOF
1044-
return qr.err
1117+
func (qr *driverRows) fetch() error {
1118+
var qresp queryResponse
1119+
var err error
1120+
for {
1121+
select {
1122+
case qresp = <-qr.stmt.queryResponses:
1123+
if qresp.ID == "" {
1124+
return io.EOF
10451125
}
1046-
}
1047-
if qr.columns == nil && len(qresp.Columns) > 0 {
10481126
err = qr.initColumns(&qresp)
10491127
if err != nil {
10501128
return err
10511129
}
1052-
}
1053-
return nil
1054-
}
1055-
return nil
1056-
}
1057-
1058-
func (qr *driverRows) executeFetchRequest(qresp *queryResponse) error {
1059-
hs := make(http.Header)
1060-
hs.Add(trinoUserHeader, qr.stmt.user)
1061-
req, err := qr.stmt.conn.newRequest("GET", qr.nextURI, nil, hs)
1062-
if err != nil {
1063-
return err
1064-
}
1065-
resp, err := qr.stmt.conn.roundTrip(qr.ctx, req)
1066-
if err != nil {
1067-
if qr.ctx.Err() == context.Canceled {
1068-
qr.Close()
1130+
qr.rowindex = 0
1131+
qr.data = qresp.Data
1132+
qr.rowsAffected = qresp.UpdateCount
1133+
qr.scheduleProgressUpdate(qresp.ID, qresp.Stats)
1134+
if len(qr.data) != 0 {
1135+
return nil
1136+
}
1137+
case err = <-qr.stmt.errors:
1138+
if err == context.Canceled {
1139+
qr.Close()
1140+
}
1141+
qr.err = err
10691142
return err
10701143
}
1071-
return err
1072-
}
1073-
defer resp.Body.Close()
1074-
1075-
d := json.NewDecoder(resp.Body)
1076-
d.UseNumber()
1077-
err = d.Decode(&qresp)
1078-
if err != nil {
1079-
return fmt.Errorf("trino: %v", err)
10801144
}
1081-
return handleResponseError(resp.StatusCode, qresp.Error)
10821145
}
10831146

10841147
func unmarshalArguments(signature *typeSignature) error {
@@ -1110,6 +1173,9 @@ func unmarshalArguments(signature *typeSignature) error {
11101173
}
11111174

11121175
func (qr *driverRows) initColumns(qresp *queryResponse) error {
1176+
if qr.columns != nil || len(qresp.Columns) == 0 {
1177+
return nil
1178+
}
11131179
var err error
11141180
for i := range qresp.Columns {
11151181
err = unmarshalArguments(&(qresp.Columns[i].TypeSignature))
@@ -1120,6 +1186,10 @@ func (qr *driverRows) initColumns(qresp *queryResponse) error {
11201186
qr.columns = make([]string, len(qresp.Columns))
11211187
qr.coltype = make([]*typeConverter, len(qresp.Columns))
11221188
for i, col := range qresp.Columns {
1189+
err = unmarshalArguments(&(qresp.Columns[i].TypeSignature))
1190+
if err != nil {
1191+
return fmt.Errorf("error decoding column type signature: %w", err)
1192+
}
11231193
qr.columns[i] = col.Name
11241194
qr.coltype[i], err = newTypeConverter(col.Type, col.TypeSignature)
11251195
if err != nil {

trino/trino_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,7 @@ func TestFetchNoStackOverflow(t *testing.T) {
800800
if buf == nil {
801801
buf = new(bytes.Buffer)
802802
json.NewEncoder(buf).Encode(&stmtResponse{
803+
ID: "fake-query",
803804
NextURI: ts.URL + "/v1/statement/20210817_140827_00000_arvdv/1",
804805
})
805806
}

0 commit comments

Comments
 (0)