Skip to content

Add integration checks for comparing the result of calling the model API directly vs via CodeGate #1032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 34 additions & 13 deletions tests/integration/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def load(test_data: dict) -> List[BaseCheck]:
checks.append(ContainsCheck(test_name))
if test_data.get(DoesNotContainCheck.KEY):
checks.append(DoesNotContainCheck(test_name))

if test_data.get(CodeGateEnrichment.KEY) is not None:
checks.append(CodeGateEnrichment(test_name))
return checks


Expand All @@ -51,11 +52,10 @@ async def run_check(self, parsed_response: str, test_data: dict) -> bool:
similarity = await self._calculate_string_similarity(
parsed_response, test_data[DistanceCheck.KEY]
)
logger.debug(f"Similarity: {similarity}")
logger.debug(f"Response: {parsed_response}")
logger.debug(f"Expected Response: {test_data[DistanceCheck.KEY]}")
if similarity < 0.8:
logger.error(f"Test {self.test_name} failed")
logger.error(f"Similarity: {similarity}")
logger.error(f"Response: {parsed_response}")
logger.error(f"Expected Response: {test_data[DistanceCheck.KEY]}")
return False
return True

Expand All @@ -64,10 +64,9 @@ class ContainsCheck(BaseCheck):
KEY = "contains"

async def run_check(self, parsed_response: str, test_data: dict) -> bool:
logger.debug(f"Response: {parsed_response}")
logger.debug(f"Expected Response to contain: {test_data[ContainsCheck.KEY]}")
if test_data[ContainsCheck.KEY].strip() not in parsed_response:
logger.error(f"Test {self.test_name} failed")
logger.error(f"Response: {parsed_response}")
logger.error(f"Expected Response to contain: '{test_data[ContainsCheck.KEY]}'")
return False
return True

Expand All @@ -76,11 +75,33 @@ class DoesNotContainCheck(BaseCheck):
KEY = "does_not_contain"

async def run_check(self, parsed_response: str, test_data: dict) -> bool:
logger.debug(f"Response: {parsed_response}")
logger.debug(f"Expected Response to not contain: '{test_data[DoesNotContainCheck.KEY]}'")
if test_data[DoesNotContainCheck.KEY].strip() in parsed_response:
logger.error(f"Test {self.test_name} failed")
logger.error(f"Response: {parsed_response}")
logger.error(
f"Expected Response to not contain: '{test_data[DoesNotContainCheck.KEY]}'"
)
return False
return True


class CodeGateEnrichment(BaseCheck):
KEY = "codegate_enrichment"

async def run_check(self, parsed_response: str, test_data: dict) -> bool:
direct_response = test_data["direct_response"]
logger.debug(f"Response (CodeGate): {parsed_response}")
logger.debug(f"Response (Raw model): {direct_response}")

# Use the DistanceCheck to compare the two responses
distance_check = DistanceCheck(self.test_name)
are_similar = await distance_check.run_check(
parsed_response, {DistanceCheck.KEY: direct_response}
)

# Check if the response is enriched by CodeGate.
# If it is, there should be a difference in the similarity score.
expect_enrichment = test_data.get(CodeGateEnrichment.KEY).get("expect_difference", False)
if expect_enrichment:
logger.info("CodeGate enrichment check: Expecting difference")
return not are_similar
# If the response is not enriched, the similarity score should be the same.
logger.info("CodeGate enrichment check: Not expecting difference")
return are_similar
33 changes: 28 additions & 5 deletions tests/integration/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import requests
import structlog
import yaml
from checks import CheckLoader
from checks import CheckLoader, CodeGateEnrichment
from dotenv import find_dotenv, load_dotenv
from requesters import RequesterFactory

Expand All @@ -21,7 +21,7 @@ def __init__(self):
self.requester_factory = RequesterFactory()
self.failed_tests = [] # Track failed tests

def call_codegate(
def call_provider(
self, url: str, headers: dict, data: dict, provider: str, method: str = "POST"
) -> Optional[requests.Response]:
logger.debug(f"Creating requester for provider: {provider}")
Expand Down Expand Up @@ -132,18 +132,29 @@ def replacement(match):

async def run_test(self, test: dict, test_headers: dict) -> bool:
test_name = test["name"]
url = test["url"]
data = json.loads(test["data"])
codegate_url = test["url"]
streaming = data.get("stream", False)
provider = test["provider"]

logger.info(f"Starting test: {test_name}")

response = self.call_codegate(url, test_headers, data, provider)
# Call Codegate
response = self.call_provider(codegate_url, test_headers, data, provider)
if not response:
logger.error(f"Test {test_name} failed: No response received")
return False

# Call model directly if specified
direct_response = None
if test.get(CodeGateEnrichment.KEY) is not None:
direct_provider_url = test.get(CodeGateEnrichment.KEY)["provider_url"]
direct_response = self.call_provider(
direct_provider_url, test_headers, data, "not-codegate"
)
if not direct_response:
logger.error(f"Test {test_name} failed: No direct response received")
return False

# Debug response info
logger.debug(f"Response status: {response.status_code}")
logger.debug(f"Response headers: {dict(response.headers)}")
Expand All @@ -152,13 +163,24 @@ async def run_test(self, test: dict, test_headers: dict) -> bool:
parsed_response = self.parse_response_message(response, streaming=streaming)
logger.debug(f"Response message: {parsed_response}")

if direct_response:
# Dirty hack to pass direct response to checks
test["direct_response"] = self.parse_response_message(
direct_response, streaming=streaming
)
logger.debug(f"Direct response message: {test['direct_response']}")

# Load appropriate checks for this test
checks = CheckLoader.load(test)

# Run all checks
all_passed = True
for check in checks:
logger.info(f"Running check: {check.__class__.__name__}")
passed_check = await check.run_check(parsed_response, test)
logger.info(
f"Check {check.__class__.__name__} {'passed' if passed_check else 'failed'}"
)
if not passed_check:
all_passed = False

Expand Down Expand Up @@ -379,6 +401,7 @@ async def main():
# Exit with status code 1 if any tests failed
if not all_tests_passed:
sys.exit(1)
logger.info("All tests passed")


if __name__ == "__main__":
Expand Down
12 changes: 12 additions & 0 deletions tests/integration/ollama/testcases.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ testcases:
name: Ollama Chat
provider: ollama
url: http://127.0.0.1:8989/ollama/chat/completions
codegate_enrichment:
provider_url: http://127.0.0.1:11434/api/chat
expect_difference: false
data: |
{
"max_tokens":4096,
Expand All @@ -55,6 +58,9 @@ testcases:
name: Ollama FIM
provider: ollama
url: http://127.0.0.1:8989/ollama/api/generate
codegate_enrichment:
provider_url: http://127.0.0.1:11434/api/generate
expect_difference: false
data: |
{
"stream": true,
Expand Down Expand Up @@ -88,6 +94,9 @@ testcases:
name: Ollama Malicious Package
provider: ollama
url: http://127.0.0.1:8989/ollama/chat/completions
codegate_enrichment:
provider_url: http://127.0.0.1:11434/api/chat
expect_difference: true
data: |
{
"max_tokens":4096,
Expand All @@ -112,6 +121,9 @@ testcases:
name: Ollama secret redacting chat
provider: ollama
url: http://127.0.0.1:8989/ollama/chat/completions
codegate_enrichment:
provider_url: http://127.0.0.1:11434/api/chat
expect_difference: true
data: |
{
"max_tokens":4096,
Expand Down
10 changes: 10 additions & 0 deletions tests/integration/vllm/testcases.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ testcases:
name: VLLM Chat
provider: vllm
url: http://127.0.0.1:8989/vllm/chat/completions
codegate_enrichment:
provider_url: http://127.0.0.1:8000/v1/chat/completions
expect_difference: false
data: |
{
"max_tokens":4096,
Expand All @@ -55,6 +58,10 @@ testcases:
name: VLLM FIM
provider: vllm
url: http://127.0.0.1:8989/vllm/completions
# This is commented out for now as there's some issue with parsing the streamed response from the model (on the vllm side, not codegate)
# codegate_enrichment:
# provider_url: http://127.0.0.1:8000/v1/completions
# expect_difference: false
data: |
{
"model": "Qwen/Qwen2.5-Coder-0.5B-Instruct",
Expand Down Expand Up @@ -84,6 +91,9 @@ testcases:
name: VLLM Malicious Package
provider: vllm
url: http://127.0.0.1:8989/vllm/chat/completions
codegate_enrichment:
provider_url: http://127.0.0.1:8000/v1/chat/completions
expect_difference: true
data: |
{
"max_tokens":4096,
Expand Down