@@ -229,15 +229,15 @@ def connection_made(self, transport: asyncio.Transport) -> None:
229
229
self .peername = transport .get_extra_info ("peername" )
230
230
logger .debug (f"Client connected from { self .peername } " )
231
231
232
- def get_headers_dict (self ) -> Dict [str , str ]:
232
+ def get_headers_dict (self , complete_request ) -> Dict [str , str ]:
233
233
"""Convert raw headers to dictionary format"""
234
234
headers_dict = {}
235
235
try :
236
- if b"\r \n \r \n " not in self . buffer :
236
+ if b"\r \n \r \n " not in complete_request :
237
237
return {}
238
238
239
- headers_end = self . buffer .index (b"\r \n \r \n " )
240
- headers = self . buffer [:headers_end ].split (b"\r \n " )[1 :]
239
+ headers_end = complete_request .index (b"\r \n \r \n " )
240
+ headers = complete_request [:headers_end ].split (b"\r \n " )[1 :]
241
241
242
242
for header in headers :
243
243
try :
@@ -449,32 +449,50 @@ def data_received(self, data: bytes) -> None:
449
449
450
450
self .buffer .extend (data )
451
451
452
- if not self .headers_parsed :
453
- self .headers_parsed = self .parse_headers ()
454
- if self .headers_parsed :
455
- self ._ensure_pipelines ()
456
- if self .request .method == "CONNECT" :
457
- self .handle_connect ()
458
- self .buffer .clear ()
459
- else :
460
- # Only process the request once we have the complete body
461
- asyncio .create_task (self .handle_http_request ())
462
- else :
463
- if self ._has_complete_body ():
464
- # Process the complete request through the pipeline
465
- complete_request = bytes (self .buffer )
466
- # logger.debug(f"Complete request: {complete_request}")
467
- self .buffer .clear ()
468
- asyncio .create_task (self ._forward_data_to_target (complete_request ))
452
+ while self .buffer : # Process as many complete requests as we have
453
+ if not self .headers_parsed :
454
+ self .headers_parsed = self .parse_headers ()
455
+ if self .headers_parsed :
456
+ self ._ensure_pipelines ()
457
+ if self .request .method == "CONNECT" :
458
+ if self ._has_complete_body ():
459
+ self .handle_connect ()
460
+ self .buffer .clear () # CONNECT requests are handled differently
461
+ break # CONNECT handling complete
462
+ elif self ._has_complete_body ():
463
+ # Find where this request ends
464
+ headers_end = self .buffer .index (b"\r \n \r \n " )
465
+ headers = self .buffer [:headers_end ].split (b"\r \n " )[1 :]
466
+ content_length = 0
467
+ for header in headers :
468
+ if header .lower ().startswith (b"content-length:" ):
469
+ content_length = int (header .split (b":" , 1 )[1 ])
470
+ break
471
+
472
+ request_end = headers_end + 4 + content_length
473
+ complete_request = self .buffer [:request_end ]
474
+
475
+ self .buffer = self .buffer [request_end :] # Keep remaining data
476
+
477
+ self .headers_parsed = False # Reset for next request
478
+
479
+ asyncio .create_task (self .handle_http_request (complete_request ))
480
+ break # Either processing request or need more data
481
+ else :
482
+ if self ._has_complete_body ():
483
+ complete_request = bytes (self .buffer )
484
+ self .buffer .clear () # Clear buffer for next request
485
+ asyncio .create_task (self ._forward_data_to_target (complete_request ))
486
+ break # Either processing request or need more data
469
487
470
488
except Exception as e :
471
489
logger .error (f"Error processing received data: { e } " )
472
490
self .send_error_response (502 , str (e ).encode ())
473
491
474
- async def handle_http_request (self ) -> None :
492
+ async def handle_http_request (self , complete_request : bytes ) -> None :
475
493
"""Handle standard HTTP request"""
476
494
try :
477
- target_url = await self ._get_target_url ()
495
+ target_url = await self ._get_target_url (complete_request )
478
496
except Exception as e :
479
497
logger .error (f"Error getting target URL: { e } " )
480
498
self .send_error_response (404 , b"Not Found" )
@@ -518,9 +536,9 @@ async def handle_http_request(self) -> None:
518
536
new_headers .append (f"Host: { self .target_host } " )
519
537
520
538
if self .target_transport :
521
- if self . buffer :
522
- body_start = self . buffer .index (b"\r \n \r \n " ) + 4
523
- body = self . buffer [body_start :]
539
+ if complete_request :
540
+ body_start = complete_request .index (b"\r \n \r \n " ) + 4
541
+ body = complete_request [body_start :]
524
542
await self ._request_to_target (new_headers , body )
525
543
else :
526
544
# just skip it
@@ -532,9 +550,9 @@ async def handle_http_request(self) -> None:
532
550
logger .error (f"Error preparing or sending request to target: { e } " )
533
551
self .send_error_response (502 , b"Bad Gateway" )
534
552
535
- async def _get_target_url (self ) -> Optional [str ]:
553
+ async def _get_target_url (self , complete_request ) -> Optional [str ]:
536
554
"""Determine target URL based on request path and headers"""
537
- headers_dict = self .get_headers_dict ()
555
+ headers_dict = self .get_headers_dict (complete_request )
538
556
auth_header = headers_dict .get ("authorization" , "" )
539
557
540
558
if auth_header :
0 commit comments