Skip to content

Commit a6657b2

Browse files
authored
Merge pull request #535 from mjibson/ctx
Add context methods
2 parents 67c3f2a + 9c80e00 commit a6657b2

File tree

4 files changed

+232
-7
lines changed

4 files changed

+232
-7
lines changed

.travis.yml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
language: go
22

33
go:
4-
- 1.5
5-
- 1.6
6-
- 1.7
7-
- tip
4+
- 1.5.x
5+
- 1.6.x
6+
- 1.7.x
7+
- 1.8.x
8+
- master
89

910
sudo: true
1011

conn.go

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,15 @@ type conn struct {
9898
namei int
9999
scratch [512]byte
100100
txnStatus transactionStatus
101+
txnClosed chan<- struct{}
102+
103+
// Save connection arguments to use during CancelRequest.
104+
dialer Dialer
105+
opts values
106+
107+
// Cancellation key data for use with CancelRequest messages.
108+
processID int
109+
secretKey int
101110

102111
parameterStatus parameterStatus
103112

@@ -307,7 +316,10 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
307316
}
308317
}
309318

310-
cn := &conn{}
319+
cn := &conn{
320+
opts: o,
321+
dialer: d,
322+
}
311323
err = cn.handleDriverSettings(o)
312324
if err != nil {
313325
return nil, err
@@ -529,7 +541,15 @@ func (cn *conn) Begin() (_ driver.Tx, err error) {
529541
return cn, nil
530542
}
531543

544+
func (cn *conn) closeTxn() {
545+
if cn.txnClosed != nil {
546+
close(cn.txnClosed)
547+
cn.txnClosed = nil
548+
}
549+
}
550+
532551
func (cn *conn) Commit() (err error) {
552+
defer cn.closeTxn()
533553
if cn.bad {
534554
return driver.ErrBadConn
535555
}
@@ -565,6 +585,7 @@ func (cn *conn) Commit() (err error) {
565585
}
566586

567587
func (cn *conn) Rollback() (err error) {
588+
defer cn.closeTxn()
568589
if cn.bad {
569590
return driver.ErrBadConn
570591
}
@@ -796,7 +817,11 @@ func (cn *conn) Close() (err error) {
796817
}
797818

798819
// Implement the "Queryer" interface
799-
func (cn *conn) Query(query string, args []driver.Value) (_ driver.Rows, err error) {
820+
func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
821+
return cn.query(query, args)
822+
}
823+
824+
func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
800825
if cn.bad {
801826
return nil, driver.ErrBadConn
802827
}
@@ -1074,6 +1099,7 @@ func (cn *conn) startup(o values) {
10741099
t, r := cn.recv()
10751100
switch t {
10761101
case 'K':
1102+
cn.processBackendKeyData(r)
10771103
case 'S':
10781104
cn.processParameterStatus(r)
10791105
case 'R':
@@ -1301,6 +1327,7 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
13011327

13021328
type rows struct {
13031329
cn *conn
1330+
closed chan<- struct{}
13041331
colNames []string
13051332
colTyps []oid.Oid
13061333
colFmts []format
@@ -1309,6 +1336,9 @@ type rows struct {
13091336
}
13101337

13111338
func (rs *rows) Close() error {
1339+
if rs.closed != nil {
1340+
defer close(rs.closed)
1341+
}
13121342
// no need to look at cn.bad as Next() will
13131343
for {
13141344
err := rs.Next(nil)
@@ -1513,6 +1543,11 @@ func (cn *conn) readReadyForQuery() {
15131543
}
15141544
}
15151545

1546+
func (c *conn) processBackendKeyData(r *readBuf) {
1547+
c.processID = r.int32()
1548+
c.secretKey = r.int32()
1549+
}
1550+
15161551
func (cn *conn) readParseResponse() {
15171552
t, r := cn.recv1()
15181553
switch t {

conn_go18.go

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
// +build go1.8
2+
3+
package pq
4+
5+
import (
6+
"context"
7+
"database/sql/driver"
8+
"errors"
9+
)
10+
11+
// Implement the "QueryerContext" interface
12+
func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
13+
list := make([]driver.Value, len(args))
14+
for i, nv := range args {
15+
list[i] = nv.Value
16+
}
17+
var closed chan<- struct{}
18+
if ctx.Done() != nil {
19+
closed = watchCancel(ctx, cn.cancel)
20+
}
21+
r, err := cn.query(query, list)
22+
if err != nil {
23+
return nil, err
24+
}
25+
r.closed = closed
26+
return r, nil
27+
}
28+
29+
// Implement the "ExecerContext" interface
30+
func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
31+
list := make([]driver.Value, len(args))
32+
for i, nv := range args {
33+
list[i] = nv.Value
34+
}
35+
36+
if ctx.Done() != nil {
37+
closed := watchCancel(ctx, cn.cancel)
38+
defer close(closed)
39+
}
40+
41+
return cn.Exec(query, list)
42+
}
43+
44+
// Implement the "ConnBeginTx" interface
45+
func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
46+
if opts.Isolation != 0 {
47+
return nil, errors.New("isolation levels not supported")
48+
}
49+
if opts.ReadOnly {
50+
return nil, errors.New("read-only transactions not supported")
51+
}
52+
tx, err := cn.Begin()
53+
if err != nil {
54+
return nil, err
55+
}
56+
if ctx.Done() != nil {
57+
cn.txnClosed = watchCancel(ctx, cn.cancel)
58+
}
59+
return tx, nil
60+
}
61+
62+
func watchCancel(ctx context.Context, cancel func()) chan<- struct{} {
63+
closed := make(chan struct{})
64+
go func() {
65+
select {
66+
case <-ctx.Done():
67+
cancel()
68+
case <-closed:
69+
}
70+
}()
71+
return closed
72+
}
73+
74+
func (cn *conn) cancel() {
75+
var err error
76+
can := &conn{}
77+
can.c, err = dial(cn.dialer, cn.opts)
78+
if err != nil {
79+
return
80+
}
81+
can.ssl(cn.opts)
82+
83+
defer can.errRecover(&err)
84+
85+
w := can.writeBuf(0)
86+
w.int32(80877102) // cancel request code
87+
w.int32(cn.processID)
88+
w.int32(cn.secretKey)
89+
90+
can.sendStartupPacket(w)
91+
_ = can.c.Close()
92+
}

go18_test.go

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22

33
package pq
44

5-
import "testing"
5+
import (
6+
"context"
7+
"database/sql"
8+
"testing"
9+
"time"
10+
)
611

712
func TestMultipleSimpleQuery(t *testing.T) {
813
db := openTestConn(t)
@@ -66,3 +71,95 @@ func TestMultipleSimpleQuery(t *testing.T) {
6671
t.Fatal("unexpected result set")
6772
}
6873
}
74+
75+
func TestContextCancelExec(t *testing.T) {
76+
db := openTestConn(t)
77+
defer db.Close()
78+
79+
ctx, cancel := context.WithCancel(context.Background())
80+
81+
// Delay execution for just a bit until db.ExecContext has begun.
82+
go func() {
83+
time.Sleep(time.Millisecond * 10)
84+
cancel()
85+
}()
86+
87+
// Not canceled until after the exec has started.
88+
if _, err := db.ExecContext(ctx, "select pg_sleep(1)"); err == nil {
89+
t.Fatal("expected error")
90+
} else if err.Error() != "pq: canceling statement due to user request" {
91+
t.Fatalf("unexpected error: %s", err)
92+
}
93+
94+
// Context is already canceled, so error should come before execution.
95+
if _, err := db.ExecContext(ctx, "select pg_sleep(1)"); err == nil {
96+
t.Fatal("expected error")
97+
} else if err.Error() != "context canceled" {
98+
t.Fatalf("unexpected error: %s", err)
99+
}
100+
}
101+
102+
func TestContextCancelQuery(t *testing.T) {
103+
db := openTestConn(t)
104+
defer db.Close()
105+
106+
ctx, cancel := context.WithCancel(context.Background())
107+
108+
// Delay execution for just a bit until db.QueryContext has begun.
109+
go func() {
110+
time.Sleep(time.Millisecond * 10)
111+
cancel()
112+
}()
113+
114+
// Not canceled until after the exec has started.
115+
if _, err := db.QueryContext(ctx, "select pg_sleep(1)"); err == nil {
116+
t.Fatal("expected error")
117+
} else if err.Error() != "pq: canceling statement due to user request" {
118+
t.Fatalf("unexpected error: %s", err)
119+
}
120+
121+
// Context is already canceled, so error should come before execution.
122+
if _, err := db.QueryContext(ctx, "select pg_sleep(1)"); err == nil {
123+
t.Fatal("expected error")
124+
} else if err.Error() != "context canceled" {
125+
t.Fatalf("unexpected error: %s", err)
126+
}
127+
}
128+
129+
func TestContextCancelBegin(t *testing.T) {
130+
db := openTestConn(t)
131+
defer db.Close()
132+
133+
ctx, cancel := context.WithCancel(context.Background())
134+
tx, err := db.BeginTx(ctx, nil)
135+
if err != nil {
136+
t.Fatal(err)
137+
}
138+
139+
// Delay execution for just a bit until tx.Exec has begun.
140+
go func() {
141+
time.Sleep(time.Millisecond * 10)
142+
cancel()
143+
}()
144+
145+
// Not canceled until after the exec has started.
146+
if _, err := tx.Exec("select pg_sleep(1)"); err == nil {
147+
t.Fatal("expected error")
148+
} else if err.Error() != "pq: canceling statement due to user request" {
149+
t.Fatalf("unexpected error: %s", err)
150+
}
151+
152+
// Transaction is canceled, so expect an error.
153+
if _, err := tx.Query("select pg_sleep(1)"); err == nil {
154+
t.Fatal("expected error")
155+
} else if err != sql.ErrTxDone {
156+
t.Fatalf("unexpected error: %s", err)
157+
}
158+
159+
// Context is canceled, so cannot begin a transaction.
160+
if _, err := db.BeginTx(ctx, nil); err == nil {
161+
t.Fatal("expected error")
162+
} else if err.Error() != "context canceled" {
163+
t.Fatalf("unexpected error: %s", err)
164+
}
165+
}

0 commit comments

Comments
 (0)