Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 380ec74

Browse files
committed
Create the pipelines only once in the copilot provider
Since the copilot provider class instance is created once per connection, let's create the pipelines when establishing the connection and reuse them.
1 parent 767a465 commit 380ec74

File tree

2 files changed

+32
-11
lines changed

2 files changed

+32
-11
lines changed

src/codegate/providers/copilot/pipeline.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class CopilotPipeline(ABC):
2424

2525
def __init__(self, pipeline_factory: PipelineFactory):
2626
self.pipeline_factory = pipeline_factory
27+
self.instance = self._create_pipeline()
2728
self.normalizer = self._create_normalizer()
2829
self.provider_name = "openai"
2930

@@ -33,7 +34,7 @@ def _create_normalizer(self):
3334
pass
3435

3536
@abstractmethod
36-
def create_pipeline(self) -> SequentialPipelineProcessor:
37+
def _create_pipeline(self) -> SequentialPipelineProcessor:
3738
"""Each strategy defines which pipeline to create"""
3839
pass
3940

@@ -84,7 +85,11 @@ def _create_shortcut_response(result: PipelineResult, model: str) -> bytes:
8485
body = response.model_dump_json(exclude_none=True, exclude_unset=True).encode()
8586
return body
8687

87-
async def process_body(self, headers: list[str], body: bytes) -> Tuple[bytes, PipelineContext]:
88+
async def process_body(
89+
self,
90+
headers: list[str],
91+
body: bytes,
92+
) -> Tuple[bytes, PipelineContext | None]:
8893
"""Common processing logic for all strategies"""
8994
try:
9095
normalized_body = self.normalizer.normalize(body)
@@ -97,8 +102,7 @@ async def process_body(self, headers: list[str], body: bytes) -> Tuple[bytes, Pi
97102
except ValueError:
98103
continue
99104

100-
pipeline = self.create_pipeline()
101-
result = await pipeline.process_request(
105+
result = await self.instance.process_request(
102106
request=normalized_body,
103107
provider=self.provider_name,
104108
model=normalized_body.get("model", "gpt-4o-mini"),
@@ -168,10 +172,13 @@ class CopilotFimPipeline(CopilotPipeline):
168172
format and the FIM pipeline used by all providers.
169173
"""
170174

175+
def __init__(self, pipeline_factory: PipelineFactory):
176+
super().__init__(pipeline_factory)
177+
171178
def _create_normalizer(self):
172179
return CopilotFimNormalizer()
173180

174-
def create_pipeline(self) -> SequentialPipelineProcessor:
181+
def _create_pipeline(self) -> SequentialPipelineProcessor:
175182
return self.pipeline_factory.create_fim_pipeline()
176183

177184

@@ -181,8 +188,11 @@ class CopilotChatPipeline(CopilotPipeline):
181188
format and the FIM pipeline used by all providers.
182189
"""
183190

191+
def __init__(self, pipeline_factory: PipelineFactory):
192+
super().__init__(pipeline_factory)
193+
184194
def _create_normalizer(self):
185195
return CopilotChatNormalizer()
186196

187-
def create_pipeline(self) -> SequentialPipelineProcessor:
197+
def _create_pipeline(self) -> SequentialPipelineProcessor:
188198
return self.pipeline_factory.create_input_pipeline()

src/codegate/providers/copilot/provider.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,16 @@ def __init__(self, loop: asyncio.AbstractEventLoop):
150150
self.cert_manager = TLSCertDomainManager(self.ca)
151151
self._closing = False
152152
self.pipeline_factory = PipelineFactory(SecretsManager())
153+
self.input_pipeline: Optional[CopilotPipeline] = None
154+
self.fim_pipeline: Optional[CopilotPipeline] = None
155+
# the context as provided by the pipeline
153156
self.context_tracking: Optional[PipelineContext] = None
154157

158+
def _ensure_pipelines(self):
159+
if not self.input_pipeline or not self.fim_pipeline:
160+
self.input_pipeline = CopilotChatPipeline(self.pipeline_factory)
161+
self.fim_pipeline = CopilotFimPipeline(self.pipeline_factory)
162+
155163
def _select_pipeline(self, method: str, path: str) -> Optional[CopilotPipeline]:
156164
if method != "POST":
157165
logger.debug("Not a POST request, no pipeline selected")
@@ -161,10 +169,10 @@ def _select_pipeline(self, method: str, path: str) -> Optional[CopilotPipeline]:
161169
if path == route.path:
162170
if route.pipeline_type == PipelineType.FIM:
163171
logger.debug("Selected FIM pipeline")
164-
return CopilotFimPipeline(self.pipeline_factory)
172+
return self.fim_pipeline
165173
elif route.pipeline_type == PipelineType.CHAT:
166174
logger.debug("Selected CHAT pipeline")
167-
return CopilotChatPipeline(self.pipeline_factory)
175+
return self.input_pipeline
168176

169177
logger.debug("No pipeline selected")
170178
return None
@@ -181,7 +189,6 @@ async def _body_through_pipeline(
181189
# if we didn't select any strategy that would change the request
182190
# let's just pass through the body as-is
183191
return body, None
184-
logger.debug(f"Processing body through pipeline: {len(body)} bytes")
185192
return await strategy.process_body(headers, body)
186193

187194
async def _request_to_target(self, headers: list[str], body: bytes):
@@ -288,6 +295,9 @@ async def _forward_data_through_pipeline(self, data: bytes) -> Union[HttpRequest
288295
http_request.headers,
289296
http_request.body,
290297
)
298+
# TODO: it's weird that we're overwriting the context.
299+
# Should we set the context once? Maybe when
300+
# creating the pipeline instance?
291301
self.context_tracking = context
292302

293303
if context and context.shortcut_response:
@@ -431,7 +441,6 @@ def data_received(self, data: bytes) -> None:
431441
Handle received data from client. Since we need to process the complete body
432442
through our pipeline before forwarding, we accumulate the entire request first.
433443
"""
434-
logger.info(f"Received data from {self.peername}: {data}")
435444
try:
436445
if not self._check_buffer_size(data):
437446
self.send_error_response(413, b"Request body too large")
@@ -442,6 +451,7 @@ def data_received(self, data: bytes) -> None:
442451
if not self.headers_parsed:
443452
self.headers_parsed = self.parse_headers()
444453
if self.headers_parsed:
454+
self._ensure_pipelines()
445455
if self.request.method == "CONNECT":
446456
self.handle_connect()
447457
self.buffer.clear()
@@ -452,7 +462,6 @@ def data_received(self, data: bytes) -> None:
452462
if self._has_complete_body():
453463
# Process the complete request through the pipeline
454464
complete_request = bytes(self.buffer)
455-
logger.debug(f"Complete request: {complete_request}")
456465
self.buffer.clear()
457466
asyncio.create_task(self._forward_data_to_target(complete_request))
458467

@@ -756,10 +765,12 @@ def connection_made(self, transport: asyncio.Transport) -> None:
756765

757766
def _ensure_output_processor(self) -> None:
758767
if self.proxy.context_tracking is None:
768+
logger.debug("No context tracking, no need to process pipeline")
759769
# No context tracking, no need to process pipeline
760770
return
761771

762772
if self.sse_processor is not None:
773+
logger.debug("Already initialized, no need to reinitialize")
763774
# Already initialized, no need to reinitialize
764775
return
765776

0 commit comments

Comments
 (0)