Skip to content

Commit 816a42f

Browse files
committed
Fix WithContext and add tests
1 parent 73d3c18 commit 816a42f

13 files changed

+111
-88
lines changed

cluster.go

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,7 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient {
673673
opt: opt,
674674
nodes: newClusterNodes(opt),
675675
},
676+
ctx: context.Background(),
676677
}
677678
c.state = newClusterStateHolder(c.loadState)
678679
c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo)
@@ -690,10 +691,7 @@ func (c *ClusterClient) init() {
690691
}
691692

692693
func (c *ClusterClient) Context() context.Context {
693-
if c.ctx != nil {
694-
return c.ctx
695-
}
696-
return context.Background()
694+
return c.ctx
697695
}
698696

699697
func (c *ClusterClient) WithContext(ctx context.Context) *ClusterClient {
@@ -702,6 +700,7 @@ func (c *ClusterClient) WithContext(ctx context.Context) *ClusterClient {
702700
}
703701
clone := *c
704702
clone.ctx = ctx
703+
clone.init()
705704
return &clone
706705
}
707706

@@ -732,7 +731,7 @@ func (c *ClusterClient) Do(args ...interface{}) *Cmd {
732731

733732
func (c *ClusterClient) DoContext(ctx context.Context, args ...interface{}) *Cmd {
734733
cmd := NewCmd(args...)
735-
c.ProcessContext(ctx, cmd)
734+
_ = c.ProcessContext(ctx, cmd)
736735
return cmd
737736
}
738737

@@ -1035,7 +1034,7 @@ func (c *ClusterClient) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
10351034
}
10361035

10371036
func (c *ClusterClient) processPipeline(ctx context.Context, cmds []Cmder) error {
1038-
return c.hooks.processPipeline(c.ctx, cmds, c._processPipeline)
1037+
return c.hooks.processPipeline(ctx, cmds, c._processPipeline)
10391038
}
10401039

10411040
func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) error {

cluster_test.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package redis_test
22

33
import (
4+
"context"
45
"fmt"
56
"net"
67
"strconv"
@@ -241,6 +242,14 @@ var _ = Describe("ClusterClient", func() {
241242
var client *redis.ClusterClient
242243

243244
assertClusterClient := func() {
245+
It("supports WithContext", func() {
246+
c, cancel := context.WithCancel(context.Background())
247+
cancel()
248+
249+
err := client.WithContext(c).Ping().Err()
250+
Expect(err).To(MatchError("context canceled"))
251+
})
252+
244253
It("should GET/SET/DEL", func() {
245254
err := client.Get("A").Err()
246255
Expect(err).To(Equal(redis.Nil))

internal/pool/pool.go

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -250,38 +250,38 @@ func (p *ConnPool) getTurn() {
250250
}
251251

252252
func (p *ConnPool) waitTurn(ctx context.Context) error {
253-
var done <-chan struct{}
254-
if ctx != nil {
255-
done = ctx.Done()
253+
select {
254+
case <-ctx.Done():
255+
return ctx.Err()
256+
default:
256257
}
257258

258259
select {
259-
case <-done:
260-
return ctx.Err()
261260
case p.queue <- struct{}{}:
262261
return nil
263262
default:
264-
timer := timers.Get().(*time.Timer)
265-
timer.Reset(p.opt.PoolTimeout)
263+
}
266264

267-
select {
268-
case <-done:
269-
if !timer.Stop() {
270-
<-timer.C
271-
}
272-
timers.Put(timer)
273-
return ctx.Err()
274-
case p.queue <- struct{}{}:
275-
if !timer.Stop() {
276-
<-timer.C
277-
}
278-
timers.Put(timer)
279-
return nil
280-
case <-timer.C:
281-
timers.Put(timer)
282-
atomic.AddUint32(&p.stats.Timeouts, 1)
283-
return ErrPoolTimeout
265+
timer := timers.Get().(*time.Timer)
266+
timer.Reset(p.opt.PoolTimeout)
267+
268+
select {
269+
case <-ctx.Done():
270+
if !timer.Stop() {
271+
<-timer.C
272+
}
273+
timers.Put(timer)
274+
return ctx.Err()
275+
case p.queue <- struct{}{}:
276+
if !timer.Stop() {
277+
<-timer.C
284278
}
279+
timers.Put(timer)
280+
return nil
281+
case <-timer.C:
282+
timers.Put(timer)
283+
atomic.AddUint32(&p.stats.Timeouts, 1)
284+
return ErrPoolTimeout
285285
}
286286
}
287287

internal/pool/pool_test.go

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package pool_test
22

33
import (
4+
"context"
45
"sync"
56
"testing"
67
"time"
@@ -12,6 +13,7 @@ import (
1213
)
1314

1415
var _ = Describe("ConnPool", func() {
16+
c := context.Background()
1517
var connPool *pool.ConnPool
1618

1719
BeforeEach(func() {
@@ -30,13 +32,13 @@ var _ = Describe("ConnPool", func() {
3032

3133
It("should unblock client when conn is removed", func() {
3234
// Reserve one connection.
33-
cn, err := connPool.Get(nil)
35+
cn, err := connPool.Get(c)
3436
Expect(err).NotTo(HaveOccurred())
3537

3638
// Reserve all other connections.
3739
var cns []*pool.Conn
3840
for i := 0; i < 9; i++ {
39-
cn, err := connPool.Get(nil)
41+
cn, err := connPool.Get(c)
4042
Expect(err).NotTo(HaveOccurred())
4143
cns = append(cns, cn)
4244
}
@@ -47,7 +49,7 @@ var _ = Describe("ConnPool", func() {
4749
defer GinkgoRecover()
4850

4951
started <- true
50-
_, err := connPool.Get(nil)
52+
_, err := connPool.Get(c)
5153
Expect(err).NotTo(HaveOccurred())
5254
done <- true
5355

@@ -80,6 +82,7 @@ var _ = Describe("ConnPool", func() {
8082
})
8183

8284
var _ = Describe("MinIdleConns", func() {
85+
c := context.Background()
8386
const poolSize = 100
8487
var minIdleConns int
8588
var connPool *pool.ConnPool
@@ -110,7 +113,7 @@ var _ = Describe("MinIdleConns", func() {
110113

111114
BeforeEach(func() {
112115
var err error
113-
cn, err = connPool.Get(nil)
116+
cn, err = connPool.Get(c)
114117
Expect(err).NotTo(HaveOccurred())
115118

116119
Eventually(func() int {
@@ -145,7 +148,7 @@ var _ = Describe("MinIdleConns", func() {
145148
perform(poolSize, func(_ int) {
146149
defer GinkgoRecover()
147150

148-
cn, err := connPool.Get(nil)
151+
cn, err := connPool.Get(c)
149152
Expect(err).NotTo(HaveOccurred())
150153
mu.Lock()
151154
cns = append(cns, cn)
@@ -160,7 +163,7 @@ var _ = Describe("MinIdleConns", func() {
160163
It("Get is blocked", func() {
161164
done := make(chan struct{})
162165
go func() {
163-
connPool.Get(nil)
166+
connPool.Get(c)
164167
close(done)
165168
}()
166169

@@ -247,6 +250,8 @@ var _ = Describe("MinIdleConns", func() {
247250
})
248251

249252
var _ = Describe("conns reaper", func() {
253+
c := context.Background()
254+
250255
const idleTimeout = time.Minute
251256
const maxAge = time.Hour
252257

@@ -274,7 +279,7 @@ var _ = Describe("conns reaper", func() {
274279
// add stale connections
275280
staleConns = nil
276281
for i := 0; i < 3; i++ {
277-
cn, err := connPool.Get(nil)
282+
cn, err := connPool.Get(c)
278283
Expect(err).NotTo(HaveOccurred())
279284
switch typ {
280285
case "idle":
@@ -288,7 +293,7 @@ var _ = Describe("conns reaper", func() {
288293

289294
// add fresh connections
290295
for i := 0; i < 3; i++ {
291-
cn, err := connPool.Get(nil)
296+
cn, err := connPool.Get(c)
292297
Expect(err).NotTo(HaveOccurred())
293298
conns = append(conns, cn)
294299
}
@@ -333,7 +338,7 @@ var _ = Describe("conns reaper", func() {
333338
for j := 0; j < 3; j++ {
334339
var freeCns []*pool.Conn
335340
for i := 0; i < 3; i++ {
336-
cn, err := connPool.Get(nil)
341+
cn, err := connPool.Get(c)
337342
Expect(err).NotTo(HaveOccurred())
338343
Expect(cn).NotTo(BeNil())
339344
freeCns = append(freeCns, cn)
@@ -342,7 +347,7 @@ var _ = Describe("conns reaper", func() {
342347
Expect(connPool.Len()).To(Equal(3))
343348
Expect(connPool.IdleLen()).To(Equal(0))
344349

345-
cn, err := connPool.Get(nil)
350+
cn, err := connPool.Get(c)
346351
Expect(err).NotTo(HaveOccurred())
347352
Expect(cn).NotTo(BeNil())
348353
conns = append(conns, cn)
@@ -370,6 +375,7 @@ var _ = Describe("conns reaper", func() {
370375
})
371376

372377
var _ = Describe("race", func() {
378+
c := context.Background()
373379
var connPool *pool.ConnPool
374380
var C, N int
375381

@@ -396,15 +402,15 @@ var _ = Describe("race", func() {
396402

397403
perform(C, func(id int) {
398404
for i := 0; i < N; i++ {
399-
cn, err := connPool.Get(nil)
405+
cn, err := connPool.Get(c)
400406
Expect(err).NotTo(HaveOccurred())
401407
if err == nil {
402408
connPool.Put(cn)
403409
}
404410
}
405411
}, func(id int) {
406412
for i := 0; i < N; i++ {
407-
cn, err := connPool.Get(nil)
413+
cn, err := connPool.Get(c)
408414
Expect(err).NotTo(HaveOccurred())
409415
if err == nil {
410416
connPool.Remove(cn)

pipeline.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ func (c *Pipeline) discard() error {
9898
// Exec always returns list of commands and error of the first failed
9999
// command if any.
100100
func (c *Pipeline) Exec() ([]Cmder, error) {
101-
return c.ExecContext(nil)
101+
return c.ExecContext(context.Background())
102102
}
103103

104104
func (c *Pipeline) ExecContext(ctx context.Context) ([]Cmder, error) {

pool_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package redis_test
22

33
import (
4+
"context"
45
"time"
56

67
"github.com/go-redis/redis"
@@ -81,7 +82,7 @@ var _ = Describe("pool", func() {
8182
})
8283

8384
It("removes broken connections", func() {
84-
cn, err := client.Pool().Get(nil)
85+
cn, err := client.Pool().Get(context.Background())
8586
Expect(err).NotTo(HaveOccurred())
8687
cn.SetNetConn(&badConn{})
8788
client.Pool().Put(cn)

redis.go

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,6 @@ type hooks struct {
3232
hooks []Hook
3333
}
3434

35-
func (hs *hooks) lazyCopy() {
36-
hs.hooks = hs.hooks[:len(hs.hooks):len(hs.hooks)]
37-
}
38-
3935
func (hs *hooks) AddHook(hook Hook) {
4036
hs.hooks = append(hs.hooks, hook)
4137
}
@@ -475,6 +471,7 @@ func NewClient(opt *Options) *Client {
475471
connPool: newConnPool(opt),
476472
},
477473
},
474+
ctx: context.Background(),
478475
}
479476
c.init()
480477

@@ -486,10 +483,7 @@ func (c *Client) init() {
486483
}
487484

488485
func (c *Client) Context() context.Context {
489-
if c.ctx != nil {
490-
return c.ctx
491-
}
492-
return context.Background()
486+
return c.ctx
493487
}
494488

495489
func (c *Client) WithContext(ctx context.Context) *Client {
@@ -498,6 +492,7 @@ func (c *Client) WithContext(ctx context.Context) *Client {
498492
}
499493
clone := *c
500494
clone.ctx = ctx
495+
clone.init()
501496
return &clone
502497
}
503498

redis_test.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ var _ = Describe("Client", func() {
2424
client.Close()
2525
})
2626

27+
It("supports WithContext", func() {
28+
c, cancel := context.WithCancel(context.Background())
29+
cancel()
30+
31+
err := client.WithContext(c).Ping().Err()
32+
Expect(err).To(MatchError("context canceled"))
33+
})
34+
2735
It("should Stringer", func() {
2836
Expect(client.String()).To(Equal("Redis<:6380 db:15>"))
2937
})
@@ -129,7 +137,7 @@ var _ = Describe("Client", func() {
129137

130138
It("processes custom commands", func() {
131139
cmd := redis.NewCmd("PING")
132-
client.Process(cmd)
140+
_ = client.Process(cmd)
133141

134142
// Flush buffers.
135143
Expect(client.Echo("hello").Err()).NotTo(HaveOccurred())
@@ -147,7 +155,7 @@ var _ = Describe("Client", func() {
147155
})
148156

149157
// Put bad connection in the pool.
150-
cn, err := client.Pool().Get(nil)
158+
cn, err := client.Pool().Get(context.Background())
151159
Expect(err).NotTo(HaveOccurred())
152160

153161
cn.SetNetConn(&badConn{})
@@ -185,7 +193,7 @@ var _ = Describe("Client", func() {
185193
})
186194

187195
It("should update conn.UsedAt on read/write", func() {
188-
cn, err := client.Pool().Get(nil)
196+
cn, err := client.Pool().Get(context.Background())
189197
Expect(err).NotTo(HaveOccurred())
190198
Expect(cn.UsedAt).NotTo(BeZero())
191199
createdAt := cn.UsedAt()
@@ -198,7 +206,7 @@ var _ = Describe("Client", func() {
198206
err = client.Ping().Err()
199207
Expect(err).NotTo(HaveOccurred())
200208

201-
cn, err = client.Pool().Get(nil)
209+
cn, err = client.Pool().Get(context.Background())
202210
Expect(err).NotTo(HaveOccurred())
203211
Expect(cn).NotTo(BeNil())
204212
Expect(cn.UsedAt().After(createdAt)).To(BeTrue())

0 commit comments

Comments
 (0)