Skip to content

Commit 5716445

Browse files
committed
Simplify parts of websocket.go
1 parent 64e7470 commit 5716445

File tree

1 file changed

+33
-43
lines changed

1 file changed

+33
-43
lines changed

websocket.go

Lines changed: 33 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,6 @@ type Conn struct {
4949
readDone chan int
5050
}
5151

52-
func (c *Conn) getCloseErr() error {
53-
if c.closeErr != nil {
54-
return c.closeErr
55-
}
56-
return nil
57-
}
58-
5952
func (c *Conn) close(err error) {
6053
if err != nil {
6154
err = xerrors.Errorf("websocket: connection broken: %w", err)
@@ -160,8 +153,12 @@ messageLoop:
160153
masked: c.client,
161154
}
162155
c.writeFrame(h, control.payload)
163-
c.writeDone <- struct{}{}
164-
continue
156+
select {
157+
case <-c.closed:
158+
return
159+
case c.writeDone <- struct{}{}:
160+
continue
161+
}
165162
case b, ok := <-c.writeBytes:
166163
h := header{
167164
fin: !ok,
@@ -349,14 +346,14 @@ func (c *Conn) Close(code StatusCode, reason string) error {
349346
p, _ = closePayload(StatusInternalError, fmt.Sprintf("websocket: application tried to send code %v but code or reason was invalid", code))
350347
}
351348

352-
err2 := c.writeClose(p, CloseError{
349+
cerr := c.writeClose(p, CloseError{
353350
Code: code,
354351
Reason: reason,
355352
})
356353
if err != nil {
357354
return err
358355
}
359-
return err2
356+
return cerr
360357
}
361358

362359
func (c *Conn) writeClose(p []byte, cerr CloseError) error {
@@ -381,19 +378,19 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error {
381378
func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
382379
select {
383380
case <-c.closed:
384-
return c.getCloseErr()
381+
return c.closeErr
385382
case c.control <- control{
386383
opcode: opcode,
387384
payload: p,
388385
}:
389386
case <-ctx.Done():
390387
c.close(xerrors.New("force closed: close frame write timed out"))
391-
return c.getCloseErr()
388+
return c.closeErr
392389
}
393390

394391
select {
395392
case <-c.closed:
396-
return c.getCloseErr()
393+
return c.closeErr
397394
case <-c.writeDone:
398395
return nil
399396
case <-ctx.Done():
@@ -420,34 +417,25 @@ type messageWriter struct {
420417
ctx context.Context
421418
c *Conn
422419
acquiredLock bool
423-
sentFirst bool
424-
425-
done chan struct{}
426420
}
427421

428422
// Write writes the given bytes to the WebSocket connection.
429423
// The frame will automatically be fragmented as appropriate
430424
// with the buffers obtained from http.Hijacker.
431425
// Please ensure you call Close once you have written the full message.
432426
func (w *messageWriter) Write(p []byte) (int, error) {
433-
if !w.acquiredLock {
434-
select {
435-
case <-w.c.closed:
436-
return 0, w.c.getCloseErr()
437-
case w.c.write <- w.datatype:
438-
w.acquiredLock = true
439-
case <-w.ctx.Done():
440-
return 0, w.ctx.Err()
441-
}
427+
err := w.acquire()
428+
if err != nil {
429+
return 0, err
442430
}
443431

444432
select {
445433
case <-w.c.closed:
446-
return 0, w.c.getCloseErr()
434+
return 0, w.c.closeErr
447435
case w.c.writeBytes <- p:
448436
select {
449437
case <-w.c.closed:
450-
return 0, w.c.getCloseErr()
438+
return 0, w.c.closeErr
451439
case <-w.c.writeDone:
452440
return len(p), nil
453441
case <-w.ctx.Done():
@@ -458,23 +446,32 @@ func (w *messageWriter) Write(p []byte) (int, error) {
458446
}
459447
}
460448

461-
// Close flushes the frame to the connection.
462-
// This must be called for every messageWriter.
463-
func (w *messageWriter) Close() error {
449+
func (w *messageWriter) acquire() error {
464450
if !w.acquiredLock {
465451
select {
466452
case <-w.c.closed:
467-
return w.c.getCloseErr()
453+
return w.c.closeErr
468454
case w.c.write <- w.datatype:
469455
w.acquiredLock = true
470456
case <-w.ctx.Done():
471457
return w.ctx.Err()
472458
}
473459
}
460+
return nil
461+
}
462+
463+
// Close flushes the frame to the connection.
464+
// This must be called for every messageWriter.
465+
func (w *messageWriter) Close() error {
466+
err := w.acquire()
467+
if err != nil {
468+
return err
469+
}
470+
474471
close(w.c.writeBytes)
475472
select {
476473
case <-w.c.closed:
477-
return w.c.getCloseErr()
474+
return w.c.closeErr
478475
case <-w.ctx.Done():
479476
return w.ctx.Err()
480477
case <-w.c.writeDone:
@@ -490,7 +487,7 @@ func (w *messageWriter) Close() error {
490487
func (c *Conn) Read(ctx context.Context) (DataType, io.Reader, error) {
491488
select {
492489
case <-c.closed:
493-
return 0, nil, xerrors.Errorf("failed to read message: %w", c.getCloseErr())
490+
return 0, nil, xerrors.Errorf("failed to read message: %w", c.closeErr)
494491
case opcode := <-c.read:
495492
return DataType(opcode), &messageReader{
496493
ctx: ctx,
@@ -507,24 +504,17 @@ type messageReader struct {
507504
c *Conn
508505
}
509506

510-
// SetContext bounds the read operation to the ctx.
511-
// By default, the context is the one passed to conn.ReadMessage.
512-
// You still almost always want a separate context for reading the message though.
513-
func (r *messageReader) SetContext(ctx context.Context) {
514-
r.ctx = ctx
515-
}
516-
517507
// Read reads as many bytes as possible into p.
518508
func (r *messageReader) Read(p []byte) (n int, err error) {
519509
select {
520510
case <-r.c.closed:
521-
return 0, r.c.getCloseErr()
511+
return 0, r.c.closeErr
522512
case <-r.c.readDone:
523513
return 0, io.EOF
524514
case r.c.readBytes <- p:
525515
select {
526516
case <-r.c.closed:
527-
return 0, r.c.getCloseErr()
517+
return 0, r.c.closeErr
528518
case n := <-r.c.readDone:
529519
return n, nil
530520
case <-r.ctx.Done():

0 commit comments

Comments
 (0)