Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 3e4790d

Browse files
committed
Handle multiple requests in one data_received call
We've seen instances where the request (typically a FIM one) contained more than one request. Let's dispatch them one by one individually. Also let's not pass around self.buffer into the tasks but a parameter.
1 parent 5b621a5 commit 3e4790d

File tree

1 file changed

+46
-28
lines changed

1 file changed

+46
-28
lines changed

src/codegate/providers/copilot/provider.py

Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -229,15 +229,15 @@ def connection_made(self, transport: asyncio.Transport) -> None:
229229
self.peername = transport.get_extra_info("peername")
230230
logger.debug(f"Client connected from {self.peername}")
231231

232-
def get_headers_dict(self) -> Dict[str, str]:
232+
def get_headers_dict(self, complete_request) -> Dict[str, str]:
233233
"""Convert raw headers to dictionary format"""
234234
headers_dict = {}
235235
try:
236-
if b"\r\n\r\n" not in self.buffer:
236+
if b"\r\n\r\n" not in complete_request:
237237
return {}
238238

239-
headers_end = self.buffer.index(b"\r\n\r\n")
240-
headers = self.buffer[:headers_end].split(b"\r\n")[1:]
239+
headers_end = complete_request.index(b"\r\n\r\n")
240+
headers = complete_request[:headers_end].split(b"\r\n")[1:]
241241

242242
for header in headers:
243243
try:
@@ -449,32 +449,50 @@ def data_received(self, data: bytes) -> None:
449449

450450
self.buffer.extend(data)
451451

452-
if not self.headers_parsed:
453-
self.headers_parsed = self.parse_headers()
454-
if self.headers_parsed:
455-
self._ensure_pipelines()
456-
if self.request.method == "CONNECT":
457-
self.handle_connect()
458-
self.buffer.clear()
459-
else:
460-
# Only process the request once we have the complete body
461-
asyncio.create_task(self.handle_http_request())
462-
else:
463-
if self._has_complete_body():
464-
# Process the complete request through the pipeline
465-
complete_request = bytes(self.buffer)
466-
# logger.debug(f"Complete request: {complete_request}")
467-
self.buffer.clear()
468-
asyncio.create_task(self._forward_data_to_target(complete_request))
452+
while self.buffer: # Process as many complete requests as we have
453+
if not self.headers_parsed:
454+
self.headers_parsed = self.parse_headers()
455+
if self.headers_parsed:
456+
self._ensure_pipelines()
457+
if self.request.method == "CONNECT":
458+
if self._has_complete_body():
459+
self.handle_connect()
460+
self.buffer.clear() # CONNECT requests are handled differently
461+
break # CONNECT handling complete
462+
elif self._has_complete_body():
463+
# Find where this request ends
464+
headers_end = self.buffer.index(b"\r\n\r\n")
465+
headers = self.buffer[:headers_end].split(b"\r\n")[1:]
466+
content_length = 0
467+
for header in headers:
468+
if header.lower().startswith(b"content-length:"):
469+
content_length = int(header.split(b":", 1)[1])
470+
break
471+
472+
request_end = headers_end + 4 + content_length
473+
complete_request = self.buffer[:request_end]
474+
475+
self.buffer = self.buffer[request_end:] # Keep remaining data
476+
477+
self.headers_parsed = False # Reset for next request
478+
479+
asyncio.create_task(self.handle_http_request(complete_request))
480+
break # Either processing request or need more data
481+
else:
482+
if self._has_complete_body():
483+
complete_request = bytes(self.buffer)
484+
self.buffer.clear() # Clear buffer for next request
485+
asyncio.create_task(self._forward_data_to_target(complete_request))
486+
break # Either processing request or need more data
469487

470488
except Exception as e:
471489
logger.error(f"Error processing received data: {e}")
472490
self.send_error_response(502, str(e).encode())
473491

474-
async def handle_http_request(self) -> None:
492+
async def handle_http_request(self, complete_request: bytes) -> None:
475493
"""Handle standard HTTP request"""
476494
try:
477-
target_url = await self._get_target_url()
495+
target_url = await self._get_target_url(complete_request)
478496
except Exception as e:
479497
logger.error(f"Error getting target URL: {e}")
480498
self.send_error_response(404, b"Not Found")
@@ -518,9 +536,9 @@ async def handle_http_request(self) -> None:
518536
new_headers.append(f"Host: {self.target_host}")
519537

520538
if self.target_transport:
521-
if self.buffer:
522-
body_start = self.buffer.index(b"\r\n\r\n") + 4
523-
body = self.buffer[body_start:]
539+
if complete_request:
540+
body_start = complete_request.index(b"\r\n\r\n") + 4
541+
body = complete_request[body_start:]
524542
await self._request_to_target(new_headers, body)
525543
else:
526544
# just skip it
@@ -532,9 +550,9 @@ async def handle_http_request(self) -> None:
532550
logger.error(f"Error preparing or sending request to target: {e}")
533551
self.send_error_response(502, b"Bad Gateway")
534552

535-
async def _get_target_url(self) -> Optional[str]:
553+
async def _get_target_url(self, complete_request) -> Optional[str]:
536554
"""Determine target URL based on request path and headers"""
537-
headers_dict = self.get_headers_dict()
555+
headers_dict = self.get_headers_dict(complete_request)
538556
auth_header = headers_dict.get("authorization", "")
539557

540558
if auth_header:

0 commit comments

Comments
 (0)