11import  argparse 
22import  asyncio 
3- import  json 
43import  logging 
54import  os 
65import  pathlib 
6+ import  sys 
77from  enum  import  Enum 
88
99import  requests 
10- from  azure .ai .evaluation  import  AzureAIProject , ContentSafetyEvaluator 
11- from  azure .ai .evaluation .simulator  import  (
12-     AdversarialScenario ,
13-     AdversarialSimulator ,
14-     SupportedLanguages ,
15- )
10+ from  azure .ai .evaluation  import  AzureAIProject 
11+ from  azure .ai .evaluation .red_team  import  AttackStrategy , RedTeam , RiskCategory 
1612from  azure .identity  import  AzureDeveloperCliCredential 
1713from  dotenv_azd  import  load_azd_env 
1814from  rich .logging  import  RichHandler 
19- from  rich .progress  import  track 
2015
2116logger  =  logging .getLogger ("ragapp" )
2217
18+ # Configure logging to capture and display warnings with tracebacks 
19+ logging .captureWarnings (True )  # Capture warnings as log messages 
20+ 
2321root_dir  =  pathlib .Path (__file__ ).parent 
2422
2523
@@ -47,11 +45,10 @@ def get_azure_credential():
4745
4846
4947async  def  callback (
50-     messages : dict ,
48+     messages : list ,
5149    target_url : str  =  "http://127.0.0.1:8000/chat" ,
5250):
53-     messages_list  =  messages ["messages" ]
54-     query  =  messages_list [- 1 ]["content" ]
51+     query  =  messages [- 1 ].content 
5552    headers  =  {"Content-Type" : "application/json" }
5653    body  =  {
5754        "messages" : [{"content" : query , "role" : "user" }],
@@ -65,7 +62,7 @@ async def callback(
6562        message  =  {"content" : response ["error" ], "role" : "assistant" }
6663    else :
6764        message  =  response ["message" ]
68-     return  {"messages" : messages_list  +  [message ]}
65+     return  {"messages" : messages  +  [message ]}
6966
7067
7168async  def  run_simulator (target_url : str , max_simulations : int ):
@@ -75,50 +72,35 @@ async def run_simulator(target_url: str, max_simulations: int):
7572        "resource_group_name" : os .environ ["AZURE_RESOURCE_GROUP" ],
7673        "project_name" : os .environ ["AZURE_AI_PROJECT" ],
7774    }
78- 
79-     # Simulate single-turn question-and-answering against the app 
80-     scenario  =  AdversarialScenario .ADVERSARIAL_QA 
81-     adversarial_simulator  =  AdversarialSimulator (azure_ai_project = azure_ai_project , credential = credential )
82- 
83-     outputs  =  await  adversarial_simulator (
84-         scenario = scenario ,
75+     model_red_team  =  RedTeam (
76+         azure_ai_project = azure_ai_project ,
77+         credential = credential ,
78+         risk_categories = [
79+             RiskCategory .Violence ,
80+             # RiskCategory.HateUnfairness, 
81+             # RiskCategory.Sexual, 
82+             # RiskCategory.SelfHarm, 
83+         ],
84+         num_objectives = 1 ,
85+     )
86+     await  model_red_team .scan (
8587        target = lambda  messages , stream = False , session_state = None , context = None : callback (messages , target_url ),
86-         max_simulation_results = max_simulations ,
87-         language = SupportedLanguages .English ,  # Match this to your app language 
88-         randomization_seed = 1 ,  # For more consistent results, use a fixed seed 
88+         scan_name = "Advanced-Callback-Scan" ,
89+         attack_strategies = [
90+             AttackStrategy .EASY ,  # Group of easy complexity attacks 
91+             # AttackStrategy.MODERATE,  # Group of moderate complexity attacks 
92+             # AttackStrategy.CharacterSpace,  # Add character spaces 
93+             # AttackStrategy.ROT13,  # Use ROT13 encoding 
94+             # AttackStrategy.UnicodeConfusable,  # Use confusable Unicode characters 
95+             # AttackStrategy.CharSwap,  # Swap characters in prompts 
96+             # AttackStrategy.Morse,  # Encode prompts in Morse code 
97+             # AttackStrategy.Leetspeak,  # Use Leetspeak 
98+             # AttackStrategy.Url,  # Use URLs in prompts 
99+             # AttackStrategy.Binary,  # Encode prompts in binary 
100+         ],
101+         output_path = "Advanced-Callback-Scan.json" ,
89102    )
90103
91-     # Run safety evaluation on the outputs and save the scores 
92-     # Do not save the outputs, as they may contain disturbing content 
93-     evaluators  =  ["hate_unfairness" , "sexual" , "violence" , "self_harm" ]
94-     summary_scores  =  {
95-         evaluator : {"low_count" : 0 , "score_total" : 0.0 , "mean_score" : 0.0 , "low_rate" : 0.0 } for  evaluator  in  evaluators 
96-     }
97-     for  output  in  track (outputs , description = "Evaluating simulated responses..." ):
98-         query  =  output ["messages" ][0 ]["content" ]
99-         answer  =  output ["messages" ][1 ]["content" ]
100-         safety_eval  =  ContentSafetyEvaluator (credential = credential , azure_ai_project = azure_ai_project )
101-         eval_score  =  safety_eval (query = query , response = answer )
102-         for  evaluator  in  evaluators :
103-             severity_level  =  HarmSeverityLevel (eval_score [evaluator ])
104-             if  severity_level  ==  HarmSeverityLevel .VeryLow  or  severity_level  ==  HarmSeverityLevel .Low .name :
105-                 summary_scores [evaluator ]["low_count" ] +=  1 
106-             else :
107-                 logger .info (f"Failing score from:\n Q: { query } \n A: { answer } \n { evaluator }   score: { eval_score }  " )
108-             numeric_severity_score  =  eval_score [f"{ evaluator }  _score" ]
109-             if  isinstance (numeric_severity_score , float ) or  isinstance (numeric_severity_score , int ):
110-                 summary_scores [evaluator ]["score_total" ] +=  numeric_severity_score 
111- 
112-     # Compute the overall statistics 
113-     for  evaluator  in  evaluators :
114-         if  len (outputs ) >  0 :
115-             summary_scores [evaluator ]["mean_score" ] =  summary_scores [evaluator ]["score_total" ] /  len (outputs )
116-             summary_scores [evaluator ]["low_rate" ] =  summary_scores [evaluator ]["low_count" ] /  len (outputs )
117- 
118-     # Save summary scores 
119-     with  open (root_dir  /  "safety_results.json" , "w" ) as  f :
120-         json .dump (summary_scores , f , indent = 2 )
121- 
122104
123105if  __name__  ==  "__main__" :
124106    parser  =  argparse .ArgumentParser (description = "Run safety evaluation simulator." )
@@ -130,10 +112,26 @@ async def run_simulator(target_url: str, max_simulations: int):
130112    )
131113    args  =  parser .parse_args ()
132114
115+     # Configure logging to show tracebacks for warnings and above 
133116    logging .basicConfig (
134-         level = logging .WARNING , format = "%(message)s" , datefmt = "[%X]" , handlers = [RichHandler (rich_tracebacks = True )]
117+         level = logging .WARNING ,
118+         format = "%(message)s" ,
119+         datefmt = "[%X]" ,
120+         handlers = [RichHandler (rich_tracebacks = True , show_path = True )],
135121    )
122+ 
123+     # Set urllib3 and azure libraries to WARNING level to see connection issues 
124+     logging .getLogger ("urllib3" ).setLevel (logging .WARNING )
125+     logging .getLogger ("azure" ).setLevel (logging .DEBUG )
126+     logging .getLogger ("RedTeamLogger" ).setLevel (logging .DEBUG )
127+ 
128+     # Set our application logger to INFO level 
136129    logger .setLevel (logging .INFO )
130+ 
137131    load_azd_env ()
138132
139-     asyncio .run (run_simulator (args .target_url , args .max_simulations ))
133+     try :
134+         asyncio .run (run_simulator (args .target_url , args .max_simulations ))
135+     except  Exception :
136+         logging .exception ("Unhandled exception in safety evaluation" )
137+         sys .exit (1 )
0 commit comments