Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 5601d63

Browse files
committed
Handle multiple requests in one data_received call
We've seen instances where the request (typically a FIM one) contained more than one request. Let's dispatch them one by one individually. Also let's not pass around self.buffer into the tasks but a parameter.
1 parent d455728 commit 5601d63

File tree

1 file changed

+46
-27
lines changed

1 file changed

+46
-27
lines changed

src/codegate/providers/copilot/provider.py

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -229,15 +229,15 @@ def connection_made(self, transport: asyncio.Transport) -> None:
229229
self.peername = transport.get_extra_info("peername")
230230
logger.debug(f"Client connected from {self.peername}")
231231

232-
def get_headers_dict(self) -> Dict[str, str]:
232+
def get_headers_dict(self, complete_request) -> Dict[str, str]:
233233
"""Convert raw headers to dictionary format"""
234234
headers_dict = {}
235235
try:
236-
if b"\r\n\r\n" not in self.buffer:
236+
if b"\r\n\r\n" not in complete_request:
237237
return {}
238238

239-
headers_end = self.buffer.index(b"\r\n\r\n")
240-
headers = self.buffer[:headers_end].split(b"\r\n")[1:]
239+
headers_end = complete_request.index(b"\r\n\r\n")
240+
headers = complete_request[:headers_end].split(b"\r\n")[1:]
241241

242242
for header in headers:
243243
try:
@@ -448,31 +448,50 @@ def data_received(self, data: bytes) -> None:
448448

449449
self.buffer.extend(data)
450450

451-
if not self.headers_parsed:
452-
self.headers_parsed = self.parse_headers()
453-
if self.headers_parsed:
454-
self._ensure_pipelines()
455-
if self.request.method == "CONNECT":
456-
self.handle_connect()
457-
self.buffer.clear()
458-
else:
459-
# Only process the request once we have the complete body
460-
asyncio.create_task(self.handle_http_request())
461-
else:
462-
if self._has_complete_body():
463-
# Process the complete request through the pipeline
464-
complete_request = bytes(self.buffer)
465-
self.buffer.clear()
466-
asyncio.create_task(self._forward_data_to_target(complete_request))
451+
while self.buffer: # Process as many complete requests as we have
452+
if not self.headers_parsed:
453+
self.headers_parsed = self.parse_headers()
454+
if self.headers_parsed:
455+
self._ensure_pipelines()
456+
if self.request.method == "CONNECT":
457+
if self._has_complete_body():
458+
self.handle_connect()
459+
self.buffer.clear() # CONNECT requests are handled differently
460+
break # CONNECT handling complete
461+
elif self._has_complete_body():
462+
# Find where this request ends
463+
headers_end = self.buffer.index(b"\r\n\r\n")
464+
headers = self.buffer[:headers_end].split(b"\r\n")[1:]
465+
content_length = 0
466+
for header in headers:
467+
if header.lower().startswith(b"content-length:"):
468+
content_length = int(header.split(b":", 1)[1])
469+
break
470+
471+
request_end = headers_end + 4 + content_length
472+
complete_request = self.buffer[:request_end]
473+
474+
self.buffer = self.buffer[request_end:] # Keep remaining data
475+
476+
self.headers_parsed = False # Reset for next request
477+
478+
asyncio.create_task(self.handle_http_request(complete_request))
479+
break # Either processing request or need more data
480+
else:
481+
if self._has_complete_body():
482+
complete_request = bytes(self.buffer)
483+
self.buffer.clear() # Clear buffer for next request
484+
asyncio.create_task(self._forward_data_to_target(complete_request))
485+
break # Either processing request or need more data
467486

468487
except Exception as e:
469488
logger.error(f"Error processing received data: {e}")
470489
self.send_error_response(502, str(e).encode())
471490

472-
async def handle_http_request(self) -> None:
491+
async def handle_http_request(self, complete_request: bytes) -> None:
473492
"""Handle standard HTTP request"""
474493
try:
475-
target_url = await self._get_target_url()
494+
target_url = await self._get_target_url(complete_request)
476495
except Exception as e:
477496
logger.error(f"Error getting target URL: {e}")
478497
self.send_error_response(404, b"Not Found")
@@ -516,9 +535,9 @@ async def handle_http_request(self) -> None:
516535
new_headers.append(f"Host: {self.target_host}")
517536

518537
if self.target_transport:
519-
if self.buffer:
520-
body_start = self.buffer.index(b"\r\n\r\n") + 4
521-
body = self.buffer[body_start:]
538+
if complete_request:
539+
body_start = complete_request.index(b"\r\n\r\n") + 4
540+
body = complete_request[body_start:]
522541
await self._request_to_target(new_headers, body)
523542
else:
524543
# just skip it
@@ -530,9 +549,9 @@ async def handle_http_request(self) -> None:
530549
logger.error(f"Error preparing or sending request to target: {e}")
531550
self.send_error_response(502, b"Bad Gateway")
532551

533-
async def _get_target_url(self) -> Optional[str]:
552+
async def _get_target_url(self, complete_request) -> Optional[str]:
534553
"""Determine target URL based on request path and headers"""
535-
headers_dict = self.get_headers_dict()
554+
headers_dict = self.get_headers_dict(complete_request)
536555
auth_header = headers_dict.get("authorization", "")
537556

538557
if auth_header:

0 commit comments

Comments
 (0)