Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ async def submit_request(


async def submit_request_onedp(
client: AIProjectClient, data: dict, metric: str, token: str, annotation_task: str, evaluator_name: str
client: AIProjectClient, data: dict, metric: str, token: str, annotation_task: str, evaluator_name: str, scan_session_id: Optional[str] = None
) -> str:
"""Submit request to Responsible AI service for evaluation and return operation ID

Expand All @@ -275,12 +275,16 @@ async def submit_request_onedp(
:type annotation_task: str
:param evaluator_name: The evaluator name.
:type evaluator_name: str
:param scan_session_id: The scan session ID to use for the evaluation.
:type scan_session_id: Optional[str]
:return: The operation ID.
:rtype: str
"""
normalized_user_text = get_formatted_template(data, annotation_task)
payload = generate_payload(normalized_user_text, metric, annotation_task=annotation_task)
headers = get_common_headers(token, evaluator_name)
if scan_session_id:
headers["client_request_id"] = scan_session_id
response = client.evaluations.submit_annotation(payload, headers=headers)
result = json.loads(response)
operation_id = result["location"].split("/")[-1]
Expand Down Expand Up @@ -631,6 +635,7 @@ async def evaluate_with_rai_service(
annotation_task: str = Tasks.CONTENT_HARM,
metric_display_name=None,
evaluator_name=None,
scan_session_id: Optional[str]=None
) -> Dict[str, Union[str, float]]:
"""Evaluate the content safety of the response using Responsible AI service

Expand All @@ -649,6 +654,8 @@ async def evaluate_with_rai_service(
:type metric_display_name: str
:param evaluator_name: The evaluator name to use.
:type evaluator_name: str
:param scan_session_id: The scan session ID to use for the evaluation.
:type scan_session_id: Optional[str]
:return: The parsed annotation result.
:rtype: Dict[str, Union[str, float]]
"""
Expand All @@ -661,7 +668,7 @@ async def evaluate_with_rai_service(
)
token = await fetch_or_reuse_token(credential=credential, workspace=COG_SRV_WORKSPACE)
await ensure_service_availability_onedp(client, token, annotation_task)
operation_id = await submit_request_onedp(client, data, metric_name, token, annotation_task, evaluator_name)
operation_id = await submit_request_onedp(client, data, metric_name, token, annotation_task, evaluator_name, scan_session_id)
annotation_response = cast(List[Dict], await fetch_result_onedp(client, operation_id, token))
result = parse_response(annotation_response, metric_name, metric_display_name)
return result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -582,18 +582,18 @@ def get_jail_break_dataset_with_type(self, type: str, **kwargs: Any) -> List[str
def get_attack_objectives(
self,
*,
risk_category: str,
risk_types: Optional[List[str]] = None,
risk_categories: Optional[List[str]] = None,
lang: Optional[str] = None,
strategy: Optional[str] = None,
**kwargs: Any
) -> List[_models.AttackObjective]:
"""Get the attack objectives.

:keyword risk_category: Risk category for the attack objectives. Required.
:paramtype risk_category: str
:keyword risk_types: Risk types for the attack objectives dataset. Default value is None.
:paramtype risk_types: list[str]
:keyword risk_categories: Risk categories for the attack objectives dataset. Default value is None.
:paramtype risk_categories: list[str]
:keyword lang: The language for the attack objectives dataset, defaults to 'en'. Default value
is None.
:paramtype lang: str
Expand All @@ -615,10 +615,10 @@ def get_attack_objectives(
_params = kwargs.pop("params", {}) or {}

cls: ClsType[List[_models.AttackObjective]] = kwargs.pop("cls", None)

_request = build_rai_svc_get_attack_objectives_request(
risk_categories=[risk_category],
risk_types=risk_types,
risk_categories=risk_categories,
lang=lang,
strategy=strategy,
api_version=self._config.api_version,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def __init__(
self.failed_tasks = 0
self.start_time = None
self.scan_id = None
self.scan_session_id = None
self.scan_output_dir = None

self.generated_rai_client = GeneratedRAIClient(azure_ai_project=self.azure_ai_project, token_manager=self.token_manager.credential) # type: ignore
Expand Down Expand Up @@ -342,7 +343,7 @@ def _start_redteam_mlflow_run(
if self._one_dp_project:
response = self.generated_rai_client._evaluation_onedp_client.start_red_team_run(
red_team=RedTeamUpload(
scan_name=run_name or f"redteam-agent-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
display_name=run_name or f"redteam-agent-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
)
)

Expand Down Expand Up @@ -561,7 +562,7 @@ async def _log_redteam_results_to_mlflow(
name=eval_run.id,
red_team=RedTeamUpload(
id=eval_run.id,
scan_name=eval_run.scan_name or f"redteam-agent-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
display_name=eval_run.display_name or f"redteam-agent-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
status="Completed",
outputs={
"evaluationResultId": create_evaluation_result_response.id,
Expand Down Expand Up @@ -741,7 +742,7 @@ async def get_jailbreak_prefixes_with_retry():

else:
content_harm_risk = None
other_risk = None
other_risk = ""
if risk_cat_value in ["hate_unfairness", "violence", "self_harm", "sexual"]:
content_harm_risk = risk_cat_value
else:
Expand All @@ -759,13 +760,15 @@ async def get_jailbreak_prefixes_with_retry():
risk_category=other_risk,
application_scenario=application_scenario or "",
strategy="tense",
scan_session_id=self.scan_session_id
)
else:
objectives_response = await self.generated_rai_client.get_attack_objectives(
risk_type=content_harm_risk,
risk_category=other_risk,
application_scenario=application_scenario or "",
strategy=None,
scan_session_id=self.scan_session_id
)
if isinstance(objectives_response, list):
self.logger.debug(f"API returned {len(objectives_response)} objectives")
Expand All @@ -775,7 +778,7 @@ async def get_jailbreak_prefixes_with_retry():
# Handle jailbreak strategy - need to apply jailbreak prefixes to messages
if strategy == "jailbreak":
self.logger.debug("Applying jailbreak prefixes to objectives")
jailbreak_prefixes = await self.generated_rai_client.get_jailbreak_prefixes()
jailbreak_prefixes = await self.generated_rai_client.get_jailbreak_prefixes(scan_session_id=self.scan_session_id)
for objective in objectives_response:
if "messages" in objective and len(objective["messages"]) > 0:
message = objective["messages"][0]
Expand Down Expand Up @@ -2350,6 +2353,7 @@ async def evaluate_with_rai_service_with_retry():
project_scope=self.azure_ai_project,
credential=self.credential,
annotation_task=annotation_task,
scan_session_id=self.scan_session_id,
)
except (
httpx.ConnectTimeout,
Expand Down Expand Up @@ -2692,7 +2696,7 @@ async def scan(
application_scenario: Optional[str] = None,
parallel_execution: bool = True,
max_parallel_tasks: int = 5,
timeout: int = 120,
timeout: int = 3600,
skip_evals: bool = False,
**kwargs: Any,
) -> RedTeamResult:
Expand Down Expand Up @@ -2737,13 +2741,20 @@ async def scan(
)
self.scan_id = self.scan_id.replace(" ", "_")

self.scan_session_id = str(uuid.uuid4()) # Unique session ID for this scan

# Create output directory for this scan
# If DEBUG environment variable is set, use a regular folder name; otherwise, use a hidden folder
is_debug = os.environ.get("DEBUG", "").lower() in ("true", "1", "yes", "y")
folder_prefix = "" if is_debug else "."
self.scan_output_dir = os.path.join(self.output_dir or ".", f"{folder_prefix}{self.scan_id}")
os.makedirs(self.scan_output_dir, exist_ok=True)

if not is_debug:
gitignore_path = os.path.join(self.scan_output_dir, ".gitignore")
with open(gitignore_path, "w", encoding="utf-8") as f:
f.write("*\n")

# Re-initialize logger with the scan output directory
self.logger = setup_logger(output_dir=self.scan_output_dir)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,12 @@ def _get_service_discovery_url(self):

async def get_attack_objectives(
self,
*,
risk_type: Optional[str] = None,
risk_category: Optional[str] = None,
application_scenario: str = None,
strategy: Optional[str] = None,
scan_session_id: Optional[str] = None,
) -> Dict:
"""Get attack objectives using the auto-generated operations.

Expand All @@ -113,16 +115,19 @@ async def get_attack_objectives(
:type application_scenario: str
:param strategy: Optional strategy to filter the attack objectives
:type strategy: Optional[str]
:param scan_session_id: Optional unique session ID for the scan
:type scan_session_id: Optional[str]
:return: The attack objectives
:rtype: Dict
"""
try:
# Send the request using the autogenerated client
response = self._client.get_attack_objectives(
risk_types=[risk_type],
risk_categories=[risk_category],
risk_category=risk_category,
lang="en",
strategy=strategy,
headers={"client_request_id": scan_session_id},
)
return response

Expand All @@ -133,15 +138,19 @@ async def get_attack_objectives(
logging.error(f"Error in get_attack_objectives: {str(e)}")
raise

async def get_jailbreak_prefixes(self) -> List[str]:
async def get_jailbreak_prefixes(self, scan_session_id: Optional[str] = None) -> List[str]:
"""Get jailbreak prefixes using the auto-generated operations.

:param scan_session_id: Optional unique session ID for the scan
:type scan_session_id: Optional[str]
:return: The jailbreak prefixes
:rtype: List[str]
"""
try:
# Send the request using the autogenerated client
response = self._client.get_jail_break_dataset_with_type(type="upia")
response = self._client.get_jail_break_dataset_with_type(
type="upia", headers={"client_request_id": scan_session_id}
)
if isinstance(response, list):
return response
else:
Expand Down