@@ -873,6 +873,19 @@ async def generate_async(
873873 The completion (when a prompt is provided) or the next message.
874874
875875 System messages are not yet supported."""
876+ # convert options to gen_options of type GenerationOptions
877+ gen_options : Optional [GenerationOptions ] = None
878+
879+ if prompt is None and messages is None :
880+ raise ValueError ("Either prompt or messages must be provided." )
881+
882+ if prompt is not None and messages is not None :
883+ raise ValueError ("Only one of prompt or messages can be provided." )
884+
885+ if prompt is not None :
886+ # Currently, we transform the prompt request into a single turn conversation
887+ messages = [{"role" : "user" , "content" : prompt }]
888+
876889 # If a state object is specified, then we switch to "generation options" mode.
877890 # This is because we want the output to be a GenerationResponse which will contain
878891 # the output state.
@@ -882,15 +895,25 @@ async def generate_async(
882895 state = json_to_state (state ["state" ])
883896
884897 if options is None :
885- options = GenerationOptions ()
886-
887- # We allow options to be specified both as a dict and as an object.
888- if options and isinstance (options , dict ):
889- options = GenerationOptions (** options )
898+ gen_options = GenerationOptions ()
899+ elif isinstance (options , dict ):
900+ gen_options = GenerationOptions (** options )
901+ else :
902+ gen_options = options
903+ else :
904+ # We allow options to be specified both as a dict and as an object.
905+ if options and isinstance (options , dict ):
906+ gen_options = GenerationOptions (** options )
907+ elif isinstance (options , GenerationOptions ):
908+ gen_options = options
909+ elif options is None :
910+ gen_options = None
911+ else :
912+ raise TypeError ("options must be a dict or GenerationOptions" )
890913
891914 # Save the generation options in the current async context.
892- # At this point, options is either None or GenerationOptions
893- generation_options_var .set (options if not isinstance ( options , dict ) else None )
915+ # At this point, gen_options is either None or GenerationOptions
916+ generation_options_var .set (gen_options )
894917
895918 if streaming_handler :
896919 streaming_handler_var .set (streaming_handler )
@@ -900,23 +923,14 @@ async def generate_async(
900923 # requests are made.
901924 self .explain_info = self ._ensure_explain_info ()
902925
903- if prompt is not None :
904- # Currently, we transform the prompt request into a single turn conversation
905- messages = [{"role" : "user" , "content" : prompt }]
906- raw_llm_request .set (prompt )
907- else :
908- raw_llm_request .set (messages )
926+ raw_llm_request .set (messages )
909927
910928 # If we have generation options, we also add them to the context
911- if options :
929+ if gen_options :
912930 messages = [
913931 {
914932 "role" : "context" ,
915- "content" : {
916- "generation_options" : getattr (
917- options , "dict" , lambda : options
918- )()
919- },
933+ "content" : {"generation_options" : gen_options .model_dump ()},
920934 }
921935 ] + (messages or [])
922936
@@ -926,9 +940,8 @@ async def generate_async(
926940 if (
927941 messages
928942 and messages [- 1 ]["role" ] == "assistant"
929- and options
930- and hasattr (options , "rails" )
931- and getattr (getattr (options , "rails" , None ), "dialog" , None ) is False
943+ and gen_options
944+ and gen_options .rails .dialog is False
932945 ):
933946 # We already have the first message with a context update, so we use that
934947 messages [0 ]["content" ]["bot_message" ] = messages [- 1 ]["content" ]
@@ -945,7 +958,7 @@ async def generate_async(
945958 processing_log = []
946959
947960 # The array of events corresponding to the provided sequence of messages.
948- events = self ._get_events_for_messages (messages or [] , state )
961+ events = self ._get_events_for_messages (messages , state ) # type: ignore
949962
950963 if self .config .colang_version == "1.0" :
951964 # If we had a state object, we also need to prepend the events from the state.
@@ -1064,7 +1077,7 @@ async def generate_async(
10641077 # If a state object is not used, then we use the implicit caching
10651078 if state is None :
10661079 # Save the new events in the history and update the cache
1067- cache_key = get_history_cache_key ((messages or [] ) + [new_message ])
1080+ cache_key = get_history_cache_key ((messages ) + [new_message ]) # type: ignore
10681081 self .events_history_cache [cache_key ] = events
10691082 else :
10701083 output_state = {"events" : events }
@@ -1092,30 +1105,26 @@ async def generate_async(
10921105 # IF tracing is enabled we need to set GenerationLog attrs
10931106 original_log_options = None
10941107 if self .config .tracing .enabled :
1095- if options is None :
1096- options = GenerationOptions ()
1108+ if gen_options is None :
1109+ gen_options = GenerationOptions ()
10971110 else :
1098- # create a copy of the options to avoid modifying the original
1099- if isinstance (options , GenerationOptions ):
1100- options = options .model_copy (deep = True )
1101- else :
1102- # If options is a dict, convert it to GenerationOptions
1103- options = GenerationOptions (** options )
1104- original_log_options = options .log .model_copy (deep = True )
1111+ # create a copy of the gen_options to avoid modifying the original
1112+ gen_options = gen_options .model_copy (deep = True )
1113+ original_log_options = gen_options .log .model_copy (deep = True )
11051114
11061115 # enable log options
11071116 # it is aggressive, but these are required for tracing
11081117 if (
1109- not options .log .activated_rails
1110- or not options .log .llm_calls
1111- or not options .log .internal_events
1118+ not gen_options .log .activated_rails
1119+ or not gen_options .log .llm_calls
1120+ or not gen_options .log .internal_events
11121121 ):
1113- options .log .activated_rails = True
1114- options .log .llm_calls = True
1115- options .log .internal_events = True
1122+ gen_options .log .activated_rails = True
1123+ gen_options .log .llm_calls = True
1124+ gen_options .log .internal_events = True
11161125
11171126 # If we have generation options, we prepare a GenerationResponse instance.
1118- if options :
1127+ if gen_options :
11191128 # If a prompt was used, we only need to return the content of the message.
11201129 if prompt :
11211130 res = GenerationResponse (response = new_message ["content" ])
@@ -1136,9 +1145,9 @@ async def generate_async(
11361145
11371146 if self .config .colang_version == "1.0" :
11381147 # If output variables are specified, we extract their values
1139- if getattr ( options , "output_vars" , None ) :
1148+ if gen_options and gen_options . output_vars :
11401149 context = compute_context (events )
1141- output_vars = getattr ( options , " output_vars" , None )
1150+ output_vars = gen_options . output_vars
11421151 if isinstance (output_vars , list ):
11431152 # If we have only a selection of keys, we filter to only that.
11441153 res .output_data = {k : context .get (k ) for k in output_vars }
@@ -1149,65 +1158,64 @@ async def generate_async(
11491158 _log = compute_generation_log (processing_log )
11501159
11511160 # Include information about activated rails and LLM calls if requested
1152- log_options = getattr ( options , " log" , None )
1161+ log_options = gen_options . log if gen_options else None
11531162 if log_options and (
1154- getattr (log_options , "activated_rails" , False )
1155- or getattr (log_options , "llm_calls" , False )
1163+ log_options .activated_rails or log_options .llm_calls
11561164 ):
11571165 res .log = GenerationLog ()
11581166
11591167 # We always include the stats
11601168 res .log .stats = _log .stats
11611169
1162- if getattr ( log_options , " activated_rails" , False ) :
1170+ if log_options . activated_rails :
11631171 res .log .activated_rails = _log .activated_rails
11641172
1165- if getattr ( log_options , " llm_calls" , False ) :
1173+ if log_options . llm_calls :
11661174 res .log .llm_calls = []
11671175 for activated_rail in _log .activated_rails :
11681176 for executed_action in activated_rail .executed_actions :
11691177 res .log .llm_calls .extend (executed_action .llm_calls )
11701178
11711179 # Include internal events if requested
1172- if getattr ( log_options , "internal_events" , False ) :
1180+ if log_options and log_options . internal_events :
11731181 if res .log is None :
11741182 res .log = GenerationLog ()
11751183
11761184 res .log .internal_events = new_events
11771185
11781186 # Include the Colang history if requested
1179- if getattr ( log_options , "colang_history" , False ) :
1187+ if log_options and log_options . colang_history :
11801188 if res .log is None :
11811189 res .log = GenerationLog ()
11821190
11831191 res .log .colang_history = get_colang_history (events )
11841192
11851193 # Include the raw llm output if requested
1186- if getattr ( options , "llm_output" , False ) :
1194+ if gen_options and gen_options . llm_output :
11871195 # Currently, we include the output from the generation LLM calls.
11881196 for activated_rail in _log .activated_rails :
11891197 if activated_rail .type == "generation" :
11901198 for executed_action in activated_rail .executed_actions :
11911199 for llm_call in executed_action .llm_calls :
11921200 res .llm_output = llm_call .raw_response
11931201 else :
1194- if getattr ( options , "output_vars" , None ) :
1202+ if gen_options and gen_options . output_vars :
11951203 raise ValueError (
11961204 "The `output_vars` option is not supported for Colang 2.0 configurations."
11971205 )
11981206
1199- log_options = getattr ( options , " log" , None )
1207+ log_options = gen_options . log if gen_options else None
12001208 if log_options and (
1201- getattr ( log_options , " activated_rails" , False )
1202- or getattr ( log_options , " llm_calls" , False )
1203- or getattr ( log_options , " internal_events" , False )
1204- or getattr ( log_options , " colang_history" , False )
1209+ log_options . activated_rails
1210+ or log_options . llm_calls
1211+ or log_options . internal_events
1212+ or log_options . colang_history
12051213 ):
12061214 raise ValueError (
12071215 "The `log` option is not supported for Colang 2.0 configurations."
12081216 )
12091217
1210- if getattr ( options , "llm_output" , False ) :
1218+ if gen_options and gen_options . llm_output :
12111219 raise ValueError (
12121220 "The `llm_output` option is not supported for Colang 2.0 configurations."
12131221 )
@@ -1241,25 +1249,21 @@ async def generate_async(
12411249 if original_log_options :
12421250 if not any (
12431251 (
1244- getattr ( original_log_options , " internal_events" , False ) ,
1245- getattr ( original_log_options , " activated_rails" , False ) ,
1246- getattr ( original_log_options , " llm_calls" , False ) ,
1247- getattr ( original_log_options , " colang_history" , False ) ,
1252+ original_log_options . internal_events ,
1253+ original_log_options . activated_rails ,
1254+ original_log_options . llm_calls ,
1255+ original_log_options . colang_history ,
12481256 )
12491257 ):
12501258 res .log = None
12511259 else :
12521260 # Ensure res.log exists before setting attributes
12531261 if res .log is not None :
1254- if not getattr (
1255- original_log_options , "internal_events" , False
1256- ):
1262+ if not original_log_options .internal_events :
12571263 res .log .internal_events = []
1258- if not getattr (
1259- original_log_options , "activated_rails" , False
1260- ):
1264+ if not original_log_options .activated_rails :
12611265 res .log .activated_rails = []
1262- if not getattr ( original_log_options , " llm_calls" , False ) :
1266+ if not original_log_options . llm_calls :
12631267 res .log .llm_calls = []
12641268
12651269 return res
0 commit comments