Skip to content

Commit e3f9e9a

Browse files
authored
Pass fewer snippets to suspicious commands (#1151)
Signed-off-by: nigel brown <[email protected]>
1 parent 1cbee55 commit e3f9e9a

File tree

3 files changed

+100
-19
lines changed

3 files changed

+100
-19
lines changed

src/codegate/pipeline/comment/output.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
)
1313
from codegate.pipeline.base import PipelineContext
1414
from codegate.pipeline.output import OutputPipelineContext, OutputPipelineStep
15-
16-
# from codegate.pipeline.suspicious_commands.suspicious_commands import check_suspicious_code
15+
from codegate.pipeline.suspicious_commands.suspicious_commands import check_suspicious_code
1716
from codegate.storage import StorageEngine
1817
from codegate.utils.package_extractor import PackageExtractor
1918

@@ -53,10 +52,15 @@ async def _snippet_comment(self, snippet: CodeSnippet, context: PipelineContext)
5352
"""Create a comment for a snippet"""
5453
comment = ""
5554

56-
# Remove this for now. We need to find a better place for it.
57-
# comment, is_suspicious = await check_suspicious_code(snippet.code, snippet.language)
58-
# if is_suspicious:
59-
# comment += comment
55+
if (
56+
snippet.filepath is None
57+
and snippet.file_extension is None
58+
and "filepath" not in snippet.code
59+
and "existing code" not in snippet.code
60+
):
61+
new_comment, is_suspicious = await check_suspicious_code(snippet.code, snippet.language)
62+
if is_suspicious:
63+
comment += new_comment
6064

6165
snippet.libraries = PackageExtractor.extract_packages(snippet.code, snippet.language)
6266

src/codegate/pipeline/suspicious_commands/suspicious_commands.py

+23-13
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@
1212

1313
import numpy as np # Add this import
1414
import onnxruntime as ort
15+
import structlog
1516

1617
from codegate.config import Config
1718
from codegate.inference.inference_engine import LlamaCppInferenceEngine
1819

20+
logger = structlog.get_logger("codegate")
21+
1922

2023
class SuspiciousCommands:
2124
"""
@@ -123,22 +126,29 @@ async def check_suspicious_code(code, language=None):
123126
Returns:
124127
tuple: A comment string and a boolean indicating if the code is suspicious.
125128
"""
129+
if language is None:
130+
language = "code"
131+
if language in [
132+
"python",
133+
"javascript",
134+
"typescript",
135+
"go",
136+
"rust",
137+
"java",
138+
]:
139+
logger.debug(f"Skipping suspicious command check for {language}")
140+
return "", False
141+
logger.debug("Checking code for suspicious commands")
126142
sc = SuspiciousCommands.get_instance()
127143
comment = ""
128144
class_, prob = await sc.classify_phrase(code)
129-
if class_ == 1:
145+
is_suspicious = class_ == 1
146+
if is_suspicious:
130147
liklihood = "possibly"
131148
if prob > 0.9:
132149
liklihood = "likely"
133-
if language is None:
134-
language = "code"
135-
if language not in [
136-
"python",
137-
"javascript",
138-
"typescript",
139-
"go",
140-
"rust",
141-
"java",
142-
]:
143-
comment = f"{comment}\n\n🛡️ CodeGate: The {language} supplied is {liklihood} unsafe. Please check carefully!\n\n" # noqa: E501
144-
return comment, class_ == 1
150+
comment = f"{comment}\n\n🛡️ CodeGate: The {language} supplied is {liklihood} unsafe. Please check carefully!\n\n" # noqa: E501
151+
logger.info(f"Suspicious: {code}")
152+
else:
153+
logger.debug("Not Suspicious")
154+
return comment, is_suspicious

tests/test_suspicious_commands.py

+67
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
"""
55
import csv
66
import os
7+
from unittest.mock import AsyncMock, patch
78

89
import pytest
910

1011
from codegate.pipeline.suspicious_commands.suspicious_commands import (
1112
SuspiciousCommands,
13+
check_suspicious_code,
1214
)
1315

1416
try:
@@ -189,3 +191,68 @@ async def test_classify_phrase_confident(sc):
189191
else:
190192
print(f"{command['cmd']} {prob} {prediction} 1")
191193
check_results(tp, tn, fp, fn)
194+
195+
196+
@pytest.mark.asyncio
197+
@patch("codegate.pipeline.suspicious_commands.suspicious_commands.SuspiciousCommands.get_instance")
198+
async def test_check_suspicious_code_safe(mock_get_instance):
199+
"""
200+
Test check_suspicious_code with safe code.
201+
"""
202+
mock_instance = mock_get_instance.return_value
203+
mock_instance.classify_phrase = AsyncMock(return_value=(0, 0.5))
204+
205+
code = "print('Hello, world!')"
206+
comment, is_suspicious = await check_suspicious_code(code, "python")
207+
208+
assert comment == ""
209+
assert is_suspicious is False
210+
211+
212+
@pytest.mark.asyncio
213+
@patch("codegate.pipeline.suspicious_commands.suspicious_commands.SuspiciousCommands.get_instance")
214+
async def test_check_suspicious_code_suspicious(mock_get_instance):
215+
"""
216+
Test check_suspicious_code with suspicious code.
217+
"""
218+
mock_instance = mock_get_instance.return_value
219+
mock_instance.classify_phrase = AsyncMock(return_value=(1, 0.95))
220+
221+
code = "rm -rf /"
222+
comment, is_suspicious = await check_suspicious_code(code, "bash")
223+
224+
assert "🛡️ CodeGate: The bash supplied is likely unsafe." in comment
225+
assert is_suspicious is True
226+
227+
228+
@pytest.mark.asyncio
229+
@patch("codegate.pipeline.suspicious_commands.suspicious_commands.SuspiciousCommands.get_instance")
230+
async def test_check_suspicious_code_skipped_language(mock_get_instance):
231+
"""
232+
Test check_suspicious_code with a language that should be skipped.
233+
"""
234+
mock_instance = mock_get_instance.return_value
235+
mock_instance.classify_phrase = AsyncMock()
236+
237+
code = "print('Hello, world!')"
238+
comment, is_suspicious = await check_suspicious_code(code, "python")
239+
240+
assert comment == ""
241+
assert is_suspicious is False
242+
mock_instance.classify_phrase.assert_not_called()
243+
244+
245+
@pytest.mark.asyncio
246+
@patch("codegate.pipeline.suspicious_commands.suspicious_commands.SuspiciousCommands.get_instance")
247+
async def test_check_suspicious_code_no_language(mock_get_instance):
248+
"""
249+
Test check_suspicious_code with no language specified.
250+
"""
251+
mock_instance = mock_get_instance.return_value
252+
mock_instance.classify_phrase = AsyncMock(return_value=(1, 0.85))
253+
254+
code = "rm -rf /"
255+
comment, is_suspicious = await check_suspicious_code(code)
256+
257+
assert "🛡️ CodeGate: The code supplied is possibly unsafe." in comment
258+
assert is_suspicious is True

0 commit comments

Comments
 (0)