@@ -29,7 +29,7 @@ class Broadcast:
2929    def  __init__ (self , url : str  |  None  =  None , * , backend : BroadcastBackend  |  None  =  None ) ->  None :
3030        assert  url  or  backend , "Either `url` or `backend` must be provided." 
3131        self ._backend  =  backend  or  self ._create_backend (cast (str , url ))
32-         self ._subscribers : dict [str , set [asyncio .Queue [Event  |  None ]]] =  {}
32+         self ._subscribers : dict [str , set [asyncio .Queue [Event  |  BaseException   |   None ]]] =  {}
3333
3434    def  _create_backend (self , url : str ) ->  BroadcastBackend :
3535        parsed_url  =  urlparse (url )
@@ -69,10 +69,19 @@ async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
6969    async  def  connect (self ) ->  None :
7070        await  self ._backend .connect ()
7171        self ._listener_task  =  asyncio .create_task (self ._listener ())
72+         self ._listener_task .add_done_callback (self .drop )
73+ 
74+     def  drop (self , task : asyncio .Task [None ]) ->  None :
75+         exc  =  task .exception ()
76+         for  queues  in  self ._subscribers .values ():
77+             for  queue  in  queues :
78+                 queue .put_nowait (exc )
7279
7380    async  def  disconnect (self ) ->  None :
7481        if  self ._listener_task .done ():
75-             self ._listener_task .result ()
82+             exc  =  self ._listener_task .exception ()
83+             if  exc  is  None :
84+                 self ._listener_task .result ()
7685        else :
7786            self ._listener_task .cancel ()
7887        await  self ._backend .disconnect ()
@@ -88,7 +97,7 @@ async def publish(self, channel: str, message: Any) -> None:
8897
8998    @asynccontextmanager  
9099    async  def  subscribe (self , channel : str ) ->  AsyncIterator [Subscriber ]:
91-         queue : asyncio .Queue [Event  |  None ] =  asyncio .Queue ()
100+         queue : asyncio .Queue [Event  |  BaseException   |   None ] =  asyncio .Queue ()
92101
93102        try :
94103            if  not  self ._subscribers .get (channel ):
@@ -107,7 +116,7 @@ async def subscribe(self, channel: str) -> AsyncIterator[Subscriber]:
107116
108117
109118class  Subscriber :
110-     def  __init__ (self , queue : asyncio .Queue [Event  |  None ]) ->  None :
119+     def  __init__ (self , queue : asyncio .Queue [Event  |  BaseException   |   None ]) ->  None :
111120        self ._queue  =  queue 
112121
113122    async  def  __aiter__ (self ) ->  AsyncGenerator [Event  |  None , None ]:
@@ -119,6 +128,8 @@ async def __aiter__(self) -> AsyncGenerator[Event | None, None]:
119128
120129    async  def  get (self ) ->  Event :
121130        item  =  await  self ._queue .get ()
131+         if  isinstance (item , BaseException ):
132+             raise  item 
122133        if  item  is  None :
123134            raise  Unsubscribed ()
124135        return  item 
0 commit comments