diff --git a/src/codegate/pipeline/base.py b/src/codegate/pipeline/base.py index dc7bac53..f5b62c8c 100644 --- a/src/codegate/pipeline/base.py +++ b/src/codegate/pipeline/base.py @@ -1,9 +1,12 @@ +import uuid from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any, Dict, List, Optional from litellm import ChatCompletionRequest +from codegate.pipeline.secrets.manager import SecretsManager + @dataclass class CodeSnippet: @@ -24,10 +27,25 @@ def __post_init__(self): self.language = self.language.strip().lower() +@dataclass +class PipelineSensitiveData: + manager: SecretsManager + session_id: str + + def secure_cleanup(self): + """Securely cleanup sensitive data for this session""" + if self.manager is None or self.session_id == "": + return + + self.manager.cleanup_session(self.session_id) + self.session_id = "" + + @dataclass class PipelineContext: code_snippets: List[CodeSnippet] = field(default_factory=list) metadata: Dict[str, Any] = field(default_factory=dict) + sensitive: Optional[PipelineSensitiveData] = field(default_factory=lambda: None) def add_code_snippet(self, snippet: CodeSnippet): self.code_snippets.append(snippet) @@ -139,6 +157,7 @@ def __init__(self, pipeline_steps: List[PipelineStep]): async def process_request( self, + secret_manager: SecretsManager, request: ChatCompletionRequest, ) -> PipelineResult: """ @@ -146,11 +165,15 @@ async def process_request( Args: request: The chat completion request to process + secret_manager: The secrets manager instance to gather sensitive data from the request Returns: PipelineResult containing either a modified request or response structure """ context = PipelineContext() + context.sensitive = PipelineSensitiveData( + manager=secret_manager, session_id=str(uuid.uuid4()) + ) # Generate a new session ID for each request current_request = request for step in self.pipeline_steps: diff --git a/src/codegate/pipeline/output.py b/src/codegate/pipeline/output.py new file mode 100644 index 00000000..b74ce2d0 --- /dev/null +++ b/src/codegate/pipeline/output.py @@ -0,0 +1,173 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import AsyncIterator, Optional + +from litellm import ModelResponse +from litellm.types.utils import Delta, StreamingChoices + +from codegate.pipeline.base import PipelineContext + + +@dataclass +class OutputPipelineContext: + """ + Context passed between output pipeline steps. + + Does not include the input context, that one is separate. + """ + + # We store the messages that are not yet sent to the client in the buffer. + # One reason for this might be that the buffer contains a secret that we want to de-obfuscate + buffer: list[str] = field(default_factory=list) + + +class OutputPipelineStep(ABC): + """ + Base class for output pipeline steps + The process method should be implemented by subclasses and handles + processing of a single chunk of the stream. + """ + + @property + @abstractmethod + def name(self) -> str: + """Returns the name of this pipeline step""" + pass + + @abstractmethod + async def process_chunk( + self, + chunk: ModelResponse, + context: OutputPipelineContext, + input_context: Optional[PipelineContext] = None, + ) -> Optional[ModelResponse]: + """ + Process a single chunk of the stream. + + Args: + - chunk: The input chunk to process, normalized to ModelResponse + - context: The output pipeline context. Can be used to store state between steps, mainly + the buffer. + - input_context: The input context from processing the user's input. Can include the secrets + obfuscated in the user message or code snippets in the user message. + + Return: + - None to pause the stream + - Modified or unmodified input chunk to pass through + """ + pass + + +class OutputPipelineInstance: + """ + Handles processing of a single stream + Think of this class as steps + buffer + """ + + def __init__( + self, + pipeline_steps: list[OutputPipelineStep], + input_context: Optional[PipelineContext] = None, + ): + self._input_context = input_context + self._pipeline_steps = pipeline_steps + self._context = OutputPipelineContext() + # we won't actually buffer the chunk, but in case we need to pass + # the remaining content in the buffer when the stream ends, we need + # to store the parameters like model, timestamp, etc. + self._buffered_chunk = None + + def _buffer_chunk(self, chunk: ModelResponse) -> None: + """ + Add chunk content to buffer. + """ + self._buffered_chunk = chunk + for choice in chunk.choices: + # the last choice has no delta or content, let's not buffer it + if choice.delta is not None and choice.delta.content is not None: + self._context.buffer.append(choice.delta.content) + + async def process_stream( + self, stream: AsyncIterator[ModelResponse] + ) -> AsyncIterator[ModelResponse]: + """ + Process a stream through all pipeline steps + """ + try: + async for chunk in stream: + # Store chunk content in buffer + self._buffer_chunk(chunk) + + # Process chunk through each step of the pipeline + current_chunk = chunk + for step in self._pipeline_steps: + if current_chunk is None: + # Stop processing if a step returned None previously + # this means that the pipeline step requested to pause the stream + # instead, let's try again with the next chunk + break + + processed_chunk = await step.process_chunk( + current_chunk, self._context, self._input_context + ) + # the returned chunk becomes the input for the next chunk in the pipeline + current_chunk = processed_chunk + + # we have either gone through all the steps in the pipeline and have a chunk + # to return or we are paused in which case we don't yield + if current_chunk is not None: + # Step processed successfully, yield the chunk and clear buffer + self._context.buffer.clear() + yield current_chunk + # else: keep buffering for next iteration + + except Exception as e: + # Log exception and stop processing + raise e + finally: + # Process any remaining content in buffer when stream ends + if self._context.buffer: + final_content = "".join(self._context.buffer) + yield ModelResponse( + id=self._buffered_chunk.id, + choices=[ + StreamingChoices( + finish_reason=None, + # we just put one choice in the buffer, so 0 is fine + index=0, + delta=Delta(content=final_content, role="assistant"), + # umm..is this correct? + logprobs=self._buffered_chunk.choices[0].logprobs, + ) + ], + created=self._buffered_chunk.created, + model=self._buffered_chunk.model, + object="chat.completion.chunk", + ) + self._context.buffer.clear() + + # Cleanup sensitive data through the input context + if self._input_context and self._input_context.sensitive: + self._input_context.sensitive.secure_cleanup() + + +class OutputPipelineProcessor: + """ + Since we want to provide each run of the pipeline with a fresh context, + we need a factory to create new instances of the pipeline. + """ + + def __init__(self, pipeline_steps: list[OutputPipelineStep]): + self.pipeline_steps = pipeline_steps + + def _create_instance(self) -> OutputPipelineInstance: + """Create a new pipeline instance for processing a stream""" + return OutputPipelineInstance(self.pipeline_steps) + + async def process_stream( + self, stream: AsyncIterator[ModelResponse] + ) -> AsyncIterator[ModelResponse]: + """Create a new pipeline instance and process the stream""" + instance = self._create_instance() + async for chunk in instance.process_stream(stream): + yield chunk diff --git a/src/codegate/pipeline/secrets/manager.py b/src/codegate/pipeline/secrets/manager.py new file mode 100644 index 00000000..a7b32319 --- /dev/null +++ b/src/codegate/pipeline/secrets/manager.py @@ -0,0 +1,112 @@ +from typing import NamedTuple, Optional + +import structlog + +from codegate.pipeline.secrets.gatecrypto import CodeGateCrypto + +logger = structlog.get_logger("codegate") + + +class SecretEntry(NamedTuple): + """Represents a stored secret""" + + original: str + encrypted: str + service: str + secret_type: str + + +class SecretsManager: + """Manages encryption, storage and retrieval of secrets""" + + def __init__(self): + self.crypto = CodeGateCrypto() + self._session_store: dict[str, SecretEntry] = {} + self._encrypted_to_session: dict[str, str] = {} # Reverse lookup index + + def store_secret(self, value: str, service: str, secret_type: str, session_id: str) -> str: + """ + Encrypts and stores a secret value. + Returns the encrypted value. + """ + if not value: + raise ValueError("Value must be provided") + if not service: + raise ValueError("Service must be provided") + if not secret_type: + raise ValueError("Secret type must be provided") + if not session_id: + raise ValueError("Session ID must be provided") + + encrypted_value = self.crypto.encrypt_token(value, session_id) + + # Store mappings + self._session_store[session_id] = SecretEntry( + original=value, + encrypted=encrypted_value, + service=service, + secret_type=secret_type, + ) + self._encrypted_to_session[encrypted_value] = session_id + + logger.debug("Stored secret", service=service, type=secret_type, encrypted=encrypted_value) + + return encrypted_value + + def get_original_value(self, encrypted_value: str, session_id: str) -> Optional[str]: + """Retrieve original value for an encrypted value""" + try: + stored_session_id = self._encrypted_to_session.get(encrypted_value) + if stored_session_id == session_id: + return self._session_store[session_id].original + except Exception as e: + logger.error("Error retrieving secret", error=str(e)) + return None + + def get_by_session_id(self, session_id: str) -> Optional[SecretEntry]: + """Get stored data by session ID""" + return self._session_store.get(session_id) + + def cleanup(self): + """Securely wipe sensitive data""" + try: + # Convert and wipe original values + for entry in self._session_store.values(): + original_bytes = bytearray(entry.original.encode()) + self.crypto.wipe_bytearray(original_bytes) + + # Clear the dictionaries + self._session_store.clear() + self._encrypted_to_session.clear() + + logger.info("Secrets manager data securely wiped") + except Exception as e: + logger.error("Error during secure cleanup", error=str(e)) + + def cleanup_session(self, session_id: str): + """ + Remove a specific session's secrets and perform secure cleanup. + + Args: + session_id (str): The session identifier to remove + """ + try: + # Get the secret entry for the session + entry = self._session_store.get(session_id) + + if entry: + # Securely wipe the original value + original_bytes = bytearray(entry.original.encode()) + self.crypto.wipe_bytearray(original_bytes) + + # Remove the encrypted value from the reverse lookup index + self._encrypted_to_session.pop(entry.encrypted, None) + + # Remove the session from the store + self._session_store.pop(session_id, None) + + logger.debug("Session secrets securely removed", session_id=session_id) + else: + logger.debug("No secrets found for session", session_id=session_id) + except Exception as e: + logger.error("Error during session cleanup", session_id=session_id, error=str(e)) diff --git a/src/codegate/pipeline/secrets/secrets.py b/src/codegate/pipeline/secrets/secrets.py index f9c3d0df..782211a6 100644 --- a/src/codegate/pipeline/secrets/secrets.py +++ b/src/codegate/pipeline/secrets/secrets.py @@ -1,14 +1,17 @@ import re +from typing import Optional import structlog -from litellm import ChatCompletionRequest +from litellm import ChatCompletionRequest, ModelResponse +from litellm.types.utils import Delta, StreamingChoices from codegate.pipeline.base import ( PipelineContext, PipelineResult, PipelineStep, ) -from codegate.pipeline.secrets.gatecrypto import CodeGateCrypto +from codegate.pipeline.output import OutputPipelineContext, OutputPipelineStep +from codegate.pipeline.secrets.manager import SecretsManager from codegate.pipeline.secrets.signatures import CodegateSignatures logger = structlog.get_logger("codegate") @@ -20,9 +23,6 @@ class CodegateSecrets(PipelineStep): def __init__(self): """Initialize the CodegateSecrets pipeline step.""" super().__init__() - self.crypto = CodeGateCrypto() - self._session_store = {} - self._encrypted_to_session = {} # Reverse lookup index @property def name(self) -> str: @@ -73,7 +73,7 @@ def _extend_match_boundaries(self, text: str, start: int, end: int) -> tuple[int return start, end - def _redeact_text(self, text: str) -> str: + def _redeact_text(self, text: str, secrets_manager: SecretsManager, session_id: str) -> str: """ Find and encrypt secrets in the given text. @@ -114,34 +114,19 @@ def _redeact_text(self, text: str) -> str: # Replace each match with its encrypted value for start, end, match in absolute_matches: - # Generate session key and encrypt the value - session_id = self.crypto.generate_session_key(None).hex() - encrypted_value = self.crypto.encrypt_token(match.value, session_id) - - print("Original value: ", match.value) - print("Encrypted value: ", encrypted_value) - print("Service: ", match.service) - print("Type: ", match.type) - - # Store the mapping - self._session_store[session_id] = { - "original": match.value, - "encrypted": encrypted_value, - "service": match.service, - "type": match.type, - } - # Store reverse lookup - self._encrypted_to_session[encrypted_value] = session_id - - # Print the session store - logger.info(f"Session store: {self._session_store}") + # Encrypt and store the value + encrypted_value = secrets_manager.store_secret( + match.value, + match.service, + match.type, + session_id, + ) # Create the replacement string replacement = f"REDACTED<${encrypted_value}>" # Replace the secret in the text protected_text[start:end] = replacement - # Store for logging found_secrets.append( { @@ -152,110 +137,20 @@ def _redeact_text(self, text: str) -> str: } ) - # Convert back to string - protected_string = "".join(protected_text) - - # Log the findings - logger.info("\nFound secrets:") - for secret in found_secrets: - logger.info(f"\nService: {secret['service']}") - logger.info(f"Type: {secret['type']}") - logger.info(f"Original: {secret['original']}") - logger.info(f"Encrypted: REDACTED<${secret['encrypted']}>") - - (f"\nProtected text:\n{protected_string}") - return protected_string + # Convert back to string + protected_string = "".join(protected_text) - def _get_original_value(self, encrypted_value: str) -> str: - """ - Get the original value for an encrypted value from the session store. - - Args: - encrypted_value: The encrypted value to look up - - Returns: - Original value if found, or the encrypted value if not found - """ - try: - # Use reverse lookup index to get session_id - session_id = self._encrypted_to_session.get(encrypted_value) - if session_id: - return self._session_store[session_id]["original"] - except Exception as e: - logger.error(f"Error looking up original value: {e}") - return encrypted_value - - def get_by_session_id(self, session_id: str) -> dict | None: - """ - Get stored data directly by session ID. - - Args: - session_id: The session ID to look up + # Log the findings + logger.info("\nFound secrets:") - Returns: - Dict containing the stored data if found, None otherwise - """ - try: - return self._session_store.get(session_id) - except Exception as e: - logger.error(f"Error looking up by session ID: {e}") - return None - - def _cleanup_session_store(self): - """ - Securely wipe sensitive data from session stores. - """ - try: - # Convert and wipe original values - for session_data in self._session_store.values(): - if "original" in session_data: - original_bytes = bytearray(session_data["original"].encode()) - self.crypto.wipe_bytearray(original_bytes) - - # Clear the dictionaries - self._session_store.clear() - self._encrypted_to_session.clear() - - logger.info("Session stores securely wiped") - except Exception as e: - logger.error(f"Error during secure cleanup: {e}") - - def _unredact_text(self, protected_text: str) -> str: - """ - Decrypt and restore the original text from protected text. - - Args: - protected_text: The protected text containing encrypted values - - Returns: - Original text with decrypted values - """ - # Find all REDACTED markers - pattern = r"REDACTED<\$([^>]+)>" + for secret in found_secrets: + logger.info(f"\nService: {secret['service']}") + logger.info(f"Type: {secret['type']}") + logger.info(f"Original: {secret['original']}") + logger.info(f"Encrypted: REDACTED<${secret['encrypted']}>") - # Start from the beginning of the text - result = [] - last_end = 0 - - # Find each REDACTED section and replace with original value - for match in re.finditer(pattern, protected_text): - # Add text before this match - result.append(protected_text[last_end : match.start()]) - - # Get and add the original value - encrypted_value = match.group(1) - original_value = self._get_original_value(encrypted_value) - result.append(original_value) - - last_end = match.end() - - # Add any remaining text - result.append(protected_text[last_end:]) - - # Join all parts together - unprotected_text = "".join(result) - logger.info(f"\nUnprotected text:\n{unprotected_text}") - return unprotected_text + print(f"\nProtected text:\n{protected_string}") + return "".join(protected_text) async def process( self, request: ChatCompletionRequest, context: PipelineContext @@ -270,32 +165,118 @@ async def process( Returns: PipelineResult containing the processed request """ + secrets_manager = context.sensitive.manager + if not secrets_manager or not isinstance(secrets_manager, SecretsManager): + # Should this be an error? + raise ValueError("Secrets manager not found in context") + session_id = context.sensitive.session_id + if not session_id: + raise ValueError("Session ID not found in context") + last_user_message = self.get_last_user_message(request) - extracted_string = last_user_message[0] if last_user_message else None - print(f"Original text:\n{extracted_string}") + extracted_string = None + extracted_index = None + if last_user_message: + extracted_string = last_user_message[0] + extracted_index = last_user_message[1] if not extracted_string: return PipelineResult(request=request) - try: - # Protect the text - protected_string = self._redeact_text(extracted_string) - print(f"\nProtected text:\n{protected_string}") + # Protect the text + protected_string = self._redeact_text(extracted_string, secrets_manager, session_id) - # LLM - unprotected_string = self._unredact_text(protected_string) - print(f"\nUnprotected text:\n{unprotected_string}") + # Update the user message + new_request = request.copy() + new_request["messages"][extracted_index]["content"] = protected_string + return PipelineResult(request=new_request) - # Update the user message with protected text - if isinstance(request["messages"], list): - for msg in request["messages"]: - if msg.get("role") == "user" and msg.get("content") == extracted_string: - msg["content"] = protected_string - return PipelineResult(request=request) - except Exception as e: - logger.error(f"CodegateSecrets operation failed: {e}") +class SecretUnredactionStep(OutputPipelineStep): + """Pipeline step that unredacts protected content in the stream""" + + def __init__(self): + self.redacted_pattern = re.compile(r"REDACTED<\$([^>]+)>") + self.marker_start = "REDACTED<$" + self.marker_end = ">" + + @property + def name(self) -> str: + return "secret-unredaction" + + def _is_partial_marker_prefix(self, text: str) -> bool: + """Check if text ends with a partial marker prefix""" + for i in range(1, len(self.marker_start) + 1): + if text.endswith(self.marker_start[:i]): + return True + return False + + def _find_complete_redaction(self, text: str) -> tuple[Optional[re.Match[str]], str]: + """ + Find the first complete REDACTED marker in text. + Returns (match, remaining_text) if found, (None, original_text) if not. + """ + matches = list(self.redacted_pattern.finditer(text)) + if not matches: + return None, text + + # Get the first complete match + match = matches[0] + return match, text[match.end() :] + + async def process_chunk( + self, + chunk: ModelResponse, + context: OutputPipelineContext, + input_context: Optional[PipelineContext] = None, + ) -> Optional[ModelResponse]: + """Process a single chunk of the stream""" + if not input_context: + raise ValueError("Input context not found") + if input_context.sensitive is None or input_context.sensitive.manager is None: + raise ValueError("Secrets manager not found in input context") + if input_context.sensitive.session_id == "": + raise ValueError("Session ID not found in input context") + + if not chunk.choices[0].delta.content: + return chunk + + # Check the buffered content + buffered_content = "".join(context.buffer) + + # Look for complete REDACTED markers first + match, remaining = self._find_complete_redaction(buffered_content) + if match: + # Found a complete marker, process it + encrypted_value = match.group(1) + original_value = input_context.sensitive.manager.get_original_value( + encrypted_value, + input_context.sensitive.session_id, + ) + + if original_value is None: + # If value not found, leave as is + original_value = match.group(0) # Keep the REDACTED marker + + # Return the unredacted content up to this point + chunk.choices = [ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta( + content=buffered_content[: match.start()] + original_value + remaining, + role="assistant", + ), + logprobs=None, + ) + ] + return chunk + + # If we have a partial marker at the end, keep buffering + if self.marker_start in buffered_content or self._is_partial_marker_prefix( + buffered_content + ): + return None - finally: - # Clean up sensitive data - self._cleanup_session_store() + # No markers or partial markers, let pipeline handle the chunk normally + return chunk diff --git a/src/codegate/providers/anthropic/provider.py b/src/codegate/providers/anthropic/provider.py index 4d7eba59..32909260 100644 --- a/src/codegate/providers/anthropic/provider.py +++ b/src/codegate/providers/anthropic/provider.py @@ -3,6 +3,8 @@ from fastapi import Header, HTTPException, Request +from codegate.pipeline.output import OutputPipelineProcessor +from codegate.pipeline.secrets.manager import SecretsManager from codegate.providers.anthropic.adapter import AnthropicInputNormalizer, AnthropicOutputNormalizer from codegate.providers.anthropic.completion_handler import AnthropicCompletion from codegate.providers.base import BaseProvider, SequentialPipelineProcessor @@ -12,16 +14,20 @@ class AnthropicProvider(BaseProvider): def __init__( self, + secrets_manager: SecretsManager, pipeline_processor: Optional[SequentialPipelineProcessor] = None, fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, + output_pipeline_processor: Optional[OutputPipelineProcessor] = None, ): completion_handler = AnthropicCompletion(stream_generator=anthropic_stream_generator) super().__init__( + secrets_manager, AnthropicInputNormalizer(), AnthropicOutputNormalizer(), completion_handler, pipeline_processor, fim_pipeline_processor, + output_pipeline_processor, ) @property diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py index 585c1d9a..509f8e9c 100644 --- a/src/codegate/providers/base.py +++ b/src/codegate/providers/base.py @@ -7,7 +7,9 @@ from litellm.types.llms.openai import ChatCompletionRequest from codegate.db.connection import DbRecorder -from codegate.pipeline.base import PipelineResult, SequentialPipelineProcessor +from codegate.pipeline.base import PipelineContext, PipelineResult, SequentialPipelineProcessor +from codegate.pipeline.output import OutputPipelineInstance, OutputPipelineProcessor +from codegate.pipeline.secrets.manager import SecretsManager from codegate.providers.completion.base import BaseCompletionHandler from codegate.providers.formatting.input_pipeline import PipelineResponseFormatter from codegate.providers.normalizer.base import ModelInputNormalizer, ModelOutputNormalizer @@ -25,18 +27,22 @@ class BaseProvider(ABC): def __init__( self, + secrets_manager: Optional[SecretsManager], input_normalizer: ModelInputNormalizer, output_normalizer: ModelOutputNormalizer, completion_handler: BaseCompletionHandler, pipeline_processor: Optional[SequentialPipelineProcessor] = None, fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, + output_pipeline_processor: Optional[OutputPipelineProcessor] = None, ): self.router = APIRouter() + self._secrets_manager = secrets_manager self._completion_handler = completion_handler self._input_normalizer = input_normalizer self._output_normalizer = output_normalizer self._pipeline_processor = pipeline_processor self._fim_pipelin_processor = fim_pipeline_processor + self._output_pipeline_processor = output_pipeline_processor self._pipeline_response_formatter = PipelineResponseFormatter(output_normalizer) self.db_recorder = DbRecorder() @@ -53,16 +59,20 @@ def provider_route_name(self) -> str: async def _run_output_stream_pipeline( self, + input_context: PipelineContext, normalized_stream: AsyncIterator[ModelResponse], ) -> AsyncIterator[ModelResponse]: - # we don't have a pipeline for output stream yet - return normalized_stream + output_pipeline_instance = OutputPipelineInstance( + self._output_pipeline_processor.pipeline_steps, + input_context=input_context, + ) + return output_pipeline_instance.process_stream(normalized_stream) def _run_output_pipeline( self, normalized_response: ModelResponse, ) -> ModelResponse: - # we don't have a pipeline for output yet + # we don't have a pipeline for non-streamed output yet return normalized_response async def _run_input_pipeline( @@ -78,7 +88,9 @@ async def _run_input_pipeline( if pipeline_processor is None: return PipelineResult(request=normalized_request) - result = await pipeline_processor.process_request(normalized_request) + result = await pipeline_processor.process_request( + secret_manager=self._secrets_manager, request=normalized_request + ) # TODO(jakub): handle this by returning a message to the client if result.error_message: @@ -135,6 +147,18 @@ def _is_fim_request(self, request: Request, data: Dict) -> bool: return self._is_fim_request_body(data) + async def _cleanup_after_streaming( + self, stream: AsyncIterator[ModelResponse], context: PipelineContext + ) -> AsyncIterator[ModelResponse]: + """Wraps the stream to ensure cleanup after consumption""" + try: + async for item in stream: + yield item + finally: + # Ensure sensitive data is cleaned up after the stream is consumed + if context and context.sensitive: + context.sensitive.secure_cleanup() + async def complete( self, data: Dict, api_key: Optional[str], is_fim_request: bool ) -> Union[ModelResponse, AsyncIterator[ModelResponse]]: @@ -175,8 +199,12 @@ async def complete( return self._output_normalizer.denormalize(pipeline_output) normalized_stream = self._output_normalizer.normalize_streaming(model_response) - pipeline_output_stream = await self._run_output_stream_pipeline(normalized_stream) - return self._output_normalizer.denormalize_streaming(pipeline_output_stream) + pipeline_output_stream = await self._run_output_stream_pipeline( + input_pipeline_result.context, + normalized_stream, + ) + denormalized_stream = self._output_normalizer.denormalize_streaming(pipeline_output_stream) + return self._cleanup_after_streaming(denormalized_stream, input_pipeline_result.context) def get_routes(self) -> APIRouter: return self.router diff --git a/src/codegate/providers/llamacpp/provider.py b/src/codegate/providers/llamacpp/provider.py index efe06f09..d97feb79 100644 --- a/src/codegate/providers/llamacpp/provider.py +++ b/src/codegate/providers/llamacpp/provider.py @@ -3,6 +3,8 @@ from fastapi import Request +from codegate.pipeline.output import OutputPipelineProcessor +from codegate.pipeline.secrets.manager import SecretsManager from codegate.providers.base import BaseProvider, SequentialPipelineProcessor from codegate.providers.llamacpp.completion_handler import LlamaCppCompletionHandler from codegate.providers.llamacpp.normalizer import LLamaCppInputNormalizer, LLamaCppOutputNormalizer @@ -11,16 +13,20 @@ class LlamaCppProvider(BaseProvider): def __init__( self, + secrets_manager: SecretsManager, pipeline_processor: Optional[SequentialPipelineProcessor] = None, fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, + output_pipeline_processor: Optional[OutputPipelineProcessor] = None, ): completion_handler = LlamaCppCompletionHandler() super().__init__( + secrets_manager, LLamaCppInputNormalizer(), LLamaCppOutputNormalizer(), completion_handler, pipeline_processor, fim_pipeline_processor, + output_pipeline_processor, ) @property diff --git a/src/codegate/providers/ollama/provider.py b/src/codegate/providers/ollama/provider.py index 5b8c9a4b..95c7fea8 100644 --- a/src/codegate/providers/ollama/provider.py +++ b/src/codegate/providers/ollama/provider.py @@ -4,6 +4,8 @@ from fastapi import Request from codegate.config import Config +from codegate.pipeline.output import OutputPipelineProcessor +from codegate.pipeline.secrets.manager import SecretsManager from codegate.providers.base import BaseProvider, SequentialPipelineProcessor from codegate.providers.ollama.adapter import OllamaInputNormalizer, OllamaOutputNormalizer from codegate.providers.ollama.completion_handler import OllamaCompletionHandler @@ -12,16 +14,20 @@ class OllamaProvider(BaseProvider): def __init__( self, + secrets_manager: Optional[SecretsManager], pipeline_processor: Optional[SequentialPipelineProcessor] = None, fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, + output_pipeline_processor: Optional[OutputPipelineProcessor] = None, ): completion_handler = OllamaCompletionHandler() super().__init__( + secrets_manager, OllamaInputNormalizer(), OllamaOutputNormalizer(), completion_handler, pipeline_processor, fim_pipeline_processor, + output_pipeline_processor, ) # Get the Ollama base URL config = Config.get_config() diff --git a/src/codegate/providers/openai/provider.py b/src/codegate/providers/openai/provider.py index 741d3143..649805a9 100644 --- a/src/codegate/providers/openai/provider.py +++ b/src/codegate/providers/openai/provider.py @@ -3,6 +3,8 @@ from fastapi import Header, HTTPException, Request +from codegate.pipeline.output import OutputPipelineProcessor +from codegate.pipeline.secrets.manager import SecretsManager from codegate.providers.base import BaseProvider, SequentialPipelineProcessor from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator from codegate.providers.openai.adapter import OpenAIInputNormalizer, OpenAIOutputNormalizer @@ -11,16 +13,20 @@ class OpenAIProvider(BaseProvider): def __init__( self, + secrets_manager: SecretsManager, pipeline_processor: Optional[SequentialPipelineProcessor] = None, fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, + output_pipeline_processor: Optional[OutputPipelineProcessor] = None, ): completion_handler = LiteLLmShim(stream_generator=sse_stream_generator) super().__init__( + secrets_manager, OpenAIInputNormalizer(), OpenAIOutputNormalizer(), completion_handler, pipeline_processor, fim_pipeline_processor, + output_pipeline_processor, ) @property diff --git a/src/codegate/providers/vllm/provider.py b/src/codegate/providers/vllm/provider.py index a342ac6f..242ce05f 100644 --- a/src/codegate/providers/vllm/provider.py +++ b/src/codegate/providers/vllm/provider.py @@ -5,6 +5,8 @@ from litellm import atext_completion from codegate.config import Config +from codegate.pipeline.output import OutputPipelineProcessor +from codegate.pipeline.secrets.manager import SecretsManager from codegate.providers.base import BaseProvider, SequentialPipelineProcessor from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator from codegate.providers.vllm.adapter import VLLMInputNormalizer, VLLMOutputNormalizer @@ -13,18 +15,22 @@ class VLLMProvider(BaseProvider): def __init__( self, + secrets_manager: SecretsManager, pipeline_processor: Optional[SequentialPipelineProcessor] = None, fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, + output_pipeline_processor: Optional[OutputPipelineProcessor] = None, ): completion_handler = LiteLLmShim( stream_generator=sse_stream_generator, fim_completion_func=atext_completion ) super().__init__( + secrets_manager, VLLMInputNormalizer(), VLLMOutputNormalizer(), completion_handler, pipeline_processor, fim_pipeline_processor, + output_pipeline_processor, ) @property diff --git a/src/codegate/server.py b/src/codegate/server.py index f8a953f4..a45e6b65 100644 --- a/src/codegate/server.py +++ b/src/codegate/server.py @@ -8,7 +8,9 @@ from codegate.pipeline.codegate_context_retriever.codegate import CodegateContextRetriever from codegate.pipeline.codegate_system_prompt.codegate import CodegateSystemPrompt from codegate.pipeline.extract_snippets.extract_snippets import CodeSnippetExtractor -from codegate.pipeline.secrets.secrets import CodegateSecrets +from codegate.pipeline.output import OutputPipelineProcessor, OutputPipelineStep +from codegate.pipeline.secrets.manager import SecretsManager +from codegate.pipeline.secrets.secrets import CodegateSecrets, SecretUnredactionStep from codegate.pipeline.secrets.signatures import CodegateSignatures from codegate.pipeline.version.version import CodegateVersion from codegate.providers.anthropic.provider import AnthropicProvider @@ -27,6 +29,12 @@ def init_app() -> FastAPI: version=__version__, ) + # Initialize secrets manager + # TODO: we need to clean up the secrets manager + # after the conversation is concluded + # this was done in the pipeline step but I just removed it for now + secrets_manager = SecretsManager() + steps: List[PipelineStep] = [ CodegateVersion(), CodeSnippetExtractor(), @@ -39,6 +47,11 @@ def init_app() -> FastAPI: pipeline = SequentialPipelineProcessor(steps) fim_pipeline = SequentialPipelineProcessor(fim_steps) + output_steps: List[OutputPipelineStep] = [ + SecretUnredactionStep(), + ] + output_pipeline = OutputPipelineProcessor(output_steps) + # Create provider registry registry = ProviderRegistry(app) @@ -47,21 +60,49 @@ def init_app() -> FastAPI: # Register all known providers registry.add_provider( - "openai", OpenAIProvider(pipeline_processor=pipeline, fim_pipeline_processor=fim_pipeline) + "openai", + OpenAIProvider( + secrets_manager=secrets_manager, + pipeline_processor=pipeline, + fim_pipeline_processor=fim_pipeline, + output_pipeline_processor=output_pipeline, + ), ) registry.add_provider( "anthropic", - AnthropicProvider(pipeline_processor=pipeline, fim_pipeline_processor=fim_pipeline), + AnthropicProvider( + secrets_manager=secrets_manager, + pipeline_processor=pipeline, + fim_pipeline_processor=fim_pipeline, + output_pipeline_processor=output_pipeline, + ), ) registry.add_provider( "llamacpp", - LlamaCppProvider(pipeline_processor=pipeline, fim_pipeline_processor=fim_pipeline), + LlamaCppProvider( + secrets_manager=secrets_manager, + pipeline_processor=pipeline, + fim_pipeline_processor=fim_pipeline, + output_pipeline_processor=output_pipeline, + ), ) registry.add_provider( - "vllm", VLLMProvider(pipeline_processor=pipeline, fim_pipeline_processor=fim_pipeline) + "vllm", + VLLMProvider( + secrets_manager=secrets_manager, + pipeline_processor=pipeline, + fim_pipeline_processor=fim_pipeline, + output_pipeline_processor=output_pipeline, + ), ) registry.add_provider( - "ollama", OllamaProvider(pipeline_processor=pipeline, fim_pipeline_processor=fim_pipeline) + "ollama", + OllamaProvider( + secrets_manager=secrets_manager, + pipeline_processor=pipeline, + fim_pipeline_processor=fim_pipeline, + output_pipeline_processor=output_pipeline, + ), ) # Create and add system routes diff --git a/tests/pipeline/secrets/test_manager.py b/tests/pipeline/secrets/test_manager.py new file mode 100644 index 00000000..5cb06ade --- /dev/null +++ b/tests/pipeline/secrets/test_manager.py @@ -0,0 +1,148 @@ +import pytest + +from codegate.pipeline.secrets.manager import SecretEntry, SecretsManager + + +class TestSecretsManager: + def setup_method(self): + """Setup a fresh SecretsManager for each test""" + self.manager = SecretsManager() + self.test_session = "test_session_id" + self.test_value = "super_secret_value" + self.test_service = "test_service" + self.test_type = "api_key" + + def test_store_secret(self): + """Test basic secret storage and retrieval""" + # Store a secret + encrypted = self.manager.store_secret( + self.test_value, self.test_service, self.test_type, self.test_session + ) + + # Verify the secret was stored + stored = self.manager.get_by_session_id(self.test_session) + assert isinstance(stored, SecretEntry) + assert stored.original == self.test_value + assert stored.encrypted == encrypted + assert stored.service == self.test_service + assert stored.secret_type == self.test_type + + # Verify encrypted value can be retrieved + retrieved = self.manager.get_original_value(encrypted, self.test_session) + assert retrieved == self.test_value + + def test_get_original_value_wrong_session(self): + """Test that secrets can't be accessed with wrong session ID""" + encrypted = self.manager.store_secret( + self.test_value, self.test_service, self.test_type, self.test_session + ) + + # Try to retrieve with wrong session ID + wrong_session = "wrong_session_id" + retrieved = self.manager.get_original_value(encrypted, wrong_session) + assert retrieved is None + + def test_get_original_value_nonexistent(self): + """Test handling of non-existent encrypted values""" + retrieved = self.manager.get_original_value("nonexistent", self.test_session) + assert retrieved is None + + def test_cleanup_session(self): + """Test that session cleanup properly removes secrets""" + # Store multiple secrets in different sessions + session1 = "session1" + session2 = "session2" + + encrypted1 = self.manager.store_secret("secret1", "service1", "type1", session1) + encrypted2 = self.manager.store_secret("secret2", "service2", "type2", session2) + + # Clean up session1 + self.manager.cleanup_session(session1) + + # Verify session1 secrets are gone + assert self.manager.get_by_session_id(session1) is None + assert self.manager.get_original_value(encrypted1, session1) is None + + # Verify session2 secrets remain + assert self.manager.get_by_session_id(session2) is not None + assert self.manager.get_original_value(encrypted2, session2) == "secret2" + + def test_cleanup(self): + """Test that cleanup properly wipes all data""" + # Store multiple secrets + self.manager.store_secret("secret1", "service1", "type1", "session1") + self.manager.store_secret("secret2", "service2", "type2", "session2") + + # Perform cleanup + self.manager.cleanup() + + # Verify all data is wiped + assert len(self.manager._session_store) == 0 + assert len(self.manager._encrypted_to_session) == 0 + + def test_multiple_secrets_same_session(self): + """Test storing multiple secrets in the same session""" + # Store multiple secrets in same session + encrypted1 = self.manager.store_secret("secret1", "service1", "type1", self.test_session) + encrypted2 = self.manager.store_secret("secret2", "service2", "type2", self.test_session) + + # Latest secret should be retrievable + stored = self.manager.get_by_session_id(self.test_session) + assert stored.original == "secret2" + assert stored.encrypted == encrypted2 + + # Both encrypted values should map to the session + assert self.manager._encrypted_to_session[encrypted1] == self.test_session + assert self.manager._encrypted_to_session[encrypted2] == self.test_session + + def test_error_handling(self): + """Test error handling in secret operations""" + # Test with None values + with pytest.raises(ValueError): + self.manager.store_secret(None, self.test_service, self.test_type, self.test_session) + + with pytest.raises(ValueError): + self.manager.store_secret(self.test_value, None, self.test_type, self.test_session) + + with pytest.raises(ValueError): + self.manager.store_secret(self.test_value, self.test_service, None, self.test_session) + + with pytest.raises(ValueError): + self.manager.store_secret(self.test_value, self.test_service, self.test_type, None) + + def test_secure_cleanup(self): + """Test that cleanup securely wipes sensitive data""" + # Store a secret + self.manager.store_secret( + self.test_value, self.test_service, self.test_type, self.test_session + ) + + # Get reference to stored data before cleanup + stored = self.manager.get_by_session_id(self.test_session) + original_value = stored.original + + # Perform cleanup + self.manager.cleanup() + + # Verify the original string was overwritten, not just removed + # This test is a bit tricky since Python strings are immutable, + # but we can at least verify the data is no longer accessible + assert original_value not in str(self.manager._session_store) + assert self.test_value not in str(self.manager._session_store) + + def test_session_isolation(self): + """Test that sessions are properly isolated""" + session1 = "session1" + session2 = "session2" + + # Store secrets in different sessions + encrypted1 = self.manager.store_secret("secret1", "service1", "type1", session1) + encrypted2 = self.manager.store_secret("secret2", "service2", "type2", session2) + + # Verify cross-session access is not possible + assert self.manager.get_original_value(encrypted1, session2) is None + assert self.manager.get_original_value(encrypted2, session1) is None + + # Verify correct session access works + assert self.manager.get_original_value(encrypted1, session1) == "secret1" + assert self.manager.get_original_value(encrypted2, session2) == "secret2" diff --git a/tests/pipeline/secrets/test_secrets.py b/tests/pipeline/secrets/test_secrets.py new file mode 100644 index 00000000..52be4eaf --- /dev/null +++ b/tests/pipeline/secrets/test_secrets.py @@ -0,0 +1,147 @@ +import pytest +from litellm import ModelResponse +from litellm.types.utils import Delta, StreamingChoices + +from codegate.pipeline.base import PipelineContext, PipelineSensitiveData +from codegate.pipeline.output import OutputPipelineContext +from codegate.pipeline.secrets.manager import SecretsManager +from codegate.pipeline.secrets.secrets import SecretUnredactionStep + + +def create_model_response(content: str) -> ModelResponse: + """Helper to create test ModelResponse objects""" + return ModelResponse( + id="test", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta(content=content, role="assistant"), + logprobs=None, + ) + ], + created=0, + model="test-model", + object="chat.completion.chunk", + ) + + +class TestSecretUnredactionStep: + def setup_method(self): + """Setup fresh instances for each test""" + self.step = SecretUnredactionStep() + self.context = OutputPipelineContext() + self.secrets_manager = SecretsManager() + self.session_id = "test_session" + + # Setup input context with secrets manager + self.input_context = PipelineContext() + self.input_context.sensitive = PipelineSensitiveData( + manager=self.secrets_manager, session_id=self.session_id + ) + + @pytest.mark.asyncio + async def test_complete_marker_processing(self): + """Test processing of a complete REDACTED marker""" + # Store a secret + encrypted = self.secrets_manager.store_secret( + "secret_value", "test_service", "api_key", self.session_id + ) + + # Add content with REDACTED marker to buffer + self.context.buffer.append(f"Here is the REDACTED<${encrypted}> in text") + + # Process a chunk + result = await self.step.process_chunk( + create_model_response("more text"), self.context, self.input_context + ) + + # Verify unredaction + assert result is not None + assert result.choices[0].delta.content == "Here is the secret_value in text" + + @pytest.mark.asyncio + async def test_partial_marker_buffering(self): + """Test handling of partial REDACTED markers""" + # Add partial marker to buffer + self.context.buffer.append("Here is REDACTED<$") + + # Process a chunk + result = await self.step.process_chunk( + create_model_response("partial"), self.context, self.input_context + ) + + # Should return None to continue buffering + assert result is None + + @pytest.mark.asyncio + async def test_invalid_encrypted_value(self): + """Test handling of invalid encrypted values""" + # Add content with invalid encrypted value + self.context.buffer.append("Here is REDACTED<$invalid_value> in text") + + # Process chunk + result = await self.step.process_chunk( + create_model_response("text"), self.context, self.input_context + ) + + # Should keep the REDACTED marker for invalid values + assert result is not None + assert result.choices[0].delta.content == "Here is REDACTED<$invalid_value> in text" + + @pytest.mark.asyncio + async def test_missing_context(self): + """Test handling of missing input context or secrets manager""" + # Test with None input context + with pytest.raises(ValueError, match="Input context not found"): + await self.step.process_chunk(create_model_response("text"), self.context, None) + + # Test with missing secrets manager + self.input_context.sensitive.manager = None + with pytest.raises(ValueError, match="Secrets manager not found in input context"): + await self.step.process_chunk( + create_model_response("text"), self.context, self.input_context + ) + + @pytest.mark.asyncio + async def test_empty_content(self): + """Test handling of empty content chunks""" + result = await self.step.process_chunk( + create_model_response(""), self.context, self.input_context + ) + + # Should pass through empty chunks + assert result is not None + assert result.choices[0].delta.content == "" + + @pytest.mark.asyncio + async def test_no_markers(self): + """Test processing of content without any REDACTED markers""" + # Create chunk with content + chunk = create_model_response("Regular text without any markers") + + # Process chunk + result = await self.step.process_chunk(chunk, self.context, self.input_context) + + # Should pass through unchanged + assert result is not None + assert result.choices[0].delta.content == "Regular text without any markers" + + @pytest.mark.asyncio + async def test_wrong_session(self): + """Test unredaction with wrong session ID""" + # Store secret with one session + encrypted = self.secrets_manager.store_secret( + "secret_value", "test_service", "api_key", "different_session" + ) + + # Try to unredact with different session + self.context.buffer.append(f"Here is the REDACTED<${encrypted}> in text") + + result = await self.step.process_chunk( + create_model_response("text"), self.context, self.input_context + ) + + # Should keep REDACTED marker when session doesn't match + assert result is not None + assert result.choices[0].delta.content == f"Here is the REDACTED<${encrypted}> in text" diff --git a/tests/pipeline/test_output.py b/tests/pipeline/test_output.py new file mode 100644 index 00000000..eeb42085 --- /dev/null +++ b/tests/pipeline/test_output.py @@ -0,0 +1,289 @@ +from typing import Optional + +import pytest +from litellm import ModelResponse +from litellm.types.utils import Delta, StreamingChoices + +from codegate.pipeline.base import PipelineContext +from codegate.pipeline.output import ( + OutputPipelineContext, + OutputPipelineInstance, + OutputPipelineStep, +) + + +class MockOutputPipelineStep(OutputPipelineStep): + """Mock pipeline step for testing""" + + def __init__(self, name: str, should_pause: bool = False, modify_content: bool = False): + self._name = name + self._should_pause = should_pause + self._modify_content = modify_content + + @property + def name(self) -> str: + return self._name + + async def process_chunk( + self, + chunk: ModelResponse, + context: OutputPipelineContext, + input_context: PipelineContext = None, + ) -> ModelResponse: + if self._should_pause: + return None + + if self._modify_content and chunk.choices[0].delta.content: + # Append step name to content to track modifications + modified_content = f"{chunk.choices[0].delta.content}_{self.name}" + chunk.choices[0].delta.content = modified_content + + return chunk + + +def create_model_response(content: str, id: str = "test") -> ModelResponse: + """Helper to create test ModelResponse objects""" + return ModelResponse( + id=id, + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta(content=content, role="assistant"), + logprobs=None, + ) + ], + created=0, + model="test-model", + object="chat.completion.chunk", + ) + + +class TestOutputPipelineContext: + def test_buffer_initialization(self): + """Test that buffer is properly initialized""" + context = OutputPipelineContext() + assert isinstance(context.buffer, list) + assert len(context.buffer) == 0 + + def test_buffer_operations(self): + """Test adding and clearing buffer content""" + context = OutputPipelineContext() + context.buffer.append("test1") + context.buffer.append("test2") + + assert len(context.buffer) == 2 + assert context.buffer == ["test1", "test2"] + + context.buffer.clear() + assert len(context.buffer) == 0 + + +class TestOutputPipelineInstance: + @pytest.mark.asyncio + async def test_single_step_processing(self): + """Test processing a stream through a single step""" + step = MockOutputPipelineStep("test_step", modify_content=True) + instance = OutputPipelineInstance([step]) + + async def mock_stream(): + yield create_model_response("Hello") + yield create_model_response("World") + + chunks = [] + async for chunk in instance.process_stream(mock_stream()): + chunks.append(chunk) + + assert len(chunks) == 2 + assert chunks[0].choices[0].delta.content == "Hello_test_step" + assert chunks[1].choices[0].delta.content == "World_test_step" + # Buffer should be cleared after each successful chunk + assert len(instance._context.buffer) == 0 + + @pytest.mark.asyncio + async def test_multiple_steps_processing(self): + """Test processing a stream through multiple steps""" + steps = [ + MockOutputPipelineStep("step1", modify_content=True), + MockOutputPipelineStep("step2", modify_content=True), + ] + instance = OutputPipelineInstance(steps) + + async def mock_stream(): + yield create_model_response("Hello") + + chunks = [] + async for chunk in instance.process_stream(mock_stream()): + chunks.append(chunk) + + assert len(chunks) == 1 + # Content should be modified by both steps + assert chunks[0].choices[0].delta.content == "Hello_step1_step2" + # Buffer should be cleared after successful processing + assert len(instance._context.buffer) == 0 + + @pytest.mark.asyncio + async def test_step_pausing(self): + """Test that a step can pause the stream and content is buffered until flushed""" + steps = [ + MockOutputPipelineStep("step1", should_pause=True), + MockOutputPipelineStep("step2", modify_content=True), + ] + instance = OutputPipelineInstance(steps) + + async def mock_stream(): + yield create_model_response("he") + yield create_model_response("ll") + yield create_model_response("o") + yield create_model_response(" wo") + yield create_model_response("rld") + + chunks = [] + async for chunk in instance.process_stream(mock_stream()): + chunks.append(chunk) + + # Should get one chunk at the end with all buffered content + assert len(chunks) == 1 + # Content should be buffered and combined + assert chunks[0].choices[0].delta.content == "hello world" + # Buffer should be cleared after flush + assert len(instance._context.buffer) == 0 + + @pytest.mark.asyncio + async def test_step_pausing_with_replacement(self): + """Test that a step can pause the stream and modify the buffered content before flushing""" + + class ReplacementStep(OutputPipelineStep): + """Step that replaces 'world' with 'moon' when found in buffer""" + + def __init__(self, should_pause: bool = True): + self._should_pause = should_pause + + @property + def name(self) -> str: + return "replacement" + + async def process_chunk( + self, + chunk: ModelResponse, + context: OutputPipelineContext, + input_context: PipelineContext = None, + ) -> Optional[ModelResponse]: + # Replace 'world' with 'moon' in buffered content + content = "".join(context.buffer) + if "world" in content: + content = content.replace("world", "moon") + chunk.choices = [ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta(content=content, role="assistant"), + logprobs=None, + ) + ] + return chunk + return None + + instance = OutputPipelineInstance([ReplacementStep()]) + + async def mock_stream(): + yield create_model_response("he") + yield create_model_response("ll") + yield create_model_response("o") + yield create_model_response("wo") + yield create_model_response("rld") + + chunks = [] + async for chunk in instance.process_stream(mock_stream()): + chunks.append(chunk) + + # Should get one chunk at the end with modified content + assert len(chunks) == 1 + assert chunks[0].choices[0].delta.content == "hellomoon" + # Buffer should be cleared after flush + assert len(instance._context.buffer) == 0 + + @pytest.mark.asyncio + async def test_buffer_processing(self): + """Test that content is properly buffered and cleared""" + step = MockOutputPipelineStep("test_step") + instance = OutputPipelineInstance([step]) + + async def mock_stream(): + yield create_model_response("Hello") + yield create_model_response("World") + + chunks = [] + async for chunk in instance.process_stream(mock_stream()): + chunks.append(chunk) + # Buffer should be cleared after each successful chunk + assert len(instance._context.buffer) == 0 + + assert len(chunks) == 2 + assert chunks[0].choices[0].delta.content == "Hello" + assert chunks[1].choices[0].delta.content == "World" + + @pytest.mark.asyncio + async def test_empty_stream(self): + """Test handling of an empty stream""" + step = MockOutputPipelineStep("test_step") + instance = OutputPipelineInstance([step]) + + async def mock_stream(): + if False: + yield # Empty stream + + chunks = [] + async for chunk in instance.process_stream(mock_stream()): + chunks.append(chunk) + + assert len(chunks) == 0 + assert len(instance._context.buffer) == 0 + + @pytest.mark.asyncio + async def test_input_context_passing(self): + """Test that input context is properly passed to steps""" + input_context = PipelineContext() + input_context.metadata["test"] = "value" + + class ContextCheckingStep(OutputPipelineStep): + @property + def name(self) -> str: + return "context_checker" + + async def process_chunk( + self, + chunk: ModelResponse, + context: OutputPipelineContext, + input_context: PipelineContext = None, + ) -> ModelResponse: + assert input_context.metadata["test"] == "value" + return chunk + + instance = OutputPipelineInstance([ContextCheckingStep()], input_context=input_context) + + async def mock_stream(): + yield create_model_response("test") + + async for _ in instance.process_stream(mock_stream()): + pass + + @pytest.mark.asyncio + async def test_buffer_flush_on_stream_end(self): + """Test that buffer is properly flushed when stream ends""" + step = MockOutputPipelineStep("test_step", should_pause=True) + instance = OutputPipelineInstance([step]) + + async def mock_stream(): + yield create_model_response("Hello") + yield create_model_response("World") + + chunks = [] + async for chunk in instance.process_stream(mock_stream()): + chunks.append(chunk) + + # Should get one chunk with combined buffer content + assert len(chunks) == 1 + assert chunks[0].choices[0].delta.content == "HelloWorld" + # Buffer should be cleared after flush + assert len(instance._context.buffer) == 0 diff --git a/tests/providers/ollama/test_ollama_provider.py b/tests/providers/ollama/test_ollama_provider.py index 5fd5cf4e..ed10e7fd 100644 --- a/tests/providers/ollama/test_ollama_provider.py +++ b/tests/providers/ollama/test_ollama_provider.py @@ -18,7 +18,7 @@ def __init__(self): def app(): """Create FastAPI app with Ollama provider.""" app = FastAPI() - provider = OllamaProvider() + provider = OllamaProvider(None) app.include_router(provider.get_routes()) return app