@@ -49,13 +49,6 @@ type Conn struct {
49
49
readDone chan int
50
50
}
51
51
52
- func (c * Conn ) getCloseErr () error {
53
- if c .closeErr != nil {
54
- return c .closeErr
55
- }
56
- return nil
57
- }
58
-
59
52
func (c * Conn ) close (err error ) {
60
53
if err != nil {
61
54
err = xerrors .Errorf ("websocket: connection broken: %w" , err )
@@ -160,8 +153,12 @@ messageLoop:
160
153
masked : c .client ,
161
154
}
162
155
c .writeFrame (h , control .payload )
163
- c .writeDone <- struct {}{}
164
- continue
156
+ select {
157
+ case <- c .closed :
158
+ return
159
+ case c .writeDone <- struct {}{}:
160
+ continue
161
+ }
165
162
case b , ok := <- c .writeBytes :
166
163
h := header {
167
164
fin : ! ok ,
@@ -349,14 +346,14 @@ func (c *Conn) Close(code StatusCode, reason string) error {
349
346
p , _ = closePayload (StatusInternalError , fmt .Sprintf ("websocket: application tried to send code %v but code or reason was invalid" , code ))
350
347
}
351
348
352
- err2 := c .writeClose (p , CloseError {
349
+ cerr := c .writeClose (p , CloseError {
353
350
Code : code ,
354
351
Reason : reason ,
355
352
})
356
353
if err != nil {
357
354
return err
358
355
}
359
- return err2
356
+ return cerr
360
357
}
361
358
362
359
func (c * Conn ) writeClose (p []byte , cerr CloseError ) error {
@@ -381,19 +378,19 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error {
381
378
func (c * Conn ) writeControl (ctx context.Context , opcode opcode , p []byte ) error {
382
379
select {
383
380
case <- c .closed :
384
- return c .getCloseErr ()
381
+ return c .closeErr
385
382
case c .control <- control {
386
383
opcode : opcode ,
387
384
payload : p ,
388
385
}:
389
386
case <- ctx .Done ():
390
387
c .close (xerrors .New ("force closed: close frame write timed out" ))
391
- return c .getCloseErr ()
388
+ return c .closeErr
392
389
}
393
390
394
391
select {
395
392
case <- c .closed :
396
- return c .getCloseErr ()
393
+ return c .closeErr
397
394
case <- c .writeDone :
398
395
return nil
399
396
case <- ctx .Done ():
@@ -420,34 +417,25 @@ type messageWriter struct {
420
417
ctx context.Context
421
418
c * Conn
422
419
acquiredLock bool
423
- sentFirst bool
424
-
425
- done chan struct {}
426
420
}
427
421
428
422
// Write writes the given bytes to the WebSocket connection.
429
423
// The frame will automatically be fragmented as appropriate
430
424
// with the buffers obtained from http.Hijacker.
431
425
// Please ensure you call Close once you have written the full message.
432
426
func (w * messageWriter ) Write (p []byte ) (int , error ) {
433
- if ! w .acquiredLock {
434
- select {
435
- case <- w .c .closed :
436
- return 0 , w .c .getCloseErr ()
437
- case w .c .write <- w .datatype :
438
- w .acquiredLock = true
439
- case <- w .ctx .Done ():
440
- return 0 , w .ctx .Err ()
441
- }
427
+ err := w .acquire ()
428
+ if err != nil {
429
+ return 0 , err
442
430
}
443
431
444
432
select {
445
433
case <- w .c .closed :
446
- return 0 , w .c .getCloseErr ()
434
+ return 0 , w .c .closeErr
447
435
case w .c .writeBytes <- p :
448
436
select {
449
437
case <- w .c .closed :
450
- return 0 , w .c .getCloseErr ()
438
+ return 0 , w .c .closeErr
451
439
case <- w .c .writeDone :
452
440
return len (p ), nil
453
441
case <- w .ctx .Done ():
@@ -458,23 +446,32 @@ func (w *messageWriter) Write(p []byte) (int, error) {
458
446
}
459
447
}
460
448
461
- // Close flushes the frame to the connection.
462
- // This must be called for every messageWriter.
463
- func (w * messageWriter ) Close () error {
449
+ func (w * messageWriter ) acquire () error {
464
450
if ! w .acquiredLock {
465
451
select {
466
452
case <- w .c .closed :
467
- return w .c .getCloseErr ()
453
+ return w .c .closeErr
468
454
case w .c .write <- w .datatype :
469
455
w .acquiredLock = true
470
456
case <- w .ctx .Done ():
471
457
return w .ctx .Err ()
472
458
}
473
459
}
460
+ return nil
461
+ }
462
+
463
+ // Close flushes the frame to the connection.
464
+ // This must be called for every messageWriter.
465
+ func (w * messageWriter ) Close () error {
466
+ err := w .acquire ()
467
+ if err != nil {
468
+ return err
469
+ }
470
+
474
471
close (w .c .writeBytes )
475
472
select {
476
473
case <- w .c .closed :
477
- return w .c .getCloseErr ()
474
+ return w .c .closeErr
478
475
case <- w .ctx .Done ():
479
476
return w .ctx .Err ()
480
477
case <- w .c .writeDone :
@@ -490,7 +487,7 @@ func (w *messageWriter) Close() error {
490
487
func (c * Conn ) Read (ctx context.Context ) (DataType , io.Reader , error ) {
491
488
select {
492
489
case <- c .closed :
493
- return 0 , nil , xerrors .Errorf ("failed to read message: %w" , c .getCloseErr () )
490
+ return 0 , nil , xerrors .Errorf ("failed to read message: %w" , c .closeErr )
494
491
case opcode := <- c .read :
495
492
return DataType (opcode ), & messageReader {
496
493
ctx : ctx ,
@@ -507,24 +504,17 @@ type messageReader struct {
507
504
c * Conn
508
505
}
509
506
510
- // SetContext bounds the read operation to the ctx.
511
- // By default, the context is the one passed to conn.ReadMessage.
512
- // You still almost always want a separate context for reading the message though.
513
- func (r * messageReader ) SetContext (ctx context.Context ) {
514
- r .ctx = ctx
515
- }
516
-
517
507
// Read reads as many bytes as possible into p.
518
508
func (r * messageReader ) Read (p []byte ) (n int , err error ) {
519
509
select {
520
510
case <- r .c .closed :
521
- return 0 , r .c .getCloseErr ()
511
+ return 0 , r .c .closeErr
522
512
case <- r .c .readDone :
523
513
return 0 , io .EOF
524
514
case r .c .readBytes <- p :
525
515
select {
526
516
case <- r .c .closed :
527
- return 0 , r .c .getCloseErr ()
517
+ return 0 , r .c .closeErr
528
518
case n := <- r .c .readDone :
529
519
return n , nil
530
520
case <- r .ctx .Done ():
0 commit comments