@@ -71,10 +71,14 @@ def __init__(
7171 h2 .events .ResponseReceived
7272 | h2 .events .DataReceived
7373 | h2 .events .StreamEnded
74- | h2 .events .StreamReset ,
74+ | h2 .events .StreamReset
75+ | h2 .events .TrailersReceived ,
7576 ],
7677 ] = {}
7778
79+ # Mapping from stream ID to trailing headers
80+ self ._trailing_headers : dict [int , list [tuple [bytes , bytes ]]] = {}
81+
7882 # Connection terminated events are stored as state since
7983 # we need to handle them for all streams.
8084 self ._connection_terminated : h2 .events .ConnectionTerminated | None = None
@@ -152,15 +156,22 @@ async def handle_async_request(self, request: Request) -> Response:
152156 )
153157 trace .return_value = (status , headers )
154158
159+ extensions = {
160+ "http_version" : b"HTTP/2" ,
161+ "network_stream" : self ._network_stream ,
162+ "stream_id" : stream_id ,
163+ }
164+
155165 return Response (
156166 status = status ,
157167 headers = headers ,
158- content = HTTP2ConnectionByteStream (self , request , stream_id = stream_id ),
159- extensions = {
160- "http_version" : b"HTTP/2" ,
161- "network_stream" : self ._network_stream ,
162- "stream_id" : stream_id ,
163- },
168+ content = HTTP2ConnectionByteStream (
169+ connection = self ,
170+ request = request ,
171+ stream_id = stream_id ,
172+ extensions = extensions ,
173+ ),
174+ extensions = extensions ,
164175 )
165176 except BaseException as exc : # noqa: PIE786
166177 with AsyncShieldCancellation ():
@@ -321,12 +332,21 @@ async def _receive_response_body(
321332 self ._h2_state .acknowledge_received_data (amount , stream_id )
322333 await self ._write_outgoing_data (request )
323334 yield event .data
335+ elif isinstance (event , h2 .events .TrailersReceived ):
336+ # Process trailing headers but continue receiving events
337+ # The trailing headers are already stored in self._trailing_headers
338+ continue
324339 elif isinstance (event , h2 .events .StreamEnded ):
325340 break
326341
327342 async def _receive_stream_event (
328343 self , request : Request , stream_id : int
329- ) -> h2 .events .ResponseReceived | h2 .events .DataReceived | h2 .events .StreamEnded :
344+ ) -> (
345+ h2 .events .ResponseReceived
346+ | h2 .events .DataReceived
347+ | h2 .events .StreamEnded
348+ | h2 .events .TrailersReceived
349+ ):
330350 """
331351 Return the next available event for a given stream ID.
332352
@@ -377,10 +397,19 @@ async def _receive_events(
377397 h2 .events .DataReceived ,
378398 h2 .events .StreamEnded ,
379399 h2 .events .StreamReset ,
400+ h2 .events .TrailersReceived ,
380401 ),
381402 ):
382403 if event .stream_id in self ._events :
383404 self ._events [event .stream_id ].append (event )
405+ if isinstance (event , h2 .events .TrailersReceived ):
406+ self ._trailing_headers [event .stream_id ] = []
407+ if event .headers is not None :
408+ for k , v in event .headers :
409+ if not k .startswith (b":" ):
410+ self ._trailing_headers [
411+ event .stream_id
412+ ].append ((k , v ))
384413
385414 elif isinstance (event , h2 .events .ConnectionTerminated ):
386415 self ._connection_terminated = event
@@ -409,6 +438,8 @@ async def _receive_remote_settings_change(
409438 async def _response_closed (self , stream_id : int ) -> None :
410439 await self ._max_streams_semaphore .release ()
411440 del self ._events [stream_id ]
441+ if stream_id in self ._trailing_headers :
442+ del self ._trailing_headers [stream_id ]
412443 async with self ._state_lock :
413444 if self ._connection_terminated and not self ._events :
414445 await self .aclose ()
@@ -561,12 +592,17 @@ async def __aexit__(
561592
562593class HTTP2ConnectionByteStream :
563594 def __init__ (
564- self , connection : AsyncHTTP2Connection , request : Request , stream_id : int
595+ self ,
596+ connection : AsyncHTTP2Connection ,
597+ request : Request ,
598+ stream_id : int ,
599+ extensions : typing .MutableMapping [str , typing .Any ],
565600 ) -> None :
566601 self ._connection = connection
567602 self ._request = request
568603 self ._stream_id = stream_id
569604 self ._closed = False
605+ self ._extensions = extensions
570606
571607 async def __aiter__ (self ) -> typing .AsyncIterator [bytes ]:
572608 kwargs = {"request" : self ._request , "stream_id" : self ._stream_id }
@@ -576,6 +612,11 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]:
576612 request = self ._request , stream_id = self ._stream_id
577613 ):
578614 yield chunk
615+
616+ if self ._stream_id in self ._connection ._trailing_headers :
617+ self ._extensions ["trailing_headers" ] = (
618+ self ._connection ._trailing_headers [self ._stream_id ]
619+ )
579620 except BaseException as exc :
580621 # If we get an exception while streaming the response,
581622 # we want to close the response (and possibly the connection)
0 commit comments