@@ -923,6 +923,19 @@ async def generate_async(
923923            The completion (when a prompt is provided) or the next message. 
924924
925925        System messages are not yet supported.""" 
926+         # convert options to gen_options of type GenerationOptions 
927+         gen_options : Optional [GenerationOptions ] =  None 
928+ 
929+         if  prompt  is  None  and  messages  is  None :
930+             raise  ValueError ("Either prompt or messages must be provided." )
931+ 
932+         if  prompt  is  not None  and  messages  is  not None :
933+             raise  ValueError ("Only one of prompt or messages can be provided." )
934+ 
935+         if  prompt  is  not None :
936+             # Currently, we transform the prompt request into a single turn conversation 
937+             messages  =  [{"role" : "user" , "content" : prompt }]
938+ 
926939        # If a state object is specified, then we switch to "generation options" mode. 
927940        # This is because we want the output to be a GenerationResponse which will contain 
928941        # the output state. 
@@ -932,15 +945,25 @@ async def generate_async(
932945                state  =  json_to_state (state ["state" ])
933946
934947            if  options  is  None :
935-                 options  =  GenerationOptions ()
936- 
937-         # We allow options to be specified both as a dict and as an object. 
938-         if  options  and  isinstance (options , dict ):
939-             options  =  GenerationOptions (** options )
948+                 gen_options  =  GenerationOptions ()
949+             elif  isinstance (options , dict ):
950+                 gen_options  =  GenerationOptions (** options )
951+             else :
952+                 gen_options  =  options 
953+         else :
954+             # We allow options to be specified both as a dict and as an object. 
955+             if  options  and  isinstance (options , dict ):
956+                 gen_options  =  GenerationOptions (** options )
957+             elif  isinstance (options , GenerationOptions ):
958+                 gen_options  =  options 
959+             elif  options  is  None :
960+                 gen_options  =  None 
961+             else :
962+                 raise  TypeError ("options must be a dict or GenerationOptions" )
940963
941964        # Save the generation options in the current async context. 
942-         # At this point, options  is either None or GenerationOptions 
943-         generation_options_var .set (options   if   not   isinstance ( options ,  dict )  else   None )
965+         # At this point, gen_options  is either None or GenerationOptions 
966+         generation_options_var .set (gen_options )
944967
945968        if  streaming_handler :
946969            streaming_handler_var .set (streaming_handler )
@@ -950,23 +973,14 @@ async def generate_async(
950973        # requests are made. 
951974        self .explain_info  =  self ._ensure_explain_info ()
952975
953-         if  prompt  is  not None :
954-             # Currently, we transform the prompt request into a single turn conversation 
955-             messages  =  [{"role" : "user" , "content" : prompt }]
956-             raw_llm_request .set (prompt )
957-         else :
958-             raw_llm_request .set (messages )
976+         raw_llm_request .set (messages )
959977
960978        # If we have generation options, we also add them to the context 
961-         if  options :
979+         if  gen_options :
962980            messages  =  [
963981                {
964982                    "role" : "context" ,
965-                     "content" : {
966-                         "generation_options" : getattr (
967-                             options , "dict" , lambda : options 
968-                         )()
969-                     },
983+                     "content" : {"generation_options" : gen_options .model_dump ()},
970984                }
971985            ] +  (messages  or  [])
972986
@@ -976,9 +990,8 @@ async def generate_async(
976990        if  (
977991            messages 
978992            and  messages [- 1 ]["role" ] ==  "assistant" 
979-             and  options 
980-             and  hasattr (options , "rails" )
981-             and  getattr (getattr (options , "rails" , None ), "dialog" , None ) is  False 
993+             and  gen_options 
994+             and  gen_options .rails .dialog  is  False 
982995        ):
983996            # We already have the first message with a context update, so we use that 
984997            messages [0 ]["content" ]["bot_message" ] =  messages [- 1 ]["content" ]
@@ -995,7 +1008,7 @@ async def generate_async(
9951008        processing_log  =  []
9961009
9971010        # The array of events corresponding to the provided sequence of messages. 
998-         events  =  self ._get_events_for_messages (messages   or  [] , state )
1011+         events  =  self ._get_events_for_messages (messages , state )   # type: ignore 
9991012
10001013        if  self .config .colang_version  ==  "1.0" :
10011014            # If we had a state object, we also need to prepend the events from the state. 
@@ -1114,7 +1127,7 @@ async def generate_async(
11141127            # If a state object is not used, then we use the implicit caching 
11151128            if  state  is  None :
11161129                # Save the new events in the history and update the cache 
1117-                 cache_key  =  get_history_cache_key ((messages   or  [] ) +  [new_message ])
1130+                 cache_key  =  get_history_cache_key ((messages ) +  [new_message ])   # type: ignore 
11181131                self .events_history_cache [cache_key ] =  events 
11191132            else :
11201133                output_state  =  {"events" : events }
@@ -1142,33 +1155,29 @@ async def generate_async(
11421155        # IF tracing is enabled we need to set GenerationLog attrs 
11431156        original_log_options  =  None 
11441157        if  self .config .tracing .enabled :
1145-             if  options  is  None :
1146-                 options  =  GenerationOptions ()
1158+             if  gen_options  is  None :
1159+                 gen_options  =  GenerationOptions ()
11471160            else :
1148-                 # create a copy of the options to avoid modifying the original 
1149-                 if  isinstance (options , GenerationOptions ):
1150-                     options  =  options .model_copy (deep = True )
1151-                 else :
1152-                     # If options is a dict, convert it to GenerationOptions 
1153-                     options  =  GenerationOptions (** options )
1154-             original_log_options  =  options .log .model_copy (deep = True )
1161+                 # create a copy of the gen_options to avoid modifying the original 
1162+                 gen_options  =  gen_options .model_copy (deep = True )
1163+             original_log_options  =  gen_options .log .model_copy (deep = True )
11551164
11561165            # enable log options 
11571166            # it is aggressive, but these are required for tracing 
11581167            if  (
1159-                 not  options .log .activated_rails 
1160-                 or  not  options .log .llm_calls 
1161-                 or  not  options .log .internal_events 
1168+                 not  gen_options .log .activated_rails 
1169+                 or  not  gen_options .log .llm_calls 
1170+                 or  not  gen_options .log .internal_events 
11621171            ):
1163-                 options .log .activated_rails  =  True 
1164-                 options .log .llm_calls  =  True 
1165-                 options .log .internal_events  =  True 
1172+                 gen_options .log .activated_rails  =  True 
1173+                 gen_options .log .llm_calls  =  True 
1174+                 gen_options .log .internal_events  =  True 
11661175
11671176        tool_calls  =  extract_tool_calls_from_events (new_events )
11681177        llm_metadata  =  get_and_clear_response_metadata_contextvar ()
11691178
11701179        # If we have generation options, we prepare a GenerationResponse instance. 
1171-         if  options :
1180+         if  gen_options :
11721181            # If a prompt was used, we only need to return the content of the message. 
11731182            if  prompt :
11741183                res  =  GenerationResponse (response = new_message ["content" ])
@@ -1195,9 +1204,9 @@ async def generate_async(
11951204
11961205            if  self .config .colang_version  ==  "1.0" :
11971206                # If output variables are specified, we extract their values 
1198-                 if  getattr ( options ,  "output_vars" ,  None ) :
1207+                 if  gen_options   and   gen_options . output_vars :
11991208                    context  =  compute_context (events )
1200-                     output_vars  =  getattr ( options ,  " output_vars" ,  None ) 
1209+                     output_vars  =  gen_options . output_vars 
12011210                    if  isinstance (output_vars , list ):
12021211                        # If we have only a selection of keys, we filter to only that. 
12031212                        res .output_data  =  {k : context .get (k ) for  k  in  output_vars }
@@ -1208,65 +1217,64 @@ async def generate_async(
12081217                _log  =  compute_generation_log (processing_log )
12091218
12101219                # Include information about activated rails and LLM calls if requested 
1211-                 log_options  =  getattr ( options ,  " log" ,  None ) 
1220+                 log_options  =  gen_options . log   if   gen_options   else   None 
12121221                if  log_options  and  (
1213-                     getattr (log_options , "activated_rails" , False )
1214-                     or  getattr (log_options , "llm_calls" , False )
1222+                     log_options .activated_rails  or  log_options .llm_calls 
12151223                ):
12161224                    res .log  =  GenerationLog ()
12171225
12181226                    # We always include the stats 
12191227                    res .log .stats  =  _log .stats 
12201228
1221-                     if  getattr ( log_options ,  " activated_rails" ,  False ) :
1229+                     if  log_options . activated_rails :
12221230                        res .log .activated_rails  =  _log .activated_rails 
12231231
1224-                     if  getattr ( log_options ,  " llm_calls" ,  False ) :
1232+                     if  log_options . llm_calls :
12251233                        res .log .llm_calls  =  []
12261234                        for  activated_rail  in  _log .activated_rails :
12271235                            for  executed_action  in  activated_rail .executed_actions :
12281236                                res .log .llm_calls .extend (executed_action .llm_calls )
12291237
12301238                # Include internal events if requested 
1231-                 if  getattr ( log_options ,  "internal_events" ,  False ) :
1239+                 if  log_options   and   log_options . internal_events :
12321240                    if  res .log  is  None :
12331241                        res .log  =  GenerationLog ()
12341242
12351243                    res .log .internal_events  =  new_events 
12361244
12371245                # Include the Colang history if requested 
1238-                 if  getattr ( log_options ,  "colang_history" ,  False ) :
1246+                 if  log_options   and   log_options . colang_history :
12391247                    if  res .log  is  None :
12401248                        res .log  =  GenerationLog ()
12411249
12421250                    res .log .colang_history  =  get_colang_history (events )
12431251
12441252                # Include the raw llm output if requested 
1245-                 if  getattr ( options ,  "llm_output" ,  False ) :
1253+                 if  gen_options   and   gen_options . llm_output :
12461254                    # Currently, we include the output from the generation LLM calls. 
12471255                    for  activated_rail  in  _log .activated_rails :
12481256                        if  activated_rail .type  ==  "generation" :
12491257                            for  executed_action  in  activated_rail .executed_actions :
12501258                                for  llm_call  in  executed_action .llm_calls :
12511259                                    res .llm_output  =  llm_call .raw_response 
12521260            else :
1253-                 if  getattr ( options ,  "output_vars" ,  None ) :
1261+                 if  gen_options   and   gen_options . output_vars :
12541262                    raise  ValueError (
12551263                        "The `output_vars` option is not supported for Colang 2.0 configurations." 
12561264                    )
12571265
1258-                 log_options  =  getattr ( options ,  " log" ,  None ) 
1266+                 log_options  =  gen_options . log   if   gen_options   else   None 
12591267                if  log_options  and  (
1260-                     getattr ( log_options ,  " activated_rails" ,  False ) 
1261-                     or  getattr ( log_options ,  " llm_calls" ,  False ) 
1262-                     or  getattr ( log_options ,  " internal_events" ,  False ) 
1263-                     or  getattr ( log_options ,  " colang_history" ,  False ) 
1268+                     log_options . activated_rails 
1269+                     or  log_options . llm_calls 
1270+                     or  log_options . internal_events 
1271+                     or  log_options . colang_history 
12641272                ):
12651273                    raise  ValueError (
12661274                        "The `log` option is not supported for Colang 2.0 configurations." 
12671275                    )
12681276
1269-                 if  getattr ( options ,  "llm_output" ,  False ) :
1277+                 if  gen_options   and   gen_options . llm_output :
12701278                    raise  ValueError (
12711279                        "The `llm_output` option is not supported for Colang 2.0 configurations." 
12721280                    )
@@ -1300,25 +1308,21 @@ async def generate_async(
13001308                if  original_log_options :
13011309                    if  not  any (
13021310                        (
1303-                             getattr ( original_log_options ,  " internal_events" ,  False ) ,
1304-                             getattr ( original_log_options ,  " activated_rails" ,  False ) ,
1305-                             getattr ( original_log_options ,  " llm_calls" ,  False ) ,
1306-                             getattr ( original_log_options ,  " colang_history" ,  False ) ,
1311+                             original_log_options . internal_events ,
1312+                             original_log_options . activated_rails ,
1313+                             original_log_options . llm_calls ,
1314+                             original_log_options . colang_history ,
13071315                        )
13081316                    ):
13091317                        res .log  =  None 
13101318                    else :
13111319                        # Ensure res.log exists before setting attributes 
13121320                        if  res .log  is  not None :
1313-                             if  not  getattr (
1314-                                 original_log_options , "internal_events" , False 
1315-                             ):
1321+                             if  not  original_log_options .internal_events :
13161322                                res .log .internal_events  =  []
1317-                             if  not  getattr (
1318-                                 original_log_options , "activated_rails" , False 
1319-                             ):
1323+                             if  not  original_log_options .activated_rails :
13201324                                res .log .activated_rails  =  []
1321-                             if  not  getattr ( original_log_options ,  " llm_calls" ,  False ) :
1325+                             if  not  original_log_options . llm_calls :
13221326                                res .log .llm_calls  =  []
13231327
13241328            return  res 
0 commit comments