@@ -1008,82 +1008,118 @@ func TestIntegrationUnsupportedHeader(t *testing.T) {
10081008 }
10091009}
10101010
1011- func TestIntegrationQueryContextCancellation (t * testing.T ) {
1012- err := RegisterCustomClient ("uncompressed" , & http.Client {Transport : & http.Transport {DisableCompression : true }})
1013- if err != nil {
1011+ func TestIntegrationQueryContext (t * testing.T ) {
1012+ tests := []struct {
1013+ name string
1014+ timeout time.Duration
1015+ expectedErrMsg string
1016+ }{
1017+ {
1018+ name : "Context Cancellation" ,
1019+ timeout : 0 ,
1020+ expectedErrMsg : "canceled" ,
1021+ },
1022+ {
1023+ name : "Context Deadline Exceeded" ,
1024+ timeout : 3 * time .Second ,
1025+ expectedErrMsg : "context deadline exceeded" ,
1026+ },
1027+ }
1028+
1029+ if err := RegisterCustomClient ("uncompressed" , & http.Client {Transport : & http.Transport {DisableCompression : true }}); err != nil {
10141030 t .Fatal (err )
10151031 }
1016- dsn := * integrationServerFlag
1017- dsn += "?catalog=tpch&schema=sf100&source=cancel-test&custom_client=uncompressed"
1032+
1033+ dsn := * integrationServerFlag + "?catalog=tpch&schema=sf100&source=cancel-test&custom_client=uncompressed"
10181034 db := integrationOpen (t , dsn )
10191035 defer db .Close ()
10201036
1021- ctx , cancel := context .WithCancel (context .Background ())
1022- errCh := make (chan error , 3 )
1023- done := make (chan struct {})
1024- longQuery := "SELECT COUNT(*) FROM lineitem"
1025- go func () {
1026- // query will complete in ~7s unless cancelled
1027- rows , err := db .QueryContext (ctx , longQuery )
1028- if err != nil {
1029- errCh <- err
1030- return
1031- }
1032- rows .Next ()
1033- if err = rows .Err (); err != nil {
1034- errCh <- err
1035- return
1036- }
1037- close (done )
1038- }()
1039-
1040- // poll system.runtime.queries and wait for query to start working
1041- var queryID string
1042- pollCtx , pollCancel := context .WithTimeout (context .Background (), 1 * time .Second )
1043- defer pollCancel ()
1044- for {
1045- row := db .QueryRowContext (pollCtx , "SELECT query_id FROM system.runtime.queries WHERE state = 'RUNNING' AND source = 'cancel-test' AND query = ?" , longQuery )
1046- err := row .Scan (& queryID )
1047- if err == nil {
1048- break
1049- }
1050- if err != sql .ErrNoRows {
1051- t .Fatal ("failed to read query id" , err )
1052- }
1053- if err = contextSleep (pollCtx , 100 * time .Millisecond ); err != nil {
1054- t .Fatal ("query did not start in 1 second" )
1055- }
1056- }
1037+ for _ , tt := range tests {
1038+ t .Run (tt .name , func (t * testing.T ) {
1039+ var ctx context.Context
1040+ var cancel context.CancelFunc
10571041
1058- cancel ()
1042+ if tt .timeout == 0 {
1043+ ctx , cancel = context .WithCancel (context .Background ())
1044+ } else {
1045+ ctx , cancel = context .WithTimeout (context .Background (), tt .timeout )
1046+ }
1047+ defer cancel ()
10591048
1060- select {
1061- case <- done :
1062- t .Fatal ("unexpected query with cancelled context succeeded" )
1063- break
1064- case err = <- errCh :
1065- if ! strings .Contains (err .Error (), "canceled" ) {
1066- t .Fatal ("expected err to be canceled but got:" , err )
1067- }
1068- }
1049+ errCh := make (chan error , 1 )
1050+ done := make (chan struct {})
1051+ longQuery := "SELECT COUNT(*) FROM lineitem"
10691052
1070- // poll system.runtime.queries and wait for query to be cancelled
1071- pollCtx , pollCancel = context .WithTimeout (context .Background (), 1 * time .Second )
1072- defer pollCancel ()
1073- for {
1074- row := db .QueryRowContext (pollCtx , "SELECT state, error_code FROM system.runtime.queries WHERE query_id = ?" , queryID )
1075- var state string
1076- var code * string
1077- err := row .Scan (& state , & code )
1078- if err != nil {
1079- t .Fatal ("failed to read query id" , err )
1080- }
1081- if state == "FAILED" && code != nil && * code == "USER_CANCELED" {
1082- break
1083- }
1084- if err = contextSleep (pollCtx , 100 * time .Millisecond ); err != nil {
1085- t .Fatal ("query was not cancelled in 1 second; state, code, err are:" , state , code , err )
1086- }
1053+ go func () {
1054+ // query will complete in ~7s unless cancelled
1055+ rows , err := db .QueryContext (ctx , longQuery )
1056+ if err != nil {
1057+ errCh <- err
1058+ return
1059+ }
1060+ defer rows .Close ()
1061+
1062+ rows .Next ()
1063+ if err = rows .Err (); err != nil {
1064+ errCh <- err
1065+ return
1066+ }
1067+ close (done )
1068+ }()
1069+
1070+ // Poll system.runtime.queries to get the query ID
1071+ var queryID string
1072+ pollCtx , pollCancel := context .WithTimeout (context .Background (), 1 * time .Second )
1073+ defer pollCancel ()
1074+
1075+ for {
1076+ row := db .QueryRowContext (pollCtx , "SELECT query_id FROM system.runtime.queries WHERE state = 'RUNNING' AND source = 'cancel-test' AND query = ?" , longQuery )
1077+ err := row .Scan (& queryID )
1078+ if err == nil {
1079+ break
1080+ }
1081+ if err != sql .ErrNoRows {
1082+ t .Fatal ("failed to read query ID:" , err )
1083+ }
1084+ if err = contextSleep (pollCtx , 100 * time .Millisecond ); err != nil {
1085+ t .Fatal ("query did not start in 1 second" )
1086+ }
1087+ }
1088+
1089+ if tt .timeout == 0 {
1090+ cancel ()
1091+ }
1092+
1093+ // Wait for the query to be canceled or completed
1094+ select {
1095+ case <- done :
1096+ t .Fatal ("unexpected query succeeded despite cancellation or deadline" )
1097+ case err := <- errCh :
1098+ if ! strings .Contains (err .Error (), tt .expectedErrMsg ) {
1099+ t .Fatalf ("expected error containing %q, but got: %v" , tt .expectedErrMsg , err )
1100+ }
1101+ }
1102+
1103+ // Poll system.runtime.queries to verify the query was canceled
1104+ pollCtx , pollCancel = context .WithTimeout (context .Background (), 2 * time .Second )
1105+ defer pollCancel ()
1106+
1107+ for {
1108+ row := db .QueryRowContext (pollCtx , "SELECT state, error_code FROM system.runtime.queries WHERE query_id = ?" , queryID )
1109+ var state string
1110+ var code * string
1111+ err := row .Scan (& state , & code )
1112+ if err != nil {
1113+ t .Fatal ("failed to read query state:" , err )
1114+ }
1115+ if state == "FAILED" && code != nil && * code == "USER_CANCELED" {
1116+ return
1117+ }
1118+ if err = contextSleep (pollCtx , 100 * time .Millisecond ); err != nil {
1119+ t .Fatalf ("query was not canceled in 2 seconds; state: %s, code: %v, err: %v" , state , code , err )
1120+ }
1121+ }
1122+ })
10871123 }
10881124}
10891125
0 commit comments