@@ -407,6 +407,8 @@ EXPORT_SYMBOL_GPL(vsock_enqueue_accept);
407
407
408
408
static bool vsock_use_local_transport (unsigned int remote_cid )
409
409
{
410
+ lockdep_assert_held (& vsock_register_mutex );
411
+
410
412
if (!transport_local )
411
413
return false;
412
414
@@ -464,6 +466,8 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
464
466
465
467
remote_flags = vsk -> remote_addr .svm_flags ;
466
468
469
+ mutex_lock (& vsock_register_mutex );
470
+
467
471
switch (sk -> sk_type ) {
468
472
case SOCK_DGRAM :
469
473
new_transport = transport_dgram ;
@@ -479,12 +483,15 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
479
483
new_transport = transport_h2g ;
480
484
break ;
481
485
default :
482
- return - ESOCKTNOSUPPORT ;
486
+ ret = - ESOCKTNOSUPPORT ;
487
+ goto err ;
483
488
}
484
489
485
490
if (vsk -> transport ) {
486
- if (vsk -> transport == new_transport )
487
- return 0 ;
491
+ if (vsk -> transport == new_transport ) {
492
+ ret = 0 ;
493
+ goto err ;
494
+ }
488
495
489
496
/* transport->release() must be called with sock lock acquired.
490
497
* This path can only be taken during vsock_connect(), where we
@@ -508,8 +515,16 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
508
515
/* We increase the module refcnt to prevent the transport unloading
509
516
* while there are open sockets assigned to it.
510
517
*/
511
- if (!new_transport || !try_module_get (new_transport -> module ))
512
- return - ENODEV ;
518
+ if (!new_transport || !try_module_get (new_transport -> module )) {
519
+ ret = - ENODEV ;
520
+ goto err ;
521
+ }
522
+
523
+ /* It's safe to release the mutex after a successful try_module_get().
524
+ * Whichever transport `new_transport` points at, it won't go away until
525
+ * the last module_put() below or in vsock_deassign_transport().
526
+ */
527
+ mutex_unlock (& vsock_register_mutex );
513
528
514
529
if (sk -> sk_type == SOCK_SEQPACKET ) {
515
530
if (!new_transport -> seqpacket_allow ||
@@ -528,6 +543,9 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
528
543
vsk -> transport = new_transport ;
529
544
530
545
return 0 ;
546
+ err :
547
+ mutex_unlock (& vsock_register_mutex );
548
+ return ret ;
531
549
}
532
550
EXPORT_SYMBOL_GPL (vsock_assign_transport );
533
551
0 commit comments